omni/graph/exec/unstable/Node.h

File members: omni/graph/exec/unstable/Node.h

// Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include <omni/core/ResultError.h>
#include <omni/graph/exec/unstable/Assert.h>
#include <omni/graph/exec/unstable/IGraph.h>
#include <omni/graph/exec/unstable/IGraphBuilderNode.h>
#include <omni/graph/exec/unstable/INode.h>
#include <omni/graph/exec/unstable/INodeDef.h>
#include <omni/graph/exec/unstable/INodeGraphDef.h>
#include <omni/graph/exec/unstable/ITopology.h>
#include <omni/graph/exec/unstable/SmallVector.h>
#include <omni/graph/exec/unstable/Types.h>

namespace omni
{
namespace graph
{
namespace exec
{
namespace unstable
{

template <typename... Bases>
class NodeT : public Implements<Bases...>
{
public:
    template <typename T>
    static omni::core::ObjectPtr<NodeT> create(T&& graphOrTopology, const carb::cpp::string_view& idName)
    {
        OMNI_GRAPH_EXEC_ASSERT(graphOrTopology);
        OMNI_GRAPH_EXEC_ASSERT(idName.data());

        // Get the topology in which the node should reside in.
        ITopology* topology = GetTopology::get(std::forward<T>(graphOrTopology));

        return omni::core::steal(new NodeT(topology, idName));
    }

    template <typename T, typename D>
    static omni::core::ObjectPtr<NodeT> create(T&& graphOrTopology, D&& def, const carb::cpp::string_view& idName)
    {
        OMNI_GRAPH_EXEC_ASSERT(graphOrTopology);
        OMNI_GRAPH_EXEC_ASSERT(idName.data());

        // Get the definition type (IDef, INodeDef, or INodeGraphDef).
        using DefType = typename GetDefType<typename std::remove_reference<D>::type>::type;

        // Get the topology in which the node should reside in.
        ITopology* topology = GetTopology::get(std::forward<T>(graphOrTopology));

        return _create(topology, omni::core::ObjectParam<DefType>(std::forward<D>(def)).get(), idName);
    }

    virtual ~NodeT()
    {
        // in case we decide to implement move constructor
        if (m_indexInTopology != kInvalidNodeIndexInTopology)
        {
            m_topology->releaseNodeIndex(m_indexInTopology);

            if (isValidTopology_abi())
            {
                m_topology->invalidate();
            }
        }
    }

    // disambiguate between INode and IGraphBuilderNode
    using INode::getChildren;
    using INode::getName;
    using INode::getParents;
    using INode::getTopology;
    using INode::hasChild;

protected:
    ITopology* getTopology_abi() noexcept override
    {
        return m_topology;
    }

    const ConstName* getName_abi() noexcept override
    {
        return &m_name;
    }

    NodeIndexInTopology getIndexInTopology_abi() noexcept override
    {
        return m_indexInTopology;
    }

    Span<INode* const> getParents_abi() noexcept override
    {
        return isValidTopology_abi() ? Span<INode* const>{ m_parents.begin(), m_parents.size() } :
                                       Span<INode* const>{ nullptr, 0 };
    }

    Span<INode* const> getChildren_abi() noexcept override
    {
        return isValidTopology_abi() ? Span<INode* const>{ m_children.begin(), m_children.size() } :
                                       Span<INode* const>{ nullptr, 0 };
    }

    uint32_t getCycleParentCount_abi() noexcept override
    {
        return isValidTopology_abi() ? m_cycleParentCount : 0;
    }

    bool isValidTopology_abi() noexcept final override
    {
        return m_topologyStamp.inSync(m_topology->getStamp());
    }

    virtual void validateOrResetTopology_abi() noexcept
    {
        if (m_topologyStamp.makeSync(m_topology->getStamp()))
        {
            // topology changed, let's clear the old one
            m_parents.clear();
            m_children.clear();
            m_cycleParentCount = 0;
        }
    }

    IDef* getDef_abi() noexcept override
    {
        if (m_nodeDef.get())
        {
            return m_nodeDef.get();
        }
        else
        {
            return m_nodeGraphDef.get();
        }
    }

    INodeDef* getNodeDef_abi() noexcept override
    {
        return m_nodeDef.get();
    }

    INodeGraphDef* getNodeGraphDef_abi() noexcept override
    {
        return m_nodeGraphDef.get();
    }

    void _addParent_abi(IGraphBuilderNode* parent) noexcept override
    {
        OMNI_GRAPH_EXEC_ASSERT(parent);
        OMNI_GRAPH_EXEC_CAST_OR_FATAL(asNode, INode, parent);
        OMNI_GRAPH_EXEC_ASSERT(isValidTopology_abi());
        m_parents.push_back(asNode); // may throw
    }

    void _removeParent_abi(IGraphBuilderNode* parent) noexcept override
    {
        if (!parent)
        {
            return; // LCOV_EXCL_LINE
        }
        OMNI_GRAPH_EXEC_CAST_OR_FATAL(asNode, INode, parent);
        _eraseRemove(m_parents, asNode);
    }

    void _addChild_abi(IGraphBuilderNode* child) noexcept override
    {
        OMNI_GRAPH_EXEC_ASSERT(child);
        OMNI_GRAPH_EXEC_CAST_OR_FATAL(asNode, INode, child);
        OMNI_GRAPH_EXEC_ASSERT(isValidTopology_abi());
        m_children.push_back(asNode);
    }

    void _removeChild_abi(IGraphBuilderNode* child) noexcept override
    {
        if (!child)
        {
            return; // LCOV_EXCL_LINE
        }
        OMNI_GRAPH_EXEC_CAST_OR_FATAL(asNode, INode, child);
        _eraseRemove(m_children, asNode);
    }

    void _removeInvalidParents_abi() noexcept override
    {
        if (isValidTopology_abi())
        {
            m_parents.erase(
                std::remove_if(m_parents.begin(), m_parents.end(), [](INode* n) { return !n->isValidTopology(); }),
                m_parents.end());
        }
    }

    void _removeInvalidChildren_abi() noexcept override
    {
        if (isValidTopology_abi())
        {
            m_children.erase(
                std::remove_if(m_children.begin(), m_children.end(), [](INode* n) { return !n->isValidTopology(); }),
                m_children.end());
        }
    }

    void _invalidateConnections_abi() noexcept override
    {
        m_topologyStamp.invalidate();
    }

    void setCycleParentCount_abi(uint32_t count) noexcept override
    {
        m_cycleParentCount = count;
    }

    void _setNodeDef_abi(INodeDef* nodeDef) noexcept override
    {
        m_nodeDef.borrow(nodeDef);
        m_nodeGraphDef.release();
    }

    void _setNodeGraphDef_abi(INodeGraphDef* nodeGraphDef) noexcept override
    {
        m_nodeGraphDef.borrow(nodeGraphDef);
        m_nodeDef.release();
    }

    void _clearDef_abi() noexcept override
    {
        m_nodeDef.release();
        m_nodeGraphDef.release();
    }

    IGraphBuilderNode* getParentAt_abi(uint64_t index) noexcept override
    {
        OMNI_GRAPH_EXEC_FATAL_UNLESS(isValidTopology_abi());
        OMNI_GRAPH_EXEC_FATAL_UNLESS(index < m_parents.size());
        OMNI_GRAPH_EXEC_CAST_OR_FATAL(asGraphBuilderNode, IGraphBuilderNode, m_parents[static_cast<uint32_t>(index)]);
        return asGraphBuilderNode;
    }

    uint64_t getParentCount_abi() noexcept override
    {
        return isValidTopology_abi() ? m_parents.size() : 0;
    }

    IGraphBuilderNode* getChildAt_abi(uint64_t index) noexcept override
    {
        OMNI_GRAPH_EXEC_FATAL_UNLESS(isValidTopology_abi());
        OMNI_GRAPH_EXEC_FATAL_UNLESS(index < m_children.size());
        OMNI_GRAPH_EXEC_CAST_OR_FATAL(asGraphBuilderNode, IGraphBuilderNode, m_children[static_cast<uint32_t>(index)]);
        return asGraphBuilderNode;
    }

    uint64_t getChildCount_abi() noexcept override
    {
        return isValidTopology_abi() ? m_children.size() : 0;
    }

    bool hasChild_abi(IGraphBuilderNode* node) noexcept override
    {
        if (!isValidTopology_abi())
        {
            return false; // LCOV_EXCL_LINE
        }

        auto asNode = omni::graph::exec::unstable::cast<INode>(node);
        if (!asNode)
        {
            return false; // LCOV_EXCL_LINE
        }

        return std::find(m_children.begin(), m_children.end(), asNode) != m_children.end();
    }

    bool isRoot_abi() noexcept override
    {
        return (m_topology->getRoot() == static_cast<INode*>(this));
    }

    IGraphBuilderNode* getRoot_abi() noexcept override
    {
        OMNI_GRAPH_EXEC_CAST_OR_FATAL(asGraphBuilderNode, IGraphBuilderNode, m_topology->getRoot());
        return asGraphBuilderNode;
    }

    NodeT(ITopology* topology, const carb::cpp::string_view& idName) noexcept
        : m_topology{ topology }, m_indexInTopology{ m_topology->acquireNodeIndex() }, m_name{ idName }
    {
    }

    NodeT(ITopology* topology, INodeGraphDef* nodeGraphDef, const carb::cpp::string_view& idName) noexcept
        : m_topology{ topology },
          m_indexInTopology{ m_topology->acquireNodeIndex() },
          m_nodeGraphDef{ nodeGraphDef, omni::core::kBorrow },
          m_name{ idName }
    {
    }

    NodeT(ITopology* topology, INodeDef* nodeDef, const carb::cpp::string_view& idName) noexcept
        : m_topology{ topology },
          m_indexInTopology{ m_topology->acquireNodeIndex() },
          m_nodeDef{ nodeDef, omni::core::kBorrow },
          m_name{ idName }
    {
    }

private:
    using NodeArray = SmallVector<INode*, 2>;

    template <typename T>
    void _eraseRemove(T& v, INode* n) noexcept
    {
        v.erase(std::remove(v.begin(), v.end(), n), v.end());
    };

    ITopology* m_topology;

    NodeIndexInTopology m_indexInTopology{ kInvalidNodeIndexInTopology };

    NodeArray m_parents;
    NodeArray m_children;

    uint32_t m_cycleParentCount{ 0 };
    SyncStamp m_topologyStamp;

    omni::core::ObjectPtr<INodeDef> m_nodeDef;
    omni::core::ObjectPtr<INodeGraphDef> m_nodeGraphDef;

    ConstName m_name;

    // Helper struct for getting a topology pointer from one of the following:
    // - IGraph*
    // - omni::core::ObjectPtr<IGraph>
    // - omni::core::ObjectParam<IGraph>
    // - ITopology*
    // - omni::core::ObjectPtr<ITopology>
    // - omni::core::ObjectParam<ITopology>
    struct GetTopology
    {
        static ITopology* get(omni::core::ObjectParam<IGraph> graph)
        {
            return graph->getTopology();
        }
        static ITopology* get(omni::core::ObjectParam<ITopology> topology)
        {
            return topology.get();
        }
    };

    // Helper structs for obtaining a definition type.
    // LCOV_EXCL_START
    template <typename T>
    struct GetDefType
    {
        using type = std::remove_pointer_t<T>;
    };
    template <typename T>
    struct GetDefType<omni::core::ObjectPtr<T>>
    {
        using type = T;
    };
    template <typename T>
    struct GetDefType<omni::core::ObjectParam<T>>
    {
        using type = T;
    };
    // LCOV_EXCL_STOP

    // Internal implementations for node creation that intake a topology
    // and various definitions. We have separate overloads for the various
    // definition interfaces (IDef, INodeDef, and INodeGraphDef) to reduce
    // some overhead from having to explicitly deduce typing at runtime for
    // the latter two; instead, the correct overload to utilize in a given
    // context can be decided at compile time, saving us runtime checks
    // if we are attempting to create a node using either INodeDef- or
    // INodeGraphDef-derived definitions. Said checks still need to be
    // performed if IDefs are directly passed in, however.
    template <typename T>
    static omni::core::ObjectPtr<NodeT> _create(T topology, INodeGraphDef* def, const carb::cpp::string_view& idName)
    {
        return omni::core::steal(new NodeT(topology, def, idName));
    }

    template <typename T>
    static omni::core::ObjectPtr<NodeT> _create(T topology, INodeDef* def, const carb::cpp::string_view& idName)
    {
        return omni::core::steal(new NodeT(topology, def, idName));
    }

    template <typename T, typename D>
    static omni::core::ObjectPtr<NodeT> _create(T topology, D* def, const carb::cpp::string_view& idName)
    {
        // We don't know if def is an INodeDef or INodeGraphDef at compile time, so
        // we have to perform this runtime check.
        if (!def)
        {
            return omni::core::steal(new NodeT(topology, idName));
        }
        else if (auto* nodeDef = omni::graph::exec::unstable::cast<INodeDef>(def))
        {
            return omni::core::steal(new NodeT(topology, nodeDef, idName));
        }
        else if (auto* nodeGraphDef = omni::graph::exec::unstable::cast<INodeGraphDef>(def))
        {
            return omni::core::steal(new NodeT(topology, nodeGraphDef, idName));
        }
        // LCOV_EXCL_START
        else
        {
            return nullptr; // Should not happen.
        }
        // LCOV_EXCL_STOP
    }
};

using Node = NodeT<INode, IGraphBuilderNode>;

} // namespace unstable
} // namespace exec
} // namespace graph
} // namespace omni