omni/graph/exec/unstable/Graph.h

File members: omni/graph/exec/unstable/Graph.h

// Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include <omni/graph/exec/unstable/Assert.h>
#include <omni/graph/exec/unstable/ExecutorFactory.h>
#include <omni/graph/exec/unstable/IGraph.h>
#include <omni/graph/exec/unstable/NodeGraphDef.h>

#include <memory>
#include <string>

namespace omni
{
namespace graph
{
namespace exec
{
namespace unstable
{

template <typename... Bases>
class GraphT : public Implements<Bases...>
{
public:
    static omni::core::ObjectPtr<GraphT> create(const carb::cpp::string_view& name) noexcept
    {
        OMNI_GRAPH_EXEC_ASSERT(name.data());
        return omni::core::steal(new GraphT(name));
    }

    static omni::core::ObjectPtr<GraphT> create(const ExecutorFactory& executorFactory,
                                                const carb::cpp::string_view& name) noexcept
    {
        OMNI_GRAPH_EXEC_ASSERT(name.data());
        return omni::core::steal(new GraphT(executorFactory, name));
    }

    template <typename Fn>
    static omni::core::ObjectPtr<GraphT> create(const carb::cpp::string_view& name, Fn&& nodeGraphDefFactory) noexcept
    {
        OMNI_GRAPH_EXEC_ASSERT(name.data());
        return omni::core::steal(new GraphT(name, std::forward<Fn>(nodeGraphDefFactory)));
    }

protected:
    INodeGraphDef* getNodeGraphDef_abi() noexcept override
    {
        return m_nodeGraphDef.get();
    }

    const ConstName* getName_abi() noexcept override
    {
        return &m_name;
    }

    Stamp* getGlobalTopologyStamp_abi() noexcept override
    {
        return &m_globalTopologyStamp;
    }

    virtual bool inBuild_abi() noexcept override
    {
        return (m_inBuild > 0);
    }

    virtual void _setInBuild_abi(bool inBuild) noexcept override
    {
        if (inBuild)
        {
            ++m_inBuild;
        }
        else
        {
            --m_inBuild;
            OMNI_GRAPH_EXEC_ASSERT(m_inBuild > -1);
        }
    }

    GraphT(const carb::cpp::string_view& name) noexcept : m_name(name)
    {
        m_globalTopologyStamp.next();
        m_nodeGraphDef = NodeGraphDef::create(this, "NODE-ROOT");
    }

    GraphT(const ExecutorFactory& executorFactory, const carb::cpp::string_view& name) noexcept : m_name(name)
    {
        m_globalTopologyStamp.next();
        m_nodeGraphDef = NodeGraphDef::create(this, executorFactory, "NODE-ROOT");
    }

    template <typename Fn>
    GraphT(const carb::cpp::string_view& name, Fn&& nodeGraphDefFactory) noexcept : m_name(name)
    {
        m_globalTopologyStamp.next();
        m_nodeGraphDef = nodeGraphDefFactory(this); // may throw
    }

private:
    Stamp m_globalTopologyStamp;
    omni::core::ObjectPtr<INodeGraphDef> m_nodeGraphDef;
    ConstName m_name;

    std::atomic<int> m_inBuild{ 0 };
};

using Graph = GraphT<IGraph>;

} // namespace unstable
} // namespace exec
} // namespace graph
} // namespace omni