Graph Traversal Guide

This is a practitioner’s guide to using the Execution Framework. Before continuing, it is recommended you first review the Execution Framework Overview along with basic topics such as Graphs Concepts, Pass Concepts, and Execution Concepts.

Graph traversal – the systematic visitation of nodes within the IR – is an integral part of EF.

EF contains several built-in traversal functions:

The following sections examine a few code examples demonstrating how one can explore EF graphs in a customized manner using the available APIs.

Getting Started with Writing Graph Traversals

In order to further elucidate the concepts embedded in these examples, some of the traversals will be applied to the following sample IR graph \(G_1\) in order to see what the corresponding output would look like for a concrete case:

_images/ef-graph-traversal-01.svg

Figure 19 An example IR graph \(G_1\).

Note that each node’s downstream edges are ordered alphabetically with respect to their connected children nodes, e.g. for node \(a\), its first, second, and third edges are \(\{a,b\}\), \(\{a,c\}\), and \(\{a,d\}\), respectively. Also note that the below examples are all assumed to reside within the omni::graph::exec::unstable namespace.

Using Custom NodeUserData

Listing 30 showcases how one can pass custom node data into the traversal methods to tackle problems that would otherwise be much more inconvenient (or downright impossible) to solve if the API were missing that flexibility. In this case we are using the SCC_NodeData struct to store per-node information that is necessary for implementing Tarjan’s algorithm for strongly connected components; this is what ultimately allows us to create the global graph transformation pass responsible for detecting cycles in the graph.

Listing 30 The global pass used for detecting cycles in the execution graph.
class PassStronglyConnectedComponents : public Implements<IGlobalPass>
{
public:
    static omni::core::ObjectPtr<PassStronglyConnectedComponents> create(
        omni::core::ObjectParam<exec::unstable::IGraphBuilder> builder)
    {
        return omni::core::steal(new PassStronglyConnectedComponents(builder.get()));
    }

protected:
    PassStronglyConnectedComponents(IGraphBuilder*)
    {
    }

    void run_abi(IGraphBuilder* builder) noexcept override
    {
        _detectCycles(builder, builder->getTopology());
    }

private:
    void _detectCycles(IGraphBuilder* builder, ITopology* topology)
    {
        struct SCC_NodeData
        {
            size_t index{ 0 };
            size_t lowLink{ 0 };
            uint32_t cycleParentCount{ 0 };
            bool onStack{ false };
        };

        size_t globalIndex = 0;
        std::stack<INode*> globalStack;

        traverseDepthFirst<VisitAll, SCC_NodeData>(
            topology->getRoot(),
            [this, builder, &globalIndex, &globalStack](auto info, INode* prev, INode* curr)
            {
                auto pushStack = [&globalStack](INode* node, SCC_NodeData& data)
                {
                    data.onStack = true;
                    globalStack.push(node);
                };

                auto popStack = [builder, &info, &globalStack]()
                {
                    auto* top = globalStack.top();
                    globalStack.pop();

                    auto& userData = info.userData(top);
                    userData.onStack = false;

                    auto node = exec::unstable::cast<exec::unstable::IGraphBuilderNode>(top);
                    node->setCycleParentCount(userData.cycleParentCount);

                    return top;
                };

                auto& userData = info.userData(curr);
                auto& userDataPrev = info.userData(prev);
                if (info.isFirstVisit())
                {
                    userData.index = userData.lowLink = globalIndex++;
                    pushStack(curr, userData);

                    info.continueVisit(curr);

                    userDataPrev.lowLink = std::min(userDataPrev.lowLink, userData.lowLink);

                    if (userData.lowLink == userData.index)
                    {
                        auto* top = popStack();
                        if (top != curr)
                        {
                            while (top != curr)
                            {
                                top = popStack();
                            }
                        }
                    }

                    auto nodeGraph = curr->getNodeGraphDef();
                    if (nodeGraph)
                    {
                        this->_detectCycles(builder, nodeGraph->getTopology());
                    }
                }
                else if (userData.onStack)
                {
                    userDataPrev.lowLink = std::min(userDataPrev.lowLink, userData.index);
                    userData.cycleParentCount++;
                }
            });
    }
};

Next Steps

To learn more about graph traversals in the context of EF, see Graph Traversal In-Depth.