X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fieee754%2Fpart_mul_add%2Fmultiply.py;h=e0531b420e2cc885b2e10c42acec581c8e852c0b;hb=9d3ead0d8c46b1d4929a3e03f61597f3a9c56e00;hp=dbb09d03808c00a1ee6095d600bddf346c635b34;hpb=9662e1b65a20183d1f9a0c9dbb9acd98b63c6cb8;p=ieee754fpu.git diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index dbb09d03..e0531b42 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -346,15 +346,13 @@ class FinalAdd(Elaboratable): """ Final stage of add reduce """ - def __init__(self, n_inputs, output_width, n_parts, register_levels, - partition_points): + def __init__(self, n_inputs, output_width, n_parts, partition_points): self.i = AddReduceData(partition_points, n_inputs, output_width, n_parts) self.o = FinalReduceData(partition_points, output_width, n_parts) self.output_width = output_width self.n_inputs = n_inputs self.n_parts = n_parts - self.register_levels = list(register_levels) self.partition_points = PartitionPoints(partition_points) if not self.partition_points.fits_in_width(output_width): raise ValueError("partition_points doesn't fit in output_width") @@ -400,14 +398,11 @@ class AddReduceSingle(Elaboratable): supported, except for by ``Signal.eq``. """ - def __init__(self, n_inputs, output_width, n_parts, register_levels, - partition_points): + def __init__(self, n_inputs, output_width, n_parts, partition_points): """Create an ``AddReduce``. :param inputs: input ``Signal``s to be summed. :param output_width: bit-width of ``output``. - :param register_levels: List of nesting levels that should have - pipeline registers. :param partition_points: the input partition points. """ self.n_inputs = n_inputs @@ -415,17 +410,10 @@ class AddReduceSingle(Elaboratable): self.output_width = output_width self.i = AddReduceData(partition_points, n_inputs, output_width, n_parts) - self.register_levels = list(register_levels) self.partition_points = PartitionPoints(partition_points) if not self.partition_points.fits_in_width(output_width): raise ValueError("partition_points doesn't fit in output_width") - max_level = AddReduceSingle.get_max_level(n_inputs) - for level in self.register_levels: - if level > max_level: - raise ValueError( - "not enough adder levels for specified register levels") - self.groups = AddReduceSingle.full_adder_groups(n_inputs) n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups) self.o = AddReduceData(partition_points, n_terms, output_width, n_parts) @@ -574,7 +562,6 @@ class AddReduce(Elaboratable): """creates reduction levels""" mods = [] - next_levels = self.register_levels partition_points = self.partition_points part_ops = self.part_ops n_parts = len(part_ops) @@ -585,16 +572,15 @@ class AddReduce(Elaboratable): if len(groups) == 0: break next_level = AddReduceSingle(ilen, self.output_width, n_parts, - next_levels, partition_points) + partition_points) mods.append(next_level) - next_levels = list(AddReduce.next_register_levels(next_levels)) partition_points = next_level.i.part_pts inputs = next_level.o.terms ilen = len(inputs) part_ops = next_level.i.part_ops next_level = FinalAdd(ilen, self.output_width, n_parts, - next_levels, partition_points) + partition_points) mods.append(next_level) self.levels = mods @@ -616,7 +602,7 @@ class AddReduce(Elaboratable): m.d.comb += i.eq_from(partition_points, inputs, part_ops) for idx in range(len(self.levels)): mcur = self.levels[idx] - if 0 in mcur.register_levels: + if idx in self.register_levels: m.d.sync += mcur.i.eq(i) else: m.d.comb += mcur.i.eq(i) @@ -959,11 +945,11 @@ class FinalOut(Elaboratable): def elaborate(self, platform): m = Module() - pps = self.part_pts - m.submodules.p_8 = p_8 = Parts(8, pps, 8) - m.submodules.p_16 = p_16 = Parts(8, pps, 4) - m.submodules.p_32 = p_32 = Parts(8, pps, 2) - m.submodules.p_64 = p_64 = Parts(8, pps, 1) + part_pts = self.part_pts + m.submodules.p_8 = p_8 = Parts(8, part_pts, 8) + m.submodules.p_16 = p_16 = Parts(8, part_pts, 4) + m.submodules.p_32 = p_32 = Parts(8, part_pts, 2) + m.submodules.p_64 = p_64 = Parts(8, part_pts, 1) out_part_pts = self.i.part_pts @@ -1079,12 +1065,14 @@ class IntermediateData: rhs.intermediate_output, rhs.part_ops) -class AllTermsData: +class InputData: - def __init__(self, partition_points): + def __init__(self): self.a = Signal(64) self.b = Signal(64) - self.part_pts = partition_points.like() + self.part_pts = PartitionPoints() + for i in range(8, 64, 8): + self.part_pts[i] = Signal(name=f"part_pts_{i}") self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)] def eq_from(self, part_pts, inputs, part_ops): @@ -1101,8 +1089,7 @@ class AllTerms(Elaboratable): """Set of terms to be added together """ - def __init__(self, n_inputs, output_width, n_parts, register_levels, - partition_points): + def __init__(self, n_inputs, output_width, n_parts, register_levels): """Create an ``AddReduce``. :param inputs: input ``Signal``s to be summed. @@ -1111,7 +1098,7 @@ class AllTerms(Elaboratable): pipeline registers. :param partition_points: the input partition points. """ - self.i = AllTermsData(partition_points) + self.i = InputData() self.register_levels = register_levels self.n_inputs = n_inputs self.n_parts = n_parts @@ -1285,12 +1272,11 @@ class Mul8_16_32_64(Elaboratable): self.register_levels = list(register_levels) # inputs - self.part_pts = PartitionPoints() - for i in range(8, 64, 8): - self.part_pts[i] = Signal(name=f"part_pts_{i}") - self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)] - self.a = Signal(64) - self.b = Signal(64) + self.i = InputData() + self.part_pts = self.i.part_pts + self.part_ops = self.i.part_ops + self.a = self.i.a + self.b = self.i.b # intermediates (needed for unit tests) self.intermediate_output = Signal(128) @@ -1301,15 +1287,15 @@ class Mul8_16_32_64(Elaboratable): def elaborate(self, platform): m = Module() - pps = self.part_pts + part_pts = self.part_pts n_inputs = 64 + 4 n_parts = 8 #len(self.part_pts) - t = AllTerms(n_inputs, 128, n_parts, self.register_levels, pps) + t = AllTerms(n_inputs, 128, n_parts, self.register_levels) m.submodules.allterms = t m.d.comb += t.i.a.eq(self.a) m.d.comb += t.i.b.eq(self.b) - m.d.comb += t.i.part_pts.eq(pps) + m.d.comb += t.i.part_pts.eq(part_pts) for i in range(8): m.d.comb += t.i.part_ops[i].eq(self.part_ops[i]) @@ -1326,12 +1312,12 @@ class Mul8_16_32_64(Elaboratable): m.submodules.add_reduce = add_reduce - interm = Intermediates(128, 8, pps) + interm = Intermediates(128, 8, part_pts) m.submodules.intermediates = interm m.d.comb += interm.i.eq(add_reduce.o) # final output - m.submodules.finalout = finalout = FinalOut(128, 8, pps) + m.submodules.finalout = finalout = FinalOut(128, 8, part_pts) m.d.comb += finalout.i.eq(interm.o) m.d.comb += self.output.eq(finalout.out) m.d.comb += self.intermediate_output.eq(finalout.intermediate_output)