X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fieee754%2Fpart_mul_add%2Fmultiply.py;h=2c828c187f2747bde3285df64599636719e3be72;hb=a5a060d10873d4ae26ba656fa9bfdda96a429d4e;hp=aef5c16e1a1a5db0c402857c41a68876ce5f12f4;hpb=15dc23ee136eff9b2adaec11b4723979ce57b6bf;p=ieee754fpu.git diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index aef5c16e..2c828c18 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -346,7 +346,9 @@ class FinalAdd(Elaboratable): """ Final stage of add reduce """ - def __init__(self, n_inputs, output_width, n_parts, partition_points): + def __init__(self, n_inputs, output_width, n_parts, partition_points, + partition_step=1): + self.partition_step = partition_step self.output_width = output_width self.n_inputs = n_inputs self.n_parts = n_parts @@ -381,7 +383,7 @@ class FinalAdd(Elaboratable): # base case for adding 2 inputs assert self.n_inputs == 2 adder = PartitionedAdder(output_width, - self.i.part_pts, 2) + self.i.part_pts, self.partition_step) m.submodules.final_adder = adder m.d.comb += adder.a.eq(self.i.terms[0]) m.d.comb += adder.b.eq(self.i.terms[1]) @@ -406,13 +408,15 @@ class AddReduceSingle(Elaboratable): supported, except for by ``Signal.eq``. """ - def __init__(self, n_inputs, output_width, n_parts, partition_points): + def __init__(self, n_inputs, output_width, n_parts, partition_points, + partition_step=1): """Create an ``AddReduce``. :param inputs: input ``Signal``s to be summed. :param output_width: bit-width of ``output``. :param partition_points: the input partition points. """ + self.partition_step = partition_step self.n_inputs = n_inputs self.n_parts = n_parts self.output_width = output_width @@ -516,7 +520,8 @@ class AddReduceSingle(Elaboratable): part_mask = Signal(self.output_width, reset_less=True) # get partition points as a mask - mask = self.i.part_pts.as_mask(self.output_width, mul=2) + mask = self.i.part_pts.as_mask(self.output_width, + mul=self.partition_step) m.d.comb += part_mask.eq(mask) # add and link the intermediate term modules @@ -543,18 +548,19 @@ class AddReduceInternal: supported, except for by ``Signal.eq``. """ - def __init__(self, inputs, output_width, partition_points, - part_ops): + def __init__(self, i, output_width, partition_step=1): """Create an ``AddReduce``. :param inputs: input ``Signal``s to be summed. :param output_width: bit-width of ``output``. :param partition_points: the input partition points. """ - self.inputs = inputs - self.part_ops = part_ops + self.i = i + self.inputs = i.terms + self.part_ops = i.part_ops self.output_width = output_width - self.partition_points = partition_points + self.partition_points = i.part_pts + self.partition_step = partition_step self.create_levels() @@ -572,7 +578,8 @@ class AddReduceInternal: if len(groups) == 0: break next_level = AddReduceSingle(ilen, self.output_width, n_parts, - partition_points) + partition_points, + self.partition_step) mods.append(next_level) partition_points = next_level.i.part_pts inputs = next_level.o.terms @@ -580,7 +587,7 @@ class AddReduceInternal: part_ops = next_level.i.part_ops next_level = FinalAdd(ilen, self.output_width, n_parts, - partition_points) + partition_points, self.partition_step) mods.append(next_level) self.levels = mods @@ -598,8 +605,8 @@ class AddReduce(AddReduceInternal, Elaboratable): supported, except for by ``Signal.eq``. """ - def __init__(self, inputs, output_width, register_levels, partition_points, - part_ops): + def __init__(self, inputs, output_width, register_levels, part_pts, + part_ops, partition_step=1): """Create an ``AddReduce``. :param inputs: input ``Signal``s to be summed. @@ -608,10 +615,14 @@ class AddReduce(AddReduceInternal, Elaboratable): pipeline registers. :param partition_points: the input partition points. """ - AddReduceInternal.__init__(self, inputs, output_width, - partition_points, part_ops) + self._inputs = inputs + self._part_pts = part_pts + self._part_ops = part_ops n_parts = len(part_ops) - self.o = FinalReduceData(partition_points, output_width, n_parts) + self.i = AddReduceData(part_pts, len(inputs), + output_width, n_parts) + AddReduceInternal.__init__(self, self.i, output_width, partition_step) + self.o = FinalReduceData(part_pts, output_width, n_parts) self.register_levels = register_levels @staticmethod @@ -629,17 +640,12 @@ class AddReduce(AddReduceInternal, Elaboratable): """Elaborate this module.""" m = Module() + m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops) + for i, next_level in enumerate(self.levels): setattr(m.submodules, "next_level%d" % i, next_level) - partition_points = self.partition_points - inputs = self.inputs - part_ops = self.part_ops - n_parts = len(part_ops) - n_inputs = len(inputs) - output_width = self.output_width - i = AddReduceData(partition_points, n_inputs, output_width, n_parts) - m.d.comb += i.eq_from(partition_points, inputs, part_ops) + i = self.i for idx in range(len(self.levels)): mcur = self.levels[idx] if idx in self.register_levels: @@ -1383,17 +1389,21 @@ class Mul8_16_32_64(Elaboratable): terms = t.o.terms - add_reduce = AddReduce(terms, - 128, - self.register_levels, - t.o.part_pts, - t.o.part_ops) + at = AddReduceInternal(t.o, 128, partition_step=2) - m.submodules.add_reduce = add_reduce + i = at.i + for idx in range(len(at.levels)): + mcur = at.levels[idx] + setattr(m.submodules, "addreduce_%d" % idx, mcur) + if idx in self.register_levels: + m.d.sync += mcur.i.eq(i) + else: + m.d.comb += mcur.i.eq(i) + i = mcur.o # for next loop interm = Intermediates(128, 8, part_pts) m.submodules.intermediates = interm - m.d.comb += interm.i.eq(add_reduce.o) + m.d.comb += interm.i.eq(i) # final output m.submodules.finalout = finalout = FinalOut(128, 8, part_pts)