NativeFutex.h#

Fully qualified name: carb/thread/detail/NativeFutex.h

File members: carb/thread/detail/NativeFutex.h

// Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "../../Defines.h"
#include "../../thread/Util.h"

#include <atomic>

#if CARB_PLATFORM_WINDOWS
#    pragma comment(lib, "synchronization.lib") // must link with synchronization.lib
#    include "../../CarbWindows.h"
#elif CARB_PLATFORM_LINUX
#    include <linux/futex.h>
#    include <sys/syscall.h>
#    include <sys/time.h>

#    include <unistd.h>
#else
CARB_UNSUPPORTED_PLATFORM();
#endif

namespace carb
{
namespace thread
{
namespace detail
{

template <class T, size_t S = sizeof(T)>
struct to_integral
{
};

template <class T>
struct to_integral<T, 1>
{
    using type = int8_t;
};

template <class T>
struct to_integral<T, 2>
{
    using type = int16_t;
};

template <class T>
struct to_integral<T, 4>
{
    using type = int32_t;
};

template <class T>
struct to_integral<T, 8>
{
    using type = int64_t;
};

template <class T>
using to_integral_t = typename to_integral<T>::type;

template <class As, class T>
CARB_NODISCARD std::enable_if_t<std::is_integral<T>::value && sizeof(As) == sizeof(T), As> reinterpret_as(const T& in) noexcept
{
    static_assert(std::is_integral<As>::value, "Must be integral type");
    return static_cast<As>(in);
}

template <class As, class T>
CARB_NODISCARD std::enable_if_t<std::is_pointer<T>::value && sizeof(As) == sizeof(T), As> reinterpret_as(const T& in) noexcept
{
    static_assert(std::is_integral<As>::value, "Must be integral type");
    return reinterpret_cast<As>(in);
}

template <class As, class T>
CARB_NODISCARD std::enable_if_t<(!std::is_pointer<T>::value && !std::is_integral<T>::value) || sizeof(As) != sizeof(T), As> reinterpret_as(
    const T& in) noexcept
{
    static_assert(std::is_integral<As>::value, "Must be integral type");
    As out{}; // Init to zero
    memcpy(&out, std::addressof(in), sizeof(in));
    return out;
}

template <class T>
bool futex_compare(const std::atomic<T>& val, T compare, std::memory_order order = std::memory_order_seq_cst) noexcept
{
    using I = to_integral_t<T>;
    return reinterpret_as<I>(val.load(order)) == reinterpret_as<I>(compare);
}

#if CARB_PLATFORM_WINDOWS
using hundrednanos = std::chrono::duration<int64_t, std::ratio<1, 10'000'000>>;

using RtlWaitOnAddressFn = NTSTATUS(__stdcall*)(volatile const void*, const void*, size_t, int64_t*);
CARB_WEAKLINK auto RtlWaitOnAddress =
    (RtlWaitOnAddressFn)GetProcAddress(GetModuleHandleW(L"ntdll.dll"), "RtlWaitOnAddress");

inline bool WaitOnAddress(volatile const void* val, const void* compare, size_t size, int64_t* timeout) noexcept
{
    // Use the NTDLL version of this function since we can give it relative or absolute times in 100ns units
    switch (NTSTATUS ret = RtlWaitOnAddress(val, compare, size, timeout))
    {
        case CARBWIN_STATUS_SUCCESS:
            return true;

        default:
            CARB_FATAL_UNLESS(
                0, "Unexpected result from RtlWaitOnAddress: 0x%lx, GetLastError=%" PRIdword, ret, ::GetLastError());
            CARB_FALLTHROUGH; // (not really, but the compiler doesn't know that CARB_FATAL_UNLESS doesn't return)
        case CARBWIN_STATUS_TIMEOUT:
            return false;
    }
}

template <class T>
inline bool WaitOnAddress(const std::atomic<T>& val, T compare, int64_t* timeout) noexcept
{
    static_assert(sizeof(val) == sizeof(compare), "Invalid assumption about atomic");
    return WaitOnAddress(std::addressof(val), std::addressof(compare), sizeof(T), timeout);
}

// Use undocumented API unlikely to change since so many things are based on it.
// https://ntdoc.m417z.com/ntwaitforalertbythreadid
// https://dennisbabkin.com/blog/?t=how-to-put-thread-into-kernel-wait-and-to-wake-it-by-thread-id
using NtWaitForAlertByThreadIdFn = NTSTATUS(__stdcall*)(const void*, int64_t*);
CARB_WEAKLINK auto NtWaitForAlertByThreadId =
    (NtWaitForAlertByThreadIdFn)GetProcAddress(GetModuleHandleW(L"ntdll.dll"), "NtWaitForAlertByThreadId");
using NtAlertThreadByThreadIdFn = NTSTATUS(__stdcall*)(HANDLE); // HANDLE?
CARB_WEAKLINK auto NtAlertThreadByThreadId =
    (NtAlertThreadByThreadIdFn)GetProcAddress(GetModuleHandleW(L"ntdll.dll"), "NtAlertThreadByThreadId");

inline bool WaitForAlertByThreadId(const void* addr, int64_t* timeout) noexcept
{
    // OVCC-1549: Do not use thread-safe static initialization for looking up NtWaitForAlertByThreadId! See Jira/MR for
    // more info.
    switch (NTSTATUS ret = NtWaitForAlertByThreadId(addr, timeout))
    {
        case CARBWIN_STATUS_SUCCESS: // not received in practice
        case CARBWIN_STATUS_ALERTED:
            return true;

        default:
            CARB_FATAL_UNLESS(0, "Unexpected result from NtWaitForAlertByThreadId: 0x%lx, GetLastError=%" PRIdword, ret,
                              ::GetLastError());
            CARB_FALLTHROUGH; // (not really, but the compiler doesn't know that CARB_FATAL_UNLESS doesn't return)
        case CARBWIN_STATUS_TIMEOUT:
            return false;
    }
}

inline void AlertThreadByThreadId(ThreadId threadId) noexcept
{
    // OVCC-1549: Do not use thread-safe static initialization for looking up NtAlertThreadByThreadId! See Jira/MR for
    // more info.
    NTSTATUS ret = NtAlertThreadByThreadId((HANDLE)(UINT_PTR)threadId);
    // Win32 APIs typically ignore this return value
    CARB_ASSERT(ret == CARBWIN_STATUS_SUCCESS, "NtAlertThreadByThreadId gave unexpected result: %lx", ret);
    CARB_UNUSED(ret);
}

template <class T>
inline void threadId_wait(const std::atomic<T>& val) noexcept
{
    WaitForAlertByThreadId(std::addressof(val), nullptr);
}

template <class T, class Clock, class Duration>
inline bool threadId_wait_until(const std::atomic<T>& val, std::chrono::time_point<Clock, Duration> time_point)
{
    // We want to use monotonic clocks instead of the system clock which is subject to time correction and can jump
    // backwards or forwards. Positive absolute time for WaitForAlertByThreadId is based on the system clock, so always
    // use the relative version.
    auto const steadyClockTarget = cpp::detail::convertToClock<std::chrono::steady_clock>(time_point);

    // Compute updated relative timeout and wait
    auto now = std::chrono::steady_clock::now();
    do
    {
        // Compute relative time (passed as a negative number to WaitForAlertByThreadId). WaitForAlertByThreadId must be
        // called at least once.
        int64_t relTime =
            -std::chrono::duration_cast<detail::hundrednanos>(cpp::detail::clampDuration(steadyClockTarget - now)).count();
        if (WaitForAlertByThreadId(std::addressof(val), &relTime))
            return true;
        now = std::chrono::steady_clock::now();
    } while (now <= steadyClockTarget);

    return false;
}

template <class T, class Rep, class Period>
inline bool threadId_wait_for(const std::atomic<T>& val, std::chrono::duration<Rep, Period> duration)
{
    // WaitForAlertByThreadId with relative time is not accurate and will often wait less than the requested time.
    // Compute an absolute timeout value based on a monotonic clock.
    return threadId_wait_until(val, cpp::detail::absTime<std::chrono::steady_clock>(duration));
}

inline void threadId_wake(ThreadId id) noexcept
{
    AlertThreadByThreadId(id);
}

template <class T>
inline void futex_wait(const std::atomic<T>& val, T compare) noexcept
{
    WaitOnAddress(val, compare, nullptr);
}

template <class T, class Clock, class Duration>
inline bool futex_wait_until(const std::atomic<T>& val, T compare, std::chrono::time_point<Clock, Duration> time_point)
{
    // We want to use monotonic clocks instead of the system clock which is subject to time correction and can jump
    // backwards or forwards. Positive absolute time for WaitForAlertByThreadId is based on the system clock, so always
    // use the relative version.
    auto const steadyClockTarget = cpp::detail::convertToClock<std::chrono::steady_clock>(time_point);
    for (;;)
    {
        const auto now = std::chrono::steady_clock::now();
        if (steadyClockTarget <= now)
        {
            // Timeout expired, do immediate check and return. WaitOnAddress is quite slow if the timeout is zero.
            return !futex_compare(val, compare);
        }

        int64_t relTime =
            -std::chrono::duration_cast<detail::hundrednanos>(cpp::detail::clampDuration(steadyClockTarget - now)).count();
        CARB_ASSERT(relTime < 0);
        if (detail::WaitOnAddress(val, compare, &relTime))
            return true;
    }
}

template <class T, class Rep, class Period>
inline bool futex_wait_for(const std::atomic<T>& val, T compare, std::chrono::duration<Rep, Period> duration)
{
    return duration.count() > 0 ?
               futex_wait_until(val, compare, cpp::detail::absTime<std::chrono::steady_clock>(duration)) :
               !futex_compare(val, compare);
}

template <class T>
inline void futex_wake_one(std::atomic<T>& val) noexcept
{
    WakeByAddressSingle(std::addressof(val));
}

template <class T>
inline void futex_wake_n(std::atomic<T>& val, size_t n) noexcept
{
    while (n--)
        futex_wake_one(val);
}

template <class T>
inline void futex_wake_all(std::atomic<T>& val) noexcept
{
    WakeByAddressAll(std::addressof(val));
}

// Windows-specific futex
template <class T, size_t S = sizeof(T)>
class WindowsFutex
{
    static_assert(S == 1 || S == 2 || S == 4 || S == 8, "Unsupported size");

public:
    using AtomicType = typename std::atomic<T>;
    using Type = T;
    static inline void wait(const AtomicType& val, Type compare) noexcept
    {
        futex_wait(val, compare);
    }
    template <class Rep, class Period>
    static inline bool wait_for(const AtomicType& val, Type compare, std::chrono::duration<Rep, Period> duration)
    {
        return futex_wait_for(val, compare, duration);
    }
    template <class Clock, class Duration>
    static inline bool wait_until(const AtomicType& val, Type compare, std::chrono::time_point<Clock, Duration> time_point)
    {
        return futex_wait_until(val, compare, time_point);
    }
    static inline void notify_one(AtomicType& a) noexcept
    {
        futex_wake_one(a);
    }
    static inline void notify_n(AtomicType& a, size_t n) noexcept
    {
        futex_wake_n(a, n);
    }
    static inline void notify_all(AtomicType& a) noexcept
    {
        futex_wake_all(a);
    }
};
#elif CARB_PLATFORM_LINUX
constexpr int64_t kNsPerSec = 1'000'000'000;

inline int futex(const std::atomic_uint32_t& aval,
                 int futex_op,
                 uint32_t val,
                 const struct timespec* timeout,
                 uint32_t* uaddr2,
                 unsigned int val3) noexcept
{
    static_assert(sizeof(aval) == sizeof(uint32_t), "Invalid assumption about atomic");
    auto ret = syscall(SYS_futex, std::addressof(aval), futex_op, val, timeout, uaddr2, val3);
    return ret >= 0 ? int(ret) : -errno;
}

inline void futex_wait(const std::atomic_uint32_t& val, uint32_t compare) noexcept
{
    for (;;)
    {
        int ret = futex(val, FUTEX_WAIT_BITSET_PRIVATE, compare, nullptr, nullptr, FUTEX_BITSET_MATCH_ANY);
        switch (ret)
        {
            case 0:
            case -EAGAIN: // Valid or spurious wakeup
                return;

            case -ETIMEDOUT:
                // Apparently on Windows Subsystem for Linux, calls to the kernel can timeout even when a timeout value
                // was not specified. Fall through.
            case -EINTR: // Interrupted by signal; loop again
                break;

            default:
                CARB_FATAL_UNLESS(0, "Unexpected result from futex(): %d/%s", -ret, strerror(-ret));
        }
    }
}

template <class Rep, class Period>
inline bool futex_wait_for(const std::atomic_uint32_t& val, uint32_t compare, std::chrono::duration<Rep, Period> duration)
{
    // Relative time
    int64_t ns = std::chrono::duration_cast<std::chrono::nanoseconds>(cpp::detail::clampDuration(duration)).count();
    if (ns <= 0)
    {
        return false;
    }

    struct timespec ts;
    ts.tv_sec = time_t(ns / detail::kNsPerSec);
    ts.tv_nsec = long(ns % detail::kNsPerSec);

    // Since we're using relative time here, we can use FUTEX_WAIT_PRIVATE (see futex() man page)
    int ret = futex(val, FUTEX_WAIT_PRIVATE, compare, &ts, nullptr, 0);
    switch (ret)
    {
        case 0: // Valid wakeup
        case -EAGAIN: // Valid or spurious wakeup
        case -EINTR: // Interrupted by signal; treat as a spurious wakeup
            return true;

        default:
            CARB_FATAL_UNLESS(0, "Unexpected result from futex(): %d/%s", -ret, strerror(-ret));
            CARB_FALLTHROUGH; // (not really but the compiler doesn't know that the above won't return)
        case -ETIMEDOUT:
            return false;
    }
}

template <class Clock, class Duration>
inline bool futex_wait_until(const std::atomic_uint32_t& val,
                             uint32_t compare,
                             const std::chrono::time_point<Clock, Duration>& time_point)
{
    struct timespec ts;
    clock_gettime(CLOCK_MONOTONIC, &ts);

    auto const duration = time_point - Clock::now();

    // Get the number of nanoseconds to go (constrained)
    int64_t ns = std::chrono::duration_cast<std::chrono::nanoseconds>(cpp::detail::clampDuration(duration)).count();

    if (ns <= 0)
    {
        return false;
    }

    ts.tv_sec += time_t(ns / kNsPerSec);
    ts.tv_nsec += long(ns % kNsPerSec);

    // Handle rollover
    if (ts.tv_nsec >= kNsPerSec)
    {
        ++ts.tv_sec;
        ts.tv_nsec -= kNsPerSec;
    }

    for (;;)
    {
        // Since we're using absolute monotonic time, we use FUTEX_WAIT_BITSET_PRIVATE. See the man page for futex for
        // more info.
        int ret = futex(val, FUTEX_WAIT_BITSET_PRIVATE, compare, &ts, nullptr, FUTEX_BITSET_MATCH_ANY);
        switch (ret)
        {
            case 0: // Valid wakeup
            case -EAGAIN: // Valid or spurious wakeup
                return true;

            case -EINTR: // Interrupted by signal; loop again
                break;

            default:
                CARB_FATAL_UNLESS(0, "Unexpected result from futex(): %d/%s", -ret, strerror(-ret));
                CARB_FALLTHROUGH; // (not really but the compiler doesn't know that the above won't return)
            case -ETIMEDOUT:
                return false;
        }
    }
}

inline void futex_wake_n(std::atomic_uint32_t& val, unsigned count) noexcept
{
    int ret = futex(val, FUTEX_WAKE_BITSET_PRIVATE, count, nullptr, nullptr, FUTEX_BITSET_MATCH_ANY);
    CARB_ASSERT(ret >= 0, "futex(FUTEX_WAKE) failed with errno=%d/%s", -ret, strerror(-ret));
    CARB_UNUSED(ret);
}

inline void futex_wake_one(std::atomic_uint32_t& val) noexcept
{
    futex_wake_n(val, 1);
}

inline void futex_wake_all(std::atomic_uint32_t& val) noexcept
{
    futex_wake_n(val, INT_MAX);
}
#endif

class NativeFutex
{
public:
    using AtomicType = std::atomic_uint32_t;
    using Type = uint32_t;
    static inline void wait(const AtomicType& val, Type compare) noexcept
    {
        futex_wait(val, compare);
    }
    template <class Rep, class Period>
    static inline bool wait_for(const AtomicType& val, Type compare, std::chrono::duration<Rep, Period> duration)
    {
        return futex_wait_for(val, compare, duration);
    }
    template <class Clock, class Duration>
    static inline bool wait_until(const AtomicType& val, Type compare, std::chrono::time_point<Clock, Duration> time_point)
    {
        return futex_wait_until(val, compare, time_point);
    }
    static inline void notify_one(AtomicType& a) noexcept
    {
        futex_wake_one(a);
    }
    static inline void notify_n(AtomicType& a, size_t n) noexcept
    {
        futex_wake_n(a, unsigned(carb_min<size_t>(UINT_MAX, n)));
    }
    static inline void notify_all(AtomicType& a) noexcept
    {
        futex_wake_all(a);
    }
};

} // namespace detail
} // namespace thread
} // namespace carb