From 1e1b12812ebd3f052f2ddc8a6eaf8e84e533013f Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Sat, 17 Aug 2019 13:33:22 +0100 Subject: [PATCH] split out "Parts" to separate module --- src/ieee754/part_mul_add/multiply.py | 245 +++++++++--------- .../part_mul_add/test/test_multiply.py | 28 -- 2 files changed, 116 insertions(+), 157 deletions(-) diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index 33232e77..189656db 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -6,7 +6,8 @@ from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl from nmigen.hdl.ast import Assign from abc import ABCMeta, abstractmethod from nmigen.cli import main - +from functools import reduce +from operator import or_ class PartitionPoints(dict): """Partition points and corresponding ``Value``s. @@ -424,6 +425,90 @@ class ProductTerm(Elaboratable): return m +class Part(Elaboratable): + def __init__(self, width, n_parts, n_levels, pbwid): + + # inputs + self.a = Signal(64) + self.b = Signal(64) + self._a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)] + self._b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)] + self.pbs = Signal(pbwid, reset_less=True) + + # outputs + self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)] + self.delayed_parts = [ + [Signal(name=f"delayed_part_8_{delay}_{i}") + for i in range(n_parts)] + for delay in range(n_levels)] + + self.not_a_term = Signal(width) + self.neg_lsb_a_term = Signal(width) + self.not_b_term = Signal(width) + self.neg_lsb_b_term = Signal(width) + + def elaborate(self, platform): + m = Module() + + pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts + byte_count = 8 // len(parts) + for i in range(len(parts)): + pbl = [] + pbl.append(~pbs[i * byte_count - 1]) + for j in range(i * byte_count, (i + 1) * byte_count - 1): + pbl.append(pbs[j]) + pbl.append(~pbs[(i + 1) * byte_count - 1]) + value = Signal(len(pbl), reset_less=True) + m.d.comb += value.eq(Cat(*pbl)) + m.d.comb += parts[i].eq(~(value).bool()) + m.d.comb += delayed_parts[0][i].eq(parts[i]) + m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i]) + for j in range(len(delayed_parts)-1)] + + not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \ + self.not_a_term, self.neg_lsb_a_term, \ + self.not_b_term, self.neg_lsb_b_term + + byte_width = 8 // len(parts) + bit_width = 8 * byte_width + nat, nbt, nla, nlb = [], [], [], [] + for i in range(len(parts)): + be = parts[i] & self.a[(i + 1) * bit_width - 1] \ + & self._a_signed[i * byte_width] + ae = parts[i] & self.b[(i + 1) * bit_width - 1] \ + & self._b_signed[i * byte_width] + a_enabled = Signal(name="a_en_%d" % i, reset_less=True) + b_enabled = Signal(name="b_en_%d" % i, reset_less=True) + m.d.comb += a_enabled.eq(ae) + m.d.comb += b_enabled.eq(be) + + # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the + # negation operation is split into a bitwise not and a +1. + # likewise for 16, 32, and 64-bit values. + nat.append(Mux(a_enabled, + Cat(Repl(0, bit_width), + ~self.a.bit_select(bit_width * i, bit_width)), + 0)) + + nla.append(Cat(Repl(0, bit_width), a_enabled, + Repl(0, bit_width-1))) + + nbt.append(Mux(b_enabled, + Cat(Repl(0, bit_width), + ~self.b.bit_select(bit_width * i, bit_width)), + 0)) + + nlb.append(Cat(Repl(0, bit_width), b_enabled, + Repl(0, bit_width-1))) + + m.d.comb += [not_a_term.eq(Cat(*nat)), + not_b_term.eq(Cat(*nbt)), + neg_lsb_a_term.eq(Cat(*nla)), + neg_lsb_b_term.eq(Cat(*nlb)), + ] + + return m + class Mul8_16_32_64(Elaboratable): """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier. @@ -467,47 +552,12 @@ class Mul8_16_32_64(Elaboratable): [Signal(2, name=f"_delayed_part_ops_{delay}_{i}") for i in range(8)] for delay in range(1 + len(self.register_levels))] - self._part_8 = [Signal(name=f"_part_8_{i}") for i in range(8)] - self._part_16 = [Signal(name=f"_part_16_{i}") for i in range(4)] - self._part_32 = [Signal(name=f"_part_32_{i}") for i in range(2)] - self._part_64 = [Signal(name=f"_part_64")] - self._delayed_part_8 = [ - [Signal(name=f"_delayed_part_8_{delay}_{i}") - for i in range(8)] - for delay in range(1 + len(self.register_levels))] - self._delayed_part_16 = [ - [Signal(name=f"_delayed_part_16_{delay}_{i}") - for i in range(4)] - for delay in range(1 + len(self.register_levels))] - self._delayed_part_32 = [ - [Signal(name=f"_delayed_part_32_{delay}_{i}") - for i in range(2)] - for delay in range(1 + len(self.register_levels))] - self._delayed_part_64 = [ - [Signal(name=f"_delayed_part_64_{delay}")] - for delay in range(1 + len(self.register_levels))] self._output_64 = Signal(64) self._output_32 = Signal(64) self._output_16 = Signal(64) self._output_8 = Signal(64) self._a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)] self._b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)] - self._not_a_term_8 = Signal(128) - self._neg_lsb_a_term_8 = Signal(128) - self._not_b_term_8 = Signal(128) - self._neg_lsb_b_term_8 = Signal(128) - self._not_a_term_16 = Signal(128) - self._neg_lsb_a_term_16 = Signal(128) - self._not_b_term_16 = Signal(128) - self._neg_lsb_b_term_16 = Signal(128) - self._not_a_term_32 = Signal(128) - self._neg_lsb_a_term_32 = Signal(128) - self._not_b_term_32 = Signal(128) - self._neg_lsb_b_term_32 = Signal(128) - self._not_a_term_64 = Signal(128) - self._neg_lsb_a_term_64 = Signal(128) - self._not_b_term_64 = Signal(128) - self._neg_lsb_b_term_64 = Signal(128) def _part_byte(self, index): if index == -1 or index == 7: @@ -533,23 +583,24 @@ class Mul8_16_32_64(Elaboratable): .eq(self._delayed_part_ops[j][i]) for j in range(len(self.register_levels))] - for parts, delayed_parts in [(self._part_64, self._delayed_part_64), - (self._part_32, self._delayed_part_32), - (self._part_16, self._delayed_part_16), - (self._part_8, self._delayed_part_8)]: - byte_count = 8 // len(parts) - for i in range(len(parts)): - pbl = [] - pbl.append(~pbs[i * byte_count - 1]) - for j in range(i * byte_count, (i + 1) * byte_count - 1): - pbl.append(pbs[j]) - pbl.append(~pbs[(i + 1) * byte_count - 1]) - value = Signal(len(pbl), reset_less=True) - m.d.comb += value.eq(Cat(*pbl)) - m.d.comb += parts[i].eq(~(value).bool()) - m.d.comb += delayed_parts[0][i].eq(parts[i]) - m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i]) - for j in range(len(self.register_levels))] + n_levels = len(self.register_levels)+1 + m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8) + m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8) + m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8) + m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8) + nat_l, nbt_l, nla_l, nlb_l = [], [], [], [] + for mod in [part_8, part_16, part_32, part_64]: + m.d.comb += mod.a.eq(self.a) + m.d.comb += mod.b.eq(self.b) + for i in range(len(self._a_signed)): + m.d.comb += mod._a_signed[i].eq(self._a_signed[i]) + for i in range(len(self._b_signed)): + m.d.comb += mod._b_signed[i].eq(self._b_signed[i]) + m.d.comb += mod.pbs.eq(pbs) + nat_l.append(mod.not_a_term) + nbt_l.append(mod.not_b_term) + nla_l.append(mod.neg_lsb_a_term) + nlb_l.append(mod.neg_lsb_b_term) terms = [] @@ -573,87 +624,23 @@ class Mul8_16_32_64(Elaboratable): # it's fine to bitwise-or these together since they are never enabled # at the same time + nat_l = reduce(or_, nat_l) + nbt_l = reduce(or_, nbt_l) + nla_l = reduce(or_, nla_l) + nlb_l = reduce(or_, nlb_l) m.submodules.nat = nat = Term(128, 128) m.submodules.nla = nla = Term(128, 128) m.submodules.nbt = nbt = Term(128, 128) m.submodules.nlb = nlb = Term(128, 128) - m.d.comb += nat.ti.eq(self._not_a_term_8 | self._not_a_term_16 - | self._not_a_term_32 | self._not_a_term_64) - m.d.comb += nbt.ti.eq(self._not_b_term_8 | self._not_b_term_16 - | self._not_b_term_32 | self._not_b_term_64) - m.d.comb += nla.ti.eq(self._neg_lsb_a_term_8 | self._neg_lsb_a_term_16 - | self._neg_lsb_a_term_32 | self._neg_lsb_a_term_64) - m.d.comb += nlb.ti.eq(self._neg_lsb_b_term_8 | self._neg_lsb_b_term_16 - | self._neg_lsb_b_term_32 | self._neg_lsb_b_term_64) + m.d.comb += nat.ti.eq(nat_l) + m.d.comb += nbt.ti.eq(nbt_l) + m.d.comb += nla.ti.eq(nla_l) + m.d.comb += nlb.ti.eq(nlb_l) terms.append(nat.term) terms.append(nla.term) terms.append(nbt.term) terms.append(nlb.term) - for not_a_term, \ - neg_lsb_a_term, \ - not_b_term, \ - neg_lsb_b_term, \ - parts in [ - (self._not_a_term_8, - self._neg_lsb_a_term_8, - self._not_b_term_8, - self._neg_lsb_b_term_8, - self._part_8), - (self._not_a_term_16, - self._neg_lsb_a_term_16, - self._not_b_term_16, - self._neg_lsb_b_term_16, - self._part_16), - (self._not_a_term_32, - self._neg_lsb_a_term_32, - self._not_b_term_32, - self._neg_lsb_b_term_32, - self._part_32), - (self._not_a_term_64, - self._neg_lsb_a_term_64, - self._not_b_term_64, - self._neg_lsb_b_term_64, - self._part_64), - ]: - byte_width = 8 // len(parts) - bit_width = 8 * byte_width - nat, nbt, nla, nlb = [], [], [], [] - for i in range(len(parts)): - be = parts[i] & self.a[(i + 1) * bit_width - 1] \ - & self._a_signed[i * byte_width] - ae = parts[i] & self.b[(i + 1) * bit_width - 1] \ - & self._b_signed[i * byte_width] - a_enabled = Signal(name="a_en_%d" % i, reset_less=True) - b_enabled = Signal(name="b_en_%d" % i, reset_less=True) - m.d.comb += a_enabled.eq(ae) - m.d.comb += b_enabled.eq(be) - - # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the - # negation operation is split into a bitwise not and a +1. - # likewise for 16, 32, and 64-bit values. - nat.append(Mux(a_enabled, - Cat(Repl(0, bit_width), - ~self.a.bit_select(bit_width * i, bit_width)), - 0)) - - nla.append(Cat(Repl(0, bit_width), a_enabled, - Repl(0, bit_width-1))) - - nbt.append(Mux(b_enabled, - Cat(Repl(0, bit_width), - ~self.b.bit_select(bit_width * i, bit_width)), - 0)) - - nlb.append(Cat(Repl(0, bit_width), b_enabled, - Repl(0, bit_width-1))) - - m.d.comb += [not_a_term.eq(Cat(*nat)), - not_b_term.eq(Cat(*nbt)), - neg_lsb_a_term.eq(Cat(*nla)), - neg_lsb_b_term.eq(Cat(*nlb)), - ] - expanded_part_pts = PartitionPoints() for i, v in self.part_pts.items(): signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True) @@ -709,12 +696,12 @@ class Mul8_16_32_64(Elaboratable): for i in range(8): op = Signal(8, reset_less=True, name="op%d" % i) m.d.comb += op.eq( - Mux(self._delayed_part_8[-1][i] - | self._delayed_part_16[-1][i // 2], - Mux(self._delayed_part_8[-1][i], + Mux(part_8.delayed_parts[-1][i] + | part_16.delayed_parts[-1][i // 2], + Mux(part_8.delayed_parts[-1][i], self._output_8.bit_select(i * 8, 8), self._output_16.bit_select(i * 8, 8)), - Mux(self._delayed_part_32[-1][i // 4], + Mux(part_32.delayed_parts[-1][i // 4], self._output_32.bit_select(i * 8, 8), self._output_64.bit_select(i * 8, 8)))) ol.append(op) diff --git a/src/ieee754/part_mul_add/test/test_multiply.py b/src/ieee754/part_mul_add/test/test_multiply.py index ef7f5cd7..d96d45c1 100644 --- a/src/ieee754/part_mul_add/test/test_multiply.py +++ b/src/ieee754/part_mul_add/test/test_multiply.py @@ -527,40 +527,12 @@ class TestMul8_16_32_64(unittest.TestCase): ports.extend(module.part_pts.values()) for signals in module._delayed_part_ops: ports.extend(signals) - ports.extend(module._part_8) - ports.extend(module._part_16) - ports.extend(module._part_32) - ports.extend(module._part_64) - for signals in module._delayed_part_8: - ports.extend(signals) - for signals in module._delayed_part_16: - ports.extend(signals) - for signals in module._delayed_part_32: - ports.extend(signals) - for signals in module._delayed_part_64: - ports.extend(signals) ports += [module._output_64, module._output_32, module._output_16, module._output_8] ports.extend(module._a_signed) ports.extend(module._b_signed) - ports += [module._not_a_term_8, - module._neg_lsb_a_term_8, - module._not_b_term_8, - module._neg_lsb_b_term_8, - module._not_a_term_16, - module._neg_lsb_a_term_16, - module._not_b_term_16, - module._neg_lsb_b_term_16, - module._not_a_term_32, - module._neg_lsb_a_term_32, - module._not_b_term_32, - module._neg_lsb_b_term_32, - module._not_a_term_64, - module._neg_lsb_a_term_64, - module._not_b_term_64, - module._neg_lsb_b_term_64] with create_simulator(module, ports, file_name) as sim: def process(gen_or_check: GenOrCheck) -> AsyncProcessGenerator: for a_signed in False, True: -- 2.30.2