carb/tasking/TaskingUtils.h
File members: carb/tasking/TaskingUtils.h
// Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include "ITasking.h"
#include "../thread/Util.h"
#include <atomic>
#include <condition_variable> // for std::cv_status
#include <functional>
namespace carb
{
namespace tasking
{
class RecursiveSharedMutex;
struct SpinMutex
{
public:
constexpr SpinMutex() noexcept = default;
CARB_PREVENT_COPY_AND_MOVE(SpinMutex);
void lock() noexcept
{
this_thread::spinWaitWithBackoff([&] { return try_lock(); });
}
bool try_lock() noexcept
{
return (!mtx.load(std::memory_order_relaxed) && !mtx.exchange(true, std::memory_order_acquire));
}
void unlock() noexcept
{
mtx.store(false, std::memory_order_release);
}
private:
std::atomic_bool mtx{};
};
struct SpinSharedMutex
{
public:
constexpr SpinSharedMutex() = default;
CARB_PREVENT_COPY_AND_MOVE(SpinSharedMutex);
void lock()
{
while (!try_lock())
{
CARB_HARDWARE_PAUSE();
}
}
bool try_lock()
{
int expected = 0;
return counter.compare_exchange_strong(expected, -1, std::memory_order_acquire, std::memory_order_relaxed);
}
void unlock()
{
CARB_ASSERT(counter == -1);
counter.store(0, std::memory_order_release);
}
bool try_lock_shared()
{
auto ctr = counter.load(std::memory_order_relaxed);
if (ctr >= 0)
{
return counter.compare_exchange_strong(ctr, ctr + 1, std::memory_order_acquire, std::memory_order_relaxed);
}
return false;
}
void lock_shared()
{
auto ctr = counter.load(std::memory_order_relaxed);
for (;;)
{
if (ctr < 0)
{
CARB_HARDWARE_PAUSE();
ctr = counter.load(std::memory_order_relaxed);
}
else if (counter.compare_exchange_strong(ctr, ctr + 1, std::memory_order_acquire, std::memory_order_relaxed))
{
return;
}
}
}
void unlock_shared()
{
int ctr = counter.fetch_sub(1, std::memory_order_release);
CARB_ASSERT(ctr > 0);
CARB_UNUSED(ctr);
}
private:
// 0 - unlocked
// > 0 - Shared lock count
// -1 - Exclusive lock
std::atomic<int> counter{ 0 };
};
class CounterWrapper
{
public:
CounterWrapper(uint32_t target = 0)
: m_counter(carb::getCachedInterface<ITasking>()->createCounterWithTarget(target))
{
}
CARB_DEPRECATED("ITasking no longer needed.")
CounterWrapper(ITasking* tasking, uint32_t target = 0)
: m_counter(carb::getCachedInterface<ITasking>()->createCounterWithTarget(target))
{
CARB_UNUSED(tasking);
}
~CounterWrapper()
{
carb::getCachedInterface<ITasking>()->destroyCounter(m_counter);
}
CARB_DEPRECATED("The Counter interface is deprecated.") bool check() const
{
return try_wait();
}
bool try_wait() const
{
return carb::getCachedInterface<ITasking>()->try_wait(m_counter);
}
void wait() const
{
carb::getCachedInterface<ITasking>()->wait(m_counter);
}
template <class Rep, class Period>
bool wait_for(const std::chrono::duration<Rep, Period>& dur) const
{
return carb::getCachedInterface<ITasking>()->wait_for(dur, m_counter);
}
template <class Clock, class Duration>
bool wait_until(const std::chrono::time_point<Clock, Duration>& tp) const
{
return carb::getCachedInterface<ITasking>()->wait_until(tp, m_counter);
}
operator Counter*() const
{
return m_counter;
}
CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
{
return carb::getCachedInterface<ITasking>();
}
CARB_PREVENT_COPY_AND_MOVE(CounterWrapper);
private:
Counter* m_counter;
};
class TaskGroup
{
public:
CARB_PREVENT_COPY_AND_MOVE(TaskGroup);
constexpr TaskGroup() = default;
~TaskGroup()
{
CARB_CHECK(empty(), "Destroying busy TaskGroup!");
}
bool empty() const
{
// This cannot be memory_order_relaxed because it does not synchronize with anything and would allow the
// compiler to cache the value or hoist it out of a loop. Acquire semantics will require synchronization with
// all other locations that release m_count.
return m_count.load(std::memory_order_acquire) == 0;
}
void enter()
{
m_count.fetch_add(1, std::memory_order_acquire); // synchronizes-with all other locations releasing m_count
}
void leave()
{
size_t v = m_count.fetch_sub(1, std::memory_order_release);
CARB_ASSERT(v, "Mismatched enter()/leave() calls");
if (v == 1)
{
carb::getCachedInterface<ITasking>()->futexWakeup(m_count, UINT_MAX);
}
}
bool try_wait() const
{
return empty();
}
void wait() const
{
size_t v = m_count.load(std::memory_order_acquire); // synchronizes-with all other locations releasing m_count
if (v)
{
carb::getCachedInterface<ITasking>()->wait(*this);
}
}
template <class Rep, class Period>
bool try_wait_for(std::chrono::duration<Rep, Period> dur)
{
return try_wait_until(std::chrono::steady_clock::now() + dur);
}
template <class Clock, class Duration>
bool try_wait_until(std::chrono::time_point<Clock, Duration> when)
{
size_t v = m_count.load(std::memory_order_acquire); // synchronizes-with all other locations releasing m_count
if (v)
{
ITasking* tasking = carb::getCachedInterface<ITasking>();
while (v)
{
if (!tasking->futexWaitUntil(m_count, v, when))
{
return false;
}
v = m_count.load(std::memory_order_relaxed);
}
}
return true;
}
template <class... Args>
auto with(Args&&... args)
{
enter();
CARB_SCOPE_EXIT
{
leave();
};
return carb::cpp::invoke(std::forward<Args>(args)...);
}
private:
friend struct Tracker;
friend struct RequiredObject;
std::atomic_size_t m_count{ 0 };
};
class MutexWrapper
{
public:
MutexWrapper() : m_mutex(carb::getCachedInterface<ITasking>()->createMutex())
{
}
CARB_DEPRECATED("ITasking no longer needed.")
MutexWrapper(ITasking*) : m_mutex(carb::getCachedInterface<ITasking>()->createMutex())
{
}
~MutexWrapper()
{
carb::getCachedInterface<ITasking>()->destroyMutex(m_mutex);
}
bool try_lock()
{
return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, 0);
}
void lock()
{
carb::getCachedInterface<ITasking>()->lockMutex(m_mutex);
}
void unlock()
{
carb::getCachedInterface<ITasking>()->unlockMutex(m_mutex);
}
template <class Rep, class Period>
bool try_lock_for(const std::chrono::duration<Rep, Period>& duration)
{
return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, detail::convertDuration(duration));
}
template <class Clock, class Duration>
bool try_lock_until(const std::chrono::time_point<Clock, Duration>& time_point)
{
return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, detail::convertAbsTime(time_point));
}
operator Mutex*() const
{
return m_mutex;
}
CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
{
return carb::getCachedInterface<ITasking>();
}
CARB_PREVENT_COPY_AND_MOVE(MutexWrapper);
private:
Mutex* m_mutex;
};
class RecursiveMutexWrapper
{
public:
RecursiveMutexWrapper() : m_mutex(carb::getCachedInterface<ITasking>()->createRecursiveMutex())
{
}
CARB_DEPRECATED("ITasking no longer needed.")
RecursiveMutexWrapper(ITasking*) : m_mutex(carb::getCachedInterface<ITasking>()->createRecursiveMutex())
{
}
~RecursiveMutexWrapper()
{
carb::getCachedInterface<ITasking>()->destroyMutex(m_mutex);
}
bool try_lock()
{
return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, 0);
}
void lock()
{
carb::getCachedInterface<ITasking>()->lockMutex(m_mutex);
}
void unlock()
{
carb::getCachedInterface<ITasking>()->unlockMutex(m_mutex);
}
template <class Rep, class Period>
bool try_lock_for(const std::chrono::duration<Rep, Period>& duration)
{
return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, detail::convertDuration(duration));
}
template <class Clock, class Duration>
bool try_lock_until(const std::chrono::time_point<Clock, Duration>& time_point)
{
return carb::getCachedInterface<ITasking>()->timedLockMutex(m_mutex, detail::convertAbsTime(time_point));
}
operator Mutex*() const
{
return m_mutex;
}
CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
{
return carb::getCachedInterface<ITasking>();
}
CARB_PREVENT_COPY_AND_MOVE(RecursiveMutexWrapper);
private:
Mutex* m_mutex;
};
class SemaphoreWrapper
{
public:
SemaphoreWrapper(unsigned value) : m_sema(carb::getCachedInterface<ITasking>()->createSemaphore(value))
{
}
CARB_DEPRECATED("ITasking no longer needed.")
SemaphoreWrapper(ITasking*, unsigned value) : m_sema(carb::getCachedInterface<ITasking>()->createSemaphore(value))
{
}
~SemaphoreWrapper()
{
carb::getCachedInterface<ITasking>()->destroySemaphore(m_sema);
}
void release(unsigned count = 1)
{
carb::getCachedInterface<ITasking>()->releaseSemaphore(m_sema, count);
}
void acquire()
{
carb::getCachedInterface<ITasking>()->waitSemaphore(m_sema);
}
bool try_acquire()
{
return carb::getCachedInterface<ITasking>()->timedWaitSemaphore(m_sema, 0);
}
template <class Rep, class Period>
bool try_acquire_for(const std::chrono::duration<Rep, Period>& dur)
{
return carb::getCachedInterface<ITasking>()->timedWaitSemaphore(m_sema, detail::convertDuration(dur));
}
template <class Clock, class Duration>
bool try_acquire_until(const std::chrono::time_point<Clock, Duration>& tp)
{
return carb::getCachedInterface<ITasking>()->timedWaitSemaphore(m_sema, detail::convertAbsTime(tp));
}
operator Semaphore*() const
{
return m_sema;
}
CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
{
return carb::getCachedInterface<ITasking>();
}
CARB_PREVENT_COPY_AND_MOVE(SemaphoreWrapper);
private:
Semaphore* m_sema;
};
class SharedMutexWrapper
{
public:
SharedMutexWrapper() : m_mutex(carb::getCachedInterface<ITasking>()->createSharedMutex())
{
}
CARB_DEPRECATED("ITasking no longer needed.")
SharedMutexWrapper(ITasking*) : m_mutex(carb::getCachedInterface<ITasking>()->createSharedMutex())
{
}
~SharedMutexWrapper()
{
carb::getCachedInterface<ITasking>()->destroySharedMutex(m_mutex);
}
bool try_lock_shared()
{
return carb::getCachedInterface<ITasking>()->timedLockSharedMutex(m_mutex, 0);
}
bool try_lock()
{
return carb::getCachedInterface<ITasking>()->timedLockSharedMutexExclusive(m_mutex, 0);
}
template <class Rep, class Period>
bool try_lock_for(const std::chrono::duration<Rep, Period>& duration)
{
return carb::getCachedInterface<ITasking>()->timedLockSharedMutexExclusive(
m_mutex, detail::convertDuration(duration));
}
template <class Rep, class Period>
bool try_lock_shared_for(const std::chrono::duration<Rep, Period>& duration)
{
return carb::getCachedInterface<ITasking>()->timedLockSharedMutex(m_mutex, detail::convertDuration(duration));
}
template <class Clock, class Duration>
bool try_lock_until(const std::chrono::time_point<Clock, Duration>& time_point)
{
return try_lock_for(time_point - Clock::now());
}
template <class Clock, class Duration>
bool try_lock_shared_until(const std::chrono::time_point<Clock, Duration>& time_point)
{
return try_lock_shared_for(time_point - Clock::now());
}
void lock_shared()
{
carb::getCachedInterface<ITasking>()->lockSharedMutex(m_mutex);
}
void unlock_shared()
{
carb::getCachedInterface<ITasking>()->unlockSharedMutex(m_mutex);
}
void lock()
{
carb::getCachedInterface<ITasking>()->lockSharedMutexExclusive(m_mutex);
}
void unlock()
{
carb::getCachedInterface<ITasking>()->unlockSharedMutex(m_mutex);
}
operator SharedMutex*() const
{
return m_mutex;
}
CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
{
return carb::getCachedInterface<ITasking>();
}
CARB_PREVENT_COPY_AND_MOVE(SharedMutexWrapper);
private:
SharedMutex* m_mutex;
};
class ConditionVariableWrapper
{
public:
ConditionVariableWrapper() : m_cv(carb::getCachedInterface<ITasking>()->createConditionVariable())
{
}
CARB_DEPRECATED("ITasking no longer needed.")
ConditionVariableWrapper(ITasking*) : m_cv(carb::getCachedInterface<ITasking>()->createConditionVariable())
{
}
~ConditionVariableWrapper()
{
carb::getCachedInterface<ITasking>()->destroyConditionVariable(m_cv);
}
void wait(Mutex* m)
{
carb::getCachedInterface<ITasking>()->waitConditionVariable(m_cv, m);
}
template <class Pred>
void wait(Mutex* m, Pred&& pred)
{
carb::getCachedInterface<ITasking>()->waitConditionVariablePred(m_cv, m, std::forward<Pred>(pred));
}
template <class Rep, class Period>
std::cv_status wait_for(Mutex* m, const std::chrono::duration<Rep, Period>& duration)
{
return carb::getCachedInterface<ITasking>()->timedWaitConditionVariable(
m_cv, m, detail::convertDuration(duration)) ?
std::cv_status::no_timeout :
std::cv_status::timeout;
}
template <class Rep, class Period, class Pred>
bool wait_for(Mutex* m, const std::chrono::duration<Rep, Period>& duration, Pred&& pred)
{
return carb::getCachedInterface<ITasking>()->timedWaitConditionVariablePred(
m_cv, m, detail::convertDuration(duration), std::forward<Pred>(pred));
}
template <class Clock, class Duration>
std::cv_status wait_until(Mutex* m, const std::chrono::time_point<Clock, Duration>& time_point)
{
return carb::getCachedInterface<ITasking>()->timedWaitConditionVariable(
m_cv, m, detail::convertAbsTime(time_point)) ?
std::cv_status::no_timeout :
std::cv_status::timeout;
}
template <class Clock, class Duration, class Pred>
bool wait_until(Mutex* m, const std::chrono::time_point<Clock, Duration>& time_point, Pred&& pred)
{
return carb::getCachedInterface<ITasking>()->timedWaitConditionVariablePred(
m_cv, m, detail::convertAbsTime(time_point), std::forward<Pred>(pred));
}
void notify_one()
{
carb::getCachedInterface<ITasking>()->notifyConditionVariableOne(m_cv);
}
void notify_all()
{
carb::getCachedInterface<ITasking>()->notifyConditionVariableAll(m_cv);
}
operator ConditionVariable*() const
{
return m_cv;
}
CARB_DEPRECATED("Use carb::getCachedInterface") ITasking* getTasking() const
{
return carb::getCachedInterface<ITasking>();
}
CARB_PREVENT_COPY_AND_MOVE(ConditionVariableWrapper);
private:
ConditionVariable* m_cv;
};
class ScopedTracking
{
public:
ScopedTracking() : m_tracker{ ObjectType::eNone, nullptr }
{
}
ScopedTracking(Trackers trackers);
~ScopedTracking();
CARB_PREVENT_COPY(ScopedTracking);
ScopedTracking(ScopedTracking&& rhs);
ScopedTracking& operator=(ScopedTracking&& rhs) noexcept;
private:
Object m_tracker;
};
inline constexpr RequiredObject::RequiredObject(const TaskGroup& tg)
: Object{ ObjectType::eTaskGroup, const_cast<std::atomic_size_t*>(&tg.m_count) }
{
}
inline constexpr RequiredObject::RequiredObject(const TaskGroup* tg)
: Object{ ObjectType::eTaskGroup, tg ? const_cast<std::atomic_size_t*>(&tg->m_count) : nullptr }
{
}
inline All::All(std::initializer_list<RequiredObject> il)
{
static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
m_counter = carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAll, il.begin(), il.size());
}
template <class InputIt, std::enable_if_t<detail::IsForwardIter<InputIt, RequiredObject>::value, bool>>
inline All::All(InputIt begin, InputIt end)
{
static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
std::vector<RequiredObject> objects;
for (; begin != end; ++begin)
objects.push_back(*begin);
m_counter =
carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAll, objects.data(), objects.size());
}
template <class InputIt, std::enable_if_t<detail::IsRandomAccessIter<InputIt, RequiredObject>::value, bool>>
inline All::All(InputIt begin, InputIt end)
{
static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
size_t const count = end - begin;
RequiredObject* objects = CARB_STACK_ALLOC(RequiredObject, count);
size_t index = 0;
for (; begin != end; ++begin)
objects[index++] = *begin;
CARB_ASSERT(index == count);
m_counter = carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAll, objects, count);
}
inline Any::Any(std::initializer_list<RequiredObject> il)
{
static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
m_counter = carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAny, il.begin(), il.size());
}
template <class InputIt, std::enable_if_t<detail::IsForwardIter<InputIt, RequiredObject>::value, bool>>
inline Any::Any(InputIt begin, InputIt end)
{
static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
std::vector<RequiredObject> objects;
for (; begin != end; ++begin)
objects.push_back(*begin);
m_counter =
carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAny, objects.data(), objects.size());
}
template <class InputIt, std::enable_if_t<detail::IsRandomAccessIter<InputIt, RequiredObject>::value, bool>>
inline Any::Any(InputIt begin, InputIt end)
{
static_assert(sizeof(RequiredObject) == sizeof(Object), "Invalid assumption");
size_t const count = end - begin;
RequiredObject* objects = CARB_STACK_ALLOC(RequiredObject, count);
size_t index = 0;
for (; begin != end; ++begin)
objects[index++] = *begin;
CARB_ASSERT(index == count);
m_counter = carb::getCachedInterface<ITasking>()->internalGroupObjects(ITasking::eAny, objects, count);
}
inline Tracker::Tracker(TaskGroup& grp) : Object{ ObjectType::eTaskGroup, &grp.m_count }
{
}
inline Tracker::Tracker(TaskGroup* grp) : Object{ ObjectType::eTaskGroup, grp ? &grp->m_count : nullptr }
{
}
inline ScopedTracking::ScopedTracking(Trackers trackers)
{
Tracker const* ptrackers;
size_t numTrackers;
trackers.output(ptrackers, numTrackers);
m_tracker = carb::getCachedInterface<ITasking>()->beginTracking(ptrackers, numTrackers);
}
inline ScopedTracking::~ScopedTracking()
{
Object tracker = std::exchange(m_tracker, { ObjectType::eNone, nullptr });
if (tracker.type == ObjectType::eTrackerGroup)
{
carb::getCachedInterface<ITasking>()->endTracking(tracker);
}
}
inline ScopedTracking::ScopedTracking(ScopedTracking&& rhs)
: m_tracker(std::exchange(rhs.m_tracker, { ObjectType::eNone, nullptr }))
{
}
inline ScopedTracking& ScopedTracking::operator=(ScopedTracking&& rhs) noexcept
{
std::swap(m_tracker, rhs.m_tracker);
return *this;
}
} // namespace tasking
} // namespace carb