add copy_sign and genericify abs
[vector-math.git] / src / f16.rs
1 use crate::{
2 scalar::Value,
3 traits::{ConvertFrom, ConvertTo, Float},
4 };
5 use core::{
6 fmt,
7 ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign},
8 };
9
10 #[cfg(feature = "f16")]
11 use half::f16 as F16Impl;
12
13 #[cfg(not(feature = "f16"))]
14 type F16Impl = u16;
15
16 #[derive(Clone, Copy, PartialEq, PartialOrd)]
17 #[repr(transparent)]
18 pub struct F16(F16Impl);
19
20 #[cfg(not(feature = "f16"))]
21 #[track_caller]
22 pub(crate) fn panic_f16_feature_disabled() -> ! {
23 panic!("f16 feature is not enabled")
24 }
25
26 #[cfg(feature = "f16")]
27 macro_rules! f16_impl {
28 ($v:expr, [$($vars:ident),*]) => {
29 $v
30 };
31 }
32
33 #[cfg(not(feature = "f16"))]
34 macro_rules! f16_impl {
35 ($v:expr, [$($vars:ident),*]) => {
36 {
37 $(let _ = $vars;)*
38 panic_f16_feature_disabled()
39 }
40 };
41 }
42
43 impl fmt::Display for F16 {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 f16_impl!(self.0.fmt(f), [f])
46 }
47 }
48
49 impl fmt::Debug for F16 {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 f16_impl!(self.0.fmt(f), [f])
52 }
53 }
54
55 impl Default for F16 {
56 fn default() -> Self {
57 f16_impl!(F16(F16Impl::default()), [])
58 }
59 }
60
61 impl From<F16Impl> for F16 {
62 fn from(v: F16Impl) -> Self {
63 F16(v)
64 }
65 }
66
67 impl From<F16> for F16Impl {
68 fn from(v: F16) -> Self {
69 v.0
70 }
71 }
72
73 macro_rules! impl_f16_from {
74 ($($ty:ident,)*) => {
75 $(
76 impl From<$ty> for F16 {
77 fn from(v: $ty) -> Self {
78 f16_impl!(F16(F16Impl::from(v)), [v])
79 }
80 }
81
82 impl ConvertFrom<$ty> for F16 {
83 fn cvt_from(v: $ty) -> F16 {
84 v.into()
85 }
86 }
87 )*
88 };
89 }
90
91 macro_rules! impl_from_f16 {
92 ($($ty:ident,)*) => {
93 $(
94 impl From<F16> for $ty {
95 fn from(v: F16) -> Self {
96 f16_impl!(v.0.into(), [v])
97 }
98 }
99
100 impl ConvertFrom<F16> for $ty {
101 fn cvt_from(v: F16) -> Self {
102 v.into()
103 }
104 }
105 )*
106 };
107 }
108
109 impl_f16_from![i8, u8,];
110
111 impl_from_f16![f32, f64,];
112
113 macro_rules! impl_int_to_f16 {
114 ($($int:ident),*) => {
115 $(
116 impl ConvertFrom<$int> for F16 {
117 fn cvt_from(v: $int) -> Self {
118 // f32 has enough mantissa bits such that f16 overflows to
119 // infinity before f32 stops being able to properly
120 // represent integer values, making the below conversion correct.
121 F16::cvt_from(v as f32)
122 }
123 }
124 )*
125 };
126 }
127
128 macro_rules! impl_f16_to_int {
129 ($($int:ident),*) => {
130 $(
131 impl ConvertFrom<F16> for $int {
132 fn cvt_from(v: F16) -> Self {
133 f32::from(v) as $int
134 }
135 }
136 )*
137 };
138 }
139
140 impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128];
141 impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128];
142
143 impl ConvertFrom<f32> for F16 {
144 fn cvt_from(v: f32) -> Self {
145 f16_impl!(F16(F16Impl::from_f32(v)), [v])
146 }
147 }
148
149 impl ConvertFrom<f64> for F16 {
150 fn cvt_from(v: f64) -> Self {
151 f16_impl!(F16(F16Impl::from_f64(v)), [v])
152 }
153 }
154
155 impl Neg for F16 {
156 type Output = Self;
157
158 fn neg(self) -> Self::Output {
159 f16_impl!(Self::from_bits(self.to_bits() ^ 0x8000), [])
160 }
161 }
162
163 macro_rules! impl_bin_op_using_f32 {
164 ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => {
165 $(
166 impl $op for F16 {
167 type Output = Self;
168
169 fn $op_fn(self, rhs: Self) -> Self::Output {
170 f32::from(self).$op_fn(f32::from(rhs)).to()
171 }
172 }
173
174 impl $op_assign for F16 {
175 fn $op_assign_fn(&mut self, rhs: Self) {
176 *self = (*self).$op_fn(rhs);
177 }
178 }
179 )*
180 };
181 }
182
183 impl_bin_op_using_f32! {
184 Add, add, AddAssign, add_assign;
185 Sub, sub, SubAssign, sub_assign;
186 Mul, mul, MulAssign, mul_assign;
187 Div, div, DivAssign, div_assign;
188 Rem, rem, RemAssign, rem_assign;
189 }
190
191 impl F16 {
192 pub fn from_bits(v: u16) -> Self {
193 #[cfg(feature = "f16")]
194 return F16(F16Impl::from_bits(v));
195 #[cfg(not(feature = "f16"))]
196 return F16(v);
197 }
198 pub fn to_bits(self) -> u16 {
199 #[cfg(feature = "f16")]
200 return self.0.to_bits();
201 #[cfg(not(feature = "f16"))]
202 return self.0;
203 }
204 pub fn abs(self) -> Self {
205 f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), [])
206 }
207 pub fn copysign(self, sign: Self) -> Self {
208 f16_impl!(
209 Self::from_bits((self.to_bits() & 0x7FFF) | (sign.to_bits() & 0x8000)),
210 [sign]
211 )
212 }
213 pub fn trunc(self) -> Self {
214 #[cfg(feature = "std")]
215 return f32::from(self).trunc().to();
216 #[cfg(not(feature = "std"))]
217 todo!();
218 }
219 pub fn ceil(self) -> Self {
220 #[cfg(feature = "std")]
221 return f32::from(self).ceil().to();
222 #[cfg(not(feature = "std"))]
223 todo!();
224 }
225 pub fn floor(self) -> Self {
226 #[cfg(feature = "std")]
227 return f32::from(self).floor().to();
228 #[cfg(not(feature = "std"))]
229 todo!();
230 }
231 pub fn round(self) -> Self {
232 #[cfg(feature = "std")]
233 return f32::from(self).round().to();
234 #[cfg(not(feature = "std"))]
235 todo!();
236 }
237 #[cfg(feature = "fma")]
238 pub fn fma(self, a: Self, b: Self) -> Self {
239 (f64::from(self) * f64::from(a) + f64::from(b)).to()
240 }
241
242 pub fn is_nan(self) -> bool {
243 f16_impl!(self.0.is_nan(), [])
244 }
245
246 pub fn is_infinite(self) -> bool {
247 f16_impl!(self.0.is_infinite(), [])
248 }
249
250 pub fn is_finite(self) -> bool {
251 f16_impl!(self.0.is_finite(), [])
252 }
253 }
254
255 impl Float for Value<F16> {
256 type PrimFloat = F16;
257 type BitsType = Value<u16>;
258 type SignedBitsType = Value<i16>;
259
260 fn abs(self) -> Self {
261 Value(self.0.abs())
262 }
263
264 fn trunc(self) -> Self {
265 Value(self.0.trunc())
266 }
267
268 fn ceil(self) -> Self {
269 Value(self.0.ceil())
270 }
271
272 fn floor(self) -> Self {
273 Value(self.0.floor())
274 }
275
276 fn round(self) -> Self {
277 Value(self.0.round())
278 }
279
280 #[cfg(feature = "fma")]
281 fn fma(self, a: Self, b: Self) -> Self {
282 Value(self.0.fma(a.0, b.0))
283 }
284
285 fn is_nan(self) -> Self::Bool {
286 Value(self.0.is_nan())
287 }
288
289 fn is_infinite(self) -> Self::Bool {
290 Value(self.0.is_infinite())
291 }
292
293 fn is_finite(self) -> Self::Bool {
294 Value(self.0.is_finite())
295 }
296
297 fn from_bits(v: Self::BitsType) -> Self {
298 Value(F16::from_bits(v.0))
299 }
300
301 fn to_bits(self) -> Self::BitsType {
302 Value(self.0.to_bits())
303 }
304 }
305
306 #[cfg(test)]
307 mod tests {
308 use super::*;
309 use core::cmp::Ordering;
310
311 #[test]
312 #[cfg_attr(
313 not(feature = "f16"),
314 should_panic(expected = "f16 feature is not enabled")
315 )]
316 fn test_abs() {
317 assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0);
318 assert_eq!(F16::from_bits(0).abs().to_bits(), 0);
319 assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC);
320 assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00);
321 assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00);
322 }
323
324 #[test]
325 #[cfg_attr(
326 not(feature = "f16"),
327 should_panic(expected = "f16 feature is not enabled")
328 )]
329 fn test_neg() {
330 assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0);
331 assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000);
332 assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC);
333 assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00);
334 assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00);
335 }
336
337 #[test]
338 #[cfg_attr(
339 not(feature = "f16"),
340 should_panic(expected = "f16 feature is not enabled")
341 )]
342 fn test_int_to_f16() {
343 assert_eq!(F16::to_bits(0u32.to()), 0);
344 for v in 1..0x20000u32 {
345 let leading_zeros = u32::leading_zeros(v);
346 let shifted_v = v << leading_zeros;
347 // round to nearest, ties to even
348 let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) {
349 Ordering::Less => false,
350 Ordering::Equal => (shifted_v & 0x200000) != 0,
351 Ordering::Greater => true,
352 };
353 let (rounded, carry) =
354 (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0));
355 let mantissa;
356 if carry {
357 mantissa = (rounded >> 22) as u16 + 0x400;
358 } else {
359 mantissa = (rounded >> 21) as u16;
360 }
361 assert_eq!((mantissa & !0x3FF), 0x400);
362 let exponent = 31 - leading_zeros as u16 + 15 + carry as u16;
363 let expected = if exponent < 0x1F {
364 (mantissa & 0x3FF) + (exponent << 10)
365 } else {
366 0x7C00
367 };
368 let actual = F16::to_bits(v.to());
369 assert_eq!(
370 actual, expected,
371 "actual = {:#X}, expected = {:#X}, v = {:#X}",
372 actual, expected, v
373 );
374 }
375 }
376 }