From 06af1a3a6433bce62c4498458c78b76eef48ebcf Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 5 May 2022 00:33:05 -0700 Subject: [PATCH] switch to better CLDivRem algorithm --- src/nmigen_gf/hdl/cldivrem.py | 156 ++++++++++++++---------- src/nmigen_gf/hdl/test/test_cldivrem.py | 75 ++++++++---- 2 files changed, 139 insertions(+), 92 deletions(-) diff --git a/src/nmigen_gf/hdl/cldivrem.py b/src/nmigen_gf/hdl/cldivrem.py index 02ce034..73bdd1e 100644 --- a/src/nmigen_gf/hdl/cldivrem.py +++ b/src/nmigen_gf/hdl/cldivrem.py @@ -121,26 +121,26 @@ def cldivrem_shifting(n, d, width): assert isinstance(d, int) and 0 <= d < 1 << width assert d != 0, "TODO: decide what happens on division by zero" - shift_wid = (width - 1).bit_length() + shape = CLDivRemShape(width) # `clz(d, width)`, but maxes out at `width - 1` instead of `width` in - # order to both fit in `shift_wid` bits and not shift by more than needed. + # order to both fit in `shape.shift_width` bits and to not shift by more + # than needed. shift = clz(d >> 1, width - 1) - assert shift < 1 << shift_wid, f"shift overflows a {shift_wid}-bit signal" + assert 0 <= shift < 1 << shape.shift_width, "shift overflow" d <<= shift - assert d < 1 << width, f"d overflows a {width}-bit signal" - n <<= shift - assert n < 1 << (width * 2), f"n overflows a {width * 2}-bit signal" - r = n + assert 0 <= d < 1 << shape.d_width, "d overflow" + r = n << shift + assert 0 <= r < 1 << shape.r_width, "r overflow" q = 0 - for _ in range(width): + for step in range(width): q <<= 1 r <<= 1 if r >> (width * 2 - 1) != 0: r ^= d << width q |= 1 - assert q < 1 << width, f"q overflows a {width}-bit signal" - assert r < 1 << (width * 2), f"r overflows a {width * 2}-bit signal" + assert 0 <= q < 1 << shape.q_width, "q overflow" + assert 0 <= r < 1 << shape.r_width, "r overflow" r >>= width r >>= shift return q, r @@ -149,44 +149,73 @@ def cldivrem_shifting(n, d, width): @dataclass(frozen=True, unsafe_hash=True) class CLDivRemShape: width: int - n_width: int def __post_init__(self): - assert self.n_width >= self.width > 0 + assert isinstance(self.width, int) and self.width >= 1, "invalid width" @property def done_step(self): + """the step number when iteration is finished + -- the largest `CLDivRemState.step` will get + """ return self.width @property def step_range(self): + """the range that `CLDivRemState.step` will fall in. + + returns: range + """ return range(self.done_step + 1) + @property + def d_width(self): + """bit-width of the internal signal `CLDivRemState.d`""" + return self.width + + @property + def r_width(self): + """bit-width of the internal signal `CLDivRemState.r`""" + return self.width * 2 + + @property + def q_width(self): + """bit-width of the internal signal `CLDivRemState.q`""" + return self.width + + @property + def shift_width(self): + """bit-width of the internal signal `CLDivRemState.shift`""" + return (self.width - 1).bit_length() + @dataclass(frozen=True, eq=False) class CLDivRemState: shape: CLDivRemShape name: str + step: Signal = field(init=False) d: Signal = field(init=False) r: Signal = field(init=False) q: Signal = field(init=False) - step: Signal = field(init=False) + shift: Signal = field(init=False) def __init__(self, shape, *, name=None, src_loc_at=0): assert isinstance(shape, CLDivRemShape) if name is None: name = Signal(src_loc_at=1 + src_loc_at).name assert isinstance(name, str) - d = Signal(2 * shape.width, name=f"{name}_d") - r = Signal(shape.n_width, name=f"{name}_r") - q = Signal(shape.width, name=f"{name}_q") - step = Signal(shape.width, name=f"{name}_step") + step = Signal(shape.step_range, name=f"{name}_step") + d = Signal(shape.d_width, name=f"{name}_d") + r = Signal(shape.r_width, name=f"{name}_r") + q = Signal(shape.q_width, name=f"{name}_q") + shift = Signal(shape.shift_width, name=f"{name}_shift") object.__setattr__(self, "shape", shape) object.__setattr__(self, "name", name) + object.__setattr__(self, "step", step) object.__setattr__(self, "d", d) object.__setattr__(self, "r", r) object.__setattr__(self, "q", q) - object.__setattr__(self, "step", step) + object.__setattr__(self, "shift", shift) def eq(self, rhs): assert isinstance(rhs, CLDivRemState) @@ -212,11 +241,24 @@ class CLDivRemState: assert isinstance(steps, int) and steps >= 0 return self.step >= max(0, self.shape.done_step - steps) + def get_output(self): + return self.q, (self.r >> self.shape.width) >> self.shift + def set_to_initial(self, m, n, d): assert isinstance(m, Module) + n = Value.cast(n) # convert to Value + d = Value.cast(d) # convert to Value + clz_mod = CLZ(self.shape.width - 1) + # can't name submodule since it would conflict if this function is + # called multiple times in a Module + m.submodules += clz_mod + assert clz_mod.lz.width == self.shape.shift_width, \ + "internal inconsistency -- mismatched shift signal width" m.d.comb += [ - self.d.eq(Value.cast(d) << self.shape.width), - self.r.eq(n), + clz_mod.sig_in.eq(d >> 1), + self.shift.eq(clz_mod.lz), + self.d.eq(d << self.shift), + self.r.eq(n << self.shift), self.q.eq(0), self.step.eq(0), ] @@ -226,31 +268,27 @@ class CLDivRemState: assert isinstance(state_in, CLDivRemState) assert state_in.shape == self.shape assert self is not state_in, "a.set_to_next(m, a) is not allowed" - - equal_leading_zero_count = EqualLeadingZeroCount(self.shape.n_width) - # can't name submodule since it would conflict if this function is - # called multiple times in a Module - m.submodules += equal_leading_zero_count + width = self.shape.width with m.If(state_in.done): m.d.comb += self.eq(state_in) with m.Else(): m.d.comb += [ self.step.eq(state_in.step + 1), - self.d.eq(state_in.d >> 1), - equal_leading_zero_count.a.eq(self.d), - equal_leading_zero_count.b.eq(state_in.r), + self.d.eq(state_in.d), + self.shift.eq(state_in.shift), ] - d_top = self.d[self.shape.n_width:] - with m.If(equal_leading_zero_count.out & (d_top == 0)): + q = state_in.q << 1 + r = state_in.r << 1 + with m.If(r[width * 2 - 1]): m.d.comb += [ - self.r.eq(state_in.r ^ self.d), - self.q.eq((state_in.q << 1) | 1), + self.q.eq(q | 1), + self.r.eq(r ^ (state_in.d << width)), ] with m.Else(): m.d.comb += [ - self.r.eq(state_in.r), - self.q.eq(state_in.q << 1), + self.q.eq(q), + self.r.eq(r), ] @@ -258,7 +296,7 @@ class CLDivRemInputData: def __init__(self, shape): assert isinstance(shape, CLDivRemShape) self.shape = shape - self.n = Signal(shape.n_width) + self.n = Signal(shape.width) self.d = Signal(shape.width) def __iter__(self): @@ -302,31 +340,15 @@ class CLDivRemFSMStage(ControlBase): the shape steps_per_clock: int number of steps that should be taken per clock cycle - in_valid: Signal() - input. true when the data inputs (`n` and `d`) are valid. - data transfer in occurs when `in_valid & in_ready`. - in_ready: Signal() - output. true when this FSM is ready to accept input. - data transfer in occurs when `in_valid & in_ready`. - n: Signal(shape.n_width) - numerator in, the value must be small enough that `q` and `r` don't - overflow. having `n_width == width` is sufficient. - d: Signal(shape.width) - denominator in, must be non-zero. - q: Signal(shape.width) - quotient out. - r: Signal(shape.width) - remainder out. - out_valid: Signal() - output. true when the data outputs (`q` and `r`) are valid - (or are junk because the inputs were out of range). - data transfer out occurs when `out_valid & out_ready`. - out_ready: Signal() - input. true when the output can be read. - data transfer out occurs when `out_valid & out_ready`. + pspec: + pipe-spec + empty: Signal() + true if nothing is stored in `self.saved_state` + saved_state: CLDivRemState() + the saved state that is currently being worked on. """ - def __init__(self, pspec, shape, *, steps_per_clock=4): + def __init__(self, pspec, shape, *, steps_per_clock=8): assert isinstance(shape, CLDivRemShape) assert isinstance(steps_per_clock, int) and steps_per_clock >= 1 self.shape = shape @@ -352,9 +374,7 @@ class CLDivRemFSMStage(ControlBase): # TODO: handle cancellation - state_will_be_done = self.saved_state.will_be_done_after( - self.steps_per_clock) - m.d.comb += self.n.o_valid.eq(~self.empty & state_will_be_done) + m.d.comb += self.n.o_valid.eq(~self.empty & self.saved_state.done) m.d.comb += self.p.o_ready.eq(self.empty) def make_nc(i): @@ -362,19 +382,21 @@ class CLDivRemFSMStage(ControlBase): next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)] for i in range(self.steps_per_clock): next_chain[i + 1].set_to_next(m, next_chain[i]) - m.d.sync += self.saved_state.eq(next_chain[-1]) - m.d.comb += o_data.q.eq(next_chain[-1].q) - m.d.comb += o_data.r.eq(next_chain[-1].r) + m.d.comb += next_chain[0].eq(self.saved_state) + out_q, out_r = self.saved_state.get_output() + m.d.comb += o_data.q.eq(out_q) + m.d.comb += o_data.r.eq(out_r) + initial_state = CLDivRemState(self.shape) + initial_state.set_to_initial(m, n=i_data.n, d=i_data.d) with m.If(self.empty): - next_chain[0].set_to_initial(m, n=i_data.n, d=i_data.d) + m.d.sync += self.saved_state.eq(initial_state) with m.If(self.p.i_valid): m.d.sync += self.empty.eq(0) with m.Else(): - m.d.comb += next_chain[0].eq(self.saved_state) + m.d.sync += self.saved_state.eq(next_chain[-1]) with m.If(self.n.i_ready & self.n.o_valid): m.d.sync += self.empty.eq(1) - return m def __iter__(self): diff --git a/src/nmigen_gf/hdl/test/test_cldivrem.py b/src/nmigen_gf/hdl/test/test_cldivrem.py index de7cf34..2782070 100644 --- a/src/nmigen_gf/hdl/test/test_cldivrem.py +++ b/src/nmigen_gf/hdl/test/test_cldivrem.py @@ -144,9 +144,12 @@ class TestCLDivRemShifting(FHDLTestCase): class TestCLDivRemComb(FHDLTestCase): def tst(self, shape, full): assert isinstance(shape, CLDivRemShape) + width = shape.width m = Module() - n_in = Signal(shape.n_width) - d_in = Signal(shape.width) + n_in = Signal(width) + d_in = Signal(width) + q_out = Signal(width) + r_out = Signal(width) states: "list[CLDivRemState]" = [] for i in shape.step_range: states.append(CLDivRemState(shape, name=f"state_{i}")) @@ -154,13 +157,14 @@ class TestCLDivRemComb(FHDLTestCase): states[i].set_to_initial(m, n=n_in, d=d_in) else: states[i].set_to_next(m, states[i - 1]) + q, r = states[-1].get_output() + m.d.comb += [q_out.eq(q), r_out.eq(r)] def case(n, d): assert isinstance(n, int) assert isinstance(d, int) - max_width = max(shape.width, shape.n_width) if d != 0: - expected_q, expected_r = cldivrem(n, d, width=max_width) + expected_q, expected_r = cldivrem_shifting(n, d, width) else: expected_q = expected_r = 0 with self.subTest(n=hex(n), d=hex(d), @@ -175,35 +179,38 @@ class TestCLDivRemComb(FHDLTestCase): step = yield states[i].step self.assertEqual(done, i >= shape.done_step) self.assertEqual(step, i) - q = yield states[-1].q - r = yield states[-1].r + q = yield q_out + r = yield r_out with self.subTest(q=hex(q), r=hex(r)): # only check results when inputs are valid - if d != 0 and (expected_q >> shape.width) == 0: + if d != 0: self.assertEqual(q, expected_q) self.assertEqual(r, expected_r) def process(): if full: - for n in range(1 << shape.n_width): - for d in range(1 << shape.width): + for n in range(1 << width): + for d in range(1 << width): yield from case(n, d) else: for i in range(100): n = hash_256(f"cldivrem comb n {i}") - n = Const.normalize(n, unsigned(shape.n_width)) + n = Const.normalize(n, unsigned(width)) d = hash_256(f"cldivrem comb d {i}") - d = Const.normalize(d, unsigned(shape.width)) + d = Const.normalize(d, unsigned(width)) yield from case(n, d) - with do_sim(self, m, [n_in, d_in, states[-1].q, states[-1].r]) as sim: + with do_sim(self, m, [n_in, d_in, q_out, r_out]) as sim: sim.add_process(process) sim.run() def test_4(self): - self.tst(CLDivRemShape(width=4, n_width=4), full=True) + self.tst(CLDivRemShape(width=4), full=True) + + def test_6(self): + self.tst(CLDivRemShape(width=6), full=True) - def test_8_by_4(self): - self.tst(CLDivRemShape(width=4, n_width=8), full=True) + def test_8(self): + self.tst(CLDivRemShape(width=8), full=False) class TestCLDivRemFSM(FHDLTestCase): @@ -214,7 +221,7 @@ class TestCLDivRemFSM(FHDLTestCase): dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock) i_data: CLDivRemInputData = dut.p.i_data o_data: CLDivRemOutputData = dut.n.o_data - self.assertEqual(i_data.n.shape(), unsigned(shape.n_width)) + self.assertEqual(i_data.n.shape(), unsigned(shape.width)) self.assertEqual(i_data.d.shape(), unsigned(shape.width)) self.assertEqual(o_data.q.shape(), unsigned(shape.width)) self.assertEqual(o_data.r.shape(), unsigned(shape.width)) @@ -222,9 +229,8 @@ class TestCLDivRemFSM(FHDLTestCase): def case(n, d): assert isinstance(n, int) assert isinstance(d, int) - max_width = max(shape.width, shape.n_width) if d != 0: - expected_q, expected_r = cldivrem(n, d, width=max_width) + expected_q, expected_r = cldivrem(n, d, width=shape.width) else: expected_q = expected_r = 0 with self.subTest(n=hex(n), d=hex(d), @@ -245,8 +251,7 @@ class TestCLDivRemFSM(FHDLTestCase): yield i_data.n.eq(-1) yield i_data.d.eq(-1) yield dut.p.i_valid.eq(0) - for i in range(steps_per_clock * 2, shape.done_step, - steps_per_clock): + for step in range(0, shape.done_step, steps_per_clock): yield Delay(0.1e-6) valid = yield dut.n.o_valid ready = yield dut.p.o_ready @@ -279,13 +284,13 @@ class TestCLDivRemFSM(FHDLTestCase): def process(): if full: - for n in range(1 << shape.n_width): + for n in range(1 << shape.width): for d in range(1 << shape.width): yield from case(n, d) else: for i in range(100): n = hash_256(f"cldivrem fsm n {i}") - n = Const.normalize(n, unsigned(shape.n_width)) + n = Const.normalize(n, unsigned(shape.width)) d = hash_256(f"cldivrem fsm d {i}") d = Const.normalize(d, unsigned(shape.width)) yield from case(n, d) @@ -296,20 +301,40 @@ class TestCLDivRemFSM(FHDLTestCase): sim.run() def test_4_step_1(self): - self.tst(CLDivRemShape(width=4, n_width=4), + self.tst(CLDivRemShape(width=4), full=True, steps_per_clock=1) def test_4_step_2(self): - self.tst(CLDivRemShape(width=4, n_width=4), + self.tst(CLDivRemShape(width=4), full=True, steps_per_clock=2) def test_4_step_3(self): - self.tst(CLDivRemShape(width=4, n_width=4), + self.tst(CLDivRemShape(width=4), full=True, steps_per_clock=3) + def test_4_step_4(self): + self.tst(CLDivRemShape(width=4), + full=True, + steps_per_clock=4) + + def test_8_step_4(self): + self.tst(CLDivRemShape(width=8), + full=False, + steps_per_clock=4) + + def test_64_step_4(self): + self.tst(CLDivRemShape(width=64), + full=False, + steps_per_clock=4) + + def test_64_step_8(self): + self.tst(CLDivRemShape(width=64), + full=False, + steps_per_clock=8) + if __name__ == "__main__": unittest.main() -- 2.30.2