omni/graph/exec/unstable/Node.h
File members: omni/graph/exec/unstable/Node.h
// Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include <omni/core/ResultError.h>
#include <omni/graph/exec/unstable/Assert.h>
#include <omni/graph/exec/unstable/IGraph.h>
#include <omni/graph/exec/unstable/IGraphBuilderNode.h>
#include <omni/graph/exec/unstable/INode.h>
#include <omni/graph/exec/unstable/INodeDef.h>
#include <omni/graph/exec/unstable/INodeGraphDef.h>
#include <omni/graph/exec/unstable/ITopology.h>
#include <omni/graph/exec/unstable/SmallVector.h>
#include <omni/graph/exec/unstable/Types.h>
namespace omni
{
namespace graph
{
namespace exec
{
namespace unstable
{
template <typename... Bases>
class NodeT : public Implements<Bases...>
{
public:
template <typename T>
static omni::core::ObjectPtr<NodeT> create(T&& graphOrTopology, const carb::cpp::string_view& idName)
{
OMNI_GRAPH_EXEC_ASSERT(graphOrTopology);
OMNI_GRAPH_EXEC_ASSERT(idName.data());
// Get the topology in which the node should reside in.
ITopology* topology = GetTopology::get(std::forward<T>(graphOrTopology));
return omni::core::steal(new NodeT(topology, idName));
}
template <typename T, typename D>
static omni::core::ObjectPtr<NodeT> create(T&& graphOrTopology, D&& def, const carb::cpp::string_view& idName)
{
OMNI_GRAPH_EXEC_ASSERT(graphOrTopology);
OMNI_GRAPH_EXEC_ASSERT(idName.data());
// Get the definition type (IDef, INodeDef, or INodeGraphDef).
using DefType = typename GetDefType<typename std::remove_reference<D>::type>::type;
// Get the topology in which the node should reside in.
ITopology* topology = GetTopology::get(std::forward<T>(graphOrTopology));
return _create(topology, omni::core::ObjectParam<DefType>(std::forward<D>(def)).get(), idName);
}
virtual ~NodeT()
{
// in case we decide to implement move constructor
if (m_indexInTopology != kInvalidNodeIndexInTopology)
{
m_topology->releaseNodeIndex(m_indexInTopology);
if (isValidTopology_abi())
{
m_topology->invalidate();
}
}
}
// disambiguate between INode and IGraphBuilderNode
using INode::getChildren;
using INode::getName;
using INode::getParents;
using INode::getTopology;
using INode::hasChild;
protected:
ITopology* getTopology_abi() noexcept override
{
return m_topology;
}
const ConstName* getName_abi() noexcept override
{
return &m_name;
}
NodeIndexInTopology getIndexInTopology_abi() noexcept override
{
return m_indexInTopology;
}
Span<INode* const> getParents_abi() noexcept override
{
return isValidTopology_abi() ? Span<INode* const>{ m_parents.begin(), m_parents.size() } :
Span<INode* const>{ nullptr, 0 };
}
Span<INode* const> getChildren_abi() noexcept override
{
return isValidTopology_abi() ? Span<INode* const>{ m_children.begin(), m_children.size() } :
Span<INode* const>{ nullptr, 0 };
}
uint32_t getCycleParentCount_abi() noexcept override
{
return isValidTopology_abi() ? m_cycleParentCount : 0;
}
bool isValidTopology_abi() noexcept final override
{
return m_topologyStamp.inSync(m_topology->getStamp());
}
virtual void validateOrResetTopology_abi() noexcept
{
if (m_topologyStamp.makeSync(m_topology->getStamp()))
{
// topology changed, let's clear the old one
m_parents.clear();
m_children.clear();
m_cycleParentCount = 0;
}
}
IDef* getDef_abi() noexcept override
{
if (m_nodeDef.get())
{
return m_nodeDef.get();
}
else
{
return m_nodeGraphDef.get();
}
}
INodeDef* getNodeDef_abi() noexcept override
{
return m_nodeDef.get();
}
INodeGraphDef* getNodeGraphDef_abi() noexcept override
{
return m_nodeGraphDef.get();
}
void _addParent_abi(IGraphBuilderNode* parent) noexcept override
{
OMNI_GRAPH_EXEC_ASSERT(parent);
OMNI_GRAPH_EXEC_CAST_OR_FATAL(asNode, INode, parent);
OMNI_GRAPH_EXEC_ASSERT(isValidTopology_abi());
m_parents.push_back(asNode); // may throw
}
void _removeParent_abi(IGraphBuilderNode* parent) noexcept override
{
if (!parent)
{
return; // LCOV_EXCL_LINE
}
OMNI_GRAPH_EXEC_CAST_OR_FATAL(asNode, INode, parent);
_eraseRemove(m_parents, asNode);
}
void _addChild_abi(IGraphBuilderNode* child) noexcept override
{
OMNI_GRAPH_EXEC_ASSERT(child);
OMNI_GRAPH_EXEC_CAST_OR_FATAL(asNode, INode, child);
OMNI_GRAPH_EXEC_ASSERT(isValidTopology_abi());
m_children.push_back(asNode);
}
void _removeChild_abi(IGraphBuilderNode* child) noexcept override
{
if (!child)
{
return; // LCOV_EXCL_LINE
}
OMNI_GRAPH_EXEC_CAST_OR_FATAL(asNode, INode, child);
_eraseRemove(m_children, asNode);
}
void _removeInvalidParents_abi() noexcept override
{
if (isValidTopology_abi())
{
m_parents.erase(
std::remove_if(m_parents.begin(), m_parents.end(), [](INode* n) { return !n->isValidTopology(); }),
m_parents.end());
}
}
void _removeInvalidChildren_abi() noexcept override
{
if (isValidTopology_abi())
{
m_children.erase(
std::remove_if(m_children.begin(), m_children.end(), [](INode* n) { return !n->isValidTopology(); }),
m_children.end());
}
}
void _invalidateConnections_abi() noexcept override
{
m_topologyStamp.invalidate();
}
void setCycleParentCount_abi(uint32_t count) noexcept override
{
m_cycleParentCount = count;
}
void _setNodeDef_abi(INodeDef* nodeDef) noexcept override
{
m_nodeDef.borrow(nodeDef);
m_nodeGraphDef.release();
}
void _setNodeGraphDef_abi(INodeGraphDef* nodeGraphDef) noexcept override
{
m_nodeGraphDef.borrow(nodeGraphDef);
m_nodeDef.release();
}
void _clearDef_abi() noexcept override
{
m_nodeDef.release();
m_nodeGraphDef.release();
}
IGraphBuilderNode* getParentAt_abi(uint64_t index) noexcept override
{
OMNI_GRAPH_EXEC_FATAL_UNLESS(isValidTopology_abi());
OMNI_GRAPH_EXEC_FATAL_UNLESS(index < m_parents.size());
OMNI_GRAPH_EXEC_CAST_OR_FATAL(asGraphBuilderNode, IGraphBuilderNode, m_parents[static_cast<uint32_t>(index)]);
return asGraphBuilderNode;
}
uint64_t getParentCount_abi() noexcept override
{
return isValidTopology_abi() ? m_parents.size() : 0;
}
IGraphBuilderNode* getChildAt_abi(uint64_t index) noexcept override
{
OMNI_GRAPH_EXEC_FATAL_UNLESS(isValidTopology_abi());
OMNI_GRAPH_EXEC_FATAL_UNLESS(index < m_children.size());
OMNI_GRAPH_EXEC_CAST_OR_FATAL(asGraphBuilderNode, IGraphBuilderNode, m_children[static_cast<uint32_t>(index)]);
return asGraphBuilderNode;
}
uint64_t getChildCount_abi() noexcept override
{
return isValidTopology_abi() ? m_children.size() : 0;
}
bool hasChild_abi(IGraphBuilderNode* node) noexcept override
{
if (!isValidTopology_abi())
{
return false; // LCOV_EXCL_LINE
}
auto asNode = omni::graph::exec::unstable::cast<INode>(node);
if (!asNode)
{
return false; // LCOV_EXCL_LINE
}
return std::find(m_children.begin(), m_children.end(), asNode) != m_children.end();
}
bool isRoot_abi() noexcept override
{
return (m_topology->getRoot() == static_cast<INode*>(this));
}
IGraphBuilderNode* getRoot_abi() noexcept override
{
OMNI_GRAPH_EXEC_CAST_OR_FATAL(asGraphBuilderNode, IGraphBuilderNode, m_topology->getRoot());
return asGraphBuilderNode;
}
NodeT(ITopology* topology, const carb::cpp::string_view& idName) noexcept
: m_topology{ topology }, m_indexInTopology{ m_topology->acquireNodeIndex() }, m_name{ idName }
{
}
NodeT(ITopology* topology, INodeGraphDef* nodeGraphDef, const carb::cpp::string_view& idName) noexcept
: m_topology{ topology },
m_indexInTopology{ m_topology->acquireNodeIndex() },
m_nodeGraphDef{ nodeGraphDef, omni::core::kBorrow },
m_name{ idName }
{
}
NodeT(ITopology* topology, INodeDef* nodeDef, const carb::cpp::string_view& idName) noexcept
: m_topology{ topology },
m_indexInTopology{ m_topology->acquireNodeIndex() },
m_nodeDef{ nodeDef, omni::core::kBorrow },
m_name{ idName }
{
}
private:
using NodeArray = SmallVector<INode*, 2>;
template <typename T>
void _eraseRemove(T& v, INode* n) noexcept
{
v.erase(std::remove(v.begin(), v.end(), n), v.end());
};
ITopology* m_topology;
NodeIndexInTopology m_indexInTopology{ kInvalidNodeIndexInTopology };
NodeArray m_parents;
NodeArray m_children;
uint32_t m_cycleParentCount{ 0 };
SyncStamp m_topologyStamp;
omni::core::ObjectPtr<INodeDef> m_nodeDef;
omni::core::ObjectPtr<INodeGraphDef> m_nodeGraphDef;
ConstName m_name;
// Helper struct for getting a topology pointer from one of the following:
// - IGraph*
// - omni::core::ObjectPtr<IGraph>
// - omni::core::ObjectParam<IGraph>
// - ITopology*
// - omni::core::ObjectPtr<ITopology>
// - omni::core::ObjectParam<ITopology>
struct GetTopology
{
static ITopology* get(omni::core::ObjectParam<IGraph> graph)
{
return graph->getTopology();
}
static ITopology* get(omni::core::ObjectParam<ITopology> topology)
{
return topology.get();
}
};
// Helper structs for obtaining a definition type.
// LCOV_EXCL_START
template <typename T>
struct GetDefType
{
using type = std::remove_pointer_t<T>;
};
template <typename T>
struct GetDefType<omni::core::ObjectPtr<T>>
{
using type = T;
};
template <typename T>
struct GetDefType<omni::core::ObjectParam<T>>
{
using type = T;
};
// LCOV_EXCL_STOP
// Internal implementations for node creation that intake a topology
// and various definitions. We have separate overloads for the various
// definition interfaces (IDef, INodeDef, and INodeGraphDef) to reduce
// some overhead from having to explicitly deduce typing at runtime for
// the latter two; instead, the correct overload to utilize in a given
// context can be decided at compile time, saving us runtime checks
// if we are attempting to create a node using either INodeDef- or
// INodeGraphDef-derived definitions. Said checks still need to be
// performed if IDefs are directly passed in, however.
template <typename T>
static omni::core::ObjectPtr<NodeT> _create(T topology, INodeGraphDef* def, const carb::cpp::string_view& idName)
{
return omni::core::steal(new NodeT(topology, def, idName));
}
template <typename T>
static omni::core::ObjectPtr<NodeT> _create(T topology, INodeDef* def, const carb::cpp::string_view& idName)
{
return omni::core::steal(new NodeT(topology, def, idName));
}
template <typename T, typename D>
static omni::core::ObjectPtr<NodeT> _create(T topology, D* def, const carb::cpp::string_view& idName)
{
// We don't know if def is an INodeDef or INodeGraphDef at compile time, so
// we have to perform this runtime check.
if (!def)
{
return omni::core::steal(new NodeT(topology, idName));
}
else if (auto* nodeDef = omni::graph::exec::unstable::cast<INodeDef>(def))
{
return omni::core::steal(new NodeT(topology, nodeDef, idName));
}
else if (auto* nodeGraphDef = omni::graph::exec::unstable::cast<INodeGraphDef>(def))
{
return omni::core::steal(new NodeT(topology, nodeGraphDef, idName));
}
// LCOV_EXCL_START
else
{
return nullptr; // Should not happen.
}
// LCOV_EXCL_STOP
}
};
using Node = NodeT<INode, IGraphBuilderNode>;
} // namespace unstable
} // namespace exec
} // namespace graph
} // namespace omni