+/// calculate `sqrt(v)`, error inherited from `kernel_fn`. Calls `kernel_fn` with inputs in the range `0.5 <= x < 2.0`.
+pub fn sqrt_impl<
+ Ctx: Context,
+ VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+ VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+ PrimF: PrimFloat<BitsType = PrimU>,
+ PrimU: PrimUInt,
+ KernelFn: FnOnce(Ctx, VecF) -> VecF,
+>(
+ ctx: Ctx,
+ v: VecF,
+ kernel_fn: KernelFn,
+) -> VecF {
+ let is_normal_case = v.gt(ctx.make(0.0.to())) & v.is_finite();
+ let is_zero_or_positive = v.ge(ctx.make(0.0.to()));
+ let exceptional_retval = is_zero_or_positive.select(v, VecF::nan(ctx));
+ let need_subnormal_scale = v.is_zero_or_subnormal();
+ let subnormal_result_scale_exponent: PrimU = (PrimF::MANTISSA_FIELD_WIDTH + 1.to()) / 2.to();
+ let subnormal_input_scale_exponent = subnormal_result_scale_exponent * 2.to();
+ let subnormal_result_scale_prim =
+ PrimF::cvt_from(1) / PrimF::cvt_from(PrimU::cvt_from(1) << subnormal_result_scale_exponent);
+ let subnormal_input_scale_prim =
+ PrimF::cvt_from(PrimU::cvt_from(1) << subnormal_input_scale_exponent);
+ let subnormal_result_scale: VecF =
+ need_subnormal_scale.select(ctx.make(subnormal_result_scale_prim), ctx.make(1.0.to()));
+ let subnormal_input_scale: VecF =
+ need_subnormal_scale.select(ctx.make(subnormal_input_scale_prim), ctx.make(1.0.to()));
+ let v = v * subnormal_input_scale;
+ let exponent_field = v.extract_exponent_field();
+ let normal_result_scale_exponent_field_offset: PrimU =
+ PrimF::EXPONENT_BIAS_UNSIGNED - (PrimF::EXPONENT_BIAS_UNSIGNED >> 1.to());
+ let shifted_exponent_field = exponent_field >> ctx.make(1.to());
+ let normal_result_scale_exponent_field =
+ shifted_exponent_field + ctx.make(normal_result_scale_exponent_field_offset);
+ let normal_result_scale = ctx
+ .make::<VecF>(1.to())
+ .with_exponent_field(normal_result_scale_exponent_field);
+ let v = v.with_exponent_field(
+ (exponent_field & ctx.make(1.to()))
+ | ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED & !PrimU::cvt_from(1)),
+ );
+ let normal_result = kernel_fn(ctx, v) * (normal_result_scale * subnormal_result_scale);
+ is_normal_case.select(normal_result, exceptional_retval)
+}
+
+/// computes `sqrt(x)`
+/// has an error of up to 2ULP
+pub fn sqrt_fast_f16<Ctx: Context>(ctx: Ctx, v: Ctx::VecF16) -> Ctx::VecF16 {
+ sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 3).0)
+}
+
+/// computes `sqrt(x)`
+/// has an error of up to 3ULP
+pub fn sqrt_fast_f32<Ctx: Context>(ctx: Ctx, v: Ctx::VecF32) -> Ctx::VecF32 {
+ sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 4).0)
+}
+
+/// computes `sqrt(x)`
+/// has an error of up to 2ULP
+pub fn sqrt_fast_f64<Ctx: Context>(ctx: Ctx, v: Ctx::VecF64) -> Ctx::VecF64 {
+ sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 5).0)
+}
+