MulDiv.h#
Fully qualified name: carb/math/MulDiv.h
File members: carb/math/MulDiv.h
// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include "../cpp/Optional.h"
#if CARB_COMPILER_GNUC || CARB_TOOLCHAIN_CLANG
#elif CARB_COMPILER_MSC
extern "C"
{
# if CARB_X86_64
unsigned char _addcarry_u64(unsigned char, unsigned __int64, unsigned __int64, unsigned __int64*);
unsigned __int64 _umul128(unsigned __int64, unsigned __int64, unsigned __int64*);
unsigned __int64 _udiv128(unsigned __int64, unsigned __int64, unsigned __int64, unsigned __int64*);
# pragma intrinsic(_addcarry_u64)
# pragma intrinsic(_umul128)
# pragma intrinsic(_udiv128)
# elif CARB_AARCH64
unsigned __int64 __umulh(unsigned __int64, unsigned __int64);
# pragma intrinsic(__umulh)
# else
CARB_UNSUPPORTED_ARCHITECTURE();
# endif
}
#else
CARB_UNSUPPORTED_PLATFORM();
#endif
namespace carb
{
namespace math
{
namespace detail
{
inline uint8_t adc(uint8_t cf, uint64_t addend1, uint64_t addend2, uint64_t& sum) noexcept
{
#if CARB_COMPILER_GNUC || CARB_TOOLCHAIN_CLANG
auto c1 = __builtin_add_overflow(addend1, addend2, &sum);
auto c2 = __builtin_add_overflow(sum, cf, &sum);
return uint8_t(c1 | c2);
#elif CARB_COMPILER_MSC
# if CARB_X86_64
return _addcarry_u64(cf, addend1, addend2, &sum);
# elif CARB_AARCH64
// Don't have any intrinsics that support carry flag here, so do manual carry check
auto temp = addend1 + addend2;
auto c1 = temp < addend1;
sum = temp + cf;
c1 |= (sum < temp);
return uint8_t(c1);
# else
CARB_UNSUPPORTED_ARCHITECTURE();
# endif
#else
CARB_UNSUPPORTED_COMPILER();
#endif
}
inline unsigned clz(uint64_t val) noexcept
{
#if CARB_COMPILER_MSC
unsigned long index;
return _BitScanReverse64(&index, val) ? 64u - 1u - index : 64u;
#elif CARB_COMPILER_GNUC
return val ? unsigned(__builtin_clzll(val)) : 64u;
#else
CARB_UNSUPPORTED_COMPILER();
#endif
}
inline uint64_t umul128(uint64_t multiplier, uint64_t multiplicand, uint64_t& productHigh) noexcept
{
#if CARB_COMPILER_GNUC || CARB_TOOLCHAIN_CLANG
# if CARB_X86_64
uint64_t low;
asm("mulq %[multiplicand]"
: "=d"(productHigh), "=a"(low)
: "a"(multiplier), [multiplicand] "rm"(multiplicand)
: "cc" /* flags */
);
return low;
# elif CARB_AARCH64
// Combining these instructions into one asm block usually results in a runtime error since the compiler uses the x0
// register both for the multiplier as well as the result from the umulh instruction.
uint64_t low;
asm("umulh %[high], %[u], %[v]" : [high] "=r"(productHigh) : [u] "r"(multiplier), [v] "r"(multiplicand));
asm("mul %[low], %[u], %[v]" : [low] "=r"(low) : [u] "r"(multiplier), [v] "r"(multiplicand));
return low;
# else
CARB_UNSUPPORTED_ARCHITECTURE();
# endif
#elif CARB_COMPILER_MSC
# if CARB_X86_64
return _umul128(multiplier, multiplicand, &productHigh);
# elif CARB_AARCH64
productHigh = __umulh(multiplier, multiplicand);
return multiplier * multiplicand;
# else
CARB_UNSUPPORTED_ARCHITECTURE();
# endif
#else
CARB_UNSUPPORTED_COMPILER();
#endif
}
inline uint64_t udiv128(uint64_t dividendHigh, uint64_t dividendLow, uint64_t divisor, uint64_t* remainder = nullptr) noexcept
{
#if CARB_AARCH64
// aarch64 doesn't have 128-bit divide, so we do Knuth's long division (Algorithm D)
CARB_ASSERT(divisor != 0); // divide-by-zero
CARB_ASSERT(dividendHigh < divisor); // overflow
unsigned s = clz(divisor);
if (s)
{
divisor <<= s;
dividendHigh <<= s;
dividendHigh |= (dividendLow >> (64 - s));
dividendLow <<= s;
}
// High quotient
uint64_t qhat = dividendHigh / unsigned(divisor >> 32);
uint64_t rhat = dividendHigh % unsigned(divisor >> 32);
while (unsigned(qhat >> 32) != 0 ||
uint64_t(unsigned(qhat)) * unsigned(divisor) > ((rhat << 32) | unsigned(dividendLow >> 32)))
{
--qhat;
rhat += unsigned(divisor >> 32);
if (unsigned(rhat >> 32))
break;
}
unsigned q1 = unsigned(qhat);
uint64_t uhat = ((dividendHigh << 32) | unsigned(dividendLow >> 32)) - q1 * divisor;
qhat = uhat / unsigned(divisor >> 32);
rhat = uhat % unsigned(divisor >> 32);
while (unsigned(qhat >> 32) != 0 ||
uint64_t(unsigned(qhat)) * unsigned(divisor) > ((rhat << 32) | unsigned(dividendLow)))
{
--qhat;
rhat += unsigned(divisor >> 32);
if (unsigned(rhat >> 32))
break;
}
unsigned q0 = unsigned(qhat);
if (remainder)
*remainder = (((uhat << 32) | unsigned(dividendLow)) - q0 * divisor) >> s;
return uint64_t(q1) << 32 | q0;
#elif CARB_X86_64
# if CARB_COMPILER_GNUC || CARB_TOOLCHAIN_CLANG
uint64_t quotient, rem;
asm("divq %4" : "=a"(quotient), "=d"(rem) : "1"(dividendHigh), "0"(dividendLow), "rm"(divisor) : "cc" /*flags*/);
if (remainder)
*remainder = rem;
return quotient;
# elif CARB_COMPILER_MSC
uint64_t rem;
uint64_t quotient = _udiv128(dividendHigh, dividendLow, divisor, &rem);
if (remainder)
*remainder = rem;
return quotient;
# else
CARB_UNSUPPORTED_COMPILER();
# endif
#else
CARB_UNSUPPORTED_ARCHITECTURE();
#endif
}
} // namespace detail
struct round_toward_zero_t
{
void operator()(uint64_t&, uint64_t&, uint64_t) noexcept
{
// nothing needed; integer math already rounds towards zero
}
};
CARB_WEAKLINK round_toward_zero_t round_toward_zero;
CARB_WEAKLINK round_toward_zero_t round_floor;
struct round_away_from_zero_t
{
void operator()(uint64_t& high, uint64_t& low, uint64_t divisor) noexcept
{
high += detail::adc(0, low, divisor - 1, low);
}
};
CARB_WEAKLINK round_away_from_zero_t round_away_from_zero;
CARB_WEAKLINK round_away_from_zero_t round_ceil;
struct round_nearest_neighbor_t
{
void operator()(uint64_t& high, uint64_t& low, uint64_t divisor) noexcept
{
high += detail::adc(0, low, divisor / 2, low);
}
};
CARB_WEAKLINK round_nearest_neighbor_t round_nearest_neighbor;
template <class RoundPolicy>
[[nodiscard]] cpp::optional<uint64_t> mulDiv(RoundPolicy round, uint64_t num, uint64_t multiplier, uint64_t divisor) noexcept
{
// No divide-by-zero
CARB_UNLIKELY_IF(!divisor)
{
return cpp::nullopt;
}
uint64_t high, low;
low = detail::umul128(num, multiplier, high);
// Apply rounding policy
round(high, low, divisor);
// Fail on overflow
if (high >= divisor)
return {};
return cpp::make_optional(detail::udiv128(high, low, divisor));
}
template <class RoundPolicy>
[[nodiscard]] cpp::optional<int64_t> mulDiv(RoundPolicy round, int64_t num, int64_t multiplier, int64_t divisor) noexcept
{
bool const negative = (num ^ multiplier ^ divisor) < 0;
auto&& absolute = [](int64_t val) noexcept { return val >= 0 ? uint64_t(val) : uint64_t(-val); };
auto uresult = mulDiv(round, absolute(num), absolute(multiplier), absolute(divisor));
if (!uresult)
return cpp::nullopt;
// Change the sign
int64_t result = negative ? -int64_t(uresult.value()) : int64_t(uresult.value());
// Overflow occurred if the sign doesn't match expected
if (result == 0 || negative == (result < 0))
return cpp::make_optional(result);
return cpp::nullopt;
}
template <class RoundPolicy>
[[nodiscard]] cpp::optional<double> mulDiv(RoundPolicy /*ignored*/, double num, double multiplier, double divisor) noexcept
{
if (divisor == 0.)
return cpp::nullopt;
return (num * multiplier) / divisor;
}
[[nodiscard]] inline cpp::optional<uint64_t> mulDiv(uint64_t num, uint64_t multiplier, uint64_t divisor) noexcept
{
return mulDiv(round_toward_zero, num, multiplier, divisor);
}
[[nodiscard]] inline cpp::optional<int64_t> mulDiv(int64_t num, int64_t multiplier, int64_t divisor) noexcept
{
return mulDiv(round_toward_zero, num, multiplier, divisor);
}
[[nodiscard]] inline cpp::optional<double> mulDiv(double num, double multiplier, double divisor) noexcept
{
return mulDiv([] {}, num, multiplier, divisor);
}
} // namespace math
} // namespace carb