split step counter into clock and substep
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 6 May 2022 03:10:32 +0000 (20:10 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 6 May 2022 03:10:32 +0000 (20:10 -0700)
this allows substep to be completely optimized away by yosys for CLDivRemFSMStage

src/nmigen_gf/hdl/cldivrem.py
src/nmigen_gf/hdl/test/test_cldivrem.py

index 48105fa7e0d005854afb6da4d7a215d3431ec57d..31650e53db40ea65b429629298001be5fe2f1232 100644 (file)
@@ -10,13 +10,13 @@ https://bugs.libre-soc.org/show_bug.cgi?id=784
 """
 
 from dataclasses import dataclass, field, fields
-from nmigen.hdl.ast import Signal, Value
+from nmigen.hdl.ast import Signal, Value, Assert
 from nmigen.hdl.dsl import Module
 from nmutil.singlepipe import ControlBase
 from nmutil.clz import CLZ, clz
 
 
-def cldivrem_shifting(n, d, width):
+def cldivrem_shifting(n, d, shape):
     """ Carry-less Division and Remainder based on shifting at start and end
         allowing us to get away with checking a single bit each iteration
         rather than checking for equal degrees every iteration.
@@ -24,57 +24,104 @@ def cldivrem_shifting(n, d, width):
         each input/output.
         Returns a tuple `q, r` of the quotient and remainder.
     """
-    assert isinstance(width, int) and width >= 1
-    assert isinstance(n, int) and 0 <= n < 1 << width
-    assert isinstance(d, int) and 0 <= d < 1 << width
+    assert isinstance(shape, CLDivRemShape)
+    assert isinstance(n, int) and 0 <= n < 1 << shape.width
+    assert isinstance(d, int) and 0 <= d < 1 << shape.width
     assert d != 0, "TODO: decide what happens on division by zero"
 
-    shape = CLDivRemShape(width)
-
-    # `clz(d, width)`, but maxes out at `width - 1` instead of `width` in
-    # order to both fit in `shape.shift_width` bits and to not shift by more
-    # than needed.
-    shift = clz(d >> 1, width - 1)
-    assert 0 <= shift < 1 << shape.shift_width, "shift overflow"
-    d <<= shift
-    assert 0 <= d < 1 << shape.d_width, "d overflow"
-    r = n << shift
-    assert 0 <= r < 1 << shape.r_width, "r overflow"
-    q = 0
-    for step in range(width):
+    # declare locals so nonlocal works
+    r = q = shift = clock = substep = NotImplemented
+
+    # functions match up to HDL parts:
+
+    def set_to_initial():
+        nonlocal d, r, q, clock, substep, shift
+        # `clz(d, shape.width)`, but maxes out at `shape.width - 1` instead of
+        # `shape.width` in order to both fit in `shape.shift_width` bits and
+        # to not shift by more than needed.
+        shift = clz(d >> 1, shape.width - 1)
+        assert 0 <= shift < 1 << shape.shift_width, "shift overflow"
+        d <<= shift
+        assert 0 <= d < 1 << shape.d_width, "d overflow"
+        r = n << shift
+        assert 0 <= r < 1 << shape.r_width, "r overflow"
+        q = 0
+        clock = 0
+        substep = 0
+
+    def done():
+        return clock == shape.done_clock
+
+    def set_to_next():
+        nonlocal r, q, clock, substep
+        substep += 1
+        substep %= shape.steps_per_clock
+        if done():
+            return
+        elif substep == 0:
+            clock += 1
+        if clock == shape.width // shape.steps_per_clock \
+                and substep >= shape.width % shape.steps_per_clock:
+            clock = shape.done_clock
         q <<= 1
         r <<= 1
-        if r >> (width * 2 - 1) != 0:
-            r ^= d << width
+        if r >> (shape.width * 2 - 1) != 0:
+            r ^= d << shape.width
             q |= 1
         assert 0 <= q < 1 << shape.q_width, "q overflow"
         assert 0 <= r < 1 << shape.r_width, "r overflow"
-    r >>= width
-    r >>= shift
-    return q, r
+
+    def get_output():
+        return q, (r >> shape.width) >> shift
+
+    set_to_initial()
+
+    # one clock-cycle per outer loop
+    while not done():
+        for expected_substep in range(shape.steps_per_clock):
+            assert substep == expected_substep
+            set_to_next()
+
+    return get_output()
 
 
 @dataclass(frozen=True, unsafe_hash=True)
 class CLDivRemShape:
     width: int
+    """bit-width of each of the carry-less div/rem inputs and outputs"""
+
+    steps_per_clock: int = 8
+    """number of steps that should be taken per clock cycle"""
 
     def __post_init__(self):
         assert isinstance(self.width, int) and self.width >= 1, "invalid width"
+        assert (isinstance(self.steps_per_clock, int)
+                and self.steps_per_clock >= 1), "invalid steps_per_clock"
 
     @property
-    def done_step(self):
-        """the step number when iteration is finished
-        -- the largest `CLDivRemState.step` will get
+    def done_clock(self):
+        """the clock tick number when iteration is finished
+        -- the largest `CLDivRemState.clock` will get
         """
-        return self.width
+        if self.width % self.steps_per_clock == 0:
+            return self.width // self.steps_per_clock
+        return self.width // self.steps_per_clock + 1
 
     @property
-    def step_range(self):
-        """the range that `CLDivRemState.step` will fall in.
+    def clock_range(self):
+        """the range that `CLDivRemState.clock` will fall in.
 
         returns: range
         """
-        return range(self.done_step + 1)
+        return range(self.done_clock + 1)
+
+    @property
+    def substep_range(self):
+        """the range that `CLDivRemState.substep` will fall in.
+
+        returns: range
+        """
+        return range(self.steps_per_clock)
 
     @property
     def d_width(self):
@@ -101,7 +148,8 @@ class CLDivRemShape:
 class CLDivRemState:
     shape: CLDivRemShape
     name: str
-    step: Signal = field(init=False)
+    clock: Signal = field(init=False)
+    substep: Signal = field(init=False)
     d: Signal = field(init=False)
     r: Signal = field(init=False)
     q: Signal = field(init=False)
@@ -112,14 +160,16 @@ class CLDivRemState:
         if name is None:
             name = Signal(src_loc_at=1 + src_loc_at).name
         assert isinstance(name, str)
-        step = Signal(shape.step_range, name=f"{name}_step")
+        clock = Signal(shape.clock_range, name=f"{name}_clock")
+        substep = Signal(shape.substep_range, name=f"{name}_substep", reset=0)
         d = Signal(shape.d_width, name=f"{name}_d")
         r = Signal(shape.r_width, name=f"{name}_r")
         q = Signal(shape.q_width, name=f"{name}_q")
         shift = Signal(shape.shift_width, name=f"{name}_shift")
         object.__setattr__(self, "shape", shape)
         object.__setattr__(self, "name", name)
-        object.__setattr__(self, "step", step)
+        object.__setattr__(self, "clock", clock)
+        object.__setattr__(self, "substep", substep)
         object.__setattr__(self, "d", d)
         object.__setattr__(self, "r", r)
         object.__setattr__(self, "q", q)
@@ -141,13 +191,7 @@ class CLDivRemState:
 
     @property
     def done(self):
-        return self.will_be_done_after(steps=0)
-
-    def will_be_done_after(self, steps):
-        """ Returns True if this state will be done after
-            another `steps` passes through `set_to_next`."""
-        assert isinstance(steps, int) and steps >= 0
-        return self.step >= max(0, self.shape.done_step - steps)
+        return self.clock == self.shape.done_clock
 
     def get_output(self):
         return self.q, (self.r >> self.shape.width) >> self.shift
@@ -168,21 +212,51 @@ class CLDivRemState:
             self.d.eq(d << self.shift),
             self.r.eq(n << self.shift),
             self.q.eq(0),
-            self.step.eq(0),
+            self.clock.eq(0),
+            self.substep.eq(0),
         ]
 
+    def eq_but_zero_substep(self, rhs, do_assert):
+        assert isinstance(rhs, CLDivRemState)
+        for f in fields(CLDivRemState):
+            if f.name in ("shape", "name"):
+                continue
+            l = getattr(self, f.name)
+            r = getattr(rhs, f.name)
+            if f.name == "substep":
+                if do_assert:
+                    yield Assert(r == 0)
+                r = 0
+            yield l.eq(r)
+
     def set_to_next(self, m, state_in):
         assert isinstance(m, Module)
         assert isinstance(state_in, CLDivRemState)
         assert state_in.shape == self.shape
         assert self is not state_in, "a.set_to_next(m, a) is not allowed"
         width = self.shape.width
+        substep_wraps = state_in.substep >= self.shape.steps_per_clock - 1
+        with m.If(substep_wraps):
+            m.d.comb += self.substep.eq(0)
+        with m.Else():
+            m.d.comb += self.substep.eq(state_in.substep + 1)
 
         with m.If(state_in.done):
-            m.d.comb += self.eq(state_in)
+            m.d.comb += [
+                self.clock.eq(state_in.clock),
+                self.d.eq(state_in.d),
+                self.r.eq(state_in.r),
+                self.q.eq(state_in.q),
+                self.shift.eq(state_in.shift),
+            ]
         with m.Else():
+            clock = state_in.clock + substep_wraps
+            with m.If((clock == width // self.shape.steps_per_clock)
+                      & (self.substep >= width % self.shape.steps_per_clock)):
+                m.d.comb += self.clock.eq(self.shape.done_clock)
+            with m.Else():
+                m.d.comb += self.clock.eq(clock)
             m.d.comb += [
-                self.step.eq(state_in.step + 1),
                 self.d.eq(state_in.d),
                 self.shift.eq(state_in.shift),
             ]
@@ -239,6 +313,12 @@ class CLDivRemOutputData:
             self.r.eq(rhs.r),
         ]
 
+    def eq_output(self, state):
+        assert isinstance(state, CLDivRemState)
+        assert state.shape == self.shape
+        q, r = state.get_output()
+        return [self.q.eq(q), self.r.eq(r)]
+
 
 class CLDivRemFSMStage(ControlBase):
     """carry-less div/rem
@@ -246,8 +326,6 @@ class CLDivRemFSMStage(ControlBase):
     Attributes:
     shape: CLDivRemShape
         the shape
-    steps_per_clock: int
-        number of steps that should be taken per clock cycle
     pspec:
         pipe-spec
     empty: Signal()
@@ -256,11 +334,9 @@ class CLDivRemFSMStage(ControlBase):
         the saved state that is currently being worked on.
     """
 
-    def __init__(self, pspec, shape, *, steps_per_clock=8):
+    def __init__(self, pspec, shape):
         assert isinstance(shape, CLDivRemShape)
-        assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
         self.shape = shape
-        self.steps_per_clock = steps_per_clock
         self.pspec = pspec  # store now: used in ispec and ospec
         super().__init__(stage=self)
         self.empty = Signal(reset=1)
@@ -279,6 +355,7 @@ class CLDivRemFSMStage(ControlBase):
         m = super().elaborate(platform)
         i_data: CLDivRemInputData = self.p.i_data
         o_data: CLDivRemOutputData = self.n.o_data
+        steps_per_clock = self.shape.steps_per_clock
 
         # TODO: handle cancellation
 
@@ -287,22 +364,24 @@ class CLDivRemFSMStage(ControlBase):
 
         def make_nc(i):
             return CLDivRemState(self.shape, name=f"next_chain_{i}")
-        next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)]
-        for i in range(self.steps_per_clock):
+        next_chain = [make_nc(i) for i in range(steps_per_clock + 1)]
+        for i in range(steps_per_clock):
             next_chain[i + 1].set_to_next(m, next_chain[i])
         m.d.comb += next_chain[0].eq(self.saved_state)
-        out_q, out_r = self.saved_state.get_output()
-        m.d.comb += o_data.q.eq(out_q)
-        m.d.comb += o_data.r.eq(out_r)
+        m.d.comb += o_data.eq_output(self.saved_state)
         initial_state = CLDivRemState(self.shape)
         initial_state.set_to_initial(m, n=i_data.n, d=i_data.d)
 
+        do_assert = platform == "formal"
+
         with m.If(self.empty):
-            m.d.sync += self.saved_state.eq(initial_state)
+            m.d.sync += self.saved_state.eq_but_zero_substep(initial_state,
+                                                             do_assert)
             with m.If(self.p.i_valid):
                 m.d.sync += self.empty.eq(0)
         with m.Else():
-            m.d.sync += self.saved_state.eq(next_chain[-1])
+            m.d.sync += self.saved_state.eq_but_zero_substep(next_chain[-1],
+                                                             do_assert)
             with m.If(self.n.i_ready & self.n.o_valid):
                 m.d.sync += self.empty.eq(1)
         return m
index 438b547fca5cc822f91c6add9f71cadd2ddfe8cb..efcd23e08f8511e51e0e5e34a37ef10201ac5f9e 100644 (file)
@@ -17,39 +17,68 @@ from nmigen_gf.reference.cldivrem import cldivrem
 
 
 class TestCLDivRemShifting(FHDLTestCase):
-    def tst(self, width, full):
+    def tst(self, shape, full):
+        assert isinstance(shape, CLDivRemShape)
+
         def case(n, d):
             assert isinstance(n, int)
             assert isinstance(d, int)
             if d != 0:
-                expected_q, expected_r = cldivrem(n, d, width=width)
-                q, r = cldivrem_shifting(n, d, width=width)
+                expected_q, expected_r = cldivrem(n, d, width=shape.width)
+                q, r = cldivrem_shifting(n, d, shape)
             else:
                 expected_q = expected_r = 0
                 q = r = 0
-            with self.subTest(n=hex(n), d=hex(d),
-                              expected_q=hex(expected_q),
+            with self.subTest(expected_q=hex(expected_q),
                               expected_r=hex(expected_r),
                               q=hex(q), r=hex(r)):
                 self.assertEqual(expected_q, q)
                 self.assertEqual(expected_r, r)
         if full:
-            for n in range(1 << width):
-                for d in range(1 << width):
-                    case(n, d)
+            for n in range(1 << shape.width):
+                for d in range(1 << shape.width):
+                    with self.subTest(n=hex(n), d=hex(d)):
+                        case(n, d)
         else:
             for i in range(100):
                 n = hash_256(f"cldivrem comb n {i}")
-                n = Const.normalize(n, unsigned(width))
+                n = Const.normalize(n, unsigned(shape.width))
                 d = hash_256(f"cldivrem comb d {i}")
-                d = Const.normalize(d, unsigned(width))
+                d = Const.normalize(d, unsigned(shape.width))
                 case(n, d)
 
-    def test_6(self):
-        self.tst(6, full=True)
+    def test_6_step_1(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=1), full=True)
+
+    def test_6_step_2(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=2), full=True)
+
+    def test_6_step_3(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=3), full=True)
 
-    def test_64(self):
-        self.tst(64, full=False)
+    def test_6_step_4(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=4), full=True)
+
+    def test_6_step_6(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=6), full=True)
+
+    def test_6_step_10(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=10), full=True)
+
+    def test_64_step_1(self):
+        self.tst(CLDivRemShape(width=64, steps_per_clock=1), full=False)
+
+    def test_64_step_2(self):
+        self.tst(CLDivRemShape(width=64, steps_per_clock=2), full=False)
+
+    def test_64_step_3(self):
+        self.tst(CLDivRemShape(width=64, steps_per_clock=3), full=False)
+
+    def test_64_step_4(self):
+        self.tst(CLDivRemShape(width=64, steps_per_clock=4), full=False)
+
+    def test_64_step_8(self):
+        self.tst(CLDivRemShape(width=64, steps_per_clock=8), full=False)
 
 
 class TestCLDivRemComb(FHDLTestCase):
@@ -62,7 +91,7 @@ class TestCLDivRemComb(FHDLTestCase):
         q_out = Signal(width)
         r_out = Signal(width)
         states: "list[CLDivRemState]" = []
-        for i in shape.step_range:
+        for i in range(shape.width + 1):
             states.append(CLDivRemState(shape, name=f"state_{i}"))
             if i == 0:
                 states[i].set_to_initial(m, n=n_in, d=d_in)
@@ -75,7 +104,7 @@ class TestCLDivRemComb(FHDLTestCase):
             assert isinstance(n, int)
             assert isinstance(d, int)
             if d != 0:
-                expected_q, expected_r = cldivrem_shifting(n, d, width)
+                expected_q, expected_r = cldivrem_shifting(n, d, shape)
             else:
                 expected_q = expected_r = 0
             with self.subTest(n=hex(n), d=hex(d),
@@ -84,12 +113,17 @@ class TestCLDivRemComb(FHDLTestCase):
                 yield n_in.eq(n)
                 yield d_in.eq(d)
                 yield Delay(1e-6)
-                for i in shape.step_range:
+                for i, state in enumerate(states):
                     with self.subTest(i=i):
-                        done = yield states[i].done
-                        step = yield states[i].step
-                        self.assertEqual(done, i >= shape.done_step)
-                        self.assertEqual(step, i)
+                        done = yield state.done
+                        substep = yield state.substep
+                        clock = yield state.clock
+                        self.assertEqual(done, i >= shape.width)
+                        if i % shape.steps_per_clock == 0:
+                            self.assertEqual(substep, 0)
+                        if i < shape.width:
+                            self.assertEqual(substep
+                                             + clock * shape.steps_per_clock, i)
                 q = yield q_out
                 r = yield r_out
                 with self.subTest(q=hex(q), r=hex(r)):
@@ -114,22 +148,45 @@ class TestCLDivRemComb(FHDLTestCase):
             sim.add_process(process)
             sim.run()
 
-    def test_4(self):
-        self.tst(CLDivRemShape(width=4), full=True)
+    def test_4_step_1(self):
+        self.tst(CLDivRemShape(width=4, steps_per_clock=1), full=True)
+
+    def test_4_step_2(self):
+        self.tst(CLDivRemShape(width=4, steps_per_clock=2), full=True)
+
+    def test_4_step_3(self):
+        self.tst(CLDivRemShape(width=4, steps_per_clock=3), full=True)
+
+    def test_4_step_4(self):
+        self.tst(CLDivRemShape(width=4, steps_per_clock=4), full=True)
 
-    def test_6(self):
-        self.tst(CLDivRemShape(width=6), full=True)
+    def test_6_step_1(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=1), full=False)
 
-    def test_8(self):
-        self.tst(CLDivRemShape(width=8), full=False)
+    def test_6_step_2(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=2), full=False)
+
+    def test_6_step_6(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=6), full=False)
+
+    def test_6_step_8(self):
+        self.tst(CLDivRemShape(width=6, steps_per_clock=8), full=False)
+
+    def test_8_step_1(self):
+        self.tst(CLDivRemShape(width=8, steps_per_clock=1), full=False)
+
+    def test_8_step_4(self):
+        self.tst(CLDivRemShape(width=8, steps_per_clock=4), full=False)
+
+    def test_8_step_8(self):
+        self.tst(CLDivRemShape(width=8, steps_per_clock=8), full=False)
 
 
 class TestCLDivRemFSM(FHDLTestCase):
-    def tst(self, shape, full, steps_per_clock):
+    def tst(self, shape, full):
         assert isinstance(shape, CLDivRemShape)
-        assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
         pspec = {}
-        dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock)
+        dut = CLDivRemFSMStage(pspec, shape)
         i_data: CLDivRemInputData = dut.p.i_data
         o_data: CLDivRemOutputData = dut.n.o_data
         self.assertEqual(i_data.n.shape(), unsigned(shape.width))
@@ -162,7 +219,7 @@ class TestCLDivRemFSM(FHDLTestCase):
                 yield i_data.n.eq(-1)
                 yield i_data.d.eq(-1)
                 yield dut.p.i_valid.eq(0)
-                for step in range(0, shape.done_step, steps_per_clock):
+                for step in range(shape.done_clock):
                     yield Delay(0.1e-6)
                     valid = yield dut.n.o_valid
                     ready = yield dut.p.o_ready
@@ -212,39 +269,25 @@ class TestCLDivRemFSM(FHDLTestCase):
             sim.run()
 
     def test_4_step_1(self):
-        self.tst(CLDivRemShape(width=4),
-                 full=True,
-                 steps_per_clock=1)
+        self.tst(CLDivRemShape(width=4, steps_per_clock=1), full=True)
 
     def test_4_step_2(self):
-        self.tst(CLDivRemShape(width=4),
-                 full=True,
-                 steps_per_clock=2)
+        self.tst(CLDivRemShape(width=4, steps_per_clock=2), full=True)
 
     def test_4_step_3(self):
-        self.tst(CLDivRemShape(width=4),
-                 full=True,
-                 steps_per_clock=3)
+        self.tst(CLDivRemShape(width=4, steps_per_clock=3), full=True)
 
     def test_4_step_4(self):
-        self.tst(CLDivRemShape(width=4),
-                 full=True,
-                 steps_per_clock=4)
+        self.tst(CLDivRemShape(width=4, steps_per_clock=4), full=True)
 
     def test_8_step_4(self):
-        self.tst(CLDivRemShape(width=8),
-                 full=False,
-                 steps_per_clock=4)
+        self.tst(CLDivRemShape(width=8, steps_per_clock=4), full=False)
 
     def test_64_step_4(self):
-        self.tst(CLDivRemShape(width=64),
-                 full=False,
-                 steps_per_clock=4)
+        self.tst(CLDivRemShape(width=64, steps_per_clock=4), full=False)
 
     def test_64_step_8(self):
-        self.tst(CLDivRemShape(width=64),
-                 full=False,
-                 steps_per_clock=8)
+        self.tst(CLDivRemShape(width=64, steps_per_clock=8), full=False)
 
 
 if __name__ == "__main__":