switch to better CLDivRem algorithm
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 5 May 2022 07:33:05 +0000 (00:33 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 5 May 2022 07:33:05 +0000 (00:33 -0700)
src/nmigen_gf/hdl/cldivrem.py
src/nmigen_gf/hdl/test/test_cldivrem.py

index 02ce034dcfb92aec4f6d2aa0eb0406e9c5358d94..73bdd1e9b75f144bcd81d5ad961b688f1a70f8c7 100644 (file)
@@ -121,26 +121,26 @@ def cldivrem_shifting(n, d, width):
     assert isinstance(d, int) and 0 <= d < 1 << width
     assert d != 0, "TODO: decide what happens on division by zero"
 
-    shift_wid = (width - 1).bit_length()
+    shape = CLDivRemShape(width)
 
     # `clz(d, width)`, but maxes out at `width - 1` instead of `width` in
-    # order to both fit in `shift_wid` bits and not shift by more than needed.
+    # order to both fit in `shape.shift_width` bits and to not shift by more
+    # than needed.
     shift = clz(d >> 1, width - 1)
-    assert shift < 1 << shift_wid, f"shift overflows a {shift_wid}-bit signal"
+    assert 0 <= shift < 1 << shape.shift_width, "shift overflow"
     d <<= shift
-    assert d < 1 << width, f"d overflows a {width}-bit signal"
-    n <<= shift
-    assert n < 1 << (width * 2), f"n overflows a {width * 2}-bit signal"
-    r = n
+    assert 0 <= d < 1 << shape.d_width, "d overflow"
+    r = n << shift
+    assert 0 <= r < 1 << shape.r_width, "r overflow"
     q = 0
-    for _ in range(width):
+    for step in range(width):
         q <<= 1
         r <<= 1
         if r >> (width * 2 - 1) != 0:
             r ^= d << width
             q |= 1
-        assert q < 1 << width, f"q overflows a {width}-bit signal"
-        assert r < 1 << (width * 2), f"r overflows a {width * 2}-bit signal"
+        assert 0 <= q < 1 << shape.q_width, "q overflow"
+        assert 0 <= r < 1 << shape.r_width, "r overflow"
     r >>= width
     r >>= shift
     return q, r
@@ -149,44 +149,73 @@ def cldivrem_shifting(n, d, width):
 @dataclass(frozen=True, unsafe_hash=True)
 class CLDivRemShape:
     width: int
-    n_width: int
 
     def __post_init__(self):
-        assert self.n_width >= self.width > 0
+        assert isinstance(self.width, int) and self.width >= 1, "invalid width"
 
     @property
     def done_step(self):
+        """the step number when iteration is finished
+        -- the largest `CLDivRemState.step` will get
+        """
         return self.width
 
     @property
     def step_range(self):
+        """the range that `CLDivRemState.step` will fall in.
+
+        returns: range
+        """
         return range(self.done_step + 1)
 
+    @property
+    def d_width(self):
+        """bit-width of the internal signal `CLDivRemState.d`"""
+        return self.width
+
+    @property
+    def r_width(self):
+        """bit-width of the internal signal `CLDivRemState.r`"""
+        return self.width * 2
+
+    @property
+    def q_width(self):
+        """bit-width of the internal signal `CLDivRemState.q`"""
+        return self.width
+
+    @property
+    def shift_width(self):
+        """bit-width of the internal signal `CLDivRemState.shift`"""
+        return (self.width - 1).bit_length()
+
 
 @dataclass(frozen=True, eq=False)
 class CLDivRemState:
     shape: CLDivRemShape
     name: str
+    step: Signal = field(init=False)
     d: Signal = field(init=False)
     r: Signal = field(init=False)
     q: Signal = field(init=False)
-    step: Signal = field(init=False)
+    shift: Signal = field(init=False)
 
     def __init__(self, shape, *, name=None, src_loc_at=0):
         assert isinstance(shape, CLDivRemShape)
         if name is None:
             name = Signal(src_loc_at=1 + src_loc_at).name
         assert isinstance(name, str)
-        d = Signal(2 * shape.width, name=f"{name}_d")
-        r = Signal(shape.n_width, name=f"{name}_r")
-        q = Signal(shape.width, name=f"{name}_q")
-        step = Signal(shape.width, name=f"{name}_step")
+        step = Signal(shape.step_range, name=f"{name}_step")
+        d = Signal(shape.d_width, name=f"{name}_d")
+        r = Signal(shape.r_width, name=f"{name}_r")
+        q = Signal(shape.q_width, name=f"{name}_q")
+        shift = Signal(shape.shift_width, name=f"{name}_shift")
         object.__setattr__(self, "shape", shape)
         object.__setattr__(self, "name", name)
+        object.__setattr__(self, "step", step)
         object.__setattr__(self, "d", d)
         object.__setattr__(self, "r", r)
         object.__setattr__(self, "q", q)
-        object.__setattr__(self, "step", step)
+        object.__setattr__(self, "shift", shift)
 
     def eq(self, rhs):
         assert isinstance(rhs, CLDivRemState)
@@ -212,11 +241,24 @@ class CLDivRemState:
         assert isinstance(steps, int) and steps >= 0
         return self.step >= max(0, self.shape.done_step - steps)
 
+    def get_output(self):
+        return self.q, (self.r >> self.shape.width) >> self.shift
+
     def set_to_initial(self, m, n, d):
         assert isinstance(m, Module)
+        n = Value.cast(n)  # convert to Value
+        d = Value.cast(d)  # convert to Value
+        clz_mod = CLZ(self.shape.width - 1)
+        # can't name submodule since it would conflict if this function is
+        # called multiple times in a Module
+        m.submodules += clz_mod
+        assert clz_mod.lz.width == self.shape.shift_width, \
+            "internal inconsistency -- mismatched shift signal width"
         m.d.comb += [
-            self.d.eq(Value.cast(d) << self.shape.width),
-            self.r.eq(n),
+            clz_mod.sig_in.eq(d >> 1),
+            self.shift.eq(clz_mod.lz),
+            self.d.eq(d << self.shift),
+            self.r.eq(n << self.shift),
             self.q.eq(0),
             self.step.eq(0),
         ]
@@ -226,31 +268,27 @@ class CLDivRemState:
         assert isinstance(state_in, CLDivRemState)
         assert state_in.shape == self.shape
         assert self is not state_in, "a.set_to_next(m, a) is not allowed"
-
-        equal_leading_zero_count = EqualLeadingZeroCount(self.shape.n_width)
-        # can't name submodule since it would conflict if this function is
-        # called multiple times in a Module
-        m.submodules += equal_leading_zero_count
+        width = self.shape.width
 
         with m.If(state_in.done):
             m.d.comb += self.eq(state_in)
         with m.Else():
             m.d.comb += [
                 self.step.eq(state_in.step + 1),
-                self.d.eq(state_in.d >> 1),
-                equal_leading_zero_count.a.eq(self.d),
-                equal_leading_zero_count.b.eq(state_in.r),
+                self.d.eq(state_in.d),
+                self.shift.eq(state_in.shift),
             ]
-            d_top = self.d[self.shape.n_width:]
-            with m.If(equal_leading_zero_count.out & (d_top == 0)):
+            q = state_in.q << 1
+            r = state_in.r << 1
+            with m.If(r[width * 2 - 1]):
                 m.d.comb += [
-                    self.r.eq(state_in.r ^ self.d),
-                    self.q.eq((state_in.q << 1) | 1),
+                    self.q.eq(q | 1),
+                    self.r.eq(r ^ (state_in.d << width)),
                 ]
             with m.Else():
                 m.d.comb += [
-                    self.r.eq(state_in.r),
-                    self.q.eq(state_in.q << 1),
+                    self.q.eq(q),
+                    self.r.eq(r),
                 ]
 
 
@@ -258,7 +296,7 @@ class CLDivRemInputData:
     def __init__(self, shape):
         assert isinstance(shape, CLDivRemShape)
         self.shape = shape
-        self.n = Signal(shape.n_width)
+        self.n = Signal(shape.width)
         self.d = Signal(shape.width)
 
     def __iter__(self):
@@ -302,31 +340,15 @@ class CLDivRemFSMStage(ControlBase):
         the shape
     steps_per_clock: int
         number of steps that should be taken per clock cycle
-    in_valid: Signal()
-        input. true when the data inputs (`n` and `d`) are valid.
-        data transfer in occurs when `in_valid & in_ready`.
-    in_ready: Signal()
-        output. true when this FSM is ready to accept input.
-        data transfer in occurs when `in_valid & in_ready`.
-    n: Signal(shape.n_width)
-        numerator in, the value must be small enough that `q` and `r` don't
-        overflow. having `n_width == width` is sufficient.
-    d: Signal(shape.width)
-        denominator in, must be non-zero.
-    q: Signal(shape.width)
-        quotient out.
-    r: Signal(shape.width)
-        remainder out.
-    out_valid: Signal()
-        output. true when the data outputs (`q` and `r`) are valid
-        (or are junk because the inputs were out of range).
-        data transfer out occurs when `out_valid & out_ready`.
-    out_ready: Signal()
-        input. true when the output can be read.
-        data transfer out occurs when `out_valid & out_ready`.
+    pspec:
+        pipe-spec
+    empty: Signal()
+        true if nothing is stored in `self.saved_state`
+    saved_state: CLDivRemState()
+        the saved state that is currently being worked on.
     """
 
-    def __init__(self, pspec, shape, *, steps_per_clock=4):
+    def __init__(self, pspec, shape, *, steps_per_clock=8):
         assert isinstance(shape, CLDivRemShape)
         assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
         self.shape = shape
@@ -352,9 +374,7 @@ class CLDivRemFSMStage(ControlBase):
 
         # TODO: handle cancellation
 
-        state_will_be_done = self.saved_state.will_be_done_after(
-            self.steps_per_clock)
-        m.d.comb += self.n.o_valid.eq(~self.empty & state_will_be_done)
+        m.d.comb += self.n.o_valid.eq(~self.empty & self.saved_state.done)
         m.d.comb += self.p.o_ready.eq(self.empty)
 
         def make_nc(i):
@@ -362,19 +382,21 @@ class CLDivRemFSMStage(ControlBase):
         next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)]
         for i in range(self.steps_per_clock):
             next_chain[i + 1].set_to_next(m, next_chain[i])
-        m.d.sync += self.saved_state.eq(next_chain[-1])
-        m.d.comb += o_data.q.eq(next_chain[-1].q)
-        m.d.comb += o_data.r.eq(next_chain[-1].r)
+        m.d.comb += next_chain[0].eq(self.saved_state)
+        out_q, out_r = self.saved_state.get_output()
+        m.d.comb += o_data.q.eq(out_q)
+        m.d.comb += o_data.r.eq(out_r)
+        initial_state = CLDivRemState(self.shape)
+        initial_state.set_to_initial(m, n=i_data.n, d=i_data.d)
 
         with m.If(self.empty):
-            next_chain[0].set_to_initial(m, n=i_data.n, d=i_data.d)
+            m.d.sync += self.saved_state.eq(initial_state)
             with m.If(self.p.i_valid):
                 m.d.sync += self.empty.eq(0)
         with m.Else():
-            m.d.comb += next_chain[0].eq(self.saved_state)
+            m.d.sync += self.saved_state.eq(next_chain[-1])
             with m.If(self.n.i_ready & self.n.o_valid):
                 m.d.sync += self.empty.eq(1)
-
         return m
 
     def __iter__(self):
index de7cf34025e5e7c13e173d3c30a3440abb63c2f1..27820705176c7528d2298a499292f72b03377b9b 100644 (file)
@@ -144,9 +144,12 @@ class TestCLDivRemShifting(FHDLTestCase):
 class TestCLDivRemComb(FHDLTestCase):
     def tst(self, shape, full):
         assert isinstance(shape, CLDivRemShape)
+        width = shape.width
         m = Module()
-        n_in = Signal(shape.n_width)
-        d_in = Signal(shape.width)
+        n_in = Signal(width)
+        d_in = Signal(width)
+        q_out = Signal(width)
+        r_out = Signal(width)
         states: "list[CLDivRemState]" = []
         for i in shape.step_range:
             states.append(CLDivRemState(shape, name=f"state_{i}"))
@@ -154,13 +157,14 @@ class TestCLDivRemComb(FHDLTestCase):
                 states[i].set_to_initial(m, n=n_in, d=d_in)
             else:
                 states[i].set_to_next(m, states[i - 1])
+        q, r = states[-1].get_output()
+        m.d.comb += [q_out.eq(q), r_out.eq(r)]
 
         def case(n, d):
             assert isinstance(n, int)
             assert isinstance(d, int)
-            max_width = max(shape.width, shape.n_width)
             if d != 0:
-                expected_q, expected_r = cldivrem(n, d, width=max_width)
+                expected_q, expected_r = cldivrem_shifting(n, d, width)
             else:
                 expected_q = expected_r = 0
             with self.subTest(n=hex(n), d=hex(d),
@@ -175,35 +179,38 @@ class TestCLDivRemComb(FHDLTestCase):
                         step = yield states[i].step
                         self.assertEqual(done, i >= shape.done_step)
                         self.assertEqual(step, i)
-                q = yield states[-1].q
-                r = yield states[-1].r
+                q = yield q_out
+                r = yield r_out
                 with self.subTest(q=hex(q), r=hex(r)):
                     # only check results when inputs are valid
-                    if d != 0 and (expected_q >> shape.width) == 0:
+                    if d != 0:
                         self.assertEqual(q, expected_q)
                         self.assertEqual(r, expected_r)
 
         def process():
             if full:
-                for n in range(1 << shape.n_width):
-                    for d in range(1 << shape.width):
+                for n in range(1 << width):
+                    for d in range(1 << width):
                         yield from case(n, d)
             else:
                 for i in range(100):
                     n = hash_256(f"cldivrem comb n {i}")
-                    n = Const.normalize(n, unsigned(shape.n_width))
+                    n = Const.normalize(n, unsigned(width))
                     d = hash_256(f"cldivrem comb d {i}")
-                    d = Const.normalize(d, unsigned(shape.width))
+                    d = Const.normalize(d, unsigned(width))
                     yield from case(n, d)
-        with do_sim(self, m, [n_in, d_in, states[-1].q, states[-1].r]) as sim:
+        with do_sim(self, m, [n_in, d_in, q_out, r_out]) as sim:
             sim.add_process(process)
             sim.run()
 
     def test_4(self):
-        self.tst(CLDivRemShape(width=4, n_width=4), full=True)
+        self.tst(CLDivRemShape(width=4), full=True)
+
+    def test_6(self):
+        self.tst(CLDivRemShape(width=6), full=True)
 
-    def test_8_by_4(self):
-        self.tst(CLDivRemShape(width=4, n_width=8), full=True)
+    def test_8(self):
+        self.tst(CLDivRemShape(width=8), full=False)
 
 
 class TestCLDivRemFSM(FHDLTestCase):
@@ -214,7 +221,7 @@ class TestCLDivRemFSM(FHDLTestCase):
         dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock)
         i_data: CLDivRemInputData = dut.p.i_data
         o_data: CLDivRemOutputData = dut.n.o_data
-        self.assertEqual(i_data.n.shape(), unsigned(shape.n_width))
+        self.assertEqual(i_data.n.shape(), unsigned(shape.width))
         self.assertEqual(i_data.d.shape(), unsigned(shape.width))
         self.assertEqual(o_data.q.shape(), unsigned(shape.width))
         self.assertEqual(o_data.r.shape(), unsigned(shape.width))
@@ -222,9 +229,8 @@ class TestCLDivRemFSM(FHDLTestCase):
         def case(n, d):
             assert isinstance(n, int)
             assert isinstance(d, int)
-            max_width = max(shape.width, shape.n_width)
             if d != 0:
-                expected_q, expected_r = cldivrem(n, d, width=max_width)
+                expected_q, expected_r = cldivrem(n, d, width=shape.width)
             else:
                 expected_q = expected_r = 0
             with self.subTest(n=hex(n), d=hex(d),
@@ -245,8 +251,7 @@ class TestCLDivRemFSM(FHDLTestCase):
                 yield i_data.n.eq(-1)
                 yield i_data.d.eq(-1)
                 yield dut.p.i_valid.eq(0)
-                for i in range(steps_per_clock * 2, shape.done_step,
-                               steps_per_clock):
+                for step in range(0, shape.done_step, steps_per_clock):
                     yield Delay(0.1e-6)
                     valid = yield dut.n.o_valid
                     ready = yield dut.p.o_ready
@@ -279,13 +284,13 @@ class TestCLDivRemFSM(FHDLTestCase):
 
         def process():
             if full:
-                for n in range(1 << shape.n_width):
+                for n in range(1 << shape.width):
                     for d in range(1 << shape.width):
                         yield from case(n, d)
             else:
                 for i in range(100):
                     n = hash_256(f"cldivrem fsm n {i}")
-                    n = Const.normalize(n, unsigned(shape.n_width))
+                    n = Const.normalize(n, unsigned(shape.width))
                     d = hash_256(f"cldivrem fsm d {i}")
                     d = Const.normalize(d, unsigned(shape.width))
                     yield from case(n, d)
@@ -296,20 +301,40 @@ class TestCLDivRemFSM(FHDLTestCase):
             sim.run()
 
     def test_4_step_1(self):
-        self.tst(CLDivRemShape(width=4, n_width=4),
+        self.tst(CLDivRemShape(width=4),
                  full=True,
                  steps_per_clock=1)
 
     def test_4_step_2(self):
-        self.tst(CLDivRemShape(width=4, n_width=4),
+        self.tst(CLDivRemShape(width=4),
                  full=True,
                  steps_per_clock=2)
 
     def test_4_step_3(self):
-        self.tst(CLDivRemShape(width=4, n_width=4),
+        self.tst(CLDivRemShape(width=4),
                  full=True,
                  steps_per_clock=3)
 
+    def test_4_step_4(self):
+        self.tst(CLDivRemShape(width=4),
+                 full=True,
+                 steps_per_clock=4)
+
+    def test_8_step_4(self):
+        self.tst(CLDivRemShape(width=8),
+                 full=False,
+                 steps_per_clock=4)
+
+    def test_64_step_4(self):
+        self.tst(CLDivRemShape(width=64),
+                 full=False,
+                 steps_per_clock=4)
+
+    def test_64_step_8(self):
+        self.tst(CLDivRemShape(width=64),
+                 full=False,
+                 steps_per_clock=8)
+
 
 if __name__ == "__main__":
     unittest.main()