carb/tasking/TaskingHelpers.h

File members: carb/tasking/TaskingHelpers.h

// Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "TaskingTypes.h"

#include "../thread/Futex.h"
#include "../cpp/Functional.h"
#include "../cpp/Optional.h"
#include "../cpp/Variant.h"
#include "../Memory.h"

#include <atomic>
#include <chrono>
#include <future>
#include <iterator>
#include <vector>

namespace carb
{
namespace tasking
{

#ifndef DOXYGEN_BUILD
namespace detail
{

template <class T>
struct CarbDeleter
{
    void operator()(T* p) noexcept
    {
        p->~T();
        carb::deallocate(p);
    }
};

template <class T, class U = std::remove_cv_t<T>>
inline void carbDelete(T* p) noexcept
{
    if (p)
    {
        p->~T();
        carb::deallocate(const_cast<U*>(p));
    }
}

[[noreturn]] inline void future_error(std::future_errc err, const char* explanation)
{
#    if CARB_EXCEPTIONS_ENABLED
    CARB_UNUSED(explanation);
    throw std::future_error(err);
#    else
    CARB_FATAL_UNLESS(false, "future_errc(%d): %s", int(err), explanation);
#    endif
}

template <class T>
inline void check_future_valid(T* me)
{
    CARB_UNLIKELY_IF(!me->valid())
    {
        future_error(std::future_errc::no_state, "no state");
    }
}

template <class T>
struct is_literal_string
{
    constexpr static bool value = false;
};

template <size_t N>
struct is_literal_string<const char (&)[N]>
{
    constexpr static bool value = true;
};

Counter* const kListOfCounters{ (Counter*)(size_t)-1 };

template <class Rep, class Period>
uint64_t convertDuration(const std::chrono::duration<Rep, Period>& dur)
{
    auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(thread::detail::clampDuration(dur)).count();
    return uint64_t(::carb_max(std::chrono::nanoseconds::rep(0), ns));
}

template <class Clock, class Duration>
uint64_t convertAbsTime(const std::chrono::time_point<Clock, Duration>& tp)
{
    return convertDuration(tp - Clock::now());
}

template <class F, class Tuple, size_t... I, class... Args>
decltype(auto) applyExtraImpl(F&& f, Tuple&& t, std::index_sequence<I...>, Args&&... args)
{
    CARB_UNUSED(t); // Can get C4100: unreferenced formal parameter on MSVC when Tuple is empty.
    return cpp::invoke(std::forward<F>(f), std::get<I>(std::forward<Tuple>(t))..., std::forward<Args>(args)...);
}

template <class F, class Tuple, class... Args>
decltype(auto) applyExtra(F&& f, Tuple&& t, Args&&... args)
{
    return applyExtraImpl(std::forward<F>(f), std::forward<Tuple>(t),
                          std::make_index_sequence<std::tuple_size<std::remove_reference_t<Tuple>>::value>{},
                          std::forward<Args>(args)...);
}

// U looks like an iterator convertible to V when dereferenced
template <class U, class V>
using IsForwardIter = carb::cpp::conjunction<
    carb::cpp::negation<
        typename std::is_convertible<typename std::iterator_traits<U>::iterator_category, std::random_access_iterator_tag>>,
    typename std::is_convertible<typename std::iterator_traits<U>::iterator_category, std::forward_iterator_tag>,
    std::is_convertible<decltype(*std::declval<U&>()), V>>;

template <class U, class V>
using IsRandomAccessIter = carb::cpp::conjunction<
    typename std::is_convertible<typename std::iterator_traits<U>::iterator_category, std::random_access_iterator_tag>,
    std::is_convertible<decltype(*std::declval<U&>()), V>>;

// Must exactly fit within a pointer, appropriate alignment, and be trivially copyable
template <class T>
using TriviallyBitCastable = carb::cpp::conjunction<carb::cpp::bool_constant<sizeof(T) == sizeof(void*)>,
                                                    carb::cpp::bool_constant<(alignof(T) <= alignof(void*))>,
                                                    std::is_trivially_copyable<T>>;

// Must fit within a pointer, appropriate alignment, and be trivially copyable
template <class T>
using TriviallyFitsWithinPointer = carb::cpp::conjunction<carb::cpp::bool_constant<(sizeof(T) < sizeof(void*))>,
                                                          carb::cpp::bool_constant<(alignof(T) <= alignof(void*))>,
                                                          std::is_trivially_copyable<T>>;

template <class Functor,
          class FuncType = typename std::decay_t<Functor>,
          std::enable_if_t<TriviallyBitCastable<FuncType>::value, bool> = false>
inline void generateTaskFunc(TaskDesc& desc, Functor&& func)
{
    // Use SFINAE to have this version of generateTaskFunc() contribute to resolution only if Functor is bit_cast-able
    // to void*
    FuncType f(std::forward<Functor>(func));
    desc.taskArg = cpp::bit_cast<void*>(f);
    desc.task = [](void* arg) {
        FuncType f = cpp::bit_cast<FuncType>(arg);
        f();
    };
}

template <class Functor,
          class FuncType = typename std::decay_t<Functor>,
          std::enable_if_t<TriviallyFitsWithinPointer<FuncType>::value, bool> = false>
inline void generateTaskFunc(TaskDesc& desc, Functor&& func)
{
    // Use SFINAE to have this version of generateTaskFunc() contribute to resolution only if Functor is smaller than a
    // void*, so that we can use the taskArg as the instance. On my machine, this is about a tenth of the time for the
    // below specialization, and happens more frequently.
    struct Padded // must be the same size as (void*)
    {
        FuncType f;
        char padding[sizeof(void*) - sizeof(f)] = {};
        Padded(Functor&& f) : f(std::forward<Functor>(f))
        {
        }
    };
    static_assert(sizeof(Padded) == sizeof(void*), "Internal error");
    new (&desc.taskArg) Padded(std::forward<Functor>(func));
    desc.task = [](void* arg) {
        Padded p = carb::cpp::bit_cast<Padded>(arg);
        p.f();
    };
    // FuncType is trivially destructible so we don't need a cancel func
}

template <class Functor,
          class FuncType = typename std::decay_t<Functor>,
          std::enable_if_t<cpp::conjunction<cpp::negation<TriviallyFitsWithinPointer<FuncType>>,
                                            cpp::negation<TriviallyBitCastable<FuncType>>>::value,
                           bool> = false>
inline void generateTaskFunc(TaskDesc& desc, Functor&& func)
{
    // Use SFINAE to have this version of generateTaskFunc() contribute to resolution only if Functor will NOT fit
    // within a void*, so that the heap can be used only if necessary
    // Need to allocate
    desc.taskArg = new (carb::allocate(sizeof(FuncType), alignof(FuncType))) FuncType(std::forward<Functor>(func));
    desc.task = [](void* arg) {
        std::unique_ptr<FuncType, detail::CarbDeleter<FuncType>> p(static_cast<FuncType*>(arg));
        (*p)();
    };
    desc.cancel = [](void* arg) { detail::carbDelete(static_cast<FuncType*>(arg)); };
}

template <class T>
class SharedState;

template <>
class SharedState<void>
{
    std::atomic_size_t m_refs{ 1 };

public:
    enum FromTask_t
    {
        FromTask
    };

    constexpr SharedState() noexcept = default;
    // Refs are increased by 1 because the future is "retrieved". Only called from GenerateFuture. It's important that
    // m_object.type gets set to eTaskContext because isTask() must always return true for the life of the object, even
    // if we don't have a TaskContext yet. Otherwise there is a race between the task fulfilling the promise and the
    // future being read (OVCC-1500).
    constexpr SharedState(FromTask_t) noexcept
        : m_refs(2), m_futureRetrieved(true), m_object{ ObjectType::eTaskContext, nullptr }
    {
    }
    virtual ~SharedState() = default;

    void addRef() noexcept
    {
        m_refs.fetch_add(1, std::memory_order_relaxed);
    }

    void release()
    {
        if (m_refs.fetch_sub(1, std::memory_order_release) == 1)
        {
            std::atomic_thread_fence(std::memory_order_acquire);
            detail::carbDelete(this);
        }
    }

    void set()
    {
        State prev = eUnset;
        CARB_UNLIKELY_IF(!m_futex.compare_exchange_strong(prev, isTask() ? eTaskPending : eReady))
        {
            detail::future_error(std::future_errc::promise_already_satisfied, "promise already satisfied");
        }
    }
    void get()
    {
    }

    void notify();

    void markReady()
    {
        m_futex.store(eReady, std::memory_order_release);
    }

    bool ready() const
    {
        return m_futex.load(std::memory_order_acquire) == eReady;
    }

    bool isTask() const
    {
        return m_object.type == ObjectType::eTaskContext;
    }

    enum State : uint8_t
    {
        eReady = 0,
        eUnset,
        eInProgress,
        eTaskPending,
    };

    std::atomic<State> m_futex{ eUnset };
    std::atomic_bool m_futureRetrieved{ false };
    Object m_object{ ObjectType::eFutex1, &m_futex };
};

template <class T>
class SharedState<T&> final : public SharedState<void>
{
public:
    constexpr SharedState() noexcept = default;
    constexpr SharedState(FromTask_t) noexcept : SharedState<void>(FromTask)
    {
    }

    bool isSet() const noexcept
    {
        return m_value != nullptr;
    }

    T& get() const
    {
        CARB_UNLIKELY_IF(!m_value)
        {
            detail::future_error(std::future_errc::broken_promise, "broken promise");
        }
        return *m_value;
    }
    void set(T& val)
    {
        State prev = eUnset;
        CARB_UNLIKELY_IF(!m_futex.compare_exchange_strong(prev, eInProgress, std::memory_order_acquire))
        {
            detail::future_error(std::future_errc::promise_already_satisfied, "promise already satisfied");
        }
        m_value = std::addressof(val);
        const State newState = this->isTask() ? eTaskPending : eReady;
#    if CARB_ASSERT_ENABLED
        CARB_ASSERT(m_futex.exchange(newState, std::memory_order_release) == eInProgress);
#    else
        m_futex.store(newState, std::memory_order_release);
#    endif
    }

    T* m_value{ nullptr };
};

template <class T>
class SharedState final : public SharedState<void>
{
public:
    using Type = typename std::decay<T>::type;

    constexpr SharedState() noexcept = default;
    constexpr SharedState(FromTask_t) noexcept : SharedState<void>(FromTask)
    {
    }

    bool isSet() const noexcept
    {
        return m_type.has_value();
    }

    const T& get_ref() const
    {
        CARB_UNLIKELY_IF(!m_type)
        {
            detail::future_error(std::future_errc::broken_promise, "broken promise");
        }
        return m_type.value();
    }

    T get()
    {
        CARB_UNLIKELY_IF(!m_type)
        {
            detail::future_error(std::future_errc::broken_promise, "broken promise");
        }
        return std::move(m_type.value());
    }
    void set(const T& value)
    {
        State prev = eUnset;
        CARB_UNLIKELY_IF(!m_futex.compare_exchange_strong(prev, eInProgress, std::memory_order_acquire))
        {
            detail::future_error(std::future_errc::promise_already_satisfied, "promise already satisfied");
        }
        m_type.emplace(value);
        const State newState = this->isTask() ? eTaskPending : eReady;
#    if CARB_ASSERT_ENABLED
        CARB_ASSERT(m_futex.exchange(newState, std::memory_order_release) == eInProgress);
#    else
        m_futex.store(newState, std::memory_order_release);
#    endif
    }
    void set(T&& value)
    {
        State prev = eUnset;
        CARB_UNLIKELY_IF(!m_futex.compare_exchange_strong(prev, eInProgress, std::memory_order_acquire))
        {
            detail::future_error(std::future_errc::promise_already_satisfied, "promise already satisfied");
        }
        m_type.emplace(std::move(value));
        const State newState = this->isTask() ? eTaskPending : eReady;
#    if CARB_ASSERT_ENABLED
        CARB_ASSERT(m_futex.exchange(newState, std::memory_order_release) == eInProgress);
#    else
        m_futex.store(newState, std::memory_order_release);
#    endif
    }

    carb::cpp::optional<Type> m_type;
};

} // namespace detail
#endif

class TaskGroup;

struct RequiredObject final : public Object
{
    constexpr RequiredObject(std::nullptr_t) : Object{ ObjectType::eNone, nullptr }
    {
    }

    template <class T, std::enable_if_t<std::is_convertible<T, Counter*>::value, bool> = false>
    constexpr RequiredObject(T&& c) : Object{ ObjectType::eCounter, static_cast<Counter*>(c) }
    {
    }

    template <class T, std::enable_if_t<std::is_convertible<T, TaskContext>::value, bool> = true>
    constexpr RequiredObject(T&& tc)
        : Object{ ObjectType::eTaskContext, reinterpret_cast<void*>(static_cast<TaskContext>(tc)) }
    {
    }

    constexpr RequiredObject(const TaskGroup& tg);

    constexpr RequiredObject(const TaskGroup* tg);

private:
    friend struct ITasking;
    template <class U>
    friend class Future;
    template <class U>
    friend class SharedFuture;

    constexpr RequiredObject(const Object& o) : Object(o)
    {
    }

    void get(TaskDesc& desc);
};

struct All final
{
    All(std::initializer_list<RequiredObject> il);

    template <class InputIt, std::enable_if_t<detail::IsForwardIter<InputIt, RequiredObject>::value, bool> = false>
    All(InputIt begin, InputIt end);

    template <class InputIt, std::enable_if_t<detail::IsRandomAccessIter<InputIt, RequiredObject>::value, bool> = false>
    All(InputIt begin, InputIt end);

    operator RequiredObject() const
    {
        return RequiredObject(m_counter);
    }

private:
    friend struct RequiredObject;
    Counter* m_counter;

    operator Counter*() const
    {
        return m_counter;
    }
};

struct Any final
{
    Any(std::initializer_list<RequiredObject> il);

    template <class InputIt, std::enable_if_t<detail::IsForwardIter<InputIt, RequiredObject>::value, bool> = false>
    Any(InputIt begin, InputIt end);

    template <class InputIt, std::enable_if_t<detail::IsRandomAccessIter<InputIt, RequiredObject>::value, bool> = false>
    Any(InputIt begin, InputIt end);

    operator RequiredObject() const
    {
        return RequiredObject(m_counter);
    }

private:
    friend struct RequiredObject;
    Counter* m_counter;

    operator Counter*() const
    {
        return m_counter;
    }
};

struct Tracker final : Object
{
    constexpr Tracker(std::nullptr_t) : Object{ ObjectType::eNone, nullptr }
    {
    }

    template <class T, std::enable_if_t<std::is_convertible<T, Counter*>::value, bool> = false>
    constexpr Tracker(T&& c) : Object{ ObjectType::eCounter, reinterpret_cast<void*>(static_cast<Counter*>(c)) }
    {
    }

    template <class T, std::enable_if_t<std::is_convertible<T, const char*>::value, bool> = false>
    constexpr Tracker(T&& name)
        : Object{ detail::is_literal_string<T>::value ? ObjectType::eTaskNameLiteral : ObjectType::eTaskName,
                  const_cast<void*>(reinterpret_cast<const void*>(name)) }
    {
    }

    Tracker(Future<>& fut) : Object{ ObjectType::ePtrTaskContext, fut.ptask() }
    {
    }

    Tracker(Future<>* fut) : Object{ ObjectType::ePtrTaskContext, fut ? fut->ptask() : nullptr }
    {
    }

    Tracker(SharedFuture<>& fut) : Object{ ObjectType::ePtrTaskContext, fut.ptask() }
    {
    }

    Tracker(SharedFuture<>* fut) : Object{ ObjectType::ePtrTaskContext, fut ? fut->ptask() : nullptr }
    {
    }

    constexpr Tracker(TaskContext& ctx) : Object{ ObjectType::ePtrTaskContext, &ctx }
    {
    }

    constexpr Tracker(TaskContext* ctx) : Object{ ObjectType::ePtrTaskContext, ctx }
    {
    }

    Tracker(TaskGroup& grp);

    Tracker(TaskGroup* grp);

private:
    friend struct Trackers;
};

struct Trackers final
{
    constexpr Trackers() : m_variant{}
    {
    }

    template <class T, std::enable_if_t<std::is_constructible<Tracker, T>::value, bool> = false>
    constexpr Trackers(T&& t) : m_variant(Tracker(t))
    {
    }

    constexpr Trackers(std::initializer_list<Tracker> il)
    {
        switch (il.size())
        {
            case 0:
                break;
            case 1:
                m_variant.emplace<Tracker>(*il.begin());
                break;
            default:
                m_variant.emplace<TrackerVec>(std::move(il));
        }
    }

    Trackers(std::initializer_list<Tracker> il, Tracker const* p, size_t count)
        : m_variant(carb::cpp::in_place_index<2>)
    {
        switch (il.size() + count)
        {
            case 0:
                break;
            case 1:
                m_variant.emplace<Tracker>(il.size() == 0 ? *p : *il.begin());
                break;
            default:
            {
                auto& vec = m_variant.emplace<TrackerVec>();
                vec.reserve(il.size() + count);
                vec.insert(vec.end(), il.begin(), il.end());
                vec.insert(vec.end(), p, p + count);
            }
        }
    }

    void output(Tracker const*& trackers, size_t& count) const
    {
        static_assert(sizeof(Object) == sizeof(Tracker), "");
        fill(reinterpret_cast<Object const*&>(trackers), count);
    }

    CARB_PREVENT_COPY(Trackers);

    Trackers(Trackers&&) = default;
    Trackers& operator=(Trackers&&) = default;

private:
    friend struct ITasking;
    using TrackerVec = std::vector<Tracker, carb::Allocator<Tracker>>;
    using Variant = carb::cpp::variant<carb::cpp::monostate, Tracker, TrackerVec>;
    Variant m_variant;
    Counter* fill(carb::tasking::Object const*& trackers, size_t& count) const
    {
        if (m_variant.index() == 0)
        {
            trackers = nullptr;
            count = 0;
            return nullptr;
        }

        if (auto* vec = carb::cpp::get_if<TrackerVec>(&m_variant))
        {
            trackers = vec->data();
            count = vec->size();
        }
        else
        {
            const Tracker& t = carb::cpp::get<Tracker>(m_variant);
            trackers = &t;
            count = 1;
        }
        return detail::kListOfCounters;
    }
};

#define CARB_ASYNC

#define CARB_MAYBE_ASYNC

#define CARB_IS_ASYNC                                                                                                  \
    (::carb::getCachedInterface<carb::tasking::ITasking>()->getTaskContext() != ::carb::tasking::kInvalidTaskContext)

#define CARB_ASSERT_ASYNC CARB_ASSERT(CARB_IS_ASYNC)

#define CARB_CHECK_ASYNC CARB_CHECK(CARB_IS_ASYNC)

#define CARB_FATAL_UNLESS_ASYNC CARB_FATAL_UNLESS(CARB_IS_ASYNC, "Not running in task context!")

} // namespace tasking
} // namespace carb