carb/container/IntrusiveUnorderedMultimap.h
File members: carb/container/IntrusiveUnorderedMultimap.h
// Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include "../Defines.h"
#include "../cpp/Bit.h"
#include <cmath>
#include <functional>
#include <iterator>
#include <memory>
namespace carb
{
namespace container
{
template <class Key, class T>
class IntrusiveUnorderedMultimapLink;
template <class Key, class T, IntrusiveUnorderedMultimapLink<Key, T> T::*U, class Hash, class Pred>
class IntrusiveUnorderedMultimap;
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{
struct NontrivialDummyType
{
constexpr NontrivialDummyType() noexcept
{
}
};
static_assert(!std::is_trivially_default_constructible<NontrivialDummyType>::value, "Invalid assumption");
template <class Key, class T>
class IntrusiveUnorderedMultimapLinkBase
{
public:
using KeyType = const Key;
using MappedType = T;
using ValueType = std::pair<const Key, T&>;
constexpr IntrusiveUnorderedMultimapLinkBase() noexcept = default;
~IntrusiveUnorderedMultimapLinkBase()
{
// Shouldn't be contained at destruction time
CARB_ASSERT(!isContained());
}
bool isContained() const noexcept
{
return m_next != nullptr;
}
private:
template <class Key2, class U, IntrusiveUnorderedMultimapLink<Key2, U> U::*V, class Hash, class Pred>
friend class ::carb::container::IntrusiveUnorderedMultimap;
IntrusiveUnorderedMultimapLink<Key, T>* m_next{ nullptr };
IntrusiveUnorderedMultimapLink<Key, T>* m_prev{ nullptr };
CARB_PREVENT_COPY_AND_MOVE(IntrusiveUnorderedMultimapLinkBase);
constexpr IntrusiveUnorderedMultimapLinkBase(IntrusiveUnorderedMultimapLink<Key, T>* init) noexcept
: m_next(init), m_prev(init)
{
}
};
} // namespace detail
#endif
template <class Key, class T>
class IntrusiveUnorderedMultimapLink : public detail::IntrusiveUnorderedMultimapLinkBase<Key, T>
{
using Base = detail::IntrusiveUnorderedMultimapLinkBase<Key, T>;
public:
constexpr IntrusiveUnorderedMultimapLink() noexcept : empty{}
{
}
~IntrusiveUnorderedMultimapLink()
{
// Shouldn't be contained at destruction time
CARB_ASSERT(!this->isContained());
}
private:
template <class Key2, class U, IntrusiveUnorderedMultimapLink<Key2, U> U::*V, class Hash, class Pred>
friend class IntrusiveUnorderedMultimap;
union
{
detail::NontrivialDummyType empty;
typename Base::ValueType value;
};
CARB_PREVENT_COPY_AND_MOVE(IntrusiveUnorderedMultimapLink);
};
template <class Key, class T, IntrusiveUnorderedMultimapLink<Key, T> T::*U, class Hash = ::std::hash<Key>, class Pred = ::std::equal_to<Key>>
class IntrusiveUnorderedMultimap
{
using BaseLink = detail::IntrusiveUnorderedMultimapLinkBase<Key, T>;
public:
using KeyType = const Key;
using MappedType = T;
using Link = IntrusiveUnorderedMultimapLink<Key, T>;
using ValueType = typename Link::ValueType;
// Iterator support
// clang-format off
class const_iterator
{
public:
#ifndef DOXYGEN_SHOULD_SKIP_THIS
using iterator_category = std::forward_iterator_tag;
using value_type = ValueType;
using difference_type = ptrdiff_t;
using pointer = const value_type*;
using reference = const value_type&;
const_iterator() noexcept = default;
reference operator * () const { assertNotEnd(); return m_where->value; }
pointer operator -> () const { assertNotEnd(); return std::addressof(operator*()); }
const_iterator& operator ++ () noexcept /* ++iter */ { assertNotEnd(); incr(); return *this; }
const_iterator operator ++ (int) noexcept /* iter++ */ { assertNotEnd(); const_iterator i{ *this }; incr(); return i; }
bool operator == (const const_iterator& rhs) const noexcept { assertSameOwner(rhs); return m_where == rhs.m_where; }
bool operator != (const const_iterator& rhs) const noexcept { assertSameOwner(rhs); return m_where != rhs.m_where; }
protected:
friend class IntrusiveUnorderedMultimap;
Link* m_where{ nullptr };
#if CARB_ASSERT_ENABLED
const IntrusiveUnorderedMultimap* m_owner{ nullptr };
const_iterator(Link* where, const IntrusiveUnorderedMultimap* owner) noexcept : m_where(where), m_owner(owner) {}
void assertOwner(const IntrusiveUnorderedMultimap* other) const noexcept
{
CARB_ASSERT(m_owner == other, "IntrusiveUnorderedMultimap iterator for invalid container");
}
void assertSameOwner(const const_iterator& rhs) const noexcept
{
CARB_ASSERT(m_owner == rhs.m_owner, "IntrusiveUnorderedMultimap iterators are from different containers");
}
void assertNotEnd() const noexcept
{
CARB_ASSERT(m_where != m_owner->_end(), "Invalid operation on IntrusiveUnorderedMultimap::end() iterator");
}
#else
const_iterator(Link* where, const IntrusiveUnorderedMultimap*) noexcept : m_where(where) {}
void assertOwner(const IntrusiveUnorderedMultimap* other) const noexcept
{
CARB_UNUSED(other);
}
void assertSameOwner(const const_iterator& rhs) const noexcept
{
CARB_UNUSED(rhs);
}
void assertNotEnd() const noexcept {}
#endif // !CARB_ASSERT_ENABLED
void incr() noexcept { m_where = m_where->m_next; }
#endif // !DOXYGEN_SHOULD_SKIP_THIS
};
class iterator : public const_iterator
{
using Base = const_iterator;
public:
#ifndef DOXYGEN_SHOULD_SKIP_THIS
using iterator_category = std::forward_iterator_tag;
using value_type = ValueType;
using difference_type = ptrdiff_t;
using pointer = value_type*;
using reference = value_type&;
iterator() noexcept = default;
reference operator * () const { this->assertNotEnd(); return this->m_where->value; }
pointer operator -> () const { this->assertNotEnd(); return std::addressof(operator*()); }
iterator& operator ++ () noexcept /* ++iter */ { this->assertNotEnd(); this->incr(); return *this; }
iterator operator ++ (int) noexcept /* iter++ */ { this->assertNotEnd(); iterator i{ *this }; this->incr(); return i; }
private:
friend class IntrusiveUnorderedMultimap;
iterator(Link* where, const IntrusiveUnorderedMultimap* owner) : Base(where, owner) {}
#endif
};
// clang-format on
CARB_PREVENT_COPY(IntrusiveUnorderedMultimap);
constexpr IntrusiveUnorderedMultimap() : m_list(_end())
{
}
IntrusiveUnorderedMultimap(IntrusiveUnorderedMultimap&& other) noexcept : m_list(_end())
{
swap(other);
}
~IntrusiveUnorderedMultimap()
{
clear();
m_list.m_next = m_list.m_prev = nullptr; // Prevents the assert
}
IntrusiveUnorderedMultimap& operator=(IntrusiveUnorderedMultimap&& other) noexcept
{
swap(other);
return *this;
}
bool empty() const noexcept
{
return !m_size;
}
size_t size() const noexcept
{
return m_size;
}
size_t max_size() const noexcept
{
return size_t(-1);
}
// Iterator support
iterator begin() noexcept
{
return iterator(_head(), this);
}
iterator end() noexcept
{
return iterator(_end(), this);
}
const_iterator cbegin() const noexcept
{
return const_iterator(_head(), this);
}
const_iterator cend() const noexcept
{
return const_iterator(_end(), this);
}
const_iterator begin() const noexcept
{
return cbegin();
}
const_iterator end() const noexcept
{
return cend();
}
iterator locate(T& value) noexcept
{
Link* l = _link(value);
return iterator(l->isContained() ? _listfind(value) : _end(), this);
}
const_iterator locate(T& value) const noexcept
{
Link* l = _link(value);
return const_iterator(l->isContained() ? _listfind(value) : _end(), this);
}
iterator iter_from_value(T& value)
{
Link* l = _link(value);
CARB_ASSERT(!l->isContained() || _listfind(value) != _end());
return iterator(l->isContained() ? l : _end(), this);
}
const_iterator iter_from_value(T& value) const
{
Link* l = _link(value);
CARB_ASSERT(!l->isContained() || _listfind(value) != _end());
return const_iterator(l->isContained() ? l : _end(), this);
}
void clear()
{
if (_head() != _end())
{
do
{
Link* p = _head();
p->value.~ValueType(); // Destruct the key
m_list.m_next = p->m_next;
p->m_next = p->m_prev = nullptr;
} while (_head() != _end());
m_list.m_prev = _end();
m_size = 0;
// Clear the buckets
memset(m_buckets.get(), 0, sizeof(LinkPair) * m_bucketCount);
}
}
iterator insert(ValueType value)
{
T& val = value.second;
Link* l = _link(val);
CARB_ASSERT(!l->isContained());
// Construct the key
new (&l->value) ValueType(std::move(value));
// Hash
size_t const hash = Hash{}(l->value.first);
++m_size;
// Find insertion point
reserve(size());
LinkPair& bucket = m_buckets[_bucket(hash)];
if (bucket.first)
{
// Need to see if there's a matching value in the bucket so that we group all keys together
Pred pred{};
Link* const end = bucket.second->m_next;
for (Link* p = bucket.first; p != end; p = p->m_next)
{
if (pred(l->value.first, p->value.first))
{
// Match! Insert here.
l->m_prev = p->m_prev;
l->m_next = p;
l->m_prev->m_next = l;
l->m_next->m_prev = l;
if (p == bucket.first)
{
bucket.first = l;
}
return iterator(l, this);
}
}
// Didn't find a match within the bucket. Just add to the end of the bucket
l->m_prev = bucket.second;
l->m_next = end;
l->m_prev->m_next = l;
l->m_next->m_prev = l;
bucket.second = l;
}
else
{
// Insert at end of the list
l->m_prev = _tail();
l->m_next = _end();
l->m_prev->m_next = l;
l->m_next->m_prev = l;
bucket.first = bucket.second = l;
}
return iterator(l, this);
}
template <class... Args>
iterator emplace(Args&&... args)
{
return insert(ValueType{ std::forward<Args>(args)... });
}
iterator find(const Key& key)
{
if (empty())
{
return end();
}
size_t const hash = Hash{}(key);
LinkPair& pair = m_buckets[_bucket(hash)];
if (!pair.first)
{
return end();
}
Pred pred{};
for (Link* p = pair.first; p != pair.second->m_next; p = p->m_next)
{
if (pred(p->value.first, key))
{
return iterator(p, this);
}
}
// Not found
return end();
}
const_iterator find(const Key& key) const
{
if (empty())
{
return cend();
}
size_t const hash = Hash{}(key);
LinkPair& pair = m_buckets[_bucket(hash)];
if (!pair.first)
{
return cend();
}
Pred pred{};
Link* const bucketEnd = pair.second->m_next;
for (Link* p = pair.first; p != bucketEnd; p = p->m_next)
{
if (pred(p->value.first, key))
{
return const_iterator(p, this);
}
}
return cend();
}
std::pair<iterator, iterator> equal_range(const Key& key)
{
if (empty())
{
return std::make_pair(end(), end());
}
size_t const hash = Hash{}(key);
LinkPair& pair = m_buckets[_bucket(hash)];
if (!pair.first)
{
return std::make_pair(end(), end());
}
Pred pred{};
Link* p = pair.first;
Link* const bucketEnd = pair.second->m_next;
for (; p != bucketEnd; p = p->m_next)
{
if (pred(p->value.first, key))
{
// Inner loop: terminates when no longer matches or bucket ends
Link* first = p;
p = p->m_next;
for (; p != bucketEnd; p = p->m_next)
{
if (!pred(p->value.first, key))
{
break;
}
}
return std::make_pair(iterator(first, this), iterator(p, this));
}
}
return std::make_pair(end(), end());
}
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const
{
if (empty())
{
return std::make_pair(cend(), cend());
}
size_t const hash = Hash{}(key);
LinkPair& pair = m_buckets[_bucket(hash)];
if (!pair.first)
{
return std::make_pair(cend(), cend());
}
Pred pred{};
Link* p = pair.first;
Link* const bucketEnd = pair.second->m_next;
for (; p != bucketEnd; p = p->m_next)
{
if (pred(p->value.first, key))
{
// Inner loop: terminates when no longer matches or bucket ends
Link* first = p;
p = p->m_next;
for (; p != bucketEnd; p = p->m_next)
{
if (!pred(p->value.first, key))
{
break;
}
}
return std::make_pair(const_iterator(first, this), const_iterator(p, this));
}
}
return std::make_pair(cend(), cend());
}
size_t count(const Key& key) const
{
if (empty())
{
return 0;
}
size_t const hash = Hash{}(key);
LinkPair& pair = m_buckets[_bucket(hash)];
if (!pair.first)
{
return 0;
}
Pred pred{};
Link* p = pair.first;
Link* const bucketEnd = pair.second->m_next;
for (; p != bucketEnd; p = p->m_next)
{
if (pred(p->value.first, key))
{
// Inner loop: terminates when no longer matches or bucket ends
size_t count = 1;
p = p->m_next;
for (; p != bucketEnd; p = p->m_next)
{
if (!pred(p->value.first, key))
{
break;
}
++count;
}
return count;
}
}
return 0;
}
iterator remove(const_iterator pos)
{
CARB_ASSERT(!empty());
pos.assertNotEnd();
pos.assertOwner(this);
Link* l = pos.m_where;
Link* next = l->m_next;
// Fix up bucket if necessary
LinkPair& pair = m_buckets[_bucket(Hash{}(l->value.first))];
if (pair.first == l)
{
if (pair.second == l)
{
// Empty bucket now
pair.first = pair.second = nullptr;
}
else
{
pair.first = next;
}
}
else if (pair.second == l)
{
pair.second = l->m_prev;
}
l->m_prev->m_next = l->m_next;
l->m_next->m_prev = l->m_prev;
l->m_next = l->m_prev = nullptr;
--m_size;
// Destruct value
l->value.~ValueType();
return iterator(next, this);
}
T& remove(T& value)
{
Link* l = _link(value);
if (l->isContained())
{
CARB_ASSERT(!empty());
CARB_ASSERT(_listfind(value) != _end());
// Fix up bucket if necessary
LinkPair& pair = m_buckets[_bucket(Hash{}(l->value.first))];
if (pair.first == l)
{
if (pair.second == l)
{
// Empty bucket now
pair.first = pair.second = nullptr;
}
else
{
pair.first = l->m_next;
}
}
else if (pair.second == l)
{
pair.second = l->m_prev;
}
l->m_prev->m_next = l->m_next;
l->m_next->m_prev = l->m_prev;
l->m_next = l->m_prev = nullptr;
--m_size;
// Destruct value
l->value.~ValueType();
}
return value;
}
size_t remove(const Key& key)
{
size_t count{ 0 };
auto pair = equal_range(key);
while (pair.first != pair.second)
{
remove(pair.first++);
++count;
}
return count;
}
void swap(IntrusiveUnorderedMultimap& other) noexcept
{
if (this != std::addressof(other))
{
// Fix up the end iterators first
Link *&lhead = _head()->m_prev, *<ail = _tail()->m_next;
Link *&rhead = other._head()->m_prev, *&rtail = other._tail()->m_next;
lhead = ltail = other._end();
rhead = rtail = _end();
// Now swap everything else
std::swap(m_buckets, other.m_buckets);
std::swap(m_bucketCount, other.m_bucketCount);
std::swap(_end()->m_next, other._end()->m_next);
std::swap(_end()->m_prev, other._end()->m_prev);
std::swap(m_size, other.m_size);
std::swap(m_maxLoadFactor, other.m_maxLoadFactor);
}
}
size_t bucket_count() const noexcept
{
return m_bucketCount;
}
size_t max_bucket_count() const noexcept
{
return size_t(-1);
}
size_t bucket(const Key& key) const
{
return _bucket(Hash{}(key));
}
float load_factor() const
{
return bucket_count() ? float(size()) / float(bucket_count()) : 0.f;
}
float max_load_factor() const noexcept
{
return m_maxLoadFactor;
}
void max_load_factor(float ml)
{
CARB_ASSERT(ml > 0.f);
m_maxLoadFactor = ml;
}
void reserve(size_t count)
{
rehash(size_t(std::ceil(float(count) / max_load_factor())));
}
void rehash(size_t buckets)
{
if (buckets > m_bucketCount)
{
constexpr static size_t kMinBuckets(8);
static_assert(carb::cpp::has_single_bit(kMinBuckets), "Invalid assumption");
buckets = carb::cpp::bit_ceil(::carb_max(buckets, kMinBuckets));
CARB_ASSERT(carb::cpp::has_single_bit(buckets));
m_buckets.reset(new LinkPair[buckets]);
memset(m_buckets.get(), 0, sizeof(LinkPair) * buckets);
m_bucketCount = buckets;
// Walk through the list backwards and rehash everything. Things that have equal keys and are already
// grouped together will remain so.
Link* cur = _tail();
m_list.m_prev = m_list.m_next = _end();
Link* next;
Hash hasher;
for (; cur != _end(); cur = next)
{
next = cur->m_prev;
LinkPair& bucket = m_buckets[_bucket(hasher(cur->value.first))];
if (bucket.first)
{
// Insert in front of whatever was in the bucket
cur->m_prev = bucket.first->m_prev;
cur->m_next = bucket.first;
cur->m_prev->m_next = cur;
cur->m_next->m_prev = cur;
bucket.first = cur;
}
else
{
// Insert at the front of the list and the beginning of the bucket
cur->m_prev = _end();
cur->m_next = _head();
cur->m_prev->m_next = cur;
cur->m_next->m_prev = cur;
bucket.first = bucket.second = cur;
}
}
}
}
private:
struct LinkPair
{
Link* first;
Link* second;
};
std::unique_ptr<LinkPair[]> m_buckets{};
size_t m_bucketCount{ 0 };
BaseLink m_list;
size_t m_size{ 0 };
float m_maxLoadFactor{ 1.f };
size_t _bucket(size_t hash) const
{
// bucket count is always a power of 2
return hash & (m_bucketCount - 1);
}
Link* _listfind(T& value) const
{
Link* find = _link(value);
Link* p = _head();
for (; p != _end(); p = p->m_next)
{
if (p == find)
{
return p;
}
}
return _end();
}
static Link* _link(T& value) noexcept
{
return std::addressof(value.*U);
}
static T& _value(Link& l) noexcept
{
// Need to calculate the offset of our link member which will allow adjusting the pointer to where T is.
// This will not work if T uses virtual inheritance. Also, offsetof() cannot be used because we have a pointer
// to the member
size_t offset = size_t(reinterpret_cast<uint8_t*>(&(((T*)0)->*U)));
return *reinterpret_cast<T*>(reinterpret_cast<uint8_t*>(std::addressof(l)) - offset);
}
static const T& _value(const Link& l) noexcept
{
// Need to calculate the offset of our link member which will allow adjusting the pointer to where T is.
// This will not work if T uses virtual inheritance. Also, offsetof() cannot be used because we have a pointer
// to the member
size_t offset = size_t(reinterpret_cast<uint8_t*>(&(((T*)0)->*U)));
return *reinterpret_cast<const T*>(reinterpret_cast<const uint8_t*>(std::addressof(l)) - offset);
}
constexpr Link* _head() const noexcept
{
return const_cast<Link*>(m_list.m_next);
}
constexpr Link* _tail() const noexcept
{
return const_cast<Link*>(m_list.m_prev);
}
constexpr Link* _end() const noexcept
{
return static_cast<Link*>(const_cast<BaseLink*>(&m_list));
}
};
} // namespace container
} // namespace carb