231 triangulationType *triangulation,
233 dataType *
const inputScalars,
236 std::vector<SimplexId> &listAllIndicesToChange,
237 std::vector<std::vector<SimplexId>> &pair2MatchedPair,
238 std::vector<std::vector<SimplexId>> &pair2Delete,
239 std::vector<SimplexId> &pairChangeMatchingPair,
240 std::vector<SimplexId> &birthPairToDeleteCurrentDiagram,
241 std::vector<double> &birthPairToDeleteTargetDiagram,
242 std::vector<SimplexId> &deathPairToDeleteCurrentDiagram,
243 std::vector<double> &deathPairToDeleteTargetDiagram,
244 std::vector<SimplexId> &birthPairToChangeCurrentDiagram,
245 std::vector<double> &birthPairToChangeTargetDiagram,
246 std::vector<SimplexId> &deathPairToChangeCurrentDiagram,
247 std::vector<double> &deathPairToChangeTargetDiagram,
248 std::vector<std::vector<SimplexId>> ¤tVertex2PairsCurrentDiagram,
249 std::vector<int> &vertexInHowManyPairs)
const {
255 bool needUpdateDefaultValue
258 std::vector<bool> needUpdate(
vertexNumber_, needUpdateDefaultValue);
266 if(not(epoch == 0 || epoch < 0)) {
267#ifdef TTK_ENABLE_OPENMP
268#pragma omp parallel for num_threads(threadNumber_)
270 for(
size_t index = 0; index < listAllIndicesToChange.size(); index++) {
271 if(listAllIndicesToChange[index] == 1) {
272 needUpdate[index] =
true;
275 int vertexNumber = triangulation->getVertexNeighborNumber(index);
276 for(
int i = 0; i < vertexNumber; i++) {
278 triangulation->getVertexNeighbor(index, i, vertexNeighborId);
279 needUpdate[vertexNeighborId] =
true;
286 SimplexId count = std::count(needUpdate.begin(), needUpdate.end(),
true);
289 "Get Indices | The number of vertices that need to be updated is: "
290 + std::to_string(count),
299 std::vector<ttk::PersistencePair> diagramOutput;
300 ttk::preconditionOrderArray<dataType>(
308 diagramOutput, inputScalars, 0, inputOffsets, triangulation, &needUpdate);
311 diagramOutput, inputScalars, epoch, inputOffsets, triangulation);
320 auto pair = diagramOutput[i];
323 birthPairToDeleteCurrentDiagram.push_back(
325 birthPairToDeleteTargetDiagram.push_back(
326 (pair.birth.sfValue + pair.death.sfValue) / 2);
327 deathPairToDeleteCurrentDiagram.push_back(
329 deathPairToDeleteTargetDiagram.push_back(
330 (pair.birth.sfValue + pair.death.sfValue) / 2);
334 birthPairToDeleteCurrentDiagram.push_back(
336 birthPairToDeleteTargetDiagram.push_back(
337 (pair.birth.sfValue + pair.death.sfValue) / 2);
338 deathPairToDeleteCurrentDiagram.push_back(
340 deathPairToDeleteTargetDiagram.push_back(
341 (pair.birth.sfValue + pair.death.sfValue) / 2);
343 birthPairToDeleteCurrentDiagram.push_back(
345 birthPairToDeleteTargetDiagram.push_back(
346 (pair.birth.sfValue + pair.death.sfValue) / 2);
347 deathPairToDeleteCurrentDiagram.push_back(
349 deathPairToDeleteTargetDiagram.push_back(
350 (pair.birth.sfValue + pair.death.sfValue) / 2);
355 std::vector<std::vector<SimplexId>> vertex2PairsCurrentDiagram(
358 auto &pair = diagramOutput[i];
359 vertex2PairsCurrentDiagram[pair.birth.id].push_back(i);
360 vertex2PairsCurrentDiagram[pair.death.id].push_back(i);
361 vertexInHowManyPairs[pair.birth.id]++;
362 vertexInHowManyPairs[pair.death.id]++;
365 std::vector<std::vector<SimplexId>> vertex2PairsTargetDiagram(
368 auto &pair = constraintDiagram[i];
369 vertex2PairsTargetDiagram[pair.birth.id].push_back(i);
370 vertex2PairsTargetDiagram[pair.death.id].push_back(i);
373 std::vector<std::vector<SimplexId>> matchedPairs;
375 auto &pair = constraintDiagram[i];
380 if(pairChangeMatchingPair[i] == 1) {
381 birthId = pair2MatchedPair[i][0];
382 deathId = pair2MatchedPair[i][1];
384 birthId = pair.birth.id;
385 deathId = pair.death.id;
389 for(
auto &idPairBirth : vertex2PairsCurrentDiagram[birthId]) {
390 for(
auto &idPairDeath : vertex2PairsCurrentDiagram[deathId]) {
391 if(idPairBirth == idPairDeath) {
392 matchedPairs.push_back({i, idPairBirth});
396 }
else if((vertex2PairsCurrentDiagram[birthId].size() == 1)
397 && (vertex2PairsCurrentDiagram[deathId].size() == 1)) {
398 if(vertex2PairsCurrentDiagram[birthId][0]
399 == vertex2PairsCurrentDiagram[deathId][0]) {
400 matchedPairs.push_back({i, vertex2PairsCurrentDiagram[deathId][0]});
405 std::vector<SimplexId> matchingPairCurrentDiagram(
407 std::vector<SimplexId> matchingPairTargetDiagram(
408 (
SimplexId)constraintDiagram.size(), -1);
410 for(
auto &match : matchedPairs) {
411 auto &indicePairTargetDiagram = match[0];
412 auto &indicePairCurrentDiagram = match[1];
414 auto &pairCurrentDiagram = diagramOutput[indicePairCurrentDiagram];
415 auto &pairTargetDiagram = constraintDiagram[indicePairTargetDiagram];
417 pair2MatchedPair[indicePairTargetDiagram][0]
418 = pairCurrentDiagram.birth.id;
419 pair2MatchedPair[indicePairTargetDiagram][1]
420 = pairCurrentDiagram.death.id;
422 matchingPairCurrentDiagram[indicePairCurrentDiagram] = 1;
423 matchingPairTargetDiagram[indicePairTargetDiagram] = 1;
425 SimplexId valueBirthPairToChangeCurrentDiagram
426 = (
SimplexId)(pairCurrentDiagram.birth.id);
427 SimplexId valueDeathPairToChangeCurrentDiagram
428 = (
SimplexId)(pairCurrentDiagram.death.id);
430 double valueBirthPairToChangeTargetDiagram
431 = pairTargetDiagram.birth.sfValue;
432 double valueDeathPairToChangeTargetDiagram
433 = pairTargetDiagram.death.sfValue;
435 birthPairToChangeCurrentDiagram.push_back(
436 valueBirthPairToChangeCurrentDiagram);
437 birthPairToChangeTargetDiagram.push_back(
438 valueBirthPairToChangeTargetDiagram);
439 deathPairToChangeCurrentDiagram.push_back(
440 valueDeathPairToChangeCurrentDiagram);
441 deathPairToChangeTargetDiagram.push_back(
442 valueDeathPairToChangeTargetDiagram);
447 auto &pair = diagramOutput[i];
449 if((pair2Delete[pair.birth.id].size() == 1)
450 && (pair2Delete[pair.death.id].size() == 1)
451 && (pair2Delete[pair.birth.id] == pair2Delete[pair.death.id])) {
453 birthPairToDeleteCurrentDiagram.push_back(
455 birthPairToDeleteTargetDiagram.push_back(
456 (pair.birth.sfValue + pair.death.sfValue) / 2);
457 deathPairToDeleteCurrentDiagram.push_back(
459 deathPairToDeleteTargetDiagram.push_back(
460 (pair.birth.sfValue + pair.death.sfValue) / 2);
463 if(matchingPairCurrentDiagram[i] == -1) {
464 thresholdCurrentDiagram.push_back(pair);
469 std::vector<SimplexId> pairIndiceLocal2Global{};
471 auto &pair = constraintDiagram[i];
473 if(matchingPairTargetDiagram[i] == -1) {
474 thresholdConstraintDiagram.push_back(pair);
475 pairIndiceLocal2Global.push_back(i);
479 this->
printMsg(
"Get Indices | thresholdCurrentDiagram.size(): "
480 + std::to_string(thresholdCurrentDiagram.size()),
483 this->
printMsg(
"Get Indices | thresholdConstraintDiagram.size(): "
484 + std::to_string(thresholdConstraintDiagram.size()),
487 if(thresholdConstraintDiagram.size() == 0) {
489 auto &pair = thresholdCurrentDiagram[i];
496 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
497 && (vertex2PairsTargetDiagram[pair.death.id].size() == 0)) {
498 deathPairToDeleteCurrentDiagram.push_back(
500 deathPairToDeleteTargetDiagram.push_back(
501 (pair.birth.sfValue + pair.death.sfValue) / 2);
508 if((vertex2PairsTargetDiagram[pair.birth.id].size() == 0)
509 && (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
510 birthPairToDeleteCurrentDiagram.push_back(
512 birthPairToDeleteTargetDiagram.push_back(
513 (pair.birth.sfValue + pair.death.sfValue) / 2);
520 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
521 || (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
526 birthPairToDeleteCurrentDiagram.push_back(
528 birthPairToDeleteTargetDiagram.push_back(
529 (pair.birth.sfValue + pair.death.sfValue) / 2);
530 deathPairToDeleteCurrentDiagram.push_back(
532 deathPairToDeleteTargetDiagram.push_back(
533 (pair.birth.sfValue + pair.death.sfValue) / 2);
535 pair2Delete[pair.birth.id].push_back(i);
536 pair2Delete[pair.death.id].push_back(i);
544 std::vector<ttk::DiagramType> intermediateDiagrams{
545 thresholdConstraintDiagram, thresholdCurrentDiagram};
546 std::vector<std::vector<std::vector<ttk::MatchingType>>> allMatchings;
547 std::vector<ttk::DiagramType> centroids{};
569 std::vector<int> clusterIds = persistenceDiagramClustering.
execute(
570 intermediateDiagrams, centroids, allMatchings);
574 const auto wassersteinMetric = std::to_string(2);
575 pdBarycenter.setWasserstein(wassersteinMetric);
576 pdBarycenter.setMethod(2);
577 pdBarycenter.setNumberOfInputs(2);
578 pdBarycenter.setDeterministic(
true);
579 pdBarycenter.setUseProgressive(
true);
582 pdBarycenter.setAlpha(1);
583 pdBarycenter.setLambda(1);
584 pdBarycenter.execute(intermediateDiagrams, centroids[0], allMatchings);
587 std::vector<std::vector<SimplexId>> allPairsSelected{};
588 std::vector<std::vector<SimplexId>> matchingsBlockPairs(
589 centroids[0].size());
591 for(
auto i = 1; i >= 0; --i) {
592 std::vector<ttk::MatchingType> &matching = allMatchings[0][i];
594 const auto &diag{intermediateDiagrams[i]};
598 const auto &m{matching[j]};
599 const auto &bidderId{std::get<0>(m)};
600 const auto &goodId{std::get<1>(m)};
602 if((goodId == -1) | (bidderId == -1)) {
606 if(diag[bidderId].persistence() != 0) {
608 matchingsBlockPairs[goodId].push_back(bidderId);
609 }
else if(matchingsBlockPairs[goodId].size() > 0) {
610 matchingsBlockPairs[goodId].push_back(bidderId);
612 allPairsSelected.push_back(
613 {diag[bidderId].birth.id, diag[bidderId].death.id});
618 std::vector<ttk::PersistencePair> pairsToErase{};
620 std::map<std::vector<SimplexId>,
SimplexId> currentToTarget;
621 for(
auto &pair : allPairsSelected) {
622 currentToTarget[{pair[0], pair[1]}] = 1;
625 for(
auto &pair : intermediateDiagrams[1]) {
626 if(pair.isFinite != 0) {
627 if(!(currentToTarget.count({pair.birth.id, pair.death.id}) > 0)) {
628 pairsToErase.push_back(pair);
633 for(
auto &pair : pairsToErase) {
640 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
641 && (vertex2PairsTargetDiagram[pair.death.id].size() == 0)) {
642 deathPairToDeleteCurrentDiagram.push_back(
644 deathPairToDeleteTargetDiagram.push_back(
645 (pair.birth.sfValue + pair.death.sfValue) / 2);
652 if((vertex2PairsTargetDiagram[pair.birth.id].size() == 0)
653 && (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
654 birthPairToDeleteCurrentDiagram.push_back(
656 birthPairToDeleteTargetDiagram.push_back(
657 (pair.birth.sfValue + pair.death.sfValue) / 2);
664 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
665 || (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
670 birthPairToDeleteCurrentDiagram.push_back(
672 birthPairToDeleteTargetDiagram.push_back(
673 (pair.birth.sfValue + pair.death.sfValue) / 2);
674 deathPairToDeleteCurrentDiagram.push_back(
676 deathPairToDeleteTargetDiagram.push_back(
677 (pair.birth.sfValue + pair.death.sfValue) / 2);
680 for(
const auto &entry : matchingsBlockPairs) {
682 if(entry.size() == 1) {
690 if((vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
694 && (vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
698 deathPairToDeleteCurrentDiagram.push_back(
static_cast<SimplexId>(
699 thresholdCurrentDiagram[entry[0]].death.id));
700 deathPairToDeleteTargetDiagram.push_back(
701 (thresholdCurrentDiagram[entry[0]].birth.sfValue
702 + thresholdCurrentDiagram[entry[0]].death.sfValue)
712 if((vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
716 && (vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
720 birthPairToDeleteCurrentDiagram.push_back(
static_cast<SimplexId>(
721 thresholdCurrentDiagram[entry[0]].birth.id));
722 birthPairToDeleteTargetDiagram.push_back(
723 (thresholdCurrentDiagram[entry[0]].birth.sfValue
724 + thresholdCurrentDiagram[entry[0]].death.sfValue)
733 if((vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
737 || (vertex2PairsTargetDiagram[thresholdCurrentDiagram[entry[0]]
745 birthPairToDeleteCurrentDiagram.push_back(
746 static_cast<SimplexId>(thresholdCurrentDiagram[entry[0]].birth.id));
747 birthPairToDeleteTargetDiagram.push_back(
748 (thresholdCurrentDiagram[entry[0]].birth.sfValue
749 + thresholdCurrentDiagram[entry[0]].death.sfValue)
751 deathPairToDeleteCurrentDiagram.push_back(
752 static_cast<SimplexId>(thresholdCurrentDiagram[entry[0]].death.id));
753 deathPairToDeleteTargetDiagram.push_back(
754 (thresholdCurrentDiagram[entry[0]].birth.sfValue
755 + thresholdCurrentDiagram[entry[0]].death.sfValue)
758 }
else if(entry.empty())
761 SimplexId valueBirthPairToChangeCurrentDiagram
762 =
static_cast<SimplexId>(thresholdCurrentDiagram[entry[0]].birth.id);
763 SimplexId valueDeathPairToChangeCurrentDiagram
764 =
static_cast<SimplexId>(thresholdCurrentDiagram[entry[0]].death.id);
766 double valueBirthPairToChangeTargetDiagram
767 = thresholdConstraintDiagram[entry[1]].birth.sfValue;
768 double valueDeathPairToChangeTargetDiagram
769 = thresholdConstraintDiagram[entry[1]].death.sfValue;
771 pair2MatchedPair[pairIndiceLocal2Global[entry[1]]][0]
772 = thresholdCurrentDiagram[entry[0]].birth.id;
773 pair2MatchedPair[pairIndiceLocal2Global[entry[1]]][1]
774 = thresholdCurrentDiagram[entry[0]].death.id;
776 pairChangeMatchingPair[pairIndiceLocal2Global[entry[1]]] = 1;
778 birthPairToChangeCurrentDiagram.push_back(
779 valueBirthPairToChangeCurrentDiagram);
780 birthPairToChangeTargetDiagram.push_back(
781 valueBirthPairToChangeTargetDiagram);
782 deathPairToChangeCurrentDiagram.push_back(
783 valueDeathPairToChangeCurrentDiagram);
784 deathPairToChangeTargetDiagram.push_back(
785 valueDeathPairToChangeTargetDiagram);
798 auto &pair = diagramOutput[i];
799 currentVertex2PairsCurrentDiagram[pair.birth.id].push_back(i);
800 currentVertex2PairsCurrentDiagram[pair.death.id].push_back(i);
803 std::vector<std::vector<SimplexId>> newVertex2PairsCurrentDiagram(
807 auto &pair = diagramOutput[i];
808 newVertex2PairsCurrentDiagram[pair.birth.id].push_back(i);
809 newVertex2PairsCurrentDiagram[pair.death.id].push_back(i);
812 currentVertex2PairsCurrentDiagram = newVertex2PairsCurrentDiagram;
815 std::vector<std::vector<SimplexId>> vertex2PairsCurrentDiagram(
818 auto &pair = diagramOutput[i];
819 vertex2PairsCurrentDiagram[pair.birth.id].push_back(i);
820 vertex2PairsCurrentDiagram[pair.death.id].push_back(i);
821 vertexInHowManyPairs[pair.birth.id]++;
822 vertexInHowManyPairs[pair.death.id]++;
825 std::vector<std::vector<SimplexId>> vertex2PairsTargetDiagram(
828 auto &pair = constraintDiagram[i];
829 vertex2PairsTargetDiagram[pair.birth.id].push_back(i);
830 vertex2PairsTargetDiagram[pair.death.id].push_back(i);
840 std::vector<ttk::DiagramType> intermediateDiagrams{
841 constraintDiagram, diagramOutput};
842 std::vector<ttk::DiagramType> centroids;
843 std::vector<std::vector<std::vector<ttk::MatchingType>>> allMatchings;
862 persistenceDiagramClustering.
setDeltaLim(0.00000001);
868 std::vector<int> clusterIds = persistenceDiagramClustering.
execute(
869 intermediateDiagrams, centroids, allMatchings);
872 const auto wassersteinMetric = std::to_string(2);
873 pdBarycenter.setWasserstein(wassersteinMetric);
874 pdBarycenter.setMethod(2);
875 pdBarycenter.setNumberOfInputs(2);
876 pdBarycenter.setDeterministic(
true);
877 pdBarycenter.setUseProgressive(
true);
880 pdBarycenter.setAlpha(1);
881 pdBarycenter.setLambda(1);
882 pdBarycenter.execute(intermediateDiagrams, centroids[0], allMatchings);
886 "Get Indices | Persistence Diagram Clustering Time: "
887 + std::to_string(timePersistenceDiagramClustering.
getElapsedTime()),
894 std::vector<std::vector<SimplexId>> allPairsSelected{};
895 std::vector<std::vector<std::vector<double>>> matchingsBlock(
896 centroids[0].size());
897 std::vector<std::vector<ttk::PersistencePair>> matchingsBlockPairs(
898 centroids[0].size());
900 for(
auto i = 1; i >= 0; --i) {
901 std::vector<ttk::MatchingType> &matching = allMatchings[0][i];
903 const auto &diag{intermediateDiagrams[i]};
907 const auto &m{matching[j]};
908 const auto &bidderId{std::get<0>(m)};
909 const auto &goodId{std::get<1>(m)};
911 if((goodId == -1) | (bidderId == -1))
914 if(diag[bidderId].persistence() != 0) {
915 matchingsBlock[goodId].push_back(
916 {
static_cast<double>(diag[bidderId].birth.id),
917 static_cast<double>(diag[bidderId].death.id),
918 diag[bidderId].persistence()});
920 matchingsBlockPairs[goodId].push_back(diag[bidderId]);
921 }
else if(matchingsBlockPairs[goodId].size() > 0) {
922 matchingsBlockPairs[goodId].push_back(diag[bidderId]);
924 allPairsSelected.push_back(
925 {diag[bidderId].birth.id, diag[bidderId].death.id});
930 std::vector<ttk::PersistencePair> pairsToErase{};
932 std::map<std::vector<SimplexId>,
SimplexId> currentToTarget;
933 for(
auto &pair : allPairsSelected) {
934 currentToTarget[{pair[0], pair[1]}] = 1;
937 for(
auto &pair : intermediateDiagrams[1]) {
938 if(pair.isFinite != 0) {
939 if(!(currentToTarget.count({pair.birth.id, pair.death.id}) > 0)) {
940 pairsToErase.push_back(pair);
945 for(
auto &pair : pairsToErase) {
952 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
953 && (vertex2PairsTargetDiagram[pair.death.id].size() == 0)) {
954 deathPairToDeleteCurrentDiagram.push_back(
956 deathPairToDeleteTargetDiagram.push_back(
957 (pair.birth.sfValue + pair.death.sfValue) / 2);
964 if((vertex2PairsTargetDiagram[pair.birth.id].size() == 0)
965 && (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
966 birthPairToDeleteCurrentDiagram.push_back(
968 birthPairToDeleteTargetDiagram.push_back(
969 (pair.birth.sfValue + pair.death.sfValue) / 2);
976 if((vertex2PairsTargetDiagram[pair.birth.id].size() >= 1)
977 || (vertex2PairsTargetDiagram[pair.death.id].size() >= 1)) {
982 birthPairToDeleteCurrentDiagram.push_back(
984 birthPairToDeleteTargetDiagram.push_back(
985 (pair.birth.sfValue + pair.death.sfValue) / 2);
986 deathPairToDeleteCurrentDiagram.push_back(
988 deathPairToDeleteTargetDiagram.push_back(
989 (pair.birth.sfValue + pair.death.sfValue) / 2);
992 for(
const auto &entry : matchingsBlockPairs) {
994 if(entry.size() == 1) {
995 birthPairToDeleteCurrentDiagram.push_back(
996 static_cast<SimplexId>(entry[0].birth.id));
997 birthPairToDeleteTargetDiagram.push_back(
998 (entry[0].birth.sfValue + entry[0].death.sfValue) / 2);
999 deathPairToDeleteCurrentDiagram.push_back(
1000 static_cast<SimplexId>(entry[0].death.id));
1001 deathPairToDeleteTargetDiagram.push_back(
1002 (entry[0].birth.sfValue + entry[0].death.sfValue) / 2);
1004 }
else if(entry.empty())
1007 SimplexId valueBirthPairToChangeCurrentDiagram
1008 =
static_cast<SimplexId>(entry[0].birth.id);
1009 SimplexId valueDeathPairToChangeCurrentDiagram
1010 =
static_cast<SimplexId>(entry[0].death.id);
1012 double valueBirthPairToChangeTargetDiagram = entry[1].birth.sfValue;
1013 double valueDeathPairToChangeTargetDiagram = entry[1].death.sfValue;
1015 birthPairToChangeCurrentDiagram.push_back(
1016 valueBirthPairToChangeCurrentDiagram);
1017 birthPairToChangeTargetDiagram.push_back(
1018 valueBirthPairToChangeTargetDiagram);
1019 deathPairToChangeCurrentDiagram.push_back(
1020 valueDeathPairToChangeCurrentDiagram);
1021 deathPairToChangeTargetDiagram.push_back(
1022 valueDeathPairToChangeTargetDiagram);
1045 const dataType *
const inputScalars,
1046 dataType *
const outputScalars,
1048 triangulationType *triangulation,
1052 double stoppingCondition = 0;
1053 bool enableTorch =
true;
1055 if(methodOptimization_ == 1) {
1056#ifndef TTK_ENABLE_TORCH
1057 this->printWrn(
"Adam unavailable (Torch not found).");
1058 this->printWrn(
"Using direct gradient descent.");
1059 enableTorch =
false;
1066 std::vector<double> dataVector(vertexNumber_);
1067 SimplexId *inputOffsetsCopie = inputOffsets;
1069#ifdef TTK_ENABLE_OPENMP
1070#pragma omp parallel for num_threads(threadNumber_)
1072 for(
SimplexId k = 0; k < vertexNumber_; ++k) {
1073 outputScalars[k] = inputScalars[k];
1074 dataVector[k] = inputScalars[k];
1075 if(std::isnan((
double)outputScalars[k]))
1076 outputScalars[k] = 0;
1083 dataType minVal = *std::min_element(dataVector.begin(), dataVector.end());
1084 dataType maxVal = *std::max_element(dataVector.begin(), dataVector.end());
1086#ifdef TTK_ENABLE_OPENMP
1087#pragma omp parallel for num_threads(threadNumber_)
1089 for(
size_t i = 0; i < dataVector.size(); ++i) {
1090 dataVector[i] = (dataVector[i] - minVal) / (maxVal - minVal);
1095#ifdef TTK_ENABLE_OPENMP
1096#pragma omp parallel for num_threads(threadNumber_)
1099 auto pair = constraintDiagram[i];
1100 pair.birth.sfValue = (pair.birth.sfValue - minVal) / (maxVal - minVal);
1101 pair.death.sfValue = (pair.death.sfValue - minVal) / (maxVal - minVal);
1102 normalizedConstraintDiagram[i] = pair;
1105 std::vector<double> losses;
1106 std::vector<double> inputScalarsX(vertexNumber_);
1111 if((methodOptimization_ == 0) || !(enableTorch)) {
1112 std::vector<SimplexId> listAllIndicesToChangeSmoothing(vertexNumber_, 0);
1113 std::vector<std::vector<SimplexId>> pair2MatchedPair(
1114 constraintDiagram.size(), std::vector<SimplexId>(2));
1115 std::vector<SimplexId> pairChangeMatchingPair(constraintDiagram.size(), -1);
1116 std::vector<std::vector<SimplexId>> pair2Delete(
1117 vertexNumber_, std::vector<SimplexId>());
1118 std::vector<std::vector<SimplexId>> currentVertex2PairsCurrentDiagram(
1119 vertexNumber_, std::vector<SimplexId>());
1121 for(
int it = 0; it < epochNumber_; it++) {
1123 if(it % printFrequency_ == 0) {
1129 this->
printMsg(
"DirectGradientDescent - iteration #" + std::to_string(it),
1133 std::vector<SimplexId> birthPairToChangeCurrentDiagram{};
1134 std::vector<double> birthPairToChangeTargetDiagram{};
1135 std::vector<SimplexId> deathPairToChangeCurrentDiagram{};
1136 std::vector<double> deathPairToChangeTargetDiagram{};
1139 std::vector<SimplexId> birthPairToDeleteCurrentDiagram{};
1140 std::vector<double> birthPairToDeleteTargetDiagram{};
1141 std::vector<SimplexId> deathPairToDeleteCurrentDiagram{};
1142 std::vector<double> deathPairToDeleteTargetDiagram{};
1144 std::vector<int> vertexInHowManyPairs(vertexNumber_, 0);
1147 triangulation, inputOffsetsCopie, dataVector.data(),
1148 normalizedConstraintDiagram, it, listAllIndicesToChangeSmoothing,
1149 pair2MatchedPair, pair2Delete, pairChangeMatchingPair,
1150 birthPairToDeleteCurrentDiagram, birthPairToDeleteTargetDiagram,
1151 deathPairToDeleteCurrentDiagram, deathPairToDeleteTargetDiagram,
1152 birthPairToChangeCurrentDiagram, birthPairToChangeTargetDiagram,
1153 deathPairToChangeCurrentDiagram, deathPairToChangeTargetDiagram,
1154 currentVertex2PairsCurrentDiagram, vertexInHowManyPairs);
1155 std::fill(listAllIndicesToChangeSmoothing.begin(),
1156 listAllIndicesToChangeSmoothing.end(), 0);
1161 double lossDeletePairs = 0;
1163 std::vector<SimplexId> &indexBirthPairToDelete
1164 = birthPairToDeleteCurrentDiagram;
1165 std::vector<double> &targetValueBirthPairToDelete
1166 = birthPairToDeleteTargetDiagram;
1167 std::vector<SimplexId> &indexDeathPairToDelete
1168 = deathPairToDeleteCurrentDiagram;
1169 std::vector<double> &targetValueDeathPairToDelete
1170 = deathPairToDeleteTargetDiagram;
1172 this->
printMsg(
"DirectGradientDescent - Number of pairs to delete: "
1173 + std::to_string(indexBirthPairToDelete.size()),
1176 std::vector<int> vertexInCellMultiple(vertexNumber_, -1);
1177 std::vector<std::vector<double>> vertexToTargetValue(
1178 vertexNumber_, std::vector<double>());
1180 if(indexBirthPairToDelete.size() == indexDeathPairToDelete.size()) {
1181 for(
size_t i = 0; i < indexBirthPairToDelete.size(); i++) {
1182 lossDeletePairs += std::pow(dataVector[indexBirthPairToDelete[i]]
1183 - targetValueBirthPairToDelete[i],
1185 + std::pow(dataVector[indexDeathPairToDelete[i]]
1186 - targetValueDeathPairToDelete[i],
1188 SimplexId indexMax = indexBirthPairToDelete[i];
1189 SimplexId indexSelle = indexDeathPairToDelete[i];
1191 if(!(finePairManagement_ == 2) && !(finePairManagement_ == 1)) {
1192 if(constraintAveraging_) {
1193 if(vertexInHowManyPairs[indexMax] == 1) {
1194 dataVector[indexMax]
1195 = dataVector[indexMax]
1197 * (dataVector[indexMax]
1198 - targetValueBirthPairToDelete[i]);
1199 listAllIndicesToChangeSmoothing[indexMax] = 1;
1201 vertexInCellMultiple[indexMax] = 1;
1202 vertexToTargetValue[indexMax].push_back(
1203 targetValueBirthPairToDelete[i]);
1206 if(vertexInHowManyPairs[indexSelle] == 1) {
1207 dataVector[indexSelle]
1208 = dataVector[indexSelle]
1210 * (dataVector[indexSelle]
1211 - targetValueDeathPairToDelete[i]);
1212 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1214 vertexInCellMultiple[indexSelle] = 1;
1215 vertexToTargetValue[indexSelle].push_back(
1216 targetValueDeathPairToDelete[i]);
1219 dataVector[indexMax] = dataVector[indexMax]
1221 * (dataVector[indexMax]
1222 - targetValueBirthPairToDelete[i]);
1223 dataVector[indexSelle]
1224 = dataVector[indexSelle]
1226 * (dataVector[indexSelle]
1227 - targetValueDeathPairToDelete[i]);
1228 listAllIndicesToChangeSmoothing[indexMax] = 1;
1229 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1231 }
else if(finePairManagement_ == 1) {
1232 if(constraintAveraging_) {
1233 if(vertexInHowManyPairs[indexSelle] == 1) {
1234 dataVector[indexSelle]
1235 = dataVector[indexSelle]
1237 * (dataVector[indexSelle]
1238 - targetValueDeathPairToDelete[i]);
1239 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1241 vertexInCellMultiple[indexSelle] = 1;
1242 vertexToTargetValue[indexSelle].push_back(
1243 targetValueDeathPairToDelete[i]);
1246 dataVector[indexSelle]
1247 = dataVector[indexSelle]
1249 * (dataVector[indexSelle]
1250 - targetValueDeathPairToDelete[i]);
1251 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1253 }
else if(finePairManagement_ == 2) {
1254 if(constraintAveraging_) {
1255 if(vertexInHowManyPairs[indexMax] == 1) {
1256 dataVector[indexMax]
1257 = dataVector[indexMax]
1259 * (dataVector[indexMax]
1260 - targetValueBirthPairToDelete[i]);
1261 listAllIndicesToChangeSmoothing[indexMax] = 1;
1263 vertexInCellMultiple[indexMax] = 1;
1264 vertexToTargetValue[indexMax].push_back(
1265 targetValueBirthPairToDelete[i]);
1268 dataVector[indexMax] = dataVector[indexMax]
1270 * (dataVector[indexMax]
1271 - targetValueBirthPairToDelete[i]);
1272 listAllIndicesToChangeSmoothing[indexMax] = 1;
1277 for(
size_t i = 0; i < indexBirthPairToDelete.size(); i++) {
1278 lossDeletePairs += std::pow(dataVector[indexBirthPairToDelete[i]]
1279 - targetValueBirthPairToDelete[i],
1281 SimplexId indexMax = indexBirthPairToDelete[i];
1283 if(!(finePairManagement_ == 1)) {
1284 if(constraintAveraging_) {
1285 if(vertexInHowManyPairs[indexMax] == 1) {
1286 dataVector[indexMax]
1287 = dataVector[indexMax]
1289 * (dataVector[indexMax]
1290 - targetValueBirthPairToDelete[i]);
1291 listAllIndicesToChangeSmoothing[indexMax] = 1;
1293 vertexInCellMultiple[indexMax] = 1;
1294 vertexToTargetValue[indexMax].push_back(
1295 targetValueBirthPairToDelete[i]);
1298 dataVector[indexMax] = dataVector[indexMax]
1300 * (dataVector[indexMax]
1301 - targetValueBirthPairToDelete[i]);
1302 listAllIndicesToChangeSmoothing[indexMax] = 1;
1309 for(
size_t i = 0; i < indexDeathPairToDelete.size(); i++) {
1310 lossDeletePairs += std::pow(dataVector[indexDeathPairToDelete[i]]
1311 - targetValueDeathPairToDelete[i],
1313 SimplexId indexSelle = indexDeathPairToDelete[i];
1315 if(!(finePairManagement_ == 2)) {
1316 if(constraintAveraging_) {
1317 if(vertexInHowManyPairs[indexSelle] == 1) {
1318 dataVector[indexSelle]
1319 = dataVector[indexSelle]
1321 * (dataVector[indexSelle]
1322 - targetValueDeathPairToDelete[i]);
1323 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1325 vertexInCellMultiple[indexSelle] = 1;
1326 vertexToTargetValue[indexSelle].push_back(
1327 targetValueDeathPairToDelete[i]);
1330 dataVector[indexSelle]
1331 = dataVector[indexSelle]
1333 * (dataVector[indexSelle]
1334 - targetValueDeathPairToDelete[i]);
1335 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1343 this->
printMsg(
"DirectGradientDescent - Loss Delete Pairs: "
1344 + std::to_string(lossDeletePairs),
1349 double lossChangePairs = 0;
1351 std::vector<SimplexId> &indexBirthPairToChange
1352 = birthPairToChangeCurrentDiagram;
1353 std::vector<double> &targetValueBirthPairToChange
1354 = birthPairToChangeTargetDiagram;
1355 std::vector<SimplexId> &indexDeathPairToChange
1356 = deathPairToChangeCurrentDiagram;
1357 std::vector<double> &targetValueDeathPairToChange
1358 = deathPairToChangeTargetDiagram;
1360 for(
size_t i = 0; i < indexBirthPairToChange.size(); i++) {
1361 lossChangePairs += std::pow(dataVector[indexBirthPairToChange[i]]
1362 - targetValueBirthPairToChange[i],
1364 + std::pow(dataVector[indexDeathPairToChange[i]]
1365 - targetValueDeathPairToChange[i],
1368 SimplexId indexMax = indexBirthPairToChange[i];
1369 SimplexId indexSelle = indexDeathPairToChange[i];
1371 if(constraintAveraging_) {
1372 if(vertexInHowManyPairs[indexMax] == 1) {
1373 dataVector[indexMax]
1374 = dataVector[indexMax]
1376 * (dataVector[indexMax] - targetValueBirthPairToChange[i]);
1377 listAllIndicesToChangeSmoothing[indexMax] = 1;
1379 vertexInCellMultiple[indexMax] = 1;
1380 vertexToTargetValue[indexMax].push_back(
1381 targetValueBirthPairToChange[i]);
1384 if(vertexInHowManyPairs[indexSelle] == 1) {
1385 dataVector[indexSelle] = dataVector[indexSelle]
1387 * (dataVector[indexSelle]
1388 - targetValueDeathPairToChange[i]);
1389 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1391 vertexInCellMultiple[indexSelle] = 1;
1392 vertexToTargetValue[indexSelle].push_back(
1393 targetValueDeathPairToChange[i]);
1396 dataVector[indexMax]
1397 = dataVector[indexMax]
1399 * (dataVector[indexMax] - targetValueBirthPairToChange[i]);
1400 dataVector[indexSelle]
1401 = dataVector[indexSelle]
1403 * (dataVector[indexSelle] - targetValueDeathPairToChange[i]);
1404 listAllIndicesToChangeSmoothing[indexMax] = 1;
1405 listAllIndicesToChangeSmoothing[indexSelle] = 1;
1409 this->
printMsg(
"DirectGradientDescent - Loss Change Pairs: "
1410 + std::to_string(lossChangePairs),
1413 if(constraintAveraging_) {
1415 double averageTargetValue = 0;
1417 if(vertexInCellMultiple[i] == 1) {
1418 for(
auto targetValue : vertexToTargetValue[i]) {
1419 averageTargetValue += targetValue;
1422 = averageTargetValue / (int)vertexToTargetValue[i].size();
1424 dataVector[i] = dataVector[i]
1425 - alpha_ * 2 * (dataVector[i] - averageTargetValue);
1426 listAllIndicesToChangeSmoothing[i] = 1;
1437 = coefStopCondition_ * (lossDeletePairs + lossChangePairs);
1440 if(((lossDeletePairs + lossChangePairs) <= stoppingCondition))
1447#ifdef TTK_ENABLE_OPENMP
1448#pragma omp parallel for num_threads(threadNumber_)
1450 for(
SimplexId k = 0; k < vertexNumber_; ++k) {
1451 outputScalars[k] = dataVector[k] * (maxVal - minVal) + minVal;
1458#ifdef TTK_ENABLE_TORCH
1459 else if(methodOptimization_ == 1) {
1464 = torch::from_blob(dataVector.data(), {SimplexId(dataVector.size())},
1465 torch::dtype(torch::kDouble))
1466 .to(torch::kDouble);
1467 PersistenceGradientDescent model(F);
1469 torch::optim::Adam optimizer(model.parameters(), learningRate_);
1475 std::vector<std::vector<SimplexId>> pair2MatchedPair(
1476 constraintDiagram.size(), std::vector<SimplexId>(2));
1477 std::vector<SimplexId> pairChangeMatchingPair(constraintDiagram.size(), -1);
1478 std::vector<SimplexId> listAllIndicesToChange(vertexNumber_, 0);
1479 std::vector<std::vector<SimplexId>> pair2Delete(
1480 vertexNumber_, std::vector<SimplexId>());
1481 std::vector<std::vector<SimplexId>> currentVertex2PairsCurrentDiagram(
1482 vertexNumber_, std::vector<SimplexId>());
1484 for(
int i = 0; i < epochNumber_; i++) {
1486 if(i % printFrequency_ == 0) {
1498 tensorToVectorFast(model.X.to(torch::kDouble), inputScalarsX);
1501 std::vector<SimplexId> birthPairToChangeCurrentDiagram{};
1502 std::vector<double> birthPairToChangeTargetDiagram{};
1503 std::vector<SimplexId> deathPairToChangeCurrentDiagram{};
1504 std::vector<double> deathPairToChangeTargetDiagram{};
1507 std::vector<SimplexId> birthPairToDeleteCurrentDiagram{};
1508 std::vector<double> birthPairToDeleteTargetDiagram{};
1509 std::vector<SimplexId> deathPairToDeleteCurrentDiagram{};
1510 std::vector<double> deathPairToDeleteTargetDiagram{};
1512 std::vector<int> vertexInHowManyPairs(vertexNumber_, 0);
1517 triangulation, inputOffsetsCopie, inputScalarsX.data(),
1518 normalizedConstraintDiagram, i, listAllIndicesToChange,
1519 pair2MatchedPair, pair2Delete, pairChangeMatchingPair,
1520 birthPairToDeleteCurrentDiagram, birthPairToDeleteTargetDiagram,
1521 deathPairToDeleteCurrentDiagram, deathPairToDeleteTargetDiagram,
1522 birthPairToChangeCurrentDiagram, birthPairToChangeTargetDiagram,
1523 deathPairToChangeCurrentDiagram, deathPairToChangeTargetDiagram,
1524 currentVertex2PairsCurrentDiagram, vertexInHowManyPairs);
1527 listAllIndicesToChange.begin(), listAllIndicesToChange.end(), 0);
1532 torch::Tensor valueOfXDeleteBirth = torch::index_select(
1533 model.X, 0, torch::tensor(birthPairToDeleteCurrentDiagram));
1534 auto valueDeleteBirth = torch::from_blob(
1535 birthPairToDeleteTargetDiagram.data(),
1536 {static_cast<SimplexId>(birthPairToDeleteTargetDiagram.size())},
1538 torch::Tensor valueOfXDeleteDeath = torch::index_select(
1539 model.X, 0, torch::tensor(deathPairToDeleteCurrentDiagram));
1540 auto valueDeleteDeath = torch::from_blob(
1541 deathPairToDeleteTargetDiagram.data(),
1542 {static_cast<SimplexId>(deathPairToDeleteTargetDiagram.size())},
1545 torch::Tensor lossDeletePairs = torch::zeros({1}, torch::kDouble);
1546 if(!(finePairManagement_ == 2) && !(finePairManagement_ == 1)) {
1548 = torch::sum(torch::pow(valueOfXDeleteBirth - valueDeleteBirth, 2));
1551 + torch::sum(torch::pow(valueOfXDeleteDeath - valueDeleteDeath, 2));
1552 }
else if(finePairManagement_ == 1) {
1554 = torch::sum(torch::pow(valueOfXDeleteDeath - valueDeleteDeath, 2));
1555 }
else if(finePairManagement_ == 2) {
1557 = torch::sum(torch::pow(valueOfXDeleteBirth - valueDeleteBirth, 2));
1560 this->
printMsg(
"Adam - Loss Delete Pairs: "
1561 + std::to_string(lossDeletePairs.item<
double>()),
1568 torch::Tensor valueOfXChangeBirth = torch::index_select(
1569 model.X, 0, torch::tensor(birthPairToChangeCurrentDiagram));
1570 auto valueChangeBirth = torch::from_blob(
1571 birthPairToChangeTargetDiagram.data(),
1572 {static_cast<SimplexId>(birthPairToChangeTargetDiagram.size())},
1574 torch::Tensor valueOfXChangeDeath = torch::index_select(
1575 model.X, 0, torch::tensor(deathPairToChangeCurrentDiagram));
1576 auto valueChangeDeath = torch::from_blob(
1577 deathPairToChangeTargetDiagram.data(),
1578 {static_cast<SimplexId>(deathPairToChangeTargetDiagram.size())},
1581 auto lossChangePairs
1582 = torch::sum((torch::pow(valueOfXChangeBirth - valueChangeBirth, 2)
1583 + torch::pow(valueOfXChangeDeath - valueChangeDeath, 2)));
1585 this->
printMsg(
"Adam - Loss Change Pairs: "
1586 + std::to_string(lossChangePairs.item<
double>()),
1593 auto loss = lossDeletePairs + lossChangePairs;
1595 this->
printMsg(
"Adam - Loss: " + std::to_string(loss.item<
double>()),
1602 losses.push_back(loss.item<
double>());
1605 optimizer.zero_grad();
1614 std::vector<double> NewinputScalarsX(vertexNumber_);
1615 tensorToVectorFast(model.X.to(torch::kDouble), NewinputScalarsX);
1617#ifdef TTK_ENABLE_OPENMP
1618#pragma omp parallel for num_threads(threadNumber_)
1620 for(
SimplexId k = 0; k < vertexNumber_; ++k) {
1621 double diff = NewinputScalarsX[k] - inputScalarsX[k];
1623 listAllIndicesToChange[k] = 1;
1631 stoppingCondition = coefStopCondition_ * loss.item<
double>();
1634 if(loss.item<
double>() < stoppingCondition)
1641#ifdef TTK_ENABLE_OPENMP
1642#pragma omp parallel for num_threads(threadNumber_)
1644 for(
SimplexId k = 0; k < vertexNumber_; ++k) {
1646 = model.X[k].item().to<
double>() * (maxVal - minVal) + minVal;
1647 if(std::isnan((
double)outputScalars[k]))
1648 outputScalars[k] = 0;
1662 this->
printMsg(
"Number of constrained pairs: "
1663 + std::to_string(numberPairsConstraintDiagram),
1666 this->
printMsg(
"Stopping condition: " + std::to_string(stoppingCondition),
1669 this->
printMsg(
"Scalar field optimized", 1.0, time, this->threadNumber_);