一呼百應, "one call, a hundred responses"
Loading...
Searching...
No Matches
numerics.hpp
Go to the documentation of this file.
1#pragma once
2
11#include <bit>
12#include <cmath>
13#include <cstdint>
14#include <type_traits>
15#include "attributes/common.hpp"
16
17#ifdef __AVX512F__
18#define IFAVX512(x,y) x
19#else
20#define IFAVX512(x,y) y
21#endif
22
23#include "types.hpp"
24#include "numerics/bf16.hpp"
25#include "numerics/fp16.hpp"
26
27namespace ein {
30
32template <auto N, auto ... candidates>
33concept one_of = ((N==candidates) || ... || false);
34
36template <auto N, auto ... candidates>
37concept not_one_of = (!one_of<N,candidates...>);
38
40template <size_t N> struct integer_traits {};
41template <> struct integer_traits<8> { using signed_t = int8_t; using unsigned_t = uint8_t; };
42template <> struct integer_traits<16> { using signed_t = int16_t; using unsigned_t = uint16_t; };
43template <> struct integer_traits<32> { using signed_t = int32_t; using unsigned_t = uint32_t; };
44template <> struct integer_traits<64> { using signed_t = int64_t; using unsigned_t = uint64_t; };
46
49template <typename T>
50requires one_of<sizeof(T),1,2,4,8>
51using int_t = typename integer_traits<sizeof(T)*8>::signed_t;
52
55template <typename T>
56requires one_of<sizeof(T),1,2,4,8>
57using uint_t = typename integer_traits<sizeof(T)*8>::unsigned_t;
58
60template <size_t N>
61struct imm_t {
62 static constexpr size_t value = N;
64 consteval operator size_t () const noexcept { return N; }
65};
66
74template <size_t N>
75constinit imm_t<N> imm {};
76
83template <typename T>
85constexpr bool cmp_unord(T a, T b) noexcept {
86 return isnan(a) || isnan(b);
87}
88
90extern template bool cmp_unord(float,float) noexcept;
91extern template bool cmp_unord(double,double) noexcept;
93
100template <typename T>
102constexpr bool cmp_ord(T a, T b) noexcept {
103 return !isnan(a) && !isnan(b);
104}
105
107extern template bool cmp_ord(float,float) noexcept;
108extern template bool cmp_ord(double,double) noexcept;
110
112template <one_of_t<float,double> T>
114constexpr T scalef(T x, T y) noexcept {
115 if consteval {
116 // Constexpr path using bit manipulation
117
118 using uint_type = uint_t<T>; // std::conditional_t<std::is_same_v<T, float>, uint32_t, uint64_t>;
119 constexpr int exponent_bias = std::is_same_v<T, float> ? 127 : 1023;
120 constexpr uint_type exponent_mask = std::is_same_v<T, float> ? uint_type(0x7F800000) : uint_type(0x7FF0000000000000ull);
121 constexpr uint_type mantissa_mask = std::is_same_v<T, float> ? uint_type(0x007FFFFF) : uint_type(0x000FFFFFFFFFFFFFull);
122
123 if (x == 0.0 || isnan(x) || std::isinf(x))
124 return x; // Handle special cases
125
126 // Get the raw bits of the floating-point number
127 uint_type x_bits = std::bit_cast<uint_type>(x);
128
129 // Extract exponent and mantissa
130 int exponent = static_cast<int>(((x_bits & exponent_mask) >> (std::is_same_v<T, float> ? 23 : 52)) - exponent_bias);
131 uint_type mantissa = x_bits & mantissa_mask;
132
133 // Normalize the mantissa (implicit leading 1 for normalized numbers)
134 if (exponent != -exponent_bias)
135 mantissa |= std::is_same_v<T, float> ? (1u << 23) : (1ull << 52); // Set implicit leading bit
136
137 // Scale the exponent by adding y
138 exponent += static_cast<int>(y);
139
140 // Check for overflow and underflow
141 if (exponent > std::numeric_limits<T>::max_exponent)
142 return std::numeric_limits<T>::infinity();
143
144 if (exponent < std::numeric_limits<T>::min_exponent)
145 return static_cast<T>(0.0); // Underflow to zero
146
147 // Rebuild the floating-point number from the new exponent and mantissa
148 x_bits =
149 (x_bits & (std::is_same_v<T, float> ? 0x80000000u : 0x8000000000000000ull)) | // Preserve the sign bit
150 (static_cast<uint_type>(exponent + exponent_bias) << (std::is_same_v<T, float> ? 23 : 52)) | // Apply new exponent
151 (mantissa & mantissa_mask); // Mask mantissa bits
152
153 return std::bit_cast<T>(x_bits);
154 } else {
155 // Runtime path using fast standard library call
156 return std::scalbn(x, static_cast<int>(y));
157 }
158}
159
161enum class ein_nodiscard CMPINT : size_t {
162 EQ = 0x0uz
163, LT = 0x1uz
164, LE = 0x2uz
165, FALSE = 0x3uz
166, NE = 0x4uz
167, NLT = 0x5uz
168, NLE = 0x6uz
169, TRUE = 0x7uz
170};
171
172template <CMPINT imm8, typename T>
173requires (one_of_t<T,uint8_t,int8_t,uint16_t,int16_t,uint32_t,int32_t,uint64_t,int64_t> && (size_t(imm8) < 8uz))
175constexpr bool cmpint(T a, T b) noexcept {
176 using enum CMPINT;
177 if constexpr (imm8 == TRUE) return -1;
178 else if constexpr (imm8 == FALSE) return 0;
179 else if constexpr (imm8 == LT) return a < b;
180 else if constexpr (imm8 == NLT) return a >= b;
181 else if constexpr (imm8 == LE) return a <= b;
182 else if constexpr (imm8 == NLE) return a > b;
183 else if constexpr (imm8 == EQ) return a == b;
184 else if constexpr (imm8 == NE) return a != b;
185 else static_assert(false);
186}
187
190constexpr size_t max_fp_comparison_predicate = IFAVX512(32,8);
191
193enum class ein_nodiscard CMP : size_t {
194 EQ_OQ = 0x00uz
195, LT_OS = 0x01uz
196, LE_OS = 0x02uz
197, UNORD_Q = 0x03uz
198, NEQ_UQ = 0x04uz
199, NLT_US = 0x05uz
200, NLE_US = 0x06uz
201, ORD_Q = 0x07uz
202, EQ_UQ = 0x08uz
203, NGE_US = 0x09uz
204, NGT_US = 0x0Auz
205, FALSE_OQ = 0x0Buz
206, NEQ_OQ = 0x0Cuz
207, GE_OS = 0x0Duz
208, GT_OS = 0x0Euz
209, TRUE_UQ = 0x0Fuz
210, EQ_OS = 0x10uz
211, LT_OQ = 0x11uz
212, LE_OQ = 0x12uz
213, UNORD_S = 0x13uz
214, NEQ_US = 0x14uz
215, NLT_UQ = 0x15uz
216, NLE_UQ = 0x16uz
217, ORD_S = 0x17uz
218, EQ_US = 0x18uz
219, NGE_UQ = 0x19uz
220, NGT_UQ = 0x1Auz
221, FALSE_OS = 0x1Buz
222, NEQ_OS = 0x1Cuz
223, GE_OQ = 0x1Duz
224, GT_OQ = 0x1Euz
225, TRUE_US = 0x1Fuz
226};
227
229template <CMP imm8, typename T>
230requires (one_of_t<T,float,double> && (size_t(imm8) < 32uz))
232constexpr bool cmp(T a, T b) noexcept {
233 using enum CMP;
234 if constexpr (imm8 == EQ_OQ) return cmp_ord(a, b) && (a == b);
235 else if constexpr (imm8 == LT_OS) return cmp_ord(a, b) && (a < b);
236 else if constexpr (imm8 == LE_OS) return cmp_ord(a, b) && (a <= b);
237 else if constexpr (imm8 == UNORD_Q) return cmp_unord(a, b);
238 else if constexpr (imm8 == NEQ_UQ) return cmp_unord(a, b) || (a != b);
239 else if constexpr (imm8 == NLT_US) return cmp_unord(a, b) || !(a < b);
240 else if constexpr (imm8 == NLE_US) return cmp_unord(a, b) || !(a <= b);
241 else if constexpr (imm8 == ORD_Q) return cmp_ord(a, b);
242 else if constexpr (imm8 == EQ_UQ) return cmp_unord(a, b) || (a == b);
243 else if constexpr (imm8 == NGE_US) return cmp_unord(a, b) || !(a >= b);
244 else if constexpr (imm8 == NGT_US) return cmp_unord(a, b) || !(a > b);
245 else if constexpr (imm8 == FALSE_OQ) return 0;
246 else if constexpr (imm8 == NEQ_OQ) return cmp_ord(a, b) && (a != b);
247 else if constexpr (imm8 == GE_OS) return cmp_ord(a, b) && (a >= b);
248 else if constexpr (imm8 == GT_OS) return cmp_ord(a, b) && (a > b);
249 else if constexpr (imm8 == TRUE_UQ) return -1;
250 else if constexpr (imm8 == EQ_OS) return cmp_ord(a, b) && (a == b);
251 else if constexpr (imm8 == LT_OQ) return cmp_ord(a, b) && (a < b);
252 else if constexpr (imm8 == LE_OQ) return cmp_ord(a, b) && (a <= b);
253 else if constexpr (imm8 == UNORD_S) return cmp_unord(a, b);
254 else if constexpr (imm8 == NEQ_US) return cmp_unord(a, b) || (a != b);
255 else if constexpr (imm8 == NLT_UQ) return cmp_unord(a, b) || !(a < b);
256 else if constexpr (imm8 == NLE_UQ) return cmp_unord(a, b) || !(a <= b);
257 else if constexpr (imm8 == ORD_S) return cmp_ord(a, b);
258 else if constexpr (imm8 == EQ_US) return cmp_unord(a, b) || (a == b);
259 else if constexpr (imm8 == NGE_UQ) return cmp_unord(a, b) || !(a >= b);
260 else if constexpr (imm8 == NGT_UQ) return cmp_unord(a, b) || !(a > b);
261 else if constexpr (imm8 == FALSE_OS) return 0;
262 else if constexpr (imm8 == NEQ_OS) return cmp_ord(a, b) && (a != b);
263 else if constexpr (imm8 == GE_OQ) return cmp_ord(a, b) && (a >= b);
264 else if constexpr (imm8 == GT_OQ) return cmp_ord(a, b) && (a > b);
265 else if constexpr (imm8 == TRUE_US) return -1;
266 else static_assert(false);
267}
268
270#define X extern
271#include "numerics.x"
272#undef X
274
276}
277
278#if defined(EIN_TESTING) || defined(EIN_TESTING_NUMERICS)
279TEST_CASE("numerics","[numerics]") {
280 using namespace ein;
281 using Catch::Approx;
282
283 SECTION("Concepts: one_of and not_one_of") {
284 CHECK(one_of<1, 1, 2, 3>);
285 CHECK_FALSE(one_of<4, 1, 2, 3>);
286 CHECK(not_one_of<4, 1, 2, 3>);
287 CHECK_FALSE(not_one_of<1, 1, 2, 3>);
288 }
289
290 SECTION("integer_traits") {
291 CHECK(std::is_same_v<integer_traits<8>::signed_t, int8_t>);
292 CHECK(std::is_same_v<integer_traits<16>::unsigned_t, uint16_t>);
293 CHECK(std::is_same_v<integer_traits<32>::signed_t, int32_t>);
294 CHECK(std::is_same_v<integer_traits<64>::unsigned_t, uint64_t>);
295 }
296
297 SECTION("int_t and uint_t") {
298 CHECK(std::is_same_v<int_t<int8_t>, int8_t>);
299 CHECK(std::is_same_v<uint_t<int16_t>, uint16_t>);
300 CHECK(std::is_same_v<int_t<int32_t>, int32_t>);
301 CHECK(std::is_same_v<uint_t<int64_t>, uint64_t>);
302 }
303
304 SECTION("imm_t compile-time constant") {
305 constexpr auto imm4 = imm<4>;
306 CHECK(imm4.value == 4);
307 CHECK(static_cast<size_t>(imm4) == 4);
308 }
309
310 SECTION("cmp_unord and cmp_ord") {
311 CHECK(cmp_unord(NAN, 1.0f));
312 CHECK(cmp_unord(1.0f, NAN));
313 CHECK_FALSE(cmp_unord(1.0f, 2.0f));
314
315 CHECK(cmp_ord(1.0f, 2.0f));
316 CHECK_FALSE(cmp_ord(NAN, 1.0f));
317 CHECK_FALSE(cmp_ord(1.0f, NAN));
318 }
319
320 SECTION("scalef") {
321 constexpr float x = 2.0f;
322 constexpr float y = 3.0f;
323 CHECK(scalef(x, y) == Approx(16.0f));
324
325 constexpr double a = 4.0;
326 constexpr double b = -2.0;
327 CHECK(scalef(a, b) == Approx(1.0));
328 }
329
330 SECTION("CMPINT comparison") {
331 CHECK(cmpint<CMPINT::EQ>(5, 5));
332 CHECK(cmpint<CMPINT::NE>(5, 4));
333 CHECK_FALSE(cmpint<CMPINT::LT>(5, 4));
334 CHECK(cmpint<CMPINT::LE>(5, 5));
335 CHECK(cmpint<CMPINT::NLT>(5, 5));
336 }
337
338#ifdef __AVX512F__
339 SECTION("CMP floating-point comparison with AVX512") {
340 CHECK(cmp<CMP::EQ_OQ>(1.0f, 1.0f));
341 CHECK(cmp<CMP::LT_OS>(1.0f, 2.0f));
342 CHECK_FALSE(cmp<CMP::GT_OS>(1.0f, 2.0f));
343 //CHECK(cmp<CMP::NEQ_UQ>(1.0f, NAN));
344 }
345#endif // __AVX512F__
346}
347#endif
N is not one of the candidates
Definition numerics.hpp:37
N is one of the candidates
Definition numerics.hpp:33
#define ein_artificial
[[artificial]].
Definition common.hpp:220
#define ein_inline
inline [[always_inline]]
Definition common.hpp:188
#define ein_nodiscard
C++17 [[nodiscard]].
Definition common.hpp:165
static constexpr size_t value
Definition numerics.hpp:62
constexpr size_t max_fp_comparison_predicate
AVX512 added many more floating point comparison types. Do we have them?
Definition numerics.hpp:190
typename integer_traits< sizeof(T) *8 >::signed_t int_t
returns a signed integer type of the same size as T suitable for std::bitcast
Definition numerics.hpp:51
constexpr bool cmpint(T a, T b) noexcept
Definition numerics.hpp:175
constinit imm_t< N > imm
A compile time constant passed as an empty struct.
Definition numerics.hpp:75
typename integer_traits< sizeof(T) *8 >::unsigned_t uint_t
returns an unsigned integer type of the same size as T suitable for std::bitcast
Definition numerics.hpp:57
constexpr bool cmp(T a, T b) noexcept
perform an avx512 style floating point comparison for scalar values.
Definition numerics.hpp:232
@ FALSE
always false
@ TRUE
always true
@ LE_OQ
Less-than-or-equal (ordered, nonsignaling) (AVX-512)
@ FALSE_OQ
False (ordered, nonsignaling) (AVX-512)
@ NEQ_OS
Not-equal (ordered, signaling) (AVX-512)
@ NLE_US
Not-less-than-or-equal (unordered, signaling)
@ TRUE_US
True (unordered, signaling) (AVX-512)
@ NGE_UQ
Not-greater-than-or-equal (unordered, nonsignaling) (AVX-512)
@ GE_OS
Greater-than-or-equal (ordered, signaling) (AVX-512)
@ NEQ_OQ
Not-equal (ordered, nonsignaling) (AVX-512)
@ ORD_Q
Ordered (nonsignaling)
@ LT_OS
Less-than (ordered, signaling)
@ GE_OQ
Greater-than-or-equal (ordered, nonsignaling) (AVX-512)
@ EQ_US
Equal (unordered, signaling) (AVX-512)
@ TRUE_UQ
True (unordered, nonsignaling) (AVX-512)
@ EQ_OS
Equal (ordered, signaling) (AVX-512)
@ FALSE_OS
False (ordered, signaling) (AVX-512)
@ GT_OS
Greater-than (ordered, signaling) (AVX-512)
@ ORD_S
Ordered (signaling) (AVX-512)
@ NLE_UQ
Not-less-than-or-equal (unordered, nonsignaling) (AVX-512)
@ GT_OQ
Greater-than (ordered, nonsignaling) (AVX-512)
@ UNORD_S
Unordered (signaling) (AVX-512)
@ LE_OS
Less-than-or-equal (ordered, signaling)
@ EQ_OQ
Equal (ordered, nonsignaling)
@ NEQ_UQ
Not-equal (unordered, nonsignaling)
@ EQ_UQ
Equal (unordered, nonsignaling) (AVX-512)
@ NLT_US
Not-less-than (unordered, signaling)
@ UNORD_Q
Unordered (nonsignaling)
@ LT_OQ
Less-than (ordered, nonsignaling) (AVX-512)
@ NLT_UQ
Not-less-than (unordered, nonsignaling) (AVX-512)
@ NGT_US
Not-greater-than (unordered, signaling) (AVX-512)
@ NGE_US
Not-greater-than-or-equal (unordered, signaling) (AVX-512)
@ NGT_UQ
Not-greater-than (unordered, nonsignaling) (AVX-512)
@ NEQ_US
Not-equal (unordered, signaling) (AVX-512)
A compile time constant passed as an empty struct.
Definition numerics.hpp:61
#define ein_const
[[const]] is not const
Definition common.hpp:84
#define ein_pure
[[pure]]
Definition common.hpp:102
Definition cpuid.cpp:16
template bool cmp_unord(float, float) noexcept
template bool cmp_ord(float, float) noexcept
X template float scalef(float, float) noexcept
cond xmacro
constexpr bool isnan(ein::bf16 x) noexcept
Definition bf16.hpp:131
#define IFAVX512(x, y)
Definition numerics.hpp:18