omni/graph/exec/unstable/ExecutionContext.h

File members: omni/graph/exec/unstable/ExecutionContext.h

// Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include <carb/thread/RecursiveSharedMutex.h>
#include <carb/thread/Spinlock.h>

#include <omni/graph/exec/unstable/Assert.h>
#include <omni/graph/exec/unstable/ExecutionPath.h>
#include <omni/graph/exec/unstable/Executor.h>
#include <omni/graph/exec/unstable/IExecutionContext.h>
#include <omni/graph/exec/unstable/IGraph.h>
#include <omni/graph/exec/unstable/INodeGraphDefDebug.h>
#include <omni/graph/exec/unstable/SmallVector.h>
#include <omni/graph/exec/unstable/Traversal.h>

#include <thread>
#include <unordered_map>

namespace omni
{
namespace graph
{
namespace exec
{
namespace unstable
{

namespace detail
{

class ExecutionPathCache
{
public:
    ExecutionPathCache() = delete;

    explicit ExecutionPathCache(IGraph& graph) noexcept : m_graph(graph)
    {
    }

    template <typename Key>
    void applyOnEach(const Key& key, IApplyOnEachFunction& applyFn)
    {
        if (m_graph.inBuild())
        {
            // Traversing the entire graph while building it is isn't allowed since multiple threads may be building it.
            OMNI_GRAPH_EXEC_ASSERT(!m_graph.inBuild());
            return;
        }

        if (!m_graph.getTopology()->isValid())
        {
            return;
        }

        // If we're targeting the top-level NodeGraphDef that belongs to the execution graph,
        // simply execute the function here and skip subsequent processing.
        if (_isMatch(key, m_graph.getNodeGraphDef()))
        {
            applyFn.invoke(ExecutionPath::getEmpty());
            return;
        }

        auto discoverAndApplyOnNodesWithDefinitionFn = [this, &key, &applyFn](
                                                           const ExecutionPath& upstreamPath, INodeGraphDef& graph,
                                                           Paths& collectedPaths, auto recursionFn) -> void
        {
            traverseDepthFirst<VisitFirst>(
                graph.getRoot(),
                [this, &upstreamPath, &key, &recursionFn, &applyFn, &collectedPaths](auto info, INode* prev, INode* curr)
                {
                    auto currNodeGraph = curr->getNodeGraphDef();
                    if (currNodeGraph)
                    {
                        ExecutionPath newUpstreamPath(upstreamPath, curr);
                        recursionFn(newUpstreamPath, *currNodeGraph, collectedPaths, recursionFn);
                    }

                    auto def = curr->getDef();
                    if (def && _isMatch(key, def))
                    {
                        collectedPaths.emplace_back(upstreamPath, curr);
                        applyFn.invoke(collectedPaths.back());
                    }

                    info.continueVisit(curr);
                });
        };

        // Check if the this cache is in-sync with the current topology. Since we can run this method in parallel, we
        // need a read lock to m_mutex to safely read m_topologyStamp.
        std::shared_lock<MutexType> readLock(m_mutex);

        auto topologyStamp = *m_graph.getGlobalTopologyStamp();
        if (!m_topologyStamp.inSync(topologyStamp))
        {
            // Cache is out-of-sync; upgrade to a write lock.
            readLock.unlock();
            {
                // Here we once again check to see if the cache is in-sync since another thread may have beat this
                // thread to the write lock and brought the cache into sync.
                std::lock_guard<MutexType> writeLock(m_mutex);
                if (m_topologyStamp.makeSync(topologyStamp))
                {
                    // We're the thread that got to the write lock first, so it's our job to clear the cache.
                    m_defCache.clear();
                    m_nameCache.clear();
                }
            }

            // Grab the read lock again so we can safely read the cache.
            readLock.lock();
        }

        auto& cache = _getCache(key);
        auto findIt = cache.find(key);
        if (findIt != cache.end())
        {
            // We've seen this name before. Make a copy of the paths so we can release the readLock. This is
            // required because an invocation can result in re-entering and taking the writeLock.
            auto pathsCopy = findIt->second;
            readLock.unlock();
            for (ExecutionPath& path : pathsCopy)
            {
                applyFn.invoke(path);
            }
        }
        else
        {
            // Release readLock because apply below can result in re-entry of this function.
            readLock.unlock();
            // Either the key wasn't found or we're building the graph.
            Paths paths;
            discoverAndApplyOnNodesWithDefinitionFn(
                ExecutionPath::getEmpty(), *m_graph.getNodeGraphDef(), paths, discoverAndApplyOnNodesWithDefinitionFn);

            // Insert only once we collected all the paths. Some other thread may be looking for this definition at
            // the same time.
            std::lock_guard<MutexType> writeLock(m_mutex);
            cache.emplace(key, std::move(paths));
        }
    }

private:
    bool _isMatch(const ConstName& desired, IDef* candidate) noexcept
    {
        return (desired == candidate->getName());
    }

    bool _isMatch(IDef* desired, IDef* candidate) noexcept
    {
        return (desired == candidate);
    }

    auto& _getCache(const ConstName&) noexcept
    {
        return m_nameCache;
    }

    auto& _getCache(IDef*) noexcept
    {
        return m_defCache;
    }

    using Paths = SmallVector<ExecutionPath, 2>;
    using DefCache = std::unordered_map<IDef*, Paths>;
    using NameCache = std::unordered_map<ConstName, Paths>;
    using MutexType = carb::thread::recursive_shared_mutex;

    IGraph& m_graph;
    DefCache m_defCache;
    NameCache m_nameCache;
    MutexType m_mutex;
    SyncStamp m_topologyStamp;
};

class ExecutorSingleNode final : public ExecutorFallback
{
public:
    static omni::core::ObjectPtr<ExecutorSingleNode> create(omni::core::ObjectParam<ITopology> toExecute,
                                                            const ExecutionTask& thisTask)
    {
        return omni::core::steal(new ExecutorSingleNode(toExecute.get(), thisTask));
    }

private:
    ExecutorSingleNode(ITopology* toExecute, const ExecutionTask& thisTask) : ExecutorFallback(toExecute, thisTask)
    {
    }

    Status continueExecute_abi(ExecutionTask* currentTask) noexcept override
    {
        OMNI_GRAPH_EXEC_ASSERT(currentTask);

        return currentTask->getExecutionStatus();
    }
};

} // namespace detail

template <typename StorageType, typename ParentInterface = IExecutionContext>
class ExecutionContext : public Implements<ParentInterface>
{
protected:
    class ScopedInExecute
    {
    public:
        ScopedInExecute(ExecutionContext& context, const bool recursive) noexcept
            : m_context(context), m_recursive(recursive)
        {
            // Note that we skip adding a new thread ID if we're (a) accounting for recursive
            // execution, and (b) we've invoked context evaluation from within some executed
            // task. This helps in preventing deadlocks from occurring when parallel-scheduled
            // tasks directly invoke on-demand execution of other definitions via the context.
            if ((m_recursive && !getCurrentTask()) || !m_recursive)
            {
                std::lock_guard<carb::thread::Spinlock> lock(m_context.m_threadIdSpinlock);
                ++m_context.m_contextThreadIds[std::this_thread::get_id()];
            }
        }
        ~ScopedInExecute() noexcept
        {
            if ((m_recursive && !getCurrentTask()) || !m_recursive)
            {
                std::lock_guard<carb::thread::Spinlock> lock(m_context.m_threadIdSpinlock);
                --m_context.m_contextThreadIds[std::this_thread::get_id()];
                if (m_context.m_contextThreadIds[std::this_thread::get_id()] == 0)
                {
                    m_context.m_contextThreadIds.erase(std::this_thread::get_id());
                }
            }
        }

    private:
        ExecutionContext& m_context;
        const bool m_recursive;
    };

    Stamp getExecutionStamp_abi() noexcept override
    {
        return m_executionStamp;
    }

    bool inExecute_abi() noexcept override
    {
        std::lock_guard<carb::thread::Spinlock> lock(m_threadIdSpinlock);
        return !m_contextThreadIds.empty();
    }

    bool isExecutingThread_abi() noexcept override
    {
        std::lock_guard<carb::thread::Spinlock> lock(m_threadIdSpinlock);
        return m_contextThreadIds.find(std::this_thread::get_id()) != m_contextThreadIds.end();
    }

    Status execute_abi() noexcept override
    {
        this->initialize();

        m_executionStamp = _getNextGlobalExecutionStamp();

        ScopedInExecute scopedInExecute(*this, false);
        ScopedExecutionDebug scopedDebug{ m_graph->getNodeGraphDef() };

        return getCurrentThread()->executeGraph(m_graph, this);
    }

    Status executeNode_abi(const ExecutionPath* path, INode* node) noexcept override
    {
        OMNI_GRAPH_EXEC_ASSERT(path);
        OMNI_GRAPH_EXEC_ASSERT(node);

        this->initialize();

        m_executionStamp = _getNextGlobalExecutionStamp();

        ScopedInExecute scopedInExecute(*this, true);
        ScopedExecutionDebug scopedDebug{ m_graph->getNodeGraphDef() };

        auto def = node->getDef();
        if (def)
        {
            // Part of the Executor's (here referring to the templated "default" implementation of IExecutor that
            // comes with EF) top-level execute() method queries its scheduler's status, which pops the current task
            // from the queue, executes it, and automatically tells the executor to continue with its traversal. The
            // subsequent continued executor traversals then keep scheduling work for evaluating all downstream
            // dependent nodes until the scheduler's task list is cleared. On-demand execution of individual definitions
            // (via ExecutionContext::executeNode), however, does not call into this code-path; it instead creates and
            // directly executes a temporary task with a temporary executor for the specified node. As a result, if one
            // were to simply use the ExecutorFallback here, the eventual call to continueExecute after computing the
            // given task would force the executor to take one more unnecessary "step" downstream to try and process a
            // child. This can cause many issues to unfold, some of which will be listed here:
            // - The top-level call to execute the temporary task actually ends up returning the status of this
            // continued
            //   traversal, which won't necessarily line up with the execution status of the temporary task (despite us
            //   only caring about the latter in this situation). For example, if the specified node was successfully
            //   executed, but then the next node in the chain is scheduled for deferred execution, the final returned
            //   status will be a combination of eSuccess and eDeferred.
            // - If the child node that gets processed after the initially-targeted/evaluated node is set to bypass
            // scheduling,
            //   then it will be incorrectly executed on-the-spot, which does not reflect the original query made by
            //   executeNode(). Depending on the definitions, this can be unnecessarily costly.
            // In addition to the above, even if the extra traversal step doesn't cause any issues to immediately arise,
            // there's no need to spend CPU cycles on code-paths that don't need to be run in this specific situation.

            ExecutionTask newTask{ this, node, *path };
            auto tmpExecutor = detail::ExecutorSingleNode::create(node->getTopology(), newTask);
            return newTask.execute(tmpExecutor);
        }
        else
        {
            // There was no work to be done, so we successfully no-op!
            return Status::eSuccess;
        }
    }

    void initialize_abi() noexcept override
    {
        if (!m_initStamp.makeSync(m_graph->getTopology()->getStamp()))
        {
            return; // in sync
        }

        auto traversalFn = [this](INodeGraphDef* nodeGraphDef, const ExecutionPath& path, auto& recursionFn) -> void
        {
            ExecutionTask info(this, nodeGraphDef->getRoot(), path);
            nodeGraphDef->initializeState(info);

            traverseDepthFirst<VisitFirst>(nodeGraphDef->getRoot(),
                                           [&path, &recursionFn, nodeGraphDef](auto info, INode* prev, INode* curr)
                                           {
                                               auto currNodeGraphDef = curr->getNodeGraphDef();
                                               if (currNodeGraphDef)
                                               {
                                                   ExecutionPath newPath{ path, curr }; // may throw
                                                   recursionFn(currNodeGraphDef, newPath, recursionFn);
                                               }

                                               info.continueVisit(curr);
                                           });
        };

        ExecutionPath path;
        traversalFn(m_graph->getNodeGraphDef(), path, traversalFn);
    }

    virtual IExecutionStateInfo* getStateInfo_abi(const ExecutionPath* path, INode* node) noexcept override
    {
        OMNI_GRAPH_EXEC_ASSERT(path);
        return m_storage.getStateInfo(*path, node);
    }

    virtual omni::core::Result getNodeData_abi(const ExecutionPath* path,
                                               INode* node,
                                               NodeDataKey key,
                                               omni::core::TypeId* outTypeId,
                                               void** outPtr,
                                               uint64_t* outItemSize,
                                               uint64_t* outBufferSize) noexcept override
    {
        // outTypeId, outItemSize, and outBufferSize should be checked by m_storage
        OMNI_GRAPH_EXEC_ASSERT(path);
        return m_storage.getNodeData(*path, node, key, outTypeId, outPtr, outItemSize, outBufferSize);
    }

    virtual void setNodeData_abi(const ExecutionPath* path,
                                 INode* node,
                                 NodeDataKey key,
                                 omni::core::TypeId typeId,
                                 void* data,
                                 uint64_t dataByteCount,
                                 uint64_t dataItemCount,
                                 NodeDataDeleterFn* deleter) noexcept override
    {
        OMNI_GRAPH_EXEC_FATAL_UNLESS_ARG(path);
        m_storage.setNodeData(*path, node, key, typeId, data, dataByteCount, dataItemCount, deleter);
    }

    void applyOnEachDef_abi(IDef* def, IApplyOnEachFunction* callback) noexcept override
    {
        OMNI_GRAPH_EXEC_ASSERT(callback);
        m_pathCache.applyOnEach(def, *callback);
    }

    void applyOnEachDefWithName_abi(const ConstName* name, IApplyOnEachFunction* callback) noexcept override
    {
        OMNI_GRAPH_EXEC_ASSERT(name);
        OMNI_GRAPH_EXEC_ASSERT(callback);
        m_pathCache.applyOnEach(*name, *callback);
    }

    ExecutionContext(IGraph* graph) noexcept
        : m_graph(graph), m_executionStamp(_getNextGlobalExecutionStamp()), m_pathCache(*graph)
    {
    }

    StorageType m_storage;

private:
    static Stamp _getNextGlobalExecutionStamp() noexcept
    {
        // since this is private, and will only be accessed indirectly via virtual methods, declaring this inline static
        // should be ok
        static Stamp gExecutionStamp;
        gExecutionStamp.next();
        return gExecutionStamp;
    }

    IGraph* m_graph{ nullptr };
    Stamp m_executionStamp;
    SyncStamp m_initStamp;
    detail::ExecutionPathCache m_pathCache;
    std::unordered_map<std::thread::id, size_t> m_contextThreadIds;
    carb::thread::Spinlock m_threadIdSpinlock;
};

} // namespace unstable
} // namespace exec
} // namespace graph
} // namespace omni