X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fieee754%2Fpart_mul_add%2Fmultiply.py;h=e0531b420e2cc885b2e10c42acec581c8e852c0b;hb=9d3ead0d8c46b1d4929a3e03f61597f3a9c56e00;hp=92afc2bbd85f18e6305a7f85576af793d2d8e2af;hpb=7b4588e1710b0e997b9dd5d1f0376150afe276ce;p=ieee754fpu.git diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index 92afc2bb..e0531b42 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -346,15 +346,13 @@ class FinalAdd(Elaboratable): """ Final stage of add reduce """ - def __init__(self, n_inputs, output_width, n_parts, register_levels, - partition_points): + def __init__(self, n_inputs, output_width, n_parts, partition_points): self.i = AddReduceData(partition_points, n_inputs, output_width, n_parts) self.o = FinalReduceData(partition_points, output_width, n_parts) self.output_width = output_width self.n_inputs = n_inputs self.n_parts = n_parts - self.register_levels = list(register_levels) self.partition_points = PartitionPoints(partition_points) if not self.partition_points.fits_in_width(output_width): raise ValueError("partition_points doesn't fit in output_width") @@ -400,14 +398,11 @@ class AddReduceSingle(Elaboratable): supported, except for by ``Signal.eq``. """ - def __init__(self, n_inputs, output_width, n_parts, register_levels, - partition_points): + def __init__(self, n_inputs, output_width, n_parts, partition_points): """Create an ``AddReduce``. :param inputs: input ``Signal``s to be summed. :param output_width: bit-width of ``output``. - :param register_levels: List of nesting levels that should have - pipeline registers. :param partition_points: the input partition points. """ self.n_inputs = n_inputs @@ -415,17 +410,10 @@ class AddReduceSingle(Elaboratable): self.output_width = output_width self.i = AddReduceData(partition_points, n_inputs, output_width, n_parts) - self.register_levels = list(register_levels) self.partition_points = PartitionPoints(partition_points) if not self.partition_points.fits_in_width(output_width): raise ValueError("partition_points doesn't fit in output_width") - max_level = AddReduceSingle.get_max_level(n_inputs) - for level in self.register_levels: - if level > max_level: - raise ValueError( - "not enough adder levels for specified register levels") - self.groups = AddReduceSingle.full_adder_groups(n_inputs) n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups) self.o = AddReduceData(partition_points, n_terms, output_width, n_parts) @@ -574,7 +562,6 @@ class AddReduce(Elaboratable): """creates reduction levels""" mods = [] - next_levels = self.register_levels partition_points = self.partition_points part_ops = self.part_ops n_parts = len(part_ops) @@ -585,16 +572,15 @@ class AddReduce(Elaboratable): if len(groups) == 0: break next_level = AddReduceSingle(ilen, self.output_width, n_parts, - next_levels, partition_points) + partition_points) mods.append(next_level) - next_levels = list(AddReduce.next_register_levels(next_levels)) partition_points = next_level.i.part_pts inputs = next_level.o.terms ilen = len(inputs) part_ops = next_level.i.part_ops next_level = FinalAdd(ilen, self.output_width, n_parts, - next_levels, partition_points) + partition_points) mods.append(next_level) self.levels = mods @@ -1079,12 +1065,14 @@ class IntermediateData: rhs.intermediate_output, rhs.part_ops) -class AllTermsData: +class InputData: - def __init__(self, partition_points): + def __init__(self): self.a = Signal(64) self.b = Signal(64) - self.part_pts = partition_points.like() + self.part_pts = PartitionPoints() + for i in range(8, 64, 8): + self.part_pts[i] = Signal(name=f"part_pts_{i}") self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)] def eq_from(self, part_pts, inputs, part_ops): @@ -1101,8 +1089,7 @@ class AllTerms(Elaboratable): """Set of terms to be added together """ - def __init__(self, n_inputs, output_width, n_parts, register_levels, - partition_points): + def __init__(self, n_inputs, output_width, n_parts, register_levels): """Create an ``AddReduce``. :param inputs: input ``Signal``s to be summed. @@ -1111,7 +1098,7 @@ class AllTerms(Elaboratable): pipeline registers. :param partition_points: the input partition points. """ - self.i = AllTermsData(partition_points) + self.i = InputData() self.register_levels = register_levels self.n_inputs = n_inputs self.n_parts = n_parts @@ -1285,12 +1272,11 @@ class Mul8_16_32_64(Elaboratable): self.register_levels = list(register_levels) # inputs - self.part_pts = PartitionPoints() - for i in range(8, 64, 8): - self.part_pts[i] = Signal(name=f"part_pts_{i}") - self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)] - self.a = Signal(64) - self.b = Signal(64) + self.i = InputData() + self.part_pts = self.i.part_pts + self.part_ops = self.i.part_ops + self.a = self.i.a + self.b = self.i.b # intermediates (needed for unit tests) self.intermediate_output = Signal(128) @@ -1305,7 +1291,7 @@ class Mul8_16_32_64(Elaboratable): n_inputs = 64 + 4 n_parts = 8 #len(self.part_pts) - t = AllTerms(n_inputs, 128, n_parts, self.register_levels, part_pts) + t = AllTerms(n_inputs, 128, n_parts, self.register_levels) m.submodules.allterms = t m.d.comb += t.i.a.eq(self.a) m.d.comb += t.i.b.eq(self.b)