TTK
Loading...
Searching...
No Matches
MergeTreeTorchUtils.cpp
Go to the documentation of this file.
2
3using namespace std;
4using namespace ttk;
5
6#ifdef TTK_ENABLE_TORCH
7using namespace torch::indexing;
8
9void mtu::copyTensor(const torch::Tensor &a, torch::Tensor &b) {
10 b = a.detach().clone();
11 b.requires_grad_(a.requires_grad());
12}
13
14void mtu::getDeltaProjTensor(torch::Tensor &diagTensor,
15 torch::Tensor &deltaProjTensor) {
16 deltaProjTensor
17 = (diagTensor.index({Slice(), 0}) + diagTensor.index({Slice(), 1})) / 2.0;
18 deltaProjTensor = deltaProjTensor.reshape({-1, 1});
19 deltaProjTensor = torch::cat({deltaProjTensor, deltaProjTensor}, 1);
20}
21
22void mtu::dataReorderingGivenMatching(const mtu::TorchMergeTree<float> &tree,
23 const mtu::TorchMergeTree<float> &tree2,
24 torch::Tensor &tree1ProjIndexer,
25 torch::Tensor &tree2ReorderingIndexes,
26 torch::Tensor &tree2ReorderedTensor,
27 torch::Tensor &tree2DeltaProjTensor,
28 torch::Tensor &tree1ReorderedTensor,
29 torch::Tensor &tree2ProjIndexer,
30 bool doubleReordering) {
31 // Reorder tree2 tensor
32 torch::Tensor tree2DiagTensor = tree2.tensor.reshape({-1, 2});
33 auto zeros = torch::zeros(
34 {1, 2}, torch::TensorOptions().device(tree2DiagTensor.device()));
35 tree2ReorderedTensor = torch::cat({tree2DiagTensor, zeros});
36 tree2ReorderedTensor = tree2ReorderedTensor.index({tree2ReorderingIndexes});
37
38 // Create tree projection given matching
39 torch::Tensor treeDiagTensor = tree.tensor.reshape({-1, 2});
40 getDeltaProjTensor(treeDiagTensor, tree2DeltaProjTensor);
41 if(!tree2DeltaProjTensor.device().is_cpu())
42 tree1ProjIndexer = tree1ProjIndexer.to(tree2DeltaProjTensor.device());
43 tree2DeltaProjTensor = tree2DeltaProjTensor * tree1ProjIndexer;
44
45 // Double reordering
46 if(doubleReordering) {
47 torch::Tensor tree1DeltaProjTensor;
48 getDeltaProjTensor(tree2DiagTensor, tree1DeltaProjTensor);
49 torch::Tensor tree2ProjIndexerR = tree2ProjIndexer.reshape({-1});
50 tree1DeltaProjTensor = tree1DeltaProjTensor.index({tree2ProjIndexerR});
51 tree1ReorderedTensor = torch::cat({treeDiagTensor, tree1DeltaProjTensor});
52 tree1ReorderedTensor = tree1ReorderedTensor.reshape({-1, 1});
53 torch::Tensor tree2UnmatchedTensor
54 = tree2DiagTensor.index({tree2ProjIndexerR});
55 tree2ReorderedTensor
56 = torch::cat({tree2ReorderedTensor, tree2UnmatchedTensor});
57 tree2DeltaProjTensor = torch::cat(
58 {tree2DeltaProjTensor, torch::zeros_like(tree2UnmatchedTensor)});
59 }
60
61 // Reshape
62 tree2ReorderedTensor = tree2ReorderedTensor.reshape({-1, 1});
63 tree2DeltaProjTensor = tree2DeltaProjTensor.reshape({-1, 1});
64}
65
66void mtu::dataReorderingGivenMatching(const mtu::TorchMergeTree<float> &tree,
67 const mtu::TorchMergeTree<float> &tree2,
68 torch::Tensor &tree1ProjIndexer,
69 torch::Tensor &tree2ReorderingIndexes,
70 torch::Tensor &tree2ReorderedTensor,
71 torch::Tensor &tree2DeltaProjTensor) {
72 torch::Tensor tree1ReorderedTensor;
73 torch::Tensor tree2ProjIndexer;
74 bool doubleReordering = false;
75 dataReorderingGivenMatching(tree, tree2, tree1ProjIndexer,
76 tree2ReorderingIndexes, tree2ReorderedTensor,
77 tree2DeltaProjTensor, tree1ReorderedTensor,
78 tree2ProjIndexer, doubleReordering);
79}
80
81void mtu::dataReorderingGivenMatching(
82 const mtu::TorchMergeTree<float> &tree,
83 const mtu::TorchMergeTree<float> &tree2,
84 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
85 torch::Tensor &tree1ReorderedTensor,
86 torch::Tensor &tree2ReorderedTensor,
87 bool doubleReordering) {
88 // Get tensor matching
89 std::vector<int> tensorMatching;
90 mtu::getTensorMatching(tree, tree2, matching, tensorMatching);
91 torch::Tensor tree2ReorderingIndexes = torch::tensor(tensorMatching);
92 torch::Tensor tree1ProjIndexer
93 = (tree2ReorderingIndexes == -1).reshape({-1, 1});
94 // Reorder tensor
95 torch::Tensor tree2DeltaProjTensor;
96 if(not doubleReordering) {
97 dataReorderingGivenMatching(tree, tree2, tree1ProjIndexer,
98 tree2ReorderingIndexes, tree2ReorderedTensor,
99 tree2DeltaProjTensor);
100 } else {
101 std::vector<int> tensorMatching2;
102 mtu::getInverseTensorMatching(tree, tree2, matching, tensorMatching);
103 torch::Tensor tree1ReorderingIndexes = torch::tensor(tensorMatching);
104 torch::Tensor tree2ProjIndexer
105 = (tree1ReorderingIndexes == -1).reshape({-1, 1});
106 dataReorderingGivenMatching(tree, tree2, tree1ProjIndexer,
107 tree2ReorderingIndexes, tree2ReorderedTensor,
108 tree2DeltaProjTensor, tree1ReorderedTensor,
109 tree2ProjIndexer, doubleReordering);
110 }
111 tree2ReorderedTensor = tree2ReorderedTensor + tree2DeltaProjTensor;
112}
113
114void mtu::dataReorderingGivenMatching(
115 const mtu::TorchMergeTree<float> &tree,
116 const mtu::TorchMergeTree<float> &tree2,
117 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
118 torch::Tensor &tree2ReorderedTensor) {
119 torch::Tensor tree1ReorderedTensor;
120 bool doubleReordering = false;
121 dataReorderingGivenMatching(tree, tree2, matching, tree1ReorderedTensor,
122 tree2ReorderedTensor, doubleReordering);
123}
124
125void mtu::meanBirthShift(torch::Tensor &diagTensor,
126 torch::Tensor &diagBaseTensor) {
127 torch::Tensor birthShiftValue = diagBaseTensor.index({Slice(), 0}).mean()
128 - diagTensor.index({Slice(), 0}).mean();
129 torch::Tensor shiftTensor
130 = torch::full({diagTensor.sizes()[0], 2}, birthShiftValue.item<float>(),
131 torch::TensorOptions().device(diagTensor.device()));
132 diagTensor.index_put_({None}, diagTensor + shiftTensor);
133}
134
135void mtu::meanBirthMaxPersShift(torch::Tensor &tensor,
136 torch::Tensor &baseTensor) {
137 torch::Tensor diagTensor = tensor.reshape({-1, 2});
138 torch::Tensor diagBaseTensor = baseTensor.reshape({-1, 2});
139 // Shift to have same max pers
140 torch::Tensor baseMaxPers
141 = (diagBaseTensor.index({Slice(), 1}) - diagBaseTensor.index({Slice(), 0}))
142 .max();
143 torch::Tensor maxPers
144 = (diagTensor.index({Slice(), 1}) - diagTensor.index({Slice(), 0})).max();
145 torch::Tensor shiftTensor = (baseMaxPers - maxPers) / 2.0;
146 shiftTensor = torch::stack({-shiftTensor, shiftTensor});
147 if(!diagTensor.device().is_cpu())
148 shiftTensor = shiftTensor.to(diagTensor.device());
149 diagTensor.index_put_({None}, diagTensor + shiftTensor);
150 // Shift to have same birth mean
151 meanBirthShift(diagTensor, diagBaseTensor);
152}
153
154void mtu::belowDiagonalPointsShift(torch::Tensor &tensor,
155 torch::Tensor &backupTensor) {
156 torch::Tensor oPDiag = tensor.reshape({-1, 2});
157 torch::Tensor badPointsIndexer
158 = (oPDiag.index({Slice(), 0}) > oPDiag.index({Slice(), 1}));
159 torch::Tensor goodPoints = oPDiag.index({~badPointsIndexer});
160 if(goodPoints.sizes()[0] == 0)
161 goodPoints = backupTensor.reshape({-1, 2});
162 torch::Tensor badPoints = oPDiag.index({badPointsIndexer});
163 // Shift to be above diagonal with median pers
164 torch::Tensor pers
165 = (goodPoints.index({Slice(), 1}) - goodPoints.index({Slice(), 0}))
166 .median();
167 torch::Tensor shiftTensor
168 = torch::full({badPoints.sizes()[0], 1}, pers.item<float>(),
169 torch::TensorOptions().device(badPoints.device()));
170 shiftTensor = (shiftTensor - badPoints.index({Slice(), 1}).reshape({-1, 1})
171 + badPoints.index({Slice(), 0}).reshape({-1, 1}))
172 / 2.0;
173 shiftTensor = torch::cat({-shiftTensor, shiftTensor}, 1);
174 badPoints = badPoints + shiftTensor;
175 // Update tensor
176 oPDiag.index_put_({badPointsIndexer}, badPoints);
177 tensor = oPDiag.reshape({-1, 1}).detach();
178}
179
180void mtu::normalizeVectors(torch::Tensor &originTensor,
181 torch::Tensor &vectorsTensor) {
182 torch::Tensor vSliced = vectorsTensor.index({Slice(2, None)});
183 vSliced.index_put_({None}, vSliced / (originTensor[1] - originTensor[0]));
184}
185
186void mtu::normalizeVectors(mtu::TorchMergeTree<float> &origin,
187 std::vector<std::vector<double>> &vectors) {
188 std::queue<ftm::idNode> queue;
189 queue.emplace(origin.mTree.tree.getRoot());
190 while(!queue.empty()) {
191 ftm::idNode node = queue.front();
192 queue.pop();
193 if(not origin.mTree.tree.isRoot(node))
194 for(unsigned int i = 0; i < 2; ++i)
195 vectors[node][i] /= (origin.tensor[1] - origin.tensor[0]).item<float>();
196 std::vector<ftm::idNode> children;
197 origin.mTree.tree.getChildren(node, children);
198 for(auto &child : children)
199 queue.emplace(child);
200 }
201}
202
203// Work only for persistence diagrams
204bool mtu::isThereMissingPairs(mtu::TorchMergeTree<float> &interpolation) {
205 float maxPers
206 = interpolation.mTree.tree.template getMaximumPersistence<float>();
207 torch::Tensor interTensor = interpolation.tensor;
208 torch::Tensor indexer
209 = torch::abs(interTensor.reshape({-1, 2}).index({Slice(), 0})
210 - interTensor.reshape({-1, 2}).index({Slice(), 1}))
211 > (maxPers * 0.001 / 100.0);
212 torch::Tensor indexed = interTensor.reshape({-1, 2}).index({indexer});
213 bool isMissingPairs
214 = indexed.sizes()[0] > interpolation.mTree.tree.getRealNumberOfNodes();
215 return isMissingPairs;
216}
217#endif
T mean(const T *v, const int &dimension=3)
Definition Statistics.cpp:9
unsigned int idNode
Node index in vect_nodes_.
FiltratedEdge max(const FiltratedEdge &a, const FiltratedEdge &b)
TTK base package defining the standard types.