carb/memory/ArenaAllocator.h

File members: carb/memory/ArenaAllocator.h

// Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once
#include "../Defines.h"

#include <memory>

namespace carb
{
namespace memory
{

template <class T, class FallbackAllocator = std::allocator<T>>
class ArenaAllocator
{
public:
    using pointer = typename std::allocator_traits<FallbackAllocator>::pointer;
    using const_pointer = typename std::allocator_traits<FallbackAllocator>::const_pointer;
    using void_pointer = typename std::allocator_traits<FallbackAllocator>::void_pointer;
    using const_void_pointer = typename std::allocator_traits<FallbackAllocator>::const_void_pointer;
    using value_type = typename std::allocator_traits<FallbackAllocator>::value_type;
    using size_type = typename std::allocator_traits<FallbackAllocator>::size_type;
    using difference_type = typename std::allocator_traits<FallbackAllocator>::difference_type;

    template <class U>
    struct rebind
    {
        using other = ArenaAllocator<U, typename FallbackAllocator::template rebind<U>::other>;
    };

    ArenaAllocator() : m_members(ValueInitFirst{}, nullptr), m_current(nullptr), m_end(nullptr)
    {
    }

    explicit ArenaAllocator(const FallbackAllocator& fallback)
        : m_members(InitBoth{}, fallback, nullptr), m_current(nullptr), m_end(nullptr)
    {
    }

    ArenaAllocator(void* begin, void* end, const FallbackAllocator& fallback = FallbackAllocator())
        : m_members(InitBoth{}, fallback, static_cast<uint8_t*>(begin)),
          m_current(alignForward(m_members.second)),
          m_end(static_cast<uint8_t*>(end))
    {
    }

    ArenaAllocator(ArenaAllocator&& other)
        : m_members(InitBoth{}, std::move(other.m_members.first()), other.m_members.second),
          m_current(other.m_current),
          m_end(other.m_end)
    {
        // Prevent `other` from allocating memory from the arena. By adding 1 we put it past the end which prevents
        // other->deallocate() from reclaiming the last allocation.
        other.m_current = other.m_end + 1;
    }

    ArenaAllocator(const ArenaAllocator& other)
        : m_members(InitBoth{}, other.m_members.first(), other.m_members.second),
          m_current(other.m_current),
          m_end(other.m_end)
    {
        // Prevent `other` from allocating memory from the arena. By adding 1 we put it past the end which prevents
        // other->deallocate() from reclaiming the last allocation.
        other.m_current = other.m_end + 1;
    }

    template <class U, class UFallbackAllocator>
    ArenaAllocator(const ArenaAllocator<U, UFallbackAllocator>& other)
        : m_members(InitBoth{}, other.m_members.first(), other.m_members.second),
          m_current(other.m_end + 1),
          m_end(other.m_end)
    {
        // m_current is explicitly assigned to `other.m_end + 1` to prevent further allocations from the arena from
        // *this and to prevent this->deallocate() from reclaiming the last allocation.
    }

    pointer allocate(size_type n = 1)
    {
        if ((m_current + (sizeof(value_type) * n)) <= end())
        {
            pointer p = reinterpret_cast<pointer>(m_current);
            m_current += (sizeof(value_type) * n);
            return p;
        }
        return m_members.first().allocate(n);
    }

    void deallocate(pointer in, size_type n = 1)
    {
        uint8_t* p = reinterpret_cast<uint8_t*>(in);
        if (p >= begin() && p < end())
        {
            if ((p + (sizeof(value_type) * n)) == m_current)
                m_current -= (sizeof(value_type) * n);
        }
        else
            m_members.first().deallocate(in, n);
    }

private:
    uint8_t* begin() const noexcept
    {
        return m_members.second;
    }
    uint8_t* end() const noexcept
    {
        return m_end;
    }

    static uint8_t* alignForward(void* p)
    {
        uint8_t* out = reinterpret_cast<uint8_t*>(p);
        constexpr static size_t align = alignof(value_type);
        size_t aligned = (size_t(out) + (align - 1)) & -(ptrdiff_t)align;
        return out + (aligned - size_t(out));
    }

    template <class U, class UFallbackAllocator>
    friend class ArenaAllocator;
    mutable EmptyMemberPair<FallbackAllocator, uint8_t* /*begin*/> m_members;
    mutable uint8_t* m_current;
    mutable uint8_t* m_end;
};

template <class T, class U, class Allocator1, class Allocator2>
bool operator==(const ArenaAllocator<T, Allocator1>& lhs, const ArenaAllocator<U, Allocator2>& rhs)
{
    return (void*)lhs.m_members.second == (void*)rhs.m_members.second && lhs.m_members.first() == rhs.m_members.first();
}

template <class T, class U, class Allocator1, class Allocator2>
bool operator!=(const ArenaAllocator<T, Allocator1>& lhs, const ArenaAllocator<U, Allocator2>& rhs)
{
    return !(lhs == rhs);
}

} // namespace memory
} // namespace carb