omni/graph/exec/unstable/ExecutionContext.h
File members: omni/graph/exec/unstable/ExecutionContext.h
// Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include <carb/thread/RecursiveSharedMutex.h>
#include <carb/thread/Spinlock.h>
#include <omni/graph/exec/unstable/Assert.h>
#include <omni/graph/exec/unstable/ExecutionPath.h>
#include <omni/graph/exec/unstable/Executor.h>
#include <omni/graph/exec/unstable/IExecutionContext.h>
#include <omni/graph/exec/unstable/IGraph.h>
#include <omni/graph/exec/unstable/INodeGraphDefDebug.h>
#include <omni/graph/exec/unstable/SmallVector.h>
#include <omni/graph/exec/unstable/Traversal.h>
#include <thread>
#include <unordered_map>
namespace omni
{
namespace graph
{
namespace exec
{
namespace unstable
{
namespace detail
{
class ExecutionPathCache
{
public:
ExecutionPathCache() = delete;
explicit ExecutionPathCache(IGraph& graph) noexcept : m_graph(graph)
{
}
template <typename Key>
void applyOnEach(const Key& key, IApplyOnEachFunction& applyFn)
{
if (m_graph.inBuild())
{
// Traversing the entire graph while building it is isn't allowed since multiple threads may be building it.
OMNI_GRAPH_EXEC_ASSERT(!m_graph.inBuild());
return;
}
if (!m_graph.getTopology()->isValid())
{
return;
}
// If we're targeting the top-level NodeGraphDef that belongs to the execution graph,
// simply execute the function here and skip subsequent processing.
if (_isMatch(key, m_graph.getNodeGraphDef()))
{
applyFn.invoke(ExecutionPath::getEmpty());
return;
}
auto discoverAndApplyOnNodesWithDefinitionFn = [this, &key, &applyFn](
const ExecutionPath& upstreamPath, INodeGraphDef& graph,
Paths& collectedPaths, auto recursionFn) -> void
{
traverseDepthFirst<VisitFirst>(
graph.getRoot(),
[this, &upstreamPath, &key, &recursionFn, &applyFn, &collectedPaths](auto info, INode* prev, INode* curr)
{
auto currNodeGraph = curr->getNodeGraphDef();
if (currNodeGraph)
{
ExecutionPath newUpstreamPath(upstreamPath, curr);
recursionFn(newUpstreamPath, *currNodeGraph, collectedPaths, recursionFn);
}
auto def = curr->getDef();
if (def && _isMatch(key, def))
{
collectedPaths.emplace_back(upstreamPath, curr);
applyFn.invoke(collectedPaths.back());
}
info.continueVisit(curr);
});
};
// Check if the this cache is in-sync with the current topology. Since we can run this method in parallel, we
// need a read lock to m_mutex to safely read m_topologyStamp.
std::shared_lock<MutexType> readLock(m_mutex);
auto topologyStamp = *m_graph.getGlobalTopologyStamp();
if (!m_topologyStamp.inSync(topologyStamp))
{
// Cache is out-of-sync; upgrade to a write lock.
readLock.unlock();
{
// Here we once again check to see if the cache is in-sync since another thread may have beat this
// thread to the write lock and brought the cache into sync.
std::lock_guard<MutexType> writeLock(m_mutex);
if (m_topologyStamp.makeSync(topologyStamp))
{
// We're the thread that got to the write lock first, so it's our job to clear the cache.
m_defCache.clear();
m_nameCache.clear();
}
}
// Grab the read lock again so we can safely read the cache.
readLock.lock();
}
auto& cache = _getCache(key);
auto findIt = cache.find(key);
if (findIt != cache.end())
{
// We've seen this name before. Make a copy of the paths so we can release the readLock. This is
// required because an invocation can result in re-entering and taking the writeLock.
auto pathsCopy = findIt->second;
readLock.unlock();
for (ExecutionPath& path : pathsCopy)
{
applyFn.invoke(path);
}
}
else
{
// Release readLock because apply below can result in re-entry of this function.
readLock.unlock();
// Either the key wasn't found or we're building the graph.
Paths paths;
discoverAndApplyOnNodesWithDefinitionFn(
ExecutionPath::getEmpty(), *m_graph.getNodeGraphDef(), paths, discoverAndApplyOnNodesWithDefinitionFn);
// Insert only once we collected all the paths. Some other thread may be looking for this definition at
// the same time.
std::lock_guard<MutexType> writeLock(m_mutex);
cache.emplace(key, std::move(paths));
}
}
private:
bool _isMatch(const ConstName& desired, IDef* candidate) noexcept
{
return (desired == candidate->getName());
}
bool _isMatch(IDef* desired, IDef* candidate) noexcept
{
return (desired == candidate);
}
auto& _getCache(const ConstName&) noexcept
{
return m_nameCache;
}
auto& _getCache(IDef*) noexcept
{
return m_defCache;
}
using Paths = SmallVector<ExecutionPath, 2>;
using DefCache = std::unordered_map<IDef*, Paths>;
using NameCache = std::unordered_map<ConstName, Paths>;
using MutexType = carb::thread::recursive_shared_mutex;
IGraph& m_graph;
DefCache m_defCache;
NameCache m_nameCache;
MutexType m_mutex;
SyncStamp m_topologyStamp;
};
class ExecutorSingleNode final : public ExecutorFallback
{
public:
static omni::core::ObjectPtr<ExecutorSingleNode> create(omni::core::ObjectParam<ITopology> toExecute,
const ExecutionTask& thisTask)
{
return omni::core::steal(new ExecutorSingleNode(toExecute.get(), thisTask));
}
private:
ExecutorSingleNode(ITopology* toExecute, const ExecutionTask& thisTask) : ExecutorFallback(toExecute, thisTask)
{
}
Status continueExecute_abi(ExecutionTask* currentTask) noexcept override
{
OMNI_GRAPH_EXEC_ASSERT(currentTask);
return currentTask->getExecutionStatus();
}
};
} // namespace detail
template <typename StorageType, typename ParentInterface = IExecutionContext>
class ExecutionContext : public Implements<ParentInterface>
{
protected:
class ScopedInExecute
{
public:
ScopedInExecute(ExecutionContext& context, const bool recursive) noexcept
: m_context(context), m_recursive(recursive)
{
// Note that we skip adding a new thread ID if we're (a) accounting for recursive
// execution, and (b) we've invoked context evaluation from within some executed
// task. This helps in preventing deadlocks from occurring when parallel-scheduled
// tasks directly invoke on-demand execution of other definitions via the context.
if ((m_recursive && !getCurrentTask()) || !m_recursive)
{
std::lock_guard<carb::thread::Spinlock> lock(m_context.m_threadIdSpinlock);
++m_context.m_contextThreadIds[std::this_thread::get_id()];
}
}
~ScopedInExecute() noexcept
{
if ((m_recursive && !getCurrentTask()) || !m_recursive)
{
std::lock_guard<carb::thread::Spinlock> lock(m_context.m_threadIdSpinlock);
--m_context.m_contextThreadIds[std::this_thread::get_id()];
if (m_context.m_contextThreadIds[std::this_thread::get_id()] == 0)
{
m_context.m_contextThreadIds.erase(std::this_thread::get_id());
}
}
}
private:
ExecutionContext& m_context;
const bool m_recursive;
};
Stamp getExecutionStamp_abi() noexcept override
{
return m_executionStamp;
}
bool inExecute_abi() noexcept override
{
std::lock_guard<carb::thread::Spinlock> lock(m_threadIdSpinlock);
return !m_contextThreadIds.empty();
}
bool isExecutingThread_abi() noexcept override
{
std::lock_guard<carb::thread::Spinlock> lock(m_threadIdSpinlock);
return m_contextThreadIds.find(std::this_thread::get_id()) != m_contextThreadIds.end();
}
Status execute_abi() noexcept override
{
this->initialize();
m_executionStamp = _getNextGlobalExecutionStamp();
ScopedInExecute scopedInExecute(*this, false);
ScopedExecutionDebug scopedDebug{ m_graph->getNodeGraphDef() };
return getCurrentThread()->executeGraph(m_graph, this);
}
Status executeNode_abi(const ExecutionPath* path, INode* node) noexcept override
{
OMNI_GRAPH_EXEC_ASSERT(path);
OMNI_GRAPH_EXEC_ASSERT(node);
this->initialize();
m_executionStamp = _getNextGlobalExecutionStamp();
ScopedInExecute scopedInExecute(*this, true);
ScopedExecutionDebug scopedDebug{ m_graph->getNodeGraphDef() };
auto def = node->getDef();
if (def)
{
// Part of the Executor's (here referring to the templated "default" implementation of IExecutor that
// comes with EF) top-level execute() method queries its scheduler's status, which pops the current task
// from the queue, executes it, and automatically tells the executor to continue with its traversal. The
// subsequent continued executor traversals then keep scheduling work for evaluating all downstream
// dependent nodes until the scheduler's task list is cleared. On-demand execution of individual definitions
// (via ExecutionContext::executeNode), however, does not call into this code-path; it instead creates and
// directly executes a temporary task with a temporary executor for the specified node. As a result, if one
// were to simply use the ExecutorFallback here, the eventual call to continueExecute after computing the
// given task would force the executor to take one more unnecessary "step" downstream to try and process a
// child. This can cause many issues to unfold, some of which will be listed here:
// - The top-level call to execute the temporary task actually ends up returning the status of this
// continued
// traversal, which won't necessarily line up with the execution status of the temporary task (despite us
// only caring about the latter in this situation). For example, if the specified node was successfully
// executed, but then the next node in the chain is scheduled for deferred execution, the final returned
// status will be a combination of eSuccess and eDeferred.
// - If the child node that gets processed after the initially-targeted/evaluated node is set to bypass
// scheduling,
// then it will be incorrectly executed on-the-spot, which does not reflect the original query made by
// executeNode(). Depending on the definitions, this can be unnecessarily costly.
// In addition to the above, even if the extra traversal step doesn't cause any issues to immediately arise,
// there's no need to spend CPU cycles on code-paths that don't need to be run in this specific situation.
ExecutionTask newTask{ this, node, *path };
auto tmpExecutor = detail::ExecutorSingleNode::create(node->getTopology(), newTask);
return newTask.execute(tmpExecutor);
}
else
{
// There was no work to be done, so we successfully no-op!
return Status::eSuccess;
}
}
void initialize_abi() noexcept override
{
if (!m_initStamp.makeSync(m_graph->getTopology()->getStamp()))
{
return; // in sync
}
auto traversalFn = [this](INodeGraphDef* nodeGraphDef, const ExecutionPath& path, auto& recursionFn) -> void
{
ExecutionTask info(this, nodeGraphDef->getRoot(), path);
nodeGraphDef->initializeState(info);
traverseDepthFirst<VisitFirst>(nodeGraphDef->getRoot(),
[&path, &recursionFn, nodeGraphDef](auto info, INode* prev, INode* curr)
{
auto currNodeGraphDef = curr->getNodeGraphDef();
if (currNodeGraphDef)
{
ExecutionPath newPath{ path, curr }; // may throw
recursionFn(currNodeGraphDef, newPath, recursionFn);
}
info.continueVisit(curr);
});
};
ExecutionPath path;
traversalFn(m_graph->getNodeGraphDef(), path, traversalFn);
}
virtual IExecutionStateInfo* getStateInfo_abi(const ExecutionPath* path, INode* node) noexcept override
{
OMNI_GRAPH_EXEC_ASSERT(path);
return m_storage.getStateInfo(*path, node);
}
virtual omni::core::Result getNodeData_abi(const ExecutionPath* path,
INode* node,
NodeDataKey key,
omni::core::TypeId* outTypeId,
void** outPtr,
uint64_t* outItemSize,
uint64_t* outBufferSize) noexcept override
{
// outTypeId, outItemSize, and outBufferSize should be checked by m_storage
OMNI_GRAPH_EXEC_ASSERT(path);
return m_storage.getNodeData(*path, node, key, outTypeId, outPtr, outItemSize, outBufferSize);
}
virtual void setNodeData_abi(const ExecutionPath* path,
INode* node,
NodeDataKey key,
omni::core::TypeId typeId,
void* data,
uint64_t dataByteCount,
uint64_t dataItemCount,
NodeDataDeleterFn* deleter) noexcept override
{
OMNI_GRAPH_EXEC_FATAL_UNLESS_ARG(path);
m_storage.setNodeData(*path, node, key, typeId, data, dataByteCount, dataItemCount, deleter);
}
void applyOnEachDef_abi(IDef* def, IApplyOnEachFunction* callback) noexcept override
{
OMNI_GRAPH_EXEC_ASSERT(callback);
m_pathCache.applyOnEach(def, *callback);
}
void applyOnEachDefWithName_abi(const ConstName* name, IApplyOnEachFunction* callback) noexcept override
{
OMNI_GRAPH_EXEC_ASSERT(name);
OMNI_GRAPH_EXEC_ASSERT(callback);
m_pathCache.applyOnEach(*name, *callback);
}
ExecutionContext(IGraph* graph) noexcept
: m_graph(graph), m_executionStamp(_getNextGlobalExecutionStamp()), m_pathCache(*graph)
{
}
StorageType m_storage;
private:
static Stamp _getNextGlobalExecutionStamp() noexcept
{
// since this is private, and will only be accessed indirectly via virtual methods, declaring this inline static
// should be ok
static Stamp gExecutionStamp;
gExecutionStamp.next();
return gExecutionStamp;
}
IGraph* m_graph{ nullptr };
Stamp m_executionStamp;
SyncStamp m_initStamp;
detail::ExecutionPathCache m_pathCache;
std::unordered_map<std::thread::id, size_t> m_contextThreadIds;
carb::thread::Spinlock m_threadIdSpinlock;
};
} // namespace unstable
} // namespace exec
} // namespace graph
} // namespace omni