- m.submodules.p_8 = p_8 = Parts(8, eps, len(part_8.parts))
- m.submodules.p_16 = p_16 = Parts(8, eps, len(part_16.parts))
- m.submodules.p_32 = p_32 = Parts(8, eps, len(part_32.parts))
- m.submodules.p_64 = p_64 = Parts(8, eps, len(part_64.parts))
+ :attribute part_pts: the input partition points. Has a partition point at
+ multiples of 8 in 0 < i < 64. Each partition point's associated
+ ``Value`` is a ``Signal``. Modification not supported, except for by
+ ``Signal.eq``.
+ :attribute part_ops: the operation for each byte. The operation for a
+ particular partition is selected by assigning the selected operation
+ code to each byte in the partition. The allowed operation codes are:
+
+ :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
+ RISC-V's `mul` instruction.
+ :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
+ ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
+ instruction.
+ :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
+ where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
+ `mulhsu` instruction.
+ :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
+ ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
+ instruction.
+ """
+
+ def __init__(self, register_levels=()):
+ """ register_levels: specifies the points in the cascade at which
+ flip-flops are to be inserted.
+ """
+
+ # parameter(s)
+ self.register_levels = list(register_levels)
+
+ # inputs
+ self.part_pts = PartitionPoints()
+ for i in range(8, 64, 8):
+ self.part_pts[i] = Signal(name=f"part_pts_{i}")
+ self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
+ self.a = Signal(64)
+ self.b = Signal(64)
+
+ # intermediates (needed for unit tests)
+ self.intermediate_output = Signal(128)
+
+ # output
+ self.output = Signal(64)
+
+ def elaborate(self, platform):
+ m = Module()
+
+ part_pts = self.part_pts
+
+ n_inputs = 64 + 4
+ n_parts = 8 #len(self.part_pts)
+ t = AllTerms(n_inputs, 128, n_parts, self.register_levels, part_pts)
+ m.submodules.allterms = t
+ m.d.comb += t.i.a.eq(self.a)
+ m.d.comb += t.i.b.eq(self.b)
+ m.d.comb += t.i.part_pts.eq(part_pts)
+ for i in range(8):
+ m.d.comb += t.i.part_ops[i].eq(self.part_ops[i])
+
+ terms = t.o.terms
+
+ add_reduce = AddReduce(terms,
+ 128,
+ self.register_levels,
+ t.o.part_pts,
+ t.o.part_ops)
+
+ out_part_ops = add_reduce.o.part_ops
+ out_part_pts = add_reduce.o.part_pts
+
+ m.submodules.add_reduce = add_reduce