create a new "MaskedFullAdder" class, which performs the partition-carry mask
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 """
58 if name is None:
59 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
60 retval = PartitionPoints()
61 for point, enabled in self.items():
62 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
63 return retval
64
65 def eq(self, rhs):
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self.keys()) != set(rhs.keys()):
68 raise ValueError("incompatible point set")
69 for point, enabled in self.items():
70 yield enabled.eq(rhs[point])
71
72 def as_mask(self, width):
73 """Create a bit-mask from `self`.
74
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
77
78 :param width: the bit width of the resulting mask
79 """
80 bits = []
81 for i in range(width):
82 if i in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106
107 class FullAdder(Elaboratable):
108 """Full Adder.
109
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
115
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
120 """
121
122 def __init__(self, width):
123 """Create a ``FullAdder``.
124
125 :param width: the bit width of the input and output
126 """
127 self.in0 = Signal(width)
128 self.in1 = Signal(width)
129 self.in2 = Signal(width)
130 self.sum = Signal(width)
131 self.carry = Signal(width)
132
133 def elaborate(self, platform):
134 """Elaborate this module."""
135 m = Module()
136 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
137 m.d.comb += self.carry.eq((self.in0 & self.in1)
138 | (self.in1 & self.in2)
139 | (self.in2 & self.in0))
140 return m
141
142
143 class MaskedFullAdder(FullAdder):
144 """Masked Full Adder.
145
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
152
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
156 """
157
158 def __init__(self, width):
159 """Create a ``MaskedFullAdder``.
160
161 :param width: the bit width of the input and output
162 """
163 FullAdder.__init__(self, width)
164 self.mask = Signal(width)
165 self.mcarry = Signal(width)
166
167 def elaborate(self, platform):
168 """Elaborate this module."""
169 m = FullAdder.elaborate(self, platform)
170 m.d.comb += self.mcarry.eq((self.carry << 1) & self.mask)
171 return m
172
173
174 class PartitionedAdder(Elaboratable):
175 """Partitioned Adder.
176
177 Performs the final add. The partition points are included in the
178 actual add (in one of the operands only), which causes a carry over
179 to the next bit. Then the final output *removes* the extra bits from
180 the result.
181
182 :attribute width: the bit width of the input and output. Read-only.
183 :attribute a: the first input to the adder
184 :attribute b: the second input to the adder
185 :attribute output: the sum output
186 :attribute partition_points: the input partition points. Modification not
187 supported, except for by ``Signal.eq``.
188 """
189
190 def __init__(self, width, partition_points):
191 """Create a ``PartitionedAdder``.
192
193 :param width: the bit width of the input and output
194 :param partition_points: the input partition points
195 """
196 self.width = width
197 self.a = Signal(width)
198 self.b = Signal(width)
199 self.output = Signal(width)
200 self.partition_points = PartitionPoints(partition_points)
201 if not self.partition_points.fits_in_width(width):
202 raise ValueError("partition_points doesn't fit in width")
203 expanded_width = 0
204 for i in range(self.width):
205 if i in self.partition_points:
206 expanded_width += 1
207 expanded_width += 1
208 self._expanded_width = expanded_width
209 # XXX these have to remain here due to some horrible nmigen
210 # simulation bugs involving sync. it is *not* necessary to
211 # have them here, they should (under normal circumstances)
212 # be moved into elaborate, as they are entirely local
213 self._expanded_a = Signal(expanded_width)
214 self._expanded_b = Signal(expanded_width)
215 self._expanded_output = Signal(expanded_width)
216
217 def elaborate(self, platform):
218 """Elaborate this module."""
219 m = Module()
220 expanded_index = 0
221 # store bits in a list, use Cat later. graphviz is much cleaner
222 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
223
224 # partition points are "breaks" (extra zeros or 1s) in what would
225 # otherwise be a massive long add. when the "break" points are 0,
226 # whatever is in it (in the output) is discarded. however when
227 # there is a "1", it causes a roll-over carry to the *next* bit.
228 # we still ignore the "break" bit in the [intermediate] output,
229 # however by that time we've got the effect that we wanted: the
230 # carry has been carried *over* the break point.
231
232 for i in range(self.width):
233 if i in self.partition_points:
234 # add extra bit set to 0 + 0 for enabled partition points
235 # and 1 + 0 for disabled partition points
236 ea.append(self._expanded_a[expanded_index])
237 al.append(~self.partition_points[i]) # add extra bit in a
238 eb.append(self._expanded_b[expanded_index])
239 bl.append(C(0)) # do *not* add extra bit into b.
240 expanded_index += 1
241 ea.append(self._expanded_a[expanded_index])
242 al.append(self.a[i])
243 eb.append(self._expanded_b[expanded_index])
244 bl.append(self.b[i])
245 eo.append(self._expanded_output[expanded_index])
246 ol.append(self.output[i])
247 expanded_index += 1
248
249 # combine above using Cat
250 m.d.comb += Cat(*ea).eq(Cat(*al))
251 m.d.comb += Cat(*eb).eq(Cat(*bl))
252 m.d.comb += Cat(*ol).eq(Cat(*eo))
253
254 # use only one addition to take advantage of look-ahead carry and
255 # special hardware on FPGAs
256 m.d.comb += self._expanded_output.eq(
257 self._expanded_a + self._expanded_b)
258 return m
259
260
261 FULL_ADDER_INPUT_COUNT = 3
262
263
264 class AddReduce(Elaboratable):
265 """Add list of numbers together.
266
267 :attribute inputs: input ``Signal``s to be summed. Modification not
268 supported, except for by ``Signal.eq``.
269 :attribute register_levels: List of nesting levels that should have
270 pipeline registers.
271 :attribute output: output sum.
272 :attribute partition_points: the input partition points. Modification not
273 supported, except for by ``Signal.eq``.
274 """
275
276 def __init__(self, inputs, output_width, register_levels, partition_points):
277 """Create an ``AddReduce``.
278
279 :param inputs: input ``Signal``s to be summed.
280 :param output_width: bit-width of ``output``.
281 :param register_levels: List of nesting levels that should have
282 pipeline registers.
283 :param partition_points: the input partition points.
284 """
285 self.inputs = list(inputs)
286 self._resized_inputs = [
287 Signal(output_width, name=f"resized_inputs[{i}]")
288 for i in range(len(self.inputs))]
289 self.register_levels = list(register_levels)
290 self.output = Signal(output_width)
291 self.partition_points = PartitionPoints(partition_points)
292 if not self.partition_points.fits_in_width(output_width):
293 raise ValueError("partition_points doesn't fit in output_width")
294 self._reg_partition_points = self.partition_points.like()
295 max_level = AddReduce.get_max_level(len(self.inputs))
296 for level in self.register_levels:
297 if level > max_level:
298 raise ValueError(
299 "not enough adder levels for specified register levels")
300
301 @staticmethod
302 def get_max_level(input_count):
303 """Get the maximum level.
304
305 All ``register_levels`` must be less than or equal to the maximum
306 level.
307 """
308 retval = 0
309 while True:
310 groups = AddReduce.full_adder_groups(input_count)
311 if len(groups) == 0:
312 return retval
313 input_count %= FULL_ADDER_INPUT_COUNT
314 input_count += 2 * len(groups)
315 retval += 1
316
317 def next_register_levels(self):
318 """``Iterable`` of ``register_levels`` for next recursive level."""
319 for level in self.register_levels:
320 if level > 0:
321 yield level - 1
322
323 @staticmethod
324 def full_adder_groups(input_count):
325 """Get ``inputs`` indices for which a full adder should be built."""
326 return range(0,
327 input_count - FULL_ADDER_INPUT_COUNT + 1,
328 FULL_ADDER_INPUT_COUNT)
329
330 def elaborate(self, platform):
331 """Elaborate this module."""
332 m = Module()
333
334 # resize inputs to correct bit-width and optionally add in
335 # pipeline registers
336 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
337 for i in range(len(self.inputs))]
338 if 0 in self.register_levels:
339 m.d.sync += resized_input_assignments
340 m.d.sync += self._reg_partition_points.eq(self.partition_points)
341 else:
342 m.d.comb += resized_input_assignments
343 m.d.comb += self._reg_partition_points.eq(self.partition_points)
344
345 groups = AddReduce.full_adder_groups(len(self.inputs))
346 # if there are no full adders to create, then we handle the base cases
347 # and return, otherwise we go on to the recursive case
348 if len(groups) == 0:
349 if len(self.inputs) == 0:
350 # use 0 as the default output value
351 m.d.comb += self.output.eq(0)
352 elif len(self.inputs) == 1:
353 # handle single input
354 m.d.comb += self.output.eq(self._resized_inputs[0])
355 else:
356 # base case for adding 2 or more inputs, which get recursively
357 # reduced to 2 inputs
358 assert len(self.inputs) == 2
359 adder = PartitionedAdder(len(self.output),
360 self._reg_partition_points)
361 m.submodules.final_adder = adder
362 m.d.comb += adder.a.eq(self._resized_inputs[0])
363 m.d.comb += adder.b.eq(self._resized_inputs[1])
364 m.d.comb += self.output.eq(adder.output)
365 return m
366 # go on to handle recursive case
367 intermediate_terms = []
368
369 def add_intermediate_term(value):
370 intermediate_term = Signal(
371 len(self.output),
372 name=f"intermediate_terms[{len(intermediate_terms)}]")
373 intermediate_terms.append(intermediate_term)
374 m.d.comb += intermediate_term.eq(value)
375
376 # store mask in intermediary (simplifies graph)
377 part_mask = Signal(len(self.output), reset_less=True)
378 mask = self._reg_partition_points.as_mask(len(self.output))
379 m.d.comb += part_mask.eq(mask)
380
381 # create full adders for this recursive level.
382 # this shrinks N terms to 2 * (N // 3) plus the remainder
383 for i in groups:
384 adder_i = MaskedFullAdder(len(self.output))
385 setattr(m.submodules, f"adder_{i}", adder_i)
386 m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
387 m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
388 m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
389 m.d.comb += adder_i.mask.eq(part_mask)
390 add_intermediate_term(adder_i.sum)
391 # mask out carry bits to prevent carries between partitions
392 add_intermediate_term(adder_i.mcarry)
393 # handle the remaining inputs.
394 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
395 add_intermediate_term(self._resized_inputs[-1])
396 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
397 # Just pass the terms to the next layer, since we wouldn't gain
398 # anything by using a half adder since there would still be 2 terms
399 # and just passing the terms to the next layer saves gates.
400 add_intermediate_term(self._resized_inputs[-2])
401 add_intermediate_term(self._resized_inputs[-1])
402 else:
403 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
404 # recursive invocation of ``AddReduce``
405 next_level = AddReduce(intermediate_terms,
406 len(self.output),
407 self.next_register_levels(),
408 self._reg_partition_points)
409 m.submodules.next_level = next_level
410 m.d.comb += self.output.eq(next_level.output)
411 return m
412
413
414 OP_MUL_LOW = 0
415 OP_MUL_SIGNED_HIGH = 1
416 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
417 OP_MUL_UNSIGNED_HIGH = 3
418
419
420 def get_term(value, shift=0, enabled=None):
421 if enabled is not None:
422 value = Mux(enabled, value, 0)
423 if shift > 0:
424 value = Cat(Repl(C(0, 1), shift), value)
425 else:
426 assert shift == 0
427 return value
428
429
430 class ProductTerm(Elaboratable):
431 """ this class creates a single product term (a[..]*b[..]).
432 it has a design flaw in that is the *output* that is selected,
433 where the multiplication(s) are combinatorially generated
434 all the time.
435 """
436
437 def __init__(self, width, twidth, pbwid, a_index, b_index):
438 self.a_index = a_index
439 self.b_index = b_index
440 shift = 8 * (self.a_index + self.b_index)
441 self.pwidth = width
442 self.twidth = twidth
443 self.width = width*2
444 self.shift = shift
445
446 self.ti = Signal(self.width, reset_less=True)
447 self.term = Signal(twidth, reset_less=True)
448 self.a = Signal(twidth//2, reset_less=True)
449 self.b = Signal(twidth//2, reset_less=True)
450 self.pb_en = Signal(pbwid, reset_less=True)
451
452 self.tl = tl = []
453 min_index = min(self.a_index, self.b_index)
454 max_index = max(self.a_index, self.b_index)
455 for i in range(min_index, max_index):
456 tl.append(self.pb_en[i])
457 name = "te_%d_%d" % (self.a_index, self.b_index)
458 if len(tl) > 0:
459 term_enabled = Signal(name=name, reset_less=True)
460 else:
461 term_enabled = None
462 self.enabled = term_enabled
463 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
464
465 def elaborate(self, platform):
466
467 m = Module()
468 if self.enabled is not None:
469 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
470
471 bsa = Signal(self.width, reset_less=True)
472 bsb = Signal(self.width, reset_less=True)
473 a_index, b_index = self.a_index, self.b_index
474 pwidth = self.pwidth
475 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
476 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
477 m.d.comb += self.ti.eq(bsa * bsb)
478 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
479 """
480 #TODO: sort out width issues, get inputs a/b switched on/off.
481 #data going into Muxes is 1/2 the required width
482
483 pwidth = self.pwidth
484 width = self.width
485 bsa = Signal(self.twidth//2, reset_less=True)
486 bsb = Signal(self.twidth//2, reset_less=True)
487 asel = Signal(width, reset_less=True)
488 bsel = Signal(width, reset_less=True)
489 a_index, b_index = self.a_index, self.b_index
490 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
491 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
492 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
493 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
494 m.d.comb += self.ti.eq(bsa * bsb)
495 m.d.comb += self.term.eq(self.ti)
496 """
497
498 return m
499
500
501 class ProductTerms(Elaboratable):
502 """ creates a bank of product terms. also performs the actual bit-selection
503 this class is to be wrapped with a for-loop on the "a" operand.
504 it creates a second-level for-loop on the "b" operand.
505 """
506 def __init__(self, width, twidth, pbwid, a_index, blen):
507 self.a_index = a_index
508 self.blen = blen
509 self.pwidth = width
510 self.twidth = twidth
511 self.pbwid = pbwid
512 self.a = Signal(twidth//2, reset_less=True)
513 self.b = Signal(twidth//2, reset_less=True)
514 self.pb_en = Signal(pbwid, reset_less=True)
515 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
516 for i in range(blen)]
517
518 def elaborate(self, platform):
519
520 m = Module()
521
522 for b_index in range(self.blen):
523 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
524 self.a_index, b_index)
525 setattr(m.submodules, "term_%d" % b_index, t)
526
527 m.d.comb += t.a.eq(self.a)
528 m.d.comb += t.b.eq(self.b)
529 m.d.comb += t.pb_en.eq(self.pb_en)
530
531 m.d.comb += self.terms[b_index].eq(t.term)
532
533 return m
534
535 class LSBNegTerm(Elaboratable):
536
537 def __init__(self, bit_width):
538 self.bit_width = bit_width
539 self.part = Signal(reset_less=True)
540 self.signed = Signal(reset_less=True)
541 self.op = Signal(bit_width, reset_less=True)
542 self.msb = Signal(reset_less=True)
543 self.nt = Signal(bit_width*2, reset_less=True)
544 self.nl = Signal(bit_width*2, reset_less=True)
545
546 def elaborate(self, platform):
547 m = Module()
548 comb = m.d.comb
549 bit_wid = self.bit_width
550 ext = Repl(0, bit_wid) # extend output to HI part
551
552 # determine sign of each incoming number *in this partition*
553 enabled = Signal(reset_less=True)
554 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
555
556 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
557 # negation operation is split into a bitwise not and a +1.
558 # likewise for 16, 32, and 64-bit values.
559
560 # width-extended 1s complement if a is signed, otherwise zero
561 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
562
563 # add 1 if signed, otherwise add zero
564 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
565
566 return m
567
568
569 class Part(Elaboratable):
570 """ a key class which, depending on the partitioning, will determine
571 what action to take when parts of the output are signed or unsigned.
572
573 this requires 2 pieces of data *per operand, per partition*:
574 whether the MSB is HI/LO (per partition!), and whether a signed
575 or unsigned operation has been *requested*.
576
577 once that is determined, signed is basically carried out
578 by splitting 2's complement into 1's complement plus one.
579 1's complement is just a bit-inversion.
580
581 the extra terms - as separate terms - are then thrown at the
582 AddReduce alongside the multiplication part-results.
583 """
584 def __init__(self, width, n_parts, n_levels, pbwid):
585
586 # inputs
587 self.a = Signal(64)
588 self.b = Signal(64)
589 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
590 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
591 self.pbs = Signal(pbwid, reset_less=True)
592
593 # outputs
594 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
595 self.delayed_parts = [
596 [Signal(name=f"delayed_part_{delay}_{i}")
597 for i in range(n_parts)]
598 for delay in range(n_levels)]
599 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
600 self.dplast = [Signal(name=f"dplast_{i}")
601 for i in range(n_parts)]
602
603 self.not_a_term = Signal(width)
604 self.neg_lsb_a_term = Signal(width)
605 self.not_b_term = Signal(width)
606 self.neg_lsb_b_term = Signal(width)
607
608 def elaborate(self, platform):
609 m = Module()
610
611 pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
612 # negated-temporary copy of partition bits
613 npbs = Signal.like(pbs, reset_less=True)
614 m.d.comb += npbs.eq(~pbs)
615 byte_count = 8 // len(parts)
616 for i in range(len(parts)):
617 pbl = []
618 pbl.append(npbs[i * byte_count - 1])
619 for j in range(i * byte_count, (i + 1) * byte_count - 1):
620 pbl.append(pbs[j])
621 pbl.append(npbs[(i + 1) * byte_count - 1])
622 value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
623 m.d.comb += value.eq(Cat(*pbl))
624 m.d.comb += parts[i].eq(~(value).bool())
625 m.d.comb += delayed_parts[0][i].eq(parts[i])
626 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
627 for j in range(len(delayed_parts)-1)]
628 m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
629
630 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
631 self.not_a_term, self.neg_lsb_a_term, \
632 self.not_b_term, self.neg_lsb_b_term
633
634 byte_width = 8 // len(parts) # byte width
635 bit_wid = 8 * byte_width # bit width
636 nat, nbt, nla, nlb = [], [], [], []
637 for i in range(len(parts)):
638 # work out bit-inverted and +1 term for a.
639 pa = LSBNegTerm(bit_wid)
640 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
641 m.d.comb += pa.part.eq(parts[i])
642 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
643 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
644 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
645 nat.append(pa.nt)
646 nla.append(pa.nl)
647
648 # work out bit-inverted and +1 term for b
649 pb = LSBNegTerm(bit_wid)
650 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
651 m.d.comb += pb.part.eq(parts[i])
652 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
653 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
654 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
655 nbt.append(pb.nt)
656 nlb.append(pb.nl)
657
658 # concatenate together and return all 4 results.
659 m.d.comb += [not_a_term.eq(Cat(*nat)),
660 not_b_term.eq(Cat(*nbt)),
661 neg_lsb_a_term.eq(Cat(*nla)),
662 neg_lsb_b_term.eq(Cat(*nlb)),
663 ]
664
665 return m
666
667
668 class IntermediateOut(Elaboratable):
669 """ selects the HI/LO part of the multiplication, for a given bit-width
670 the output is also reconstructed in its SIMD (partition) lanes.
671 """
672 def __init__(self, width, out_wid, n_parts):
673 self.width = width
674 self.n_parts = n_parts
675 self.delayed_part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
676 for i in range(8)]
677 self.intermed = Signal(out_wid, reset_less=True)
678 self.output = Signal(out_wid//2, reset_less=True)
679
680 def elaborate(self, platform):
681 m = Module()
682
683 ol = []
684 w = self.width
685 sel = w // 8
686 for i in range(self.n_parts):
687 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
688 m.d.comb += op.eq(
689 Mux(self.delayed_part_ops[sel * i] == OP_MUL_LOW,
690 self.intermed.bit_select(i * w*2, w),
691 self.intermed.bit_select(i * w*2 + w, w)))
692 ol.append(op)
693 m.d.comb += self.output.eq(Cat(*ol))
694
695 return m
696
697
698 class FinalOut(Elaboratable):
699 """ selects the final output based on the partitioning.
700
701 each byte is selectable independently, i.e. it is possible
702 that some partitions requested 8-bit computation whilst others
703 requested 16 or 32 bit.
704 """
705 def __init__(self, out_wid):
706 # inputs
707 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
708 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
709 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
710
711 self.i8 = Signal(out_wid, reset_less=True)
712 self.i16 = Signal(out_wid, reset_less=True)
713 self.i32 = Signal(out_wid, reset_less=True)
714 self.i64 = Signal(out_wid, reset_less=True)
715
716 # output
717 self.out = Signal(out_wid, reset_less=True)
718
719 def elaborate(self, platform):
720 m = Module()
721 ol = []
722 for i in range(8):
723 # select one of the outputs: d8 selects i8, d16 selects i16
724 # d32 selects i32, and the default is i64.
725 # d8 and d16 are ORed together in the first Mux
726 # then the 2nd selects either i8 or i16.
727 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
728 op = Signal(8, reset_less=True, name="op_%d" % i)
729 m.d.comb += op.eq(
730 Mux(self.d8[i] | self.d16[i // 2],
731 Mux(self.d8[i], self.i8.bit_select(i * 8, 8),
732 self.i16.bit_select(i * 8, 8)),
733 Mux(self.d32[i // 4], self.i32.bit_select(i * 8, 8),
734 self.i64.bit_select(i * 8, 8))))
735 ol.append(op)
736 m.d.comb += self.out.eq(Cat(*ol))
737 return m
738
739
740 class OrMod(Elaboratable):
741 """ ORs four values together in a hierarchical tree
742 """
743 def __init__(self, wid):
744 self.wid = wid
745 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
746 for i in range(4)]
747 self.orout = Signal(wid, reset_less=True)
748
749 def elaborate(self, platform):
750 m = Module()
751 or1 = Signal(self.wid, reset_less=True)
752 or2 = Signal(self.wid, reset_less=True)
753 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
754 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
755 m.d.comb += self.orout.eq(or1 | or2)
756
757 return m
758
759
760 class Signs(Elaboratable):
761 """ determines whether a or b are signed numbers
762 based on the required operation type (OP_MUL_*)
763 """
764
765 def __init__(self):
766 self.part_ops = Signal(2, reset_less=True)
767 self.a_signed = Signal(reset_less=True)
768 self.b_signed = Signal(reset_less=True)
769
770 def elaborate(self, platform):
771
772 m = Module()
773
774 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
775 bsig = (self.part_ops == OP_MUL_LOW) \
776 | (self.part_ops == OP_MUL_SIGNED_HIGH)
777 m.d.comb += self.a_signed.eq(asig)
778 m.d.comb += self.b_signed.eq(bsig)
779
780 return m
781
782
783 class Mul8_16_32_64(Elaboratable):
784 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
785
786 Supports partitioning into any combination of 8, 16, 32, and 64-bit
787 partitions on naturally-aligned boundaries. Supports the operation being
788 set for each partition independently.
789
790 :attribute part_pts: the input partition points. Has a partition point at
791 multiples of 8 in 0 < i < 64. Each partition point's associated
792 ``Value`` is a ``Signal``. Modification not supported, except for by
793 ``Signal.eq``.
794 :attribute part_ops: the operation for each byte. The operation for a
795 particular partition is selected by assigning the selected operation
796 code to each byte in the partition. The allowed operation codes are:
797
798 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
799 RISC-V's `mul` instruction.
800 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
801 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
802 instruction.
803 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
804 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
805 `mulhsu` instruction.
806 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
807 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
808 instruction.
809 """
810
811 def __init__(self, register_levels=()):
812 """ register_levels: specifies the points in the cascade at which
813 flip-flops are to be inserted.
814 """
815
816 # parameter(s)
817 self.register_levels = list(register_levels)
818
819 # inputs
820 self.part_pts = PartitionPoints()
821 for i in range(8, 64, 8):
822 self.part_pts[i] = Signal(name=f"part_pts_{i}")
823 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
824 self.a = Signal(64)
825 self.b = Signal(64)
826
827 # intermediates (needed for unit tests)
828 self._intermediate_output = Signal(128)
829
830 # output
831 self.output = Signal(64)
832
833 def _part_byte(self, index):
834 if index == -1 or index == 7:
835 return C(True, 1)
836 assert index >= 0 and index < 8
837 return self.part_pts[index * 8 + 8]
838
839 def elaborate(self, platform):
840 m = Module()
841
842 # collect part-bytes
843 pbs = Signal(8, reset_less=True)
844 tl = []
845 for i in range(8):
846 pb = Signal(name="pb%d" % i, reset_less=True)
847 m.d.comb += pb.eq(self._part_byte(i))
848 tl.append(pb)
849 m.d.comb += pbs.eq(Cat(*tl))
850
851 # local variables
852 signs = []
853 for i in range(8):
854 s = Signs()
855 signs.append(s)
856 setattr(m.submodules, "signs%d" % i, s)
857 m.d.comb += s.part_ops.eq(self.part_ops[i])
858
859 delayed_part_ops = [
860 [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
861 for i in range(8)]
862 for delay in range(1 + len(self.register_levels))]
863 for i in range(len(self.part_ops)):
864 m.d.comb += delayed_part_ops[0][i].eq(self.part_ops[i])
865 m.d.sync += [delayed_part_ops[j + 1][i].eq(delayed_part_ops[j][i])
866 for j in range(len(self.register_levels))]
867
868 n_levels = len(self.register_levels)+1
869 m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
870 m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
871 m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
872 m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
873 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
874 for mod in [part_8, part_16, part_32, part_64]:
875 m.d.comb += mod.a.eq(self.a)
876 m.d.comb += mod.b.eq(self.b)
877 for i in range(len(signs)):
878 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
879 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
880 m.d.comb += mod.pbs.eq(pbs)
881 nat_l.append(mod.not_a_term)
882 nbt_l.append(mod.not_b_term)
883 nla_l.append(mod.neg_lsb_a_term)
884 nlb_l.append(mod.neg_lsb_b_term)
885
886 terms = []
887
888 for a_index in range(8):
889 t = ProductTerms(8, 128, 8, a_index, 8)
890 setattr(m.submodules, "terms_%d" % a_index, t)
891
892 m.d.comb += t.a.eq(self.a)
893 m.d.comb += t.b.eq(self.b)
894 m.d.comb += t.pb_en.eq(pbs)
895
896 for term in t.terms:
897 terms.append(term)
898
899 # it's fine to bitwise-or data together since they are never enabled
900 # at the same time
901 m.submodules.nat_or = nat_or = OrMod(128)
902 m.submodules.nbt_or = nbt_or = OrMod(128)
903 m.submodules.nla_or = nla_or = OrMod(128)
904 m.submodules.nlb_or = nlb_or = OrMod(128)
905 for l, mod in [(nat_l, nat_or),
906 (nbt_l, nbt_or),
907 (nla_l, nla_or),
908 (nlb_l, nlb_or)]:
909 for i in range(len(l)):
910 m.d.comb += mod.orin[i].eq(l[i])
911 terms.append(mod.orout)
912
913 expanded_part_pts = PartitionPoints()
914 for i, v in self.part_pts.items():
915 signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
916 expanded_part_pts[i * 2] = signal
917 m.d.comb += signal.eq(v)
918
919 add_reduce = AddReduce(terms,
920 128,
921 self.register_levels,
922 expanded_part_pts)
923 m.submodules.add_reduce = add_reduce
924 m.d.comb += self._intermediate_output.eq(add_reduce.output)
925 # create _output_64
926 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
927 m.d.comb += io64.intermed.eq(self._intermediate_output)
928 for i in range(8):
929 m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
930
931 # create _output_32
932 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
933 m.d.comb += io32.intermed.eq(self._intermediate_output)
934 for i in range(8):
935 m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
936
937 # create _output_16
938 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
939 m.d.comb += io16.intermed.eq(self._intermediate_output)
940 for i in range(8):
941 m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
942
943 # create _output_8
944 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
945 m.d.comb += io8.intermed.eq(self._intermediate_output)
946 for i in range(8):
947 m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
948
949 # final output
950 m.submodules.finalout = finalout = FinalOut(64)
951 for i in range(len(part_8.delayed_parts[-1])):
952 m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
953 for i in range(len(part_16.delayed_parts[-1])):
954 m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
955 for i in range(len(part_32.delayed_parts[-1])):
956 m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
957 m.d.comb += finalout.i8.eq(io8.output)
958 m.d.comb += finalout.i16.eq(io16.output)
959 m.d.comb += finalout.i32.eq(io32.output)
960 m.d.comb += finalout.i64.eq(io64.output)
961 m.d.comb += self.output.eq(finalout.out)
962
963 return m
964
965
966 if __name__ == "__main__":
967 m = Mul8_16_32_64()
968 main(m, ports=[m.a,
969 m.b,
970 m._intermediate_output,
971 m.output,
972 *m.part_ops,
973 *m.part_pts.values()])