omni/structuredlog/BinarySerializer.h

File members: omni/structuredlog/BinarySerializer.h

// Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once

#include "../structuredlog/StringView.h"
#include "../../carb/extras/StringSafe.h"
#include "../../carb/Defines.h"

#include <stdint.h>

namespace omni
{
namespace structuredlog
{

class BinaryBlobSizeCalculator
{
public:
    static constexpr uint32_t kVersion = 0;

    size_t getSize()
    {
        return m_counter;
    }

    template <typename T>
    void track(T v)
    {
        CARB_UNUSED(v);
        static_assert(std::is_arithmetic<T>::value, "this is only valid for primitive types");
        m_counter = alignOffset<T>(m_counter);
        m_counter += sizeof(T);
    }

    template <typename T>
    void track(T* v, uint16_t len)
    {
        CARB_UNUSED(v);
        static_assert(std::is_arithmetic<T>::value, "this is only valid for primitive types");
        m_counter = alignOffset<uint16_t>(m_counter);
        m_counter += sizeof(uint16_t);
        if (len > 0)
        {
            m_counter = alignOffset<T>(m_counter);
            m_counter += sizeof(T) * len;
        }
    }

    void track(const StringView& v)
    {
        m_counter = alignOffset<uint16_t>(m_counter);
        m_counter += sizeof(uint16_t);
        m_counter += v.length() + 1;
    }

    void track(StringView& v)
    {
        track(static_cast<const StringView&>(v));
    }

    void track(const char* const* v, const uint16_t* stringLengths, uint16_t len)
    {
        CARB_UNUSED(v);
        m_counter = alignOffset<uint16_t>(m_counter);
        m_counter += sizeof(uint16_t);
        for (uint16_t i = 0; i < len; i++)
        {
            m_counter = alignOffset<uint16_t>(m_counter);
            m_counter += sizeof(uint16_t);
            m_counter += stringLengths[i];
        }
    }

    void track(const char* const* v, uint16_t len)
    {
        m_counter = alignOffset<uint16_t>(m_counter);
        m_counter += sizeof(uint16_t);
        for (uint16_t i = 0; i < len; i++)
        {
            m_counter = alignOffset<uint16_t>(m_counter);
            m_counter += sizeof(uint16_t);
            if (v[i] != nullptr)
            {
                size_t size = strlen(v[i]) + 1;
                m_counter += CARB_MIN(size_t(UINT16_MAX), size);
            }
        }
    }

    template <typename T>
    void trackFixed(T* v, uint16_t len)
    {
        CARB_UNUSED(v);
        static_assert(std::is_arithmetic<T>::value, "this is only valid for primitive types");
        m_counter = alignOffset<T>(m_counter);
        m_counter += sizeof(T) * len;
    }

    template <typename T>
    static size_t alignOffset(size_t offset)
    {
        size_t misalign = offset & (sizeof(T) - 1);
        if (misalign != 0)
            offset += sizeof(T) - misalign;
        return offset;
    }

private:
    size_t m_counter = 0;
};

static constexpr bool kBlobWriterValidate = true;
static constexpr bool kBlobWriterNoValidate = false;
namespace
{
void ignoreBlobWriterValidationError(const char* s) noexcept
{
    CARB_UNUSED(s);
}
} // namespace

using OnBlobWriterValidationErrorFunc = void (*)(const char* message);

template <bool validate = false, OnBlobWriterValidationErrorFunc onValidationError = ignoreBlobWriterValidationError>
class BlobWriter
{
public:
    static constexpr uint32_t kVersion = 0;

    BlobWriter(void* buffer, size_t bytes)
    {
        CARB_ASSERT(buffer != nullptr);
        CARB_ASSERT((uintptr_t(buffer) & (sizeof(void*) - 1)) == 0);
        m_buffer = static_cast<uint8_t*>(buffer);
        m_bufferLen = bytes;
    }

    template <typename T>
    bool copy(T v)
    {
        static_assert(std::is_arithmetic<T>::value, "this is only valid for primitive types");

        alignBuffer<T>();
        if (validate && m_bufferLen < m_written + sizeof(v))
        {
            outOfMemoryErrorMessage(sizeof(v));
            return false;
        }

        reinterpret_cast<T*>(m_buffer + m_written)[0] = v;
        m_written += sizeof(v);
        return true;
    }

    bool copy(const char* const* v, const uint16_t* stringLengths, uint16_t len)
    {
        CARB_ASSERT(v != nullptr || len == 0);
        CARB_ASSERT(stringLengths != nullptr || len == 0);

        alignBuffer<decltype(len)>();
        if (validate && m_bufferLen < m_written + sizeof(len))
        {
            outOfMemoryErrorMessage(sizeof(len));
            return false;
        }

        reinterpret_cast<decltype(len)*>(m_buffer + m_written)[0] = len;
        m_written += sizeof(len);
        for (uint16_t i = 0; i < len; i++)
        {
            alignBuffer<decltype(stringLengths[i])>();
            if (validate && m_bufferLen < m_written + sizeof(stringLengths[i]) + stringLengths[i])
            {
                outOfMemoryErrorMessage(sizeof(stringLengths[i]) + stringLengths[i]);
                return false;
            }

            reinterpret_cast<uint16_t*>(m_buffer + m_written)[0] = stringLengths[i];
            m_written += sizeof(stringLengths[i]);

            memcpy(m_buffer + m_written, v[i], stringLengths[i]);
            m_written += stringLengths[i];
        }
        return true;
    }

    bool copy(const char* const* v, uint16_t len)
    {
        CARB_ASSERT(v != nullptr || len == 0);

        alignBuffer<decltype(len)>();
        if (validate && m_bufferLen < m_written + sizeof(len))
        {
            outOfMemoryErrorMessage(sizeof(len));
            return false;
        }

        reinterpret_cast<decltype(len)*>(m_buffer + m_written)[0] = len;
        m_written += sizeof(len);
        for (uint16_t i = 0; i < len; i++)
        {
            size_t size;
            uint16_t s = 0;

            if (v[i] == nullptr)
            {
                alignBuffer<uint16_t>();
                if (validate && m_bufferLen < m_written + sizeof(uint16_t))
                {
                    outOfMemoryErrorMessage(sizeof(uint16_t));
                    return false;
                }

                reinterpret_cast<uint16_t*>(m_buffer + m_written)[0] = 0;
                m_written += sizeof(uint16_t);
                continue;
            }

            // this might silently truncate if a really long string is passed
            size = strlen(v[i]) + 1;
            s = uint16_t(CARB_MIN(size_t(UINT16_MAX), size));

            alignBuffer<decltype(s)>();
            if (validate && m_bufferLen < m_written + sizeof(s) + s)
            {
                outOfMemoryErrorMessage(sizeof(s) + s);
                return false;
            }

            reinterpret_cast<decltype(s)*>(m_buffer + m_written)[0] = s;
            m_written += sizeof(s);

            memcpy(m_buffer + m_written, v[i], s);
            m_written += s;
        }
        return true;
    }

    template <typename T>
    bool copy(T* v, uint16_t len)
    {
        static_assert(std::is_arithmetic<T>::value, "this is only valid for primitive types");

        CARB_ASSERT(v != nullptr || len == 0);

        alignBuffer<decltype(len)>();
        if (validate && m_bufferLen < m_written + sizeof(len))
        {
            outOfMemoryErrorMessage(sizeof(len));
            return false;
        }

        reinterpret_cast<decltype(len)*>(m_buffer + m_written)[0] = len;
        m_written += sizeof(len);

        if (len == 0)
            return true;

        alignBuffer<T>();
        if (validate && m_bufferLen < m_written + sizeof(T) * len)
        {
            outOfMemoryErrorMessage(sizeof(T) * len);
            return false;
        }

        memcpy(m_buffer + m_written, v, sizeof(T) * len);
        m_written += sizeof(T) * len;
        return true;
    }

    bool copy(const StringView& v)
    {
        uint16_t len = uint16_t(CARB_MIN(size_t(UINT16_MAX), v.length() + 1));
        alignBuffer<decltype(len)>();
        if (validate && m_bufferLen < m_written + sizeof(len))
        {
            outOfMemoryErrorMessage(sizeof(len));
            return false;
        }

        reinterpret_cast<decltype(len)*>(m_buffer + m_written)[0] = len;
        m_written += sizeof(len);

        if (len == 0)
            return true;

        if (validate && m_bufferLen < m_written + len)
        {
            outOfMemoryErrorMessage(len);
            return false;
        }

        // the string view may not be null terminated, so we need to write the
        // terminator separately
        if (len > 1)
        {
            memcpy(m_buffer + m_written, v.data(), len - 1);
            m_written += len - 1;
        }

        m_buffer[m_written++] = '\0';
        return true;
    }

    bool copy(StringView& v)
    {
        return copy(static_cast<const StringView&>(v));
    }

    template <typename T>
    bool copy(T* v, uint16_t actualLen, uint16_t fixedLen)
    {
        const size_t total = sizeof(T) * fixedLen;
        const size_t written = sizeof(T) * actualLen;

        static_assert(std::is_arithmetic<T>::value, "this is only valid for primitive types");
        CARB_ASSERT(v != nullptr);
        CARB_ASSERT(fixedLen >= actualLen);

        alignBuffer<T>();
        if (validate && m_bufferLen < m_written + total)
        {
            outOfMemoryErrorMessage(total);
            return false;
        }

        // write out the actual values
        memcpy(m_buffer + m_written, v, written);

        // zero the padding at the end
        memset(m_buffer + m_written + written, 0, total - written);

        m_written += total;
        return true;
    }

    template <typename T>
    void alignBuffer()
    {
        size_t next = BinaryBlobSizeCalculator::alignOffset<T>(m_written);

        // there's no strict requirement for the padding to be 0
        if (validate)
            memset(m_buffer + m_written, 0, next - m_written);

        m_written = next;
    }

private:
    void outOfMemoryErrorMessage(size_t size)
    {
        char tmp[256];
        carb::extras::formatString(tmp, sizeof(tmp),
                                   "hit end of buffer while writing"
                                   " (tried to write %zu bytes, with %zd available)",
                                   size, m_bufferLen - m_written);
        onValidationError(tmp);
    }

    uint8_t* m_buffer = nullptr;

    size_t m_bufferLen = 0;

    size_t m_written = 0;
};

static constexpr bool kBlobReaderValidate = true;
static constexpr bool kBlobReaderNoValidate = false;
template <bool validate = false, OnBlobWriterValidationErrorFunc onValidationError = ignoreBlobWriterValidationError>
class BlobReader
{
public:
    static constexpr uint32_t kVersion = 0;

    BlobReader(const void* blob, size_t blobSize)
    {
        CARB_ASSERT(blob != nullptr || blobSize == 0);
        m_buffer = static_cast<const uint8_t*>(blob);
        m_bufferLen = blobSize;
    }

    template <typename T>
    bool read(T* out)
    {
        static_assert(std::is_arithmetic<T>::value, "this is only valid for primitive types");
        CARB_ASSERT(out != nullptr);

        alignBuffer<T>();
        if (validate && m_bufferLen < m_read + sizeof(T))
        {
            outOfMemoryErrorMessage(sizeof(T));
            return false;
        }

        *out = reinterpret_cast<const T*>(m_buffer + m_read)[0];
        m_read += sizeof(T);
        return true;
    }

    bool read(const char** out, uint16_t* outLen, uint16_t maxLen)
    {
        uint16_t len = 0;

        alignBuffer<decltype(len)>();
        if (validate && m_bufferLen < m_read + sizeof(uint16_t))
        {
            outOfMemoryErrorMessage(sizeof(uint16_t));
            return false;
        }

        len = reinterpret_cast<const decltype(len)*>(m_buffer + m_read)[0];
        *outLen = len;
        if (maxLen == 0 && len != 0)
            return true;

        m_read += sizeof(len);
        if (validate && len > maxLen)
        {
            char tmp[256];
            carb::extras::formatString(tmp, sizeof(tmp),
                                       "buffer is too small to read the data"
                                       " (length = %" PRIu16 ", needed = %" PRIu16 ")",
                                       maxLen, len);
            onValidationError(tmp);
        }

        for (uint16_t i = 0; i < len; i++)
        {
            uint16_t s = 0;

            alignBuffer<decltype(s)>();
            if (validate && m_bufferLen < m_read + sizeof(uint16_t))
            {
                outOfMemoryErrorMessage(sizeof(uint16_t));
                return false;
            }

            s = reinterpret_cast<const decltype(s)*>(m_buffer + m_read)[0];
            m_read += sizeof(len);
            if (validate && m_bufferLen < m_read + s)
            {
                outOfMemoryErrorMessage(s);
                return false;
            }

            if (s == 0)
            {
                out[i] = nullptr;
                continue;
            }

            out[i] = reinterpret_cast<const char*>(m_buffer + m_read);

            m_read += s;
        }

        return true;
    }

    template <typename T>
    bool read(const T** out, uint16_t* outLen)
    {
        CARB_ASSERT(out != nullptr);
        CARB_ASSERT(outLen != nullptr);

        alignBuffer<uint16_t>();
        if (validate && m_bufferLen < m_read + sizeof(uint16_t))
        {
            outOfMemoryErrorMessage(sizeof(uint16_t));
            return false;
        }

        *outLen = reinterpret_cast<const uint16_t*>(m_buffer + m_read)[0];
        m_read += sizeof(*outLen);

        if (*outLen == 0)
        {
            *out = nullptr;
            return true;
        }

        alignBuffer<T>();
        if (validate && m_bufferLen < m_read + *outLen)
        {
            outOfMemoryErrorMessage(*outLen);
            return false;
        }

        *out = reinterpret_cast<const T*>(m_buffer + m_read);
        m_read += *outLen * sizeof(T);
        return true;
    }

    template <typename T>
    bool read(const T** out, uint16_t fixedLen)
    {
        CARB_ASSERT(out != nullptr);

        alignBuffer<T>();
        if (validate && m_bufferLen < m_read + fixedLen)
        {
            outOfMemoryErrorMessage(fixedLen);
            return false;
        }

        *out = reinterpret_cast<const T*>(m_buffer + m_read);
        m_read += fixedLen * sizeof(T);
        return true;
    }

    template <typename T>
    void alignBuffer()
    {
        m_read = BinaryBlobSizeCalculator::alignOffset<T>(m_read);
    }

protected:
    void outOfMemoryErrorMessage(size_t size)
    {
        char tmp[256];
        carb::extras::formatString(tmp, sizeof(tmp),
                                   "hit end of buffer while reading"
                                   " (tried to read %zu bytes, with %zd available)",
                                   size, m_bufferLen - m_read);
        onValidationError(tmp);
    }

    const uint8_t* m_buffer = nullptr;

    size_t m_bufferLen = 0;

    size_t m_read = 0;
};

} // namespace structuredlog
} // namespace omni