TTK
Loading...
Searching...
No Matches
KDTree.h
Go to the documentation of this file.
1
8
9#pragma once
10
11// base code includes
12#include <Debug.h>
13#include <Geometry.h> // for pow
14
15#include <algorithm>
16#include <limits>
17#include <memory>
18
19namespace ttk {
20 template <typename dataType, typename Container>
21 class KDTree {
22
23 protected:
24 // Boolean indicating if the current node is a left node of its parent
25 bool is_left_{};
26 // Indicates according to which coordinate the tree splits its elements
28 // Power used for the computation of distances. p=2 yields euclidean
29 // distance
30 int p_{2};
31 // Whether or not the KDTree should include weights that add up to distance
32 // for the computation of nearest neighbours
33 bool include_weights_{false};
34
35 public:
36 using KDTreeRoot = std::unique_ptr<KDTree>;
37 using KDTreeMap = std::vector<KDTree *>;
38
39 KDTreeRoot left_{}; // Lower half for the coordinate specified
40 KDTreeRoot right_{}; // Higher half
42 // ID of the object saved here. The whole object is not kept in the KDTree
43 // Users should keep track of them in a table for instance
44 int id_{};
45 Container coordinates_{};
46 Container coords_min_{};
47 Container coords_max_{};
48 int level_{};
49
50 std::vector<dataType> weight_{};
51 std::vector<dataType> min_subweights_{};
52
53 KDTree() = default;
54 KDTree(const bool include_weights, const int p)
55 : p_{p}, include_weights_{include_weights} {
56 }
57
58 KDTree(KDTree *const father, const int coords_number, const bool is_left)
59 : is_left_{is_left}, coords_number_{coords_number},
60 include_weights_{father->include_weights_}, parent_{father} {
61 }
62
63 // ptNumber : Number of points in the dataset.
64 // nodeNumber : Number of nodes in the tree.
65 // preciseBoundingBox : Allows for a bounding box that is accurate relative
66 // to the input data. It is recommended to set it to false for persistence
67 // diagrams and true otherwise.
68 KDTreeMap build(dataType *data,
69 const int &ptNumber,
70 const int &dimension,
71 const std::vector<std::vector<dataType>> &weights = {},
72 const int &weightNumber = 1,
73 const int &nodeNumber = -1,
74 const bool &preciseBoundingBox = false);
75
76 void buildRecursive(dataType *data,
77 std::vector<int> &idx_side,
78 const int &ptNumber,
79 const int &dimension,
80 KDTree<dataType, Container> *parent,
81 KDTreeMap &correspondence_map,
82 const int &nodeNumber,
83 const int &maximumLevel,
84 int &createdNumberNode,
85 const std::vector<std::vector<dataType>> &weights = {},
86 const int &weightNumber = 1);
87
88 inline void updateWeight(const dataType new_weight,
89 const int weight_index = 0) {
90 weight_[weight_index] = new_weight;
91 updateMinSubweight(weight_index);
92 }
93
94 void updateMinSubweight(const int weight_index = 0);
95 void getKClosest(const unsigned int k,
96 const Container &coordinates,
97 KDTreeMap &neighbours,
98 std::vector<dataType> &costs,
99 const int weight_index = 0);
100
101 template <typename PowerFunc>
102 void recursiveGetKClosest(const unsigned int k,
103 const Container &coordinates,
104 KDTreeMap &neighbours,
105 std::vector<dataType> &costs,
106 const int weight_index,
107 const PowerFunc &power);
108
109 template <typename PowerFunc>
110 inline dataType getCost(const Container &coordinates,
111 const PowerFunc &power) const {
112 dataType cost = 0;
113 for(size_t i = 0; i < coordinates.size(); i++) {
114 cost += power(std::abs(coordinates[i] - coordinates_[i]));
115 }
116 return cost;
117 }
118
119 template <typename PowerFunc>
120 inline dataType distanceToBox(const KDTree<dataType, Container> &subtree,
121 const Container &coordinates,
122 const PowerFunc &power) const {
123 dataType d_min = 0;
124 for(size_t axis = 0; axis < coordinates.size(); axis++) {
125 if(subtree.coords_min_[axis] > coordinates[axis]) {
126 d_min += power(subtree.coords_min_[axis] - coordinates[axis]);
127 } else if(subtree.coords_max_[axis] < coordinates[axis]) {
128 d_min += power(coordinates[axis] - subtree.coords_max_[axis]);
129 }
130 }
131 return d_min;
132 }
133
134 inline const std::vector<dataType> &getCoordinates() const {
135 return coordinates_;
136 }
137 inline dataType getWeight(const int weight_index = 0) const {
138 return weight_[weight_index];
139 }
140 inline dataType getMinSubWeight(const int weight_index = 0) const {
141 return min_subweights_[weight_index];
142 }
143 inline bool isLeaf() const {
144 return left_ == nullptr && right_ == nullptr;
145 }
146 inline bool isRoot() const {
147 return parent_ == nullptr;
148 }
149 };
150} // namespace ttk
151
152template <typename dataType, typename Container>
155 dataType *data,
156 const int &ptNumber,
157 const int &dimension,
158 const std::vector<std::vector<dataType>> &weights,
159 const int &weightNumber,
160 const int &nodeNumber,
161 const bool &preciseBoundingBox) {
162
163 int createdNumberNode = 1;
164 int idGenerator = createdNumberNode - 1;
165 int maximumLevel = 0;
166
167 int correspondence_map_size = 0;
168
169 if(nodeNumber == -1) {
170 correspondence_map_size = ptNumber;
171 } else {
172 correspondence_map_size = nodeNumber;
173 maximumLevel = ceil(log2(nodeNumber + 1)) - 1;
174 }
175
176 KDTreeMap correspondence_map(correspondence_map_size);
177
178 if(preciseBoundingBox) {
179 // First, perform a argsort on the data
180 // initialize original index locations
181 dataType x_max = std::numeric_limits<dataType>::lowest();
182 dataType x_min = std::numeric_limits<dataType>::max();
183 dataType y_max = std::numeric_limits<dataType>::lowest();
184 dataType y_min = std::numeric_limits<dataType>::max();
185 dataType z_max = std::numeric_limits<dataType>::lowest();
186 dataType z_min = std::numeric_limits<dataType>::max();
187
188 for(int i = 0; i < ptNumber * dimension; i += dimension) {
189 if(x_max < data[i]) {
190 x_max = data[i];
191 }
192 if(x_min > data[i]) {
193 x_min = data[i];
194 }
195
196 if(y_max < data[i + 1]) {
197 y_max = data[i + 1];
198 }
199 if(y_min > data[i + 1]) {
200 y_min = data[i + 1];
201 }
202
203 if(dimension > 2) {
204 if(z_max < data[i + 2]) {
205 z_max = data[i + 2];
206 }
207 if(z_min > data[i + 2]) {
208 z_min = data[i + 2];
209 }
210 }
211 }
212
213 coords_min_[0] = x_min;
214 coords_max_[0] = x_max;
215 coords_min_[1] = y_min;
216 coords_max_[1] = y_max;
217 if(dimension > 2) {
218 coords_min_[2] = z_min;
219 coords_max_[2] = z_max;
220 }
221 } else {
222 for(int axis = 0; axis < dimension; axis++) {
223 coords_min_[axis] = std::numeric_limits<dataType>::lowest();
224 coords_max_[axis] = std::numeric_limits<dataType>::max();
225 }
226 }
227
228 std::vector<int> idx(ptNumber);
229 for(int i = 0; i < ptNumber; i++) {
230 idx[i] = i;
231 }
232 // sort indexes based on comparing values in coordinates
233 sort(idx.begin(), idx.end(), [&](int i1, int i2) {
234 return data[dimension * i1 + coords_number_]
235 < data[dimension * i2 + coords_number_];
236 });
237 int median_loc = (int)(ptNumber - 1) / 2;
238 int median_idx = idx[median_loc];
239
240 for(int axis = 0; axis < dimension; axis++) {
241 coordinates_[axis] = data[dimension * median_idx + axis];
242 }
243
244 if(nodeNumber == -1) {
245 correspondence_map[median_idx] = this;
246 id_ = median_idx;
247 } else {
248 correspondence_map[idGenerator] = this;
249 id_ = idGenerator;
250 }
251
252 parent_ = nullptr;
253 level_ = 0;
254
255 this->weight_.clear();
256 this->min_subweights_.clear();
257
258 if(weights.empty()) {
259 this->weight_.resize(weightNumber);
260 this->min_subweights_.resize(weightNumber);
261 } else {
262 for(int i = 0; i < weightNumber; i++) {
263 weight_.push_back(weights[i][median_idx]);
264 min_subweights_.push_back(weights[i][median_idx]);
265 }
266 }
267
268 if(((nodeNumber == -1) && (idx.size() > 2))
269 || ((nodeNumber != -1) && (createdNumberNode < nodeNumber))) {
270 // Build left leaf
271 std::vector<int> idx_left(median_loc);
272 for(int i = 0; i < median_loc; i++) {
273 idx_left[i] = idx[i];
274 }
275
276 this->left_
277 = std::make_unique<KDTree>(this, (coords_number_ + 1) % dimension, true);
278 this->left_->buildRecursive(data, idx_left, ptNumber, dimension, this,
279 correspondence_map, nodeNumber, maximumLevel,
280 createdNumberNode, weights, weightNumber);
281 }
282
283 if(((nodeNumber == -1) && (idx.size() > 1))
284 || ((nodeNumber != -1) && (createdNumberNode < nodeNumber))) {
285 // Build right leaf
286 std::vector<int> idx_right(ptNumber - median_loc - 1);
287 for(int i = 0; i < ptNumber - median_loc - 1; i++) {
288 idx_right[i] = idx[i + median_loc + 1];
289 }
290 this->right_
291 = std::make_unique<KDTree>(this, (coords_number_ + 1) % dimension, false);
292 this->right_->buildRecursive(data, idx_right, ptNumber, dimension, this,
293 correspondence_map, nodeNumber, maximumLevel,
294 createdNumberNode, weights, weightNumber);
295 }
296
297 return correspondence_map;
298}
299
300template <typename dataType, typename Container>
302 dataType *data,
303 std::vector<int> &idx_side,
304 const int &ptNumber,
305 const int &dimension,
307 KDTreeMap &correspondence_map,
308 const int &nodeNumber,
309 const int &maximumLevel,
310 int &createdNumberNode,
311 const std::vector<std::vector<dataType>> &weights,
312 const int &weightNumber) {
313
314 createdNumberNode++;
315 int idGenerator = createdNumberNode - 1;
316
317 // First, perform a argsort on the data
318 sort(idx_side.begin(), idx_side.end(), [&](int i1, int i2) {
319 return data[dimension * i1 + coords_number_]
320 < data[dimension * i2 + coords_number_];
321 });
322 int median_loc = (int)(idx_side.size() - 1) / 2;
323 int median_idx = idx_side[median_loc];
324
325 for(int axis = 0; axis < dimension; axis++) {
326 coordinates_[axis] = data[dimension * median_idx + axis];
327 }
328
329 if(nodeNumber == -1) {
330 correspondence_map[median_idx] = this;
331 id_ = median_idx;
332 } else {
333 correspondence_map[idGenerator] = this;
334 id_ = idGenerator;
335 }
336
337 parent_ = parent;
338 level_ = parent->level_ + 1;
339
340 this->weight_.clear();
341 this->min_subweights_.clear();
342
343 if(weights.empty()) {
344 this->weight_.resize(weightNumber);
345 this->min_subweights_.resize(weightNumber);
346 } else {
347 for(int i = 0; i < weightNumber; i++) {
348 weight_.push_back(weights[i][median_idx]);
349 min_subweights_.push_back(weights[i][median_idx]);
350 }
351
352 if(idx_side.size() > 1) {
353 // Once we get to a leaf, update min_subweights of the parents
354 for(int w = 0; w < weightNumber; w++) {
355 this->updateMinSubweight(w);
356 }
357 }
358 }
359
360 // Create bounding box
361 for(int axis = 0; axis < dimension; axis++) {
362 coords_min_[axis] = parent_->coords_min_[axis];
363 coords_max_[axis] = parent_->coords_max_[axis];
364 }
365 if(is_left_ && !this->isRoot()) {
366 coords_max_[parent_->coords_number_]
367 = parent_->coordinates_[parent_->coords_number_];
368 } else if(!is_left_ && !this->isRoot()) {
369 coords_min_[parent_->coords_number_]
370 = parent_->coordinates_[parent_->coords_number_];
371 }
372
373 if(((nodeNumber == -1) && (idx_side.size() > 2))
374 || ((nodeNumber != -1) && (level_ < maximumLevel)
375 && (createdNumberNode < nodeNumber))) {
376 // Build left leaf
377 std::vector<int> idx_left(median_loc);
378 for(int i = 0; i < median_loc; i++) {
379 idx_left[i] = idx_side[i];
380 }
381
382 this->left_
383 = std::make_unique<KDTree>(this, (coords_number_ + 1) % dimension, true);
384 this->left_->buildRecursive(data, idx_left, ptNumber, dimension, this,
385 correspondence_map, nodeNumber, maximumLevel,
386 createdNumberNode, weights, weightNumber);
387 }
388
389 if(((nodeNumber == -1) && (idx_side.size() > 1))
390 || ((nodeNumber != -1) && (level_ < maximumLevel)
391 && (createdNumberNode < nodeNumber))) {
392 // Build right leaf
393 std::vector<int> idx_right(idx_side.size() - median_loc - 1);
394 for(unsigned int i = 0; i < idx_side.size() - median_loc - 1; i++) {
395 idx_right[i] = idx_side[i + median_loc + 1];
396 }
397 this->right_
398 = std::make_unique<KDTree>(this, (coords_number_ + 1) % dimension, false);
399 this->right_->buildRecursive(data, idx_right, ptNumber, dimension, this,
400 correspondence_map, nodeNumber, maximumLevel,
401 createdNumberNode, weights, weightNumber);
402 }
403}
404
405template <typename dataType, typename Container>
407 const int weight_index) {
408 dataType new_min_subweight;
409 if(this->isLeaf()) {
410 new_min_subweight = weight_[weight_index];
411 } else if(!left_) {
412 new_min_subweight
413 = std::min(right_->min_subweights_[weight_index], weight_[weight_index]);
414 } else if(!right_) {
415 new_min_subweight
416 = std::min(left_->min_subweights_[weight_index], weight_[weight_index]);
417 } else {
418 new_min_subweight
419 = std::min(std::min(left_->min_subweights_[weight_index],
420 right_->min_subweights_[weight_index]),
421 weight_[weight_index]);
422 }
423
424 if(new_min_subweight != min_subweights_[weight_index]) {
425 min_subweights_[weight_index] = new_min_subweight;
426 if(!this->isRoot()) {
427 parent_->updateMinSubweight(weight_index);
428 }
429 }
430}
431
432template <typename dataType, typename Container>
434 const Container &coordinates,
435 KDTreeMap &neighbours,
436 std::vector<dataType> &costs,
437 const int weight_index) {
438
439 const auto p{this->p_};
440
445 if(this->isLeaf()) {
446 dataType cost{};
447 TTK_POW_LAMBDA(cost = this->getCost, dataType, p, coordinates);
448 cost += weight_[weight_index];
449 neighbours.push_back(this);
450 costs.push_back(cost);
451 } else {
452 neighbours.reserve(k);
453 costs.reserve(k);
454 TTK_POW_LAMBDA(this->recursiveGetKClosest, dataType, p, k, coordinates,
455 neighbours, costs, weight_index);
456 }
457 // TODO sort neighbours and costs !
458}
459
460template <typename dataType, typename Container>
461template <typename PowerFunc>
463 const unsigned int k,
464 const Container &coordinates,
465 KDTreeMap &neighbours,
466 std::vector<dataType> &costs,
467 const int weight_index,
468 const PowerFunc &power) {
469 // 1- Look whether or not to include the current point in the nearest
470 // neighbours
471 dataType cost = this->getCost(coordinates, power);
472 cost += weight_[weight_index];
473
474 if(costs.size() < k) {
475 neighbours.push_back(this);
476 costs.push_back(cost);
477 } else {
478 // 1.1- Find the most costly amongst neighbours
479 const auto idx_max_cost = std::distance(
480 costs.begin(), std::max_element(costs.begin(), costs.begin() + k));
481 const dataType max_cost = costs[idx_max_cost];
482
483 // 1.2- If the current KDTree is less costly, put it in the neighbours and
484 // update costs.
485 if(cost < max_cost) {
486 costs[idx_max_cost] = cost;
487 neighbours[idx_max_cost] = this;
488 }
489 }
490
491 // 2- Recursively visit KDTrees that are worth it
492 if(left_) {
493 const dataType max_cost = *std::max_element(costs.begin(), costs.end());
494 const dataType min_subweight = left_->min_subweights_[weight_index];
495 const dataType d_min = this->distanceToBox(*left_, coordinates, power);
496 if(costs.size() < k || d_min + min_subweight < max_cost) {
497 // 2.2- It is possible that there exists a point in this subtree that is
498 // less costly than max_cost
499 left_->recursiveGetKClosest(
500 k, coordinates, neighbours, costs, weight_index, power);
501 }
502 }
503
504 if(right_) {
505 const dataType max_cost = *std::max_element(costs.begin(), costs.end());
506 const dataType min_subweight = right_->min_subweights_[weight_index];
507 const dataType d_min = this->distanceToBox(*right_, coordinates, power);
508 if(costs.size() < k || d_min + min_subweight < max_cost) {
509 // 2.2- It is possible that there exists a point in this subtree that is
510 // less costly than max_cost
511 right_->recursiveGetKClosest(
512 k, coordinates, neighbours, costs, weight_index, power);
513 }
514 }
515}
#define TTK_POW_LAMBDA(CALLEXPR, TYPE, EXPN,...)
Optimized Power function with lambdas.
Definition Geometry.h:430
TTK KD-Tree.
Definition KDTree.h:21
dataType getMinSubWeight(const int weight_index=0) const
Definition KDTree.h:140
bool isRoot() const
Definition KDTree.h:146
KDTreeMap build(dataType *data, const int &ptNumber, const int &dimension, const std::vector< std::vector< dataType > > &weights={}, const int &weightNumber=1, const int &nodeNumber=-1, const bool &preciseBoundingBox=false)
Definition KDTree.h:154
void recursiveGetKClosest(const unsigned int k, const Container &coordinates, KDTreeMap &neighbours, std::vector< dataType > &costs, const int weight_index, const PowerFunc &power)
Definition KDTree.h:462
Container coords_min_
Definition KDTree.h:46
void updateWeight(const dataType new_weight, const int weight_index=0)
Definition KDTree.h:88
bool include_weights_
Definition KDTree.h:33
void getKClosest(const unsigned int k, const Container &coordinates, KDTreeMap &neighbours, std::vector< dataType > &costs, const int weight_index=0)
Definition KDTree.h:433
std::vector< KDTree * > KDTreeMap
Definition KDTree.h:37
KDTree * parent_
Definition KDTree.h:41
dataType getCost(const Container &coordinates, const PowerFunc &power) const
Definition KDTree.h:110
std::vector< dataType > min_subweights_
Definition KDTree.h:51
std::vector< dataType > weight_
Definition KDTree.h:50
dataType getWeight(const int weight_index=0) const
Definition KDTree.h:137
bool is_left_
Definition KDTree.h:25
KDTree()=default
KDTree(KDTree *const father, const int coords_number, const bool is_left)
Definition KDTree.h:58
Container coordinates_
Definition KDTree.h:45
KDTree(const bool include_weights, const int p)
Definition KDTree.h:54
void buildRecursive(dataType *data, std::vector< int > &idx_side, const int &ptNumber, const int &dimension, KDTree< dataType, Container > *parent, KDTreeMap &correspondence_map, const int &nodeNumber, const int &maximumLevel, int &createdNumberNode, const std::vector< std::vector< dataType > > &weights={}, const int &weightNumber=1)
Definition KDTree.h:301
KDTreeRoot left_
Definition KDTree.h:39
void updateMinSubweight(const int weight_index=0)
Definition KDTree.h:406
const std::vector< dataType > & getCoordinates() const
Definition KDTree.h:134
KDTreeRoot right_
Definition KDTree.h:40
int coords_number_
Definition KDTree.h:27
bool isLeaf() const
Definition KDTree.h:143
int level_
Definition KDTree.h:48
std::unique_ptr< KDTree > KDTreeRoot
Definition KDTree.h:36
Container coords_max_
Definition KDTree.h:47
dataType distanceToBox(const KDTree< dataType, Container > &subtree, const Container &coordinates, const PowerFunc &power) const
Definition KDTree.h:120
int id_
Definition KDTree.h:44
The Topology ToolKit.