TTK
Loading...
Searching...
No Matches
TopologicalOptimization.h
Go to the documentation of this file.
1
6
7#pragma once
8
9#ifdef TTK_ENABLE_TORCH
10#include <torch/optim.h>
11#include <torch/torch.h>
12#endif
13
14// base code includes
15#include <Debug.h>
16#include <PersistenceDiagram.h>
18#include <Triangulation.h>
19
20namespace ttk {
21
22 class TopologicalOptimization : virtual public Debug {
23 public:
25
26 template <typename dataType, typename triangulationType>
27 int execute(const dataType *const inputScalars,
28 dataType *const outputScalars,
29 SimplexId *const inputOffsets,
30 triangulationType *triangulation,
31 const ttk::DiagramType &constraintDiagram) const;
32
34 if(triangulation) {
35 vertexNumber_ = triangulation->getNumberOfVertices();
36 triangulation->preconditionVertexNeighbors();
37 }
38 return 0;
39 }
40
41 /*
42 This function allows us to retrieve the indices of the critical points
43 that we must modify in order to match our current diagram to our target
44 diagram.
45 */
46 template <typename dataType, typename triangulationType>
47 void getIndices(
48 triangulationType *triangulation,
49 SimplexId *&inputOffsets,
50 dataType *const inputScalars,
51 const ttk::DiagramType &constraintDiagram,
52 int epoch,
53 std::vector<SimplexId> &listAllIndicesToChange,
54 std::vector<std::vector<SimplexId>> &pair2MatchedPair,
55 std::vector<std::vector<SimplexId>> &pair2Delete,
56 std::vector<SimplexId> &pairChangeMatchingPair,
57 std::vector<SimplexId> &birthPairToDeleteCurrentDiagram,
58 std::vector<double> &birthPairToDeleteTargetDiagram,
59 std::vector<SimplexId> &deathPairToDeleteCurrentDiagram,
60 std::vector<double> &deathPairToDeleteTargetDiagram,
61 std::vector<SimplexId> &birthPairToChangeCurrentDiagram,
62 std::vector<double> &birthPairToChangeTargetDiagram,
63 std::vector<SimplexId> &deathPairToChangeCurrentDiagram,
64 std::vector<double> &deathPairToChangeTargetDiagram,
65 std::vector<std::vector<SimplexId>> &currentVertex2PairsCurrentDiagram,
66 std::vector<int> &vertexInHowManyPairs) const;
67
68/*
69 This function allows you to copy the values of a pytorch tensor
70 to a vector in an optimized way.
71*/
72#ifdef TTK_ENABLE_TORCH
73 int tensorToVectorFast(const torch::Tensor &tensor,
74 std::vector<double> &result) const;
75#endif
76
77 inline void setUseFastPersistenceUpdate(bool UseFastPersistenceUpdate) {
78 useFastPersistenceUpdate_ = UseFastPersistenceUpdate;
79 }
80
81 inline void setFastAssignmentUpdate(bool FastAssignmentUpdate) {
82 fastAssignmentUpdate_ = FastAssignmentUpdate;
83 }
84
85 inline void setEpochNumber(int EpochNumber) {
86 epochNumber_ = EpochNumber;
87 }
88
89 inline void setPDCMethod(int PDCMethod) {
90 pdcMethod_ = PDCMethod;
91 }
92
93 inline void setMethodOptimization(int methodOptimization) {
94 methodOptimization_ = methodOptimization;
95 }
96
97 inline void setFinePairManagement(int finePairManagement) {
98 finePairManagement_ = finePairManagement;
99 }
100
101 inline void setChooseLearningRate(int chooseLearningRate) {
102 chooseLearningRate_ = chooseLearningRate;
103 }
104
105 inline void setLearningRate(double learningRate) {
106 learningRate_ = learningRate;
107 }
108
109 inline void setAlpha(double alpha) {
110 alpha_ = alpha;
111 }
112
113 inline void setCoefStopCondition(double coefStopCondition) {
114 coefStopCondition_ = coefStopCondition;
115 }
116
117 inline void
118 setOptimizationWithoutMatching(bool optimizationWithoutMatching) {
119 optimizationWithoutMatching_ = optimizationWithoutMatching;
120 }
121
122 inline void setThresholdMethod(int thresholdMethod) {
123 thresholdMethod_ = thresholdMethod;
124 }
125
126 inline void setThresholdPersistence(double thresholdPersistence) {
127 thresholdPersistence_ = thresholdPersistence;
128 }
129
130 inline void setLowerThreshold(int lowerThreshold) {
131 lowerThreshold_ = lowerThreshold;
132 }
133
134 inline void setUpperThreshold(int upperThreshold) {
135 upperThreshold_ = upperThreshold;
136 }
137
138 inline void setPairTypeToDelete(int pairTypeToDelete) {
139 pairTypeToDelete_ = pairTypeToDelete;
140 }
141
142 inline void setConstraintAveraging(bool ConstraintAveraging) {
143 constraintAveraging_ = ConstraintAveraging;
144 }
145
146 inline void setPrintFrequency(int printFrequency) {
147 printFrequency_ = printFrequency;
148 }
149
150 protected:
153
154 // enable the fast update of the persistence diagram
156
157 // enable the fast update of the pair assignments between the target diagram
159
160 // if pdcMethod_ == 0 then we use Progressive approach
161 // if pdcMethod_ == 1 then we use Classical Auction approach
163
164 // if methodOptimization_ == 0 then we use Direct gradient descent
165 // if methodOptimization_ == 1 then we use Adam
167
168 // if finePairManagement_ == 0 then we let the algorithm choose
169 // if finePairManagement_ == 1 then we fill the domain
170 // if finePairManagement_ == 2 then we cut the domain
172
173 // Adam
176
177 // Direct gradient descent
178 // alpha_ : the gradient step size
179 double alpha_;
180
181 // Stopping criterion: when the loss becomes less than a percentage
182 // coefStopCondition_ (e.g. coefStopCondition_ = 0.01 => 1%) of the original
183 // loss (between input diagram and simplified diagram)
185
186 // Optimization without matching (OWM)
188
189 // [OWM] if thresholdMethod_ == 0 : threshold on persistence
190 // [OWM] if thresholdMethod_ == 1 : threshold on pair type
192
193 // [OWM] thresholdPersistence_ : The threshold value on persistence.
195
196 // [OWM] lowerThreshold_ : The lower threshold on pair type
198
199 // [OWM] upperThreshold_ : The upper threshold on pair type
201
202 // [OWM] pairTypeToDelete_ : Remove only pairs of type pairTypeToDelete_
204
206
208 };
209
210} // namespace ttk
211
212#ifdef TTK_ENABLE_TORCH
213class PersistenceGradientDescent : public torch::nn::Module,
215public:
216 PersistenceGradientDescent(torch::Tensor X_tensor) : torch::nn::Module() {
217 X = register_parameter("X", X_tensor, true);
218 }
219 torch::Tensor X;
220};
221
222#endif
223
224/*
225 This function allows us to retrieve the indices of the critical points
226 that we must modify in order to match our current diagram to our target
227 diagram.
228*/
229template <typename dataType, typename triangulationType>
231 triangulationType *triangulation,
232 SimplexId *&inputOffsets,
233 dataType *const inputScalars,
234 const ttk::DiagramType &constraintDiagram,
235 int epoch,
236 std::vector<SimplexId> &listAllIndicesToChange,
237 std::vector<std::vector<SimplexId>> &pair2MatchedPair,
238 std::vector<std::vector<SimplexId>> &pair2Delete,
239 std::vector<SimplexId> &pairChangeMatchingPair,
240 std::vector<SimplexId> &birthPairToDeleteCurrentDiagram,
241 std::vector<double> &birthPairToDeleteTargetDiagram,
242 std::vector<SimplexId> &deathPairToDeleteCurrentDiagram,
243 std::vector<double> &deathPairToDeleteTargetDiagram,
244 std::vector<SimplexId> &birthPairToChangeCurrentDiagram,
245 std::vector<double> &birthPairToChangeTargetDiagram,
246 std::vector<SimplexId> &deathPairToChangeCurrentDiagram,
247 std::vector<double> &deathPairToChangeTargetDiagram,
248 std::vector<std::vector<SimplexId>> &currentVertex2PairsCurrentDiagram,
249 std::vector<int> &vertexInHowManyPairs) const {
250
251 //=========================================
252 // Lazy Gradient
253 //=========================================
254
255 bool needUpdateDefaultValue
256 = (useFastPersistenceUpdate_ ? (epoch == 0 || epoch < 0 ? true : false)
257 : true);
258 std::vector<bool> needUpdate(vertexNumber_, needUpdateDefaultValue);
260 /*
261 There is a 10% loss of performance
262 */
263 this->printMsg(
264 "Get Indices | UseFastPersistenceUpdate_", debug::Priority::DETAIL);
265
266 if(not(epoch == 0 || epoch < 0)) {
267#ifdef TTK_ENABLE_OPENMP
268#pragma omp parallel for num_threads(threadNumber_)
269#endif
270 for(size_t index = 0; index < listAllIndicesToChange.size(); index++) {
271 if(listAllIndicesToChange[index] == 1) {
272 needUpdate[index] = true;
273
274 // Find all the neighbors of the vertex
275 int vertexNumber = triangulation->getVertexNeighborNumber(index);
276 for(int i = 0; i < vertexNumber; i++) {
277 SimplexId vertexNeighborId = -1;
278 triangulation->getVertexNeighbor(index, i, vertexNeighborId);
279 needUpdate[vertexNeighborId] = true;
280 }
281 }
282 }
283 }
284 }
285
286 SimplexId count = std::count(needUpdate.begin(), needUpdate.end(), true);
287
288 this->printMsg(
289 "Get Indices | The number of vertices that need to be updated is: "
290 + std::to_string(count),
292
293 //=========================================
294 // Compute the persistence diagram
295 //=========================================
296 ttk::Timer timePersistenceDiagram;
297
299 std::vector<ttk::PersistencePair> diagramOutput;
300 ttk::preconditionOrderArray<dataType>(
301 vertexNumber_, inputScalars, inputOffsets, threadNumber_);
302 diagram.setDebugLevel(debugLevel_);
304 diagram.preconditionTriangulation(triangulation);
305
307 diagram.execute(
308 diagramOutput, inputScalars, 0, inputOffsets, triangulation, &needUpdate);
309 } else {
310 diagram.execute(
311 diagramOutput, inputScalars, epoch, inputOffsets, triangulation);
312 }
313
314 //=====================================
315 // Matching Pairs
316 //=====================================
317
319 for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) {
320 auto pair = diagramOutput[i];
321 if((thresholdMethod_ == 0)
322 && (pair.persistence() < thresholdPersistence_)) {
323 birthPairToDeleteCurrentDiagram.push_back(
324 static_cast<SimplexId>(pair.birth.id));
325 birthPairToDeleteTargetDiagram.push_back(
326 (pair.birth.sfValue + pair.death.sfValue) / 2);
327 deathPairToDeleteCurrentDiagram.push_back(
328 static_cast<SimplexId>(pair.death.id));
329 deathPairToDeleteTargetDiagram.push_back(
330 (pair.birth.sfValue + pair.death.sfValue) / 2);
331 } else if((thresholdMethod_ == 1)
332 && ((pair.dim < lowerThreshold_)
333 || (pair.dim > upperThreshold_))) {
334 birthPairToDeleteCurrentDiagram.push_back(
335 static_cast<SimplexId>(pair.birth.id));
336 birthPairToDeleteTargetDiagram.push_back(
337 (pair.birth.sfValue + pair.death.sfValue) / 2);
338 deathPairToDeleteCurrentDiagram.push_back(
339 static_cast<SimplexId>(pair.death.id));
340 deathPairToDeleteTargetDiagram.push_back(
341 (pair.birth.sfValue + pair.death.sfValue) / 2);
342 } else if((thresholdMethod_ == 2) && (pair.dim == pairTypeToDelete_)) {
343 birthPairToDeleteCurrentDiagram.push_back(
344 static_cast<SimplexId>(pair.birth.id));
345 birthPairToDeleteTargetDiagram.push_back(
346 (pair.birth.sfValue + pair.death.sfValue) / 2);
347 deathPairToDeleteCurrentDiagram.push_back(
348 static_cast<SimplexId>(pair.death.id));
349 deathPairToDeleteTargetDiagram.push_back(
350 (pair.birth.sfValue + pair.death.sfValue) / 2);
351 }
352 }
353 } else if(fastAssignmentUpdate_) {
354
355 std::vector<std::vector<SimplexId>> vertex2PairsCurrentDiagram(
356 vertexNumber_, std::vector<SimplexId>());
357 for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) {
358 auto &pair = diagramOutput[i];
359 vertex2PairsCurrentDiagram[pair.birth.id].push_back(i);
360 vertex2PairsCurrentDiagram[pair.death.id].push_back(i);
361 vertexInHowManyPairs[pair.birth.id]++;
362 vertexInHowManyPairs[pair.death.id]++;
363 }
364
365 std::vector<std::vector<SimplexId>> vertex2PairsTargetDiagram(
366 vertexNumber_, std::vector<SimplexId>());
367 for(SimplexId i = 0; i < (SimplexId)constraintDiagram.size(); i++) {
368 auto &pair = constraintDiagram[i];
369 vertex2PairsTargetDiagram[pair.birth.id].push_back(i);
370 vertex2PairsTargetDiagram[pair.death.id].push_back(i);
371 }
372
373 std::vector<std::vector<SimplexId>> matchedPairs;
374 for(SimplexId i = 0; i < (SimplexId)constraintDiagram.size(); i++) {
375 auto &pair = constraintDiagram[i];
376
377 SimplexId birthId = -1;
378 SimplexId deathId = -1;
379
380 if(pairChangeMatchingPair[i] == 1) {
381 birthId = pair2MatchedPair[i][0];
382 deathId = pair2MatchedPair[i][1];
383 } else {
384 birthId = pair.birth.id;
385 deathId = pair.death.id;
386 }
387
388 if(epoch == 0) {
389 for(auto &idPairBirth : vertex2PairsCurrentDiagram[birthId]) {
390 for(auto &idPairDeath : vertex2PairsCurrentDiagram[deathId]) {
391 if(idPairBirth == idPairDeath) {
392 matchedPairs.push_back({i, idPairBirth});
393 }
394 }
395 }
396 } else if((vertex2PairsCurrentDiagram[birthId].size() == 1)
397 && (vertex2PairsCurrentDiagram[deathId].size() == 1)) {
398 if(vertex2PairsCurrentDiagram[birthId][0]
399 == vertex2PairsCurrentDiagram[deathId][0]) {
400 matchedPairs.push_back({i, vertex2PairsCurrentDiagram[deathId][0]});
401 }
402 }
403 }
404
405 std::vector<SimplexId> matchingPairCurrentDiagram(
406 (SimplexId)diagramOutput.size(), -1);
407 std::vector<SimplexId> matchingPairTargetDiagram(
408 (SimplexId)constraintDiagram.size(), -1);
409
410 for(auto &match : matchedPairs) {
411 auto &indicePairTargetDiagram = match[0];
412 auto &indicePairCurrentDiagram = match[1];
413
414 auto &pairCurrentDiagram = diagramOutput[indicePairCurrentDiagram];
415 auto &pairTargetDiagram = constraintDiagram[indicePairTargetDiagram];
416
417 pair2MatchedPair[indicePairTargetDiagram][0]
418 = pairCurrentDiagram.birth.id;
419 pair2MatchedPair[indicePairTargetDiagram][1]
420 = pairCurrentDiagram.death.id;
421
422 matchingPairCurrentDiagram[indicePairCurrentDiagram] = 1;
423 matchingPairTargetDiagram[indicePairTargetDiagram] = 1;
424
425 SimplexId valueBirthPairToChangeCurrentDiagram
426 = (SimplexId)(pairCurrentDiagram.birth.id);
427 SimplexId valueDeathPairToChangeCurrentDiagram
428 = (SimplexId)(pairCurrentDiagram.death.id);
429
430 double valueBirthPairToChangeTargetDiagram
431 = pairTargetDiagram.birth.sfValue;
432 double valueDeathPairToChangeTargetDiagram
433 = pairTargetDiagram.death.sfValue;
434
435 birthPairToChangeCurrentDiagram.push_back(
436 valueBirthPairToChangeCurrentDiagram);
437 birthPairToChangeTargetDiagram.push_back(
438 valueBirthPairToChangeTargetDiagram);
439 deathPairToChangeCurrentDiagram.push_back(
440 valueDeathPairToChangeCurrentDiagram);
441 deathPairToChangeTargetDiagram.push_back(
442 valueDeathPairToChangeTargetDiagram);
443 }
444
445 ttk::DiagramType thresholdCurrentDiagram{};
446 for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) {
447 auto &pair = diagramOutput[i];
448
449 if((pair2Delete[pair.birth.id].size() == 1)
450 && (pair2Delete[pair.death.id].size() == 1)
451 && (pair2Delete[pair.birth.id] == pair2Delete[pair.death.id])) {
452
453 birthPairToDeleteCurrentDiagram.push_back(
454 static_cast<SimplexId>(pair.birth.id));
455 birthPairToDeleteTargetDiagram.push_back(
456 (pair.birth.sfValue + pair.death.sfValue) / 2);
457 deathPairToDeleteCurrentDiagram.push_back(
458 static_cast<SimplexId>(pair.death.id));
459 deathPairToDeleteTargetDiagram.push_back(
460 (pair.birth.sfValue + pair.death.sfValue) / 2);
461 continue;
462 }
463 if(matchingPairCurrentDiagram[i] == -1) {
464 thresholdCurrentDiagram.push_back(pair);
465 }
466 }
467
468 ttk::DiagramType thresholdConstraintDiagram{};
469 std::vector<SimplexId> pairIndiceLocal2Global{};
470 for(SimplexId i = 0; i < (SimplexId)constraintDiagram.size(); i++) {
471 auto &pair = constraintDiagram[i];
472
473 if(matchingPairTargetDiagram[i] == -1) {
474 thresholdConstraintDiagram.push_back(pair);
475 pairIndiceLocal2Global.push_back(i);
476 }
477 }
478
479 this->printMsg("Get Indices | thresholdCurrentDiagram.size(): "
480 + std::to_string(thresholdCurrentDiagram.size()),
482
483 this->printMsg("Get Indices | thresholdConstraintDiagram.size(): "
484 + std::to_string(thresholdConstraintDiagram.size()),
486
487 if(thresholdConstraintDiagram.size() == 0) {
488 for(SimplexId i = 0; i < (SimplexId)thresholdCurrentDiagram.size(); i++) {
489 auto &pair = thresholdCurrentDiagram[i];
490
492
493 // If the point pair.birth.id is in a signal pair
494 // AND If the point pair.death.id is not in a signal pair
495 // Then we only modify the pair.death.id
496 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
497 && (vertex2PairsTargetDiagram[pair.death.id].size() == 0)) {
498 deathPairToDeleteCurrentDiagram.push_back(
499 static_cast<SimplexId>(pair.death.id));
500 deathPairToDeleteTargetDiagram.push_back(
501 (pair.birth.sfValue + pair.death.sfValue) / 2);
502 continue;
503 }
504
505 // If the point pair.death.id is in a signal pair
506 // AND If the point pair.birth.id is not in a signal pair
507 // Then we only modify the pair.birth.id
508 if((vertex2PairsTargetDiagram[pair.birth.id].size() == 0)
509 && (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
510 birthPairToDeleteCurrentDiagram.push_back(
511 static_cast<SimplexId>(pair.birth.id));
512 birthPairToDeleteTargetDiagram.push_back(
513 (pair.birth.sfValue + pair.death.sfValue) / 2);
514 continue;
515 }
516
517 // If the point pair.birth.id is in a signal pair
518 // AND If the point pair.death.id is in a signal pair
519 // Then we do not modify either point
520 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
521 || (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
522 continue;
523 }
524 }
525
526 birthPairToDeleteCurrentDiagram.push_back(
527 static_cast<SimplexId>(pair.birth.id));
528 birthPairToDeleteTargetDiagram.push_back(
529 (pair.birth.sfValue + pair.death.sfValue) / 2);
530 deathPairToDeleteCurrentDiagram.push_back(
531 static_cast<SimplexId>(pair.death.id));
532 deathPairToDeleteTargetDiagram.push_back(
533 (pair.birth.sfValue + pair.death.sfValue) / 2);
534
535 pair2Delete[pair.birth.id].push_back(i);
536 pair2Delete[pair.death.id].push_back(i);
537 }
538 } else {
539
540 ttk::Timer timePersistenceDiagramClustering;
541
542 ttk::PersistenceDiagramClustering persistenceDiagramClustering;
543 PersistenceDiagramBarycenter pdBarycenter{};
544 std::vector<ttk::DiagramType> intermediateDiagrams{
545 thresholdConstraintDiagram, thresholdCurrentDiagram};
546 std::vector<std::vector<std::vector<ttk::MatchingType>>> allMatchings;
547 std::vector<ttk::DiagramType> centroids{};
548
549 if(pdcMethod_ == 0) {
550 persistenceDiagramClustering.setDebugLevel(debugLevel_);
551 persistenceDiagramClustering.setThreadNumber(threadNumber_);
552 // setDeterministic ==> Deterministic algorithm
553 persistenceDiagramClustering.setDeterministic(true);
554 // setUseProgressive ==> Compute Progressive Barycenter
555 persistenceDiagramClustering.setUseProgressive(true);
556 // setUseInterruptible ==> Interruptible algorithm
557 persistenceDiagramClustering.setUseInterruptible(false);
558 // // setTimeLimit ==> Maximal computation time (s)
559 persistenceDiagramClustering.setTimeLimit(0.01);
560 // setUseAdditionalPrecision ==> Force minimum precision on matchings
561 persistenceDiagramClustering.setUseAdditionalPrecision(true);
562 // setDeltaLim ==> Minimal relative precision
563 persistenceDiagramClustering.setDeltaLim(1e-5);
564 // setUseAccelerated ==> Use Accelerated KMeans
565 persistenceDiagramClustering.setUseAccelerated(false);
566 // setUseKmeansppInit ==> KMeanspp Initialization
567 persistenceDiagramClustering.setUseKmeansppInit(false);
568
569 std::vector<int> clusterIds = persistenceDiagramClustering.execute(
570 intermediateDiagrams, centroids, allMatchings);
571 } else {
572
573 centroids.resize(1);
574 const auto wassersteinMetric = std::to_string(2);
575 pdBarycenter.setWasserstein(wassersteinMetric);
576 pdBarycenter.setMethod(2);
577 pdBarycenter.setNumberOfInputs(2);
578 pdBarycenter.setDeterministic(true);
579 pdBarycenter.setUseProgressive(true);
580 pdBarycenter.setDebugLevel(debugLevel_);
581 pdBarycenter.setThreadNumber(threadNumber_);
582 pdBarycenter.setAlpha(1);
583 pdBarycenter.setLambda(1);
584 pdBarycenter.execute(intermediateDiagrams, centroids[0], allMatchings);
585 }
586
587 std::vector<std::vector<SimplexId>> allPairsSelected{};
588 std::vector<std::vector<SimplexId>> matchingsBlockPairs(
589 centroids[0].size());
590
591 for(auto i = 1; i >= 0; --i) {
592 std::vector<ttk::MatchingType> &matching = allMatchings[0][i];
593
594 const auto &diag{intermediateDiagrams[i]};
595
596 for(SimplexId j = 0; j < (SimplexId)matching.size(); j++) {
597
598 const auto &m{matching[j]};
599 const auto &bidderId{std::get<0>(m)};
600 const auto &goodId{std::get<1>(m)};
601
602 if((goodId == -1) | (bidderId == -1)) {
603 continue;
604 }
605
606 if(diag[bidderId].persistence() != 0) {
607 if(i == 1) {
608 matchingsBlockPairs[goodId].push_back(bidderId);
609 } else if(matchingsBlockPairs[goodId].size() > 0) {
610 matchingsBlockPairs[goodId].push_back(bidderId);
611 }
612 allPairsSelected.push_back(
613 {diag[bidderId].birth.id, diag[bidderId].death.id});
614 }
615 }
616 }
617
618 std::vector<ttk::PersistencePair> pairsToErase{};
619
620 std::map<std::vector<SimplexId>, SimplexId> currentToTarget;
621 for(auto &pair : allPairsSelected) {
622 currentToTarget[{pair[0], pair[1]}] = 1;
623 }
624
625 for(auto &pair : intermediateDiagrams[1]) {
626 if(pair.isFinite != 0) {
627 if(!(currentToTarget.count({pair.birth.id, pair.death.id}) > 0)) {
628 pairsToErase.push_back(pair);
629 }
630 }
631 }
632
633 for(auto &pair : pairsToErase) {
634
636
637 // If the point pair.birth.id is in a signal pair
638 // AND If the point pair.death.id is not in a signal pair
639 // Then we only modify the pair.death.id
640 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
641 && (vertex2PairsTargetDiagram[pair.death.id].size() == 0)) {
642 deathPairToDeleteCurrentDiagram.push_back(
643 static_cast<SimplexId>(pair.death.id));
644 deathPairToDeleteTargetDiagram.push_back(
645 (pair.birth.sfValue + pair.death.sfValue) / 2);
646 continue;
647 }
648
649 // If the point pair.death.id is in a signal pair
650 // AND If the point pair.birth.id is not in a signal pair
651 // Then we only modify the pair.birth.id
652 if((vertex2PairsTargetDiagram[pair.birth.id].size() == 0)
653 && (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
654 birthPairToDeleteCurrentDiagram.push_back(
655 static_cast<SimplexId>(pair.birth.id));
656 birthPairToDeleteTargetDiagram.push_back(
657 (pair.birth.sfValue + pair.death.sfValue) / 2);
658 continue;
659 }
660
661 // If the point pair.birth.id is in a signal pair
662 // AND If the point pair.death.id is in a signal pair
663 // Then we do not modify either point
664 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
665 || (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
666 continue;
667 }
668 }
669
670 birthPairToDeleteCurrentDiagram.push_back(
671 static_cast<SimplexId>(pair.birth.id));
672 birthPairToDeleteTargetDiagram.push_back(
673 (pair.birth.sfValue + pair.death.sfValue) / 2);
674 deathPairToDeleteCurrentDiagram.push_back(
675 static_cast<SimplexId>(pair.death.id));
676 deathPairToDeleteTargetDiagram.push_back(
677 (pair.birth.sfValue + pair.death.sfValue) / 2);
678 }
679
680 for(const auto &entry : matchingsBlockPairs) {
681 // Delete pairs that have no equivalence
682 if(entry.size() == 1) {
683
685 // If the point thresholdCurrentDiagram[entry[0]].birth.id is in a
686 // signal pair AND If the point
687 // thresholdCurrentDiagram[entry[0]].death.id is not in a signal
688 // pair Then we only modify the
689 // thresholdCurrentDiagram[entry[0]].death.id
690 if((vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
691 .birth.id]
692 .size()
693 >= 1)
694 && (vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
695 .death.id]
696 .size()
697 == 0)) {
698 deathPairToDeleteCurrentDiagram.push_back(static_cast<SimplexId>(
699 thresholdCurrentDiagram[entry[0]].death.id));
700 deathPairToDeleteTargetDiagram.push_back(
701 (thresholdCurrentDiagram[entry[0]].birth.sfValue
702 + thresholdCurrentDiagram[entry[0]].death.sfValue)
703 / 2);
704 continue;
705 }
706
707 // If the point thresholdCurrentDiagram[entry[0]].death.id is in a
708 // signal pair AND If the point
709 // thresholdCurrentDiagram[entry[0]].birth.id is not in a signal
710 // pair Then we only modify the
711 // thresholdCurrentDiagram[entry[0]].birth.id
712 if((vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
713 .birth.id]
714 .size()
715 == 0)
716 && (vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
717 .death.id]
718 .size()
719 >= 1)) {
720 birthPairToDeleteCurrentDiagram.push_back(static_cast<SimplexId>(
721 thresholdCurrentDiagram[entry[0]].birth.id));
722 birthPairToDeleteTargetDiagram.push_back(
723 (thresholdCurrentDiagram[entry[0]].birth.sfValue
724 + thresholdCurrentDiagram[entry[0]].death.sfValue)
725 / 2);
726 continue;
727 }
728
729 // If the point thresholdCurrentDiagram[entry[0]].birth.id is in a
730 // signal pair AND If the point
731 // thresholdCurrentDiagram[entry[0]].death.id is in a signal pair
732 // Then we do not modify either point
733 if((vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
734 .birth.id]
735 .size()
736 >= 1)
737 || (vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
738 .death.id]
739 .size()
740 >= 1)) {
741 continue;
742 }
743 }
744
745 birthPairToDeleteCurrentDiagram.push_back(
746 static_cast<SimplexId>(thresholdCurrentDiagram[entry[0]].birth.id));
747 birthPairToDeleteTargetDiagram.push_back(
748 (thresholdCurrentDiagram[entry[0]].birth.sfValue
749 + thresholdCurrentDiagram[entry[0]].death.sfValue)
750 / 2);
751 deathPairToDeleteCurrentDiagram.push_back(
752 static_cast<SimplexId>(thresholdCurrentDiagram[entry[0]].death.id));
753 deathPairToDeleteTargetDiagram.push_back(
754 (thresholdCurrentDiagram[entry[0]].birth.sfValue
755 + thresholdCurrentDiagram[entry[0]].death.sfValue)
756 / 2);
757 continue;
758 } else if(entry.empty())
759 continue;
760
761 SimplexId valueBirthPairToChangeCurrentDiagram
762 = static_cast<SimplexId>(thresholdCurrentDiagram[entry[0]].birth.id);
763 SimplexId valueDeathPairToChangeCurrentDiagram
764 = static_cast<SimplexId>(thresholdCurrentDiagram[entry[0]].death.id);
765
766 double valueBirthPairToChangeTargetDiagram
767 = thresholdConstraintDiagram[entry[1]].birth.sfValue;
768 double valueDeathPairToChangeTargetDiagram
769 = thresholdConstraintDiagram[entry[1]].death.sfValue;
770
771 pair2MatchedPair[pairIndiceLocal2Global[entry[1]]][0]
772 = thresholdCurrentDiagram[entry[0]].birth.id;
773 pair2MatchedPair[pairIndiceLocal2Global[entry[1]]][1]
774 = thresholdCurrentDiagram[entry[0]].death.id;
775
776 pairChangeMatchingPair[pairIndiceLocal2Global[entry[1]]] = 1;
777
778 birthPairToChangeCurrentDiagram.push_back(
779 valueBirthPairToChangeCurrentDiagram);
780 birthPairToChangeTargetDiagram.push_back(
781 valueBirthPairToChangeTargetDiagram);
782 deathPairToChangeCurrentDiagram.push_back(
783 valueDeathPairToChangeCurrentDiagram);
784 deathPairToChangeTargetDiagram.push_back(
785 valueDeathPairToChangeTargetDiagram);
786 }
787 }
788 }
789 //=====================================//
790 // Basic Matching //
791 //=====================================//
792 else {
793 this->printMsg(
794 "Get Indices | Compute Wasserstein distance: ", debug::Priority::DETAIL);
795
796 if(epoch == 0) {
797 for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) {
798 auto &pair = diagramOutput[i];
799 currentVertex2PairsCurrentDiagram[pair.birth.id].push_back(i);
800 currentVertex2PairsCurrentDiagram[pair.death.id].push_back(i);
801 }
802 } else {
803 std::vector<std::vector<SimplexId>> newVertex2PairsCurrentDiagram(
804 vertexNumber_, std::vector<SimplexId>());
805
806 for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) {
807 auto &pair = diagramOutput[i];
808 newVertex2PairsCurrentDiagram[pair.birth.id].push_back(i);
809 newVertex2PairsCurrentDiagram[pair.death.id].push_back(i);
810 }
811
812 currentVertex2PairsCurrentDiagram = newVertex2PairsCurrentDiagram;
813 }
814
815 std::vector<std::vector<SimplexId>> vertex2PairsCurrentDiagram(
816 vertexNumber_, std::vector<SimplexId>());
817 for(SimplexId i = 0; i < (SimplexId)diagramOutput.size(); i++) {
818 auto &pair = diagramOutput[i];
819 vertex2PairsCurrentDiagram[pair.birth.id].push_back(i);
820 vertex2PairsCurrentDiagram[pair.death.id].push_back(i);
821 vertexInHowManyPairs[pair.birth.id]++;
822 vertexInHowManyPairs[pair.death.id]++;
823 }
824
825 std::vector<std::vector<SimplexId>> vertex2PairsTargetDiagram(
826 vertexNumber_, std::vector<SimplexId>());
827 for(SimplexId i = 0; i < (SimplexId)constraintDiagram.size(); i++) {
828 auto &pair = constraintDiagram[i];
829 vertex2PairsTargetDiagram[pair.birth.id].push_back(i);
830 vertex2PairsTargetDiagram[pair.death.id].push_back(i);
831 }
832
833 //=========================================
834 // Compute wasserstein distance
835 //=========================================
836 ttk::Timer timePersistenceDiagramClustering;
837
838 ttk::PersistenceDiagramClustering persistenceDiagramClustering;
839 PersistenceDiagramBarycenter pdBarycenter{};
840 std::vector<ttk::DiagramType> intermediateDiagrams{
841 constraintDiagram, diagramOutput};
842 std::vector<ttk::DiagramType> centroids;
843 std::vector<std::vector<std::vector<ttk::MatchingType>>> allMatchings;
844
845 if(pdcMethod_ == 0) {
846 persistenceDiagramClustering.setDebugLevel(debugLevel_);
847 persistenceDiagramClustering.setThreadNumber(threadNumber_);
848 // SetForceUseOfAlgorithm ==> Force the progressive approch if 2 inputs
849 persistenceDiagramClustering.setForceUseOfAlgorithm(false);
850 // setDeterministic ==> Deterministic algorithm
851 persistenceDiagramClustering.setDeterministic(true);
852 // setUseProgressive ==> Compute Progressive Barycenter
853 persistenceDiagramClustering.setUseProgressive(true);
854 // setUseInterruptible ==> Interruptible algorithm
855 // persistenceDiagramClustering.setUseInterruptible(true);
856 persistenceDiagramClustering.setUseInterruptible(false);
857 // // setTimeLimit ==> Maximal computation time (s)
858 persistenceDiagramClustering.setTimeLimit(0.01);
859 // setUseAdditionalPrecision ==> Force minimum precision on matchings
860 persistenceDiagramClustering.setUseAdditionalPrecision(true);
861 // setDeltaLim ==> Minimal relative precision
862 persistenceDiagramClustering.setDeltaLim(0.00000001);
863 // setUseAccelerated ==> Use Accelerated KMeans
864 persistenceDiagramClustering.setUseAccelerated(false);
865 // setUseKmeansppInit ==> KMeanspp Initialization
866 persistenceDiagramClustering.setUseKmeansppInit(false);
867
868 std::vector<int> clusterIds = persistenceDiagramClustering.execute(
869 intermediateDiagrams, centroids, allMatchings);
870 } else {
871 centroids.resize(1);
872 const auto wassersteinMetric = std::to_string(2);
873 pdBarycenter.setWasserstein(wassersteinMetric);
874 pdBarycenter.setMethod(2);
875 pdBarycenter.setNumberOfInputs(2);
876 pdBarycenter.setDeterministic(true);
877 pdBarycenter.setUseProgressive(true);
878 pdBarycenter.setDebugLevel(debugLevel_);
879 pdBarycenter.setThreadNumber(threadNumber_);
880 pdBarycenter.setAlpha(1);
881 pdBarycenter.setLambda(1);
882 pdBarycenter.execute(intermediateDiagrams, centroids[0], allMatchings);
883 }
884
885 this->printMsg(
886 "Get Indices | Persistence Diagram Clustering Time: "
887 + std::to_string(timePersistenceDiagramClustering.getElapsedTime()),
889
890 //=========================================
891 // Find matched pairs
892 //=========================================
893
894 std::vector<std::vector<SimplexId>> allPairsSelected{};
895 std::vector<std::vector<std::vector<double>>> matchingsBlock(
896 centroids[0].size());
897 std::vector<std::vector<ttk::PersistencePair>> matchingsBlockPairs(
898 centroids[0].size());
899
900 for(auto i = 1; i >= 0; --i) {
901 std::vector<ttk::MatchingType> &matching = allMatchings[0][i];
902
903 const auto &diag{intermediateDiagrams[i]};
904
905 for(SimplexId j = 0; j < (SimplexId)matching.size(); j++) {
906
907 const auto &m{matching[j]};
908 const auto &bidderId{std::get<0>(m)};
909 const auto &goodId{std::get<1>(m)};
910
911 if((goodId == -1) | (bidderId == -1))
912 continue;
913
914 if(diag[bidderId].persistence() != 0) {
915 matchingsBlock[goodId].push_back(
916 {static_cast<double>(diag[bidderId].birth.id),
917 static_cast<double>(diag[bidderId].death.id),
918 diag[bidderId].persistence()});
919 if(i == 1) {
920 matchingsBlockPairs[goodId].push_back(diag[bidderId]);
921 } else if(matchingsBlockPairs[goodId].size() > 0) {
922 matchingsBlockPairs[goodId].push_back(diag[bidderId]);
923 }
924 allPairsSelected.push_back(
925 {diag[bidderId].birth.id, diag[bidderId].death.id});
926 }
927 }
928 }
929
930 std::vector<ttk::PersistencePair> pairsToErase{};
931
932 std::map<std::vector<SimplexId>, SimplexId> currentToTarget;
933 for(auto &pair : allPairsSelected) {
934 currentToTarget[{pair[0], pair[1]}] = 1;
935 }
936
937 for(auto &pair : intermediateDiagrams[1]) {
938 if(pair.isFinite != 0) {
939 if(!(currentToTarget.count({pair.birth.id, pair.death.id}) > 0)) {
940 pairsToErase.push_back(pair);
941 }
942 }
943 }
944
945 for(auto &pair : pairsToErase) {
946
948
949 // If the point pair.birth.id is in a signal pair
950 // AND If the point pair.death.id is not in a signal pair
951 // Then we only modify the pair.death.id
952 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
953 && (vertex2PairsTargetDiagram[pair.death.id].size() == 0)) {
954 deathPairToDeleteCurrentDiagram.push_back(
955 static_cast<SimplexId>(pair.death.id));
956 deathPairToDeleteTargetDiagram.push_back(
957 (pair.birth.sfValue + pair.death.sfValue) / 2);
958 continue;
959 }
960
961 // If the point pair.death.id is in a signal pair
962 // AND If the point pair.birth.id is not in a signal pair
963 // Then we only modify the pair.birth.id
964 if((vertex2PairsTargetDiagram[pair.birth.id].size() == 0)
965 && (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
966 birthPairToDeleteCurrentDiagram.push_back(
967 static_cast<SimplexId>(pair.birth.id));
968 birthPairToDeleteTargetDiagram.push_back(
969 (pair.birth.sfValue + pair.death.sfValue) / 2);
970 continue;
971 }
972
973 // If the point pair.birth.id is in a signal pair
974 // AND If the point pair.death.id is in a signal pair
975 // Then we do not modify either point
976 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
977 || (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
978 continue;
979 }
980 }
981
982 birthPairToDeleteCurrentDiagram.push_back(
983 static_cast<SimplexId>(pair.birth.id));
984 birthPairToDeleteTargetDiagram.push_back(
985 (pair.birth.sfValue + pair.death.sfValue) / 2);
986 deathPairToDeleteCurrentDiagram.push_back(
987 static_cast<SimplexId>(pair.death.id));
988 deathPairToDeleteTargetDiagram.push_back(
989 (pair.birth.sfValue + pair.death.sfValue) / 2);
990 }
991
992 for(const auto &entry : matchingsBlockPairs) {
993 // Delete pairs that have no equivalence
994 if(entry.size() == 1) {
995 birthPairToDeleteCurrentDiagram.push_back(
996 static_cast<SimplexId>(entry[0].birth.id));
997 birthPairToDeleteTargetDiagram.push_back(
998 (entry[0].birth.sfValue + entry[0].death.sfValue) / 2);
999 deathPairToDeleteCurrentDiagram.push_back(
1000 static_cast<SimplexId>(entry[0].death.id));
1001 deathPairToDeleteTargetDiagram.push_back(
1002 (entry[0].birth.sfValue + entry[0].death.sfValue) / 2);
1003 continue;
1004 } else if(entry.empty())
1005 continue;
1006
1007 SimplexId valueBirthPairToChangeCurrentDiagram
1008 = static_cast<SimplexId>(entry[0].birth.id);
1009 SimplexId valueDeathPairToChangeCurrentDiagram
1010 = static_cast<SimplexId>(entry[0].death.id);
1011
1012 double valueBirthPairToChangeTargetDiagram = entry[1].birth.sfValue;
1013 double valueDeathPairToChangeTargetDiagram = entry[1].death.sfValue;
1014
1015 birthPairToChangeCurrentDiagram.push_back(
1016 valueBirthPairToChangeCurrentDiagram);
1017 birthPairToChangeTargetDiagram.push_back(
1018 valueBirthPairToChangeTargetDiagram);
1019 deathPairToChangeCurrentDiagram.push_back(
1020 valueDeathPairToChangeCurrentDiagram);
1021 deathPairToChangeTargetDiagram.push_back(
1022 valueDeathPairToChangeTargetDiagram);
1023 }
1024 }
1025}
1026
1027/*
1028 This function allows you to copy the values of a pytorch tensor
1029 to a vector in an optimized way.
1030*/
1031#ifdef TTK_ENABLE_TORCH
1032int ttk::TopologicalOptimization::tensorToVectorFast(
1033 const torch::Tensor &tensor, std::vector<double> &result) const {
1034 TORCH_CHECK(
1035 tensor.dtype() == torch::kDouble, "The tensor must be of double type");
1036 const double *dataPtr = tensor.data_ptr<double>();
1037 result.assign(dataPtr, dataPtr + tensor.numel());
1038
1039 return 0;
1040}
1041#endif
1042
1043template <typename dataType, typename triangulationType>
1045 const dataType *const inputScalars,
1046 dataType *const outputScalars,
1047 SimplexId *const inputOffsets,
1048 triangulationType *triangulation,
1049 const ttk::DiagramType &constraintDiagram) const {
1050
1051 Timer t;
1052 double stoppingCondition = 0;
1053 bool enableTorch = true;
1054
1055 if(methodOptimization_ == 1) {
1056#ifndef TTK_ENABLE_TORCH
1057 this->printWrn("Adam unavailable (Torch not found).");
1058 this->printWrn("Using direct gradient descent.");
1059 enableTorch = false;
1060#endif
1061 }
1062
1063 //=======================
1064 // Copy input data
1065 //=======================
1066 std::vector<double> dataVector(vertexNumber_);
1067 SimplexId *inputOffsetsCopie = inputOffsets;
1068
1069#ifdef TTK_ENABLE_OPENMP
1070#pragma omp parallel for num_threads(threadNumber_)
1071#endif
1072 for(SimplexId k = 0; k < vertexNumber_; ++k) {
1073 outputScalars[k] = inputScalars[k];
1074 dataVector[k] = inputScalars[k];
1075 if(std::isnan((double)outputScalars[k]))
1076 outputScalars[k] = 0;
1077 }
1078
1079 //===============================
1080 // Normalize the data
1081 //===============================
1082
1083 dataType minVal = *std::min_element(dataVector.begin(), dataVector.end());
1084 dataType maxVal = *std::max_element(dataVector.begin(), dataVector.end());
1085
1086#ifdef TTK_ENABLE_OPENMP
1087#pragma omp parallel for num_threads(threadNumber_)
1088#endif
1089 for(size_t i = 0; i < dataVector.size(); ++i) {
1090 dataVector[i] = (dataVector[i] - minVal) / (maxVal - minVal);
1091 }
1092
1093 ttk::DiagramType normalizedConstraintDiagram(constraintDiagram.size());
1094
1095#ifdef TTK_ENABLE_OPENMP
1096#pragma omp parallel for num_threads(threadNumber_)
1097#endif
1098 for(SimplexId i = 0; i < (SimplexId)constraintDiagram.size(); i++) {
1099 auto pair = constraintDiagram[i];
1100 pair.birth.sfValue = (pair.birth.sfValue - minVal) / (maxVal - minVal);
1101 pair.death.sfValue = (pair.death.sfValue - minVal) / (maxVal - minVal);
1102 normalizedConstraintDiagram[i] = pair;
1103 }
1104
1105 std::vector<double> losses;
1106 std::vector<double> inputScalarsX(vertexNumber_);
1107
1108 //========================================
1109 // Direct gradient descent
1110 //========================================
1111 if((methodOptimization_ == 0) || !(enableTorch)) {
1112 std::vector<SimplexId> listAllIndicesToChangeSmoothing(vertexNumber_, 0);
1113 std::vector<std::vector<SimplexId>> pair2MatchedPair(
1114 constraintDiagram.size(), std::vector<SimplexId>(2));
1115 std::vector<SimplexId> pairChangeMatchingPair(constraintDiagram.size(), -1);
1116 std::vector<std::vector<SimplexId>> pair2Delete(
1117 vertexNumber_, std::vector<SimplexId>());
1118 std::vector<std::vector<SimplexId>> currentVertex2PairsCurrentDiagram(
1119 vertexNumber_, std::vector<SimplexId>());
1120
1121 for(int it = 0; it < epochNumber_; it++) {
1122
1123 if(it % printFrequency_ == 0) {
1124 debugLevel_ = 3;
1125 } else {
1126 debugLevel_ = 0;
1127 }
1128
1129 this->printMsg("DirectGradientDescent - iteration #" + std::to_string(it),
1131
1132 // pairs to change
1133 std::vector<SimplexId> birthPairToChangeCurrentDiagram{};
1134 std::vector<double> birthPairToChangeTargetDiagram{};
1135 std::vector<SimplexId> deathPairToChangeCurrentDiagram{};
1136 std::vector<double> deathPairToChangeTargetDiagram{};
1137
1138 // pairs to delete
1139 std::vector<SimplexId> birthPairToDeleteCurrentDiagram{};
1140 std::vector<double> birthPairToDeleteTargetDiagram{};
1141 std::vector<SimplexId> deathPairToDeleteCurrentDiagram{};
1142 std::vector<double> deathPairToDeleteTargetDiagram{};
1143
1144 std::vector<int> vertexInHowManyPairs(vertexNumber_, 0);
1145
1146 getIndices(
1147 triangulation, inputOffsetsCopie, dataVector.data(),
1148 normalizedConstraintDiagram, it, listAllIndicesToChangeSmoothing,
1149 pair2MatchedPair, pair2Delete, pairChangeMatchingPair,
1150 birthPairToDeleteCurrentDiagram, birthPairToDeleteTargetDiagram,
1151 deathPairToDeleteCurrentDiagram, deathPairToDeleteTargetDiagram,
1152 birthPairToChangeCurrentDiagram, birthPairToChangeTargetDiagram,
1153 deathPairToChangeCurrentDiagram, deathPairToChangeTargetDiagram,
1154 currentVertex2PairsCurrentDiagram, vertexInHowManyPairs);
1155 std::fill(listAllIndicesToChangeSmoothing.begin(),
1156 listAllIndicesToChangeSmoothing.end(), 0);
1157
1158 //==========================================================================
1159 // Retrieve the indices for the pairs that we want to send diagonally
1160 //==========================================================================
1161 double lossDeletePairs = 0;
1162
1163 std::vector<SimplexId> &indexBirthPairToDelete
1164 = birthPairToDeleteCurrentDiagram;
1165 std::vector<double> &targetValueBirthPairToDelete
1166 = birthPairToDeleteTargetDiagram;
1167 std::vector<SimplexId> &indexDeathPairToDelete
1168 = deathPairToDeleteCurrentDiagram;
1169 std::vector<double> &targetValueDeathPairToDelete
1170 = deathPairToDeleteTargetDiagram;
1171
1172 this->printMsg("DirectGradientDescent - Number of pairs to delete: "
1173 + std::to_string(indexBirthPairToDelete.size()),
1175
1176 std::vector<int> vertexInCellMultiple(vertexNumber_, -1);
1177 std::vector<std::vector<double>> vertexToTargetValue(
1178 vertexNumber_, std::vector<double>());
1179
1180 if(indexBirthPairToDelete.size() == indexDeathPairToDelete.size()) {
1181 for(size_t i = 0; i < indexBirthPairToDelete.size(); i++) {
1182 lossDeletePairs += std::pow(dataVector[indexBirthPairToDelete[i]]
1183 - targetValueBirthPairToDelete[i],
1184 2)
1185 + std::pow(dataVector[indexDeathPairToDelete[i]]
1186 - targetValueDeathPairToDelete[i],
1187 2);
1188 SimplexId indexMax = indexBirthPairToDelete[i];
1189 SimplexId indexSelle = indexDeathPairToDelete[i];
1190
1191 if(!(finePairManagement_ == 2) && !(finePairManagement_ == 1)) {
1192 if(constraintAveraging_) {
1193 if(vertexInHowManyPairs[indexMax] == 1) {
1194 dataVector[indexMax]
1195 = dataVector[indexMax]
1196 - alpha_ * 2
1197 * (dataVector[indexMax]
1198 - targetValueBirthPairToDelete[i]);
1199 listAllIndicesToChangeSmoothing[indexMax] = 1;
1200 } else {
1201 vertexInCellMultiple[indexMax] = 1;
1202 vertexToTargetValue[indexMax].push_back(
1203 targetValueBirthPairToDelete[i]);
1204 }
1205
1206 if(vertexInHowManyPairs[indexSelle] == 1) {
1207 dataVector[indexSelle]
1208 = dataVector[indexSelle]
1209 - alpha_ * 2
1210 * (dataVector[indexSelle]
1211 - targetValueDeathPairToDelete[i]);
1212 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1213 } else {
1214 vertexInCellMultiple[indexSelle] = 1;
1215 vertexToTargetValue[indexSelle].push_back(
1216 targetValueDeathPairToDelete[i]);
1217 }
1218 } else {
1219 dataVector[indexMax] = dataVector[indexMax]
1220 - alpha_ * 2
1221 * (dataVector[indexMax]
1222 - targetValueBirthPairToDelete[i]);
1223 dataVector[indexSelle]
1224 = dataVector[indexSelle]
1225 - alpha_ * 2
1226 * (dataVector[indexSelle]
1227 - targetValueDeathPairToDelete[i]);
1228 listAllIndicesToChangeSmoothing[indexMax] = 1;
1229 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1230 }
1231 } else if(finePairManagement_ == 1) {
1232 if(constraintAveraging_) {
1233 if(vertexInHowManyPairs[indexSelle] == 1) {
1234 dataVector[indexSelle]
1235 = dataVector[indexSelle]
1236 - alpha_ * 2
1237 * (dataVector[indexSelle]
1238 - targetValueDeathPairToDelete[i]);
1239 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1240 } else {
1241 vertexInCellMultiple[indexSelle] = 1;
1242 vertexToTargetValue[indexSelle].push_back(
1243 targetValueDeathPairToDelete[i]);
1244 }
1245 } else {
1246 dataVector[indexSelle]
1247 = dataVector[indexSelle]
1248 - alpha_ * 2
1249 * (dataVector[indexSelle]
1250 - targetValueDeathPairToDelete[i]);
1251 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1252 }
1253 } else if(finePairManagement_ == 2) {
1254 if(constraintAveraging_) {
1255 if(vertexInHowManyPairs[indexMax] == 1) {
1256 dataVector[indexMax]
1257 = dataVector[indexMax]
1258 - alpha_ * 2
1259 * (dataVector[indexMax]
1260 - targetValueBirthPairToDelete[i]);
1261 listAllIndicesToChangeSmoothing[indexMax] = 1;
1262 } else {
1263 vertexInCellMultiple[indexMax] = 1;
1264 vertexToTargetValue[indexMax].push_back(
1265 targetValueBirthPairToDelete[i]);
1266 }
1267 } else {
1268 dataVector[indexMax] = dataVector[indexMax]
1269 - alpha_ * 2
1270 * (dataVector[indexMax]
1271 - targetValueBirthPairToDelete[i]);
1272 listAllIndicesToChangeSmoothing[indexMax] = 1;
1273 }
1274 }
1275 }
1276 } else {
1277 for(size_t i = 0; i < indexBirthPairToDelete.size(); i++) {
1278 lossDeletePairs += std::pow(dataVector[indexBirthPairToDelete[i]]
1279 - targetValueBirthPairToDelete[i],
1280 2);
1281 SimplexId indexMax = indexBirthPairToDelete[i];
1282
1283 if(!(finePairManagement_ == 1)) {
1284 if(constraintAveraging_) {
1285 if(vertexInHowManyPairs[indexMax] == 1) {
1286 dataVector[indexMax]
1287 = dataVector[indexMax]
1288 - alpha_ * 2
1289 * (dataVector[indexMax]
1290 - targetValueBirthPairToDelete[i]);
1291 listAllIndicesToChangeSmoothing[indexMax] = 1;
1292 } else {
1293 vertexInCellMultiple[indexMax] = 1;
1294 vertexToTargetValue[indexMax].push_back(
1295 targetValueBirthPairToDelete[i]);
1296 }
1297 } else {
1298 dataVector[indexMax] = dataVector[indexMax]
1299 - alpha_ * 2
1300 * (dataVector[indexMax]
1301 - targetValueBirthPairToDelete[i]);
1302 listAllIndicesToChangeSmoothing[indexMax] = 1;
1303 }
1304 } else { // finePairManagement_ == 1
1305 continue;
1306 }
1307 }
1308
1309 for(size_t i = 0; i < indexDeathPairToDelete.size(); i++) {
1310 lossDeletePairs += std::pow(dataVector[indexDeathPairToDelete[i]]
1311 - targetValueDeathPairToDelete[i],
1312 2);
1313 SimplexId indexSelle = indexDeathPairToDelete[i];
1314
1315 if(!(finePairManagement_ == 2)) {
1316 if(constraintAveraging_) {
1317 if(vertexInHowManyPairs[indexSelle] == 1) {
1318 dataVector[indexSelle]
1319 = dataVector[indexSelle]
1320 - alpha_ * 2
1321 * (dataVector[indexSelle]
1322 - targetValueDeathPairToDelete[i]);
1323 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1324 } else {
1325 vertexInCellMultiple[indexSelle] = 1;
1326 vertexToTargetValue[indexSelle].push_back(
1327 targetValueDeathPairToDelete[i]);
1328 }
1329 } else {
1330 dataVector[indexSelle]
1331 = dataVector[indexSelle]
1332 - alpha_ * 2
1333 * (dataVector[indexSelle]
1334 - targetValueDeathPairToDelete[i]);
1335 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1336 }
1337 } else { // finePairManagement_ == 2
1338 continue;
1339 }
1340 }
1341 }
1342
1343 this->printMsg("DirectGradientDescent - Loss Delete Pairs: "
1344 + std::to_string(lossDeletePairs),
1346 //==========================================================================
1347 // Retrieve the indices for the pairs that we want to change
1348 //==========================================================================
1349 double lossChangePairs = 0;
1350
1351 std::vector<SimplexId> &indexBirthPairToChange
1352 = birthPairToChangeCurrentDiagram;
1353 std::vector<double> &targetValueBirthPairToChange
1354 = birthPairToChangeTargetDiagram;
1355 std::vector<SimplexId> &indexDeathPairToChange
1356 = deathPairToChangeCurrentDiagram;
1357 std::vector<double> &targetValueDeathPairToChange
1358 = deathPairToChangeTargetDiagram;
1359
1360 for(size_t i = 0; i < indexBirthPairToChange.size(); i++) {
1361 lossChangePairs += std::pow(dataVector[indexBirthPairToChange[i]]
1362 - targetValueBirthPairToChange[i],
1363 2)
1364 + std::pow(dataVector[indexDeathPairToChange[i]]
1365 - targetValueDeathPairToChange[i],
1366 2);
1367
1368 SimplexId indexMax = indexBirthPairToChange[i];
1369 SimplexId indexSelle = indexDeathPairToChange[i];
1370
1371 if(constraintAveraging_) {
1372 if(vertexInHowManyPairs[indexMax] == 1) {
1373 dataVector[indexMax]
1374 = dataVector[indexMax]
1375 - alpha_ * 2
1376 * (dataVector[indexMax] - targetValueBirthPairToChange[i]);
1377 listAllIndicesToChangeSmoothing[indexMax] = 1;
1378 } else {
1379 vertexInCellMultiple[indexMax] = 1;
1380 vertexToTargetValue[indexMax].push_back(
1381 targetValueBirthPairToChange[i]);
1382 }
1383
1384 if(vertexInHowManyPairs[indexSelle] == 1) {
1385 dataVector[indexSelle] = dataVector[indexSelle]
1386 - alpha_ * 2
1387 * (dataVector[indexSelle]
1388 - targetValueDeathPairToChange[i]);
1389 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1390 } else {
1391 vertexInCellMultiple[indexSelle] = 1;
1392 vertexToTargetValue[indexSelle].push_back(
1393 targetValueDeathPairToChange[i]);
1394 }
1395 } else {
1396 dataVector[indexMax]
1397 = dataVector[indexMax]
1398 - alpha_ * 2
1399 * (dataVector[indexMax] - targetValueBirthPairToChange[i]);
1400 dataVector[indexSelle]
1401 = dataVector[indexSelle]
1402 - alpha_ * 2
1403 * (dataVector[indexSelle] - targetValueDeathPairToChange[i]);
1404 listAllIndicesToChangeSmoothing[indexMax] = 1;
1405 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1406 }
1407 }
1408
1409 this->printMsg("DirectGradientDescent - Loss Change Pairs: "
1410 + std::to_string(lossChangePairs),
1412
1413 if(constraintAveraging_) {
1414 for(SimplexId i = 0; i < (SimplexId)vertexInCellMultiple.size(); i++) {
1415 double averageTargetValue = 0;
1416
1417 if(vertexInCellMultiple[i] == 1) {
1418 for(auto targetValue : vertexToTargetValue[i]) {
1419 averageTargetValue += targetValue;
1420 }
1421 averageTargetValue
1422 = averageTargetValue / (int)vertexToTargetValue[i].size();
1423
1424 dataVector[i] = dataVector[i]
1425 - alpha_ * 2 * (dataVector[i] - averageTargetValue);
1426 listAllIndicesToChangeSmoothing[i] = 1;
1427 }
1428 }
1429 }
1430
1431 //==================================
1432 // Stop Condition
1433 //==================================
1434
1435 if(it == 0) {
1436 stoppingCondition
1437 = coefStopCondition_ * (lossDeletePairs + lossChangePairs);
1438 }
1439
1440 if(((lossDeletePairs + lossChangePairs) <= stoppingCondition))
1441 break;
1442 }
1443
1444//========================================================
1445// De-normalize data & Update output data
1446//========================================================
1447#ifdef TTK_ENABLE_OPENMP
1448#pragma omp parallel for num_threads(threadNumber_)
1449#endif
1450 for(SimplexId k = 0; k < vertexNumber_; ++k) {
1451 outputScalars[k] = dataVector[k] * (maxVal - minVal) + minVal;
1452 }
1453 }
1454
1455//=======================================
1456// Adam Optimization
1457//=======================================
1458#ifdef TTK_ENABLE_TORCH
1459 else if(methodOptimization_ == 1) {
1460 //=====================================================
1461 // Initialization of model parameters
1462 //=====================================================
1463 torch::Tensor F
1464 = torch::from_blob(dataVector.data(), {SimplexId(dataVector.size())},
1465 torch::dtype(torch::kDouble))
1466 .to(torch::kDouble);
1467 PersistenceGradientDescent model(F);
1468
1469 torch::optim::Adam optimizer(model.parameters(), learningRate_);
1470
1471 //=======================================
1472 // Optimization
1473 //=======================================
1474
1475 std::vector<std::vector<SimplexId>> pair2MatchedPair(
1476 constraintDiagram.size(), std::vector<SimplexId>(2));
1477 std::vector<SimplexId> pairChangeMatchingPair(constraintDiagram.size(), -1);
1478 std::vector<SimplexId> listAllIndicesToChange(vertexNumber_, 0);
1479 std::vector<std::vector<SimplexId>> pair2Delete(
1480 vertexNumber_, std::vector<SimplexId>());
1481 std::vector<std::vector<SimplexId>> currentVertex2PairsCurrentDiagram(
1482 vertexNumber_, std::vector<SimplexId>());
1483
1484 for(int i = 0; i < epochNumber_; i++) {
1485
1486 if(i % printFrequency_ == 0) {
1487 debugLevel_ = 3;
1488 } else {
1489 debugLevel_ = 0;
1490 }
1491
1492 this->printMsg(
1493 "Adam - epoch: " + std::to_string(i), debug::Priority::PERFORMANCE);
1494
1495 ttk::Timer timeOneIteration;
1496
1497 // Update the tensor with the new optimized values
1498 tensorToVectorFast(model.X.to(torch::kDouble), inputScalarsX);
1499
1500 // pairs to change
1501 std::vector<SimplexId> birthPairToChangeCurrentDiagram{};
1502 std::vector<double> birthPairToChangeTargetDiagram{};
1503 std::vector<SimplexId> deathPairToChangeCurrentDiagram{};
1504 std::vector<double> deathPairToChangeTargetDiagram{};
1505
1506 // pairs to delete
1507 std::vector<SimplexId> birthPairToDeleteCurrentDiagram{};
1508 std::vector<double> birthPairToDeleteTargetDiagram{};
1509 std::vector<SimplexId> deathPairToDeleteCurrentDiagram{};
1510 std::vector<double> deathPairToDeleteTargetDiagram{};
1511
1512 std::vector<int> vertexInHowManyPairs(vertexNumber_, 0);
1513
1514 // Retrieve the indices of the critical points that we must modify in
1515 // order to match our current diagram to our target diagram.
1516 getIndices(
1517 triangulation, inputOffsetsCopie, inputScalarsX.data(),
1518 normalizedConstraintDiagram, i, listAllIndicesToChange,
1519 pair2MatchedPair, pair2Delete, pairChangeMatchingPair,
1520 birthPairToDeleteCurrentDiagram, birthPairToDeleteTargetDiagram,
1521 deathPairToDeleteCurrentDiagram, deathPairToDeleteTargetDiagram,
1522 birthPairToChangeCurrentDiagram, birthPairToChangeTargetDiagram,
1523 deathPairToChangeCurrentDiagram, deathPairToChangeTargetDiagram,
1524 currentVertex2PairsCurrentDiagram, vertexInHowManyPairs);
1525
1526 std::fill(
1527 listAllIndicesToChange.begin(), listAllIndicesToChange.end(), 0);
1528 //==========================================================================
1529 // Retrieve the indices for the pairs that we want to send diagonally
1530 //==========================================================================
1531
1532 torch::Tensor valueOfXDeleteBirth = torch::index_select(
1533 model.X, 0, torch::tensor(birthPairToDeleteCurrentDiagram));
1534 auto valueDeleteBirth = torch::from_blob(
1535 birthPairToDeleteTargetDiagram.data(),
1536 {static_cast<SimplexId>(birthPairToDeleteTargetDiagram.size())},
1537 torch::kDouble);
1538 torch::Tensor valueOfXDeleteDeath = torch::index_select(
1539 model.X, 0, torch::tensor(deathPairToDeleteCurrentDiagram));
1540 auto valueDeleteDeath = torch::from_blob(
1541 deathPairToDeleteTargetDiagram.data(),
1542 {static_cast<SimplexId>(deathPairToDeleteTargetDiagram.size())},
1543 torch::kDouble);
1544
1545 torch::Tensor lossDeletePairs = torch::zeros({1}, torch::kDouble);
1546 if(!(finePairManagement_ == 2) && !(finePairManagement_ == 1)) {
1547 lossDeletePairs
1548 = torch::sum(torch::pow(valueOfXDeleteBirth - valueDeleteBirth, 2));
1549 lossDeletePairs
1550 = lossDeletePairs
1551 + torch::sum(torch::pow(valueOfXDeleteDeath - valueDeleteDeath, 2));
1552 } else if(finePairManagement_ == 1) {
1553 lossDeletePairs
1554 = torch::sum(torch::pow(valueOfXDeleteDeath - valueDeleteDeath, 2));
1555 } else if(finePairManagement_ == 2) {
1556 lossDeletePairs
1557 = torch::sum(torch::pow(valueOfXDeleteBirth - valueDeleteBirth, 2));
1558 }
1559
1560 this->printMsg("Adam - Loss Delete Pairs: "
1561 + std::to_string(lossDeletePairs.item<double>()),
1563
1564 //==========================================================================
1565 // Retrieve the indices for the pairs that we want to change
1566 //==========================================================================
1567
1568 torch::Tensor valueOfXChangeBirth = torch::index_select(
1569 model.X, 0, torch::tensor(birthPairToChangeCurrentDiagram));
1570 auto valueChangeBirth = torch::from_blob(
1571 birthPairToChangeTargetDiagram.data(),
1572 {static_cast<SimplexId>(birthPairToChangeTargetDiagram.size())},
1573 torch::kDouble);
1574 torch::Tensor valueOfXChangeDeath = torch::index_select(
1575 model.X, 0, torch::tensor(deathPairToChangeCurrentDiagram));
1576 auto valueChangeDeath = torch::from_blob(
1577 deathPairToChangeTargetDiagram.data(),
1578 {static_cast<SimplexId>(deathPairToChangeTargetDiagram.size())},
1579 torch::kDouble);
1580
1581 auto lossChangePairs
1582 = torch::sum((torch::pow(valueOfXChangeBirth - valueChangeBirth, 2)
1583 + torch::pow(valueOfXChangeDeath - valueChangeDeath, 2)));
1584
1585 this->printMsg("Adam - Loss Change Pairs: "
1586 + std::to_string(lossChangePairs.item<double>()),
1588
1589 //====================================
1590 // Definition of final loss
1591 //====================================
1592
1593 auto loss = lossDeletePairs + lossChangePairs;
1594
1595 this->printMsg("Adam - Loss: " + std::to_string(loss.item<double>()),
1597
1598 //==========================================
1599 // Back Propagation
1600 //==========================================
1601
1602 losses.push_back(loss.item<double>());
1603
1604 ttk::Timer timeBackPropagation;
1605 optimizer.zero_grad();
1606 loss.backward();
1607 optimizer.step();
1608
1609 //==========================================
1610 // Modified index checking
1611 //==========================================
1612
1613 // On trouve les indices qui ont changé
1614 std::vector<double> NewinputScalarsX(vertexNumber_);
1615 tensorToVectorFast(model.X.to(torch::kDouble), NewinputScalarsX);
1616
1617#ifdef TTK_ENABLE_OPENMP
1618#pragma omp parallel for num_threads(threadNumber_)
1619#endif
1620 for(SimplexId k = 0; k < vertexNumber_; ++k) {
1621 double diff = NewinputScalarsX[k] - inputScalarsX[k];
1622 if(diff != 0) {
1623 listAllIndicesToChange[k] = 1;
1624 }
1625 }
1626
1627 //=======================================
1628 // Stop condition
1629 //=======================================
1630 if(i == 0) {
1631 stoppingCondition = coefStopCondition_ * loss.item<double>();
1632 }
1633
1634 if(loss.item<double>() < stoppingCondition)
1635 break;
1636 }
1637
1638//============================================
1639// De-normalize data & Update output data
1640//============================================
1641#ifdef TTK_ENABLE_OPENMP
1642#pragma omp parallel for num_threads(threadNumber_)
1643#endif
1644 for(SimplexId k = 0; k < vertexNumber_; ++k) {
1645 outputScalars[k]
1646 = model.X[k].item().to<double>() * (maxVal - minVal) + minVal;
1647 if(std::isnan((double)outputScalars[k]))
1648 outputScalars[k] = 0;
1649 }
1650 }
1651#endif
1652 //========================================
1653 // Information display
1654 //========================================
1655 debugLevel_ = 3;
1656
1657 // Total execution time
1658 double time = t.getElapsedTime();
1659
1660 // Number Pairs Constraint Diagram
1661 SimplexId numberPairsConstraintDiagram = (SimplexId)constraintDiagram.size();
1662 this->printMsg("Number of constrained pairs: "
1663 + std::to_string(numberPairsConstraintDiagram),
1665
1666 this->printMsg("Stopping condition: " + std::to_string(stoppingCondition),
1668
1669 this->printMsg("Scalar field optimized", 1.0, time, this->threadNumber_);
1670
1671 return 0;
1672}
AbstractTriangulation is an interface class that defines an interface for efficient traversal methods...
virtual SimplexId getNumberOfVertices() const
virtual int setThreadNumber(const int threadNumber)
Definition BaseClass.h:80
Minimalist debugging class.
Definition Debug.h:88
int debugLevel_
Definition Debug.h:379
virtual int setDebugLevel(const int &debugLevel)
Definition Debug.cpp:147
TTK processing package for the computation of Wasserstein barycenters and K-Means clusterings of a se...
void setUseInterruptible(bool UseInterruptible_)
std::vector< int > execute(std::vector< DiagramType > &intermediateDiagrams, std::vector< DiagramType > &centroids, std::vector< std::vector< std::vector< MatchingType > > > &all_matchings)
void setForceUseOfAlgorithm(bool forceUseOfAlgorithm)
TTK processing package for the computation of persistence diagrams.
void preconditionTriangulation(AbstractTriangulation *triangulation)
int execute(std::vector< PersistencePair > &CTDiagram, const scalarType *inputScalars, const size_t scalarsMTime, const SimplexId *inputOffsets, const triangulationType *triangulation, const std::vector< bool > *updateMask=nullptr)
double getElapsedTime()
Definition Timer.h:15
void setMethodOptimization(int methodOptimization)
void setLearningRate(double learningRate)
void setFinePairManagement(int finePairManagement)
void setThresholdMethod(int thresholdMethod)
void setChooseLearningRate(int chooseLearningRate)
void setFastAssignmentUpdate(bool FastAssignmentUpdate)
int preconditionTriangulation(AbstractTriangulation *triangulation)
void setLowerThreshold(int lowerThreshold)
void getIndices(triangulationType *triangulation, SimplexId *&inputOffsets, dataType *const inputScalars, const ttk::DiagramType &constraintDiagram, int epoch, std::vector< SimplexId > &listAllIndicesToChange, std::vector< std::vector< SimplexId > > &pair2MatchedPair, std::vector< std::vector< SimplexId > > &pair2Delete, std::vector< SimplexId > &pairChangeMatchingPair, std::vector< SimplexId > &birthPairToDeleteCurrentDiagram, std::vector< double > &birthPairToDeleteTargetDiagram, std::vector< SimplexId > &deathPairToDeleteCurrentDiagram, std::vector< double > &deathPairToDeleteTargetDiagram, std::vector< SimplexId > &birthPairToChangeCurrentDiagram, std::vector< double > &birthPairToChangeTargetDiagram, std::vector< SimplexId > &deathPairToChangeCurrentDiagram, std::vector< double > &deathPairToChangeTargetDiagram, std::vector< std::vector< SimplexId > > &currentVertex2PairsCurrentDiagram, std::vector< int > &vertexInHowManyPairs) const
void setConstraintAveraging(bool ConstraintAveraging)
void setPrintFrequency(int printFrequency)
void setThresholdPersistence(double thresholdPersistence)
void setPairTypeToDelete(int pairTypeToDelete)
void setUpperThreshold(int upperThreshold)
void setOptimizationWithoutMatching(bool optimizationWithoutMatching)
void setCoefStopCondition(double coefStopCondition)
void setUseFastPersistenceUpdate(bool UseFastPersistenceUpdate)
int execute(const dataType *const inputScalars, dataType *const outputScalars, SimplexId *const inputOffsets, triangulationType *triangulation, const ttk::DiagramType &constraintDiagram) const
The Topology ToolKit.
std::vector< PersistencePair > DiagramType
Persistence Diagram type as a vector of Persistence pairs.
int SimplexId
Identifier type for simplices of any dimension.
Definition DataTypes.h:22
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/|__ _|"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)