add f8 fma tests -- f16 takes >8hr to run with bitwuzla
[ieee754fpu.git] / src / ieee754 / fpfma / norm.py
1 from nmutil.pipemodbase import PipeModBaseChain, PipeModBase
2 from ieee754.fpcommon.fpbase import OverflowMod, FPFormat
3 from ieee754.fpcommon.postnormalise import FPNorm1Data
4 from ieee754.fpcommon.roundz import FPRoundMod
5 from ieee754.fpcommon.corrections import FPCorrectionsMod
6 from ieee754.fpcommon.pack import FPPackMod
7 from ieee754.fpfma.main_stage import FPFMAPostCalcData
8 from nmigen.hdl.dsl import Module
9 from nmigen.hdl.ast import Signal
10 from nmigen.lib.coding import PriorityEncoder
11
12
13 class FPFMANorm(PipeModBase):
14 def __init__(self, pspec):
15 super().__init__(pspec, "norm")
16
17 def ispec(self):
18 return FPFMAPostCalcData(self.pspec)
19
20 def ospec(self):
21 return FPNorm1Data(self.pspec)
22
23 def elaborate(self, platform):
24 m = Module()
25 fpf = FPFormat.from_pspec(self.pspec)
26 assert fpf.has_sign
27 inp: FPFMAPostCalcData = self.i
28 out: FPNorm1Data = self.o
29 m.submodules.pri_enc = pri_enc = PriorityEncoder(inp.mantissa.width)
30 m.d.comb += pri_enc.i.eq(inp.mantissa[::-1])
31 unrestricted_shift_amount = Signal(range(inp.mantissa.width))
32 shift_amount = Signal(range(inp.mantissa.width))
33 m.d.comb += unrestricted_shift_amount.eq(pri_enc.o)
34 with m.If(inp.exponent - (1 + fpf.e_sub) < unrestricted_shift_amount):
35 m.d.comb += shift_amount.eq(inp.exponent - (1 + fpf.e_sub))
36 with m.Else():
37 m.d.comb += shift_amount.eq(unrestricted_shift_amount)
38 n_mantissa = Signal(inp.mantissa.width)
39 m.d.comb += n_mantissa.eq(inp.mantissa << shift_amount)
40
41 m.submodules.of = of = OverflowMod()
42 m.d.comb += [
43 pri_enc.i.eq(inp.mantissa[::-1]),
44 of.guard.eq(n_mantissa[-(out.z.m.width + 1)]),
45 of.round_bit.eq(n_mantissa[-(out.z.m.width + 2)]),
46 of.sticky.eq(n_mantissa[:-(out.z.m.width + 2)].bool()),
47 of.m0.eq(out.z.m[0]),
48 of.fpflags.eq(0),
49 of.sign.eq(inp.sign),
50 of.rm.eq(inp.rm),
51 out.roundz.eq(of.roundz_out),
52 out.z.s.eq(inp.sign),
53 out.z.e.eq(inp.exponent - shift_amount),
54 out.z.m.eq(n_mantissa[-out.z.m.width:]),
55 out.out_do_z.eq(inp.do_bypass),
56 out.oz.eq(inp.bypassed_z),
57 out.ctx.eq(inp.ctx),
58 out.rm.eq(inp.rm),
59 ]
60 return m
61
62
63 class FPFMANormToPack(PipeModBaseChain):
64 def __init__(self, pspec):
65 super().__init__(pspec)
66
67 def get_chain(self):
68 """ gets chain of modules
69 """
70 nmod = FPFMANorm(self.pspec)
71 rmod = FPRoundMod(self.pspec)
72 cmod = FPCorrectionsMod(self.pspec)
73 pmod = FPPackMod(self.pspec)
74 return [nmod, rmod, cmod, pmod]