carb/tasking/ITasking.h

File members: carb/tasking/ITasking.h

// Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "../Interface.h"
#include "../InterfaceUtils.h"
#include "TaskingHelpers.h"

namespace carb
{

namespace tasking
{

inline TaskingDesc getDefaultTaskingDesc()
{
    return TaskingDesc{};
}

struct ITasking
{
    // 0.1 - Initial version
    // 0.2 - Thread pinning, sleep, suspending/waking (not ABI compatible with 0.1)
    // 0.3 - Semaphore support, SharedMutex support
    // 0.4 - ConditionVariable support
    // 1.0 - Wait timeouts (git hash e13289c5a5)
    // 1.1 - changeTaskPriority() / executeMainTasks()
    // 1.2 - restart() -> changeParameters(); don't lose tasks when changing parameters
    // 1.3 - Respect task priority when resuming tasks that have slept, waited, or unsuspended (not an API change)
    // 1.4 - Stuck checking (not an API change)
    // 1.5 - internalGroupCounters()
    // 1.6 - createRecursiveMutex()
    // 2.0 - ITasking 2.0 (git hash f68ae95da7)
    // 2.1 - allocTaskStorage() / freeTaskStorage() / setTaskStorage() / getTaskStorage()
    // 2.2 - beginTracking() / endTracking()
    // 2.3 - internalNameTask()
    // 2.4 - reloadFiberEvents()
    CARB_PLUGIN_INTERFACE("carb::tasking::ITasking", 2, 4)

    void(CARB_ABI* changeParameters)(TaskingDesc desc);

    const TaskingDesc&(CARB_ABI* getDesc)();

    Counter*(CARB_ABI* createCounter)();

    Counter*(CARB_ABI* createCounterWithTarget)(uint32_t target);

    void(CARB_ABI* destroyCounter)(Counter* counter);

    TaskContext(CARB_ABI* internalAddTask)(TaskDesc task, Counter* counter);

    void(CARB_ABI* addTasks)(TaskDesc* tasks, size_t taskCount, Counter* counter);

    TaskContext(CARB_ABI* internalAddDelayedTask)(uint64_t delayNs, TaskDesc desc, Counter* counter);

    void(CARB_ABI* internalApplyRange)(size_t range, ApplyFn fn, void* context);

    CARB_DEPRECATED("Use wait() instead") void yieldUntilCounter(RequiredObject counter);

    CARB_DEPRECATED("Use wait_for() or wait_until() instead.")
    bool timedYieldUntilCounter(RequiredObject counter, uint64_t timeoutNs);

    bool(CARB_ABI* internalCheckCounter)(Counter* counter);
    uint32_t(CARB_ABI* internalGetCounterValue)(Counter* counter);
    uint32_t(CARB_ABI* internalGetCounterTarget)(Counter* counter);
    uint32_t(CARB_ABI* internalFetchAddCounter)(Counter* counter, uint32_t value);
    uint32_t(CARB_ABI* internalFetchSubCounter)(Counter* counter, uint32_t value);
    void(CARB_ABI* internalStoreCounter)(Counter* counter, uint32_t value);

    CARB_DEPRECATED("The Counter interface is deprecated.") bool checkCounter(Counter* c)
    {
        return internalCheckCounter(c);
    }

    CARB_DEPRECATED("The Counter interface is deprecated.") uint32_t getCounterValue(Counter* counter)
    {
        return internalGetCounterValue(counter);
    }

    CARB_DEPRECATED("The Counter interface is deprecated.") uint32_t getCounterTarget(Counter* counter)
    {
        return internalGetCounterTarget(counter);
    }

    CARB_DEPRECATED("The Counter interface is deprecated.") uint32_t fetchAddCounter(Counter* counter, uint32_t value)
    {
        return internalFetchAddCounter(counter, value);
    }

    CARB_DEPRECATED("The Counter interface is deprecated.") uint32_t fetchSubCounter(Counter* counter, uint32_t value)
    {
        return internalFetchSubCounter(counter, value);
    }

    CARB_DEPRECATED("The Counter interface is deprecated.") void storeCounter(Counter* counter, uint32_t value)
    {
        return internalStoreCounter(counter, value);
    }

    void(CARB_ABI* yield)();

    bool(CARB_ABI* pinToCurrentThread)();

    bool(CARB_ABI* unpinFromCurrentThread)();

    Mutex*(CARB_ABI* createMutex)();

    void(CARB_ABI* destroyMutex)(Mutex* mutex);

    void lockMutex(Mutex* mutex);

    bool(CARB_ABI* timedLockMutex)(Mutex* mutex, uint64_t timeoutNs);

    void(CARB_ABI* unlockMutex)(Mutex* mutex);

    void(CARB_ABI* sleepNs)(uint64_t nanoseconds);

    TaskContext(CARB_ABI* getTaskContext)();

    bool(CARB_ABI* suspendTask)();

    bool(CARB_ABI* wakeTask)(TaskContext task);

    CARB_DEPRECATED("Use wait() instead") bool waitForTask(TaskContext task);

    bool(CARB_ABI* internalTimedWait)(Object obj, uint64_t timeoutNs);

    bool try_wait(RequiredObject req);

    void wait(RequiredObject req);

    template <class Rep, class Period>
    bool wait_for(std::chrono::duration<Rep, Period> dur, RequiredObject req);

    template <class Clock, class Duration>
    bool wait_until(std::chrono::time_point<Clock, Duration> when, RequiredObject req);

    Semaphore*(CARB_ABI* createSemaphore)(unsigned value);

    void(CARB_ABI* destroySemaphore)(Semaphore* sema);

    void(CARB_ABI* releaseSemaphore)(Semaphore* sema, unsigned count);

    void waitSemaphore(Semaphore* sema);

    bool(CARB_ABI* timedWaitSemaphore)(Semaphore* sema, uint64_t timeoutNs);

    SharedMutex*(CARB_ABI* createSharedMutex)();

    void lockSharedMutex(SharedMutex* mutex);

    bool(CARB_ABI* timedLockSharedMutex)(SharedMutex* mutex, uint64_t timeoutNs);

    void lockSharedMutexExclusive(SharedMutex* mutex);

    bool(CARB_ABI* timedLockSharedMutexExclusive)(SharedMutex* mutex, uint64_t timeoutNs);

    void(CARB_ABI* unlockSharedMutex)(SharedMutex* mutex);

    void(CARB_ABI* destroySharedMutex)(SharedMutex* mutex);

    ConditionVariable*(CARB_ABI* createConditionVariable)();

    void(CARB_ABI* destroyConditionVariable)(ConditionVariable* cv);

    void waitConditionVariable(ConditionVariable* cv, Mutex* m);

    bool(CARB_ABI* timedWaitConditionVariable)(ConditionVariable* cv, Mutex* m, uint64_t timeoutNs);

    void(CARB_ABI* notifyConditionVariableOne)(ConditionVariable* cv);

    void(CARB_ABI* notifyConditionVariableAll)(ConditionVariable* cv);

    bool(CARB_ABI* changeTaskPriority)(TaskContext ctx, Priority newPrio);

    void(CARB_ABI* executeMainTasks)();

    // Intended for internal use only; only for the RequiredObject object.
    // NOTE: The Counter returned from this function is a one-shot counter that is only intended to be passed as a
    // RequiredObject. It is immediately released.
    enum GroupType
    {
        eAny,
        eAll,
    };
    Counter*(CARB_ABI* internalGroupObjects)(GroupType type, Object const* counters, size_t count);

    Mutex*(CARB_ABI* createRecursiveMutex)();

    bool(CARB_ABI* tryCancelTask)(TaskContext task);

    bool(CARB_ABI* internalFutexWait)(const void* addr, const void* compare, size_t size, uint64_t timeoutNs);

    unsigned(CARB_ABI* internalFutexWakeup)(const void* addr, unsigned count);

    TaskStorageKey(CARB_ABI* allocTaskStorage)(TaskStorageDestructorFn fn);

    void(CARB_ABI* freeTaskStorage)(TaskStorageKey key);

    bool(CARB_ABI* setTaskStorage)(TaskStorageKey key, void* value);

    void*(CARB_ABI* getTaskStorage)(TaskStorageKey key);

    // Do not call directly; use ScopedTracking instead.
    // Returns a special tracking object that MUST be passed to endTracking().
    Object(CARB_ABI* beginTracking)(Object const* trackers, size_t numTrackers);

    // Do not call directly; use ScopedTracking instead.
    void(CARB_ABI* endTracking)(Object tracker);

    bool(CARB_ABI* getTaskDebugInfo)(TaskContext task, TaskDebugInfo* out);

    bool(CARB_ABI* walkTaskDebugInfo)(TaskDebugInfo& info, TaskDebugInfoFn fn, void* context);

    void(CARB_ABI* internalApplyRangeBatch)(size_t range, size_t batchHint, ApplyBatchFn fn, void* context);

    void(CARB_ABI* internalBindTrackers)(Object required, Object const* ptrackes, size_t numTrackers);

    void(CARB_ABI* internalNameTask)(TaskContext task, const char* name, bool dynamic);

    void(CARB_ABI* reloadFiberEvents)();

    // Helper functions

    void yieldUntilCounterPinThread(RequiredObject counter);

    template <class Pred>
    void waitConditionVariablePred(ConditionVariable* cv, Mutex* m, Pred&& pred)
    {
        while (!pred())
        {
            this->waitConditionVariable(cv, m);
        }
    }

    template <class Pred>
    bool timedWaitConditionVariablePred(ConditionVariable* cv, Mutex* m, uint64_t timeoutNs, Pred&& pred)
    {
        while (!pred())
            if (!this->timedWaitConditionVariable(cv, m, timeoutNs))
                return false;
        return true;
    }

    template <class Callable, class... Args>
    auto awaitSyncTask(Priority priority, Callable&& f, Args&&... args);

    template <class Callable, class... Args>
    auto addTask(Priority priority, Trackers&& trackers, Callable&& f, Args&&... args);

    CARB_DEPRECATED("Use a C++ addTask() function") TaskContext addTask(TaskDesc desc, Counter* counter)
    {
        return this->internalAddTask(desc, counter);
    }

    template <class Callable, class... Args>
    auto addThrottledTask(Semaphore* throttler, Priority priority, Trackers&& trackers, Callable&& f, Args&&... args);

    template <class Callable, class... Args>
    auto addSubTask(RequiredObject requiredObject, Priority priority, Trackers&& trackers, Callable&& f, Args&&... args);

    template <class Callable, class... Args>
    auto addThrottledSubTask(RequiredObject requiredObject,
                             Semaphore* throttler,
                             Priority priority,
                             Trackers&& trackers,
                             Callable&& f,
                             Args&&... args);

    template <class Callable, class Rep, class Period, class... Args>
    auto addTaskIn(const std::chrono::duration<Rep, Period>& dur,
                   Priority priority,
                   Trackers&& trackers,
                   Callable&& f,
                   Args&&... args);

    template <class Callable, class Clock, class Duration, class... Args>
    auto addTaskAt(const std::chrono::time_point<Clock, Duration>& when,
                   Priority priority,
                   Trackers&& trackers,
                   Callable&& f,
                   Args&&... args);

    template <class Callable, class... Args>
    void applyRange(size_t range, Callable f, Args&&... args);

    template <class Callable, class... Args>
    void applyRangeBatch(size_t range, size_t batchHint, Callable f, Args&&... args);

    template <class T, class Callable, class... Args>
    void parallelFor(T begin, T end, Callable f, Args&&... args);

    template <class T, class Callable, class... Args>
    void parallelFor(T begin, T end, T step, Callable f, Args&&... args);

    template <class Rep, class Period>
    void sleep_for(const std::chrono::duration<Rep, Period>& dur)
    {
        sleepNs(detail::convertDuration(dur));
    }
    template <class Clock, class Duration>
    void sleep_until(const std::chrono::time_point<Clock, Duration>& tp)
    {
        sleepNs(detail::convertAbsTime(tp));
    }

    template <class T>
    void futexWait(const std::atomic<T>& val, T compare)
    {
        bool b = internalFutexWait(&val, &compare, sizeof(T), kInfinite);
        CARB_ASSERT(b);
        CARB_UNUSED(b);
    }

    template <class T, class Rep, class Period>
    bool futexWaitFor(const std::atomic<T>& val, T compare, std::chrono::duration<Rep, Period> dur)
    {
        return internalFutexWait(&val, &compare, sizeof(T), detail::convertDuration(dur));
    }

    template <class T, class Clock, class Duration>
    bool futexWaitUntil(const std::atomic<T>& val, T compare, std::chrono::time_point<Clock, Duration> when)
    {
        return internalFutexWait(&val, &compare, sizeof(T), detail::convertAbsTime(when));
    }

    template <class T>
    unsigned futexWakeup(const std::atomic<T>& val, unsigned count)
    {
        return internalFutexWakeup(&val, count);
    }

    void bindTrackers(RequiredObject requiredObject, Trackers&& trackers);

    template <class T, std::enable_if_t<std::is_convertible<T, const char*>::value, bool> = false>
    void nameTask(TaskContext task, T&& name)
    {
        internalNameTask(task, name, !detail::is_literal_string<T>::value);
    }
};

class PinGuard
{
public:
    PinGuard() : m_wasPinned(carb::getCachedInterface<ITasking>()->pinToCurrentThread())
    {
    }

    CARB_DEPRECATED("ITasking no longer needed.")
    PinGuard(ITasking*) : m_wasPinned(carb::getCachedInterface<ITasking>()->pinToCurrentThread())
    {
    }

    ~PinGuard()
    {
        if (!m_wasPinned)
            carb::getCachedInterface<ITasking>()->unpinFromCurrentThread();
    }

private:
    bool m_wasPinned;
};

} // namespace tasking
} // namespace carb

#include "ITasking.inl"