ThreadLocal.h#
Fully qualified name: carb/thread/ThreadLocal.h
File members: carb/thread/ThreadLocal.h
// Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include "../Defines.h"
#include "../cpp/Bit.h"
#include "../cpp/TypeTraits.h"
#include <atomic>
#include <mutex>
#include <type_traits>
#include <utility>
#if CARB_POSIX
# include <pthread.h>
#elif CARB_PLATFORM_WINDOWS
# include "../CarbWindows.h"
# include <map>
#else
CARB_UNSUPPORTED_PLATFORM();
#endif
namespace carb
{
namespace thread
{
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{
using TlsDestructor = void (*)(void*);
struct SimpleSpinlock
{
constexpr SimpleSpinlock() noexcept = default;
void lock() noexcept
{
// Full read/write barrier
while (m_flag.exchange(true, std::memory_order_seq_cst))
CARB_HARDWARE_PAUSE();
}
void unlock() noexcept
{
// Full read/write barrier (hence the exchange)
(void)m_flag.exchange(false, std::memory_order_seq_cst);
}
std::atomic_bool m_flag{ false };
};
# if CARB_POSIX
class ThreadLocalBase
{
pthread_key_t m_key;
bool m_keyZero = false;
public:
ThreadLocalBase(TlsDestructor destructor)
{
for (;;)
{
int res = pthread_key_create(&m_key, destructor);
CARB_FATAL_UNLESS(res == 0, "pthread_key_create failed: %d/%s", res, strerror(res));
// Even though pthread key 0 might be valid, we have seen some issues where third-party libraries use 0 as
// an 'uninitialized' state, and may potentially close key 0 even if they don't own it. So if we were given
// key 0, create another key too.
CARB_LIKELY_IF(m_key != pthread_key_t{})
{
break;
}
CARB_ASSERT(!m_keyZero);
m_keyZero = true;
}
}
~ThreadLocalBase()
{
pthread_key_delete(m_key);
if (m_keyZero)
// See constructor above for rationale.
pthread_key_delete(pthread_key_t{});
}
// Not copyable or movable
CARB_PREVENT_COPY_AND_MOVE(ThreadLocalBase);
void* get() const
{
return pthread_getspecific(m_key);
}
void set(void* val) const
{
int res = pthread_setspecific(m_key, val);
CARB_CHECK(res == 0, "pthread_setspecific failed with %d/%s for key %u", res, strerror(res), m_key);
}
};
# elif CARB_PLATFORM_WINDOWS
__declspec(selectany) CARBWIN_SRWLOCK mutex = CARBWIN_SRWLOCK_INIT;
__declspec(selectany) bool destructed = false;
class CARB_VIZ ThreadLocalBase
{
CARB_VIZ DWORD m_key;
class Destructors
{
using DestructorMap = std::map<DWORD, TlsDestructor>;
DestructorMap m_map;
public:
Destructors() = default;
~Destructors()
{
AcquireSRWLockExclusive((PSRWLOCK)&mutex);
destructed = true;
// Destroy the map under the lock
DestructorMap{}.swap(m_map);
ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
}
void add(DWORD slot, TlsDestructor fn)
{
// If there's no destructor, don't do anything
if (!fn)
return;
AcquireSRWLockExclusive((PSRWLOCK)&mutex);
if (!destructed)
{
m_map[slot] = fn;
}
ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
}
void remove(DWORD slot)
{
AcquireSRWLockExclusive((PSRWLOCK)&mutex);
if (!destructed)
{
m_map.erase(slot);
}
ReleaseSRWLockExclusive((PSRWLOCK)&mutex);
}
void call()
{
AcquireSRWLockShared((PSRWLOCK)&mutex);
// It is possible for atexit destructors to run before other threads call destructors with thread-storage
// duration.
if (destructed)
{
ReleaseSRWLockShared((PSRWLOCK)&mutex);
return;
}
// This mimics the process of destructors with pthread_key_create which will iterate multiple (up to
// PTHREAD_DESTRUCTOR_ITERATIONS) times, which is typically 4.
bool again;
int iters = 0;
const int kMaxIters = 4;
do
{
again = false;
for (auto& pair : m_map)
{
if (void* val = ::TlsGetValue(pair.first))
{
// Set to nullptr and call destructor
::TlsSetValue(pair.first, nullptr);
pair.second(val);
again = true;
}
}
} while (again && ++iters < kMaxIters);
ReleaseSRWLockShared((PSRWLOCK)&mutex);
}
};
static Destructors& destructors()
{
static Destructors d;
return d;
}
public:
ThreadLocalBase(TlsDestructor destructor)
{
m_key = ::TlsAlloc();
CARB_FATAL_UNLESS(m_key != CARBWIN_TLS_OUT_OF_INDEXES, "TlsAlloc() failed: %" PRIdword "", ::GetLastError());
destructors().add(m_key, destructor);
}
~ThreadLocalBase()
{
destructors().remove(m_key);
BOOL b = ::TlsFree(m_key);
CARB_CHECK(!!b);
}
// Not copyable or movable
CARB_PREVENT_COPY_AND_MOVE(ThreadLocalBase);
void* get() const
{
return ::TlsGetValue(m_key);
}
void set(void* val) const
{
BOOL b = ::TlsSetValue(m_key, val);
CARB_CHECK(!!b);
}
static void callDestructors(HINSTANCE, DWORD fdwReason, PVOID)
{
if (fdwReason == CARBWIN_DLL_THREAD_DETACH)
{
// Call for current thread
destructors().call();
}
}
};
extern "C"
{
// Hook the TLS destructors in the CRT
// see crt/src/vcruntime/tlsdtor.cpp
using TlsHookFunc = void(__stdcall*)(HINSTANCE, DWORD, PVOID);
// Reference these so that the linker knows to include them
extern DWORD _tls_used;
extern TlsHookFunc __xl_a[], __xl_z[];
# pragma comment(linker, "/include:pthread_thread_callback")
# pragma section(".CRT$XLD", long, read)
// Since this is a header file, the __declspec(selectany) enables weak linking so that the linker will throw away
// all the duplicates and leave only one instance of pthread_thread_callback in the binary.
// This is placed into the specific binary section used for TLS destructors
__declspec(allocate(".CRT$XLD")) __declspec(selectany)
TlsHookFunc pthread_thread_callback = carb::thread::detail::ThreadLocalBase::callDestructors;
}
# else
CARB_UNSUPPORTED_PLATFORM();
# endif
} // namespace detail
#endif
template <class T,
bool Trivial CARB_NO_DOC(= std::conjunction<std::is_trivial<T>,
std::is_trivially_destructible<T>,
std::bool_constant<sizeof(T) <= sizeof(void*)>>::value)>
class ThreadLocal
{
};
// Specializations
// Trivial and can fit within a pointer
template <class T>
class ThreadLocal<T, true> : private detail::ThreadLocalBase
{
public:
ThreadLocal() : ThreadLocalBase(nullptr)
{
}
~ThreadLocal() = default;
T get() const
{
// Cannot use bit_cast because sizeof(T) may not be the same as sizeof(void*)
T t;
void* p = ThreadLocalBase::get();
memcpy(&t, &p, sizeof(T));
return t;
}
void set(T t)
{
void* p = 0;
memcpy(&p, &t, sizeof(T));
ThreadLocalBase::set(p);
}
operator T()
{
return get();
}
operator T() const
{
return get();
}
ThreadLocal& operator=(T t)
{
set(t);
return *this;
}
T operator->() const
{
static_assert(std::is_pointer<T>::value, "Requires pointer type");
return get();
}
auto operator*() const
{
static_assert(std::is_pointer<T>::value, "Requires pointer type");
return *get();
}
bool operator==(const T& rhs) const
{
return get() == rhs;
}
bool operator!=(const T& rhs) const
{
return get() != rhs;
}
};
// Non-trivial or needs more than a pointer (uses the heap)
template <class T>
class ThreadLocal<T, false> : private detail::ThreadLocalBase
{
public:
ThreadLocal() : ThreadLocalBase(destructor), m_list(_end())
{
}
~ThreadLocal()
{
// NOTE: we do not need to hold m_spinlock here because it is UB to call functions in an object from other
// threads while it is being destructed.
// Delete all instances for threads created by this object
Wrapper* p = m_list.next;
m_list.next = m_list.prev = _end();
for (Wrapper* next; p != _end(); p = next)
{
next = p->next;
delete p;
}
// It would be very bad if a thread was using this while we're destroying it
CARB_ASSERT(m_list.next == _end() && m_list.prev == _end());
}
T* get_if()
{
return _get_if();
}
const T* get_if() const
{
return _get_if();
}
T& get()
{
return *_get();
}
const T& get() const
{
return *_get();
}
void set(const T& t)
{
*_get() = t;
}
void set(T&& t)
{
*_get() = std::move(t);
}
void reset()
{
if (auto w = static_cast<Wrapper*>(ThreadLocalBase::get()))
{
destructor(w);
ThreadLocalBase::set(nullptr);
}
}
operator T()
{
return get();
}
operator T() const
{
return get();
}
ThreadLocal& operator=(const T& rhs)
{
set(rhs);
return *this;
}
ThreadLocal& operator=(T&& rhs)
{
set(std::move(rhs));
return *this;
}
auto operator->()
{
return get().operator->();
}
auto operator->() const
{
return get().operator->();
}
auto operator*()
{
return get().operator*();
}
auto operator*() const
{
return get().operator*();
}
template <class U>
auto operator[](const U& u) const
{
return get().operator[](u);
}
bool operator==(const T& rhs) const
{
return get() == rhs;
}
bool operator!=(const T& rhs) const
{
return get() != rhs;
}
private:
struct Wrapper;
struct ListNode
{
Wrapper* next = nullptr;
Wrapper* prev = nullptr;
constexpr ListNode() noexcept = default;
constexpr ListNode(Wrapper* init) noexcept : next(init), prev(init)
{
}
};
struct Wrapper : public ListNode
{
T t;
};
Wrapper* _tail() const
{
return const_cast<Wrapper*>(m_list.prev);
}
Wrapper* _end() const
{
return cpp::bit_cast<Wrapper*>(const_cast<ListNode*>(&m_list));
}
static void destructor(void* p)
{
auto w = static_cast<Wrapper*>(p);
{
// Remove from the list
std::lock_guard<detail::SimpleSpinlock> g(spinlock());
w->next->prev = w->prev;
w->prev->next = w->next;
}
delete w;
}
T* _get_if() const
{
auto w = static_cast<Wrapper*>(ThreadLocalBase::get());
return w ? std::addressof(w->t) : nullptr;
}
T* _get() const
{
T* p = _get_if();
return p ? p : _create();
}
T* _create() const
{
Wrapper* w = new Wrapper; // may throw
// Add to end of list
{
std::lock_guard<detail::SimpleSpinlock> g(spinlock());
w->next = _end();
w->prev = _tail();
w->next->prev = w;
w->prev->next = w;
}
ThreadLocalBase::set(w);
return std::addressof(w->t);
}
static detail::SimpleSpinlock& spinlock() noexcept
{
static detail::SimpleSpinlock s_spinlock;
return s_spinlock;
}
ListNode m_list;
};
} // namespace thread
} // namespace carb