From 78bc58e1e6031a81620ced176b798f2e9e4703d1 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 30 Jun 2022 20:34:24 -0700 Subject: [PATCH] finish adding all rounding modes to fadd -- formal proof passes --- src/ieee754/fpadd/add0.py | 1 + src/ieee754/fpadd/add1.py | 4 +- src/ieee754/fpadd/align.py | 1 + src/ieee754/fpadd/datastruct.py | 10 +- src/ieee754/fpadd/specialcases.py | 22 ++-- src/ieee754/fpadd/test/test_add_formal.py | 130 ++++++++++++++++--- src/ieee754/fpcommon/basedata.py | 6 + src/ieee754/fpcommon/denorm.py | 1 + src/ieee754/fpcommon/fpbase.py | 147 ++++++++++++++++++---- src/ieee754/fpcommon/pack.py | 13 +- src/ieee754/fpcommon/postnormalise.py | 12 +- src/ieee754/fpcommon/pscdata.py | 9 +- src/ieee754/fpcommon/roundz.py | 8 +- 13 files changed, 294 insertions(+), 70 deletions(-) diff --git a/src/ieee754/fpadd/add0.py b/src/ieee754/fpadd/add0.py index f7e77e48..32c4fdf9 100644 --- a/src/ieee754/fpadd/add0.py +++ b/src/ieee754/fpadd/add0.py @@ -60,5 +60,6 @@ class FPAddStage0Mod(PipeModBase): comb += self.o.oz.eq(self.i.oz) comb += self.o.out_do_z.eq(self.i.out_do_z) comb += self.o.ctx.eq(self.i.ctx) + comb += self.o.rm.eq(self.i.rm) return m diff --git a/src/ieee754/fpadd/add1.py b/src/ieee754/fpadd/add1.py index c57534c0..38677d4c 100644 --- a/src/ieee754/fpadd/add1.py +++ b/src/ieee754/fpadd/add1.py @@ -57,7 +57,9 @@ class FPAddStage1Mod(PipeModBase): self.o.of.guard.eq(to[3]), self.o.of.round_bit.eq(to[2]), # sticky sourced from LSB and shifted if MSB hi, else unshifted - self.o.of.sticky.eq(Mux(msb, to[1] | tot[0], to[1])) + self.o.of.sticky.eq(Mux(msb, to[1] | tot[0], to[1])), + self.o.of.rm.eq(self.i.rm), + self.o.of.sign.eq(self.i.z.s), ] comb += self.o.out_do_z.eq(self.i.out_do_z) diff --git a/src/ieee754/fpadd/align.py b/src/ieee754/fpadd/align.py index 59b6c1ec..d565e6e4 100644 --- a/src/ieee754/fpadd/align.py +++ b/src/ieee754/fpadd/align.py @@ -132,6 +132,7 @@ class FPAddAlignSingleMod(PipeModBase): comb += self.o.z.eq(self.i.z) comb += self.o.out_do_z.eq(self.i.out_do_z) comb += self.o.oz.eq(self.i.oz) + comb += self.o.rm.eq(self.i.rm) return m diff --git a/src/ieee754/fpadd/datastruct.py b/src/ieee754/fpadd/datastruct.py index e2ca9427..c4bf4078 100644 --- a/src/ieee754/fpadd/datastruct.py +++ b/src/ieee754/fpadd/datastruct.py @@ -4,9 +4,8 @@ Copyright (C) 2019 Luke Kenneth Casson Leighton """ -from nmigen import Module, Signal - -from ieee754.fpcommon.fpbase import FPNumBaseRecord +from nmigen import Signal +from ieee754.fpcommon.fpbase import FPNumBaseRecord, FPRoundingMode from ieee754.fpcommon.getop import FPPipeContext @@ -21,6 +20,9 @@ class FPAddStage0Data: self.ctx = FPPipeContext(pspec) self.muxid = self.ctx.muxid + self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT) + """rounding mode""" + def eq(self, i): return [self.z.eq(i.z), self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz), - self.tot.eq(i.tot), self.ctx.eq(i.ctx)] + self.tot.eq(i.tot), self.ctx.eq(i.ctx), self.rm.eq(i.rm)] diff --git a/src/ieee754/fpadd/specialcases.py b/src/ieee754/fpadd/specialcases.py index 3b6c6921..31a49068 100644 --- a/src/ieee754/fpadd/specialcases.py +++ b/src/ieee754/fpadd/specialcases.py @@ -2,12 +2,10 @@ # Copyright (C) Jonathan P Dawson 2013 # 2013-12-12 -from nmigen import Module, Signal, Cat, Const, Mux -from nmigen.cli import main, verilog -from math import log +from nmigen import Module, Signal, Cat, Mux from nmutil.pipemodbase import PipeModBase, PipeModBaseChain -from ieee754.fpcommon.fpbase import FPNumDecode +from ieee754.fpcommon.fpbase import FPNumDecode, FPRoundingMode from ieee754.fpcommon.fpbase import FPNumBaseRecord from ieee754.fpcommon.basedata import FPBaseData @@ -45,8 +43,11 @@ class FPAddSpecialCasesMod(PipeModBase): self.o.b.eq(b1) ] + zero_sign_array = FPRoundingMode.make_array(FPRoundingMode.zero_sign) + # temporaries used below s_nomatch = Signal(reset_less=True) + s_match = Signal(reset_less=True) m_match = Signal(reset_less=True) e_match = Signal(reset_less=True) absa = Signal(reset_less=True) # a1.s & b1.s @@ -61,6 +62,7 @@ class FPAddSpecialCasesMod(PipeModBase): t_special = Signal(reset_less=True) comb += s_nomatch.eq(a1.s != b1.s) + comb += s_match.eq(a1.s == b1.s) comb += m_match.eq(a1.m == b1.m) comb += e_match.eq(a1.e == b1.e) @@ -80,12 +82,14 @@ class FPAddSpecialCasesMod(PipeModBase): # prepare inf/zero/nans z_zero = FPNumBaseRecord(width, False, name="z_zero") + z_default_zero = FPNumBaseRecord(width, False, name="z_default_zero") z_default_nan = FPNumBaseRecord(width, False, name="z_default_nan") z_quieted_a = FPNumBaseRecord(width, False, name="z_quieted_a") z_quieted_b = FPNumBaseRecord(width, False, name="z_quieted_b") z_infa = FPNumBaseRecord(width, False, name="z_infa") z_infb = FPNumBaseRecord(width, False, name="z_infb") comb += z_zero.zero(0) + comb += z_default_zero.zero(zero_sign_array[self.i.rm]) comb += z_default_nan.nan(0) comb += z_quieted_a.quieted_nan(a1) comb += z_quieted_b.quieted_nan(b1) @@ -103,20 +107,20 @@ class FPAddSpecialCasesMod(PipeModBase): # if a is inf and signs don't match return NaN # else return inf(a) # elif b is inf return inf(b) - # elif a is zero and b zero return signed-a/b + # elif a is zero and b zero with same sign return a + # elif a equal to -b return zero (sign determined by rounding-mode) # elif a is zero return b # elif b is zero return a - # elif a equal to -b return zero (+ve zero) # XXX *sigh* there are better ways to do this... # one of them: use a priority-picker! # in reverse-order, accumulate Muxing oz = 0 - oz = Mux(t_aeqmb, z_zero.v, oz) oz = Mux(t_b1zero, a1.v, oz) oz = Mux(t_a1zero, b1.v, oz) - oz = Mux(t_abz, Cat(self.i.b[:-1], absa), oz) + oz = Mux(t_aeqmb, z_default_zero.v, oz) + oz = Mux(t_abz & s_match, a1.v, oz) oz = Mux(t_b1inf, z_infb.v, oz) oz = Mux(t_a1inf, Mux(bexp128s, z_default_nan.v, z_infa.v), oz) oz = Mux(t_abnan, Mux(a1.is_nan, z_quieted_a.v, z_quieted_b.v), oz) @@ -125,6 +129,8 @@ class FPAddSpecialCasesMod(PipeModBase): comb += self.o.ctx.eq(self.i.ctx) + comb += self.o.rm.eq(self.i.rm) + return m diff --git a/src/ieee754/fpadd/test/test_add_formal.py b/src/ieee754/fpadd/test/test_add_formal.py index 2d3584d2..915c2b94 100644 --- a/src/ieee754/fpadd/test/test_add_formal.py +++ b/src/ieee754/fpadd/test/test_add_formal.py @@ -2,30 +2,55 @@ import unittest from nmutil.formaltest import FHDLTestCase from ieee754.fpadd.pipeline import FPADDBasePipe from nmigen.hdl.dsl import Module -from nmigen.hdl.ast import Initial, Assert, AnyConst, Signal, Assume +from nmigen.hdl.ast import Initial, Assert, AnyConst, Signal, Assume, Mux from nmigen.hdl.smtlib2 import SmtFloatingPoint, SmtSortFloatingPoint, \ - SmtSortFloat16, SmtSortFloat32, SmtSortFloat64, \ - ROUND_NEAREST_TIES_TO_EVEN + SmtSortFloat16, SmtSortFloat32, SmtSortFloat64, SmtBool, \ + SmtRoundingMode, ROUND_TOWARD_POSITIVE, ROUND_TOWARD_NEGATIVE +from ieee754.fpcommon.fpbase import FPRoundingMode from ieee754.pipeline import PipelineSpec class TestFAddFormal(FHDLTestCase): - def tst_fadd_rne_formal(self, sort): + def tst_fadd_formal(self, sort, rm): assert isinstance(sort, SmtSortFloatingPoint) + assert isinstance(rm, FPRoundingMode) width = sort.width dut = FPADDBasePipe(PipelineSpec(width, id_width=4)) m = Module() m.submodules.dut = dut m.d.comb += dut.n.i_ready.eq(True) m.d.comb += dut.p.i_valid.eq(Initial()) - a = dut.p.i_data.a - b = dut.p.i_data.b - z = dut.n.o_data.z - rm = ROUND_NEAREST_TIES_TO_EVEN + m.d.comb += dut.p.i_data.rm.eq(Mux(Initial(), rm, 0)) + out = Signal(width) + out_full = Signal(reset=False) + with m.If(dut.n.trigger): + # check we only got output for one cycle + m.d.comb += Assert(~out_full) + m.d.sync += out.eq(dut.n.o_data.z) + m.d.sync += out_full.eq(True) + a = Signal(width) + b = Signal(width) + m.d.comb += dut.p.i_data.a.eq(Mux(Initial(), a, 0)) + m.d.comb += dut.p.i_data.b.eq(Mux(Initial(), b, 0)) a_fp = SmtFloatingPoint.from_bits(a, sort=sort) b_fp = SmtFloatingPoint.from_bits(b, sort=sort) - z_fp = SmtFloatingPoint.from_bits(z, sort=sort) - expected_fp = a_fp.add(b_fp, rm=rm) + out_fp = SmtFloatingPoint.from_bits(out, sort=sort) + if rm in (FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE, + FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE): + rounded_up = Signal(width) + m.d.comb += rounded_up.eq(AnyConst(width)) + rounded_up_fp = a_fp.add(b_fp, rm=ROUND_TOWARD_POSITIVE) + rounded_down_fp = a_fp.add(b_fp, rm=ROUND_TOWARD_NEGATIVE) + m.d.comb += Assume(SmtFloatingPoint.from_bits( + rounded_up, sort=sort).same(rounded_up_fp).as_value()) + use_rounded_up = SmtBool.make(rounded_up[0]) + if rm is FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE: + is_zero = rounded_up_fp.is_zero() & rounded_down_fp.is_zero() + use_rounded_up |= is_zero + expected_fp = use_rounded_up.ite(rounded_up_fp, rounded_down_fp) + else: + smt_rm = SmtRoundingMode.make(rm.to_smtlib2()) + expected_fp = a_fp.add(b_fp, rm=smt_rm) expected = Signal(width) m.d.comb += expected.eq(AnyConst(width)) quiet_bit = 1 << (sort.mantissa_field_width - 1) @@ -42,24 +67,89 @@ class TestFAddFormal(FHDLTestCase): .same(expected_fp).as_value()) m.d.comb += a.eq(AnyConst(width)) m.d.comb += b.eq(AnyConst(width)) - with m.If(dut.n.trigger): - m.d.sync += Assert(z_fp.same(expected_fp).as_value()) - m.d.sync += Assert(z == expected) + with m.If(out_full): + m.d.comb += Assert(out_fp.same(expected_fp).as_value()) + m.d.comb += Assert(out == expected) self.assertFormal(m, depth=5, solver="bitwuzla") - # FIXME: check other rounding modes # FIXME: check exception flags - def test_fadd16_rne_formal(self): - self.tst_fadd_rne_formal(SmtSortFloat16()) + def test_fadd_f16_rne_formal(self): + self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RNE) + + def test_fadd_f32_rne_formal(self): + self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RNE) + + @unittest.skip("too slow") + def test_fadd_f64_rne_formal(self): + self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RNE) + + def test_fadd_f16_rtz_formal(self): + self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTZ) + + def test_fadd_f32_rtz_formal(self): + self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTZ) + + @unittest.skip("too slow") + def test_fadd_f64_rtz_formal(self): + self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTZ) + + def test_fadd_f16_rtp_formal(self): + self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTP) + + def test_fadd_f32_rtp_formal(self): + self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTP) + + @unittest.skip("too slow") + def test_fadd_f64_rtp_formal(self): + self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTP) + + def test_fadd_f16_rtn_formal(self): + self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTN) + + def test_fadd_f32_rtn_formal(self): + self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTN) + + @unittest.skip("too slow") + def test_fadd_f64_rtn_formal(self): + self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTN) + + def test_fadd_f16_rna_formal(self): + self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RNA) + + def test_fadd_f32_rna_formal(self): + self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RNA) + + @unittest.skip("too slow") + def test_fadd_f64_rna_formal(self): + self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RNA) + + def test_fadd_f16_rtop_formal(self): + self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTOP) + + def test_fadd_f32_rtop_formal(self): + self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTOP) + + @unittest.skip("too slow") + def test_fadd_f64_rtop_formal(self): + self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTOP) + + def test_fadd_f16_rton_formal(self): + self.tst_fadd_formal(SmtSortFloat16(), FPRoundingMode.RTON) - def test_fadd32_rne_formal(self): - self.tst_fadd_rne_formal(SmtSortFloat32()) + def test_fadd_f32_rton_formal(self): + self.tst_fadd_formal(SmtSortFloat32(), FPRoundingMode.RTON) @unittest.skip("too slow") - def test_fadd64_rne_formal(self): - self.tst_fadd_rne_formal(SmtSortFloat64()) + def test_fadd_f64_rton_formal(self): + self.tst_fadd_formal(SmtSortFloat64(), FPRoundingMode.RTON) + def test_all_rounding_modes_covered(self): + for width in 16, 32, 64: + for rm in FPRoundingMode: + rm_s = rm.name.lower() + name = f"test_fadd_f{width}_{rm_s}_formal" + assert callable(getattr(self, name)) if __name__ == '__main__': diff --git a/src/ieee754/fpcommon/basedata.py b/src/ieee754/fpcommon/basedata.py index 197ae96f..c6be83f3 100644 --- a/src/ieee754/fpcommon/basedata.py +++ b/src/ieee754/fpcommon/basedata.py @@ -2,6 +2,7 @@ # Copyright (C) 2019 Luke Kenneth Casson Leighton from nmigen import Signal +from ieee754.fpcommon.fpbase import FPRoundingMode from ieee754.fpcommon.getop import FPPipeContext @@ -20,17 +21,22 @@ class FPBaseData: self.muxid = self.ctx.muxid # make muxid available here: complicated self.ops = ops + self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT) + """rounding mode""" + def eq(self, i): ret = [] for op1, op2 in zip(self.ops, i.ops): ret.append(op1.eq(op2)) ret.append(self.ctx.eq(i.ctx)) + ret.append(self.rm.eq(i.rm)) return ret def __iter__(self): if self.ops: yield from self.ops yield from self.ctx + yield self.rm def ports(self): return list(self) diff --git a/src/ieee754/fpcommon/denorm.py b/src/ieee754/fpcommon/denorm.py index 350f413e..14c98c12 100644 --- a/src/ieee754/fpcommon/denorm.py +++ b/src/ieee754/fpcommon/denorm.py @@ -52,5 +52,6 @@ class FPAddDeNormMod(PipeModBase): comb += self.o.z.eq(self.i.z) comb += self.o.out_do_z.eq(self.i.out_do_z) comb += self.o.oz.eq(self.i.oz) + comb += self.o.rm.eq(self.i.rm) return m diff --git a/src/ieee754/fpcommon/fpbase.py b/src/ieee754/fpcommon/fpbase.py index 01f3e07f..fee88bef 100644 --- a/src/ieee754/fpcommon/fpbase.py +++ b/src/ieee754/fpcommon/fpbase.py @@ -6,7 +6,7 @@ Copyright (C) 2019 Jake Lifshay """ -from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable, Array +from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable, Array, Value from math import log from operator import or_ from functools import reduce @@ -89,8 +89,8 @@ class FPRoundingMode(enum.Enum): ROUND_NEAREST_TIES_TO_AWAY = RNA - RTO = 0b101 - """Round to Odd + RTOP = 0b101 + """Round to Odd, unsigned zeros are Positive Not in smtlib2. @@ -98,12 +98,109 @@ class FPRoundingMode(enum.Enum): that, otherwise return the nearest representable floating-point value that has an odd mantissa. + If the result is zero but with otherwise undetermined sign + (e.g. `1.0 - 1.0`), the sign is positive. + + This rounding mode is used for instructions with Round To Odd enabled, + and `FPSCR.RN != RTN`. + + This is useful to avoid double-rounding errors when doing arithmetic in a + larger type (e.g. f128) but where the answer should be a smaller type + (e.g. f80). + """ + + ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE = RTOP + + RTON = 0b110 + """Round to Odd, unsigned zeros are Negative + + Not in smtlib2. + + If the result is exactly representable as a floating-point number, return + that, otherwise return the nearest representable floating-point value + that has an odd mantissa. + + If the result is zero but with otherwise undetermined sign + (e.g. `1.0 - 1.0`), the sign is negative. + + This rounding mode is used for instructions with Round To Odd enabled, + and `FPSCR.RN == RTN`. + This is useful to avoid double-rounding errors when doing arithmetic in a larger type (e.g. f128) but where the answer should be a smaller type (e.g. f80). """ - ROUND_TO_ODD = RTO + ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE = RTON + + @staticmethod + def make_array(f): + l = [None] * len(FPRoundingMode) + for rm in FPRoundingMode: + l[rm.value] = f(rm) + return Array(l) + + def overflow_rounds_to_inf(self, sign): + """returns true if an overflow should round to `inf`, + false if it should round to `max_normal` + """ + not_sign = ~sign if isinstance(sign, Value) else not sign + if self is FPRoundingMode.RNE: + return True + elif self is FPRoundingMode.RTZ: + return False + elif self is FPRoundingMode.RTP: + return not_sign + elif self is FPRoundingMode.RTN: + return sign + elif self is FPRoundingMode.RNA: + return True + elif self is FPRoundingMode.RTOP: + return False + else: + assert self is FPRoundingMode.RTON + return False + + def underflow_rounds_to_zero(self, sign): + """returns true if an underflow should round to `zero`, + false if it should round to `min_denormal` + """ + not_sign = ~sign if isinstance(sign, Value) else not sign + if self is FPRoundingMode.RNE: + return True + elif self is FPRoundingMode.RTZ: + return True + elif self is FPRoundingMode.RTP: + return sign + elif self is FPRoundingMode.RTN: + return not_sign + elif self is FPRoundingMode.RNA: + return True + elif self is FPRoundingMode.RTOP: + return False + else: + assert self is FPRoundingMode.RTON + return False + + def zero_sign(self): + """which sign an exact zero result should have when it isn't + otherwise determined, e.g. for `1.0 - 1.0`. + """ + if self is FPRoundingMode.RNE: + return False + elif self is FPRoundingMode.RTZ: + return False + elif self is FPRoundingMode.RTP: + return False + elif self is FPRoundingMode.RTN: + return True + elif self is FPRoundingMode.RNA: + return False + elif self is FPRoundingMode.RTOP: + return False + else: + assert self is FPRoundingMode.RTON + return True if _HAVE_SMTLIB2: def to_smtlib2(self, default=_raise_err): @@ -122,7 +219,7 @@ class FPRoundingMode(enum.Enum): elif self is FPRoundingMode.RNA: return RoundingModeEnum.RNA else: - assert self is FPRoundingMode.RTO + assert self in (FPRoundingMode.RTOP, FPRoundingMode.RTON) if default is _raise_err: raise ValueError( "no corresponding smtlib2 rounding mode", self) @@ -548,6 +645,12 @@ class FPNumBaseRecord: def inf(self, s): return self.create(*self._inf(s)) + def max_normal(self, s): + return self.create(s, self.fp.P127, ~0) + + def min_denormal(self, s): + return self.create(s, self.fp.N127, 1) + def zero(self, s): return self.create(*self._zero(s)) @@ -1017,7 +1120,8 @@ class Overflow: self.sign = Signal(reset_less=True, name=name+"sign") """sign bit -- 1 means negative, 0 means positive""" - self.rm = Signal(name=name+"rm", reset=FPRoundingMode.DEFAULT) + self.rm = Signal(FPRoundingMode, name=name+"rm", + reset=FPRoundingMode.DEFAULT) """rounding mode""" #self.roundz = Signal(reset_less=True) @@ -1062,16 +1166,15 @@ class Overflow: assumes the rounding mode is `ROUND_TOWARDS_NEGATIVE` """ - FPRoundingMode.ROUND_TOWARDS_NEGATIVE return self.sign & (self.guard | self.round_bit | self.sticky) @property def roundz_rto(self): - """true if the mantissa should be rounded up for `rm == RTO` + """true if the mantissa should be rounded up for `rm in (RTOP, RTON)` - assumes the rounding mode is `ROUND_TO_ODD` + assumes the rounding mode is `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE` + or `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE` """ - FPRoundingMode.ROUND_TO_ODD return ~self.m0 & (self.guard | self.round_bit | self.sticky) @property @@ -1080,7 +1183,6 @@ class Overflow: assumes the rounding mode is `ROUND_TOWARDS_POSITIVE` """ - FPRoundingMode.ROUND_TOWARDS_POSITIVE return ~self.sign & (self.guard | self.round_bit | self.sticky) @property @@ -1089,24 +1191,23 @@ class Overflow: assumes the rounding mode is `ROUND_TOWARDS_ZERO` """ - FPRoundingMode.ROUND_TOWARDS_ZERO - return self.guard & (self.round_bit | self.sticky | self.m0) + return False @property def roundz(self): """true if the mantissa should be rounded up for the current rounding mode `self.rm` """ - l = [None] * len(FPRoundingMode) - l[FPRoundingMode.RNA.value] = self.roundz_rna - l[FPRoundingMode.RNE.value] = self.roundz_rne - l[FPRoundingMode.RTN.value] = self.roundz_rtn - l[FPRoundingMode.RTO.value] = self.roundz_rto - l[FPRoundingMode.RTP.value] = self.roundz_rtp - l[FPRoundingMode.RTZ.value] = self.roundz_rtz - for (i, v) in enumerate(l): - assert v is not None, f"missing rounding mode {bin(i)}" - return Array(l)[self.rm] + d = { + FPRoundingMode.RNA: self.roundz_rna, + FPRoundingMode.RNE: self.roundz_rne, + FPRoundingMode.RTN: self.roundz_rtn, + FPRoundingMode.RTOP: self.roundz_rto, + FPRoundingMode.RTON: self.roundz_rto, + FPRoundingMode.RTP: self.roundz_rtp, + FPRoundingMode.RTZ: self.roundz_rtz, + } + return FPRoundingMode.make_array(lambda rm: d[rm])[self.rm] class OverflowMod(Elaboratable, Overflow): diff --git a/src/ieee754/fpcommon/pack.py b/src/ieee754/fpcommon/pack.py index d53533b1..56c4457d 100644 --- a/src/ieee754/fpcommon/pack.py +++ b/src/ieee754/fpcommon/pack.py @@ -3,12 +3,10 @@ # 2013-12-12 from nmigen import Module, Signal -from nmigen.cli import main, verilog from nmutil.pipemodbase import PipeModBase -from ieee754.fpcommon.fpbase import FPNumBaseRecord, FPNumBase +from ieee754.fpcommon.fpbase import FPNumBaseRecord, FPNumBase, FPRoundingMode from ieee754.fpcommon.roundz import FPRoundData -from ieee754.fpcommon.getop import FPPipeContext from ieee754.fpcommon.packdata import FPPackData @@ -29,10 +27,17 @@ class FPPackMod(PipeModBase): z = FPNumBaseRecord(self.pspec.width, False, name="z") m.submodules.pack_in_z = in_z = FPNumBase(self.i.z) + overflow_array = FPRoundingMode.make_array( + lambda rm: rm.overflow_rounds_to_inf(self.i.z.s)) + overflow_rounds_to_inf = Signal() + m.d.comb += overflow_rounds_to_inf.eq(overflow_array[self.i.rm]) with m.If(~self.i.out_do_z): with m.If(in_z.is_overflowed): - comb += z.inf(self.i.z.s) + with m.If(overflow_rounds_to_inf): + comb += z.inf(self.i.z.s) + with m.Else(): + comb += z.max_normal(self.i.z.s) with m.Else(): comb += z.create(self.i.z.s, self.i.z.e, self.i.z.m) with m.Else(): diff --git a/src/ieee754/fpcommon/postnormalise.py b/src/ieee754/fpcommon/postnormalise.py index f1d33801..3c582b4f 100644 --- a/src/ieee754/fpcommon/postnormalise.py +++ b/src/ieee754/fpcommon/postnormalise.py @@ -2,12 +2,10 @@ # Copyright (C) Jonathan P Dawson 2013 # 2013-12-12 -from nmigen import Module, Signal, Cat, Mux -from nmigen.cli import main, verilog -from math import log +from nmigen import Module, Signal, Cat from nmutil.pipemodbase import PipeModBase -from ieee754.fpcommon.fpbase import (Overflow, OverflowMod, +from ieee754.fpcommon.fpbase import (FPRoundingMode, Overflow, OverflowMod, FPNumBase, FPNumBaseRecord) from ieee754.fpcommon.fpbase import FPState from ieee754.fpcommon.getop import FPPipeContext @@ -27,9 +25,12 @@ class FPNorm1Data: self.ctx = FPPipeContext(pspec) self.muxid = self.ctx.muxid + self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT) + """rounding mode""" + def eq(self, i): ret = [self.z.eq(i.z), self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz), - self.roundz.eq(i.roundz), self.ctx.eq(i.ctx)] + self.roundz.eq(i.roundz), self.ctx.eq(i.ctx), self.rm.eq(i.rm)] return ret @@ -128,6 +129,7 @@ class FPNorm1ModSingle(PipeModBase): m.d.comb += self.o.ctx.eq(self.i.ctx) m.d.comb += self.o.out_do_z.eq(self.i.out_do_z) m.d.comb += self.o.oz.eq(self.i.oz) + m.d.comb += self.o.rm.eq(of.rm) return m diff --git a/src/ieee754/fpcommon/pscdata.py b/src/ieee754/fpcommon/pscdata.py index da449d27..7bf22e94 100644 --- a/src/ieee754/fpcommon/pscdata.py +++ b/src/ieee754/fpcommon/pscdata.py @@ -5,7 +5,7 @@ Copyright (C) 2019 Luke Kenneth Casson Leighton """ from nmigen import Signal -from ieee754.fpcommon.fpbase import FPNumBaseRecord +from ieee754.fpcommon.fpbase import FPNumBaseRecord, FPRoundingMode from ieee754.fpcommon.getop import FPPipeContext @@ -25,6 +25,9 @@ class FPSCData: self.ctx = FPPipeContext(pspec) self.muxid = self.ctx.muxid + self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT) + """rounding mode""" + def __iter__(self): yield from self.a yield from self.b @@ -32,8 +35,10 @@ class FPSCData: yield self.oz yield self.out_do_z yield from self.ctx + yield self.rm def eq(self, i): ret = [self.z.eq(i.z), self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz), - self.a.eq(i.a), self.b.eq(i.b), self.ctx.eq(i.ctx)] + self.a.eq(i.a), self.b.eq(i.b), self.ctx.eq(i.ctx), + self.rm.eq(i.rm)] return ret diff --git a/src/ieee754/fpcommon/roundz.py b/src/ieee754/fpcommon/roundz.py index a11f42ac..22b1d7f3 100644 --- a/src/ieee754/fpcommon/roundz.py +++ b/src/ieee754/fpcommon/roundz.py @@ -3,10 +3,9 @@ # 2013-12-12 from nmigen import Module, Signal, Mux -from nmigen.cli import main, verilog from nmutil.pipemodbase import PipeModBase -from ieee754.fpcommon.fpbase import FPNumBase, FPNumBaseRecord +from ieee754.fpcommon.fpbase import FPNumBaseRecord, FPRoundingMode from ieee754.fpcommon.getop import FPPipeContext from ieee754.fpcommon.postnormalise import FPNorm1Data @@ -22,9 +21,12 @@ class FPRoundData: self.out_do_z = Signal(reset_less=True) self.oz = Signal(width, reset_less=True) + self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT) + """rounding mode""" + def eq(self, i): ret = [self.z.eq(i.z), self.out_do_z.eq(i.out_do_z), self.oz.eq(i.oz), - self.ctx.eq(i.ctx)] + self.ctx.eq(i.ctx), self.rm.eq(i.rm)] return ret -- 2.30.2