add docstrings / comments to PartitionedAdder
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 """
58 if name is None:
59 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
60 retval = PartitionPoints()
61 for point, enabled in self.items():
62 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
63 return retval
64
65 def eq(self, rhs):
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self.keys()) != set(rhs.keys()):
68 raise ValueError("incompatible point set")
69 for point, enabled in self.items():
70 yield enabled.eq(rhs[point])
71
72 def as_mask(self, width):
73 """Create a bit-mask from `self`.
74
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
77
78 :param width: the bit width of the resulting mask
79 """
80 bits = []
81 for i in range(width):
82 if i in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106
107 class FullAdder(Elaboratable):
108 """Full Adder.
109
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
115
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
120 """
121
122 def __init__(self, width):
123 """Create a ``FullAdder``.
124
125 :param width: the bit width of the input and output
126 """
127 self.in0 = Signal(width)
128 self.in1 = Signal(width)
129 self.in2 = Signal(width)
130 self.sum = Signal(width)
131 self.carry = Signal(width)
132
133 def elaborate(self, platform):
134 """Elaborate this module."""
135 m = Module()
136 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
137 m.d.comb += self.carry.eq((self.in0 & self.in1)
138 | (self.in1 & self.in2)
139 | (self.in2 & self.in0))
140 return m
141
142
143 class PartitionedAdder(Elaboratable):
144 """Partitioned Adder.
145
146 Performs the final add. The partition points are included in the
147 actual add (in one of the operands only), which causes a carry over
148 to the next bit. Then the final output *removes* the extra bits from
149 the result.
150
151 :attribute width: the bit width of the input and output. Read-only.
152 :attribute a: the first input to the adder
153 :attribute b: the second input to the adder
154 :attribute output: the sum output
155 :attribute partition_points: the input partition points. Modification not
156 supported, except for by ``Signal.eq``.
157 """
158
159 def __init__(self, width, partition_points):
160 """Create a ``PartitionedAdder``.
161
162 :param width: the bit width of the input and output
163 :param partition_points: the input partition points
164 """
165 self.width = width
166 self.a = Signal(width)
167 self.b = Signal(width)
168 self.output = Signal(width)
169 self.partition_points = PartitionPoints(partition_points)
170 if not self.partition_points.fits_in_width(width):
171 raise ValueError("partition_points doesn't fit in width")
172 expanded_width = 0
173 for i in range(self.width):
174 if i in self.partition_points:
175 expanded_width += 1
176 expanded_width += 1
177 self._expanded_width = expanded_width
178 # XXX these have to remain here due to some horrible nmigen
179 # simulation bugs involving sync. it is *not* necessary to
180 # have them here, they should (under normal circumstances)
181 # be moved into elaborate, as they are entirely local
182 self._expanded_a = Signal(expanded_width)
183 self._expanded_b = Signal(expanded_width)
184 self._expanded_output = Signal(expanded_width)
185
186 def elaborate(self, platform):
187 """Elaborate this module."""
188 m = Module()
189 expanded_index = 0
190 # store bits in a list, use Cat later. graphviz is much cleaner
191 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
192
193 # partition points are "breaks" (extra zeros or 1s) in what would
194 # otherwise be a massive long add. when the "break" points are 0,
195 # whatever is in it (in the output) is discarded. however when
196 # there is a "1", it causes a roll-over carry to the *next* bit.
197 # we still ignore the "break" bit in the [intermediate] output,
198 # however by that time we've got the effect that we wanted: the
199 # carry has been carried *over* the break point.
200
201 for i in range(self.width):
202 if i in self.partition_points:
203 # add extra bit set to 0 + 0 for enabled partition points
204 # and 1 + 0 for disabled partition points
205 ea.append(self._expanded_a[expanded_index])
206 al.append(~self.partition_points[i]) # add extra bit in a
207 eb.append(self._expanded_b[expanded_index])
208 bl.append(C(0)) # do *not* add extra bit into b.
209 expanded_index += 1
210 ea.append(self._expanded_a[expanded_index])
211 al.append(self.a[i])
212 eb.append(self._expanded_b[expanded_index])
213 bl.append(self.b[i])
214 eo.append(self._expanded_output[expanded_index])
215 ol.append(self.output[i])
216 expanded_index += 1
217
218 # combine above using Cat
219 m.d.comb += Cat(*ea).eq(Cat(*al))
220 m.d.comb += Cat(*eb).eq(Cat(*bl))
221 m.d.comb += Cat(*ol).eq(Cat(*eo))
222
223 # use only one addition to take advantage of look-ahead carry and
224 # special hardware on FPGAs
225 m.d.comb += self._expanded_output.eq(
226 self._expanded_a + self._expanded_b)
227 return m
228
229
230 FULL_ADDER_INPUT_COUNT = 3
231
232
233 class AddReduce(Elaboratable):
234 """Add list of numbers together.
235
236 :attribute inputs: input ``Signal``s to be summed. Modification not
237 supported, except for by ``Signal.eq``.
238 :attribute register_levels: List of nesting levels that should have
239 pipeline registers.
240 :attribute output: output sum.
241 :attribute partition_points: the input partition points. Modification not
242 supported, except for by ``Signal.eq``.
243 """
244
245 def __init__(self, inputs, output_width, register_levels, partition_points):
246 """Create an ``AddReduce``.
247
248 :param inputs: input ``Signal``s to be summed.
249 :param output_width: bit-width of ``output``.
250 :param register_levels: List of nesting levels that should have
251 pipeline registers.
252 :param partition_points: the input partition points.
253 """
254 self.inputs = list(inputs)
255 self._resized_inputs = [
256 Signal(output_width, name=f"resized_inputs[{i}]")
257 for i in range(len(self.inputs))]
258 self.register_levels = list(register_levels)
259 self.output = Signal(output_width)
260 self.partition_points = PartitionPoints(partition_points)
261 if not self.partition_points.fits_in_width(output_width):
262 raise ValueError("partition_points doesn't fit in output_width")
263 self._reg_partition_points = self.partition_points.like()
264 max_level = AddReduce.get_max_level(len(self.inputs))
265 for level in self.register_levels:
266 if level > max_level:
267 raise ValueError(
268 "not enough adder levels for specified register levels")
269
270 @staticmethod
271 def get_max_level(input_count):
272 """Get the maximum level.
273
274 All ``register_levels`` must be less than or equal to the maximum
275 level.
276 """
277 retval = 0
278 while True:
279 groups = AddReduce.full_adder_groups(input_count)
280 if len(groups) == 0:
281 return retval
282 input_count %= FULL_ADDER_INPUT_COUNT
283 input_count += 2 * len(groups)
284 retval += 1
285
286 def next_register_levels(self):
287 """``Iterable`` of ``register_levels`` for next recursive level."""
288 for level in self.register_levels:
289 if level > 0:
290 yield level - 1
291
292 @staticmethod
293 def full_adder_groups(input_count):
294 """Get ``inputs`` indices for which a full adder should be built."""
295 return range(0,
296 input_count - FULL_ADDER_INPUT_COUNT + 1,
297 FULL_ADDER_INPUT_COUNT)
298
299 def elaborate(self, platform):
300 """Elaborate this module."""
301 m = Module()
302
303 # resize inputs to correct bit-width and optionally add in
304 # pipeline registers
305 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
306 for i in range(len(self.inputs))]
307 if 0 in self.register_levels:
308 m.d.sync += resized_input_assignments
309 m.d.sync += self._reg_partition_points.eq(self.partition_points)
310 else:
311 m.d.comb += resized_input_assignments
312 m.d.comb += self._reg_partition_points.eq(self.partition_points)
313
314 groups = AddReduce.full_adder_groups(len(self.inputs))
315 # if there are no full adders to create, then we handle the base cases
316 # and return, otherwise we go on to the recursive case
317 if len(groups) == 0:
318 if len(self.inputs) == 0:
319 # use 0 as the default output value
320 m.d.comb += self.output.eq(0)
321 elif len(self.inputs) == 1:
322 # handle single input
323 m.d.comb += self.output.eq(self._resized_inputs[0])
324 else:
325 # base case for adding 2 or more inputs, which get recursively
326 # reduced to 2 inputs
327 assert len(self.inputs) == 2
328 adder = PartitionedAdder(len(self.output),
329 self._reg_partition_points)
330 m.submodules.final_adder = adder
331 m.d.comb += adder.a.eq(self._resized_inputs[0])
332 m.d.comb += adder.b.eq(self._resized_inputs[1])
333 m.d.comb += self.output.eq(adder.output)
334 return m
335 # go on to handle recursive case
336 intermediate_terms = []
337
338 def add_intermediate_term(value):
339 intermediate_term = Signal(
340 len(self.output),
341 name=f"intermediate_terms[{len(intermediate_terms)}]")
342 intermediate_terms.append(intermediate_term)
343 m.d.comb += intermediate_term.eq(value)
344
345 # store mask in intermediary (simplifies graph)
346 part_mask = Signal(len(self.output), reset_less=True)
347 mask = self._reg_partition_points.as_mask(len(self.output))
348 m.d.comb += part_mask.eq(mask)
349
350 # create full adders for this recursive level.
351 # this shrinks N terms to 2 * (N // 3) plus the remainder
352 for i in groups:
353 adder_i = FullAdder(len(self.output))
354 setattr(m.submodules, f"adder_{i}", adder_i)
355 m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
356 m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
357 m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
358 add_intermediate_term(adder_i.sum)
359 shifted_carry = adder_i.carry << 1
360 # mask out carry bits to prevent carries between partitions
361 add_intermediate_term((adder_i.carry << 1) & part_mask)
362 # handle the remaining inputs.
363 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
364 add_intermediate_term(self._resized_inputs[-1])
365 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
366 # Just pass the terms to the next layer, since we wouldn't gain
367 # anything by using a half adder since there would still be 2 terms
368 # and just passing the terms to the next layer saves gates.
369 add_intermediate_term(self._resized_inputs[-2])
370 add_intermediate_term(self._resized_inputs[-1])
371 else:
372 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
373 # recursive invocation of ``AddReduce``
374 next_level = AddReduce(intermediate_terms,
375 len(self.output),
376 self.next_register_levels(),
377 self._reg_partition_points)
378 m.submodules.next_level = next_level
379 m.d.comb += self.output.eq(next_level.output)
380 return m
381
382
383 OP_MUL_LOW = 0
384 OP_MUL_SIGNED_HIGH = 1
385 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
386 OP_MUL_UNSIGNED_HIGH = 3
387
388
389 def get_term(value, shift=0, enabled=None):
390 if enabled is not None:
391 value = Mux(enabled, value, 0)
392 if shift > 0:
393 value = Cat(Repl(C(0, 1), shift), value)
394 else:
395 assert shift == 0
396 return value
397
398
399 class ProductTerm(Elaboratable):
400 """ this class creates a single product term (a[..]*b[..]).
401 it has a design flaw in that is the *output* that is selected,
402 where the multiplication(s) are combinatorially generated
403 all the time.
404 """
405
406 def __init__(self, width, twidth, pbwid, a_index, b_index):
407 self.a_index = a_index
408 self.b_index = b_index
409 shift = 8 * (self.a_index + self.b_index)
410 self.pwidth = width
411 self.twidth = twidth
412 self.width = width*2
413 self.shift = shift
414
415 self.ti = Signal(self.width, reset_less=True)
416 self.term = Signal(twidth, reset_less=True)
417 self.a = Signal(twidth//2, reset_less=True)
418 self.b = Signal(twidth//2, reset_less=True)
419 self.pb_en = Signal(pbwid, reset_less=True)
420
421 self.tl = tl = []
422 min_index = min(self.a_index, self.b_index)
423 max_index = max(self.a_index, self.b_index)
424 for i in range(min_index, max_index):
425 tl.append(self.pb_en[i])
426 name = "te_%d_%d" % (self.a_index, self.b_index)
427 if len(tl) > 0:
428 term_enabled = Signal(name=name, reset_less=True)
429 else:
430 term_enabled = None
431 self.enabled = term_enabled
432 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
433
434 def elaborate(self, platform):
435
436 m = Module()
437 if self.enabled is not None:
438 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
439
440 bsa = Signal(self.width, reset_less=True)
441 bsb = Signal(self.width, reset_less=True)
442 a_index, b_index = self.a_index, self.b_index
443 pwidth = self.pwidth
444 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
445 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
446 m.d.comb += self.ti.eq(bsa * bsb)
447 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
448 """
449 #TODO: sort out width issues, get inputs a/b switched on/off.
450 #data going into Muxes is 1/2 the required width
451
452 pwidth = self.pwidth
453 width = self.width
454 bsa = Signal(self.twidth//2, reset_less=True)
455 bsb = Signal(self.twidth//2, reset_less=True)
456 asel = Signal(width, reset_less=True)
457 bsel = Signal(width, reset_less=True)
458 a_index, b_index = self.a_index, self.b_index
459 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
460 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
461 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
462 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
463 m.d.comb += self.ti.eq(bsa * bsb)
464 m.d.comb += self.term.eq(self.ti)
465 """
466
467 return m
468
469
470 class ProductTerms(Elaboratable):
471 """ creates a bank of product terms. also performs the actual bit-selection
472 this class is to be wrapped with a for-loop on the "a" operand.
473 it creates a second-level for-loop on the "b" operand.
474 """
475 def __init__(self, width, twidth, pbwid, a_index, blen):
476 self.a_index = a_index
477 self.blen = blen
478 self.pwidth = width
479 self.twidth = twidth
480 self.pbwid = pbwid
481 self.a = Signal(twidth//2, reset_less=True)
482 self.b = Signal(twidth//2, reset_less=True)
483 self.pb_en = Signal(pbwid, reset_less=True)
484 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
485 for i in range(blen)]
486
487 def elaborate(self, platform):
488
489 m = Module()
490
491 for b_index in range(self.blen):
492 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
493 self.a_index, b_index)
494 setattr(m.submodules, "term_%d" % b_index, t)
495
496 m.d.comb += t.a.eq(self.a)
497 m.d.comb += t.b.eq(self.b)
498 m.d.comb += t.pb_en.eq(self.pb_en)
499
500 m.d.comb += self.terms[b_index].eq(t.term)
501
502 return m
503
504 class LSBNegTerm(Elaboratable):
505
506 def __init__(self, bit_width):
507 self.bit_width = bit_width
508 self.part = Signal(reset_less=True)
509 self.signed = Signal(reset_less=True)
510 self.op = Signal(bit_width, reset_less=True)
511 self.msb = Signal(reset_less=True)
512 self.nt = Signal(bit_width*2, reset_less=True)
513 self.nl = Signal(bit_width*2, reset_less=True)
514
515 def elaborate(self, platform):
516 m = Module()
517 comb = m.d.comb
518 bit_wid = self.bit_width
519 ext = Repl(0, bit_wid) # extend output to HI part
520
521 # determine sign of each incoming number *in this partition*
522 enabled = Signal(reset_less=True)
523 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
524
525 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
526 # negation operation is split into a bitwise not and a +1.
527 # likewise for 16, 32, and 64-bit values.
528
529 # width-extended 1s complement if a is signed, otherwise zero
530 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
531
532 # add 1 if signed, otherwise add zero
533 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
534
535 return m
536
537
538 class Part(Elaboratable):
539 """ a key class which, depending on the partitioning, will determine
540 what action to take when parts of the output are signed or unsigned.
541
542 this requires 2 pieces of data *per operand, per partition*:
543 whether the MSB is HI/LO (per partition!), and whether a signed
544 or unsigned operation has been *requested*.
545
546 once that is determined, signed is basically carried out
547 by splitting 2's complement into 1's complement plus one.
548 1's complement is just a bit-inversion.
549
550 the extra terms - as separate terms - are then thrown at the
551 AddReduce alongside the multiplication part-results.
552 """
553 def __init__(self, width, n_parts, n_levels, pbwid):
554
555 # inputs
556 self.a = Signal(64)
557 self.b = Signal(64)
558 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
559 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
560 self.pbs = Signal(pbwid, reset_less=True)
561
562 # outputs
563 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
564 self.delayed_parts = [
565 [Signal(name=f"delayed_part_{delay}_{i}")
566 for i in range(n_parts)]
567 for delay in range(n_levels)]
568 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
569 self.dplast = [Signal(name=f"dplast_{i}")
570 for i in range(n_parts)]
571
572 self.not_a_term = Signal(width)
573 self.neg_lsb_a_term = Signal(width)
574 self.not_b_term = Signal(width)
575 self.neg_lsb_b_term = Signal(width)
576
577 def elaborate(self, platform):
578 m = Module()
579
580 pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
581 # negated-temporary copy of partition bits
582 npbs = Signal.like(pbs, reset_less=True)
583 m.d.comb += npbs.eq(~pbs)
584 byte_count = 8 // len(parts)
585 for i in range(len(parts)):
586 pbl = []
587 pbl.append(npbs[i * byte_count - 1])
588 for j in range(i * byte_count, (i + 1) * byte_count - 1):
589 pbl.append(pbs[j])
590 pbl.append(npbs[(i + 1) * byte_count - 1])
591 value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
592 m.d.comb += value.eq(Cat(*pbl))
593 m.d.comb += parts[i].eq(~(value).bool())
594 m.d.comb += delayed_parts[0][i].eq(parts[i])
595 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
596 for j in range(len(delayed_parts)-1)]
597 m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
598
599 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
600 self.not_a_term, self.neg_lsb_a_term, \
601 self.not_b_term, self.neg_lsb_b_term
602
603 byte_width = 8 // len(parts) # byte width
604 bit_wid = 8 * byte_width # bit width
605 nat, nbt, nla, nlb = [], [], [], []
606 for i in range(len(parts)):
607 # work out bit-inverted and +1 term for a.
608 pa = LSBNegTerm(bit_wid)
609 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
610 m.d.comb += pa.part.eq(parts[i])
611 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
612 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
613 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
614 nat.append(pa.nt)
615 nla.append(pa.nl)
616
617 # work out bit-inverted and +1 term for b
618 pb = LSBNegTerm(bit_wid)
619 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
620 m.d.comb += pb.part.eq(parts[i])
621 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
622 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
623 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
624 nbt.append(pb.nt)
625 nlb.append(pb.nl)
626
627 # concatenate together and return all 4 results.
628 m.d.comb += [not_a_term.eq(Cat(*nat)),
629 not_b_term.eq(Cat(*nbt)),
630 neg_lsb_a_term.eq(Cat(*nla)),
631 neg_lsb_b_term.eq(Cat(*nlb)),
632 ]
633
634 return m
635
636
637 class IntermediateOut(Elaboratable):
638 """ selects the HI/LO part of the multiplication, for a given bit-width
639 the output is also reconstructed in its SIMD (partition) lanes.
640 """
641 def __init__(self, width, out_wid, n_parts):
642 self.width = width
643 self.n_parts = n_parts
644 self.delayed_part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
645 for i in range(8)]
646 self.intermed = Signal(out_wid, reset_less=True)
647 self.output = Signal(out_wid//2, reset_less=True)
648
649 def elaborate(self, platform):
650 m = Module()
651
652 ol = []
653 w = self.width
654 sel = w // 8
655 for i in range(self.n_parts):
656 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
657 m.d.comb += op.eq(
658 Mux(self.delayed_part_ops[sel * i] == OP_MUL_LOW,
659 self.intermed.bit_select(i * w*2, w),
660 self.intermed.bit_select(i * w*2 + w, w)))
661 ol.append(op)
662 m.d.comb += self.output.eq(Cat(*ol))
663
664 return m
665
666
667 class FinalOut(Elaboratable):
668 """ selects the final output based on the partitioning.
669
670 each byte is selectable independently, i.e. it is possible
671 that some partitions requested 8-bit computation whilst others
672 requested 16 or 32 bit.
673 """
674 def __init__(self, out_wid):
675 # inputs
676 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
677 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
678 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
679
680 self.i8 = Signal(out_wid, reset_less=True)
681 self.i16 = Signal(out_wid, reset_less=True)
682 self.i32 = Signal(out_wid, reset_less=True)
683 self.i64 = Signal(out_wid, reset_less=True)
684
685 # output
686 self.out = Signal(out_wid, reset_less=True)
687
688 def elaborate(self, platform):
689 m = Module()
690 ol = []
691 for i in range(8):
692 # select one of the outputs: d8 selects i8, d16 selects i16
693 # d32 selects i32, and the default is i64.
694 # d8 and d16 are ORed together in the first Mux
695 # then the 2nd selects either i8 or i16.
696 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
697 op = Signal(8, reset_less=True, name="op_%d" % i)
698 m.d.comb += op.eq(
699 Mux(self.d8[i] | self.d16[i // 2],
700 Mux(self.d8[i], self.i8.bit_select(i * 8, 8),
701 self.i16.bit_select(i * 8, 8)),
702 Mux(self.d32[i // 4], self.i32.bit_select(i * 8, 8),
703 self.i64.bit_select(i * 8, 8))))
704 ol.append(op)
705 m.d.comb += self.out.eq(Cat(*ol))
706 return m
707
708
709 class OrMod(Elaboratable):
710 """ ORs four values together in a hierarchical tree
711 """
712 def __init__(self, wid):
713 self.wid = wid
714 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
715 for i in range(4)]
716 self.orout = Signal(wid, reset_less=True)
717
718 def elaborate(self, platform):
719 m = Module()
720 or1 = Signal(self.wid, reset_less=True)
721 or2 = Signal(self.wid, reset_less=True)
722 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
723 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
724 m.d.comb += self.orout.eq(or1 | or2)
725
726 return m
727
728
729 class Signs(Elaboratable):
730 """ determines whether a or b are signed numbers
731 based on the required operation type (OP_MUL_*)
732 """
733
734 def __init__(self):
735 self.part_ops = Signal(2, reset_less=True)
736 self.a_signed = Signal(reset_less=True)
737 self.b_signed = Signal(reset_less=True)
738
739 def elaborate(self, platform):
740
741 m = Module()
742
743 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
744 bsig = (self.part_ops == OP_MUL_LOW) \
745 | (self.part_ops == OP_MUL_SIGNED_HIGH)
746 m.d.comb += self.a_signed.eq(asig)
747 m.d.comb += self.b_signed.eq(bsig)
748
749 return m
750
751
752 class Mul8_16_32_64(Elaboratable):
753 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
754
755 Supports partitioning into any combination of 8, 16, 32, and 64-bit
756 partitions on naturally-aligned boundaries. Supports the operation being
757 set for each partition independently.
758
759 :attribute part_pts: the input partition points. Has a partition point at
760 multiples of 8 in 0 < i < 64. Each partition point's associated
761 ``Value`` is a ``Signal``. Modification not supported, except for by
762 ``Signal.eq``.
763 :attribute part_ops: the operation for each byte. The operation for a
764 particular partition is selected by assigning the selected operation
765 code to each byte in the partition. The allowed operation codes are:
766
767 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
768 RISC-V's `mul` instruction.
769 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
770 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
771 instruction.
772 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
773 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
774 `mulhsu` instruction.
775 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
776 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
777 instruction.
778 """
779
780 def __init__(self, register_levels=()):
781 """ register_levels: specifies the points in the cascade at which
782 flip-flops are to be inserted.
783 """
784
785 # parameter(s)
786 self.register_levels = list(register_levels)
787
788 # inputs
789 self.part_pts = PartitionPoints()
790 for i in range(8, 64, 8):
791 self.part_pts[i] = Signal(name=f"part_pts_{i}")
792 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
793 self.a = Signal(64)
794 self.b = Signal(64)
795
796 # intermediates (needed for unit tests)
797 self._intermediate_output = Signal(128)
798
799 # output
800 self.output = Signal(64)
801
802 def _part_byte(self, index):
803 if index == -1 or index == 7:
804 return C(True, 1)
805 assert index >= 0 and index < 8
806 return self.part_pts[index * 8 + 8]
807
808 def elaborate(self, platform):
809 m = Module()
810
811 # collect part-bytes
812 pbs = Signal(8, reset_less=True)
813 tl = []
814 for i in range(8):
815 pb = Signal(name="pb%d" % i, reset_less=True)
816 m.d.comb += pb.eq(self._part_byte(i))
817 tl.append(pb)
818 m.d.comb += pbs.eq(Cat(*tl))
819
820 # local variables
821 signs = []
822 for i in range(8):
823 s = Signs()
824 signs.append(s)
825 setattr(m.submodules, "signs%d" % i, s)
826 m.d.comb += s.part_ops.eq(self.part_ops[i])
827
828 delayed_part_ops = [
829 [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
830 for i in range(8)]
831 for delay in range(1 + len(self.register_levels))]
832 for i in range(len(self.part_ops)):
833 m.d.comb += delayed_part_ops[0][i].eq(self.part_ops[i])
834 m.d.sync += [delayed_part_ops[j + 1][i].eq(delayed_part_ops[j][i])
835 for j in range(len(self.register_levels))]
836
837 n_levels = len(self.register_levels)+1
838 m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
839 m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
840 m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
841 m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
842 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
843 for mod in [part_8, part_16, part_32, part_64]:
844 m.d.comb += mod.a.eq(self.a)
845 m.d.comb += mod.b.eq(self.b)
846 for i in range(len(signs)):
847 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
848 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
849 m.d.comb += mod.pbs.eq(pbs)
850 nat_l.append(mod.not_a_term)
851 nbt_l.append(mod.not_b_term)
852 nla_l.append(mod.neg_lsb_a_term)
853 nlb_l.append(mod.neg_lsb_b_term)
854
855 terms = []
856
857 for a_index in range(8):
858 t = ProductTerms(8, 128, 8, a_index, 8)
859 setattr(m.submodules, "terms_%d" % a_index, t)
860
861 m.d.comb += t.a.eq(self.a)
862 m.d.comb += t.b.eq(self.b)
863 m.d.comb += t.pb_en.eq(pbs)
864
865 for term in t.terms:
866 terms.append(term)
867
868 # it's fine to bitwise-or data together since they are never enabled
869 # at the same time
870 m.submodules.nat_or = nat_or = OrMod(128)
871 m.submodules.nbt_or = nbt_or = OrMod(128)
872 m.submodules.nla_or = nla_or = OrMod(128)
873 m.submodules.nlb_or = nlb_or = OrMod(128)
874 for l, mod in [(nat_l, nat_or),
875 (nbt_l, nbt_or),
876 (nla_l, nla_or),
877 (nlb_l, nlb_or)]:
878 for i in range(len(l)):
879 m.d.comb += mod.orin[i].eq(l[i])
880 terms.append(mod.orout)
881
882 expanded_part_pts = PartitionPoints()
883 for i, v in self.part_pts.items():
884 signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
885 expanded_part_pts[i * 2] = signal
886 m.d.comb += signal.eq(v)
887
888 add_reduce = AddReduce(terms,
889 128,
890 self.register_levels,
891 expanded_part_pts)
892 m.submodules.add_reduce = add_reduce
893 m.d.comb += self._intermediate_output.eq(add_reduce.output)
894 # create _output_64
895 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
896 m.d.comb += io64.intermed.eq(self._intermediate_output)
897 for i in range(8):
898 m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
899
900 # create _output_32
901 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
902 m.d.comb += io32.intermed.eq(self._intermediate_output)
903 for i in range(8):
904 m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
905
906 # create _output_16
907 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
908 m.d.comb += io16.intermed.eq(self._intermediate_output)
909 for i in range(8):
910 m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
911
912 # create _output_8
913 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
914 m.d.comb += io8.intermed.eq(self._intermediate_output)
915 for i in range(8):
916 m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
917
918 # final output
919 m.submodules.finalout = finalout = FinalOut(64)
920 for i in range(len(part_8.delayed_parts[-1])):
921 m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
922 for i in range(len(part_16.delayed_parts[-1])):
923 m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
924 for i in range(len(part_32.delayed_parts[-1])):
925 m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
926 m.d.comb += finalout.i8.eq(io8.output)
927 m.d.comb += finalout.i16.eq(io16.output)
928 m.d.comb += finalout.i32.eq(io32.output)
929 m.d.comb += finalout.i64.eq(io64.output)
930 m.d.comb += self.output.eq(finalout.out)
931
932 return m
933
934
935 if __name__ == "__main__":
936 m = Mul8_16_32_64()
937 main(m, ports=[m.a,
938 m.b,
939 m._intermediate_output,
940 m.output,
941 *m.part_ops,
942 *m.part_pts.values()])