use ospec in FPAddAlignSingleAdd
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
index eb49c38c053b41055897a650f7debdd8171c41d9..e3e904bc05c61156775973b33c677fed9ae14313 100644 (file)
@@ -716,8 +716,7 @@ class FPAddAlignSingleAdd(FPState, FPID):
         self.a0_out_z = FPNumBase(width, False)
 
         self.a1mod = FPAddStage1Mod(width)
-        self.out_z = FPNumBase(width, False)
-        self.out_of = Overflow()
+        self.a1o = self.a1mod.ospec()
 
     def setup(self, m, in_a, in_b, in_mid):
         """ links module to inputs and outputs
@@ -726,8 +725,8 @@ class FPAddAlignSingleAdd(FPState, FPID):
         m.d.comb += self.o.eq(self.mod.o)
 
         self.a0mod.setup(m, self.o.a, self.o.b)
-        m.d.comb += self.a0_out_z.eq(self.a0mod.out_z)
-        m.d.comb += self.out_tot.eq(self.a0mod.out_tot)
+        m.d.comb += self.a0_out_z.eq(self.a0mod.o.z)
+        m.d.comb += self.out_tot.eq(self.a0mod.o.tot)
 
         self.a1mod.setup(m, self.out_tot, self.a0_out_z)
 
@@ -736,61 +735,75 @@ class FPAddAlignSingleAdd(FPState, FPID):
 
     def action(self, m):
         self.idsync(m)
-        m.d.sync += self.out_of.eq(self.a1mod.out_of)
-        m.d.sync += self.out_z.eq(self.a1mod.out_z)
+        m.d.sync += self.a1o.eq(self.a1mod.o)
         m.next = "normalise_1"
 
 
+class FPAddStage0Data:
+
+    def __init__(self, width):
+        self.z = FPNumBase(width, False)
+        self.tot = Signal(self.z.m_width + 4, reset_less=True)
+
+    def eq(self, i):
+        return [self.z.eq(i.z), self.tot.eq(i.tot)]
+
+
 class FPAddStage0Mod:
 
     def __init__(self, width):
-        self.in_a = FPNumBase(width)
-        self.in_b = FPNumBase(width)
-        self.out_z = FPNumBase(width, False)
-        self.out_tot = Signal(self.out_z.m_width + 4, reset_less=True)
+        self.width = width
+        self.i = self.ispec()
+        self.o = self.ospec()
+
+    def ispec(self):
+        return FPNumBase2Ops(self.width)
+
+    def ospec(self):
+        return FPAddStage0Data(self.width)
 
     def setup(self, m, in_a, in_b):
         """ links module to inputs and outputs
         """
         m.submodules.add0 = self
-        m.d.comb += self.in_a.eq(in_a)
-        m.d.comb += self.in_b.eq(in_b)
+        m.d.comb += self.i.a.eq(in_a)
+        m.d.comb += self.i.b.eq(in_b)
 
     def elaborate(self, platform):
         m = Module()
-        m.submodules.add0_in_a = self.in_a
-        m.submodules.add0_in_b = self.in_b
-        m.submodules.add0_out_z = self.out_z
+        m.submodules.add0_in_a = self.i.a
+        m.submodules.add0_in_b = self.i.b
+        m.submodules.add0_out_z = self.o.z
 
-        m.d.comb += self.out_z.e.eq(self.in_a.e)
+        m.d.comb += self.o.z.e.eq(self.i.a.e)
 
         # store intermediate tests (and zero-extended mantissas)
         seq = Signal(reset_less=True)
         mge = Signal(reset_less=True)
-        am0 = Signal(len(self.in_a.m)+1, reset_less=True)
-        bm0 = Signal(len(self.in_b.m)+1, reset_less=True)
-        m.d.comb += [seq.eq(self.in_a.s == self.in_b.s),
-                     mge.eq(self.in_a.m >= self.in_b.m),
-                     am0.eq(Cat(self.in_a.m, 0)),
-                     bm0.eq(Cat(self.in_b.m, 0))
+        am0 = Signal(len(self.i.a.m)+1, reset_less=True)
+        bm0 = Signal(len(self.i.b.m)+1, reset_less=True)
+        m.d.comb += [seq.eq(self.i.a.s == self.i.b.s),
+                     mge.eq(self.i.a.m >= self.i.b.m),
+                     am0.eq(Cat(self.i.a.m, 0)),
+                     bm0.eq(Cat(self.i.b.m, 0))
                     ]
         # same-sign (both negative or both positive) add mantissas
         with m.If(seq):
             m.d.comb += [
-                self.out_tot.eq(am0 + bm0),
-                self.out_z.s.eq(self.in_a.s)
+                self.o.tot.eq(am0 + bm0),
+                self.o.z.s.eq(self.i.a.s)
             ]
         # a mantissa greater than b, use a
         with m.Elif(mge):
             m.d.comb += [
-                self.out_tot.eq(am0 - bm0),
-                self.out_z.s.eq(self.in_a.s)
+                self.o.tot.eq(am0 - bm0),
+                self.o.z.s.eq(self.i.a.s)
             ]
         # b mantissa greater than a, use b
         with m.Else():
             m.d.comb += [
-                self.out_tot.eq(bm0 - am0),
-                self.out_z.s.eq(self.in_b.s)
+                self.o.tot.eq(bm0 - am0),
+                self.o.z.s.eq(self.i.b.s)
         ]
         return m
 
@@ -805,8 +818,7 @@ class FPAddStage0(FPState, FPID):
         FPState.__init__(self, "add_0")
         FPID.__init__(self, id_wid)
         self.mod = FPAddStage0Mod(width)
-        self.out_z = FPNumBase(width, False)
-        self.out_tot = Signal(self.out_z.m_width + 4, reset_less=True)
+        self.o = self.mod.ospec()
 
     def setup(self, m, in_a, in_b, in_mid):
         """ links module to inputs and outputs
@@ -818,31 +830,45 @@ class FPAddStage0(FPState, FPID):
     def action(self, m):
         self.idsync(m)
         # NOTE: these could be done as combinatorial (merge add0+add1)
-        m.d.sync += self.out_z.eq(self.mod.out_z)
-        m.d.sync += self.out_tot.eq(self.mod.out_tot)
+        m.d.sync += self.o.eq(self.mod.o)
         m.next = "add_1"
 
 
+class FPAddStage1Data:
+
+    def __init__(self, width):
+        self.z = FPNumBase(width, False)
+        self.of = Overflow()
+
+    def eq(self, i):
+        return [self.z.eq(i.z), self.of.eq(i.of)]
+
+
+
 class FPAddStage1Mod(FPState):
     """ Second stage of add: preparation for normalisation.
         detects when tot sum is too big (tot[27] is kinda a carry bit)
     """
 
     def __init__(self, width):
-        self.out_norm = Signal(reset_less=True)
-        self.in_z = FPNumBase(width, False)
-        self.in_tot = Signal(self.in_z.m_width + 4, reset_less=True)
-        self.out_z = FPNumBase(width, False)
-        self.out_of = Overflow()
+        self.width = width
+        self.i = self.ispec()
+        self.o = self.ospec()
+
+    def ispec(self):
+        return FPAddStage0Data(self.width)
+
+    def ospec(self):
+        return FPAddStage1Data(self.width)
 
     def setup(self, m, in_tot, in_z):
         """ links module to inputs and outputs
         """
         m.submodules.add1 = self
-        m.submodules.add1_out_overflow = self.out_of
+        m.submodules.add1_out_overflow = self.o.of
 
-        m.d.comb += self.in_z.eq(in_z)
-        m.d.comb += self.in_tot.eq(in_tot)
+        m.d.comb += self.i.z.eq(in_z)
+        m.d.comb += self.i.tot.eq(in_tot)
 
     def elaborate(self, platform):
         m = Module()
@@ -850,25 +876,25 @@ class FPAddStage1Mod(FPState):
         #m.submodules.norm1_out_overflow = self.out_of
         #m.submodules.norm1_in_z = self.in_z
         #m.submodules.norm1_out_z = self.out_z
-        m.d.comb += self.out_z.eq(self.in_z)
+        m.d.comb += self.o.z.eq(self.i.z)
         # tot[-1] (MSB) gets set when the sum overflows. shift result down
-        with m.If(self.in_tot[-1]):
+        with m.If(self.i.tot[-1]):
             m.d.comb += [
-                self.out_z.m.eq(self.in_tot[4:]),
-                self.out_of.m0.eq(self.in_tot[4]),
-                self.out_of.guard.eq(self.in_tot[3]),
-                self.out_of.round_bit.eq(self.in_tot[2]),
-                self.out_of.sticky.eq(self.in_tot[1] | self.in_tot[0]),
-                self.out_z.e.eq(self.in_z.e + 1)
+                self.o.z.m.eq(self.i.tot[4:]),
+                self.o.of.m0.eq(self.i.tot[4]),
+                self.o.of.guard.eq(self.i.tot[3]),
+                self.o.of.round_bit.eq(self.i.tot[2]),
+                self.o.of.sticky.eq(self.i.tot[1] | self.i.tot[0]),
+                self.o.z.e.eq(self.i.z.e + 1)
         ]
         # tot[-1] (MSB) zero case
         with m.Else():
             m.d.comb += [
-                self.out_z.m.eq(self.in_tot[3:]),
-                self.out_of.m0.eq(self.in_tot[3]),
-                self.out_of.guard.eq(self.in_tot[2]),
-                self.out_of.round_bit.eq(self.in_tot[1]),
-                self.out_of.sticky.eq(self.in_tot[0])
+                self.o.z.m.eq(self.i.tot[3:]),
+                self.o.of.m0.eq(self.i.tot[3]),
+                self.o.of.guard.eq(self.i.tot[2]),
+                self.o.of.round_bit.eq(self.i.tot[1]),
+                self.o.of.sticky.eq(self.i.tot[0])
         ]
         return m
 
@@ -1068,7 +1094,6 @@ class FPNorm1ModMulti:
     def __init__(self, width, single_cycle=True):
         self.width = width
         self.in_select = Signal(reset_less=True)
-        self.out_norm = Signal(reset_less=True)
         self.in_z = FPNumBase(width, False)
         self.in_of = Overflow()
         self.temp_z = FPNumBase(width, False)
@@ -1539,7 +1564,7 @@ class FPADDBaseMod(FPID):
         alm.setup(m, sc.o.a, sc.o.b, sc.in_mid)
 
         n1 = self.add_state(FPNormToPack(self.width, self.id_wid))
-        n1.setup(m, alm.out_z, alm.out_of, alm.in_mid)
+        n1.setup(m, alm.a1o.z, alm.a1o.of, alm.in_mid)
 
         ppz = self.add_state(FPPutZ("pack_put_z", n1.out_z, self.out_z,
                                     n1.in_mid, self.out_mid))