omni/graph/exec/unstable/IBase.h

File members: omni/graph/exec/unstable/IBase.h

// Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include <omni/core/IObject.h>

namespace omni
{
namespace graph
{
namespace exec
{
namespace unstable
{

class IBase;
class IBase_abi;
class ExecutionTask;

class IBase_abi : public omni::core::Inherits<omni::core::IObject, OMNI_TYPE_ID("omni.graph.exec.unstable.IBase")>
{
protected:
    virtual void* castWithoutAcquire_abi(omni::core::TypeId id) noexcept = 0;

    virtual uint32_t getUseCount_abi() noexcept = 0;
};

} // namespace unstable
} // namespace exec
} // namespace graph
} // namespace omni

#define OMNI_BIND_INCLUDE_INTERFACE_DECL
#include <omni/graph/exec/unstable/IBase.gen.h>

namespace omni
{
namespace graph
{
namespace exec
{
namespace unstable
{

class IBase : public omni::core::Generated<omni::graph::exec::unstable::IBase_abi>
{
};

template <typename T, typename U>
inline T* cast(U* ptr) noexcept
{
    static_assert(std::is_base_of<IBase, T>::value, "cast can only be used with classes that derive from IBase");
    if (ptr)
    {
        return reinterpret_cast<T*>(ptr->castWithoutAcquire(T::kTypeId));
    }
    else
    {
        return nullptr;
    }
}

template <typename T, typename U>
inline T* cast(omni::core::ObjectParam<U> ptr) noexcept
{
    static_assert(std::is_base_of<IBase, T>::value, "cast can only be used with classes that derive from IBase");
    if (ptr)
    {
        return reinterpret_cast<T*>(ptr->castWithoutAcquire(T::kTypeId));
    }
    else
    {
        return nullptr;
    }
}

template <typename T, typename U>
inline T* cast(omni::core::ObjectPtr<U> ptr) noexcept
{
    static_assert(std::is_base_of<IBase, T>::value, "cast can only be used with classes that derive from IBase");
    if (ptr)
    {
        return reinterpret_cast<T*>(ptr->castWithoutAcquire(T::kTypeId));
    }
    else
    {
        return nullptr;
    }
}

#ifndef DOXYGEN_BUILD
namespace details
{
template <typename T>
inline void* castWithoutAcquire(T* obj, omni::core::TypeId id) noexcept; // forward declaration
} // namespace details
#endif

template <typename T, typename... Rest>
struct ImplementsCastWithoutAcquire : public T, public Rest...
{
public:
    inline void* cast(omni::core::TypeId id) noexcept
    {
        // note: this implementation is needed to disambiguate which `cast` to call when using multiple inheritance. it
        // has zero-overhead.
        return static_cast<T*>(this)->cast(id);
    }

    inline void* castWithoutAcquire(omni::core::TypeId id) noexcept
    {
        // note: this implementation is needed to disambiguate which `cast` to call when using multiple inheritance. it
        // has zero-overhead.
        return static_cast<T*>(this)->castWithoutAcquire(id);
    }

private:
    // given a type id, castImpl() check if the type id matches T's typeid.  if not, T's parent class type id is
    // checked. if T's parent class type id does not match, the grandparent class's type id is check.  this continues
    // until IObject's type id is checked.
    //
    // if no type id in T's inheritance chain match, the next interface in Rest is checked.
    //
    // it's expected the compiler can optimize away the recursion
    template <typename U, typename... Args>
    inline void* castImpl(omni::core::TypeId id) noexcept
    {
        // omni::core::detail::cast will march down the inheritance chain
        void* obj = omni::core::detail::cast<U>(this, id);
        if (nullptr == obj)
        {
            // check the next class (inheritance chain) provide in the inheritance list
            return castImpl<Args...>(id);
        }

        return obj;
    }

    // given a type id, castWithoutAcquireImpl() check if the type id matches T's typeid.  if not, T's parent class type
    // id is checked. if T's parent class type id does not match, the grandparent class's type id is check.  this
    // continues until IObject's type id is checked.
    //
    // if no type id in T's inheritance chain match, the next interface in Rest is checked.
    //
    // it's expected the compiler can optimize away the recursion
    template <typename U, typename... Args>
    inline void* castWithoutAcquireImpl(omni::core::TypeId id) noexcept
    {
        // details::castWithoutAcquire will march down the inheritance chain
        void* obj = details::castWithoutAcquire<U>(this, id);
        if (nullptr == obj)
        {
            // check the next class (inheritance chain) provide in the inheritance list
            return castWithoutAcquireImpl<Args...>(id);
        }

        return obj;
    }

    // this terminates walking across the types in the variadic template
    // LCOV_EXCL_START
    template <int = 0>
    inline void* castImpl(omni::core::TypeId) noexcept
    {
        return nullptr;
    }
    // LCOV_EXCL_STOP

    // this terminates walking across the types in the variadic template
    template <int = 0>
    inline void* castWithoutAcquireImpl(omni::core::TypeId) noexcept
    {
        return nullptr;
    }

protected:
    virtual ~ImplementsCastWithoutAcquire() noexcept = default;

    void* cast_abi(omni::core::TypeId id) noexcept override
    {
        return castImpl<T, Rest...>(id);
    }

    void* castWithoutAcquire_abi(omni::core::TypeId id) noexcept override
    {
        return castWithoutAcquireImpl<T, Rest...>(id);
    }
};

template <typename T, typename... Rest>
struct Implements : public ImplementsCastWithoutAcquire<T, Rest...>
{
public:
    inline void acquire() noexcept
    {
        // note: this implementation is needed to disambiguate which `cast` to call when using multiple inheritance. it
        // has zero-overhead.
        static_cast<T*>(this)->acquire();
    }

    inline void release() noexcept
    {
        // note: this implementation is needed to disambiguate which `cast` to call when using multiple inheritance. it
        // has zero-overhead.
        static_cast<T*>(this)->release();
    }

    inline uint32_t getUseCount() noexcept
    {
        // note: this implementation is needed to disambiguate which `cast` to call when using multiple inheritance. it
        // has zero-overhead.
        return static_cast<T*>(this)->getUseCount();
    }

protected:
    std::atomic<uint32_t> m_refCount{ 1 };

    virtual ~Implements() noexcept = default;

    void acquire_abi() noexcept override
    {
        m_refCount.fetch_add(1, std::memory_order_relaxed);
    }

    void release_abi() noexcept override
    {
        if (0 == m_refCount.fetch_sub(1, std::memory_order_release) - 1)
        {
            std::atomic_thread_fence(std::memory_order_acquire);
            delete this;
        }
    }

    uint32_t getUseCount_abi() noexcept override
    {
        return m_refCount;
    }
};

#ifndef DOXYGEN_BUILD
namespace details
{
template <typename T>
inline void* castWithoutAcquire(T* obj, omni::core::TypeId id) noexcept
{
    if (T::kTypeId == id)
    {
        return obj;
    }
    else
    {
        return castWithoutAcquire<typename T::BaseType>(obj, id); // call cast again, but with the parent type
    }
}

template <>
inline void* castWithoutAcquire<IBase>(IBase* obj, omni::core::TypeId id) noexcept
{
    if (IBase::kTypeId == id)
    {
        return obj;
    }
    else
    {
        return nullptr;
    }
}
} // namespace details
#endif

template <typename T>
inline uint32_t useCount(T* ptr) noexcept
{
    static_assert(std::is_base_of<IBase, T>::value, "useCount can only be used with classes that derive from IBase");
    if (ptr)
    {
        return ptr->getUseCount();
    }
    else
    {
        return 0;
    }
}

} // namespace unstable
} // namespace exec
} // namespace graph
} // namespace omni

#define OMNI_BIND_INCLUDE_INTERFACE_IMPL
#include <omni/graph/exec/unstable/IBase.gen.h>