carb/thread/ThreadLocal.h

File members: carb/thread/ThreadLocal.h

// Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "../Defines.h"

#include "../cpp/detail/ImplDummy.h"
#include "../cpp/TypeTraits.h"

#include <atomic>
#include <type_traits>

#if CARB_POSIX
#    include <pthread.h>
#elif CARB_PLATFORM_WINDOWS
#    include "../CarbWindows.h"
#    include "SharedMutex.h"

#    include <map>
#else
CARB_UNSUPPORTED_PLATFORM();
#endif

namespace carb
{
namespace thread
{

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{

using TlsDestructor = void (*)(void*);

inline std::mutex& tlsMutex()
{
    static std::mutex m;
    return m;
}

#    if CARB_POSIX
class ThreadLocalBase
{
    pthread_key_t m_key;

public:
    ThreadLocalBase(TlsDestructor destructor)
    {
        int res = pthread_key_create(&m_key, destructor);
        CARB_FATAL_UNLESS(res == 0, "pthread_key_create failed: %d/%s", res, strerror(res));
    }
    ~ThreadLocalBase()
    {
        pthread_key_delete(m_key);
    }

    // Not copyable or movable
    CARB_PREVENT_COPY_AND_MOVE(ThreadLocalBase);

    void* get() const
    {
        return pthread_getspecific(m_key);
    }
    void set(void* val) const
    {
        int res = pthread_setspecific(m_key, val);
        CARB_CHECK(res == 0, "pthread_setspecific failed with %d/%s for key %u", res, strerror(res), m_key);
    }
};
#    elif CARB_PLATFORM_WINDOWS
__declspec(selectany) CARBWIN_SRWLOCK mutex = CARBWIN_SRWLOCK_INIT;
__declspec(selectany) bool destructed = false;

class CARB_VIZ ThreadLocalBase
{
    CARB_VIZ DWORD m_key;

    class Destructors
    {
        using DestructorMap = std::map<DWORD, TlsDestructor>;
        DestructorMap m_map;

    public:
        Destructors() = default;
        ~Destructors()
        {
            AcquireSRWLockExclusive((PSRWLOCK)&mutex);
            destructed = true;
            // Destroy the map under the lock
            DestructorMap{}.swap(m_map);
            ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
        }

        void add(DWORD slot, TlsDestructor fn)
        {
            // If there's no destructor, don't do anything
            if (!fn)
                return;

            AcquireSRWLockExclusive((PSRWLOCK)&mutex);
            m_map[slot] = fn;
            ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
        }
        void remove(DWORD slot)
        {
            AcquireSRWLockExclusive((PSRWLOCK)&mutex);
            m_map.erase(slot);
            ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
        }
        void call()
        {
            AcquireSRWLockShared((PSRWLOCK)&mutex);

            // It is possible for atexit destructors to run before other threads call destructors with thread-storage
            // duration.
            if (destructed)
            {
                ReleaseSRWLockShared((PSRWLOCK)&mutex);
                return;
            }

            // This mimics the process of destructors with pthread_key_create which will iterate multiple (up to
            // PTHREAD_DESTRUCTOR_ITERATIONS) times, which is typically 4.
            bool again;
            int iters = 0;
            const int kMaxIters = 4;
            do
            {
                again = false;
                for (auto& pair : m_map)
                {
                    if (void* val = ::TlsGetValue(pair.first))
                    {
                        // Set to nullptr and call destructor
                        ::TlsSetValue(pair.first, nullptr);
                        pair.second(val);
                        again = true;
                    }
                }
            } while (again && ++iters < kMaxIters);
            ReleaseSRWLockShared((PSRWLOCK)&mutex);
        }
    };

    static Destructors& destructors()
    {
        static Destructors d;
        return d;
    }

public:
    ThreadLocalBase(TlsDestructor destructor)
    {
        m_key = ::TlsAlloc();
        CARB_FATAL_UNLESS(m_key != CARBWIN_TLS_OUT_OF_INDEXES, "TlsAlloc() failed: %" PRIu32 "", ::GetLastError());
        destructors().add(m_key, destructor);
    }
    ~ThreadLocalBase()
    {
        destructors().remove(m_key);
        BOOL b = ::TlsFree(m_key);
        CARB_CHECK(!!b);
    }

    // Not copyable or movable
    CARB_PREVENT_COPY_AND_MOVE(ThreadLocalBase);

    void* get() const
    {
        return ::TlsGetValue(m_key);
    }
    void set(void* val) const
    {
        BOOL b = ::TlsSetValue(m_key, val);
        CARB_CHECK(!!b);
    }

    static void callDestructors(HINSTANCE, DWORD fdwReason, PVOID)
    {
        if (fdwReason == CARBWIN_DLL_THREAD_DETACH)
        {
            // Call for current thread
            destructors().call();
        }
    }
};

extern "C"
{
    // Hook the TLS destructors in the CRT
    // see crt/src/vcruntime/tlsdtor.cpp
    using TlsHookFunc = void(__stdcall*)(HINSTANCE, DWORD, PVOID);

    // Reference these so that the linker knows to include them
    extern DWORD _tls_used;
    extern TlsHookFunc __xl_a[], __xl_z[];

#        pragma comment(linker, "/include:pthread_thread_callback")
#        pragma section(".CRT$XLD", long, read)
    // Since this is a header file, the __declspec(selectany) enables weak linking so that the linker will throw away
    // all the duplicates and leave only one instance of pthread_thread_callback in the binary.
    // This is placed into the specific binary section used for TLS destructors
    __declspec(allocate(".CRT$XLD")) __declspec(selectany)
        TlsHookFunc pthread_thread_callback = carb::thread::detail::ThreadLocalBase::callDestructors;
}
#    else
CARB_UNSUPPORTED_PLATFORM();
#    endif

} // namespace detail
#endif

template <class T,
          bool Trivial CARB_NO_DOC(= cpp::conjunction<std::is_trivial<T>,
                                                      std::is_trivially_destructible<T>,
                                                      cpp::bool_constant<sizeof(T) <= sizeof(void*)>>::value)>
class ThreadLocal
{
};

// Specializations
// Trivial and can fit within a pointer

template <class T>
class ThreadLocal<T, true> : private detail::ThreadLocalBase
{
public:
    ThreadLocal() : ThreadLocalBase(nullptr)
    {
    }

    ~ThreadLocal() = default;

    T get() const
    {
        union
        {
            cpp::detail::NontrivialDummyType dummy;
            T t;
        } u = {};
        void* p = ThreadLocalBase::get();
        memcpy(&u.t, &p, sizeof(T));
        return u.t;
    }

    void set(T t)
    {
        void* p = 0;
        memcpy(&p, &t, sizeof(T));
        ThreadLocalBase::set(p);
    }

    operator T()
    {
        return get();
    }

    operator T() const
    {
        return get();
    }

    ThreadLocal& operator=(T t)
    {
        set(t);
        return *this;
    }

    T operator->() const
    {
        static_assert(std::is_pointer<T>::value, "Requires pointer type");
        return get();
    }

    auto operator*() const
    {
        static_assert(std::is_pointer<T>::value, "Requires pointer type");
        return *get();
    }

    bool operator==(const T& rhs) const
    {
        return get() == rhs;
    }

    bool operator!=(const T& rhs) const
    {
        return get() != rhs;
    }
};

// Non-trivial or needs more than a pointer (uses the heap)

template <class T>
class ThreadLocal<T, false> : private detail::ThreadLocalBase
{
public:
    ThreadLocal() : ThreadLocalBase(destructor), m_head(&m_head)
    {
        detail::tlsMutex(); // make sure this is constructed since we'll need it at shutdown
    }

    ~ThreadLocal()
    {
        // Delete all instances for threads created by this object
        ListNode n = m_head;
        m_head.next = m_head.prev = _end();
        while (n.next != _end())
        {
            Wrapper* w = reinterpret_cast<Wrapper*>(n.next);
            n.next = n.next->next;
            delete w;
        }
        // It would be very bad if a thread was using this while we're destroying it
        CARB_ASSERT(m_head.next == _end() && m_head.prev == _end());
    }

    T* get_if()
    {
        return _get_if();
    }

    const T* get_if() const
    {
        return _get_if();
    }

    T& get()
    {
        return *_get();
    }

    const T& get() const
    {
        return *_get();
    }

    void set(const T& t)
    {
        *_get() = t;
    }

    void set(T&& t)
    {
        *_get() = std::move(t);
    }

    void reset()
    {
        if (auto p = get_if())
        {
            destructor(p);
            ThreadLocalBase::set(nullptr);
        }
    }

    operator T()
    {
        return get();
    }

    operator T() const
    {
        return get();
    }

    ThreadLocal& operator=(const T& rhs)
    {
        set(rhs);
        return *this;
    }

    ThreadLocal& operator=(T&& rhs)
    {
        set(std::move(rhs));
        return *this;
    }

    auto operator->()
    {
        return get().operator->();
    }
    auto operator->() const
    {
        return get().operator->();
    }
    auto operator*()
    {
        return get().operator*();
    }
    auto operator*() const
    {
        return get().operator*();
    }
    template <class U>
    auto operator[](const U& u) const
    {
        return get().operator[](u);
    }

    bool operator==(const T& rhs) const
    {
        return get() == rhs;
    }

    bool operator!=(const T& rhs) const
    {
        return get() != rhs;
    }

private:
    struct ListNode
    {
        ListNode* next;
        ListNode* prev;
        ListNode() = default;
        ListNode(ListNode* init) : next(init), prev(init)
        {
        }
    };
    struct Wrapper : public ListNode
    {
        T t;
    };

    ListNode m_head;
    ListNode* _tail() const
    {
        return const_cast<ListNode*>(m_head.prev);
    }
    ListNode* _end() const
    {
        return const_cast<ListNode*>(&m_head);
    }

    static void destructor(void* p)
    {
        // Can't use offsetof because of "offsetof within non-standard-layout type 'Wrapper' is undefined"
        Wrapper* w = reinterpret_cast<Wrapper*>(reinterpret_cast<uint8_t*>(p) - size_t(&((Wrapper*)0)->t));
        {
            // Remove from the list
            std::lock_guard<std::mutex> g(detail::tlsMutex());
            w->next->prev = w->prev;
            w->prev->next = w->next;
        }
        delete w;
    }

    T* _get_if() const
    {
        T* p = reinterpret_cast<T*>(ThreadLocalBase::get());
        return p;
    }

    T* _get() const
    {
        T* p = _get_if();
        return p ? p : _create();
    }
    T* _create() const
    {
        Wrapper* w = new Wrapper;
        // Add to end of list
        {
            std::lock_guard<std::mutex> g(detail::tlsMutex());
            w->next = _end();
            w->prev = _tail();
            w->next->prev = w;
            w->prev->next = w;
        }
        T* p = std::addressof(w->t);
        ThreadLocalBase::set(p);
        return p;
    }
};

} // namespace thread
} // namespace carb