create FPNum class
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 14 Feb 2019 10:35:13 +0000 (10:35 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 14 Feb 2019 10:35:13 +0000 (10:35 +0000)
src/add/nmigen_add_experiment.py

index 8c99a1ff157e3564496e7a466e59aabbbf72bad6..7f90ca456fbe7574493352e380e81921d3c037dd 100644 (file)
@@ -5,6 +5,16 @@
 from nmigen import Module, Signal, Cat
 from nmigen.cli import main
 
+class FPNum:
+    def __init__(self, width, m_width=None):
+        self.width = width
+        if m_width is None:
+            m_width = width + 3
+        self.v = Signal(width)      # Latched copy of value
+        self.m = Signal(m_width)    # Mantissa: ??? seems to be 1 bit extra??
+        self.e = Signal((10, True)) # Exponent: 10 bits, signed
+        self.s = Signal()           # Sign bit
+
 
 class FPADD:
     def __init__(self, width):
@@ -29,9 +39,9 @@ class FPADD:
 
     def create_z(self, z, s, e, m):
         return [
-          z[31].eq(s),    # sign
-          z[23:31].eq(e), # exp
-          z[0:23].eq(m)   # mantissa
+          z.v[31].eq(s),    # sign
+          z.v[23:31].eq(e), # exp
+          z.v[0:23].eq(m)   # mantissa
         ]
 
     def nan(self, z, s):
@@ -44,19 +54,9 @@ class FPADD:
         m = Module()
 
         # Latches
-        a = Signal(self.width)
-        b = Signal(self.width)
-        z = Signal(self.width)
-
-        # Mantissa
-        a_m = Signal(27) # ??? seems to be 1 bit extra??
-        b_m = Signal(27) # ??? seems to be 1 bit extra??
-        z_m = Signal(24)
-
-        # Exponent: 10 bits, signed (the exponent bias is subtracted)
-        a_e = Signal((10, True))
-        b_e = Signal((10, True))
-        z_e = Signal((10, True))
+        a = FPNum(self.width)
+        b = FPNum(self.width)
+        z = FPNum(self.width, 24)
 
         # Sign
         a_s = Signal()
@@ -78,7 +78,7 @@ class FPADD:
                 with m.If((self.in_a_ack) & (self.in_a_stb)):
                     m.next = "get_b"
                     m.d.sync += [
-                        a.eq(self.in_a),
+                        a.v.eq(self.in_a),
                         self.in_a_ack.eq(0)
                     ]
                 with m.Else():
@@ -91,7 +91,7 @@ class FPADD:
                 with m.If((self.in_b_ack) & (self.in_b_stb)):
                     m.next = "get_a"
                     m.d.sync += [
-                        b.eq(self.in_b),
+                        b.v.eq(self.in_b),
                         self.in_b_ack.eq(0)
                     ]
                 with m.Else():
@@ -104,14 +104,14 @@ class FPADD:
                 m.next = "special_cases"
                 m.d.sync += [
                     # mantissa
-                    a_m.eq(Cat(0, 0, 0, a[0:23])),
-                    b_m.eq(Cat(0, 0, 0, b[0:23])),
+                    a.m.eq(Cat(0, 0, 0, a.v[0:23])),
+                    b.m.eq(Cat(0, 0, 0, b.v[0:23])),
                     # exponent (take off exponent bias, here)
-                    a_e.eq(Cat(a[23:31]) - 127),
-                    b_e.eq(Cat(b[23:31]) - 127),
+                    a.e.eq(Cat(a.v[23:31]) - 127),
+                    b.e.eq(Cat(b.v[23:31]) - 127),
                     # sign
-                    a_s.eq(Cat(a[31])),
-                    b_s.eq(Cat(b[31]))
+                    a.s.eq(Cat(a.v[31])),
+                    b.s.eq(Cat(b.v[31]))
                 ]
 
             # ******
@@ -120,59 +120,59 @@ class FPADD:
             with m.State("special_cases"):
 
                 # if a is NaN or b is NaN return NaN
-                with m.If(((a_e == 128) & (a_m != 0)) | \
-                          ((b_e == 128) & (b_m != 0))):
+                with m.If(((a.e == 128) & (a.m != 0)) | \
+                          ((b.e == 128) & (b.m != 0))):
                     m.next = "put_z"
                     m.d.sync += self.nan(z, 1)
 
                 # if a is inf return inf (or NaN)
-                with m.Elif(a_e == 128):
+                with m.Elif(a.e == 128):
                     m.next = "put_z"
-                    m.d.sync += self.inf(z, a_s)
+                    m.d.sync += self.inf(z, a.s)
                     # if a is inf and signs don't match return NaN
-                    with m.If((b_e == 128) & (a_s != b_s)):
-                        m.d.sync += self.nan(z, b_s)
+                    with m.If((b.e == 128) & (a.s != b.s)):
+                        m.d.sync += self.nan(z, b.s)
 
                 # if b is inf return inf
-                with m.Elif(b_e == 128):
+                with m.Elif(b.e == 128):
                     m.next = "put_z"
-                    m.d.sync += self.inf(z, b_s)
+                    m.d.sync += self.inf(z, b.s)
 
                 # if a is zero and b zero return signed-a/b
-                with m.Elif(((a_e == -127) & (a_m == 0)) & \
-                            ((b_e == -127) & (b_m == 0))):
+                with m.Elif(((a.e == -127) & (a.m == 0)) & \
+                            ((b.e == -127) & (b.m == 0))):
                     m.next = "put_z"
-                    m.d.sync += self.create_z(z, a_s & b_s,
-                                                 b_e[0:8] + 127,
-                                                 b_m[3:26])
+                    m.d.sync += self.create_z(z, a.s & b.s,
+                                                 b.e[0:8] + 127,
+                                                 b.m[3:26])
 
                 # if a is zero return b
-                with m.Elif((a_e == -127) & (a_m == 0)):
+                with m.Elif((a.e == -127) & (a.m == 0)):
                     m.next = "put_z"
-                    m.d.sync += self.create_z(z, b_s,
-                                                 b_e[0:8] + 127,
-                                                 b_m[3:26])
+                    m.d.sync += self.create_z(z, b.s,
+                                                 b.e[0:8] + 127,
+                                                 b.m[3:26])
 
                 # if b is zero return a
-                with m.Elif((b_e == -127) & (b_m == 0)):
+                with m.Elif((b.e == -127) & (b.m == 0)):
                     m.next = "put_z"
-                    m.d.sync += self.create_z(z, a_s,
-                                                 a_e[0:8] + 127,
-                                                 a_m[3:26])
+                    m.d.sync += self.create_z(z, a.s,
+                                                 a.e[0:8] + 127,
+                                                 a.m[3:26])
 
                 # Denormalised Number checks
                 with m.Else():
                     m.next = "align"
                     # denormalise a check
-                    with m.If(a_e == -127):
-                        m.d.sync += a_e.eq(-126) # limit a exponent
+                    with m.If(a.e == -127):
+                        m.d.sync += a.e.eq(-126) # limit a exponent
                     with m.Else():
-                        m.d.sync += a_m[26].eq(1) # set highest mantissa bit
+                        m.d.sync += a.m[26].eq(1) # set highest mantissa bit
                     # denormalise b check
-                    with m.If(b_e == -127):
-                        m.d.sync += b_e.eq(-126) # limit b exponent
+                    with m.If(b.e == -127):
+                        m.d.sync += b.e.eq(-126) # limit b exponent
                     with m.Else():
-                        m.d.sync += b_m[26].eq(1) # set highest mantissa bit
+                        m.d.sync += b.m[26].eq(1) # set highest mantissa bit
 
             # ******
             # align.  NOTE: this does *not* do single-cycle multi-shifting,
@@ -180,18 +180,18 @@ class FPADD:
 
             with m.State("align"):
                 # exponent of a greater than b: increment b exp, shift b mant
-                with m.If(a_e > b_e):
+                with m.If(a.e > b.e):
                     m.d.sync += [
-                      b_e.eq(b_e + 1),
-                      b_m.eq(b_m >> 1),
-                      b_m[0].eq(b_m[0] | b_m[1]) # moo??
+                      b.e.eq(b.e + 1),
+                      b.m.eq(b.m >> 1),
+                      b.m[0].eq(b.m[0] | b.m[1]) # moo??
                     ]
                 # exponent of b greater than a: increment a exp, shift a mant
-                with m.Elif(a_e < b_e):
+                with m.Elif(a.e < b.e):
                     m.d.sync += [
-                      a_e.eq(a_e + 1),
-                      a_m.eq(a_m >> 1),
-                      a_m[0].eq(a_m[0] | a_m[1]) # moo??
+                      a.e.eq(a.e + 1),
+                      a.m.eq(a.m >> 1),
+                      a.m[0].eq(a.m[0] | a.m[1]) # moo??
                     ]
                 # exponents equal: move to next stage.
                 with m.Else():
@@ -204,24 +204,24 @@ class FPADD:
 
             with m.State("add_0"):
                 m.next = "add_1"
-                m.d.sync += z_e.eq(a_e)
+                m.d.sync += z.e.eq(a.e)
                 # same-sign (both negative or both positive) add mantissas
-                with m.If(a_s == b_s):
+                with m.If(a.s == b.s):
                     m.d.sync += [
-                        tot.eq(a_m + b_m),
-                        z_s.eq(a_s)
+                        tot.eq(a.m + b.m),
+                        z_s.eq(a.s)
                     ]
                 # a mantissa greater than b, use a
-                with m.Elif(a_m >= b_m):
+                with m.Elif(a.m >= b.m):
                     m.d.sync += [
-                        tot.eq(a_m - b_m),
-                        z_s.eq(a_s)
+                        tot.eq(a.m - b.m),
+                        z_s.eq(a.s)
                     ]
                 # b mantissa greater than a, use b
                 with m.Else():
                     m.d.sync += [
-                        tot.eq(b_m - a_m),
-                        z_s.eq(b_s)
+                        tot.eq(b.m - a.m),
+                        z_s.eq(b.s)
                 ]
 
             # ******
@@ -233,16 +233,16 @@ class FPADD:
                 # tot[27] gets set when the sum overflows. shift result down
                 with m.If(tot[27]):
                     m.d.sync += [
-                        z_m.eq(tot[4:28]),
+                        z.m.eq(tot[4:28]),
                         guard.eq(tot[3]),
                         round_bit.eq(tot[2]),
                         sticky.eq(tot[1] | tot[0]),
-                        z_e.eq(z_e + 1)
+                        z.e.eq(z.e + 1)
                 ]
                 # tot[27] zero case
                 with m.Else():
                     m.d.sync += [
-                        z_m.eq(tot[3:27]),
+                        z.m.eq(tot[3:27]),
                         guard.eq(tot[2]),
                         round_bit.eq(tot[1]),
                         sticky.eq(tot[0])
@@ -256,11 +256,11 @@ class FPADD:
             #       the extra mantissa bits coming from tot[0..2]
 
             with m.State("normalise_1"):
-                with m.If((z_m[23] == 0) & (z_e > -126)):
+                with m.If((z.m[23] == 0) & (z.e > -126)):
                     m.d.sync +=[
-                        z_e.eq(z_e - 1),  # DECREASE exponent
-                        z_m.eq(z_m << 1), # shift mantissa UP
-                        z_m[0].eq(guard), # steal guard bit (was tot[2])
+                        z.e.eq(z.e - 1),  # DECREASE exponent
+                        z.m.eq(z.m << 1), # shift mantissa UP
+                        z.m[0].eq(guard), # steal guard bit (was tot[2])
                         guard.eq(round_bit), # steal round_bit (was tot[1])
                     ]
                 with m.Else():
@@ -274,11 +274,11 @@ class FPADD:
             #       the extra mantissa bits coming from tot[0..2]
 
             with m.State("normalise_2"):
-                with m.If(z_e < -126):
+                with m.If(z.e < -126):
                     m.d.sync +=[
-                        z_e.eq(z_e + 1),  # INCREASE exponent
-                        z_m.eq(z_m >> 1), # shift mantissa DOWN
-                        guard.eq(z_m[0]),
+                        z.e.eq(z.e + 1),  # INCREASE exponent
+                        z.m.eq(z.m >> 1), # shift mantissa DOWN
+                        guard.eq(z.m[0]),
                         round_bit.eq(guard),
                         sticky.eq(sticky | round_bit)
                     ]
@@ -290,10 +290,10 @@ class FPADD:
 
             with m.State("round"):
                 m.next = "pack"
-                with m.If(guard & (round_bit | sticky | z_m[0])):
-                    m.d.sync += z_m.eq(z_m + 1) # mantissa rounds up
-                    with m.If(z_m == 0xffffff): # all 1s
-                        m.d.sync += z_e.eq(z_e + 1) # exponent rounds up
+                with m.If(guard & (round_bit | sticky | z.m[0])):
+                    m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
+                    with m.If(z.m == 0xffffff): # all 1s
+                        m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
 
         return m