Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[MTNN] merge tree neural network refactor #1065

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions core/base/ftmTree/FTMNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,22 @@ namespace ttk {
}
}

inline void removeDownSuperArcs(std::vector<idSuperArc> &idSa) {
if(idSa.empty())
return;
std::vector<bool> toDelete(
(*std::max_element(idSa.begin(), idSa.end())) + 1, false);
for(auto &id : idSa)
toDelete[id] = true;
vect_downSuperArcList_.erase(
std::remove_if(vect_downSuperArcList_.begin(),
vect_downSuperArcList_.end(),
[&toDelete](const idSuperArc &i) {
return i < toDelete.size() and toDelete[i];
}),
vect_downSuperArcList_.end());
}

// Find and remove the arc
inline void removeUpSuperArc(idSuperArc idSa) {
for(idSuperArc i = 0; i < vect_upSuperArcList_.size(); ++i) {
Expand Down
78 changes: 40 additions & 38 deletions core/base/ftmTree/FTMTreeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,34 @@ namespace ttk {
// --------------------
// Is
// --------------------
bool FTMTree_MT::isNodeOriginDefined(idNode nodeId) {
bool FTMTree_MT::isNodeOriginDefined(idNode nodeId) const {
unsigned int const origin
= (unsigned int)this->getNode(nodeId)->getOrigin();
return origin != nullNodes && origin < this->getNumberOfNodes();
}

bool FTMTree_MT::isRoot(idNode nodeId) {
bool FTMTree_MT::isRoot(idNode nodeId) const {
return this->getNode(nodeId)->getNumberOfUpSuperArcs() == 0;
}

bool FTMTree_MT::isLeaf(idNode nodeId) {
bool FTMTree_MT::isLeaf(idNode nodeId) const {
return this->getNode(nodeId)->getNumberOfDownSuperArcs() == 0;
}

bool FTMTree_MT::isNodeAlone(idNode nodeId) {
bool FTMTree_MT::isNodeAlone(idNode nodeId) const {
return this->isRoot(nodeId) and this->isLeaf(nodeId);
}

bool FTMTree_MT::isFullMerge() {
bool FTMTree_MT::isFullMerge() const {
idNode const treeRoot = this->getRoot();
return (unsigned int)this->getNode(treeRoot)->getOrigin() == treeRoot;
}

bool FTMTree_MT::isBranchOrigin(idNode nodeId) {
bool FTMTree_MT::isBranchOrigin(idNode nodeId) const {
return this->getParentSafe(this->getNode(nodeId)->getOrigin()) != nodeId;
}

bool FTMTree_MT::isNodeMerged(idNode nodeId) {
bool FTMTree_MT::isNodeMerged(idNode nodeId) const {
bool merged = this->isNodeAlone(nodeId)
or this->isNodeAlone(this->getNode(nodeId)->getOrigin());
auto nodeIdOrigin = this->getNode(nodeId)->getOrigin();
Expand All @@ -49,11 +49,11 @@ namespace ttk {
return merged;
}

bool FTMTree_MT::isNodeIdInconsistent(idNode nodeId) {
bool FTMTree_MT::isNodeIdInconsistent(idNode nodeId) const {
return nodeId >= this->getNumberOfNodes();
}

bool FTMTree_MT::isThereOnlyOnePersistencePair() {
bool FTMTree_MT::isThereOnlyOnePersistencePair() const {
idNode const treeRoot = this->getRoot();
unsigned int cptNodeAlone = 0;
idNode otherNode = treeRoot;
Expand All @@ -74,7 +74,7 @@ namespace ttk {
}

// Do not normalize node is if root or son of a merged root
bool FTMTree_MT::notNeedToNormalize(idNode nodeId) {
bool FTMTree_MT::notNeedToNormalize(idNode nodeId) const {
auto nodeIdParent = this->getParentSafe(nodeId);
return this->isRoot(nodeId)
or (this->isRoot(nodeIdParent)
Expand All @@ -84,7 +84,7 @@ namespace ttk {
// and nodeIdOrigin == nodeIdParent) )
}

bool FTMTree_MT::isMultiPersPair(idNode nodeId) {
bool FTMTree_MT::isMultiPersPair(idNode nodeId) const {
auto nodeOriginOrigin
= (unsigned int)this->getNode(this->getNode(nodeId)->getOrigin())
->getOrigin();
Expand All @@ -94,14 +94,14 @@ namespace ttk {
// --------------------
// Get
// --------------------
idNode FTMTree_MT::getRoot() {
idNode FTMTree_MT::getRoot() const {
for(idNode node = 0; node < this->getNumberOfNodes(); ++node)
if(this->isRoot(node) and !this->isLeaf(node))
return node;
return nullNodes;
}

idNode FTMTree_MT::getParentSafe(idNode nodeId) {
idNode FTMTree_MT::getParentSafe(idNode nodeId) const {
if(!this->isRoot(nodeId)) {
// _ Nodes in merge trees should have only one parent
idSuperArc const arcId = this->getNode(nodeId)->getUpSuperArcId(0);
Expand All @@ -112,7 +112,7 @@ namespace ttk {
}

void FTMTree_MT::getChildren(idNode nodeId,
std::vector<idNode> &childrens) {
std::vector<idNode> &childrens) const {
childrens.clear();
for(idSuperArc i = 0;
i < this->getNode(nodeId)->getNumberOfDownSuperArcs(); ++i) {
Expand All @@ -121,33 +121,34 @@ namespace ttk {
}
}

void FTMTree_MT::getLeavesFromTree(std::vector<idNode> &treeLeaves) {
void FTMTree_MT::getLeavesFromTree(std::vector<idNode> &treeLeaves) const {
treeLeaves.clear();
for(idNode i = 0; i < this->getNumberOfNodes(); ++i) {
if(this->isLeaf(i) and !this->isRoot(i))
treeLeaves.push_back(i);
}
}

int FTMTree_MT::getNumberOfLeavesFromTree() {
int FTMTree_MT::getNumberOfLeavesFromTree() const {
std::vector<idNode> leaves;
this->getLeavesFromTree(leaves);
return leaves.size();
}

int FTMTree_MT::getNumberOfNodeAlone() {
int FTMTree_MT::getNumberOfNodeAlone() const {
int cpt = 0;
for(idNode i = 0; i < this->getNumberOfNodes(); ++i)
cpt += this->isNodeAlone(i) ? 1 : 0;
return cpt;
}

int FTMTree_MT::getRealNumberOfNodes() {
int FTMTree_MT::getRealNumberOfNodes() const {
return this->getNumberOfNodes() - this->getNumberOfNodeAlone();
}

void FTMTree_MT::getBranchOriginsFromThisBranch(
idNode node, std::tuple<std::vector<idNode>, std::vector<idNode>> &res) {
idNode node,
std::tuple<std::vector<idNode>, std::vector<idNode>> &res) const {
std::vector<idNode> branchOrigins, nonBranchOrigins;

idNode const nodeOrigin = this->getNode(node)->getOrigin();
Expand All @@ -166,7 +167,7 @@ namespace ttk {
void FTMTree_MT::getTreeBranching(
std::vector<idNode> &branching,
std::vector<int> &branchingID,
std::vector<std::vector<idNode>> &nodeBranching) {
std::vector<std::vector<idNode>> &nodeBranching) const {
branching = std::vector<idNode>(this->getNumberOfNodes());
branchingID = std::vector<int>(this->getNumberOfNodes(), -1);
nodeBranching
Expand Down Expand Up @@ -200,31 +201,31 @@ namespace ttk {
}

void FTMTree_MT::getTreeBranching(std::vector<idNode> &branching,
std::vector<int> &branchingID) {
std::vector<int> &branchingID) const {
std::vector<std::vector<idNode>> nodeBranching;
this->getTreeBranching(branching, branchingID, nodeBranching);
}

void FTMTree_MT::getAllRoots(std::vector<idNode> &roots) {
void FTMTree_MT::getAllRoots(std::vector<idNode> &roots) const {
roots.clear();
for(idNode node = 0; node < this->getNumberOfNodes(); ++node)
if(this->isRoot(node) and !this->isLeaf(node))
roots.push_back(node);
}

int FTMTree_MT::getNumberOfRoot() {
int FTMTree_MT::getNumberOfRoot() const {
int noRoot = 0;
for(idNode node = 0; node < this->getNumberOfNodes(); ++node)
if(this->isRoot(node) and !this->isLeaf(node))
++noRoot;
return noRoot;
}

int FTMTree_MT::getNumberOfChildren(idNode nodeId) {
int FTMTree_MT::getNumberOfChildren(idNode nodeId) const {
return this->getNode(nodeId)->getNumberOfDownSuperArcs();
}

int FTMTree_MT::getTreeDepth() {
int FTMTree_MT::getTreeDepth() const {
int maxDepth = 0;
std::queue<std::tuple<idNode, int>> queue;
queue.emplace(this->getRoot(), 0);
Expand All @@ -242,7 +243,7 @@ namespace ttk {
return maxDepth;
}

int FTMTree_MT::getNodeLevel(idNode nodeId) {
int FTMTree_MT::getNodeLevel(idNode nodeId) const {
int level = 0;
auto root = this->getRoot();
int const noRoot = this->getNumberOfRoot();
Expand All @@ -261,7 +262,7 @@ namespace ttk {
return level;
}

void FTMTree_MT::getAllNodeLevel(std::vector<int> &allNodeLevel) {
void FTMTree_MT::getAllNodeLevel(std::vector<int> &allNodeLevel) const {
allNodeLevel = std::vector<int>(this->getNumberOfNodes());
std::queue<std::tuple<idNode, int>> queue;
queue.emplace(this->getRoot(), 0);
Expand All @@ -279,7 +280,7 @@ namespace ttk {
}

void FTMTree_MT::getLevelToNode(
std::vector<std::vector<idNode>> &levelToNode) {
std::vector<std::vector<idNode>> &levelToNode) const {
std::vector<int> allNodeLevel;
this->getAllNodeLevel(allNodeLevel);
int const maxLevel
Expand All @@ -290,9 +291,10 @@ namespace ttk {
}
}

void FTMTree_MT::getBranchSubtree(std::vector<idNode> &branching,
idNode branchRoot,
std::vector<idNode> &branchSubtree) {
void
FTMTree_MT::getBranchSubtree(std::vector<idNode> &branching,
idNode branchRoot,
std::vector<idNode> &branchSubtree) const {
branchSubtree.clear();
std::queue<idNode> queue;
queue.push(branchRoot);
Expand All @@ -316,7 +318,7 @@ namespace ttk {
// Persistence
// --------------------
void FTMTree_MT::getMultiPersOriginsVectorFromTree(
std::vector<std::vector<idNode>> &treeMultiPers) {
std::vector<std::vector<idNode>> &treeMultiPers) const {
treeMultiPers
= std::vector<std::vector<idNode>>(this->getNumberOfNodes());
for(unsigned int i = 0; i < this->getNumberOfNodes(); ++i)
Expand Down Expand Up @@ -398,7 +400,7 @@ namespace ttk {
// --------------------
// Create/Delete/Modify Tree
// --------------------
void FTMTree_MT::copyMergeTreeStructure(FTMTree_MT *tree) {
void FTMTree_MT::copyMergeTreeStructure(const FTMTree_MT *tree) {
// Add Nodes
for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i)
this->makeNode(i);
Expand All @@ -418,7 +420,7 @@ namespace ttk {
// --------------------
// Utils
// --------------------
void FTMTree_MT::printNodeSS(idNode node, std::stringstream &ss) {
void FTMTree_MT::printNodeSS(idNode node, std::stringstream &ss) const {
ss << "(" << node << ") \\ ";

std::vector<idNode> children;
Expand All @@ -431,7 +433,7 @@ namespace ttk {
ss << std::endl;
}

std::stringstream FTMTree_MT::printSubTree(idNode subRoot) {
std::stringstream FTMTree_MT::printSubTree(idNode subRoot) const {
std::stringstream ss;
ss << "Nodes----------" << std::endl;
std::queue<idNode> queue;
Expand All @@ -450,7 +452,7 @@ namespace ttk {
return ss;
}

std::stringstream FTMTree_MT::printTree(bool doPrint) {
std::stringstream FTMTree_MT::printTree(bool doPrint) const {
std::stringstream ss;
std::vector<idNode> allRoots;
this->getAllRoots(allRoots);
Expand All @@ -471,7 +473,7 @@ namespace ttk {
return ss;
}

std::stringstream FTMTree_MT::printTreeStats(bool doPrint) {
std::stringstream FTMTree_MT::printTreeStats(bool doPrint) const {
auto noNodesT = this->getNumberOfNodes();
auto noNodes = this->getRealNumberOfNodes();
std::stringstream ss;
Expand All @@ -483,7 +485,7 @@ namespace ttk {
}

std::stringstream
FTMTree_MT::printMultiPersOriginsVectorFromTree(bool doPrint) {
FTMTree_MT::printMultiPersOriginsVectorFromTree(bool doPrint) const {
std::stringstream ss;
std::vector<std::vector<idNode>> vec;
this->getMultiPersOriginsVectorFromTree(vec);
Expand Down
6 changes: 3 additions & 3 deletions core/base/ftmTree/FTMTreeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ namespace ttk {
}

template <class dataType>
void getTreeScalars(ftm::FTMTree_MT *tree,
void getTreeScalars(const ftm::FTMTree_MT *tree,
std::vector<dataType> &scalarsVector) {
scalarsVector.clear();
for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i)
Expand All @@ -162,7 +162,7 @@ namespace ttk {
}

template <class dataType>
MergeTree<dataType> copyMergeTree(ftm::FTMTree_MT *tree,
MergeTree<dataType> copyMergeTree(const ftm::FTMTree_MT *tree,
bool doSplitMultiPersPairs = false) {
std::vector<dataType> scalarsVector;
getTreeScalars<dataType>(tree, scalarsVector);
Expand Down Expand Up @@ -201,7 +201,7 @@ namespace ttk {
}

template <class dataType>
MergeTree<dataType> copyMergeTree(MergeTree<dataType> &mergeTree,
MergeTree<dataType> copyMergeTree(const MergeTree<dataType> &mergeTree,
bool doSplitMultiPersPairs = false) {
return copyMergeTree<dataType>(&(mergeTree.tree), doSplitMultiPersPairs);
}
Expand Down
Loading
Loading