working on implementing fma, f16 rtz formal proof seems likely to work
[ieee754fpu.git] / src / ieee754 / fpfma / main_stage.py
1 """ floating-point fused-multiply-add
2
3 computes `z = (a * c) + b` but only rounds once at the end
4 """
5
6 from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
7 from ieee754.fpcommon.fpbase import FPRoundingMode
8 from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNormOutData
9 from nmigen.hdl.dsl import Module
10 from nmigen.hdl.ast import Signal, signed, unsigned, Mux, Cat
11 from ieee754.fpfma.util import expanded_exponent_shape, \
12 expanded_mantissa_shape, get_fpformat, EXPANDED_MANTISSA_EXTRA_LSBS
13 from ieee754.fpcommon.getop import FPPipeContext
14
15
16 class FPFMAPostCalcData:
17 def __init__(self, pspec):
18 fpf = get_fpformat(pspec)
19
20 self.sign = Signal()
21 """sign"""
22
23 self.exponent = Signal(expanded_exponent_shape(fpf))
24 """exponent -- unbiased"""
25
26 self.mantissa = Signal(expanded_mantissa_shape(fpf))
27 """unnormalized mantissa"""
28
29 self.bypassed_z = Signal(fpf.width)
30 """final output value of the fma when `do_bypass` is set"""
31
32 self.do_bypass = Signal()
33 """set if `bypassed_z` is the final output value of the fma"""
34
35 self.ctx = FPPipeContext(pspec)
36 """pipe context"""
37
38 self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT)
39 """rounding mode"""
40
41 def eq(self, i):
42 return [
43 self.sign.eq(i.sign),
44 self.exponent.eq(i.exponent),
45 self.mantissa.eq(i.mantissa),
46 self.bypassed_z.eq(i.bypassed_z),
47 self.do_bypass.eq(i.do_bypass),
48 self.ctx.eq(i.ctx),
49 self.rm.eq(i.rm),
50 ]
51
52 def __iter__(self):
53 yield self.sign
54 yield self.exponent
55 yield self.mantissa
56 yield self.bypassed_z
57 yield self.do_bypass
58 yield self.ctx
59 yield self.rm
60
61 def ports(self):
62 return list(self)
63
64
65 class FPFMAMain(PipeModBase):
66 def __init__(self, pspec):
67 super().__init__(pspec, "main")
68
69 def ispec(self):
70 return FPFMASpecialCasesDeNormOutData(self.pspec)
71
72 def ospec(self):
73 return FPFMAPostCalcData(self.pspec)
74
75 def elaborate(self, platform):
76 m = Module()
77 fpf = get_fpformat(self.pspec)
78 assert fpf.has_sign
79 inp = self.i
80 out = self.o
81
82 product_v = inp.a_mantissa * inp.c_mantissa
83 product = Signal(product_v.shape())
84 m.d.comb += product.eq(product_v)
85 negate_b_s = Signal(signed(1))
86 negate_b_u = Signal(unsigned(1))
87 m.d.comb += [
88 negate_b_s.eq(inp.do_sub),
89 negate_b_u.eq(inp.do_sub),
90 ]
91 sum_v = (product_v << EXPANDED_MANTISSA_EXTRA_LSBS) + \
92 (inp.b_mantissa ^ negate_b_s) + negate_b_u
93 sum = Signal(expanded_mantissa_shape(fpf))
94 m.d.comb += sum.eq(sum_v)
95
96 sum_neg = Signal()
97 sum_zero = Signal()
98 m.d.comb += [
99 sum_neg.eq(sum < 0), # just sign bit
100 sum_zero.eq(sum == 0),
101 ]
102
103 zero_sign_array = FPRoundingMode.make_array(FPRoundingMode.zero_sign)
104
105 with m.If(sum_zero & ~inp.do_bypass):
106 m.d.comb += [
107 out.bypassed_z.eq(fpf.zero(zero_sign_array[inp.rm])),
108 out.do_bypass.eq(True),
109 ]
110 with m.Else():
111 m.d.comb += [
112 out.bypassed_z.eq(inp.bypassed_z),
113 out.do_bypass.eq(inp.do_bypass),
114 ]
115
116 m.d.comb += [
117 out.sign.eq(sum_neg ^ inp.sign),
118 out.exponent.eq(inp.exponent),
119 out.mantissa.eq(Mux(sum_neg, -sum, sum)),
120 out.ctx.eq(inp.ctx),
121 out.rm.eq(inp.rm),
122 ]
123 return m
124
125
126 class FPFMAMainStage(PipeModBaseChain):
127 def __init__(self, pspec):
128 super().__init__(pspec)
129
130 def get_chain(self):
131 """ gets chain of modules
132 """
133 return [FPFMAMain(self.pspec)]