carb/thread/ThreadLocal.h
File members: carb/thread/ThreadLocal.h
// Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include "../Defines.h"
#include "../cpp/detail/ImplDummy.h"
#include "../cpp/TypeTraits.h"
#include <atomic>
#include <type_traits>
#if CARB_POSIX
# include <pthread.h>
#elif CARB_PLATFORM_WINDOWS
# include "../CarbWindows.h"
# include "SharedMutex.h"
# include <map>
#else
CARB_UNSUPPORTED_PLATFORM();
#endif
namespace carb
{
namespace thread
{
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{
using TlsDestructor = void (*)(void*);
inline std::mutex& tlsMutex()
{
static std::mutex m;
return m;
}
# if CARB_POSIX
class ThreadLocalBase
{
pthread_key_t m_key;
public:
ThreadLocalBase(TlsDestructor destructor)
{
int res = pthread_key_create(&m_key, destructor);
CARB_FATAL_UNLESS(res == 0, "pthread_key_create failed: %d/%s", res, strerror(res));
}
~ThreadLocalBase()
{
pthread_key_delete(m_key);
}
// Not copyable or movable
CARB_PREVENT_COPY_AND_MOVE(ThreadLocalBase);
void* get() const
{
return pthread_getspecific(m_key);
}
void set(void* val) const
{
int res = pthread_setspecific(m_key, val);
CARB_CHECK(res == 0, "pthread_setspecific failed with %d/%s for key %u", res, strerror(res), m_key);
}
};
# elif CARB_PLATFORM_WINDOWS
__declspec(selectany) CARBWIN_SRWLOCK mutex = CARBWIN_SRWLOCK_INIT;
__declspec(selectany) bool destructed = false;
class CARB_VIZ ThreadLocalBase
{
CARB_VIZ DWORD m_key;
class Destructors
{
using DestructorMap = std::map<DWORD, TlsDestructor>;
DestructorMap m_map;
public:
Destructors() = default;
~Destructors()
{
AcquireSRWLockExclusive((PSRWLOCK)&mutex);
destructed = true;
// Destroy the map under the lock
DestructorMap{}.swap(m_map);
ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
}
void add(DWORD slot, TlsDestructor fn)
{
// If there's no destructor, don't do anything
if (!fn)
return;
AcquireSRWLockExclusive((PSRWLOCK)&mutex);
m_map[slot] = fn;
ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
}
void remove(DWORD slot)
{
AcquireSRWLockExclusive((PSRWLOCK)&mutex);
m_map.erase(slot);
ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
}
void call()
{
AcquireSRWLockShared((PSRWLOCK)&mutex);
// It is possible for atexit destructors to run before other threads call destructors with thread-storage
// duration.
if (destructed)
{
ReleaseSRWLockShared((PSRWLOCK)&mutex);
return;
}
// This mimics the process of destructors with pthread_key_create which will iterate multiple (up to
// PTHREAD_DESTRUCTOR_ITERATIONS) times, which is typically 4.
bool again;
int iters = 0;
const int kMaxIters = 4;
do
{
again = false;
for (auto& pair : m_map)
{
if (void* val = ::TlsGetValue(pair.first))
{
// Set to nullptr and call destructor
::TlsSetValue(pair.first, nullptr);
pair.second(val);
again = true;
}
}
} while (again && ++iters < kMaxIters);
ReleaseSRWLockShared((PSRWLOCK)&mutex);
}
};
static Destructors& destructors()
{
static Destructors d;
return d;
}
public:
ThreadLocalBase(TlsDestructor destructor)
{
m_key = ::TlsAlloc();
CARB_FATAL_UNLESS(m_key != CARBWIN_TLS_OUT_OF_INDEXES, "TlsAlloc() failed: %" PRIu32 "", ::GetLastError());
destructors().add(m_key, destructor);
}
~ThreadLocalBase()
{
destructors().remove(m_key);
BOOL b = ::TlsFree(m_key);
CARB_CHECK(!!b);
}
// Not copyable or movable
CARB_PREVENT_COPY_AND_MOVE(ThreadLocalBase);
void* get() const
{
return ::TlsGetValue(m_key);
}
void set(void* val) const
{
BOOL b = ::TlsSetValue(m_key, val);
CARB_CHECK(!!b);
}
static void callDestructors(HINSTANCE, DWORD fdwReason, PVOID)
{
if (fdwReason == CARBWIN_DLL_THREAD_DETACH)
{
// Call for current thread
destructors().call();
}
}
};
extern "C"
{
// Hook the TLS destructors in the CRT
// see crt/src/vcruntime/tlsdtor.cpp
using TlsHookFunc = void(__stdcall*)(HINSTANCE, DWORD, PVOID);
// Reference these so that the linker knows to include them
extern DWORD _tls_used;
extern TlsHookFunc __xl_a[], __xl_z[];
# pragma comment(linker, "/include:pthread_thread_callback")
# pragma section(".CRT$XLD", long, read)
// Since this is a header file, the __declspec(selectany) enables weak linking so that the linker will throw away
// all the duplicates and leave only one instance of pthread_thread_callback in the binary.
// This is placed into the specific binary section used for TLS destructors
__declspec(allocate(".CRT$XLD")) __declspec(selectany)
TlsHookFunc pthread_thread_callback = carb::thread::detail::ThreadLocalBase::callDestructors;
}
# else
CARB_UNSUPPORTED_PLATFORM();
# endif
} // namespace detail
#endif
template <class T,
bool Trivial CARB_NO_DOC(= cpp::conjunction<std::is_trivial<T>,
std::is_trivially_destructible<T>,
cpp::bool_constant<sizeof(T) <= sizeof(void*)>>::value)>
class ThreadLocal
{
};
// Specializations
// Trivial and can fit within a pointer
template <class T>
class ThreadLocal<T, true> : private detail::ThreadLocalBase
{
public:
ThreadLocal() : ThreadLocalBase(nullptr)
{
}
~ThreadLocal() = default;
T get() const
{
union
{
cpp::detail::NontrivialDummyType dummy;
T t;
} u = {};
void* p = ThreadLocalBase::get();
memcpy(&u.t, &p, sizeof(T));
return u.t;
}
void set(T t)
{
void* p = 0;
memcpy(&p, &t, sizeof(T));
ThreadLocalBase::set(p);
}
operator T()
{
return get();
}
operator T() const
{
return get();
}
ThreadLocal& operator=(T t)
{
set(t);
return *this;
}
T operator->() const
{
static_assert(std::is_pointer<T>::value, "Requires pointer type");
return get();
}
auto operator*() const
{
static_assert(std::is_pointer<T>::value, "Requires pointer type");
return *get();
}
bool operator==(const T& rhs) const
{
return get() == rhs;
}
bool operator!=(const T& rhs) const
{
return get() != rhs;
}
};
// Non-trivial or needs more than a pointer (uses the heap)
template <class T>
class ThreadLocal<T, false> : private detail::ThreadLocalBase
{
public:
ThreadLocal() : ThreadLocalBase(destructor), m_head(&m_head)
{
detail::tlsMutex(); // make sure this is constructed since we'll need it at shutdown
}
~ThreadLocal()
{
// Delete all instances for threads created by this object
ListNode n = m_head;
m_head.next = m_head.prev = _end();
while (n.next != _end())
{
Wrapper* w = reinterpret_cast<Wrapper*>(n.next);
n.next = n.next->next;
delete w;
}
// It would be very bad if a thread was using this while we're destroying it
CARB_ASSERT(m_head.next == _end() && m_head.prev == _end());
}
T* get_if()
{
return _get_if();
}
const T* get_if() const
{
return _get_if();
}
T& get()
{
return *_get();
}
const T& get() const
{
return *_get();
}
void set(const T& t)
{
*_get() = t;
}
void set(T&& t)
{
*_get() = std::move(t);
}
void reset()
{
if (auto p = get_if())
{
destructor(p);
ThreadLocalBase::set(nullptr);
}
}
operator T()
{
return get();
}
operator T() const
{
return get();
}
ThreadLocal& operator=(const T& rhs)
{
set(rhs);
return *this;
}
ThreadLocal& operator=(T&& rhs)
{
set(std::move(rhs));
return *this;
}
auto operator->()
{
return get().operator->();
}
auto operator->() const
{
return get().operator->();
}
auto operator*()
{
return get().operator*();
}
auto operator*() const
{
return get().operator*();
}
template <class U>
auto operator[](const U& u) const
{
return get().operator[](u);
}
bool operator==(const T& rhs) const
{
return get() == rhs;
}
bool operator!=(const T& rhs) const
{
return get() != rhs;
}
private:
struct ListNode
{
ListNode* next;
ListNode* prev;
ListNode() = default;
ListNode(ListNode* init) : next(init), prev(init)
{
}
};
struct Wrapper : public ListNode
{
T t;
};
ListNode m_head;
ListNode* _tail() const
{
return const_cast<ListNode*>(m_head.prev);
}
ListNode* _end() const
{
return const_cast<ListNode*>(&m_head);
}
static void destructor(void* p)
{
// Can't use offsetof because of "offsetof within non-standard-layout type 'Wrapper' is undefined"
Wrapper* w = reinterpret_cast<Wrapper*>(reinterpret_cast<uint8_t*>(p) - size_t(&((Wrapper*)0)->t));
{
// Remove from the list
std::lock_guard<std::mutex> g(detail::tlsMutex());
w->next->prev = w->prev;
w->prev->next = w->next;
}
delete w;
}
T* _get_if() const
{
T* p = reinterpret_cast<T*>(ThreadLocalBase::get());
return p;
}
T* _get() const
{
T* p = _get_if();
return p ? p : _create();
}
T* _create() const
{
Wrapper* w = new Wrapper;
// Add to end of list
{
std::lock_guard<std::mutex> g(detail::tlsMutex());
w->next = _end();
w->prev = _tail();
w->next->prev = w;
w->prev->next = w;
}
T* p = std::addressof(w->t);
ThreadLocalBase::set(p);
return p;
}
};
} // namespace thread
} // namespace carb