omni/core/IObject.h

File members: omni/core/IObject.h

// Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "Assert.h"
#include "Result.h"
#include "TypeId.h"
#include "OmniAttr.h"

#include <atomic>
#include <climits> // CHAR_BITS
#include <type_traits>

namespace omni
{

namespace core
{

// we assume 8-bit chars
static_assert(CHAR_BIT == 8, "non-octet char is not supported");

class IObject_abi;
class IObject;

class OMNI_ATTR("no_py") IObject_abi
{
public:
#ifndef DOXYGEN_SHOULD_SKIP_THIS
    enum : TypeId
    {
        kTypeId = OMNI_TYPE_ID("omni.core.IObject")
    };
#endif

protected:
    virtual void* cast_abi(TypeId id) noexcept = 0;

    virtual void acquire_abi() noexcept = 0;

    virtual void release_abi() noexcept = 0;
};

} // namespace core
} // namespace omni

#define OMNI_BIND_INCLUDE_INTERFACE_DECL
#include "IObject.gen.h"

namespace omni
{
namespace core
{

class IObject : public omni::core::Generated<omni::core::IObject_abi>
{
};

template <typename BASE, TypeId TYPEID>
class Inherits : public BASE
{
public:
#ifndef DOXYGEN_BUILD
    enum : TypeId
    {
        kTypeId = TYPEID
    };
    using BaseType = BASE;
#endif
};

#ifndef DOXYGEN_BUILD
namespace detail
{
class BorrowPtrType
{
public:
    explicit BorrowPtrType() noexcept = default;
};
} // namespace detail
#endif

constexpr detail::BorrowPtrType kBorrow{};

#ifndef DOXYGEN_BUILD
namespace detail
{
class StealPtrType
{
public:
    explicit StealPtrType() noexcept = default;
};
} // namespace detail
#endif

constexpr detail::StealPtrType kSteal{};

template <typename T>
class ObjectPtr
{
public:
    constexpr ObjectPtr(std::nullptr_t = nullptr) noexcept
    {
    }

    ObjectPtr(T* other, detail::BorrowPtrType) noexcept : m_ptr(other)
    {
        addRef();
    }

    constexpr ObjectPtr(T* other, detail::StealPtrType) noexcept : m_ptr(other)
    {
    }

    ObjectPtr(const ObjectPtr& other) noexcept : m_ptr(other.m_ptr)
    {
        addRef();
    }

    template <typename U>
    ObjectPtr(const ObjectPtr<U>& other) noexcept : m_ptr(other.m_ptr)
    {
        addRef();
    }

    template <typename U>
    ObjectPtr(ObjectPtr<U>&& other) noexcept : m_ptr(std::exchange(other.m_ptr, {}))
    {
    }

    ~ObjectPtr() noexcept
    {
        releaseRef();
    }

    ObjectPtr& operator=(const ObjectPtr& other) noexcept
    {
        copyRef(other.m_ptr);
        return *this;
    }

    ObjectPtr& operator=(ObjectPtr&& other) noexcept
    {
        if (this != &other)
        {
            releaseRef();
            m_ptr = std::exchange(other.m_ptr, {});
        }

        return *this;
    }

    template <typename U>
    ObjectPtr& operator=(const ObjectPtr<U>& other) noexcept
    {
        copyRef(other.m_ptr);
        return *this;
    }

    template <typename U>
    ObjectPtr& operator=(ObjectPtr<U>&& other) noexcept
    {
        releaseRef();
        m_ptr = std::exchange(other.m_ptr, {});
        return *this;
    }

    explicit operator bool() const noexcept
    {
        return m_ptr != nullptr;
    }

    T* operator->() const noexcept
    {
        return m_ptr;
    }

    T& operator*() const noexcept
    {
        return *m_ptr;
    }

    T* get() const noexcept
    {
        return m_ptr;
    }

    T** put() noexcept
    {
        OMNI_ASSERT(m_ptr == nullptr);
        return &m_ptr;
    }

    void steal(T* value) noexcept
    {
        releaseRef();
        *put() = value;
    }

    T* detach() noexcept
    {
        return std::exchange(m_ptr, {});
    }

    void borrow(T* value) noexcept
    {
        releaseRef();
        *put() = value;
        addRef();
    }

    template <typename To>
    ObjectPtr<To> as() const noexcept
    {
        if (!m_ptr)
        {
            return nullptr; // dynamic_cast allows a nullptr, so we do as well
        }
        else
        {
            return ObjectPtr<To>(reinterpret_cast<To*>(m_ptr->cast(To::kTypeId)), kSteal);
        }
    }

    template <typename To>
    void as(ObjectPtr<To>& to) const noexcept
    {
        if (!m_ptr)
        {
            to.steal(nullptr); // dynamic_cast allows a nullptr, so we do as well
        }
        else
        {
            to.steal(reinterpret_cast<To*>(m_ptr->cast(To::kTypeId)));
        }
    }

    void release() noexcept
    {
        releaseRef();
    }

    void reset(T* value = nullptr) noexcept
    {
        if (value)
        {
            const_cast<std::remove_const_t<T>*>(value)->acquire();
        }

        T* oldval = std::exchange(m_ptr, value);

        if (oldval)
        {
            oldval->release();
        }
    }

private:
    void copyRef(T* other) noexcept
    {
        if (m_ptr != other)
        {
            releaseRef();
            m_ptr = other;
            addRef();
        }
    }

    void addRef() const noexcept
    {
        if (m_ptr)
        {
            const_cast<std::remove_const_t<T>*>(m_ptr)->acquire();
        }
    }

    void releaseRef() noexcept
    {
        if (m_ptr)
        {
            std::exchange(m_ptr, {})->release();
        }
    }

    template <typename U>
    friend class ObjectPtr;

    T* m_ptr{};
};

// Breathe/Sphinx is unable to handle these overloads and produces warnings.  Since we don't like warnings, remove these
// overloads from the docs until Breathe/Sphinx is updated.
#ifndef DOXYGEN_SHOULD_SKIP_THIS

template <typename T>
inline bool operator<(const ObjectPtr<T>& left, const ObjectPtr<T>& right) noexcept
{
    return (left.get() < right.get());
}

template <typename T>
inline bool operator==(const ObjectPtr<T>& left, const ObjectPtr<T>& right) noexcept
{
    return (left.get() == right.get());
}

template <typename T>
inline bool operator==(const ObjectPtr<T>& left, const T* right) noexcept
{
    return (left.get() == right);
}

template <typename T>
inline bool operator==(const T* left, const ObjectPtr<T>& right) noexcept
{
    return (left == right.get());
}

template <typename T>
inline bool operator==(const ObjectPtr<T>& left, std::nullptr_t) noexcept
{
    return (left.get() == nullptr);
}

template <typename T>
inline bool operator==(std::nullptr_t, const ObjectPtr<T>& right) noexcept
{
    return (right.get() == nullptr);
}

template <typename T>
inline bool operator!=(const ObjectPtr<T>& left, const ObjectPtr<T>& right) noexcept
{
    return (left.get() != right.get());
}

template <typename T>
inline bool operator!=(const ObjectPtr<T>& left, const T* right) noexcept
{
    return (left.get() != right);
}

template <typename T>
inline bool operator!=(const T* left, const ObjectPtr<T>& right) noexcept
{
    return (left != right.get());
}

template <typename T>
inline bool operator!=(const ObjectPtr<T>& left, std::nullptr_t) noexcept
{
    return (left.get() != nullptr);
}

template <typename T>
inline bool operator!=(std::nullptr_t, const ObjectPtr<T>& right) noexcept
{
    return (right.get() != nullptr);
}

#endif // DOXYGEN_SHOULD_SKIP_THIS

template <typename T>
inline ObjectPtr<T> steal(T* ptr) noexcept
{
    return ObjectPtr<T>(ptr, kSteal);
}

template <typename T>
inline ObjectPtr<T> borrow(T* ptr) noexcept
{
    return ObjectPtr<T>(ptr, kBorrow);
}

template <typename T, typename U>
inline ObjectPtr<T> cast(U* ptr) noexcept
{
    static_assert(std::is_base_of<IObject, T>::value, "cast can only be used with classes that derive from IObject");
    if (ptr)
    {
        return ObjectPtr<T>{ reinterpret_cast<T*>(ptr->cast(T::kTypeId)), kSteal };
    }
    else
    {
        return { nullptr };
    }
}

#ifndef DOXYGEN_BUILD
namespace detail
{
template <typename T>
inline void* cast(T* obj, TypeId id) noexcept; // forward declaration
} // namespace detail
#endif

template <typename T, typename... Rest>
struct ImplementsCast : public T, public Rest...
{
public:
    inline void* cast(omni::core::TypeId id) noexcept
    {
        // note: this implementation is needed to disambiguate which `cast` to call when using multiple inheritance. it
        // has zero-overhead.
        return static_cast<T*>(this)->cast(id);
    }

private:
    // given a type id, castImpl() check if the type id matches T's typeid.  if not, T's parent class type id is
    // checked. if T's parent class type id does not match, the grandparent class's type id is check.  this continues
    // until IObject's type id is checked.
    //
    // if no type id in T's inheritance chain match, the next interface in Rest is checked.
    //
    // it's expected the compiler can optimize away the recursion
    template <typename U, typename... Args>
    inline void* castImpl(TypeId id) noexcept
    {
        // detail::cast will march down the inheritance chain
        void* obj = omni::core::detail::cast<U>(this, id);
        if (nullptr == obj)
        {
            // check the next class (inheritance chain) provide in the inheritance list
            return castImpl<Args...>(id);
        }

        return obj;
    }

    // this terminates walking across the types in the variadic template
    template <int = 0>
    inline void* castImpl(TypeId) noexcept
    {
        return nullptr;
    }

protected:
    virtual ~ImplementsCast() noexcept = default;

    virtual void* cast_abi(TypeId id) noexcept override
    {
        return castImpl<T, Rest...>(id);
    }
};

template <typename T, typename... Rest>
struct Implements : public ImplementsCast<T, Rest...>
{
public:
    inline void acquire() noexcept
    {
        // note: this implementation is needed to disambiguate which `cast` to call when using multiple inheritance. it
        // has zero-overhead.
        static_cast<T*>(this)->acquire();
    }

    inline void release() noexcept
    {
        // note: this implementation is needed to disambiguate which `cast` to call when using multiple inheritance. it
        // has zero-overhead.
        static_cast<T*>(this)->release();
    }

protected:
    std::atomic<uint32_t> m_refCount{ 1 };

    virtual ~Implements() noexcept = default;

    virtual void acquire_abi() noexcept override
    {
        m_refCount.fetch_add(1, std::memory_order_relaxed);
    }

    virtual void release_abi() noexcept override
    {
        if (0 == m_refCount.fetch_sub(1, std::memory_order_release) - 1)
        {
            std::atomic_thread_fence(std::memory_order_acquire);
            delete this;
        }
    }
};

#ifndef DOXYGEN_BUILD
namespace detail
{
template <typename T>
inline void* cast(T* obj, TypeId id) noexcept
{
    if (T::kTypeId == id)
    {
        obj->acquire(); // match! since we return an interface pointer, acquire() must be called.
        return obj;
    }
    else
    {
        return cast<typename T::BaseType>(obj, id); // call cast again, but with the parent type
    }
}

template <>
inline void* cast<IObject>(IObject* obj, TypeId id) noexcept
{
    if (IObject::kTypeId == id)
    {
        obj->acquire();
        return obj;
    }
    else
    {
        return nullptr;
    }
}
} // namespace detail
#endif

} // namespace core
} // namespace omni

#define OMNI_BIND_INCLUDE_INTERFACE_IMPL
#include "IObject.gen.h"