TTK
Loading...
Searching...
No Matches
MergeTreeTorchUtils.cpp
Go to the documentation of this file.
2
3using namespace std;
4using namespace ttk;
5
6#ifdef TTK_ENABLE_TORCH
7using namespace torch::indexing;
8
9void mtu::copyTensor(torch::Tensor &a, torch::Tensor &b) {
10 b = a.detach().clone();
11 b.requires_grad_(a.requires_grad());
12}
13
14void mtu::getDeltaProjTensor(torch::Tensor &diagTensor,
15 torch::Tensor &deltaProjTensor) {
16 deltaProjTensor
17 = (diagTensor.index({Slice(), 0}) + diagTensor.index({Slice(), 1})) / 2.0;
18 deltaProjTensor = deltaProjTensor.reshape({-1, 1});
19 deltaProjTensor = torch::cat({deltaProjTensor, deltaProjTensor}, 1);
20}
21
22void mtu::dataReorderingGivenMatching(mtu::TorchMergeTree<float> &tree,
23 mtu::TorchMergeTree<float> &tree2,
24 torch::Tensor &tree1ProjIndexer,
25 torch::Tensor &tree2ReorderingIndexes,
26 torch::Tensor &tree2ReorderedTensor,
27 torch::Tensor &tree2DeltaProjTensor,
28 torch::Tensor &tree1ReorderedTensor,
29 torch::Tensor &tree2ProjIndexer,
30 bool doubleReordering) {
31 // Reorder tree2 tensor
32 torch::Tensor tree2DiagTensor = tree2.tensor.reshape({-1, 2});
33 tree2ReorderedTensor = torch::cat({tree2DiagTensor, torch::zeros({1, 2})});
34 tree2ReorderedTensor = tree2ReorderedTensor.index({tree2ReorderingIndexes});
35
36 // Create tree projection given matching
37 torch::Tensor treeDiagTensor = tree.tensor.reshape({-1, 2});
38 getDeltaProjTensor(treeDiagTensor, tree2DeltaProjTensor);
39 tree2DeltaProjTensor = tree2DeltaProjTensor * tree1ProjIndexer;
40
41 // Double reordering
42 if(doubleReordering) {
43 torch::Tensor tree1DeltaProjTensor;
44 getDeltaProjTensor(tree2DiagTensor, tree1DeltaProjTensor);
45 torch::Tensor tree2ProjIndexerR = tree2ProjIndexer.reshape({-1});
46 tree1DeltaProjTensor = tree1DeltaProjTensor.index({tree2ProjIndexerR});
47 tree1ReorderedTensor = torch::cat({treeDiagTensor, tree1DeltaProjTensor});
48 tree1ReorderedTensor = tree1ReorderedTensor.reshape({-1, 1});
49 torch::Tensor tree2UnmatchedTensor
50 = tree2DiagTensor.index({tree2ProjIndexerR});
51 tree2ReorderedTensor
52 = torch::cat({tree2ReorderedTensor, tree2UnmatchedTensor});
53 tree2DeltaProjTensor = torch::cat(
54 {tree2DeltaProjTensor, torch::zeros_like(tree2UnmatchedTensor)});
55 }
56
57 // Reshape
58 tree2ReorderedTensor = tree2ReorderedTensor.reshape({-1, 1});
59 tree2DeltaProjTensor = tree2DeltaProjTensor.reshape({-1, 1});
60}
61
62void mtu::dataReorderingGivenMatching(mtu::TorchMergeTree<float> &tree,
63 mtu::TorchMergeTree<float> &tree2,
64 torch::Tensor &tree1ProjIndexer,
65 torch::Tensor &tree2ReorderingIndexes,
66 torch::Tensor &tree2ReorderedTensor,
67 torch::Tensor &tree2DeltaProjTensor) {
68 torch::Tensor tree1ReorderedTensor;
69 torch::Tensor tree2ProjIndexer;
70 bool doubleReordering = false;
71 dataReorderingGivenMatching(tree, tree2, tree1ProjIndexer,
72 tree2ReorderingIndexes, tree2ReorderedTensor,
73 tree2DeltaProjTensor, tree1ReorderedTensor,
74 tree2ProjIndexer, doubleReordering);
75}
76
77void mtu::dataReorderingGivenMatching(
78 mtu::TorchMergeTree<float> &tree,
79 mtu::TorchMergeTree<float> &tree2,
80 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
81 torch::Tensor &tree1ReorderedTensor,
82 torch::Tensor &tree2ReorderedTensor,
83 bool doubleReordering) {
84 // Get tensor matching
85 std::vector<int> tensorMatching;
86 mtu::getTensorMatching(tree, tree2, matching, tensorMatching);
87 torch::Tensor tree2ReorderingIndexes = torch::tensor(tensorMatching);
88 torch::Tensor tree1ProjIndexer
89 = (tree2ReorderingIndexes == -1).reshape({-1, 1});
90 // Reorder tensor
91 torch::Tensor tree2DeltaProjTensor;
92 if(not doubleReordering) {
93 dataReorderingGivenMatching(tree, tree2, tree1ProjIndexer,
94 tree2ReorderingIndexes, tree2ReorderedTensor,
95 tree2DeltaProjTensor);
96 } else {
97 std::vector<int> tensorMatching2;
98 mtu::getInverseTensorMatching(tree, tree2, matching, tensorMatching);
99 torch::Tensor tree1ReorderingIndexes = torch::tensor(tensorMatching);
100 torch::Tensor tree2ProjIndexer
101 = (tree1ReorderingIndexes == -1).reshape({-1, 1});
102 dataReorderingGivenMatching(tree, tree2, tree1ProjIndexer,
103 tree2ReorderingIndexes, tree2ReorderedTensor,
104 tree2DeltaProjTensor, tree1ReorderedTensor,
105 tree2ProjIndexer, doubleReordering);
106 }
107 tree2ReorderedTensor = tree2ReorderedTensor + tree2DeltaProjTensor;
108}
109
110void mtu::dataReorderingGivenMatching(
111 mtu::TorchMergeTree<float> &tree,
112 mtu::TorchMergeTree<float> &tree2,
113 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
114 torch::Tensor &tree2ReorderedTensor) {
115 torch::Tensor tree1ReorderedTensor;
116 bool doubleReordering = false;
117 dataReorderingGivenMatching(tree, tree2, matching, tree1ReorderedTensor,
118 tree2ReorderedTensor, doubleReordering);
119}
120
121void mtu::meanBirthShift(torch::Tensor &diagTensor,
122 torch::Tensor &diagBaseTensor) {
123 torch::Tensor birthShiftValue = diagBaseTensor.index({Slice(), 0}).mean()
124 - diagTensor.index({Slice(), 0}).mean();
125 torch::Tensor shiftTensor
126 = torch::full({diagTensor.sizes()[0], 2}, birthShiftValue.item<float>());
127 diagTensor.index_put_({None}, diagTensor + shiftTensor);
128}
129
130void mtu::meanBirthMaxPersShift(torch::Tensor &tensor,
131 torch::Tensor &baseTensor) {
132 torch::Tensor diagTensor = tensor.reshape({-1, 2});
133 torch::Tensor diagBaseTensor = baseTensor.reshape({-1, 2});
134 // Shift to have same max pers
135 torch::Tensor baseMaxPers
136 = (diagBaseTensor.index({Slice(), 1}) - diagBaseTensor.index({Slice(), 0}))
137 .max();
138 torch::Tensor maxPers
139 = (diagTensor.index({Slice(), 1}) - diagTensor.index({Slice(), 0})).max();
140 torch::Tensor shiftTensor = (baseMaxPers - maxPers) / 2.0;
141 shiftTensor = torch::stack({-shiftTensor, shiftTensor});
142 diagTensor.index_put_({None}, diagTensor + shiftTensor);
143 // Shift to have same birth mean
144 meanBirthShift(diagTensor, diagBaseTensor);
145}
146
147void mtu::belowDiagonalPointsShift(torch::Tensor &tensor,
148 torch::Tensor &backupTensor) {
149 torch::Tensor oPDiag = tensor.reshape({-1, 2});
150 torch::Tensor badPointsIndexer
151 = (oPDiag.index({Slice(), 0}) > oPDiag.index({Slice(), 1}));
152 torch::Tensor goodPoints = oPDiag.index({~badPointsIndexer});
153 if(goodPoints.sizes()[0] == 0)
154 goodPoints = backupTensor.reshape({-1, 2});
155 torch::Tensor badPoints = oPDiag.index({badPointsIndexer});
156 // Shift to be above diagonal with median pers
157 torch::Tensor pers
158 = (goodPoints.index({Slice(), 1}) - goodPoints.index({Slice(), 0}))
159 .median();
160 torch::Tensor shiftTensor
161 = (torch::full({badPoints.sizes()[0], 1}, pers.item<float>())
162 - badPoints.index({Slice(), 1}).reshape({-1, 1})
163 + badPoints.index({Slice(), 0}).reshape({-1, 1}))
164 / 2.0;
165 shiftTensor = torch::cat({-shiftTensor, shiftTensor}, 1);
166 badPoints = badPoints + shiftTensor;
167 // Update tensor
168 oPDiag.index_put_({badPointsIndexer}, badPoints);
169 tensor = oPDiag.reshape({-1, 1}).detach();
170}
171
172void mtu::normalizeVectors(torch::Tensor &originTensor,
173 torch::Tensor &vectorsTensor) {
174 torch::Tensor vSliced = vectorsTensor.index({Slice(2, None)});
175 vSliced.index_put_({None}, vSliced / (originTensor[1] - originTensor[0]));
176}
177
178void mtu::normalizeVectors(mtu::TorchMergeTree<float> &origin,
179 std::vector<std::vector<double>> &vectors) {
180 std::queue<ftm::idNode> queue;
181 queue.emplace(origin.mTree.tree.getRoot());
182 while(!queue.empty()) {
183 ftm::idNode node = queue.front();
184 queue.pop();
185 if(not origin.mTree.tree.isRoot(node))
186 for(unsigned int i = 0; i < 2; ++i)
187 vectors[node][i] /= (origin.tensor[1] - origin.tensor[0]).item<float>();
188 std::vector<ftm::idNode> children;
189 origin.mTree.tree.getChildren(node, children);
190 for(auto &child : children)
191 queue.emplace(child);
192 }
193}
194
195// Work only for persistence diagrams
196bool mtu::isThereMissingPairs(mtu::TorchMergeTree<float> &interpolation) {
197 float maxPers
198 = interpolation.mTree.tree.template getMaximumPersistence<float>();
199 torch::Tensor interTensor = interpolation.tensor;
200 torch::Tensor indexer
201 = torch::abs(interTensor.reshape({-1, 2}).index({Slice(), 0})
202 - interTensor.reshape({-1, 2}).index({Slice(), 1}))
203 > (maxPers * 0.001 / 100.0);
204 torch::Tensor indexed = interTensor.reshape({-1, 2}).index({indexer});
205 return indexed.sizes()[0] > interpolation.mTree.tree.getRealNumberOfNodes();
206}
207#endif
T mean(const T *v, const int &dimension=3)
Definition Statistics.cpp:9
unsigned int idNode
Node index in vect_nodes_.
The Topology ToolKit.