Program Listing for omni/graph/core/ogn/Database.h

↰ Return to documentation for omni/graph/core/ogn/Database.h

// Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto.  Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once

#include <carb/InterfaceUtils.h>
#include <omni/fabric/IToken.h>
#include <carb/logging/Log.h>

#include <omni/graph/core/IAttributeType.h>
#include <omni/graph/core/iComputeGraph.h>
#include <omni/graph/core/IVariable2.h>
#include <omni/graph/core/StringUtils.h>
#include <omni/graph/core/ogn/RuntimeAttribute.h>
#include <omni/graph/core/ogn/SimpleRuntimeAttribute.h>
using omni::fabric::IToken;

// =================================================================================================================
// This file contains simple interface classes which wrap data in the OGN database for easier use.
//
//    OmniGraphDatabase  Base class for generated node database classes. Provides common functionality.
//
// WARNING: These interfaces are subject to change without warning and are only meant to be used by generated code.
//          If you call them directly you may have to modify your code when they change.
// =================================================================================================================

// Helper definitions for hardcoded metadata names.
// These should match the Python constants in the MetadataKey object found in the file
// source/extensions/omni.graph.tools/python/node_generator/keys.py
#define kOgnMetadataAllowMultiInputs "allowMultiInputs"
#define kOgnMetadataAllowedTokens "allowedTokens"
#define kOgnMetadataAllowedTokensRaw "__allowedTokens"
#define kOgnMetadataCategories "__categories"
#define kOgnMetadataCategoryDescriptions "__categoryDescriptions"
#define kOgnMetadataCudaPointers "__cudaPointers"
#define kOgnMetadataDefault "__default"
#define kOgnMetadataDescription "__description"
#define kOgnMetadataExclusions "__exclusions"
#define kOgnMetadataExtension "__extension"
#define kOgnMetadataHidden "hidden"
#define kOgnMetadataIconBackgroundColor "__iconBackgroundColor"
#define kOgnMetadataIconBorderColor "__iconBorderColor"
#define kOgnMetadataIconColor "__iconColor"
#define kOgnMetadataIconPath "__icon"
#define kOgnMetadataInternal "internal"
#define kOgnMetadataLanguage "__language"
#define kOgnMetadataMemoryType "__memoryType"
#define kOgnMetadataObjectId "__objectId"
#define kOgnMetadataOptional "__optional"
#define kOgnMetadataOutputOnly "outputOnly"
#define kOgnMetadataLiteralOnly "literalOnly"
#define kOgnMetadataTags "tags"
#define kOgnMetadataTokens "__tokens"
#define kOgnMetadataUiName "uiName"
#define kOgnMetadataUiType "uiType"
#define kOgnSingletonName "singleton"


namespace omni {
namespace graph {
namespace core {

class Node;

namespace ogn {
// The following type aliases are internal and are meant to be used as opaque types.
// The underlying type definition can change in future releases.
using InputAttribute = ogn::RuntimeAttribute<ogn::kOgnInput, ogn::kCpu>;
using OutputAttribute = ogn::RuntimeAttribute<ogn::kOgnOutput, ogn::kCpu>;
using VariableAttribute = RuntimeAttribute<ogn::kOgnState, ogn::kCpu>;
using DynamicInput = ogn::SimpleInput<const InputAttribute, ogn::kCpu>;
using DynamicOutput = ogn::SimpleOutput<OutputAttribute, ogn::kCpu>;
using DynamicState = ogn::SimpleState<VariableAttribute, ogn::kCpu>;

// ======================================================================
class OmniGraphDatabase
{
protected:
    GraphContextObj const* m_graphContextHandles{ nullptr };
    NodeObj const* m_nodeHandles{ nullptr };
    InstanceIndex m_offset{ 0 };
    size_t m_handleCounts{ 0 };

    void _ctor(GraphContextObj const* graphContexts, NodeObj const* nodeObjects, size_t handleCount)
    {
        m_graphContextHandles = graphContexts;
        m_nodeHandles = nodeObjects;
        m_handleCounts = handleCount;
    }

    OmniGraphDatabase(GraphContextObj const* graphContexts, NodeObj const* nodeObjects, size_t handleCount)
    { _ctor(graphContexts, nodeObjects, handleCount); }

    OmniGraphDatabase() = default;

public:

    Type typeFromName(NameToken typeNameToken) const
    {
        auto typeInterface = carb::getCachedInterface<omni::graph::core::IAttributeType>();
        if (!typeInterface)
        {
            CARB_LOG_ERROR_ONCE("Could not acquire the IAttributeType interface");
            return {};
        }
        auto typeName = tokenToString(typeNameToken);
        return typeInterface->typeFromOgnTypeName(typeName);
    }

    NameToken stringToToken(const char* tokenName) const
    {
        auto tokenInterface = carb::getCachedInterface<omni::fabric::IToken>();
        if (!tokenInterface)
        {
            CARB_LOG_ERROR_ONCE("Failed to initialize node type - no token interface");
            return omni::fabric::kUninitializedToken;
        }
        return tokenInterface->getHandle(tokenName);
    }

    const char* tokenToString(NameToken token) const
    {
        auto tokenInterface = carb::getCachedInterface<omni::fabric::IToken>();
        if (!tokenInterface)
        {
            CARB_LOG_ERROR_ONCE("Failed to initialize node type - no token interface");
            return nullptr;
        }
        return tokenInterface->getText(token);
    }

    TargetPath stringToPath(const char* pathString) const
    {
        auto pathInterface = carb::getCachedInterface<omni::fabric::IPath>();
        if (!pathInterface)
        {
            CARB_LOG_ERROR_ONCE("Failed to initialize node type - no path interface");
            return omni::fabric::kUninitializedPath;
        }
        return pathInterface->getHandle(pathString);
    }

    const char* pathToString(const TargetPath path) const
    {
        auto pathInterface = carb::getCachedInterface<omni::fabric::IPath>();
        if (!pathInterface)
        {
            CARB_LOG_ERROR_ONCE("Failed to initialize node type - no path interface");
            return nullptr;
        }
        return pathInterface->getText(path);
    }

    NameToken pathToToken(const TargetPath path) const
    {
        auto pathInterface = carb::getCachedInterface<omni::fabric::IPath>();
        auto tokenInterface = carb::getCachedInterface<omni::fabric::IToken>();
        if (!pathInterface || !tokenInterface)
        {
            CARB_LOG_ERROR_ONCE("Failed to initialize node type - no path or token interface");
            return omni::fabric::kUninitializedToken;
        }
        return tokenInterface->getHandle(pathInterface->getText(path));
    }

    TargetPath tokenToPath(const NameToken pathString) const
    {
        auto pathInterface = carb::getCachedInterface<omni::fabric::IPath>();
        auto tokenInterface = carb::getCachedInterface<omni::fabric::IToken>();
        if (!pathInterface || !tokenInterface)
        {
            CARB_LOG_ERROR_ONCE("Failed to initialize node type - no path or token interface");
            return omni::fabric::kUninitializedPath;
        }
        return pathInterface->getHandle(tokenInterface->getText(pathString));
    }

    const GraphContextObj& abi_context(InstanceIndex relativeIdx = { 0 }) const
    {
        InstanceIndex idx = m_offset + relativeIdx;
        if (idx.index < m_handleCounts)
            return m_graphContextHandles[idx.index];
        return m_graphContextHandles[0];
    }

    const NodeObj& abi_node(InstanceIndex relativeIdx = { 0 }) const
    {
        InstanceIndex idx = m_offset + relativeIdx;
        if (idx.index < m_handleCounts)
            return m_nodeHandles[idx.index];
        return m_nodeHandles[0];
    }

    size_t getGraphTotalInstanceCount() const
    {
        NodeObj const& nodeObj = abi_node();
        GraphObj graphObj = nodeObj.iNode->getGraph(nodeObj);
        return graphObj.iGraph->getInstanceCount(graphObj);
    }

    template <typename UserDataType>
    UserDataType* userData(InstanceIndex relativeIdx = { 0 }) const
    {
        NodeObj const& obj = abi_node(relativeIdx);
        return reinterpret_cast<UserDataType*>(obj.iNode->getUserData(obj));
    }

    template <typename... Args>
    static void logMessage(NodeObj const& nodeObj, InstanceIndex inst, Severity sev, const char* fmt, Args&&... args)
    {
        if (sizeof...(args) == 0)
        {
            nodeObj.iNode->logComputeMessageOnInstance(nodeObj, inst, sev, fmt);
        }
        else
        {
            std::string msg = formatString(fmt, args...);
            nodeObj.iNode->logComputeMessageOnInstance(nodeObj, inst, sev, msg.c_str());
        }
    }

    template <typename... Args>
    static void logError(NodeObj const& nodeObj, const char* fmt, Args&&... args)
    { logMessage(nodeObj, kAccordingToContextIndex, Severity::eError, fmt, args...); }

    template <typename... Args>
    static void logError(NodeObj const& nodeObj, InstanceIndex inst, const char* fmt, Args&&... args)
    { logMessage(nodeObj, inst, Severity::eError, fmt, args...); }

    template <typename... Args>
    static void logWarning(NodeObj const& nodeObj, const char* fmt, Args&&... args)
    { logMessage(nodeObj, kAccordingToContextIndex, Severity::eWarning, fmt, args...); }

    template <typename... Args>
    static void logWarning(NodeObj const& nodeObj, InstanceIndex inst, const char* fmt, Args&&... args)
    { logMessage(nodeObj, inst, Severity::eWarning, fmt, args...); }

    template <typename... Args>
    void logError(const char* fmt, Args&&... args) { logError(abi_node(), m_offset, fmt, args...); }

    template <typename... Args>
    void logError(InstanceIndex relativeIdx, const char* fmt, Args&&... args) { logError(abi_node(relativeIdx), m_offset + relativeIdx, fmt, args...); }

    template <typename... Args>
    void logWarning(const char* fmt, Args&&... args) { logWarning(abi_node(), m_offset, fmt, args...); }

    template <typename... Args>
    void logWarning(InstanceIndex relativeIdx, const char* fmt, Args&&... args) { logWarning(abi_node(relativeIdx), m_offset + relativeIdx, fmt, args...); }

    VariableAttribute getVariable(NameToken token, InstanceIndex relativeIdx = { 0 })
    {
        return getVariable(tokenToString(token), relativeIdx);
    }

    VariableAttribute getVariable(const char* variableName, InstanceIndex relativeIdx = { 0 })
    {
        NodeObj const& obj = abi_node(relativeIdx);
        GraphContextObj const& ctx = abi_context(relativeIdx);
        auto graphObj = obj.iNode->getGraph(obj);
        auto variable = graphObj.iGraph->findVariable(graphObj, variableName);
        if (!variable)
            return VariableAttribute();
        auto handle = ctx.iContext->getVariableDataHandle(ctx, variable, m_offset + relativeIdx);
        return VariableAttribute(ctx, handle);
    }

    NameToken getGraphTarget(InstanceIndex relativeIdx = { 0 }) const
    {
        GraphContextObj const& ctx = abi_context(relativeIdx);
        return ctx.iContext->getGraphTarget(ctx, relativeIdx+m_offset);
    }

    gsl::span<NameToken const> getGraphTargets(size_t count) const
    {
        GraphContextObj const& ctx = abi_context();
        return { &ctx.iContext->getGraphTarget(ctx, m_offset), count };
    }


    inline void moveToNextInstance()
    {
        ++m_offset.index;
    }
    inline void resetToFirstInstance()
    {
        m_offset = { 0 };
    }

    inline const InstanceIndex& getInstanceIndex() const
    {
        return m_offset;
    }

protected:

    template <AttributePortType portType, typename TAttribute>
    bool tryGetDynamicAttributes(size_t staticAttributeCount, std::vector<TAttribute>& dynamicAttributes)
    {
        NodeObj const& obj = abi_node();
        GraphContextObj const& ctx = abi_context();
        auto totalAttributeCount = obj.iNode->getAttributeCount(obj);
        if (totalAttributeCount > staticAttributeCount)
        {
            dynamicAttributes.reserve(totalAttributeCount - staticAttributeCount);
            std::vector<AttributeObj> allAttributes(totalAttributeCount);
            obj.iNode->getAttributes(obj, allAttributes.data(), totalAttributeCount);

            bool foundAny = false;
            for (auto const& __a : allAttributes)
            {
                if (__a.iAttribute->isDynamic(__a) && __a.iAttribute->getPortType(__a) == portType)
                {
                    foundAny = true;
                    auto __h = __a.iAttribute->getAttributeDataHandle(__a, kAccordingToContextIndex);
                    dynamicAttributes.emplace_back(m_offset.index);
                    dynamicAttributes.back()().reset(ctx, __h, __a);
                }
            }

            return foundAny;
        }

        return false;
    }

    template <AttributePortType portType>
    bool tryGetDynamicAttributes(size_t staticAttributeCount, std::vector<DynamicInput>& dynamicAttributes)
    {
        NodeObj const& obj = abi_node();
        GraphContextObj const& ctx = abi_context();
        auto totalAttributeCount = obj.iNode->getAttributeCount(obj);
        if (totalAttributeCount > staticAttributeCount)
        {
            dynamicAttributes.reserve(totalAttributeCount - staticAttributeCount);
            std::vector<AttributeObj> allAttributes(totalAttributeCount);
            obj.iNode->getAttributes(obj, allAttributes.data(), totalAttributeCount);

            bool foundAny = false;
            for (auto const& __a : allAttributes)
            {
                if (__a.iAttribute->isDynamic(__a) &&
                    __a.iAttribute->getPortType(__a) == portType)
                {
                    foundAny = true;
                    auto __h = __a.iAttribute->getAttributeDataHandle(__a, kAccordingToContextIndex);
                    dynamicAttributes.emplace_back(m_offset.index);
                    const_cast<typename std::remove_const_t<ogn::RuntimeAttribute<ogn::kOgnInput, ogn::kCpu>&>>(dynamicAttributes.back()())
                        .reset(ctx, __h, __a);
                }
            }

            return foundAny;
        }

        return false;
    }

    template<typename TAttribute>
    void onDynamicAttributeCreated(std::vector<TAttribute>& dynamicAttributes, AttributeObj const& attribute)
    {
        auto handle = attribute.iAttribute->getAttributeDataHandle(attribute, kAccordingToContextIndex);
        dynamicAttributes.emplace_back(m_offset.index);
        dynamicAttributes.back()().reset(abi_context(), handle, attribute);
    }

    void onDynamicInputsCreated(std::vector<DynamicInput>& dynamicInputs, AttributeObj const& attribute)
    {
        auto handle = attribute.iAttribute->getAttributeDataHandle(attribute, kAccordingToContextIndex);
        dynamicInputs.emplace_back(m_offset.index);
        const_cast<typename std::remove_const_t<ogn::RuntimeAttribute<ogn::kOgnInput, ogn::kCpu>&>>(
            dynamicInputs.back()())
            .reset(abi_context(), handle, attribute);
    }

    template<typename TAttribute>
    void onDynamicAttributeRemoved(std::vector<TAttribute>& dynamicAttributes, AttributeObj const& attribute)
    {
        auto handle = attribute.iAttribute->getAttributeDataHandle(attribute, kAccordingToContextIndex);
        for (auto it = dynamicAttributes.begin(); it != dynamicAttributes.end(); ++it)
        {
            if ((*it)().abi_handle() == handle)
            {
                dynamicAttributes.erase(it);
                return;
            }
        }
    }

    void onDynamicAttributeCreatedOrRemoved(std::vector<DynamicInput>& inputs,
                                            std::vector<DynamicOutput>& outputs,
                                            std::vector<DynamicState>& states,
                                            AttributeObj const& attribute,
                                            bool isAttributeCreated)
    {
        if (! attribute.iAttribute) return;
        switch (attribute.iAttribute->getPortType(attribute))
        {
        case AttributePortType::kAttributePortType_Input:
        {
            if (isAttributeCreated)
            {
                onDynamicInputsCreated(inputs, attribute);
            }
            else
            {
                onDynamicAttributeRemoved(inputs, attribute);
            }
            break;
        }
        case AttributePortType::kAttributePortType_Output:
        {
            if (isAttributeCreated)
            {
                onDynamicAttributeCreated(outputs, attribute);
            }
            else
            {
                onDynamicAttributeRemoved(outputs, attribute);
            }
            break;
        }
        case AttributePortType::kAttributePortType_State:
        {
            if (isAttributeCreated)
            {
                onDynamicAttributeCreated(states, attribute);
            }
            else
            {
                onDynamicAttributeRemoved(states, attribute);
            }
            break;
        }
        default:
            break;
        }
    }

    void collectMappedAttributes(std::vector<NameToken>& mappedAttributes)
    {
        NodeObj const& obj = abi_node();
        auto totalAttributeCount = obj.iNode->getAttributeCount(obj);
        std::vector<AttributeObj> allAttributes(totalAttributeCount);
        obj.iNode->getAttributes(obj, allAttributes.data(), totalAttributeCount);
        for (auto const& __a : allAttributes)
        {
            NameToken mapping = __a.iAttribute->getTargetMapping(__a);
            if (mapping != fabric::kUninitializedToken)
                mappedAttributes.push_back(__a.iAttribute->getNameToken(__a));
        }
        //keep sorted
        std::sort(mappedAttributes.begin(), mappedAttributes.end());
    }

    void updateMappedAttributes(std::vector<NameToken>& mappedAttributes, AttributeObj const& attr)
    {
        bool const isMapped = attr.iAttribute->getTargetMapping(attr) != fabric::kUninitializedToken;
        NameToken const token = attr.iAttribute->getNameToken(attr);
        if (isMapped)
        {
            //might be a change of mapping, make sure it is not already here
            auto found = std::lower_bound(mappedAttributes.begin(), mappedAttributes.end(), token);
            if (found == mappedAttributes.end() || *found != token)
                mappedAttributes.insert(found, token);
        }
        else
        {
            // not mapped anymore: remove from list
            auto found = std::lower_bound(mappedAttributes.begin(), mappedAttributes.end(), token);
            if (found != mappedAttributes.end() && *found == token)
                mappedAttributes.erase(found);
        }
    }

    bool validateNode() const
    {
        NodeObj const& nodeObj = abi_node();
        return nodeObj.iNode->validateCompute(nodeObj);
    }
};

template <eMemoryType MemoryType>
static inline ogn::RuntimeAttribute<ogn::kOgnInput, MemoryType> constructInputFromOutput(
    OmniGraphDatabase const& db,
    ogn::RuntimeAttribute<ogn::kOgnOutput, MemoryType> const& output,
    NameToken outputToken)
{
    auto const& nodeObj = db.abi_node();
    auto const& context = db.abi_context();
    auto const& resultAttribute = nodeObj.iNode->getAttributeByToken(nodeObj, outputToken);
    auto handle = resultAttribute.iAttribute->getConstAttributeDataHandle(resultAttribute, db.getInstanceIndex());

    return ogn::RuntimeAttribute<ogn::kOgnInput, MemoryType>(context, handle, output.type());
}

} // namespace ogn
} // namespace core
} // namespace graph
} // namespace omni