use ospec in FPAddAlignSingleAdd
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
index 744f5049451f81def45afcb1acac14fc19eb1d70..e3e904bc05c61156775973b33c677fed9ae14313 100644 (file)
@@ -716,8 +716,7 @@ class FPAddAlignSingleAdd(FPState, FPID):
         self.a0_out_z = FPNumBase(width, False)
 
         self.a1mod = FPAddStage1Mod(width)
-        self.out_z = FPNumBase(width, False)
-        self.out_of = Overflow()
+        self.a1o = self.a1mod.ospec()
 
     def setup(self, m, in_a, in_b, in_mid):
         """ links module to inputs and outputs
@@ -736,8 +735,7 @@ class FPAddAlignSingleAdd(FPState, FPID):
 
     def action(self, m):
         self.idsync(m)
-        m.d.sync += self.out_of.eq(self.a1mod.out_of)
-        m.d.sync += self.out_z.eq(self.a1mod.out_z)
+        m.d.sync += self.a1o.eq(self.a1mod.o)
         m.next = "normalise_1"
 
 
@@ -836,26 +834,41 @@ class FPAddStage0(FPState, FPID):
         m.next = "add_1"
 
 
+class FPAddStage1Data:
+
+    def __init__(self, width):
+        self.z = FPNumBase(width, False)
+        self.of = Overflow()
+
+    def eq(self, i):
+        return [self.z.eq(i.z), self.of.eq(i.of)]
+
+
+
 class FPAddStage1Mod(FPState):
     """ Second stage of add: preparation for normalisation.
         detects when tot sum is too big (tot[27] is kinda a carry bit)
     """
 
     def __init__(self, width):
-        self.out_norm = Signal(reset_less=True)
-        self.in_z = FPNumBase(width, False)
-        self.in_tot = Signal(self.in_z.m_width + 4, reset_less=True)
-        self.out_z = FPNumBase(width, False)
-        self.out_of = Overflow()
+        self.width = width
+        self.i = self.ispec()
+        self.o = self.ospec()
+
+    def ispec(self):
+        return FPAddStage0Data(self.width)
+
+    def ospec(self):
+        return FPAddStage1Data(self.width)
 
     def setup(self, m, in_tot, in_z):
         """ links module to inputs and outputs
         """
         m.submodules.add1 = self
-        m.submodules.add1_out_overflow = self.out_of
+        m.submodules.add1_out_overflow = self.o.of
 
-        m.d.comb += self.in_z.eq(in_z)
-        m.d.comb += self.in_tot.eq(in_tot)
+        m.d.comb += self.i.z.eq(in_z)
+        m.d.comb += self.i.tot.eq(in_tot)
 
     def elaborate(self, platform):
         m = Module()
@@ -863,25 +876,25 @@ class FPAddStage1Mod(FPState):
         #m.submodules.norm1_out_overflow = self.out_of
         #m.submodules.norm1_in_z = self.in_z
         #m.submodules.norm1_out_z = self.out_z
-        m.d.comb += self.out_z.eq(self.in_z)
+        m.d.comb += self.o.z.eq(self.i.z)
         # tot[-1] (MSB) gets set when the sum overflows. shift result down
-        with m.If(self.in_tot[-1]):
+        with m.If(self.i.tot[-1]):
             m.d.comb += [
-                self.out_z.m.eq(self.in_tot[4:]),
-                self.out_of.m0.eq(self.in_tot[4]),
-                self.out_of.guard.eq(self.in_tot[3]),
-                self.out_of.round_bit.eq(self.in_tot[2]),
-                self.out_of.sticky.eq(self.in_tot[1] | self.in_tot[0]),
-                self.out_z.e.eq(self.in_z.e + 1)
+                self.o.z.m.eq(self.i.tot[4:]),
+                self.o.of.m0.eq(self.i.tot[4]),
+                self.o.of.guard.eq(self.i.tot[3]),
+                self.o.of.round_bit.eq(self.i.tot[2]),
+                self.o.of.sticky.eq(self.i.tot[1] | self.i.tot[0]),
+                self.o.z.e.eq(self.i.z.e + 1)
         ]
         # tot[-1] (MSB) zero case
         with m.Else():
             m.d.comb += [
-                self.out_z.m.eq(self.in_tot[3:]),
-                self.out_of.m0.eq(self.in_tot[3]),
-                self.out_of.guard.eq(self.in_tot[2]),
-                self.out_of.round_bit.eq(self.in_tot[1]),
-                self.out_of.sticky.eq(self.in_tot[0])
+                self.o.z.m.eq(self.i.tot[3:]),
+                self.o.of.m0.eq(self.i.tot[3]),
+                self.o.of.guard.eq(self.i.tot[2]),
+                self.o.of.round_bit.eq(self.i.tot[1]),
+                self.o.of.sticky.eq(self.i.tot[0])
         ]
         return m
 
@@ -1081,7 +1094,6 @@ class FPNorm1ModMulti:
     def __init__(self, width, single_cycle=True):
         self.width = width
         self.in_select = Signal(reset_less=True)
-        self.out_norm = Signal(reset_less=True)
         self.in_z = FPNumBase(width, False)
         self.in_of = Overflow()
         self.temp_z = FPNumBase(width, False)
@@ -1552,7 +1564,7 @@ class FPADDBaseMod(FPID):
         alm.setup(m, sc.o.a, sc.o.b, sc.in_mid)
 
         n1 = self.add_state(FPNormToPack(self.width, self.id_wid))
-        n1.setup(m, alm.out_z, alm.out_of, alm.in_mid)
+        n1.setup(m, alm.a1o.z, alm.a1o.of, alm.in_mid)
 
         ppz = self.add_state(FPPutZ("pack_put_z", n1.out_z, self.out_z,
                                     n1.in_mid, self.out_mid))