Skip to content

Commit

Permalink
[MTNN] merge tree neural network refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
MatPont committed Sep 18, 2024
1 parent 483a214 commit 10a87a2
Show file tree
Hide file tree
Showing 38 changed files with 6,135 additions and 3,724 deletions.
16 changes: 16 additions & 0 deletions core/base/ftmTree/FTMNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,22 @@ namespace ttk {
}
}

inline void removeDownSuperArcs(std::vector<idSuperArc> &idSa) {
if(idSa.empty())
return;
std::vector<bool> toDelete(
(*std::max_element(idSa.begin(), idSa.end())) + 1, false);
for(auto &id : idSa)
toDelete[id] = true;
vect_downSuperArcList_.erase(
std::remove_if(vect_downSuperArcList_.begin(),
vect_downSuperArcList_.end(),
[&toDelete](const idSuperArc &i) {
return i < toDelete.size() and toDelete[i];
}),
vect_downSuperArcList_.end());
}

// Find and remove the arc
inline void removeUpSuperArc(idSuperArc idSa) {
for(idSuperArc i = 0; i < vect_upSuperArcList_.size(); ++i) {
Expand Down
78 changes: 40 additions & 38 deletions core/base/ftmTree/FTMTreeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,34 @@ namespace ttk {
// --------------------
// Is
// --------------------
bool FTMTree_MT::isNodeOriginDefined(idNode nodeId) {
bool FTMTree_MT::isNodeOriginDefined(idNode nodeId) const {
unsigned int const origin
= (unsigned int)this->getNode(nodeId)->getOrigin();
return origin != nullNodes && origin < this->getNumberOfNodes();
}

bool FTMTree_MT::isRoot(idNode nodeId) {
bool FTMTree_MT::isRoot(idNode nodeId) const {
return this->getNode(nodeId)->getNumberOfUpSuperArcs() == 0;
}

bool FTMTree_MT::isLeaf(idNode nodeId) {
bool FTMTree_MT::isLeaf(idNode nodeId) const {
return this->getNode(nodeId)->getNumberOfDownSuperArcs() == 0;
}

bool FTMTree_MT::isNodeAlone(idNode nodeId) {
bool FTMTree_MT::isNodeAlone(idNode nodeId) const {
return this->isRoot(nodeId) and this->isLeaf(nodeId);
}

bool FTMTree_MT::isFullMerge() {
bool FTMTree_MT::isFullMerge() const {
idNode const treeRoot = this->getRoot();
return (unsigned int)this->getNode(treeRoot)->getOrigin() == treeRoot;
}

bool FTMTree_MT::isBranchOrigin(idNode nodeId) {
bool FTMTree_MT::isBranchOrigin(idNode nodeId) const {
return this->getParentSafe(this->getNode(nodeId)->getOrigin()) != nodeId;
}

bool FTMTree_MT::isNodeMerged(idNode nodeId) {
bool FTMTree_MT::isNodeMerged(idNode nodeId) const {
bool merged = this->isNodeAlone(nodeId)
or this->isNodeAlone(this->getNode(nodeId)->getOrigin());
auto nodeIdOrigin = this->getNode(nodeId)->getOrigin();
Expand All @@ -49,11 +49,11 @@ namespace ttk {
return merged;
}

bool FTMTree_MT::isNodeIdInconsistent(idNode nodeId) {
bool FTMTree_MT::isNodeIdInconsistent(idNode nodeId) const {
return nodeId >= this->getNumberOfNodes();
}

bool FTMTree_MT::isThereOnlyOnePersistencePair() {
bool FTMTree_MT::isThereOnlyOnePersistencePair() const {
idNode const treeRoot = this->getRoot();
unsigned int cptNodeAlone = 0;
idNode otherNode = treeRoot;
Expand All @@ -74,7 +74,7 @@ namespace ttk {
}

// Do not normalize node is if root or son of a merged root
bool FTMTree_MT::notNeedToNormalize(idNode nodeId) {
bool FTMTree_MT::notNeedToNormalize(idNode nodeId) const {
auto nodeIdParent = this->getParentSafe(nodeId);
return this->isRoot(nodeId)
or (this->isRoot(nodeIdParent)
Expand All @@ -84,7 +84,7 @@ namespace ttk {
// and nodeIdOrigin == nodeIdParent) )
}

bool FTMTree_MT::isMultiPersPair(idNode nodeId) {
bool FTMTree_MT::isMultiPersPair(idNode nodeId) const {
auto nodeOriginOrigin
= (unsigned int)this->getNode(this->getNode(nodeId)->getOrigin())
->getOrigin();
Expand All @@ -94,14 +94,14 @@ namespace ttk {
// --------------------
// Get
// --------------------
idNode FTMTree_MT::getRoot() {
idNode FTMTree_MT::getRoot() const {
for(idNode node = 0; node < this->getNumberOfNodes(); ++node)
if(this->isRoot(node) and !this->isLeaf(node))
return node;
return nullNodes;
}

idNode FTMTree_MT::getParentSafe(idNode nodeId) {
idNode FTMTree_MT::getParentSafe(idNode nodeId) const {
if(!this->isRoot(nodeId)) {
// _ Nodes in merge trees should have only one parent
idSuperArc const arcId = this->getNode(nodeId)->getUpSuperArcId(0);
Expand All @@ -112,7 +112,7 @@ namespace ttk {
}

void FTMTree_MT::getChildren(idNode nodeId,
std::vector<idNode> &childrens) {
std::vector<idNode> &childrens) const {
childrens.clear();
for(idSuperArc i = 0;
i < this->getNode(nodeId)->getNumberOfDownSuperArcs(); ++i) {
Expand All @@ -121,33 +121,34 @@ namespace ttk {
}
}

void FTMTree_MT::getLeavesFromTree(std::vector<idNode> &treeLeaves) {
void FTMTree_MT::getLeavesFromTree(std::vector<idNode> &treeLeaves) const {
treeLeaves.clear();
for(idNode i = 0; i < this->getNumberOfNodes(); ++i) {
if(this->isLeaf(i) and !this->isRoot(i))
treeLeaves.push_back(i);
}
}

int FTMTree_MT::getNumberOfLeavesFromTree() {
int FTMTree_MT::getNumberOfLeavesFromTree() const {
std::vector<idNode> leaves;
this->getLeavesFromTree(leaves);
return leaves.size();
}

int FTMTree_MT::getNumberOfNodeAlone() {
int FTMTree_MT::getNumberOfNodeAlone() const {
int cpt = 0;
for(idNode i = 0; i < this->getNumberOfNodes(); ++i)
cpt += this->isNodeAlone(i) ? 1 : 0;
return cpt;
}

int FTMTree_MT::getRealNumberOfNodes() {
int FTMTree_MT::getRealNumberOfNodes() const {
return this->getNumberOfNodes() - this->getNumberOfNodeAlone();
}

void FTMTree_MT::getBranchOriginsFromThisBranch(
idNode node, std::tuple<std::vector<idNode>, std::vector<idNode>> &res) {
idNode node,
std::tuple<std::vector<idNode>, std::vector<idNode>> &res) const {
std::vector<idNode> branchOrigins, nonBranchOrigins;

idNode const nodeOrigin = this->getNode(node)->getOrigin();
Expand All @@ -166,7 +167,7 @@ namespace ttk {
void FTMTree_MT::getTreeBranching(
std::vector<idNode> &branching,
std::vector<int> &branchingID,
std::vector<std::vector<idNode>> &nodeBranching) {
std::vector<std::vector<idNode>> &nodeBranching) const {
branching = std::vector<idNode>(this->getNumberOfNodes());
branchingID = std::vector<int>(this->getNumberOfNodes(), -1);
nodeBranching
Expand Down Expand Up @@ -200,31 +201,31 @@ namespace ttk {
}

void FTMTree_MT::getTreeBranching(std::vector<idNode> &branching,
std::vector<int> &branchingID) {
std::vector<int> &branchingID) const {
std::vector<std::vector<idNode>> nodeBranching;
this->getTreeBranching(branching, branchingID, nodeBranching);
}

void FTMTree_MT::getAllRoots(std::vector<idNode> &roots) {
void FTMTree_MT::getAllRoots(std::vector<idNode> &roots) const {
roots.clear();
for(idNode node = 0; node < this->getNumberOfNodes(); ++node)
if(this->isRoot(node) and !this->isLeaf(node))
roots.push_back(node);
}

int FTMTree_MT::getNumberOfRoot() {
int FTMTree_MT::getNumberOfRoot() const {
int noRoot = 0;
for(idNode node = 0; node < this->getNumberOfNodes(); ++node)
if(this->isRoot(node) and !this->isLeaf(node))
++noRoot;
return noRoot;
}

int FTMTree_MT::getNumberOfChildren(idNode nodeId) {
int FTMTree_MT::getNumberOfChildren(idNode nodeId) const {
return this->getNode(nodeId)->getNumberOfDownSuperArcs();
}

int FTMTree_MT::getTreeDepth() {
int FTMTree_MT::getTreeDepth() const {
int maxDepth = 0;
std::queue<std::tuple<idNode, int>> queue;
queue.emplace(this->getRoot(), 0);
Expand All @@ -242,7 +243,7 @@ namespace ttk {
return maxDepth;
}

int FTMTree_MT::getNodeLevel(idNode nodeId) {
int FTMTree_MT::getNodeLevel(idNode nodeId) const {
int level = 0;
auto root = this->getRoot();
int const noRoot = this->getNumberOfRoot();
Expand All @@ -261,7 +262,7 @@ namespace ttk {
return level;
}

void FTMTree_MT::getAllNodeLevel(std::vector<int> &allNodeLevel) {
void FTMTree_MT::getAllNodeLevel(std::vector<int> &allNodeLevel) const {
allNodeLevel = std::vector<int>(this->getNumberOfNodes());
std::queue<std::tuple<idNode, int>> queue;
queue.emplace(this->getRoot(), 0);
Expand All @@ -279,7 +280,7 @@ namespace ttk {
}

void FTMTree_MT::getLevelToNode(
std::vector<std::vector<idNode>> &levelToNode) {
std::vector<std::vector<idNode>> &levelToNode) const {
std::vector<int> allNodeLevel;
this->getAllNodeLevel(allNodeLevel);
int const maxLevel
Expand All @@ -290,9 +291,10 @@ namespace ttk {
}
}

void FTMTree_MT::getBranchSubtree(std::vector<idNode> &branching,
idNode branchRoot,
std::vector<idNode> &branchSubtree) {
void
FTMTree_MT::getBranchSubtree(std::vector<idNode> &branching,
idNode branchRoot,
std::vector<idNode> &branchSubtree) const {
branchSubtree.clear();
std::queue<idNode> queue;
queue.push(branchRoot);
Expand All @@ -316,7 +318,7 @@ namespace ttk {
// Persistence
// --------------------
void FTMTree_MT::getMultiPersOriginsVectorFromTree(
std::vector<std::vector<idNode>> &treeMultiPers) {
std::vector<std::vector<idNode>> &treeMultiPers) const {
treeMultiPers
= std::vector<std::vector<idNode>>(this->getNumberOfNodes());
for(unsigned int i = 0; i < this->getNumberOfNodes(); ++i)
Expand Down Expand Up @@ -398,7 +400,7 @@ namespace ttk {
// --------------------
// Create/Delete/Modify Tree
// --------------------
void FTMTree_MT::copyMergeTreeStructure(FTMTree_MT *tree) {
void FTMTree_MT::copyMergeTreeStructure(const FTMTree_MT *tree) {
// Add Nodes
for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i)
this->makeNode(i);
Expand All @@ -418,7 +420,7 @@ namespace ttk {
// --------------------
// Utils
// --------------------
void FTMTree_MT::printNodeSS(idNode node, std::stringstream &ss) {
void FTMTree_MT::printNodeSS(idNode node, std::stringstream &ss) const {
ss << "(" << node << ") \\ ";

std::vector<idNode> children;
Expand All @@ -431,7 +433,7 @@ namespace ttk {
ss << std::endl;
}

std::stringstream FTMTree_MT::printSubTree(idNode subRoot) {
std::stringstream FTMTree_MT::printSubTree(idNode subRoot) const {
std::stringstream ss;
ss << "Nodes----------" << std::endl;
std::queue<idNode> queue;
Expand All @@ -450,7 +452,7 @@ namespace ttk {
return ss;
}

std::stringstream FTMTree_MT::printTree(bool doPrint) {
std::stringstream FTMTree_MT::printTree(bool doPrint) const {
std::stringstream ss;
std::vector<idNode> allRoots;
this->getAllRoots(allRoots);
Expand All @@ -471,7 +473,7 @@ namespace ttk {
return ss;
}

std::stringstream FTMTree_MT::printTreeStats(bool doPrint) {
std::stringstream FTMTree_MT::printTreeStats(bool doPrint) const {
auto noNodesT = this->getNumberOfNodes();
auto noNodes = this->getRealNumberOfNodes();
std::stringstream ss;
Expand All @@ -483,7 +485,7 @@ namespace ttk {
}

std::stringstream
FTMTree_MT::printMultiPersOriginsVectorFromTree(bool doPrint) {
FTMTree_MT::printMultiPersOriginsVectorFromTree(bool doPrint) const {
std::stringstream ss;
std::vector<std::vector<idNode>> vec;
this->getMultiPersOriginsVectorFromTree(vec);
Expand Down
6 changes: 3 additions & 3 deletions core/base/ftmTree/FTMTreeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ namespace ttk {
}

template <class dataType>
void getTreeScalars(ftm::FTMTree_MT *tree,
void getTreeScalars(const ftm::FTMTree_MT *tree,
std::vector<dataType> &scalarsVector) {
scalarsVector.clear();
for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i)
Expand All @@ -162,7 +162,7 @@ namespace ttk {
}

template <class dataType>
MergeTree<dataType> copyMergeTree(ftm::FTMTree_MT *tree,
MergeTree<dataType> copyMergeTree(const ftm::FTMTree_MT *tree,
bool doSplitMultiPersPairs = false) {
std::vector<dataType> scalarsVector;
getTreeScalars<dataType>(tree, scalarsVector);
Expand Down Expand Up @@ -201,7 +201,7 @@ namespace ttk {
}

template <class dataType>
MergeTree<dataType> copyMergeTree(MergeTree<dataType> &mergeTree,
MergeTree<dataType> copyMergeTree(const MergeTree<dataType> &mergeTree,
bool doSplitMultiPersPairs = false) {
return copyMergeTree<dataType>(&(mergeTree.tree), doSplitMultiPersPairs);
}
Expand Down
Loading

0 comments on commit 10a87a2

Please # to comment.