7using namespace torch::indexing;
9void mtu::copyTensor(torch::Tensor &a, torch::Tensor &b) {
10 b = a.detach().clone();
11 b.requires_grad_(a.requires_grad());
14void mtu::getDeltaProjTensor(torch::Tensor &diagTensor,
15 torch::Tensor &deltaProjTensor) {
17 = (diagTensor.index({Slice(), 0}) + diagTensor.index({Slice(), 1})) / 2.0;
18 deltaProjTensor = deltaProjTensor.reshape({-1, 1});
19 deltaProjTensor = torch::cat({deltaProjTensor, deltaProjTensor}, 1);
22void mtu::dataReorderingGivenMatching(mtu::TorchMergeTree<float> &tree,
23 mtu::TorchMergeTree<float> &tree2,
24 torch::Tensor &tree1ProjIndexer,
25 torch::Tensor &tree2ReorderingIndexes,
26 torch::Tensor &tree2ReorderedTensor,
27 torch::Tensor &tree2DeltaProjTensor,
28 torch::Tensor &tree1ReorderedTensor,
29 torch::Tensor &tree2ProjIndexer,
30 bool doubleReordering) {
32 torch::Tensor tree2DiagTensor = tree2.tensor.reshape({-1, 2});
33 tree2ReorderedTensor = torch::cat({tree2DiagTensor, torch::zeros({1, 2})});
34 tree2ReorderedTensor = tree2ReorderedTensor.index({tree2ReorderingIndexes});
37 torch::Tensor treeDiagTensor = tree.tensor.reshape({-1, 2});
38 getDeltaProjTensor(treeDiagTensor, tree2DeltaProjTensor);
39 tree2DeltaProjTensor = tree2DeltaProjTensor * tree1ProjIndexer;
42 if(doubleReordering) {
43 torch::Tensor tree1DeltaProjTensor;
44 getDeltaProjTensor(tree2DiagTensor, tree1DeltaProjTensor);
45 torch::Tensor tree2ProjIndexerR = tree2ProjIndexer.reshape({-1});
46 tree1DeltaProjTensor = tree1DeltaProjTensor.index({tree2ProjIndexerR});
47 tree1ReorderedTensor = torch::cat({treeDiagTensor, tree1DeltaProjTensor});
48 tree1ReorderedTensor = tree1ReorderedTensor.reshape({-1, 1});
49 torch::Tensor tree2UnmatchedTensor
50 = tree2DiagTensor.index({tree2ProjIndexerR});
52 = torch::cat({tree2ReorderedTensor, tree2UnmatchedTensor});
53 tree2DeltaProjTensor = torch::cat(
54 {tree2DeltaProjTensor, torch::zeros_like(tree2UnmatchedTensor)});
58 tree2ReorderedTensor = tree2ReorderedTensor.reshape({-1, 1});
59 tree2DeltaProjTensor = tree2DeltaProjTensor.reshape({-1, 1});
62void mtu::dataReorderingGivenMatching(mtu::TorchMergeTree<float> &tree,
63 mtu::TorchMergeTree<float> &tree2,
64 torch::Tensor &tree1ProjIndexer,
65 torch::Tensor &tree2ReorderingIndexes,
66 torch::Tensor &tree2ReorderedTensor,
67 torch::Tensor &tree2DeltaProjTensor) {
68 torch::Tensor tree1ReorderedTensor;
69 torch::Tensor tree2ProjIndexer;
70 bool doubleReordering =
false;
71 dataReorderingGivenMatching(tree, tree2, tree1ProjIndexer,
72 tree2ReorderingIndexes, tree2ReorderedTensor,
73 tree2DeltaProjTensor, tree1ReorderedTensor,
74 tree2ProjIndexer, doubleReordering);
77void mtu::dataReorderingGivenMatching(
78 mtu::TorchMergeTree<float> &tree,
79 mtu::TorchMergeTree<float> &tree2,
80 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
81 torch::Tensor &tree1ReorderedTensor,
82 torch::Tensor &tree2ReorderedTensor,
83 bool doubleReordering) {
85 std::vector<int> tensorMatching;
86 mtu::getTensorMatching(tree, tree2, matching, tensorMatching);
87 torch::Tensor tree2ReorderingIndexes = torch::tensor(tensorMatching);
88 torch::Tensor tree1ProjIndexer
89 = (tree2ReorderingIndexes == -1).reshape({-1, 1});
91 torch::Tensor tree2DeltaProjTensor;
92 if(not doubleReordering) {
93 dataReorderingGivenMatching(tree, tree2, tree1ProjIndexer,
94 tree2ReorderingIndexes, tree2ReorderedTensor,
95 tree2DeltaProjTensor);
97 std::vector<int> tensorMatching2;
98 mtu::getInverseTensorMatching(tree, tree2, matching, tensorMatching);
99 torch::Tensor tree1ReorderingIndexes = torch::tensor(tensorMatching);
100 torch::Tensor tree2ProjIndexer
101 = (tree1ReorderingIndexes == -1).reshape({-1, 1});
102 dataReorderingGivenMatching(tree, tree2, tree1ProjIndexer,
103 tree2ReorderingIndexes, tree2ReorderedTensor,
104 tree2DeltaProjTensor, tree1ReorderedTensor,
105 tree2ProjIndexer, doubleReordering);
107 tree2ReorderedTensor = tree2ReorderedTensor + tree2DeltaProjTensor;
110void mtu::dataReorderingGivenMatching(
111 mtu::TorchMergeTree<float> &tree,
112 mtu::TorchMergeTree<float> &tree2,
113 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
114 torch::Tensor &tree2ReorderedTensor) {
115 torch::Tensor tree1ReorderedTensor;
116 bool doubleReordering =
false;
117 dataReorderingGivenMatching(tree, tree2, matching, tree1ReorderedTensor,
118 tree2ReorderedTensor, doubleReordering);
121void mtu::meanBirthShift(torch::Tensor &diagTensor,
122 torch::Tensor &diagBaseTensor) {
123 torch::Tensor birthShiftValue = diagBaseTensor.index({Slice(), 0}).
mean()
124 - diagTensor.index({Slice(), 0}).
mean();
125 torch::Tensor shiftTensor
126 = torch::full({diagTensor.sizes()[0], 2}, birthShiftValue.item<
float>());
127 diagTensor.index_put_({None}, diagTensor + shiftTensor);
130void mtu::meanBirthMaxPersShift(torch::Tensor &tensor,
131 torch::Tensor &baseTensor) {
132 torch::Tensor diagTensor = tensor.reshape({-1, 2});
133 torch::Tensor diagBaseTensor = baseTensor.reshape({-1, 2});
135 torch::Tensor baseMaxPers
136 = (diagBaseTensor.index({Slice(), 1}) - diagBaseTensor.index({Slice(), 0}))
138 torch::Tensor maxPers
139 = (diagTensor.index({Slice(), 1}) - diagTensor.index({Slice(), 0})).max();
140 torch::Tensor shiftTensor = (baseMaxPers - maxPers) / 2.0;
141 shiftTensor = torch::stack({-shiftTensor, shiftTensor});
142 diagTensor.index_put_({None}, diagTensor + shiftTensor);
144 meanBirthShift(diagTensor, diagBaseTensor);
147void mtu::belowDiagonalPointsShift(torch::Tensor &tensor,
148 torch::Tensor &backupTensor) {
149 torch::Tensor oPDiag = tensor.reshape({-1, 2});
150 torch::Tensor badPointsIndexer
151 = (oPDiag.index({Slice(), 0}) > oPDiag.index({Slice(), 1}));
152 torch::Tensor goodPoints = oPDiag.index({~badPointsIndexer});
153 if(goodPoints.sizes()[0] == 0)
154 goodPoints = backupTensor.reshape({-1, 2});
155 torch::Tensor badPoints = oPDiag.index({badPointsIndexer});
158 = (goodPoints.index({Slice(), 1}) - goodPoints.index({Slice(), 0}))
160 torch::Tensor shiftTensor
161 = (torch::full({badPoints.sizes()[0], 1}, pers.item<
float>())
162 - badPoints.index({Slice(), 1}).reshape({-1, 1})
163 + badPoints.index({Slice(), 0}).reshape({-1, 1}))
165 shiftTensor = torch::cat({-shiftTensor, shiftTensor}, 1);
166 badPoints = badPoints + shiftTensor;
168 oPDiag.index_put_({badPointsIndexer}, badPoints);
169 tensor = oPDiag.reshape({-1, 1}).detach();
172void mtu::normalizeVectors(torch::Tensor &originTensor,
173 torch::Tensor &vectorsTensor) {
174 torch::Tensor vSliced = vectorsTensor.index({Slice(2, None)});
175 vSliced.index_put_({None}, vSliced / (originTensor[1] - originTensor[0]));
178void mtu::normalizeVectors(mtu::TorchMergeTree<float> &origin,
179 std::vector<std::vector<double>> &vectors) {
180 std::queue<ftm::idNode> queue;
181 queue.emplace(origin.mTree.tree.getRoot());
182 while(!queue.empty()) {
185 if(not origin.mTree.tree.isRoot(node))
186 for(
unsigned int i = 0; i < 2; ++i)
187 vectors[node][i] /= (origin.tensor[1] - origin.tensor[0]).item<
float>();
188 std::vector<ftm::idNode> children;
189 origin.mTree.tree.getChildren(node, children);
190 for(
auto &child : children)
191 queue.emplace(child);
196bool mtu::isThereMissingPairs(mtu::TorchMergeTree<float> &interpolation) {
198 = interpolation.mTree.tree.template getMaximumPersistence<float>();
199 torch::Tensor interTensor = interpolation.tensor;
200 torch::Tensor indexer
201 = torch::abs(interTensor.reshape({-1, 2}).index({Slice(), 0})
202 - interTensor.reshape({-1, 2}).index({Slice(), 1}))
203 > (maxPers * 0.001 / 100.0);
204 torch::Tensor indexed = interTensor.reshape({-1, 2}).index({indexer});
205 return indexed.sizes()[0] > interpolation.mTree.tree.getRealNumberOfNodes();
T mean(const T *v, const int &dimension=3)
unsigned int idNode
Node index in vect_nodes_.