File members: carb/container/IntrusiveUnorderedMultimap.h
// Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#pragma once
#include "../Defines.h"
#include "../cpp/Bit.h"
#include <cmath>
#include <functional>
#include <iterator>
#include <memory>
namespace carb
namespace container
template <class Key, class T>
class IntrusiveUnorderedMultimapLink;
template <class Key, class T, IntrusiveUnorderedMultimapLink<Key, T> T::*U, class Hash, class Pred>
class IntrusiveUnorderedMultimap;
namespace detail
struct NontrivialDummyType
constexpr NontrivialDummyType() noexcept
static_assert(!std::is_trivially_default_constructible<NontrivialDummyType>::value, "Invalid assumption");
template <class Key, class T>
class IntrusiveUnorderedMultimapLinkBase
using KeyType = const Key;
using MappedType = T;
using ValueType = std::pair<const Key, T&>;
constexpr IntrusiveUnorderedMultimapLinkBase() noexcept = default;
// Shouldn't be contained at destruction time
bool isContained() const noexcept
return m_next != nullptr;
template <class Key2, class U, IntrusiveUnorderedMultimapLink<Key2, U> U::*V, class Hash, class Pred>
friend class ::carb::container::IntrusiveUnorderedMultimap;
IntrusiveUnorderedMultimapLink<Key, T>* m_next{ nullptr };
IntrusiveUnorderedMultimapLink<Key, T>* m_prev{ nullptr };
constexpr IntrusiveUnorderedMultimapLinkBase(IntrusiveUnorderedMultimapLink<Key, T>* init) noexcept
: m_next(init), m_prev(init)
} // namespace detail
template <class Key, class T>
class IntrusiveUnorderedMultimapLink : public detail::IntrusiveUnorderedMultimapLinkBase<Key, T>
using Base = detail::IntrusiveUnorderedMultimapLinkBase<Key, T>;
constexpr IntrusiveUnorderedMultimapLink() noexcept : empty{}
// Shouldn't be contained at destruction time
template <class Key2, class U, IntrusiveUnorderedMultimapLink<Key2, U> U::*V, class Hash, class Pred>
friend class IntrusiveUnorderedMultimap;
detail::NontrivialDummyType empty;
typename Base::ValueType value;
template <class Key, class T, IntrusiveUnorderedMultimapLink<Key, T> T::*U, class Hash = ::std::hash<Key>, class Pred = ::std::equal_to<Key>>
class IntrusiveUnorderedMultimap
using BaseLink = detail::IntrusiveUnorderedMultimapLinkBase<Key, T>;
using KeyType = const Key;
using MappedType = T;
using Link = IntrusiveUnorderedMultimapLink<Key, T>;
using ValueType = typename Link::ValueType;
// Iterator support
// clang-format off
class const_iterator
using iterator_category = std::forward_iterator_tag;
using value_type = ValueType;
using difference_type = ptrdiff_t;
using pointer = const value_type*;
using reference = const value_type&;
const_iterator() noexcept = default;
reference operator * () const { assertNotEnd(); return m_where->value; }
pointer operator -> () const { assertNotEnd(); return std::addressof(operator*()); }
const_iterator& operator ++ () noexcept /* ++iter */ { assertNotEnd(); incr(); return *this; }
const_iterator operator ++ (int) noexcept /* iter++ */ { assertNotEnd(); const_iterator i{ *this }; incr(); return i; }
bool operator == (const const_iterator& rhs) const noexcept { assertSameOwner(rhs); return m_where == rhs.m_where; }
bool operator != (const const_iterator& rhs) const noexcept { assertSameOwner(rhs); return m_where != rhs.m_where; }
friend class IntrusiveUnorderedMultimap;
Link* m_where{ nullptr };
const IntrusiveUnorderedMultimap* m_owner{ nullptr };
const_iterator(Link* where, const IntrusiveUnorderedMultimap* owner) noexcept : m_where(where), m_owner(owner) {}
void assertOwner(const IntrusiveUnorderedMultimap* other) const noexcept
CARB_ASSERT(m_owner == other, "IntrusiveUnorderedMultimap iterator for invalid container");
void assertSameOwner(const const_iterator& rhs) const noexcept
CARB_ASSERT(m_owner == rhs.m_owner, "IntrusiveUnorderedMultimap iterators are from different containers");
void assertNotEnd() const noexcept
CARB_ASSERT(m_where != m_owner->_end(), "Invalid operation on IntrusiveUnorderedMultimap::end() iterator");
const_iterator(Link* where, const IntrusiveUnorderedMultimap*) noexcept : m_where(where) {}
void assertOwner(const IntrusiveUnorderedMultimap* other) const noexcept
void assertSameOwner(const const_iterator& rhs) const noexcept
void assertNotEnd() const noexcept {}
void incr() noexcept { m_where = m_where->m_next; }
class iterator : public const_iterator
using Base = const_iterator;
using iterator_category = std::forward_iterator_tag;
using value_type = ValueType;
using difference_type = ptrdiff_t;
using pointer = value_type*;
using reference = value_type&;
iterator() noexcept = default;
reference operator * () const { this->assertNotEnd(); return this->m_where->value; }
pointer operator -> () const { this->assertNotEnd(); return std::addressof(operator*()); }
iterator& operator ++ () noexcept /* ++iter */ { this->assertNotEnd(); this->incr(); return *this; }
iterator operator ++ (int) noexcept /* iter++ */ { this->assertNotEnd(); iterator i{ *this }; this->incr(); return i; }
friend class IntrusiveUnorderedMultimap;
iterator(Link* where, const IntrusiveUnorderedMultimap* owner) : Base(where, owner) {}
// clang-format on
constexpr IntrusiveUnorderedMultimap() : m_list(_end())
IntrusiveUnorderedMultimap(IntrusiveUnorderedMultimap&& other) noexcept : m_list(_end())
m_list.m_next = m_list.m_prev = nullptr; // Prevents the assert
IntrusiveUnorderedMultimap& operator=(IntrusiveUnorderedMultimap&& other) noexcept
return *this;
bool empty() const noexcept
return !m_size;
size_t size() const noexcept
return m_size;
size_t max_size() const noexcept
return size_t(-1);
// Iterator support
iterator begin() noexcept
return iterator(_head(), this);
iterator end() noexcept
return iterator(_end(), this);
const_iterator cbegin() const noexcept
return const_iterator(_head(), this);
const_iterator cend() const noexcept
return const_iterator(_end(), this);
const_iterator begin() const noexcept
return cbegin();
const_iterator end() const noexcept
return cend();
iterator locate(T& value) noexcept
Link* l = _link(value);
return iterator(l->isContained() ? _listfind(value) : _end(), this);
const_iterator locate(T& value) const noexcept
Link* l = _link(value);
return const_iterator(l->isContained() ? _listfind(value) : _end(), this);
iterator iter_from_value(T& value)
Link* l = _link(value);
CARB_ASSERT(!l->isContained() || _listfind(value) != _end());
return iterator(l->isContained() ? l : _end(), this);
const_iterator iter_from_value(T& value) const
Link* l = _link(value);
CARB_ASSERT(!l->isContained() || _listfind(value) != _end());
return const_iterator(l->isContained() ? l : _end(), this);
void clear()
if (_head() != _end())
Link* p = _head();
p->value.~ValueType(); // Destruct the key
m_list.m_next = p->m_next;
p->m_next = p->m_prev = nullptr;
} while (_head() != _end());
m_list.m_prev = _end();
m_size = 0;
// Clear the buckets
memset(m_buckets.get(), 0, sizeof(LinkPair) * m_bucketCount);
iterator insert(ValueType value)
T& val = value.second;
Link* l = _link(val);
// Construct the key
new (&l->value) ValueType(std::move(value));
// Hash
size_t const hash = Hash{}(l->value.first);
// Find insertion point
LinkPair& bucket = m_buckets[_bucket(hash)];
if (bucket.first)
// Need to see if there's a matching value in the bucket so that we group all keys together
Pred pred{};
Link* const end = bucket.second->m_next;
for (Link* p = bucket.first; p != end; p = p->m_next)
if (pred(l->value.first, p->value.first))
// Match! Insert here.
l->m_prev = p->m_prev;
l->m_next = p;
l->m_prev->m_next = l;
l->m_next->m_prev = l;
if (p == bucket.first)
bucket.first = l;
return iterator(l, this);
// Didn't find a match within the bucket. Just add to the end of the bucket
l->m_prev = bucket.second;
l->m_next = end;
l->m_prev->m_next = l;
l->m_next->m_prev = l;
bucket.second = l;
// Insert at end of the list
l->m_prev = _tail();
l->m_next = _end();
l->m_prev->m_next = l;
l->m_next->m_prev = l;
bucket.first = bucket.second = l;
return iterator(l, this);
template <class... Args>
iterator emplace(Args&&... args)
return insert(ValueType{ std::forward<Args>(args)... });
iterator find(const Key& key)
if (empty())
return end();
size_t const hash = Hash{}(key);
LinkPair& pair = m_buckets[_bucket(hash)];
if (!pair.first)
return end();
Pred pred{};
for (Link* p = pair.first; p != pair.second->m_next; p = p->m_next)
if (pred(p->value.first, key))
return iterator(p, this);
// Not found
return end();
const_iterator find(const Key& key) const
if (empty())
return cend();
size_t const hash = Hash{}(key);
LinkPair& pair = m_buckets[_bucket(hash)];
if (!pair.first)
return cend();
Pred pred{};
Link* const bucketEnd = pair.second->m_next;
for (Link* p = pair.first; p != bucketEnd; p = p->m_next)
if (pred(p->value.first, key))
return const_iterator(p, this);
return cend();
std::pair<iterator, iterator> equal_range(const Key& key)
if (empty())
return std::make_pair(end(), end());
size_t const hash = Hash{}(key);
LinkPair& pair = m_buckets[_bucket(hash)];
if (!pair.first)
return std::make_pair(end(), end());
Pred pred{};
Link* p = pair.first;
Link* const bucketEnd = pair.second->m_next;
for (; p != bucketEnd; p = p->m_next)
if (pred(p->value.first, key))
// Inner loop: terminates when no longer matches or bucket ends
Link* first = p;
p = p->m_next;
for (; p != bucketEnd; p = p->m_next)
if (!pred(p->value.first, key))
return std::make_pair(iterator(first, this), iterator(p, this));
return std::make_pair(end(), end());
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const
if (empty())
return std::make_pair(cend(), cend());
size_t const hash = Hash{}(key);
LinkPair& pair = m_buckets[_bucket(hash)];
if (!pair.first)
return std::make_pair(cend(), cend());
Pred pred{};
Link* p = pair.first;
Link* const bucketEnd = pair.second->m_next;
for (; p != bucketEnd; p = p->m_next)
if (pred(p->value.first, key))
// Inner loop: terminates when no longer matches or bucket ends
Link* first = p;
p = p->m_next;
for (; p != bucketEnd; p = p->m_next)
if (!pred(p->value.first, key))
return std::make_pair(const_iterator(first, this), const_iterator(p, this));
return std::make_pair(cend(), cend());
size_t count(const Key& key) const
if (empty())
return 0;
size_t const hash = Hash{}(key);
LinkPair& pair = m_buckets[_bucket(hash)];
if (!pair.first)
return 0;
Pred pred{};
Link* p = pair.first;
Link* const bucketEnd = pair.second->m_next;
for (; p != bucketEnd; p = p->m_next)
if (pred(p->value.first, key))
// Inner loop: terminates when no longer matches or bucket ends
size_t count = 1;
p = p->m_next;
for (; p != bucketEnd; p = p->m_next)
if (!pred(p->value.first, key))
return count;
return 0;
iterator remove(const_iterator pos)
Link* l = pos.m_where;
Link* next = l->m_next;
// Fix up bucket if necessary
LinkPair& pair = m_buckets[_bucket(Hash{}(l->value.first))];
if (pair.first == l)
if (pair.second == l)
// Empty bucket now
pair.first = pair.second = nullptr;
pair.first = next;
else if (pair.second == l)
pair.second = l->m_prev;
l->m_prev->m_next = l->m_next;
l->m_next->m_prev = l->m_prev;
l->m_next = l->m_prev = nullptr;
// Destruct value
return iterator(next, this);
T& remove(T& value)
Link* l = _link(value);
if (l->isContained())
CARB_ASSERT(_listfind(value) != _end());
// Fix up bucket if necessary
LinkPair& pair = m_buckets[_bucket(Hash{}(l->value.first))];
if (pair.first == l)
if (pair.second == l)
// Empty bucket now
pair.first = pair.second = nullptr;
pair.first = l->m_next;
else if (pair.second == l)
pair.second = l->m_prev;
l->m_prev->m_next = l->m_next;
l->m_next->m_prev = l->m_prev;
l->m_next = l->m_prev = nullptr;
// Destruct value
return value;
size_t remove(const Key& key)
size_t count{ 0 };
auto pair = equal_range(key);
while (pair.first != pair.second)
return count;
void swap(IntrusiveUnorderedMultimap& other) noexcept
if (this != std::addressof(other))
// Fix up the end iterators first
Link *&lhead = _head()->m_prev, *<ail = _tail()->m_next;
Link *&rhead = other._head()->m_prev, *&rtail = other._tail()->m_next;
lhead = ltail = other._end();
rhead = rtail = _end();
// Now swap everything else
std::swap(m_buckets, other.m_buckets);
std::swap(m_bucketCount, other.m_bucketCount);
std::swap(_end()->m_next, other._end()->m_next);
std::swap(_end()->m_prev, other._end()->m_prev);
std::swap(m_size, other.m_size);
std::swap(m_maxLoadFactor, other.m_maxLoadFactor);
size_t bucket_count() const noexcept
return m_bucketCount;
size_t max_bucket_count() const noexcept
return size_t(-1);
size_t bucket(const Key& key) const
return _bucket(Hash{}(key));
float load_factor() const
return bucket_count() ? float(size()) / float(bucket_count()) : 0.f;
float max_load_factor() const noexcept
return m_maxLoadFactor;
void max_load_factor(float ml)
CARB_ASSERT(ml > 0.f);
m_maxLoadFactor = ml;
void reserve(size_t count)
rehash(size_t(std::ceil(float(count) / max_load_factor())));
void rehash(size_t buckets)
if (buckets > m_bucketCount)
constexpr static size_t kMinBuckets(8);
static_assert(carb::cpp::has_single_bit(kMinBuckets), "Invalid assumption");
buckets = carb::cpp::bit_ceil(::carb_max(buckets, kMinBuckets));
m_buckets.reset(new LinkPair[buckets]);
memset(m_buckets.get(), 0, sizeof(LinkPair) * buckets);
m_bucketCount = buckets;
// Walk through the list backwards and rehash everything. Things that have equal keys and are already
// grouped together will remain so.
Link* cur = _tail();
m_list.m_prev = m_list.m_next = _end();
Link* next;
Hash hasher;
for (; cur != _end(); cur = next)
next = cur->m_prev;
LinkPair& bucket = m_buckets[_bucket(hasher(cur->value.first))];
if (bucket.first)
// Insert in front of whatever was in the bucket
cur->m_prev = bucket.first->m_prev;
cur->m_next = bucket.first;
cur->m_prev->m_next = cur;
cur->m_next->m_prev = cur;
bucket.first = cur;
// Insert at the front of the list and the beginning of the bucket
cur->m_prev = _end();
cur->m_next = _head();
cur->m_prev->m_next = cur;
cur->m_next->m_prev = cur;
bucket.first = bucket.second = cur;
struct LinkPair
Link* first;
Link* second;
std::unique_ptr<LinkPair[]> m_buckets{};
size_t m_bucketCount{ 0 };
BaseLink m_list;
size_t m_size{ 0 };
float m_maxLoadFactor{ 1.f };
size_t _bucket(size_t hash) const
// bucket count is always a power of 2
return hash & (m_bucketCount - 1);
Link* _listfind(T& value) const
Link* find = _link(value);
Link* p = _head();
for (; p != _end(); p = p->m_next)
if (p == find)
return p;
return _end();
static Link* _link(T& value) noexcept
return std::addressof(value.*U);
static T& _value(Link& l) noexcept
// Need to calculate the offset of our link member which will allow adjusting the pointer to where T is.
// This will not work if T uses virtual inheritance. Also, offsetof() cannot be used because we have a pointer
// to the member
size_t offset = size_t(reinterpret_cast<uint8_t*>(&(((T*)0)->*U)));
return *reinterpret_cast<T*>(reinterpret_cast<uint8_t*>(std::addressof(l)) - offset);
static const T& _value(const Link& l) noexcept
// Need to calculate the offset of our link member which will allow adjusting the pointer to where T is.
// This will not work if T uses virtual inheritance. Also, offsetof() cannot be used because we have a pointer
// to the member
size_t offset = size_t(reinterpret_cast<uint8_t*>(&(((T*)0)->*U)));
return *reinterpret_cast<const T*>(reinterpret_cast<const uint8_t*>(std::addressof(l)) - offset);
constexpr Link* _head() const noexcept
return const_cast<Link*>(m_list.m_next);
constexpr Link* _tail() const noexcept
return const_cast<Link*>(m_list.m_prev);
constexpr Link* _end() const noexcept
return static_cast<Link*>(const_cast<BaseLink*>(&m_list));
} // namespace container
} // namespace carb