155 const std::vector<ttk::DiagramType> &intermediateDiagrams,
156 std::vector<ttk::DiagramType> &dictDiagrams,
157 std::vector<std::vector<double>> &vectorWeights,
159 std::vector<double> &lossTab,
160 std::vector<std::vector<double>> &allLosses,
161 std::vector<std::vector<double>> &histoVectorWeights,
162 std::vector<ttk::DiagramType> &histoDictDiagrams,
165 bool doCompression) {
169 bool doOptimizeAtoms =
false;
170 bool doOptimizeWeights =
false;
173 printMsg(
"Weight Optimization activated");
174 doOptimizeWeights =
true;
176 printWrn(
"Weight Optimization desactivated");
179 doOptimizeAtoms =
true;
180 printMsg(
"Atom Optimization activated");
182 printWrn(
"Atom Optimization desactivated");
185 const auto nDiags = intermediateDiagrams.size();
188 std::vector<ttk::DiagramType> inputDiagramsMin(nDiags);
189 std::vector<ttk::DiagramType> inputDiagramsSad(nDiags);
190 std::vector<ttk::DiagramType> inputDiagramsMax(nDiags);
192 std::vector<BidderDiagram> bidderDiagramsMin{};
193 std::vector<BidderDiagram> bidderDiagramsSad{};
194 std::vector<BidderDiagram> bidderDiagramsMax{};
196 std::vector<std::vector<size_t>> originIndexDatasMin(nDiags);
197 std::vector<std::vector<size_t>> originIndexDatasSad(nDiags);
198 std::vector<std::vector<size_t>> originIndexDatasMax(nDiags);
201 intermediateDiagrams, inputDiagramsMin, inputDiagramsSad, inputDiagramsMax,
202 bidderDiagramsMin, bidderDiagramsSad, bidderDiagramsMax,
203 originIndexDatasMin, originIndexDatasSad, originIndexDatasMax,
true);
205 std::vector<ttk::DiagramType> barycentersList(nDiags);
206 std::vector<std::vector<std::vector<ttk::MatchingType>>> allMatchingsAtoms(
208 std::vector<std::vector<ttk::MatchingType>> matchingsDatasMin(nDiags);
209 std::vector<std::vector<ttk::MatchingType>> matchingsDatasSad(nDiags);
210 std::vector<std::vector<ttk::MatchingType>> matchingsDatasMax(nDiags);
221 int nbEpochPrevious = lossTab.size();
222 std::vector<size_t> initSizes(dictDiagrams.size());
223 for(
size_t j = 0; j < dictDiagrams.size(); ++j) {
224 initSizes[j] = dictDiagrams[j].size();
227 double factEquiv =
static_cast<double>(numAtom);
228 double step = 1. / (2. * 2. * factEquiv);
232 std::vector<std::vector<int>> bufferHistoAllEpochLife(dictDiagrams.size());
233 std::vector<std::vector<bool>> bufferHistoAllBoolLife(dictDiagrams.size());
234 std::vector<std::vector<bool>> bufferCheckUnderDiag(dictDiagrams.size());
235 std::vector<std::vector<bool>> bufferCheckDiag(dictDiagrams.size());
236 std::vector<std::vector<bool>> bufferCheckAboveGlobal(dictDiagrams.size());
238 std::vector<std::vector<int>> histoAllEpochLife(dictDiagrams.size());
239 std::vector<std::vector<bool>> histoAllBoolLife(dictDiagrams.size());
240 std::vector<std::vector<bool>> checkUnderDiag(dictDiagrams.size());
241 std::vector<std::vector<bool>> checkDiag(dictDiagrams.size());
242 std::vector<std::vector<bool>> checkAboveGlobal(dictDiagrams.size());
244 std::vector<double> allLossesAtEpoch(nDiags, 0.);
245 std::vector<double> trueAllLossesAtEpoch(nDiags, 0.);
247 while(epoch < maxEpoch && cond) {
250#ifdef TTK_ENABLE_OPENMP
251#pragma omp parallel for num_threads(threadNumber_)
254 for(size_t i = 0; i < nDiags; ++i) {
255 auto &barycenter = barycentersList[i];
256 std::vector<double> &weight = vectorWeights[i];
257 std::vector<std::vector<ttk::MatchingType>> &matchings
258 = allMatchingsAtoms[i];
263 "Computed 1st Barycenters for epoch " + std::to_string(epoch),
264 epoch /
static_cast<double>(maxEpoch), tm_it.getElapsedTime(),
270 std::vector<ttk::DiagramType> barycentersListMin(nDiags);
271 std::vector<ttk::DiagramType> barycentersListSad(nDiags);
272 std::vector<ttk::DiagramType> barycentersListMax(nDiags);
274 std::vector<BidderDiagram> bidderBarycentersListMin{};
275 std::vector<BidderDiagram> bidderBarycentersListSad{};
276 std::vector<BidderDiagram> bidderBarycentersListMax{};
278 std::vector<std::vector<size_t>> originIndexBarysMin(nDiags);
279 std::vector<std::vector<size_t>> originIndexBarysSad(nDiags);
280 std::vector<std::vector<size_t>> originIndexBarysMax(nDiags);
283 barycentersList, nDiags, barycentersListMin, barycentersListSad,
284 barycentersListMax, bidderBarycentersListMin, bidderBarycentersListSad,
285 bidderBarycentersListMax, originIndexBarysMin, originIndexBarysSad,
286 originIndexBarysMax, bidderDiagramsMin, bidderDiagramsMax,
287 bidderDiagramsSad, matchingsDatasMin, matchingsDatasMax,
288 matchingsDatasSad, allLossesAtEpoch,
true);
290 for(
size_t p = 0; p < nDiags; ++p) {
291 loss += allLossesAtEpoch[p];
294 for(
size_t p = 0; p < nDiags; ++p) {
296 allLosses[p].push_back(allLossesAtEpoch[p]);
298 allLosses[p].push_back(trueAllLossesAtEpoch[p]);
302 lossTab.push_back(loss);
305 " Epoch " + std::to_string(epoch) +
", loss = " + std::to_string(loss), 1,
311 doOptimizeAtoms =
false;
313 doOptimizeAtoms =
true;
319 doOptimizeWeights =
false;
320 doOptimizeAtoms =
false;
324 = *std::min_element(lossTab.begin() + nbEpochPrevious, lossTab.end() - 1);
326 for(
size_t p = 0; p < dictDiagrams.size(); ++p) {
327 const auto &atom = dictDiagrams[p];
328 histoDictDiagrams[p] = atom;
330 const auto &histoEpochAtom = histoAllEpochLife[p];
331 const auto &histoBoolAtom = histoAllBoolLife[p];
332 const auto &boolUnderDiag = checkUnderDiag[p];
333 const auto &boolDiag = checkDiag[p];
334 const auto &boolAboveGlobal = checkAboveGlobal[p];
336 bufferHistoAllEpochLife[p] = histoEpochAtom;
337 bufferHistoAllBoolLife[p] = histoBoolAtom;
338 bufferCheckUnderDiag[p] = boolUnderDiag;
339 bufferCheckDiag[p] = boolDiag;
340 bufferCheckAboveGlobal[p] = boolAboveGlobal;
342 for(
size_t p = 0; p < nDiags; ++p) {
343 const auto &weights = vectorWeights[p];
344 histoVectorWeights[p] = weights;
349 if(epoch > minEpoch) {
356 if((epoch > minEpoch)
357 && (lossTab[epoch + nbEpochPrevious]
358 / lossTab[epoch + nbEpochPrevious - 1]
360 if(lossTab[epoch + nbEpochPrevious]
361 < lossTab[epoch + nbEpochPrevious - 1]) {
363 this->
printMsg(
"Loss not decreasing enough");
365 for(
size_t p = 0; p < dictDiagrams.size(); ++p) {
366 const auto &atom = histoDictDiagrams[p];
367 dictDiagrams[p] = atom;
369 for(
size_t p = 0; p < nDiags; ++p) {
370 const auto &weights = histoVectorWeights[p];
371 vectorWeights[p] = weights;
373 doOptimizeWeights =
false;
374 doOptimizeAtoms =
false;
385 if(epoch > minEpoch && lag > lagLimit) {
388 for(
size_t p = 0; p < dictDiagrams.size(); ++p) {
389 const auto &atom = histoDictDiagrams[p];
390 dictDiagrams[p] = atom;
392 for(
size_t p = 0; p < nDiags; ++p) {
393 const auto &weights = histoVectorWeights[p];
394 vectorWeights[p] = weights;
396 this->
printMsg(
"Minimum not passed");
397 doOptimizeWeights =
false;
398 doOptimizeAtoms =
false;
404 if(cond && (epoch > 0) && (lossTab[epoch + nbEpochPrevious] > 2. * mini)) {
409 this->
printMsg(
"Loss increasing too much, reducing step and recompute "
410 "Barycenters and matchings");
411 for(
size_t p = 0; p < dictDiagrams.size(); ++p) {
412 const auto &atom = histoDictDiagrams[p];
413 dictDiagrams[p] = atom;
415 const auto &bufferHistoEpochAtom = bufferHistoAllEpochLife[p];
416 const auto &bufferHistoBoolAtom = bufferHistoAllBoolLife[p];
417 const auto &bufferBoolUnderDiag = bufferCheckUnderDiag[p];
418 const auto &bufferBoolDiag = bufferCheckDiag[p];
419 const auto &bufferBoolAboveGlobal = bufferCheckAboveGlobal[p];
421 histoAllEpochLife[p] = bufferHistoEpochAtom;
422 histoAllBoolLife[p] = bufferHistoBoolAtom;
423 checkUnderDiag[p] = bufferBoolUnderDiag;
424 checkDiag[p] = bufferBoolDiag;
425 checkAboveGlobal[p] = bufferBoolAboveGlobal;
428 for(
size_t p = 0; p < nDiags; ++p) {
429 const auto &weights = histoVectorWeights[p];
430 vectorWeights[p] = weights;
436 barycentersList.clear();
437 barycentersList.resize(nDiags);
438 allMatchingsAtoms.clear();
439 allMatchingsAtoms.resize(nDiags);
441 barycentersListMin.clear();
442 barycentersListSad.clear();
443 barycentersListMax.clear();
445 barycentersListMin.resize(nDiags);
446 barycentersListSad.resize(nDiags);
447 barycentersListMax.resize(nDiags);
449 bidderBarycentersListMin.clear();
450 bidderBarycentersListSad.clear();
451 bidderBarycentersListMax.clear();
453 originIndexBarysMin.clear();
454 originIndexBarysSad.clear();
455 originIndexBarysMax.clear();
457 originIndexBarysMin.resize(nDiags);
458 originIndexBarysSad.resize(nDiags);
459 originIndexBarysMax.resize(nDiags);
461 matchingsDatasMin.clear();
462 matchingsDatasSad.clear();
463 matchingsDatasMax.clear();
465 matchingsDatasMin.resize(nDiags);
466 matchingsDatasSad.resize(nDiags);
467 matchingsDatasMax.resize(nDiags);
469#ifdef TTK_ENABLE_OPENMP
470#pragma omp parallel for num_threads(threadNumber_)
473 for(size_t i = 0; i < nDiags; ++i) {
474 auto &barycenter = barycentersList[i];
475 std::vector<double> &weight = vectorWeights[i];
476 std::vector<std::vector<ttk::MatchingType>> &matchings
477 = allMatchingsAtoms[i];
483 barycentersList, nDiags, barycentersListMin, barycentersListSad,
484 barycentersListMax, bidderBarycentersListMin,
485 bidderBarycentersListSad, bidderBarycentersListMax,
486 originIndexBarysMin, originIndexBarysSad, originIndexBarysMax,
487 bidderDiagramsMin, bidderDiagramsMax, bidderDiagramsSad,
488 matchingsDatasMin, matchingsDatasMax, matchingsDatasSad,
489 allLossesAtEpoch,
false);
493 std::vector<std::vector<Matrix>> allHessianLists(nDiags);
494 std::vector<std::vector<double>> gradWeightsList(nDiags);
498 if(doOptimizeWeights) {
499#ifdef TTK_ENABLE_OPENMP
500#pragma omp parallel for num_threads(threadNumber_)
502 for(
size_t i = 0; i < nDiags; ++i) {
503 auto &gradWeights = gradWeightsList[i];
504 const auto &matchingsAtoms = allMatchingsAtoms[i];
505 const auto &Barycenter = barycentersList[i];
506 const auto &Data = intermediateDiagrams[i];
507 std::vector<Matrix> &hessianList = allHessianLists[i];
508 const std::vector<ttk::MatchingType> &matchingsMin
509 = matchingsDatasMin[i];
510 const std::vector<ttk::MatchingType> &matchingsMax
511 = matchingsDatasMax[i];
512 const std::vector<ttk::MatchingType> &matchingsSad
513 = matchingsDatasSad[i];
514 const std::vector<size_t> &indexBaryMin = originIndexBarysMin[i];
515 const std::vector<size_t> &indexBarySad = originIndexBarysSad[i];
516 const std::vector<size_t> &indexBaryMax = originIndexBarysMax[i];
517 const std::vector<size_t> &indexDataMin = originIndexDatasMin[i];
518 const std::vector<size_t> &indexDataSad = originIndexDatasSad[i];
519 const std::vector<size_t> &indexDataMax = originIndexDatasMax[i];
520 std::vector<double> &weights = vectorWeights[i];
521 computeGradientWeights(gradWeights, hessianList, dictDiagrams,
522 matchingsAtoms, Barycenter, Data, matchingsMin,
523 matchingsMax, matchingsSad, indexBaryMin,
524 indexBaryMax, indexBarySad, indexDataMin,
525 indexDataMax, indexDataSad, doOptimizeAtoms);
526 gradActor.executeWeightsProjected(
527 hessianList, weights, gradWeights, MaxEigenValue_);
530 this->
printMsg(
"Computed 1st opt for epoch " + std::to_string(epoch),
531 epoch /
static_cast<double>(maxEpoch),
532 tm_opt1.getElapsedTime(), threadNumber_,
536 for(
size_t p = 0; p < nDiags; ++p) {
537 allLossesAtEpoch[p] = 0.;
538 trueAllLossesAtEpoch[p] = 0.;
540 barycentersList.clear();
541 barycentersList.resize(nDiags);
542 allMatchingsAtoms.clear();
543 allMatchingsAtoms.resize(nDiags);
545 if(doOptimizeAtoms) {
547#ifdef TTK_ENABLE_OPENMP
548#pragma omp parallel for num_threads(threadNumber_)
550 for(
size_t i = 0; i < nDiags; ++i) {
551 auto &barycenter = barycentersList[i];
552 std::vector<double> &weight = vectorWeights[i];
553 std::vector<std::vector<ttk::MatchingType>> &matchings
554 = allMatchingsAtoms[i];
556 dictDiagrams, weight, barycenter, matchings, *
this, ProgBarycenter_);
559 "Computed 2nd Barycenters for epoch " + std::to_string(epoch),
560 epoch /
static_cast<double>(maxEpoch), tm_it2.getElapsedTime(),
563 barycentersListMin.clear();
564 barycentersListSad.clear();
565 barycentersListMax.clear();
567 barycentersListMin.resize(nDiags);
568 barycentersListSad.resize(nDiags);
569 barycentersListMax.resize(nDiags);
571 bidderBarycentersListMin.clear();
572 bidderBarycentersListSad.clear();
573 bidderBarycentersListMax.clear();
575 originIndexBarysMin.clear();
576 originIndexBarysSad.clear();
577 originIndexBarysMax.clear();
579 originIndexBarysMin.resize(nDiags);
580 originIndexBarysSad.resize(nDiags);
581 originIndexBarysMax.resize(nDiags);
583 matchingsDatasMin.clear();
584 matchingsDatasSad.clear();
585 matchingsDatasMax.clear();
587 matchingsDatasMin.resize(nDiags);
588 matchingsDatasSad.resize(nDiags);
589 matchingsDatasMax.resize(nDiags);
592 barycentersList, nDiags, barycentersListMin, barycentersListSad,
593 barycentersListMax, bidderBarycentersListMin, bidderBarycentersListSad,
594 bidderBarycentersListMax, originIndexBarysMin, originIndexBarysSad,
595 originIndexBarysMax, bidderDiagramsMin, bidderDiagramsMax,
596 bidderDiagramsSad, matchingsDatasMin, matchingsDatasMax,
597 matchingsDatasSad, allLossesAtEpoch,
false);
599 std::vector<std::vector<std::vector<std::array<double, 2>>>>
600 allPairToAddToGradList(nDiags);
601 std::vector<ttk::DiagramType> allInfoToAdd(nDiags);
602 std::vector<std::vector<Matrix>> gradsAtomsList(nDiags);
603 std::vector<std::vector<int>> checkerAtomsList(nDiags);
605 bool doDimReduct =
false;
606 if(DimReductMode_ && numAtom <= 3) {
613#ifdef TTK_ENABLE_OPENMP
614#pragma omp parallel for num_threads(threadNumber_)
616 for(
size_t i = 0; i < nDiags; ++i) {
617 auto &pairToAddGradList = allPairToAddToGradList[i];
618 auto &infoToAdd = allInfoToAdd[i];
619 auto &gradsAtoms = gradsAtomsList[i];
620 auto &checkerAtoms = checkerAtomsList[i];
621 const auto &Barycenter = barycentersList[i];
622 const auto &Data = intermediateDiagrams[i];
623 const std::vector<ttk::MatchingType> &matchingsMin
624 = matchingsDatasMin[i];
625 const std::vector<ttk::MatchingType> &matchingsMax
626 = matchingsDatasMax[i];
627 const std::vector<ttk::MatchingType> &matchingsSad
628 = matchingsDatasSad[i];
629 const std::vector<size_t> &indexBaryMin = originIndexBarysMin[i];
630 const std::vector<size_t> &indexBarySad = originIndexBarysSad[i];
631 const std::vector<size_t> &indexBaryMax = originIndexBarysMax[i];
632 const std::vector<size_t> &indexDataMin = originIndexDatasMin[i];
633 const std::vector<size_t> &indexDataSad = originIndexDatasSad[i];
634 const std::vector<size_t> &indexDataMax = originIndexDatasMax[i];
635 const std::vector<double> &weights = vectorWeights[i];
636 computeGradientAtoms(
637 gradsAtoms, weights, Barycenter, Data, matchingsMin, matchingsMax,
638 matchingsSad, indexBaryMin, indexBaryMax, indexBarySad, indexDataMin,
639 indexDataMax, indexDataSad, checkerAtoms, pairToAddGradList,
640 infoToAdd, doDimReduct);
643 std::vector<double> maxiDeath(numAtom);
644 std::vector<double> minBirth(numAtom);
645 for(
int j = 0; j < numAtom; ++j) {
646 auto &atom = dictDiagrams[j];
647 auto &temp = atom[0];
648 maxiDeath[j] = temp.death.sfValue;
649 minBirth[j] = temp.birth.sfValue;
652 std::vector<std::vector<std::vector<int>>> allProjectionsList(nDiags);
653 std::vector<ttk::DiagramType> allFeaturesToAdd(nDiags);
654 std::vector<std::vector<std::array<double, 2>>> allProjLocations(nDiags);
655 std::vector<std::vector<std::vector<double>>>
656 allVectorForProjContributions(nDiags);
657 std::vector<ttk::DiagramType> allTrueFeaturesToAdd(numAtom);
658 std::vector<std::vector<std::vector<int>>> allTrueProj(numAtom);
659 std::vector<std::vector<std::array<double, 2>>> allTrueProjLoc(numAtom);
661 for(
size_t i = 0; i < nDiags; ++i) {
662 auto &pairToAddGradList = allPairToAddToGradList[i];
663 auto &infoToAdd = allInfoToAdd[i];
664 auto &projForDiag = allProjectionsList[i];
665 auto &vectorForProjContrib = allVectorForProjContributions[i];
666 auto &featuresToAdd = allFeaturesToAdd[i];
667 auto &projLocations = allProjLocations[i];
668 auto &gradsAtoms = gradsAtomsList[i];
669 const auto &matchingsAtoms = allMatchingsAtoms[i];
670 const auto &Barycenter = barycentersList[i];
671 const auto &checkerAtoms = checkerAtomsList[i];
672 gradActor.executeAtoms(
673 dictDiagrams, matchingsAtoms, Barycenter, gradsAtoms, checkerAtoms,
674 projForDiag, featuresToAdd, projLocations, vectorForProjContrib,
675 pairToAddGradList, infoToAdd);
678 if(CreationFeatures_) {
679 for(
size_t i = 0; i < nDiags; ++i) {
680 auto &projForDiag = allProjectionsList[i];
681 auto &featuresToAdd = allFeaturesToAdd[i];
682 auto &projLocations = allProjLocations[i];
683 auto &vectorForProjContrib = allVectorForProjContributions[i];
684 for(
size_t j = 0; j < projForDiag.size(); ++j) {
685 auto &t = featuresToAdd[j];
686 std::array<double, 2> &pair = projLocations[j];
687 std::vector<double> &vectorContrib = vectorForProjContrib[j];
688 std::vector<int> &projAndIndex = projForDiag[j];
689 std::vector<int> proj(numAtom);
690 for(
int m = 0; m < numAtom; ++m) {
691 proj[m] = projAndIndex[m];
693 int atomIndex =
static_cast<int>(projAndIndex[numAtom]);
694 bool lenNull = allTrueProj[atomIndex].size() == 0;
699 pair[0] = pair[0] - step * vectorContrib[0];
700 pair[1] = pair[1] - step * vectorContrib[1];
701 if(pair[0] > pair[1]) {
704 if(pair[0] < minBirth[atomIndex]) {
707 if(pair[1] - pair[0] < 1e-7) {
714 allTrueProj[atomIndex].push_back(proj);
715 allTrueFeaturesToAdd[atomIndex].push_back(newPair);
716 allTrueProjLoc[atomIndex].push_back(pair);
724 pair[0] = pair[0] - step * vectorContrib[0];
725 pair[1] = pair[1] - step * vectorContrib[1];
726 if(pair[0] > pair[1]) {
729 if(pair[0] < minBirth[atomIndex]) {
732 if(pair[1] - pair[0] < 1e-7) {
739 allTrueProj[atomIndex].push_back(proj);
740 allTrueFeaturesToAdd[atomIndex].push_back(newPair);
741 allTrueProjLoc[atomIndex].push_back(pair);
744 auto &tReal = allTrueFeaturesToAdd[atomIndex][index];
746 = tReal.birth.sfValue - step * vectorContrib[0];
748 = tReal.death.sfValue - step * vectorContrib[1];
749 if(tReal.birth.sfValue > tReal.death.sfValue) {
750 tReal.death.sfValue = tReal.birth.sfValue;
758 for(
int i = 0; i < numAtom; ++i) {
759 auto &atom = dictDiagrams[i];
760 auto &histoEpochAtom = histoAllEpochLife[i];
761 auto &histoBoolAtom = histoAllBoolLife[i];
762 auto &boolUnderDiag = checkUnderDiag[i];
763 auto &boolDiag = checkDiag[i];
764 auto initSize = initSizes[i];
765 auto &boolAboveGlobal = checkAboveGlobal[i];
766 if(histoEpochAtom.size() > 0) {
767 for(
size_t j = 0; j < histoEpochAtom.size(); ++j) {
768 auto &t = atom[initSize + j];
769 histoEpochAtom[j] += 1;
770 histoBoolAtom[j] = t.death.sfValue - t.birth.sfValue
771 < 0.1 * (percent / 100.) * maxiDeath[i];
772 boolDiag[j] = t.death.sfValue - t.birth.sfValue < 1e-6;
773 boolUnderDiag[j] = t.death.sfValue < t.birth.sfValue;
774 boolAboveGlobal[j] = t.birth.sfValue > maxiDeath[i];
779 for(
int i = 0; i < numAtom; ++i) {
780 auto &atom = dictDiagrams[i];
781 auto &histoEpochAtom = histoAllEpochLife[i];
782 auto &histoBoolAtom = histoAllBoolLife[i];
783 auto &boolUnderDiag = checkUnderDiag[i];
784 auto &boolDiag = checkDiag[i];
785 auto &trueFeaturesToAdd = allTrueFeaturesToAdd[i];
786 auto &boolAboveGlobal = checkAboveGlobal[i];
787 for(
size_t j = 0; j < trueFeaturesToAdd.size(); ++j) {
788 auto &t = trueFeaturesToAdd[j];
790 histoEpochAtom.push_back(0);
791 histoBoolAtom.push_back(t.death.sfValue - t.birth.sfValue
792 < 0.1 * (percent / 100.) * maxiDeath[i]);
793 boolDiag.push_back(t.death.sfValue - t.birth.sfValue < 1e-6);
794 boolUnderDiag.push_back(t.death.sfValue < t.birth.sfValue);
795 boolAboveGlobal.push_back(t.birth.sfValue > maxiDeath[i]);
799 std::vector<std::vector<size_t>> allIndicesToDelete(numAtom);
800 for(
int i = 0; i < numAtom; ++i) {
801 auto &indicesAtomToDelete = allIndicesToDelete[i];
802 auto &histoEpochAtom = histoAllEpochLife[i];
803 auto &histoBoolAtom = histoAllBoolLife[i];
804 auto &boolUnderDiag = checkUnderDiag[i];
805 auto &boolDiag = checkDiag[i];
806 auto &boolAboveGlobal = checkAboveGlobal[i];
807 for(
size_t j = 0; j < histoEpochAtom.size(); ++j) {
808 if(boolUnderDiag[j] || boolDiag[j] || boolAboveGlobal[j]
809 || (histoEpochAtom[j] > 5 && histoBoolAtom[j])) {
810 indicesAtomToDelete.push_back(j);
815 for(
int i = 0; i < numAtom; ++i) {
816 auto &atom = dictDiagrams[i];
817 auto &histoEpochAtom = histoAllEpochLife[i];
818 auto &histoBoolAtom = histoAllBoolLife[i];
819 auto &indicesAtomToDelete = allIndicesToDelete[i];
820 auto &boolUnderDiag = checkUnderDiag[i];
821 auto &boolDiag = checkDiag[i];
822 auto &boolAboveGlobal = checkAboveGlobal[i];
823 auto initSize = initSizes[i];
824 if(
static_cast<int>(indicesAtomToDelete.size()) > 0) {
825 for(
int j =
static_cast<int>(indicesAtomToDelete.size()) - 1;
827 atom.erase(atom.begin() + initSize + indicesAtomToDelete[j]);
828 histoEpochAtom.erase(histoEpochAtom.begin()
829 + indicesAtomToDelete[j]);
830 histoBoolAtom.erase(histoBoolAtom.begin()
831 + indicesAtomToDelete[j]);
832 boolUnderDiag.erase(boolUnderDiag.begin()
833 + indicesAtomToDelete[j]);
834 boolDiag.erase(boolDiag.begin() + indicesAtomToDelete[j]);
835 boolAboveGlobal.erase(boolAboveGlobal.begin()
836 + indicesAtomToDelete[j]);
843 for(
int i = 0; i < numAtom; ++i) {
844 auto &atom = dictDiagrams[i];
845 auto &trueFeaturesToAdd = allTrueFeaturesToAdd[i];
846 for(
size_t j = 0; j < trueFeaturesToAdd.size(); ++j) {
847 auto &t = trueFeaturesToAdd[j];
851 controlAtomsSize(intermediateDiagrams, dictDiagrams);
856 controlAtomsSize(intermediateDiagrams, dictDiagrams);
861 for(
int i = 0; i < numAtom; ++i) {
862 auto &atom = dictDiagrams[i];
863 auto &globalPair = atom[0];
864 for(
size_t j = 0; j < atom.size(); ++j) {
867 if(t.death.sfValue > globalPair.death.sfValue) {
868 t.death.sfValue = globalPair.death.sfValue;
871 if(t.birth.sfValue > t.death.sfValue) {
872 t.death.sfValue = t.birth.sfValue;
875 if(t.birth.sfValue < globalPair.birth.sfValue) {
876 t.birth.sfValue = globalPair.birth.sfValue;
880 this->
printMsg(
"Computed 2nd opt for epoch " + std::to_string(epoch),
881 epoch /
static_cast<double>(maxEpoch),
882 tm_opt2.getElapsedTime(), threadNumber_,
888 barycentersList.clear();
889 barycentersList.resize(nDiags);
890 allMatchingsAtoms.clear();
891 allMatchingsAtoms.resize(nDiags);
894 " Epoch " + std::to_string(epoch) +
", loss = " + std::to_string(loss), 1.0,
895 tm.getElapsedTime(), threadNumber_);
900 *std::min_element(lossTab.begin() + nbEpochPrevious, lossTab.end()))
903 std::min_element(lossTab.begin() + nbEpochPrevious, lossTab.end())
905 1.0, tm.getElapsedTime(), threadNumber_);
907 for(
size_t p = 0; p < dictDiagrams.size(); ++p) {
908 const auto &atom = histoDictDiagrams[p];
909 dictDiagrams[p] = atom;
911 for(
size_t p = 0; p < nDiags; ++p) {
912 const auto &weights = histoVectorWeights[p];
913 vectorWeights[p] = weights;
916 this->
printMsg(
"Complete", 1.0, tm.getElapsedTime(), this->threadNumber_);
967 std::vector<double> &gradWeights,
968 std::vector<Matrix> &hessianList,
969 const std::vector<ttk::DiagramType> &dictDiagrams,
970 const std::vector<std::vector<ttk::MatchingType>> &matchingsAtoms,
973 const std::vector<ttk::MatchingType> &matchingsMin,
974 const std::vector<ttk::MatchingType> &matchingsMax,
975 const std::vector<ttk::MatchingType> &matchingsSad,
976 const std::vector<size_t> &indexBaryMin,
977 const std::vector<size_t> &indexBaryMax,
978 const std::vector<size_t> &indexBarySad,
979 const std::vector<size_t> &indexDataMin,
980 const std::vector<size_t> &indexDataMax,
981 const std::vector<size_t> &indexDataSad,
982 const bool doOptimizeAtoms)
const {
985 std::vector<std::vector<std::array<double, 2>>> gradBuffersList(
987 std::vector<std::vector<std::array<double, 2>>> pairToAddGradList;
988 for(
size_t i = 0; i < gradBuffersList.size(); ++i) {
989 gradBuffersList[i].resize(matchingsAtoms.size());
991 std::vector<std::array<double, 2>> directions(Barycenter.size());
992 std::vector<std::array<double, 2>> dataAssigned(Barycenter.size());
993 gradWeights.resize(dictDiagrams.size());
994 for(
size_t i = 0; i < dictDiagrams.size(); ++i) {
998 std::vector<std::vector<int>> checker(Barycenter.size());
999 for(
size_t j = 0; j < Barycenter.size(); ++j) {
1000 checker[j].resize(matchingsAtoms.size());
1002 std::vector<int> tracker(Barycenter.size(), 0);
1003 std::vector<int> tracker2(Barycenter.size(), 0);
1006 for(
size_t i = 0; i < matchingsAtoms.size(); ++i) {
1007 for(
size_t j = 0; j < matchingsAtoms[i].size(); ++j) {
1013 if(Id2 < 0 ||
static_cast<int>(gradBuffersList.size()) <= Id2
1014 ||
static_cast<int>(dictDiagrams[i].size()) <= Id1) {
1016 }
else if(Id1 < 0) {
1018 auto &point = gradBuffersList[Id2][i];
1021 const double birth_death_atom
1022 = birthBarycenter + (deathBarycenter - birthBarycenter) / 2.;
1023 point[0] = birth_death_atom;
1024 point[1] = birth_death_atom;
1025 checker[Id2][i] = i;
1029 auto &point = gradBuffersList[Id2][i];
1032 point[0] = birth_atom;
1033 point[1] = death_atom;
1034 checker[Id2][i] = i;
1042 indexBaryMin, indexDataMin, pairToAddGradList,
1043 directions, dataAssigned, tracker2,
1048 indexBaryMax, indexDataMax, pairToAddGradList,
1049 directions, dataAssigned, tracker2,
1054 indexBarySad, indexDataSad, pairToAddGradList,
1055 directions, dataAssigned, tracker2,
1058 std::vector<int> temp(pairToAddGradList.size(), 1);
1059 gradBuffersList.insert(
1060 gradBuffersList.end(), pairToAddGradList.begin(), pairToAddGradList.end());
1061 tracker2.insert(tracker2.end(), temp.begin(), temp.end());
1062 tracker.insert(tracker.end(), temp.begin(), temp.end());
1063 std::vector<int> temp2(matchingsAtoms.size());
1064 for(
size_t j = 0; j < matchingsAtoms.size(); ++j) {
1065 temp2.push_back(
static_cast<int>(j));
1067 for(
size_t j = 0; j < pairToAddGradList.size(); ++j) {
1068 checker.push_back(temp2);
1071 for(
size_t i = 0; i < gradBuffersList.size(); ++i) {
1072 const auto &data_point = dataAssigned[i];
1073 for(
size_t j = 0; j < checker[i].size(); ++j) {
1074 auto &point = gradBuffersList[i][checker[i][j]];
1075 point[0] -= data_point[0];
1076 point[1] -= data_point[1];
1080 for(
size_t i = 0; i < gradBuffersList.size(); ++i) {
1081 if(tracker[i] == 0 || tracker2[i] == 0) {
1084 for(
size_t j = 0; j < checker[i].size(); ++j) {
1085 const auto &point = gradBuffersList[i][checker[i][j]];
1086 const auto &direction = directions[i];
1087 gradWeights[checker[i][j]]
1088 += -2 * (point[0] * direction[0] + point[1] * direction[1]);
1092 hessianList.resize(gradBuffersList.size());
1093 for(
size_t i = 0; i < gradBuffersList.size(); ++i) {
1094 Matrix &hessian = hessianList[i];
1095 hessian.resize(checker[i].size());
1096 for(
size_t j = 0; j < checker[i].size(); ++j) {
1097 auto &line = hessian[j];
1098 line.resize(checker[i].size());
1099 const auto &point = gradBuffersList[i][checker[i][j]];
1100 for(
size_t q = 0; q < checker[i].size(); ++q) {
1101 const auto &point_temp = gradBuffersList[i][checker[i][q]];
1102 line[q] = point[0] * point_temp[0] + point[1] * point_temp[1];