Graph Traversal Guide
This is a practitioner’s guide to using the Execution Framework. Before continuing, it is recommended you first review the Execution Framework Overview along with basic topics such as Graphs Concepts, Pass Concepts, and Execution Concepts.
Graph traversal – the systematic visitation of nodes within the IR – is an integral part of EF.
EF contains several built-in traversal functions:
traverseDepthFirst()
traverses a graph in depth-first order.traverseBreadthFirst()
traverses a graph in breadth-first order.traverseDepthFirstAsync()
traverses a graph in depth-first order, potentially initiating asynchronous work before visiting the next node.traverseBreadthFirstAsync()
traverses a graph in breadth-first order, potentially initiating asynchronous work before visiting the next node.
The following sections examine a few code examples demonstrating how one can explore EF graphs in a customized manner using the available APIs.
Getting Started with Writing Graph Traversals
In order to further elucidate the concepts embedded in these examples, some of the traversals will be applied to the following sample IR graph \(G_1\) in order to see what the corresponding output would look like for a concrete case:
Note that each node’s downstream edges are ordered alphabetically with respect to their connected children nodes, e.g. for node \(a\), its first, second, and third edges are \(\{a,b\}\), \(\{a,c\}\), and \(\{a,d\}\), respectively. Also note that the below examples are all assumed to reside within the omni::graph::exec::unstable namespace.
Print all Node Names
Listing 26 shows how one can print out all top-level node names present in a
given IR graph in serial DFS ordering using the VisitFirst
policy. Here the term top-level refers to nodes
that lie directly in the top level execution graph definition; any nodes not contained in the execution graph’s
NodeGraphDef
(implying that they are contained within other nodes’ NodeGraphDef
s) will not have their names
printed with the below code-block.
std::vector<INode*> nodes;
traverseDepthFirst<VisitFirst>(myGraph->getRoot(),
[&nodes](auto info, INode* prev, INode* curr)
{
std::cout << curr->getName() << std::endl;
nodes.emplace_back(curr);
info.continueVisit(curr);
});
If we applied the above code-block to \(G_1\), we would get the following ordered list of visited node names:
Note that the root node \(a\) is ignored since we started our visitations at \(a\), which would make
prev
point to \(a\) during the very first traversal step, and since we aren’t printing prev
\(a\)
doesn’t show up in the output.
Print all Node Traversal Paths Recursively
Listing 27 shows how one can recursively print the traversal paths
(list of upstream nodes that were visited prior to reaching the current node) of all nodes present in a given IR
graph in serial DFS ordering using the VisitFirst
strategy; this will include all nodes that lie within other
non-execution graph definitions (i.e. inside other nodes’ NodeGraphDef
s that are nested inside the
execution graph definition), hence the need for recursion. The resultant list of nodes can be referred to as the
member nodes of the flattened IR.
auto traversalFn =
[](INodeGraphDef* nodeGraphDef, INode* topLevelGraphRoot, std::vector<INode*>& currentTraversalPath,
std::vector<std::pair<INode*, std::vector<INode*>>>& nodeTraversalPaths, auto& recursionFn) -> void
{
traverseDepthFirst<VisitFirst>(
nodeGraphDef->getRoot(),
[nodeGraphDef, topLevelGraphRoot, ¤tTraversalPath, &nodeTraversalPaths, &recursionFn](
auto info, INode* prev, INode* curr)
{
// Remove node elements from the current path until we get back to a common
// branching point for the current node.
if (prev == topLevelGraphRoot)
{
currentTraversalPath.clear();
}
else if (!prev->isRoot())
{
while (!currentTraversalPath.empty() && currentTraversalPath.back()->getName() != prev->getName())
{
currentTraversalPath.pop_back();
}
}
// Add the node to the current traversal path. If the previous node was also a
// graph root node, add it as well.
if (prev->isRoot())
{
currentTraversalPath.emplace_back(prev);
}
currentTraversalPath.emplace_back(curr);
// Store the current node's corresponding traversal path.
nodeTraversalPaths.emplace_back(
std::piecewise_construct, std::forward_as_tuple(curr), std::forward_as_tuple(currentTraversalPath));
// Continue the traversal.
INodeGraphDef* currNodeGraphDef = curr->getNodeGraphDef();
if (currNodeGraphDef)
{
recursionFn(
currNodeGraphDef, topLevelGraphRoot, currentTraversalPath, nodeTraversalPaths, recursionFn);
}
info.continueVisit(curr);
});
};
std::vector<INode*> currentTraversalPath;
std::vector<std::pair<INode*, std::vector<INode*>>> nodeTraversalPaths;
traversalFn(myGraph->getNodeGraphDef(), myGraph->getNodeGraphDef()->getRoot(), currentTraversalPath,
nodeTraversalPaths, traversalFn);
// Print the results. Note that nodeTraversalPaths will be ordered in a serial, DFS, VisitFirst-like manner
// (even though we used the VisitAll strategy, since we continue traversal along the first edge).
for (const std::pair<INode*, std::vector<INode*>>& namePathPair : nodeTraversalPaths)
{
// Print the node's name.
std::cout << namePathPair.first->getName() << ": ";
// Print the node's traversal path.
for (INode* const pathElement : namePathPair.second)
{
std::cout << pathElement->getName() << "/";
}
std::cout << std::endl;
}
Applying this logic to \(G_1\), the list of node traversal paths (paired with their names as well for further clarity, and ordered based on when each node was visited) would look something like this:
\(b: a/b\)
\(e: a/b/e\)
\(i: a/b/e/h/i\)
\(j: a/b/e/h/i/j\)
\(g: a/b/e/g\)
\(c: a/c\)
\(f: a/c/f\)
\(l: a/c/f/k/l\)
\(m: a/c/f/k/l/m\)
\(i: a/c/f/k/l/m/h/i\)
\(j: a/c/f/k/l/m/h/i/j\)
\(d: a/c/f/d\)
Note
EF typically uses a more space-efficient path representation called the ExecutionPath
when discussing nodal paths;
the above example prints the explicit traversal path to highlight how the graph is crawled through.
Print all Edges Recursively
Listing 28 uses the VisitAll
strategy to recursively store and print out all edges
in an IR graph in serial BFS order. Note that the choice of serial BFS is arbitrary (other search algorithms
could have been chosen to still print all top-level edges, albeit in a different order); only the selection of
VisitAll
matters since it enables us to actually explore all of the edges. Also note that traversal continues
along the first discovered edge (similar to the VisitFirst
policy).
std::vector<std::pair<INode*, INode*>> edges;
auto traversalFn = [&edges](INodeGraphDef* nodeGraphDef, auto& recursionFn) -> void
{
traverseBreadthFirst<VisitAll>(nodeGraphDef->getRoot(),
[&edges, nodeGraphDef, &recursionFn](auto info, INode* prev, INode* curr)
{
std::cout << "{" << prev->getName() << ", " << curr->getName() << "}"
<< std::endl;
edges.emplace_back(prev, curr);
if (info.isFirstVisit())
{
INodeGraphDef* currNodeGraphDef = curr->getNodeGraphDef();
if (currNodeGraphDef)
{
recursionFn(currNodeGraphDef, recursionFn);
}
info.continueVisit(curr);
}
});
};
traversalFn(myGraph->getNodeGraphDef(), traversalFn);
Running this traversal on \(G_1\) would produce the following list of edges (in the order that they are visited):
Note that for node instances which share the same definition (e.g. \(i\), \(j\), etc.), we’ve used their full traversal path for clarity’s sake.
Print all Node Names Recursively in Topological Order
Listing 29 highlights how one can recursively print out all
node names in topological order using the VisitLast
strategy, meaning that no node will be visited until all of
its parents have been visited. Note that any traversal, whether it be a serial DFS, serial BFS, parallel
DFS, parallel BFS, or something else entirely, can be considered topological as long as it employs the VisitLast
strategy; this example has opted to utilize a serial DFS approach.
std::vector<INode*> nodes;
auto traversalFn = [&nodes](INodeGraphDef* nodeGraphDef, auto& recursionFn) -> void
{
traverseDepthFirst<VisitLast>(nodeGraphDef->getRoot(),
[&nodes, nodeGraphDef, &recursionFn](auto info, INode* prev, INode* curr)
{
std::cout << curr->getName() << std::endl;
nodes.emplace_back(curr);
INodeGraphDef* currNodeGraphDef = curr->getNodeGraphDef();
if (currNodeGraphDef)
{
recursionFn(currNodeGraphDef, recursionFn);
}
info.continueVisit(curr);
});
};
traversalFn(myGraph->getNodeGraphDef(), traversalFn);
Again, in the case of \(G_1\), we would obtain the following ordered node name list:
Using Custom NodeUserData
Listing 30 showcases how one can pass custom node data into the traversal methods to tackle
problems that would otherwise be much more inconvenient (or downright impossible) to solve if the API were missing that
flexibility. In this case we are using the SCC_NodeData
struct to store per-node information that is necessary
for implementing Tarjan’s algorithm for strongly connected components; this is what ultimately allows us
to create the global graph transformation pass responsible for detecting cycles in the graph.
class PassStronglyConnectedComponents : public Implements<IGlobalPass>
{
public:
static omni::core::ObjectPtr<PassStronglyConnectedComponents> create(
omni::core::ObjectParam<exec::unstable::IGraphBuilder> builder)
{
return omni::core::steal(new PassStronglyConnectedComponents(builder.get()));
}
protected:
PassStronglyConnectedComponents(IGraphBuilder*)
{
}
void run_abi(IGraphBuilder* builder) noexcept override
{
_detectCycles(builder, builder->getTopology());
}
private:
void _detectCycles(IGraphBuilder* builder, ITopology* topology)
{
struct SCC_NodeData
{
size_t index{ 0 };
size_t lowLink{ 0 };
uint32_t cycleParentCount{ 0 };
bool onStack{ false };
};
size_t globalIndex = 0;
std::stack<INode*> globalStack;
traverseDepthFirst<VisitAll, SCC_NodeData>(
topology->getRoot(),
[this, builder, &globalIndex, &globalStack](auto info, INode* prev, INode* curr)
{
auto pushStack = [&globalStack](INode* node, SCC_NodeData& data)
{
data.onStack = true;
globalStack.push(node);
};
auto popStack = [builder, &info, &globalStack]()
{
auto* top = globalStack.top();
globalStack.pop();
auto& userData = info.userData(top);
userData.onStack = false;
auto node = exec::unstable::cast<exec::unstable::IGraphBuilderNode>(top);
node->setCycleParentCount(userData.cycleParentCount);
return top;
};
auto& userData = info.userData(curr);
auto& userDataPrev = info.userData(prev);
if (info.isFirstVisit())
{
userData.index = userData.lowLink = globalIndex++;
pushStack(curr, userData);
info.continueVisit(curr);
userDataPrev.lowLink = std::min(userDataPrev.lowLink, userData.lowLink);
if (userData.lowLink == userData.index)
{
auto* top = popStack();
if (top != curr)
{
while (top != curr)
{
top = popStack();
}
}
}
auto nodeGraph = curr->getNodeGraphDef();
if (nodeGraph)
{
this->_detectCycles(builder, nodeGraph->getTopology());
}
}
else if (userData.onStack)
{
userDataPrev.lowLink = std::min(userDataPrev.lowLink, userData.index);
userData.cycleParentCount++;
}
});
}
};
Next Steps
To learn more about graph traversals in the context of EF, see Graph Traversal In-Depth.