carb/container/IntrusiveList.h

File members: carb/container/IntrusiveList.h

// Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "../Defines.h"

#include <iterator>

namespace carb
{

namespace container
{

template <class T>
class IntrusiveListLink
{
public:
    using ValueType = T;

    constexpr IntrusiveListLink() noexcept = default;

    ~IntrusiveListLink()
    {
        // Shouldn't be contained at destruction time
        CARB_ASSERT(!isContained());
    }

    bool isContained() const noexcept
    {
        return m_next != nullptr;
    }

private:
    template <class U, IntrusiveListLink<U> U::*V>
    friend class IntrusiveList;

    CARB_VIZ IntrusiveListLink* m_next{ nullptr };
    IntrusiveListLink* m_prev{ nullptr };

    CARB_PREVENT_COPY_AND_MOVE(IntrusiveListLink);

    constexpr IntrusiveListLink(IntrusiveListLink* init) noexcept : m_next(init), m_prev(init)
    {
    }
};

template <class T, IntrusiveListLink<T> T::*U>
class CARB_VIZ IntrusiveList
{
public:
    using ValueType = T;
    using Link = IntrusiveListLink<T>;

    // Iterator support
    // clang-format off
    class const_iterator
    {
    public:
#ifndef DOXYGEN_SHOULD_SKIP_THIS
        using iterator_category = std::bidirectional_iterator_tag;
        using value_type = ValueType;
        using difference_type = ptrdiff_t;
        using pointer = const value_type*;
        using reference = const value_type&;
        const_iterator() noexcept = default;
        reference       operator *  () const                    { assertNotEnd(); return IntrusiveList::_value(*m_where); }
        pointer         operator -> () const                    { assertNotEnd(); return std::addressof(operator*()); }
        const_iterator& operator ++ () noexcept    /* ++iter */ { assertNotEnd(); incr(); return *this; }
        const_iterator  operator ++ (int) noexcept /* iter++ */ { assertNotEnd(); const_iterator i{ *this }; incr(); return i; }
        const_iterator& operator -- () noexcept    /* --iter */ { decr(); return *this; }
        const_iterator  operator -- (int) noexcept /* iter-- */ { const_iterator i{ *this }; decr(); return i; }
        bool operator == (const const_iterator& rhs) const noexcept { assertSameOwner(rhs); return m_where == rhs.m_where; }
        bool operator != (const const_iterator& rhs) const noexcept { assertSameOwner(rhs); return m_where != rhs.m_where; }
    protected:
        friend class IntrusiveList;
        Link* m_where{ nullptr };
#if CARB_ASSERT_ENABLED
        const IntrusiveList* m_owner{ nullptr };
        const_iterator(Link* where, const IntrusiveList* owner) : m_where(where), m_owner(owner) {}
        void assertOwner(const IntrusiveList* list) const noexcept
        {
            CARB_ASSERT(m_owner == list, "IntrusiveList iterator for invalid container");
        }
        void assertSameOwner(const const_iterator& rhs) const noexcept
        {
            CARB_ASSERT(m_owner == rhs.m_owner, "IntrusiveList iterators are from different containers");
        }
        void assertNotEnd() const noexcept
        {
            CARB_ASSERT(m_where != m_owner->_end(), "Invalid operation on IntrusiveList::end() iterator");
        }
#else
        const_iterator(Link* where, const IntrusiveList*) : m_where(where) {}
        void assertOwner(const IntrusiveList* list) const noexcept
        {
            CARB_UNUSED(list);
        }
        void assertSameOwner(const const_iterator& rhs) const noexcept
        {
            CARB_UNUSED(rhs);
        }
        void assertNotEnd() const noexcept {}
#endif // !CARB_ASSERT_ENABLED
        void incr() { m_where = m_where->m_next; }
        void decr() { m_where = m_where->m_prev; }
#endif // !DOXYGEN_SHOULD_SKIP_THIS
    };

    class iterator : public const_iterator
    {
        using Base = const_iterator;
    public:
#ifndef DOXYGEN_SHOULD_SKIP_THIS
        using iterator_category = std::bidirectional_iterator_tag;
        using value_type = ValueType;
        using difference_type = ptrdiff_t;
        using pointer = value_type*;
        using reference = value_type&;
        iterator() noexcept = default;
        reference operator *  () const                    { this->assertNotEnd(); return IntrusiveList::_value(*this->m_where); }
        pointer   operator -> () const                    { this->assertNotEnd(); return std::addressof(operator*()); }
        iterator& operator ++ () noexcept    /* ++iter */ { this->assertNotEnd(); this->incr(); return *this; }
        iterator  operator ++ (int) noexcept /* iter++ */ { this->assertNotEnd(); iterator i{ *this }; this->incr(); return i; }
        iterator& operator -- () noexcept    /* --iter */ { this->decr(); return *this; }
        iterator  operator -- (int) noexcept /* iter-- */ { iterator i{ *this }; this->decr(); return i; }
    protected:
        friend class IntrusiveList;
        iterator(Link* where, const IntrusiveList* owner) : Base(where, owner) {}
#endif
    };

    using reverse_iterator = std::reverse_iterator<iterator>;
    using const_reverse_iterator = std::reverse_iterator<const_iterator>;
    // clang-format on

    CARB_PREVENT_COPY(IntrusiveList);

    constexpr IntrusiveList() : m_list(&m_list)
    {
    }

    IntrusiveList(IntrusiveList&& other) noexcept : m_list(&m_list)
    {
        swap(other);
    }

    ~IntrusiveList()
    {
        clear();
        m_list.m_next = m_list.m_prev = nullptr; // prevent the assert
    }

    IntrusiveList& operator=(IntrusiveList&& other) noexcept
    {
        swap(other);
        return *this;
    }

    bool empty() const noexcept
    {
        return !m_size;
    }

    size_t size() const noexcept
    {
        return m_size;
    }

    size_t max_size() const noexcept
    {
        return size_t(-1);
    }

    // Forward iterator support
    iterator begin() noexcept
    {
        return iterator(_head(), this);
    }

    iterator end() noexcept
    {
        return iterator(_end(), this);
    }

    const_iterator cbegin() const noexcept
    {
        return const_iterator(_head(), this);
    }

    const_iterator cend() const noexcept
    {
        return const_iterator(_end(), this);
    }

    const_iterator begin() const noexcept
    {
        return cbegin();
    }

    const_iterator end() const noexcept
    {
        return cend();
    }

    // Reverse iterator support
    reverse_iterator rbegin() noexcept
    {
        return reverse_iterator(end());
    }

    reverse_iterator rend() noexcept
    {
        return reverse_iterator(begin());
    }

    const_reverse_iterator crbegin() const noexcept
    {
        return const_reverse_iterator(cend());
    }

    const_reverse_iterator crend() const noexcept
    {
        return const_reverse_iterator(cbegin());
    }

    const_reverse_iterator rbegin() const noexcept
    {
        return crbegin();
    }

    const_reverse_iterator rend() const noexcept
    {
        return crend();
    }

    iterator locate(T& value) noexcept
    {
        Link* l = _link(value);
        if (!l->isContained())
            return end();

        Link* b = m_list.m_next;
        while (b != _end())
        {
            if (b == l)
                return iterator(l, this);
            b = b->m_next;
        }

        return end();
    }

    const_iterator locate(T& value) const noexcept
    {
        Link* l = _link(value);
        if (!l->isContained())
            return end();

        Link* b = m_list.m_next;
        while (b != _end())
        {
            if (b == l)
                return const_iterator(l, this);
            b = b->m_next;
        }

        return end();
    }

#ifndef DOXYGEN_SHOULD_SKIP_THIS
    CARB_DEPRECATED("Use locate()") iterator find(T& value) noexcept
    {
        return locate(value);
    }
    CARB_DEPRECATED("Use locate()") const_iterator find(T& value) const noexcept
    {
        return locate(value);
    }
#endif

    iterator iter_from_value(T& value)
    {
        Link* l = _link(value);
        CARB_ASSERT(!l->isContained() || locate(value) != end());
        return iterator(l->isContained() ? l : _end(), this);
    }

    const_iterator iter_from_value(T& value) const
    {
        Link* l = _link(value);
        CARB_ASSERT(!l->isContained() || locate(value) != end());
        return const_iterator(l->isContained() ? l : _end(), this);
    }

    T& front()
    {
        CARB_ASSERT(!empty());
        return _value(*_head());
    }

    const T& front() const
    {
        CARB_ASSERT(!empty());
        return _value(*_head());
    }

    T& back()
    {
        CARB_ASSERT(!empty());
        return _value(*_tail());
    }

    const T& back() const
    {
        CARB_ASSERT(!empty());
        return _value(*_tail());
    }

    T& push_front(T& value)
    {
        Link* l = _link(value);
        CARB_ASSERT(!l->isContained());
        Link* const prev = _head();
        l->m_next = prev;
        l->m_prev = _end();
        m_list.m_next = l;
        prev->m_prev = l;
        ++m_size;
        return value;
    }

    T& pop_front()
    {
        CARB_ASSERT(!empty());
        Link* const head = _head();
        Link* const next = head->m_next;
        m_list.m_next = next;
        next->m_prev = _end();
        head->m_next = head->m_prev = nullptr;
        --m_size;
        return _value(*head);
    }

    T& push_back(T& value)
    {
        Link* l = _link(value);
        CARB_ASSERT(!l->isContained());
        Link* const prev = _tail();
        l->m_next = _end();
        l->m_prev = prev;
        prev->m_next = l;
        m_list.m_prev = l;
        ++m_size;
        return value;
    }

    T& pop_back()
    {
        CARB_ASSERT(!empty());
        Link* const tail = _tail();
        Link* const prev = tail->m_prev;
        m_list.m_prev = prev;
        prev->m_next = _end();
        tail->m_next = tail->m_prev = nullptr;
        --m_size;
        return _value(*tail);
    }

    void clear()
    {
        if (_head() != _end())
        {
            do
            {
                Link* p = _head();
                m_list.m_next = p->m_next;
                p->m_next = p->m_prev = nullptr;
            } while (_head() != _end());
            m_list.m_prev = _end();
            m_size = 0;
        }
    }

    iterator insert(const_iterator pos, T& value)
    {
        Link* l = _link(value);
        CARB_ASSERT(!l->isContained());
        l->m_prev = pos.m_where->m_prev;
        l->m_next = pos.m_where;
        l->m_prev->m_next = l;
        l->m_next->m_prev = l;
        ++m_size;
        return iterator(l, this);
    }

    iterator remove(const_iterator pos)
    {
        CARB_ASSERT(!empty());
        pos.assertNotEnd();
        pos.assertOwner(this);
        Link* next = pos.m_where->m_next;
        pos.m_where->m_prev->m_next = pos.m_where->m_next;
        pos.m_where->m_next->m_prev = pos.m_where->m_prev;
        pos.m_where->m_next = pos.m_where->m_prev = nullptr;
        --m_size;
        return iterator(next, this);
    }

    T& remove(T& value)
    {
        Link* l = _link(value);
        if (l->isContained())
        {
            CARB_ASSERT(!empty());
            CARB_ASSERT(locate(value) != end());
            l->m_prev->m_next = l->m_next;
            l->m_next->m_prev = l->m_prev;
            l->m_next = l->m_prev = nullptr;
            --m_size;
        }
        return value;
    }

    void swap(IntrusiveList& other) noexcept
    {
        if (this != std::addressof(other))
        {
            // Fix up the end iterators first.
            Link *&lhead = _head()->m_prev, *&ltail = _tail()->m_next;
            Link *&rhead = other._head()->m_prev, *&rtail = other._tail()->m_next;
            lhead = ltail = other._end();
            rhead = rtail = _end();

            // Now swap pointers
            std::swap(_end()->m_next, other._end()->m_next);
            std::swap(_end()->m_prev, other._end()->m_prev);
            std::swap(m_size, other.m_size);
        }
    }

    template <class Compare>
    void merge(IntrusiveList& other, Compare comp)
    {
        if (this == std::addressof(other))
            return;

        if (!other.m_size)
            // Nothing to do
            return;

        // splice all of other's nodes onto the end of *this
        Link* const head = _end();
        Link* const otherHead = other._end();
        Link* const mid = otherHead->m_next;
        _splice(head, other, mid, otherHead, other.m_size);

        if (head->m_next != mid)
            _mergeSame(head->m_next, mid, head, comp);
    }

    template <class Compare>
    void merge(IntrusiveList&& other, Compare comp)
    {
        merge(other, comp);
    }

    void merge(IntrusiveList& other)
    {
        merge(other, std::less<ValueType>());
    }

    void merge(IntrusiveList&& other)
    {
        merge(other);
    }

    void splice(const_iterator pos, IntrusiveList& other)
    {
        pos.assertOwner(this);
        if (this == std::addressof(other) || other.empty())
            return;

        _splice(pos.m_where, other, other.m_list.m_next, other._end(), other.m_size);
    }

    void splice(const_iterator pos, IntrusiveList&& other)
    {
        splice(pos, other);
    }

    void splice(const_iterator pos, IntrusiveList& other, iterator it)
    {
        pos.assertOwner(this);
        it.assertNotEnd();
        it.assertOwner(std::addressof(other));

        Link* const last_ = it.m_where->m_next;
        if (this != std::addressof(other) || (pos.m_where != it.m_where && pos.m_where != last_))
            _splice(pos.m_where, other, it.m_where, last_, 1);
    }

    void splice(const_iterator pos, IntrusiveList&& other, iterator it)
    {
        splice(pos, other, it);
    }

    void splice(const_iterator pos, IntrusiveList& other, const_iterator first, const_iterator end)
    {
        pos.assertOwner(this);
        first.assertOwner(std::addressof(other));
        end.assertOwner(std::addressof(other));

        if (first == end)
            return;

#if CARB_ASSERT_ENABLED
        if (pos.m_owner == first.m_owner)
        {
            // The behavior is undefined if pos is an iterator in the range [first, end); though we don't have an
            // efficient way of testing for that, so loop through and check
            for (const_iterator it = first; it != end; ++it)
                CARB_ASSERT(it != pos);
        }
#endif

        if (this != std::addressof(other))
        {
            size_t range = std::distance(first, end);
            CARB_ASSERT(other.m_size >= range);
            other.m_size -= range;
            m_size += range;
        }

        _splice(pos.m_where, first.m_where, end.m_where);
    }

    void splice(const_iterator pos, IntrusiveList&& other, const_iterator first, const_iterator end)
    {
        splice(pos, other, first, end);
    }

    void reverse() noexcept
    {
        Link* end = _end();
        Link* n = end;

        for (;;)
        {
            Link* next = n->m_next;
            n->m_next = n->m_prev;
            n->m_prev = next;

            if (next == end)
                break;

            n = next;
        }
    }

    template <class Compare>
    void sort(Compare comp)
    {
        _sort(_end()->m_next, m_size, comp);
    }

    void sort()
    {
        sort(std::less<ValueType>());
    }

private:
    CARB_VIZ Link m_list;
    CARB_VIZ size_t m_size{ 0 };

    template <class Compare>
    static Link* _sort(Link*& first, size_t size, Compare comp)
    {
        switch (size)
        {
            case 0:
                return first;
            case 1:
                return first->m_next;
            default:
                break;
        }
        auto mid = _sort(first, size >> 1, comp);
        const auto last_ = _sort(mid, size - (size >> 1), comp);
        first = _mergeSame(first, mid, last_, comp);
        return last_;
    }

    template <class Compare>
    static Link* _mergeSame(Link* first, Link* mid, const Link* last_, Compare comp)
    {
        // Merge the sorted ranges [first, mid) and [mid, last)
        // Returns the new beginning of the range (which won't be `first` if it was spliced elsewhere)
        Link* newfirst;
        if (comp(_value(*mid), _value(*first)))
            // mid will be spliced to the front of the range
            newfirst = mid;
        else
        {
            // Establish comp(mid, first) by skipping over elements from the first range already in position
            newfirst = first;
            do
            {
                first = first->m_next;
                if (first == mid)
                    return newfirst;
            } while (!comp(_value(*mid), _value(*first)));
        }

        // process one run splice
        for (;;)
        {
            auto runStart = mid;
            // find the end of the run of elements we need to splice from the second range into the first
            do
            {
                mid = mid->m_next;
            } while (mid != last_ && comp(_value(*mid), _value(*first)));

            // [runStart, mid) goes before first
            _splice(first, runStart, mid);
            if (mid == last_)
                return newfirst;

            // Re-establish comp(mid, first) by skipping over elements from the first range already in position.
            do
            {
                first = first->m_next;
                if (first == mid)
                    return newfirst;
            } while (!comp(_value(*mid), _value(*first)));
        }
    }

    Link* _splice(Link* const where, IntrusiveList& other, Link* const first, Link* const last_, size_t count)
    {
        if (this != std::addressof(other))
        {
            // Different list, need to fix up size
            m_size += count;
            other.m_size -= count;
        }
        return _splice(where, first, last_);
    }

    static Link* _splice(Link* const before, Link* const first, Link* const last_) noexcept
    {
        CARB_ASSERT(before != first && before != last_ && first != last_);
        Link* const firstPrev = first->m_prev;
        firstPrev->m_next = last_;
        Link* const lastPrev = last_->m_prev;
        lastPrev->m_next = before;
        Link* const beforePrev = before->m_prev;
        beforePrev->m_next = first;

        before->m_prev = lastPrev;
        last_->m_prev = firstPrev;
        first->m_prev = beforePrev;
        return last_;
    }

    static Link* _link(T& value) noexcept
    {
        return std::addressof(value.*U);
    }

    static T& _value(Link& l) noexcept
    {
        // Need to calculate the offset of our link member which will allow adjusting the pointer to where T is.
        // This will not work if T uses virtual inheritance. Also, offsetof() cannot be used because we have a pointer
        // to the member
        size_t offset = size_t(reinterpret_cast<uint8_t*>(&(((T*)0)->*U)));
        return *reinterpret_cast<T*>(reinterpret_cast<uint8_t*>(std::addressof(l)) - offset);
    }

    static const T& _value(const Link& l) noexcept
    {
        // Need to calculate the offset of our link member which will allow adjusting the pointer to where T is.
        // This will not work if T uses virtual inheritance. Also, offsetof() cannot be used because we have a pointer
        // to the member
        size_t offset = size_t(reinterpret_cast<uint8_t*>(&(((T*)0)->*U)));
        return *reinterpret_cast<const T*>(reinterpret_cast<const uint8_t*>(std::addressof(l)) - offset);
    }

    constexpr Link* _head() const noexcept
    {
        return const_cast<Link*>(m_list.m_next);
    }

    constexpr Link* _tail() const noexcept
    {
        return const_cast<Link*>(m_list.m_prev);
    }

    constexpr Link* _end() const noexcept
    {
        return const_cast<Link*>(&m_list);
    }
};

} // namespace container

} // namespace carb