working on implementing fma, f16 rtz formal proof seems likely to work
[ieee754fpu.git] / src / ieee754 / fpfma / main_stage.py
index 1ab2b2b8a048a14330c6c0f6230e6b0539cfac65..7a028107e40d1c52ae363b42019b2c550f91e91f 100644 (file)
@@ -3,13 +3,13 @@
 computes `z = (a * c) + b` but only rounds once at the end
 """
 
-from nmutil.pipemodbase import PipeModBase
+from nmutil.pipemodbase import PipeModBase, PipeModBaseChain
 from ieee754.fpcommon.fpbase import FPRoundingMode
 from ieee754.fpfma.special_cases import FPFMASpecialCasesDeNormOutData
 from nmigen.hdl.dsl import Module
-from nmigen.hdl.ast import Signal, signed, unsigned, Mux
+from nmigen.hdl.ast import Signal, signed, unsigned, Mux, Cat
 from ieee754.fpfma.util import expanded_exponent_shape, \
-    expanded_mantissa_shape, get_fpformat
+    expanded_mantissa_shape, get_fpformat, EXPANDED_MANTISSA_EXTRA_LSBS
 from ieee754.fpcommon.getop import FPPipeContext
 
 
@@ -38,8 +38,31 @@ class FPFMAPostCalcData:
         self.rm = Signal(FPRoundingMode, reset=FPRoundingMode.DEFAULT)
         """rounding mode"""
 
+    def eq(self, i):
+        return [
+            self.sign.eq(i.sign),
+            self.exponent.eq(i.exponent),
+            self.mantissa.eq(i.mantissa),
+            self.bypassed_z.eq(i.bypassed_z),
+            self.do_bypass.eq(i.do_bypass),
+            self.ctx.eq(i.ctx),
+            self.rm.eq(i.rm),
+        ]
+
+    def __iter__(self):
+        yield self.sign
+        yield self.exponent
+        yield self.mantissa
+        yield self.bypassed_z
+        yield self.do_bypass
+        yield self.ctx
+        yield self.rm
+
+    def ports(self):
+        return list(self)
+
 
-class FPFMAMainStage(PipeModBase):
+class FPFMAMain(PipeModBase):
     def __init__(self, pspec):
         super().__init__(pspec, "main")
 
@@ -65,8 +88,9 @@ class FPFMAMainStage(PipeModBase):
             negate_b_s.eq(inp.do_sub),
             negate_b_u.eq(inp.do_sub),
         ]
-        sum_v = product_v + (inp.b_mantissa ^ negate_b_s) + negate_b_u
-        sum = Signal(sum_v.shape())
+        sum_v = (product_v << EXPANDED_MANTISSA_EXTRA_LSBS) + \
+            (inp.b_mantissa ^ negate_b_s) + negate_b_u
+        sum = Signal(expanded_mantissa_shape(fpf))
         m.d.comb += sum.eq(sum_v)
 
         sum_neg = Signal()
@@ -97,3 +121,13 @@ class FPFMAMainStage(PipeModBase):
             out.rm.eq(inp.rm),
         ]
         return m
+
+
+class FPFMAMainStage(PipeModBaseChain):
+    def __init__(self, pspec):
+        super().__init__(pspec)
+
+    def get_chain(self):
+        """ gets chain of modules
+        """
+        return [FPFMAMain(self.pspec)]