ThreadLocal.h#

Fully qualified name: carb/thread/ThreadLocal.h

File members: carb/thread/ThreadLocal.h

// Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "../Defines.h"

#include "../cpp/Bit.h"
#include "../cpp/TypeTraits.h"

#include <atomic>
#include <mutex>
#include <type_traits>
#include <utility>

#if CARB_POSIX
#    include <pthread.h>
#elif CARB_PLATFORM_WINDOWS
#    include "../CarbWindows.h"

#    include <map>
#else
CARB_UNSUPPORTED_PLATFORM();
#endif

namespace carb
{
namespace thread
{

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{

using TlsDestructor = void (*)(void*);

struct SimpleSpinlock
{
    constexpr SimpleSpinlock() noexcept = default;

    void lock() noexcept
    {
        // Full read/write barrier
        while (m_flag.exchange(true, std::memory_order_seq_cst))
            CARB_HARDWARE_PAUSE();
    }
    void unlock() noexcept
    {
        // Full read/write barrier (hence the exchange)
        (void)m_flag.exchange(false, std::memory_order_seq_cst);
    }

    std::atomic_bool m_flag{ false };
};

#    if CARB_POSIX
class ThreadLocalBase
{
    pthread_key_t m_key;
    bool m_keyZero = false;

public:
    ThreadLocalBase(TlsDestructor destructor)
    {
        for (;;)
        {
            int res = pthread_key_create(&m_key, destructor);
            CARB_FATAL_UNLESS(res == 0, "pthread_key_create failed: %d/%s", res, strerror(res));

            // Even though pthread key 0 might be valid, we have seen some issues where third-party libraries use 0 as
            // an 'uninitialized' state, and may potentially close key 0 even if they don't own it. So if we were given
            // key 0, create another key too.
            CARB_LIKELY_IF(m_key != pthread_key_t{})
            {
                break;
            }
            CARB_ASSERT(!m_keyZero);
            m_keyZero = true;
        }
    }
    ~ThreadLocalBase()
    {
        pthread_key_delete(m_key);
        if (m_keyZero)
            // See constructor above for rationale.
            pthread_key_delete(pthread_key_t{});
    }

    // Not copyable or movable
    CARB_PREVENT_COPY_AND_MOVE(ThreadLocalBase);

    void* get() const
    {
        return pthread_getspecific(m_key);
    }
    void set(void* val) const
    {
        int res = pthread_setspecific(m_key, val);
        CARB_CHECK(res == 0, "pthread_setspecific failed with %d/%s for key %u", res, strerror(res), m_key);
    }
};
#    elif CARB_PLATFORM_WINDOWS
__declspec(selectany) CARBWIN_SRWLOCK mutex = CARBWIN_SRWLOCK_INIT;
__declspec(selectany) bool destructed = false;

class CARB_VIZ ThreadLocalBase
{
    CARB_VIZ DWORD m_key;

    class Destructors
    {
        using DestructorMap = std::map<DWORD, TlsDestructor>;
        DestructorMap m_map;

    public:
        Destructors() = default;
        ~Destructors()
        {
            AcquireSRWLockExclusive((PSRWLOCK)&mutex);
            destructed = true;
            // Destroy the map under the lock
            DestructorMap{}.swap(m_map);
            ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
        }

        void add(DWORD slot, TlsDestructor fn)
        {
            // If there's no destructor, don't do anything
            if (!fn)
                return;

            AcquireSRWLockExclusive((PSRWLOCK)&mutex);
            if (!destructed)
            {
                m_map[slot] = fn;
            }
            ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
        }
        void remove(DWORD slot)
        {
            AcquireSRWLockExclusive((PSRWLOCK)&mutex);
            if (!destructed)
            {
                m_map.erase(slot);
            }
            ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
        }
        void call()
        {
            AcquireSRWLockShared((PSRWLOCK)&mutex);

            // It is possible for atexit destructors to run before other threads call destructors with thread-storage
            // duration.
            if (destructed)
            {
                ReleaseSRWLockShared((PSRWLOCK)&mutex);
                return;
            }

            // This mimics the process of destructors with pthread_key_create which will iterate multiple (up to
            // PTHREAD_DESTRUCTOR_ITERATIONS) times, which is typically 4.
            bool again;
            int iters = 0;
            const int kMaxIters = 4;
            do
            {
                again = false;
                for (auto& pair : m_map)
                {
                    if (void* val = ::TlsGetValue(pair.first))
                    {
                        // Set to nullptr and call destructor
                        ::TlsSetValue(pair.first, nullptr);
                        pair.second(val);
                        again = true;
                    }
                }
            } while (again && ++iters < kMaxIters);
            ReleaseSRWLockShared((PSRWLOCK)&mutex);
        }
    };

    static Destructors& destructors()
    {
        static Destructors d;
        return d;
    }

public:
    ThreadLocalBase(TlsDestructor destructor)
    {
        m_key = ::TlsAlloc();
        CARB_FATAL_UNLESS(m_key != CARBWIN_TLS_OUT_OF_INDEXES, "TlsAlloc() failed: %" PRIdword "", ::GetLastError());
        destructors().add(m_key, destructor);
    }
    ~ThreadLocalBase()
    {
        destructors().remove(m_key);
        BOOL b = ::TlsFree(m_key);
        CARB_CHECK(!!b);
    }

    // Not copyable or movable
    CARB_PREVENT_COPY_AND_MOVE(ThreadLocalBase);

    void* get() const
    {
        return ::TlsGetValue(m_key);
    }
    void set(void* val) const
    {
        BOOL b = ::TlsSetValue(m_key, val);
        CARB_CHECK(!!b);
    }

    static void callDestructors(HINSTANCE, DWORD fdwReason, PVOID)
    {
        if (fdwReason == CARBWIN_DLL_THREAD_DETACH)
        {
            // Call for current thread
            destructors().call();
        }
    }
};

extern "C"
{
    // Hook the TLS destructors in the CRT
    // see crt/src/vcruntime/tlsdtor.cpp
    using TlsHookFunc = void(__stdcall*)(HINSTANCE, DWORD, PVOID);

    // Reference these so that the linker knows to include them
    extern DWORD _tls_used;
    extern TlsHookFunc __xl_a[], __xl_z[];

#        pragma comment(linker, "/include:pthread_thread_callback")
#        pragma section(".CRT$XLD", long, read)
    // Since this is a header file, the __declspec(selectany) enables weak linking so that the linker will throw away
    // all the duplicates and leave only one instance of pthread_thread_callback in the binary.
    // This is placed into the specific binary section used for TLS destructors
    __declspec(allocate(".CRT$XLD")) __declspec(selectany)
        TlsHookFunc pthread_thread_callback = carb::thread::detail::ThreadLocalBase::callDestructors;
}
#    else
CARB_UNSUPPORTED_PLATFORM();
#    endif

} // namespace detail
#endif

template <class T,
          bool Trivial CARB_NO_DOC(= std::conjunction<std::is_trivial<T>,
                                                      std::is_trivially_destructible<T>,
                                                      std::bool_constant<sizeof(T) <= sizeof(void*)>>::value)>
class ThreadLocal
{
};

// Specializations
// Trivial and can fit within a pointer

template <class T>
class ThreadLocal<T, true> : private detail::ThreadLocalBase
{
public:
    ThreadLocal() : ThreadLocalBase(nullptr)
    {
    }

    ~ThreadLocal() = default;

    T get() const
    {
        // Cannot use bit_cast because sizeof(T) may not be the same as sizeof(void*)
        T t;
        void* p = ThreadLocalBase::get();
        memcpy(&t, &p, sizeof(T));
        return t;
    }

    void set(T t)
    {
        void* p = 0;
        memcpy(&p, &t, sizeof(T));
        ThreadLocalBase::set(p);
    }

    operator T()
    {
        return get();
    }

    operator T() const
    {
        return get();
    }

    ThreadLocal& operator=(T t)
    {
        set(t);
        return *this;
    }

    T operator->() const
    {
        static_assert(std::is_pointer<T>::value, "Requires pointer type");
        return get();
    }

    auto operator*() const
    {
        static_assert(std::is_pointer<T>::value, "Requires pointer type");
        return *get();
    }

    bool operator==(const T& rhs) const
    {
        return get() == rhs;
    }

    bool operator!=(const T& rhs) const
    {
        return get() != rhs;
    }
};

// Non-trivial or needs more than a pointer (uses the heap)

template <class T>
class ThreadLocal<T, false> : private detail::ThreadLocalBase
{
public:
    ThreadLocal() : ThreadLocalBase(destructor), m_list(_end())
    {
    }

    ~ThreadLocal()
    {
        // NOTE: we do not need to hold m_spinlock here because it is UB to call functions in an object from other
        // threads while it is being destructed.

        // Delete all instances for threads created by this object
        Wrapper* p = m_list.next;
        m_list.next = m_list.prev = _end();
        for (Wrapper* next; p != _end(); p = next)
        {
            next = p->next;
            delete p;
        }
        // It would be very bad if a thread was using this while we're destroying it
        CARB_ASSERT(m_list.next == _end() && m_list.prev == _end());
    }

    T* get_if()
    {
        return _get_if();
    }

    const T* get_if() const
    {
        return _get_if();
    }

    T& get()
    {
        return *_get();
    }

    const T& get() const
    {
        return *_get();
    }

    void set(const T& t)
    {
        *_get() = t;
    }

    void set(T&& t)
    {
        *_get() = std::move(t);
    }

    void reset()
    {
        if (auto w = static_cast<Wrapper*>(ThreadLocalBase::get()))
        {
            destructor(w);
            ThreadLocalBase::set(nullptr);
        }
    }

    operator T()
    {
        return get();
    }

    operator T() const
    {
        return get();
    }

    ThreadLocal& operator=(const T& rhs)
    {
        set(rhs);
        return *this;
    }

    ThreadLocal& operator=(T&& rhs)
    {
        set(std::move(rhs));
        return *this;
    }

    auto operator->()
    {
        return get().operator->();
    }
    auto operator->() const
    {
        return get().operator->();
    }
    auto operator*()
    {
        return get().operator*();
    }
    auto operator*() const
    {
        return get().operator*();
    }
    template <class U>
    auto operator[](const U& u) const
    {
        return get().operator[](u);
    }

    bool operator==(const T& rhs) const
    {
        return get() == rhs;
    }

    bool operator!=(const T& rhs) const
    {
        return get() != rhs;
    }

private:
    struct Wrapper;
    struct ListNode
    {
        Wrapper* next = nullptr;
        Wrapper* prev = nullptr;
        constexpr ListNode() noexcept = default;
        constexpr ListNode(Wrapper* init) noexcept : next(init), prev(init)
        {
        }
    };
    struct Wrapper : public ListNode
    {
        T t;
    };

    Wrapper* _tail() const
    {
        return const_cast<Wrapper*>(m_list.prev);
    }
    Wrapper* _end() const
    {
        return cpp::bit_cast<Wrapper*>(const_cast<ListNode*>(&m_list));
    }

    static void destructor(void* p)
    {
        auto w = static_cast<Wrapper*>(p);
        {
            // Remove from the list
            std::lock_guard<detail::SimpleSpinlock> g(spinlock());
            w->next->prev = w->prev;
            w->prev->next = w->next;
        }
        delete w;
    }

    T* _get_if() const
    {
        auto w = static_cast<Wrapper*>(ThreadLocalBase::get());
        return w ? std::addressof(w->t) : nullptr;
    }

    T* _get() const
    {
        T* p = _get_if();
        return p ? p : _create();
    }
    T* _create() const
    {
        Wrapper* w = new Wrapper; // may throw
        // Add to end of list
        {
            std::lock_guard<detail::SimpleSpinlock> g(spinlock());
            w->next = _end();
            w->prev = _tail();
            w->next->prev = w;
            w->prev->next = w;
        }
        ThreadLocalBase::set(w);
        return std::addressof(w->t);
    }

    static detail::SimpleSpinlock& spinlock() noexcept
    {
        static detail::SimpleSpinlock s_spinlock;
        return s_spinlock;
    }

    ListNode m_list;
};

} // namespace thread
} // namespace carb