carb/tasking/TaskingHelpers.h

File members: carb/tasking/TaskingHelpers.h

// Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "TaskingTypes.h"

#include "../thread/Futex.h"
#include "../cpp17/Functional.h"
#include "../cpp17/Optional.h"
#include "../cpp17/Variant.h"

#include <atomic>
#include <chrono>
#include <iterator>
#include <vector>

namespace carb
{
namespace tasking
{

#ifndef DOXYGEN_BUILD
namespace detail
{

template <class T>
struct is_literal_string
{
    constexpr static bool value = false;
};

template <size_t N>
struct is_literal_string<const char (&)[N]>
{
    constexpr static bool value = true;
};

Counter* const kListOfCounters{ (Counter*)(size_t)-1 };

template <class Rep, class Period>
uint64_t convertDuration(const std::chrono::duration<Rep, Period>& dur)
{
    auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(thread::detail::clampDuration(dur)).count();
    return uint64_t(::carb_max(std::chrono::nanoseconds::rep(0), ns));
}

template <class Clock, class Duration>
uint64_t convertAbsTime(const std::chrono::time_point<Clock, Duration>& tp)
{
    return convertDuration(tp - Clock::now());
}

template <class F, class Tuple, size_t... I, class... Args>
decltype(auto) applyExtraImpl(F&& f, Tuple&& t, std::index_sequence<I...>, Args&&... args)
{
    CARB_UNUSED(t); // Can get C4100: unreferenced formal parameter on MSVC when Tuple is empty.
    return cpp17::invoke(std::forward<F>(f), std::get<I>(std::forward<Tuple>(t))..., std::forward<Args>(args)...);
}

template <class F, class Tuple, class... Args>
decltype(auto) applyExtra(F&& f, Tuple&& t, Args&&... args)
{
    return applyExtraImpl(std::forward<F>(f), std::forward<Tuple>(t),
                          std::make_index_sequence<std::tuple_size<std::remove_reference_t<Tuple>>::value>{},
                          std::forward<Args>(args)...);
}

// U looks like an iterator convertible to V when dereferenced
template <class U, class V>
using IsForwardIter = carb::cpp17::conjunction<
    carb::cpp17::negation<
        typename std::is_convertible<typename std::iterator_traits<U>::iterator_category, std::random_access_iterator_tag>>,
    typename std::is_convertible<typename std::iterator_traits<U>::iterator_category, std::forward_iterator_tag>,
    std::is_convertible<decltype(*std::declval<U&>()), V>>;

template <class U, class V>
using IsRandomAccessIter = carb::cpp17::conjunction<
    typename std::is_convertible<typename std::iterator_traits<U>::iterator_category, std::random_access_iterator_tag>,
    std::is_convertible<decltype(*std::declval<U&>()), V>>;

// Must fit within a pointer, be trivially move constructible and trivially destructible.
template <class Functor>
using FitsWithinPointerTrivially =
    carb::cpp17::conjunction<carb::cpp17::bool_constant<sizeof(typename std::decay_t<Functor>) <= sizeof(void*)>,
                             std::is_trivially_move_constructible<typename std::decay_t<Functor>>,
                             std::is_trivially_destructible<typename std::decay_t<Functor>>>;

template <class Functor, std::enable_if_t<FitsWithinPointerTrivially<Functor>::value, bool> = false>
inline void generateTaskFunc(TaskDesc& desc, Functor&& func)
{
    // Use SFINAE to have this version of generateTaskFunc() contribute to resolution only if Functor will fit within a
    // void*, so that we can use the taskArg as the instance. On my machine, this is about a tenth of the time for the
    // below specialization, and happens more frequently.
    using Func = typename std::decay_t<Functor>;
    union
    {
        Func f;
        void* v;
    } u{ std::forward<Functor>(func) };
    desc.taskArg = u.v;
    desc.task = [](void* arg) {
        union CARB_ATTRIBUTE(visibility("hidden"))
        {
            void* v;
            Func f;
        } u{ arg };
        u.f();
    };
    // Func is trivially destructible so we don't need a cancel func
}

template <class Functor, std::enable_if_t<!FitsWithinPointerTrivially<Functor>::value, bool> = false>
inline void generateTaskFunc(TaskDesc& desc, Functor&& func)
{
    // Use SFINAE to have this version of generateTaskFunc() contribute to resolution only if Functor will NOT fit
    // within a void*, so that the heap can be used only if necessary
    using Func = typename std::decay_t<Functor>;
    // Need to allocate
    desc.taskArg = new Func(std::forward<Functor>(func));
    desc.task = [](void* arg) {
        std::unique_ptr<Func> p(static_cast<Func*>(arg));
        (*p)();
    };
    desc.cancel = [](void* arg) { delete reinterpret_cast<Func*>(arg); };
}

template <class T>
class SharedState;

template <>
class SharedState<void>
{
    std::atomic_size_t m_refs;

public:
    SharedState(bool futureRetrieved) noexcept : m_refs(1 + futureRetrieved), m_futureRetrieved(futureRetrieved)
    {
    }
    virtual ~SharedState() = default;

    void addRef() noexcept
    {
        m_refs.fetch_add(1, std::memory_order_relaxed);
    }

    void release()
    {
        if (m_refs.fetch_sub(1, std::memory_order_release) == 1)
        {
            std::atomic_thread_fence(std::memory_order_acquire);
            delete this;
        }
    }

    void set()
    {
        CARB_FATAL_UNLESS(m_futex.exchange(isTask() ? eTaskPending : eReady, std::memory_order_acq_rel) == eUnset,
                          "Value already set");
    }
    void get()
    {
    }

    void notify();

    void markReady()
    {
        m_futex.store(eReady, std::memory_order_release);
    }

    bool ready() const
    {
        return m_futex.load(std::memory_order_relaxed) == eReady;
    }

    bool isTask() const
    {
        return m_object.type == ObjectType::eTaskContext;
    }

    enum State : uint8_t
    {
        eReady = 0,
        eUnset,
        eInProgress,
        eTaskPending,
    };

    std::atomic<State> m_futex{ eUnset };
    std::atomic_bool m_futureRetrieved{ false };
    Object m_object{ ObjectType::eFutex1, &m_futex };
};

template <class T>
class SharedState<T&> final : public SharedState<void>
{
public:
    SharedState(bool futureRetrieved) noexcept : SharedState<void>(futureRetrieved)
    {
    }

    bool isSet() const noexcept
    {
        return m_value != nullptr;
    }

    T& get() const
    {
        CARB_FATAL_UNLESS(m_value, "Attempting to retrieve value from broken promise");
        return *m_value;
    }
    void set(T& val)
    {
        CARB_FATAL_UNLESS(m_futex.exchange(eInProgress, std::memory_order_acquire) == 1, "Value already set");
        m_value = std::addressof(val);
        m_futex.store(this->isTask() ? eTaskPending : eReady, std::memory_order_release);
    }

    T* m_value{ nullptr };
};

template <class T>
class SharedState final : public SharedState<void>
{
public:
    using Type = typename std::decay<T>::type;

    SharedState(bool futureRetrieved) noexcept : SharedState<void>(futureRetrieved)
    {
    }

    bool isSet() const noexcept
    {
        return m_type.has_value();
    }

    const T& get_ref() const
    {
        CARB_FATAL_UNLESS(m_type, "Attempting to retrieve value from broken promise");
        return m_type.value();
    }
    T get()
    {
        CARB_FATAL_UNLESS(m_type, "Attempting to retrieve value from broken promise");
        return std::move(m_type.value());
    }
    void set(const T& value)
    {
        CARB_FATAL_UNLESS(m_futex.exchange(eInProgress, std::memory_order_acquire) == 1, "Value already set");
        m_type.emplace(value);
        m_futex.store(this->isTask() ? eTaskPending : eReady, std::memory_order_release);
    }
    void set(T&& value)
    {
        CARB_FATAL_UNLESS(m_futex.exchange(eInProgress, std::memory_order_acquire) == 1, "Value already set");
        m_type.emplace(std::move(value));
        m_futex.store(this->isTask() ? eTaskPending : eReady, std::memory_order_release);
    }

    carb::cpp17::optional<Type> m_type;
};

} // namespace detail
#endif

class TaskGroup;

struct RequiredObject final : public Object
{
    constexpr RequiredObject(std::nullptr_t) : Object{ ObjectType::eNone, nullptr }
    {
    }

    template <class T, std::enable_if_t<std::is_convertible<T, Counter*>::value, bool> = false>
    constexpr RequiredObject(T&& c) : Object{ ObjectType::eCounter, static_cast<Counter*>(c) }
    {
    }

    template <class T, std::enable_if_t<std::is_convertible<T, TaskContext>::value, bool> = true>
    constexpr RequiredObject(T&& tc)
        : Object{ ObjectType::eTaskContext, reinterpret_cast<void*>(static_cast<TaskContext>(tc)) }
    {
    }

    constexpr RequiredObject(const TaskGroup& tg);

    constexpr RequiredObject(const TaskGroup* tg);

private:
    friend struct ITasking;
    template <class U>
    friend class Future;
    template <class U>
    friend class SharedFuture;

    constexpr RequiredObject(const Object& o) : Object(o)
    {
    }

    void get(TaskDesc& desc);
};

struct All final
{
    All(std::initializer_list<RequiredObject> il);

    template <class InputIt, std::enable_if_t<detail::IsForwardIter<InputIt, RequiredObject>::value, bool> = false>
    All(InputIt begin, InputIt end);

    template <class InputIt, std::enable_if_t<detail::IsRandomAccessIter<InputIt, RequiredObject>::value, bool> = false>
    All(InputIt begin, InputIt end);

    operator RequiredObject() const
    {
        return RequiredObject(m_counter);
    }

private:
    friend struct RequiredObject;
    Counter* m_counter;

    operator Counter*() const
    {
        return m_counter;
    }
};

struct Any final
{
    Any(std::initializer_list<RequiredObject> il);

    template <class InputIt, std::enable_if_t<detail::IsForwardIter<InputIt, RequiredObject>::value, bool> = false>
    Any(InputIt begin, InputIt end);

    template <class InputIt, std::enable_if_t<detail::IsRandomAccessIter<InputIt, RequiredObject>::value, bool> = false>
    Any(InputIt begin, InputIt end);

    operator RequiredObject() const
    {
        return RequiredObject(m_counter);
    }

private:
    friend struct RequiredObject;
    Counter* m_counter;

    operator Counter*() const
    {
        return m_counter;
    }
};

struct Tracker final : Object
{
    constexpr Tracker(std::nullptr_t) : Object{ ObjectType::eNone, nullptr }
    {
    }

    template <class T, std::enable_if_t<std::is_convertible<T, Counter*>::value, bool> = false>
    constexpr Tracker(T&& c) : Object{ ObjectType::eCounter, reinterpret_cast<void*>(static_cast<Counter*>(c)) }
    {
    }

    template <class T, std::enable_if_t<std::is_convertible<T, const char*>::value, bool> = false>
    constexpr Tracker(T&& name)
        : Object{ detail::is_literal_string<T>::value ? ObjectType::eTaskNameLiteral : ObjectType::eTaskName,
                  const_cast<void*>(reinterpret_cast<const void*>(name)) }
    {
    }

    Tracker(Future<>& fut) : Object{ ObjectType::ePtrTaskContext, fut.ptask() }
    {
    }

    Tracker(Future<>* fut) : Object{ ObjectType::ePtrTaskContext, fut ? fut->ptask() : nullptr }
    {
    }

    Tracker(SharedFuture<>& fut) : Object{ ObjectType::ePtrTaskContext, fut.ptask() }
    {
    }

    Tracker(SharedFuture<>* fut) : Object{ ObjectType::ePtrTaskContext, fut ? fut->ptask() : nullptr }
    {
    }

    constexpr Tracker(TaskContext& ctx) : Object{ ObjectType::ePtrTaskContext, &ctx }
    {
    }

    constexpr Tracker(TaskContext* ctx) : Object{ ObjectType::ePtrTaskContext, ctx }
    {
    }

    Tracker(TaskGroup& grp);

    Tracker(TaskGroup* grp);

private:
    friend struct Trackers;
};

struct Trackers final
{
    template <class T, std::enable_if_t<std::is_constructible<Tracker, T>::value, bool> = false>
    constexpr Trackers(T&& t) : m_variant(Tracker(t))
    {
    }

    constexpr Trackers(std::initializer_list<Tracker> il) : m_variant(carb::cpp17::in_place_index<1>)
    {
        if (il.size() == 1)
            m_variant.emplace<Tracker>(*il.begin());
        else
        {
            auto& vec = carb::cpp17::get<1>(m_variant);
            vec.reserve(il.size());
            vec.insert(vec.end(), il.begin(), il.end());
        }
    }

    Trackers(std::initializer_list<Tracker> il, Tracker const* p, size_t count)
        : m_variant(carb::cpp17::in_place_index<1>)
    {
        if ((il.size() + count) == 1)
        {
            m_variant.emplace<Tracker>(il.size() == 0 ? *p : *il.begin());
        }
        else
        {
            auto& vec = carb::cpp17::get<1>(m_variant);
            vec.reserve(il.size() + count);
            vec.insert(vec.end(), il.begin(), il.end());
            vec.insert(vec.end(), p, p + count);
        }
    }

    void output(Tracker const*& trackers, size_t& count) const
    {
        static_assert(sizeof(Object) == sizeof(Tracker), "");
        fill(reinterpret_cast<Object const*&>(trackers), count);
    }

    CARB_PREVENT_COPY(Trackers);

    Trackers(Trackers&&) = default;
    Trackers& operator=(Trackers&&) = default;

private:
    friend struct ITasking;
    using Variant = carb::cpp17::variant<Tracker, std::vector<Tracker>>;
    Variant m_variant;
    Counter* fill(carb::tasking::Object const*& trackers, size_t& count) const
    {
        if (auto* vec = carb::cpp17::get_if<1>(&m_variant))
        {
            trackers = vec->data();
            count = vec->size();
        }
        else
        {
            const Tracker& t = carb::cpp17::get<0>(m_variant);
            trackers = &t;
            count = 1;
        }
        return detail::kListOfCounters;
    }
};

#define CARB_ASYNC

#define CARB_MAYBE_ASYNC

#define CARB_IS_ASYNC                                                                                                  \
    (::carb::getCachedInterface<carb::tasking::ITasking>()->getTaskContext() != ::carb::tasking::kInvalidTaskContext)

#define CARB_ASSERT_ASYNC CARB_ASSERT(CARB_IS_ASYNC)

#define CARB_CHECK_ASYNC CARB_CHECK(CARB_IS_ASYNC)

#define CARB_FATAL_UNLESS_ASYNC CARB_FATAL_UNLESS(CARB_IS_ASYNC, "Not running in task context!")

} // namespace tasking
} // namespace carb