+class IntermediateOut(Elaboratable):
+ """ selects the HI/LO part of the multiplication, for a given bit-width
+ the output is also reconstructed in its SIMD (partition) lanes.
+ """
+ def __init__(self, width, out_wid, n_parts):
+ self.width = width
+ self.n_parts = n_parts
+ self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
+ for i in range(8)]
+ self.intermed = Signal(out_wid, reset_less=True)
+ self.output = Signal(out_wid//2, reset_less=True)
+
+ def elaborate(self, platform):
+ m = Module()
+
+ ol = []
+ w = self.width
+ sel = w // 8
+ for i in range(self.n_parts):
+ op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
+ m.d.comb += op.eq(
+ Mux(self.part_ops[sel * i] == OP_MUL_LOW,
+ self.intermed.part(i * w*2, w),
+ self.intermed.part(i * w*2 + w, w)))
+ ol.append(op)
+ m.d.comb += self.output.eq(Cat(*ol))
+
+ return m
+
+
+class FinalOut(Elaboratable):
+ """ selects the final output based on the partitioning.
+
+ each byte is selectable independently, i.e. it is possible
+ that some partitions requested 8-bit computation whilst others
+ requested 16 or 32 bit.
+ """
+ def __init__(self, out_wid):
+ # inputs
+ self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
+ self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
+ self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
+
+ self.i8 = Signal(out_wid, reset_less=True)
+ self.i16 = Signal(out_wid, reset_less=True)
+ self.i32 = Signal(out_wid, reset_less=True)
+ self.i64 = Signal(out_wid, reset_less=True)
+
+ # output
+ self.out = Signal(out_wid, reset_less=True)
+
+ def elaborate(self, platform):
+ m = Module()
+ ol = []
+ for i in range(8):
+ # select one of the outputs: d8 selects i8, d16 selects i16
+ # d32 selects i32, and the default is i64.
+ # d8 and d16 are ORed together in the first Mux
+ # then the 2nd selects either i8 or i16.
+ # if neither d8 nor d16 are set, d32 selects either i32 or i64.
+ op = Signal(8, reset_less=True, name="op_%d" % i)
+ m.d.comb += op.eq(
+ Mux(self.d8[i] | self.d16[i // 2],
+ Mux(self.d8[i], self.i8.part(i * 8, 8),
+ self.i16.part(i * 8, 8)),
+ Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
+ self.i64.part(i * 8, 8))))
+ ol.append(op)
+ m.d.comb += self.out.eq(Cat(*ol))
+ return m
+
+
+class OrMod(Elaboratable):
+ """ ORs four values together in a hierarchical tree
+ """
+ def __init__(self, wid):
+ self.wid = wid
+ self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
+ for i in range(4)]
+ self.orout = Signal(wid, reset_less=True)
+
+ def elaborate(self, platform):
+ m = Module()
+ or1 = Signal(self.wid, reset_less=True)
+ or2 = Signal(self.wid, reset_less=True)
+ m.d.comb += or1.eq(self.orin[0] | self.orin[1])
+ m.d.comb += or2.eq(self.orin[2] | self.orin[3])
+ m.d.comb += self.orout.eq(or1 | or2)
+
+ return m
+
+
+class Signs(Elaboratable):
+ """ determines whether a or b are signed numbers
+ based on the required operation type (OP_MUL_*)
+ """
+
+ def __init__(self):
+ self.part_ops = Signal(2, reset_less=True)
+ self.a_signed = Signal(reset_less=True)
+ self.b_signed = Signal(reset_less=True)
+
+ def elaborate(self, platform):
+
+ m = Module()
+
+ asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
+ bsig = (self.part_ops == OP_MUL_LOW) \
+ | (self.part_ops == OP_MUL_SIGNED_HIGH)
+ m.d.comb += self.a_signed.eq(asig)
+ m.d.comb += self.b_signed.eq(bsig)
+
+ return m
+
+