carb/container/IntrusiveUnorderedMultimap.h

File members: carb/container/IntrusiveUnorderedMultimap.h

// Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "../Defines.h"

#include "../cpp/Bit.h"

#include <cmath>
#include <functional>
#include <iterator>
#include <memory>

namespace carb
{
namespace container
{

template <class Key, class T>
class IntrusiveUnorderedMultimapLink;

template <class Key, class T, IntrusiveUnorderedMultimapLink<Key, T> T::*U, class Hash, class Pred>
class IntrusiveUnorderedMultimap;

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{

struct NontrivialDummyType
{
    constexpr NontrivialDummyType() noexcept
    {
    }
};
static_assert(!std::is_trivially_default_constructible<NontrivialDummyType>::value, "Invalid assumption");

template <class Key, class T>
class IntrusiveUnorderedMultimapLinkBase
{
public:
    using KeyType = const Key;
    using MappedType = T;
    using ValueType = std::pair<const Key, T&>;

    constexpr IntrusiveUnorderedMultimapLinkBase() noexcept = default;
    ~IntrusiveUnorderedMultimapLinkBase()
    {
        // Shouldn't be contained at destruction time
        CARB_ASSERT(!isContained());
    }

    bool isContained() const noexcept
    {
        return m_next != nullptr;
    }

private:
    template <class Key2, class U, IntrusiveUnorderedMultimapLink<Key2, U> U::*V, class Hash, class Pred>
    friend class ::carb::container::IntrusiveUnorderedMultimap;

    IntrusiveUnorderedMultimapLink<Key, T>* m_next{ nullptr };
    IntrusiveUnorderedMultimapLink<Key, T>* m_prev{ nullptr };

    CARB_PREVENT_COPY_AND_MOVE(IntrusiveUnorderedMultimapLinkBase);

    constexpr IntrusiveUnorderedMultimapLinkBase(IntrusiveUnorderedMultimapLink<Key, T>* init) noexcept
        : m_next(init), m_prev(init)
    {
    }
};

} // namespace detail
#endif

template <class Key, class T>
class IntrusiveUnorderedMultimapLink : public detail::IntrusiveUnorderedMultimapLinkBase<Key, T>
{
    using Base = detail::IntrusiveUnorderedMultimapLinkBase<Key, T>;

public:
    constexpr IntrusiveUnorderedMultimapLink() noexcept : empty{}
    {
    }

    ~IntrusiveUnorderedMultimapLink()
    {
        // Shouldn't be contained at destruction time
        CARB_ASSERT(!this->isContained());
    }

private:
    template <class Key2, class U, IntrusiveUnorderedMultimapLink<Key2, U> U::*V, class Hash, class Pred>
    friend class IntrusiveUnorderedMultimap;

    union
    {
        detail::NontrivialDummyType empty;
        typename Base::ValueType value;
    };

    CARB_PREVENT_COPY_AND_MOVE(IntrusiveUnorderedMultimapLink);
};

template <class Key, class T, IntrusiveUnorderedMultimapLink<Key, T> T::*U, class Hash = ::std::hash<Key>, class Pred = ::std::equal_to<Key>>
class IntrusiveUnorderedMultimap
{
    using BaseLink = detail::IntrusiveUnorderedMultimapLinkBase<Key, T>;

public:
    using KeyType = const Key;
    using MappedType = T;
    using Link = IntrusiveUnorderedMultimapLink<Key, T>;
    using ValueType = typename Link::ValueType;

    // Iterator support
    // clang-format off
    class const_iterator
    {
    public:
#ifndef DOXYGEN_SHOULD_SKIP_THIS
        using iterator_category = std::forward_iterator_tag;
        using value_type = ValueType;
        using difference_type = ptrdiff_t;
        using pointer = const value_type*;
        using reference = const value_type&;
        const_iterator() noexcept = default;
        reference       operator *  () const                    { assertNotEnd(); return m_where->value; }
        pointer         operator -> () const                    { assertNotEnd(); return std::addressof(operator*()); }
        const_iterator& operator ++ () noexcept    /* ++iter */ { assertNotEnd(); incr(); return *this; }
        const_iterator  operator ++ (int) noexcept /* iter++ */ { assertNotEnd(); const_iterator i{ *this }; incr(); return i; }
        bool operator == (const const_iterator& rhs) const noexcept { assertSameOwner(rhs); return m_where == rhs.m_where; }
        bool operator != (const const_iterator& rhs) const noexcept { assertSameOwner(rhs); return m_where != rhs.m_where; }
    protected:
        friend class IntrusiveUnorderedMultimap;
        Link* m_where{ nullptr };
#if CARB_ASSERT_ENABLED
        const IntrusiveUnorderedMultimap* m_owner{ nullptr };
        const_iterator(Link* where, const IntrusiveUnorderedMultimap* owner) noexcept : m_where(where), m_owner(owner) {}
        void assertOwner(const IntrusiveUnorderedMultimap* other) const noexcept
        {
            CARB_ASSERT(m_owner == other, "IntrusiveUnorderedMultimap iterator for invalid container");
        }
        void assertSameOwner(const const_iterator& rhs) const noexcept
        {
            CARB_ASSERT(m_owner == rhs.m_owner, "IntrusiveUnorderedMultimap iterators are from different containers");
        }
        void assertNotEnd() const noexcept
        {
            CARB_ASSERT(m_where != m_owner->_end(), "Invalid operation on IntrusiveUnorderedMultimap::end() iterator");
        }
#else
        const_iterator(Link* where, const IntrusiveUnorderedMultimap*) noexcept : m_where(where) {}
        void assertOwner(const IntrusiveUnorderedMultimap* other) const noexcept
        {
            CARB_UNUSED(other);
        }
        void assertSameOwner(const const_iterator& rhs) const noexcept
        {
            CARB_UNUSED(rhs);
        }
        void assertNotEnd() const noexcept {}
#endif // !CARB_ASSERT_ENABLED
        void incr() noexcept { m_where = m_where->m_next; }
#endif // !DOXYGEN_SHOULD_SKIP_THIS
    };

    class iterator : public const_iterator
    {
        using Base = const_iterator;
    public:
#ifndef DOXYGEN_SHOULD_SKIP_THIS
        using iterator_category = std::forward_iterator_tag;
        using value_type = ValueType;
        using difference_type = ptrdiff_t;
        using pointer = value_type*;
        using reference = value_type&;
        iterator() noexcept = default;
        reference operator *  () const                    { this->assertNotEnd(); return this->m_where->value; }
        pointer   operator -> () const                    { this->assertNotEnd(); return std::addressof(operator*()); }
        iterator& operator ++ () noexcept    /* ++iter */ { this->assertNotEnd(); this->incr(); return *this; }
        iterator  operator ++ (int) noexcept /* iter++ */ { this->assertNotEnd(); iterator i{ *this }; this->incr(); return i; }
    private:
        friend class IntrusiveUnorderedMultimap;
        iterator(Link* where, const IntrusiveUnorderedMultimap* owner) : Base(where, owner) {}
#endif
    };
    // clang-format on

    CARB_PREVENT_COPY(IntrusiveUnorderedMultimap);

    constexpr IntrusiveUnorderedMultimap() : m_list(_end())
    {
    }

    IntrusiveUnorderedMultimap(IntrusiveUnorderedMultimap&& other) noexcept : m_list(_end())
    {
        swap(other);
    }

    ~IntrusiveUnorderedMultimap()
    {
        clear();
        m_list.m_next = m_list.m_prev = nullptr; // Prevents the assert
    }

    IntrusiveUnorderedMultimap& operator=(IntrusiveUnorderedMultimap&& other) noexcept
    {
        swap(other);
        return *this;
    }

    bool empty() const noexcept
    {
        return !m_size;
    }

    size_t size() const noexcept
    {
        return m_size;
    }

    size_t max_size() const noexcept
    {
        return size_t(-1);
    }

    // Iterator support
    iterator begin() noexcept
    {
        return iterator(_head(), this);
    }

    iterator end() noexcept
    {
        return iterator(_end(), this);
    }

    const_iterator cbegin() const noexcept
    {
        return const_iterator(_head(), this);
    }

    const_iterator cend() const noexcept
    {
        return const_iterator(_end(), this);
    }

    const_iterator begin() const noexcept
    {
        return cbegin();
    }

    const_iterator end() const noexcept
    {
        return cend();
    }

    iterator locate(T& value) noexcept
    {
        Link* l = _link(value);
        return iterator(l->isContained() ? _listfind(value) : _end(), this);
    }

    const_iterator locate(T& value) const noexcept
    {
        Link* l = _link(value);
        return const_iterator(l->isContained() ? _listfind(value) : _end(), this);
    }

    iterator iter_from_value(T& value)
    {
        Link* l = _link(value);
        CARB_ASSERT(!l->isContained() || _listfind(value) != _end());
        return iterator(l->isContained() ? l : _end(), this);
    }

    const_iterator iter_from_value(T& value) const
    {
        Link* l = _link(value);
        CARB_ASSERT(!l->isContained() || _listfind(value) != _end());
        return const_iterator(l->isContained() ? l : _end(), this);
    }

    void clear()
    {
        if (_head() != _end())
        {
            do
            {
                Link* p = _head();
                p->value.~ValueType(); // Destruct the key
                m_list.m_next = p->m_next;
                p->m_next = p->m_prev = nullptr;
            } while (_head() != _end());
            m_list.m_prev = _end();
            m_size = 0;

            // Clear the buckets
            memset(m_buckets.get(), 0, sizeof(LinkPair) * m_bucketCount);
        }
    }

    iterator insert(ValueType value)
    {
        T& val = value.second;
        Link* l = _link(val);
        CARB_ASSERT(!l->isContained());

        // Construct the key
        new (&l->value) ValueType(std::move(value));

        // Hash
        size_t const hash = Hash{}(l->value.first);

        ++m_size;

        // Find insertion point
        reserve(size());
        LinkPair& bucket = m_buckets[_bucket(hash)];
        if (bucket.first)
        {
            // Need to see if there's a matching value in the bucket so that we group all keys together
            Pred pred{};
            Link* const end = bucket.second->m_next;
            for (Link* p = bucket.first; p != end; p = p->m_next)
            {
                if (pred(l->value.first, p->value.first))
                {
                    // Match! Insert here.
                    l->m_prev = p->m_prev;
                    l->m_next = p;
                    l->m_prev->m_next = l;
                    l->m_next->m_prev = l;

                    if (p == bucket.first)
                    {
                        bucket.first = l;
                    }
                    return iterator(l, this);
                }
            }

            // Didn't find a match within the bucket. Just add to the end of the bucket
            l->m_prev = bucket.second;
            l->m_next = end;
            l->m_prev->m_next = l;
            l->m_next->m_prev = l;
            bucket.second = l;
        }
        else
        {
            // Insert at end of the list
            l->m_prev = _tail();
            l->m_next = _end();
            l->m_prev->m_next = l;
            l->m_next->m_prev = l;
            bucket.first = bucket.second = l;
        }

        return iterator(l, this);
    }

    template <class... Args>
    iterator emplace(Args&&... args)
    {
        return insert(ValueType{ std::forward<Args>(args)... });
    }

    iterator find(const Key& key)
    {
        if (empty())
        {
            return end();
        }

        size_t const hash = Hash{}(key);
        LinkPair& pair = m_buckets[_bucket(hash)];
        if (!pair.first)
        {
            return end();
        }

        Pred pred{};
        for (Link* p = pair.first; p != pair.second->m_next; p = p->m_next)
        {
            if (pred(p->value.first, key))
            {
                return iterator(p, this);
            }
        }

        // Not found
        return end();
    }

    const_iterator find(const Key& key) const
    {
        if (empty())
        {
            return cend();
        }

        size_t const hash = Hash{}(key);
        LinkPair& pair = m_buckets[_bucket(hash)];
        if (!pair.first)
        {
            return cend();
        }

        Pred pred{};
        Link* const bucketEnd = pair.second->m_next;
        for (Link* p = pair.first; p != bucketEnd; p = p->m_next)
        {
            if (pred(p->value.first, key))
            {
                return const_iterator(p, this);
            }
        }

        return cend();
    }

    std::pair<iterator, iterator> equal_range(const Key& key)
    {
        if (empty())
        {
            return std::make_pair(end(), end());
        }

        size_t const hash = Hash{}(key);
        LinkPair& pair = m_buckets[_bucket(hash)];
        if (!pair.first)
        {
            return std::make_pair(end(), end());
        }

        Pred pred{};
        Link* p = pair.first;
        Link* const bucketEnd = pair.second->m_next;
        for (; p != bucketEnd; p = p->m_next)
        {
            if (pred(p->value.first, key))
            {
                // Inner loop: terminates when no longer matches or bucket ends
                Link* first = p;
                p = p->m_next;
                for (; p != bucketEnd; p = p->m_next)
                {
                    if (!pred(p->value.first, key))
                    {
                        break;
                    }
                }
                return std::make_pair(iterator(first, this), iterator(p, this));
            }
        }
        return std::make_pair(end(), end());
    }

    std::pair<const_iterator, const_iterator> equal_range(const Key& key) const
    {
        if (empty())
        {
            return std::make_pair(cend(), cend());
        }

        size_t const hash = Hash{}(key);
        LinkPair& pair = m_buckets[_bucket(hash)];
        if (!pair.first)
        {
            return std::make_pair(cend(), cend());
        }

        Pred pred{};
        Link* p = pair.first;
        Link* const bucketEnd = pair.second->m_next;
        for (; p != bucketEnd; p = p->m_next)
        {
            if (pred(p->value.first, key))
            {
                // Inner loop: terminates when no longer matches or bucket ends
                Link* first = p;
                p = p->m_next;
                for (; p != bucketEnd; p = p->m_next)
                {
                    if (!pred(p->value.first, key))
                    {
                        break;
                    }
                }
                return std::make_pair(const_iterator(first, this), const_iterator(p, this));
            }
        }
        return std::make_pair(cend(), cend());
    }

    size_t count(const Key& key) const
    {
        if (empty())
        {
            return 0;
        }

        size_t const hash = Hash{}(key);
        LinkPair& pair = m_buckets[_bucket(hash)];
        if (!pair.first)
        {
            return 0;
        }

        Pred pred{};
        Link* p = pair.first;
        Link* const bucketEnd = pair.second->m_next;
        for (; p != bucketEnd; p = p->m_next)
        {
            if (pred(p->value.first, key))
            {
                // Inner loop: terminates when no longer matches or bucket ends
                size_t count = 1;
                p = p->m_next;
                for (; p != bucketEnd; p = p->m_next)
                {
                    if (!pred(p->value.first, key))
                    {
                        break;
                    }
                    ++count;
                }
                return count;
            }
        }
        return 0;
    }

    iterator remove(const_iterator pos)
    {
        CARB_ASSERT(!empty());
        pos.assertNotEnd();
        pos.assertOwner(this);

        Link* l = pos.m_where;
        Link* next = l->m_next;

        // Fix up bucket if necessary
        LinkPair& pair = m_buckets[_bucket(Hash{}(l->value.first))];
        if (pair.first == l)
        {
            if (pair.second == l)
            {
                // Empty bucket now
                pair.first = pair.second = nullptr;
            }
            else
            {
                pair.first = next;
            }
        }
        else if (pair.second == l)
        {
            pair.second = l->m_prev;
        }

        l->m_prev->m_next = l->m_next;
        l->m_next->m_prev = l->m_prev;
        l->m_next = l->m_prev = nullptr;
        --m_size;

        // Destruct value
        l->value.~ValueType();
        return iterator(next, this);
    }

    T& remove(T& value)
    {
        Link* l = _link(value);
        if (l->isContained())
        {
            CARB_ASSERT(!empty());
            CARB_ASSERT(_listfind(value) != _end());

            // Fix up bucket if necessary
            LinkPair& pair = m_buckets[_bucket(Hash{}(l->value.first))];
            if (pair.first == l)
            {
                if (pair.second == l)
                {
                    // Empty bucket now
                    pair.first = pair.second = nullptr;
                }
                else
                {
                    pair.first = l->m_next;
                }
            }
            else if (pair.second == l)
            {
                pair.second = l->m_prev;
            }

            l->m_prev->m_next = l->m_next;
            l->m_next->m_prev = l->m_prev;
            l->m_next = l->m_prev = nullptr;
            --m_size;

            // Destruct value
            l->value.~ValueType();
        }
        return value;
    }

    size_t remove(const Key& key)
    {
        size_t count{ 0 };
        auto pair = equal_range(key);
        while (pair.first != pair.second)
        {
            remove(pair.first++);
            ++count;
        }
        return count;
    }

    void swap(IntrusiveUnorderedMultimap& other) noexcept
    {
        if (this != std::addressof(other))
        {
            // Fix up the end iterators first
            Link *&lhead = _head()->m_prev, *&ltail = _tail()->m_next;
            Link *&rhead = other._head()->m_prev, *&rtail = other._tail()->m_next;
            lhead = ltail = other._end();
            rhead = rtail = _end();

            // Now swap everything else
            std::swap(m_buckets, other.m_buckets);
            std::swap(m_bucketCount, other.m_bucketCount);
            std::swap(_end()->m_next, other._end()->m_next);
            std::swap(_end()->m_prev, other._end()->m_prev);
            std::swap(m_size, other.m_size);
            std::swap(m_maxLoadFactor, other.m_maxLoadFactor);
        }
    }

    size_t bucket_count() const noexcept
    {
        return m_bucketCount;
    }

    size_t max_bucket_count() const noexcept
    {
        return size_t(-1);
    }

    size_t bucket(const Key& key) const
    {
        return _bucket(Hash{}(key));
    }

    float load_factor() const
    {
        return bucket_count() ? float(size()) / float(bucket_count()) : 0.f;
    }

    float max_load_factor() const noexcept
    {
        return m_maxLoadFactor;
    }

    void max_load_factor(float ml)
    {
        CARB_ASSERT(ml > 0.f);
        m_maxLoadFactor = ml;
    }

    void reserve(size_t count)
    {
        rehash(size_t(std::ceil(float(count) / max_load_factor())));
    }

    void rehash(size_t buckets)
    {
        if (buckets > m_bucketCount)
        {
            constexpr static size_t kMinBuckets(8);
            static_assert(carb::cpp::has_single_bit(kMinBuckets), "Invalid assumption");
            buckets = carb::cpp::bit_ceil(::carb_max(buckets, kMinBuckets));
            CARB_ASSERT(carb::cpp::has_single_bit(buckets));
            m_buckets.reset(new LinkPair[buckets]);
            memset(m_buckets.get(), 0, sizeof(LinkPair) * buckets);
            m_bucketCount = buckets;

            // Walk through the list backwards and rehash everything. Things that have equal keys and are already
            // grouped together will remain so.
            Link* cur = _tail();
            m_list.m_prev = m_list.m_next = _end();

            Link* next;
            Hash hasher;
            for (; cur != _end(); cur = next)
            {
                next = cur->m_prev;

                LinkPair& bucket = m_buckets[_bucket(hasher(cur->value.first))];
                if (bucket.first)
                {
                    // Insert in front of whatever was in the bucket
                    cur->m_prev = bucket.first->m_prev;
                    cur->m_next = bucket.first;
                    cur->m_prev->m_next = cur;
                    cur->m_next->m_prev = cur;
                    bucket.first = cur;
                }
                else
                {
                    // Insert at the front of the list and the beginning of the bucket
                    cur->m_prev = _end();
                    cur->m_next = _head();
                    cur->m_prev->m_next = cur;
                    cur->m_next->m_prev = cur;
                    bucket.first = bucket.second = cur;
                }
            }
        }
    }

private:
    struct LinkPair
    {
        Link* first;
        Link* second;
    };
    std::unique_ptr<LinkPair[]> m_buckets{};
    size_t m_bucketCount{ 0 };
    BaseLink m_list;
    size_t m_size{ 0 };
    float m_maxLoadFactor{ 1.f };

    size_t _bucket(size_t hash) const
    {
        // bucket count is always a power of 2
        return hash & (m_bucketCount - 1);
    }

    Link* _listfind(T& value) const
    {
        Link* find = _link(value);
        Link* p = _head();
        for (; p != _end(); p = p->m_next)
        {
            if (p == find)
            {
                return p;
            }
        }
        return _end();
    }

    static Link* _link(T& value) noexcept
    {
        return std::addressof(value.*U);
    }

    static T& _value(Link& l) noexcept
    {
        // Need to calculate the offset of our link member which will allow adjusting the pointer to where T is.
        // This will not work if T uses virtual inheritance. Also, offsetof() cannot be used because we have a pointer
        // to the member
        size_t offset = size_t(reinterpret_cast<uint8_t*>(&(((T*)0)->*U)));
        return *reinterpret_cast<T*>(reinterpret_cast<uint8_t*>(std::addressof(l)) - offset);
    }

    static const T& _value(const Link& l) noexcept
    {
        // Need to calculate the offset of our link member which will allow adjusting the pointer to where T is.
        // This will not work if T uses virtual inheritance. Also, offsetof() cannot be used because we have a pointer
        // to the member
        size_t offset = size_t(reinterpret_cast<uint8_t*>(&(((T*)0)->*U)));
        return *reinterpret_cast<const T*>(reinterpret_cast<const uint8_t*>(std::addressof(l)) - offset);
    }

    constexpr Link* _head() const noexcept
    {
        return const_cast<Link*>(m_list.m_next);
    }

    constexpr Link* _tail() const noexcept
    {
        return const_cast<Link*>(m_list.m_prev);
    }

    constexpr Link* _end() const noexcept
    {
        return static_cast<Link*>(const_cast<BaseLink*>(&m_list));
    }
};

} // namespace container

} // namespace carb