carb/tasking/TaskingUtils.h

File members: carb/tasking/TaskingUtils.h

// Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "ITasking.h"

#include "../thread/Util.h"

#include <atomic>
#include <condition_variable> // for std::cv_status
#include <functional>

namespace carb
{
namespace tasking
{

class RecursiveSharedMutex;

struct SpinMutex
{
public:
    constexpr SpinMutex() noexcept = default;

    CARB_PREVENT_COPY_AND_MOVE(SpinMutex);

    void lock() noexcept
    {
        this_thread::spinWaitWithBackoff([&] { return try_lock(); });
    }

    bool try_lock() noexcept
    {
        return (!mtx.load(std::memory_order_relaxed) && !mtx.exchange(true, std::memory_order_acquire));
    }

    void unlock() noexcept
    {
        mtx.store(false, std::memory_order_release);
    }

private:
    std::atomic_bool mtx{};
};

struct SpinSharedMutex
{
public:
    constexpr SpinSharedMutex() = default;

    CARB_PREVENT_COPY_AND_MOVE(SpinSharedMutex);

    void lock()
    {
        while (!try_lock())
        {
            CARB_HARDWARE_PAUSE();
        }
    }

    bool try_lock()
    {
        int expected = 0;
        return counter.compare_exchange_strong(expected, -1, std::memory_order_acquire, std::memory_order_relaxed);
    }

    void unlock()
    {
        CARB_ASSERT(counter == -1);
        counter.store(0, std::memory_order_release);
    }

    bool try_lock_shared()
    {
        auto ctr = counter.load(std::memory_order_relaxed);
        if (ctr >= 0)
        {
            return counter.compare_exchange_strong(ctr, ctr + 1, std::memory_order_acquire, std::memory_order_relaxed);
        }
        return false;
    }

    void lock_shared()
    {
        auto ctr = counter.load(std::memory_order_relaxed);
        for (;;)
        {
            if (ctr < 0)
            {
                CARB_HARDWARE_PAUSE();
                ctr = counter.load(std::memory_order_relaxed);
            }
            else if (counter.compare_exchange_strong(ctr, ctr + 1, std::memory_order_acquire, std::memory_order_relaxed))
            {
                return;
            }
        }
    }

    void unlock_shared()
    {
        int ctr = counter.fetch_sub(1, std::memory_order_release);
        CARB_ASSERT(ctr > 0);
        CARB_UNUSED(ctr);
    }

private:
    //   0 - unlocked
    // > 0 - Shared lock count
    //  -1 - Exclusive lock
    std::atomic<int> counter{ 0 };
};

class CounterWrapper
{
public:
    CounterWrapper(uint32_t target = 0)
        : m_counter(carb::getCachedInterface<ITasking>()->createCounterWithTarget(target))
    {
    }

    CARB_DEPRECATED("ITasking no longer needed.")
    CounterWrapper(ITasking* tasking, uint32_t target = 0)
        : m_counter(carb::getCachedInterface<ITasking>()->createCounterWithTarget(target))
    {
        CARB_UNUSED(tasking);
    }

    ~CounterWrapper()
    {
        carb::getCachedInterface<ITasking>()->destroyCounter(m_counter);
    }

    CARB_DEPRECATED("The Counter interface is deprecated.") bool check() const
    {
        return try_wait();
    }

    bool try_wait() const
    {
        return carb::getCachedInterface<ITasking>()->try_wait(m_counter);
    }

    void wait() const
    {
        carb::getCachedInterface<ITasking>()->wait(m_counter);
    }

    template <class Rep, class Period>
    bool wait_for(const std::chrono::duration<Rep, Period>& dur) const
    {
        return carb::getCachedInterface<ITasking>()->wait_for(dur, m_counter);
    }

    template <class Clock, class Duration>
    bool wait_until(const std::chrono::time_point<Clock, Duration>& tp) const
    {
        return carb::getCachedInterface<ITasking>()->wait_until(tp, m_counter);
    }

    operator Counter*() const
    {
        return m_counter;
    }

    CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
    {
        return carb::getCachedInterface<ITasking>();
    }

    CARB_PREVENT_COPY_AND_MOVE(CounterWrapper);

private:
    Counter* m_counter;
};

class TaskGroup
{
public:
    CARB_PREVENT_COPY_AND_MOVE(TaskGroup);

    constexpr TaskGroup() = default;

    ~TaskGroup()
    {
        CARB_CHECK(empty(), "Destroying busy TaskGroup!");
    }

    bool empty() const
    {
        // This cannot be memory_order_relaxed because it does not synchronize with anything and would allow the
        // compiler to cache the value or hoist it out of a loop. Acquire semantics will require synchronization with
        // all other locations that release m_count.
        return m_count.load(std::memory_order_acquire) == 0;
    }

    void enter()
    {
        m_count.fetch_add(1, std::memory_order_acquire); // synchronizes-with all other locations releasing m_count
    }

    void leave()
    {
        size_t v = m_count.fetch_sub(1, std::memory_order_release);
        CARB_ASSERT(v, "Mismatched enter()/leave() calls");
        if (v == 1)
        {
            carb::getCachedInterface<ITasking>()->futexWakeup(m_count, UINT_MAX);
        }
    }

    bool try_wait() const
    {
        return empty();
    }

    void wait() const
    {
        size_t v = m_count.load(std::memory_order_acquire); // synchronizes-with all other locations releasing m_count
        if (v)
        {
            carb::getCachedInterface<ITasking>()->wait(*this);
        }
    }

    template <class Rep, class Period>
    bool try_wait_for(std::chrono::duration<Rep, Period> dur)
    {
        return try_wait_until(std::chrono::steady_clock::now() + dur);
    }

    template <class Clock, class Duration>
    bool try_wait_until(std::chrono::time_point<Clock, Duration> when)
    {
        size_t v = m_count.load(std::memory_order_acquire); // synchronizes-with all other locations releasing m_count
        if (v)
        {
            ITasking* tasking = carb::getCachedInterface<ITasking>();
            while (v)
            {
                if (!tasking->futexWaitUntil(m_count, v, when))
                {
                    return false;
                }
                v = m_count.load(std::memory_order_relaxed);
            }
        }
        return true;
    }

    template <class... Args>
    auto with(Args&&... args)
    {
        enter();
        CARB_SCOPE_EXIT
        {
            leave();
        };
        return carb::cpp::invoke(std::forward<Args>(args)...);
    }

private:
    friend struct Tracker;
    friend struct RequiredObject;
    std::atomic_size_t m_count{ 0 };
};

class MutexWrapper
{
public:
    MutexWrapper() : m_mutex(carb::getCachedInterface<ITasking>()->createMutex())
    {
    }

    CARB_DEPRECATED("ITasking no longer needed.")
    MutexWrapper(ITasking*) : m_mutex(carb::getCachedInterface<ITasking>()->createMutex())
    {
    }

    ~MutexWrapper()
    {
        carb::getCachedInterface<ITasking>()->destroyMutex(m_mutex);
    }

    bool try_lock()
    {
        return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, 0);
    }

    void lock()
    {
        carb::getCachedInterface<ITasking>()->lockMutex(m_mutex);
    }

    void unlock()
    {
        carb::getCachedInterface<ITasking>()->unlockMutex(m_mutex);
    }

    template <class Rep, class Period>
    bool try_lock_for(const std::chrono::duration<Rep, Period>& duration)
    {
        return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, detail::convertDuration(duration));
    }

    template <class Clock, class Duration>
    bool try_lock_until(const std::chrono::time_point<Clock, Duration>& time_point)
    {
        return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, detail::convertAbsTime(time_point));
    }

    operator Mutex*() const
    {
        return m_mutex;
    }

    CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
    {
        return carb::getCachedInterface<ITasking>();
    }

    CARB_PREVENT_COPY_AND_MOVE(MutexWrapper);

private:
    Mutex* m_mutex;
};

class RecursiveMutexWrapper
{
public:
    RecursiveMutexWrapper() : m_mutex(carb::getCachedInterface<ITasking>()->createRecursiveMutex())
    {
    }

    CARB_DEPRECATED("ITasking no longer needed.")
    RecursiveMutexWrapper(ITasking*) : m_mutex(carb::getCachedInterface<ITasking>()->createRecursiveMutex())
    {
    }

    ~RecursiveMutexWrapper()
    {
        carb::getCachedInterface<ITasking>()->destroyMutex(m_mutex);
    }

    bool try_lock()
    {
        return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, 0);
    }

    void lock()
    {
        carb::getCachedInterface<ITasking>()->lockMutex(m_mutex);
    }

    void unlock()
    {
        carb::getCachedInterface<ITasking>()->unlockMutex(m_mutex);
    }

    template <class Rep, class Period>
    bool try_lock_for(const std::chrono::duration<Rep, Period>& duration)
    {
        return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, detail::convertDuration(duration));
    }

    template <class Clock, class Duration>
    bool try_lock_until(const std::chrono::time_point<Clock, Duration>& time_point)
    {
        return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, detail::convertAbsTime(time_point));
    }

    operator Mutex*() const
    {
        return m_mutex;
    }

    CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
    {
        return carb::getCachedInterface<ITasking>();
    }

    CARB_PREVENT_COPY_AND_MOVE(RecursiveMutexWrapper);

private:
    Mutex* m_mutex;
};

class SemaphoreWrapper
{
public:
    SemaphoreWrapper(unsigned value) : m_sema(carb::getCachedInterface<ITasking>()->createSemaphore(value))
    {
    }

    CARB_DEPRECATED("ITasking no longer needed.")
    SemaphoreWrapper(ITasking*, unsigned value) : m_sema(carb::getCachedInterface<ITasking>()->createSemaphore(value))
    {
    }

    ~SemaphoreWrapper()
    {
        carb::getCachedInterface<ITasking>()->destroySemaphore(m_sema);
    }

    void release(unsigned count = 1)
    {
        carb::getCachedInterface<ITasking>()->releaseSemaphore(m_sema, count);
    }

    void acquire()
    {
        carb::getCachedInterface<ITasking>()->waitSemaphore(m_sema);
    }

    bool try_acquire()
    {
        return carb::getCachedInterface<ITasking>()->timedWaitSemaphore(m_sema, 0);
    }

    template <class Rep, class Period>
    bool try_acquire_for(const std::chrono::duration<Rep, Period>& dur)
    {
        return carb::getCachedInterface<ITasking>()->timedWaitSemaphore(m_sema, detail::convertDuration(dur));
    }

    template <class Clock, class Duration>
    bool try_acquire_until(const std::chrono::time_point<Clock, Duration>& tp)
    {
        return carb::getCachedInterface<ITasking>()->timedWaitSemaphore(m_sema, detail::convertAbsTime(tp));
    }

    operator Semaphore*() const
    {
        return m_sema;
    }

    CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
    {
        return carb::getCachedInterface<ITasking>();
    }

    CARB_PREVENT_COPY_AND_MOVE(SemaphoreWrapper);

private:
    Semaphore* m_sema;
};

class SharedMutexWrapper
{
public:
    SharedMutexWrapper() : m_mutex(carb::getCachedInterface<ITasking>()->createSharedMutex())
    {
    }

    CARB_DEPRECATED("ITasking no longer needed.")
    SharedMutexWrapper(ITasking*) : m_mutex(carb::getCachedInterface<ITasking>()->createSharedMutex())
    {
    }

    ~SharedMutexWrapper()
    {
        carb::getCachedInterface<ITasking>()->destroySharedMutex(m_mutex);
    }

    bool try_lock_shared()
    {
        return carb::getCachedInterface<ITasking>()->timedLockSharedMutex(m_mutex, 0);
    }

    bool try_lock()
    {
        return carb::getCachedInterface<ITasking>()->timedLockSharedMutexExclusive(m_mutex, 0);
    }

    template <class Rep, class Period>
    bool try_lock_for(const std::chrono::duration<Rep, Period>& duration)
    {
        return carb::getCachedInterface<ITasking>()->timedLockSharedMutexExclusive(
            m_mutex, detail::convertDuration(duration));
    }

    template <class Rep, class Period>
    bool try_lock_shared_for(const std::chrono::duration<Rep, Period>& duration)
    {
        return carb::getCachedInterface<ITasking>()->timedLockSharedMutex(m_mutex, detail::convertDuration(duration));
    }

    template <class Clock, class Duration>
    bool try_lock_until(const std::chrono::time_point<Clock, Duration>& time_point)
    {
        return try_lock_for(time_point - Clock::now());
    }

    template <class Clock, class Duration>
    bool try_lock_shared_until(const std::chrono::time_point<Clock, Duration>& time_point)
    {
        return try_lock_shared_for(time_point - Clock::now());
    }

    void lock_shared()
    {
        carb::getCachedInterface<ITasking>()->lockSharedMutex(m_mutex);
    }

    void unlock_shared()
    {
        carb::getCachedInterface<ITasking>()->unlockSharedMutex(m_mutex);
    }

    void lock()
    {
        carb::getCachedInterface<ITasking>()->lockSharedMutexExclusive(m_mutex);
    }

    void unlock()
    {
        carb::getCachedInterface<ITasking>()->unlockSharedMutex(m_mutex);
    }

    operator SharedMutex*() const
    {
        return m_mutex;
    }

    CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
    {
        return carb::getCachedInterface<ITasking>();
    }

    CARB_PREVENT_COPY_AND_MOVE(SharedMutexWrapper);

private:
    SharedMutex* m_mutex;
};

class ConditionVariableWrapper
{
public:
    ConditionVariableWrapper() : m_cv(carb::getCachedInterface<ITasking>()->createConditionVariable())
    {
    }

    CARB_DEPRECATED("ITasking no longer needed.")
    ConditionVariableWrapper(ITasking*) : m_cv(carb::getCachedInterface<ITasking>()->createConditionVariable())
    {
    }

    ~ConditionVariableWrapper()
    {
        carb::getCachedInterface<ITasking>()->destroyConditionVariable(m_cv);
    }

    void wait(Mutex* m)
    {
        carb::getCachedInterface<ITasking>()->waitConditionVariable(m_cv, m);
    }

    template <class Pred>
    void wait(Mutex* m, Pred&& pred)
    {
        carb::getCachedInterface<ITasking>()->waitConditionVariablePred(m_cv, m, std::forward<Pred>(pred));
    }

    template <class Rep, class Period>
    std::cv_status wait_for(Mutex* m, const std::chrono::duration<Rep, Period>& duration)
    {
        return carb::getCachedInterface<ITasking>()->timedWaitConditionVariable(
                   m_cv, m, detail::convertDuration(duration)) ?
                   std::cv_status::no_timeout :
                   std::cv_status::timeout;
    }

    template <class Rep, class Period, class Pred>
    bool wait_for(Mutex* m, const std::chrono::duration<Rep, Period>& duration, Pred&& pred)
    {
        return carb::getCachedInterface<ITasking>()->timedWaitConditionVariablePred(
            m_cv, m, detail::convertDuration(duration), std::forward<Pred>(pred));
    }

    template <class Clock, class Duration>
    std::cv_status wait_until(Mutex* m, const std::chrono::time_point<Clock, Duration>& time_point)
    {
        return carb::getCachedInterface<ITasking>()->timedWaitConditionVariable(
                   m_cv, m, detail::convertAbsTime(time_point)) ?
                   std::cv_status::no_timeout :
                   std::cv_status::timeout;
    }

    template <class Clock, class Duration, class Pred>
    bool wait_until(Mutex* m, const std::chrono::time_point<Clock, Duration>& time_point, Pred&& pred)
    {
        return carb::getCachedInterface<ITasking>()->timedWaitConditionVariablePred(
            m_cv, m, detail::convertAbsTime(time_point), std::forward<Pred>(pred));
    }

    void notify_one()
    {
        carb::getCachedInterface<ITasking>()->notifyConditionVariableOne(m_cv);
    }

    void notify_all()
    {
        carb::getCachedInterface<ITasking>()->notifyConditionVariableAll(m_cv);
    }

    operator ConditionVariable*() const
    {
        return m_cv;
    }

    CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
    {
        return carb::getCachedInterface<ITasking>();
    }

    CARB_PREVENT_COPY_AND_MOVE(ConditionVariableWrapper);

private:
    ConditionVariable* m_cv;
};

class ScopedTracking
{
public:
    ScopedTracking() : m_tracker{ ObjectType::eNone, nullptr }
    {
    }

    ScopedTracking(Trackers trackers);

    ~ScopedTracking();

    CARB_PREVENT_COPY(ScopedTracking);

    ScopedTracking(ScopedTracking&& rhs);

    ScopedTracking& operator=(ScopedTracking&& rhs) noexcept;

private:
    Object m_tracker;
};

inline constexpr RequiredObject::RequiredObject(const TaskGroup& tg)
    : Object{ ObjectType::eTaskGroup, const_cast<std::atomic_size_t*>(&tg.m_count) }
{
}

inline constexpr RequiredObject::RequiredObject(const TaskGroup* tg)
    : Object{ ObjectType::eTaskGroup, tg ? const_cast<std::atomic_size_t*>(&tg->m_count) : nullptr }
{
}

inline All::All(std::initializer_list<RequiredObject> il)
{
    static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
    m_counter = carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAll, il.begin(), il.size());
}

template <class InputIt, std::enable_if_t<detail::IsForwardIter<InputIt, RequiredObject>::value, bool>>
inline All::All(InputIt begin, InputIt end)
{
    static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
    std::vector<RequiredObject> objects;
    for (; begin != end; ++begin)
        objects.push_back(*begin);
    m_counter =
        carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAll, objects.data(), objects.size());
}

template <class InputIt, std::enable_if_t<detail::IsRandomAccessIter<InputIt, RequiredObject>::value, bool>>
inline All::All(InputIt begin, InputIt end)
{
    static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
    size_t const count = end - begin;
    RequiredObject* objects = CARB_STACK_ALLOC(RequiredObject, count);
    size_t index = 0;
    for (; begin != end; ++begin)
        objects[index++] = *begin;
    CARB_ASSERT(index == count);
    m_counter = carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAll, objects, count);
}

inline Any::Any(std::initializer_list<RequiredObject> il)
{
    static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
    m_counter = carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAny, il.begin(), il.size());
}

template <class InputIt, std::enable_if_t<detail::IsForwardIter<InputIt, RequiredObject>::value, bool>>
inline Any::Any(InputIt begin, InputIt end)
{
    static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
    std::vector<RequiredObject> objects;
    for (; begin != end; ++begin)
        objects.push_back(*begin);
    m_counter =
        carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAny, objects.data(), objects.size());
}

template <class InputIt, std::enable_if_t<detail::IsRandomAccessIter<InputIt, RequiredObject>::value, bool>>
inline Any::Any(InputIt begin, InputIt end)
{
    static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
    size_t const count = end - begin;
    RequiredObject* objects = CARB_STACK_ALLOC(RequiredObject, count);
    size_t index = 0;
    for (; begin != end; ++begin)
        objects[index++] = *begin;
    CARB_ASSERT(index == count);
    m_counter = carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAny, objects, count);
}

inline Tracker::Tracker(TaskGroup& grp) : Object{ ObjectType::eTaskGroup, &grp.m_count }
{
}

inline Tracker::Tracker(TaskGroup* grp) : Object{ ObjectType::eTaskGroup, grp ? &grp->m_count : nullptr }
{
}

inline ScopedTracking::ScopedTracking(Trackers trackers)
{
    Tracker const* ptrackers;
    size_t numTrackers;
    trackers.output(ptrackers, numTrackers);
    m_tracker = carb::getCachedInterface<ITasking>()->beginTracking(ptrackers, numTrackers);
}

inline ScopedTracking::~ScopedTracking()
{
    Object tracker = std::exchange(m_tracker, { ObjectType::eNone, nullptr });
    if (tracker.type == ObjectType::eTrackerGroup)
    {
        carb::getCachedInterface<ITasking>()->endTracking(tracker);
    }
}

inline ScopedTracking::ScopedTracking(ScopedTracking&& rhs)
    : m_tracker(std::exchange(rhs.m_tracker, { ObjectType::eNone, nullptr }))
{
}

inline ScopedTracking& ScopedTracking::operator=(ScopedTracking&& rhs) noexcept
{
    std::swap(m_tracker, rhs.m_tracker);
    return *this;
}

} // namespace tasking
} // namespace carb