4c3a3cf177af557c2d0c1fa8520847ae6fa33cfb
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 """
58 if name is None:
59 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
60 retval = PartitionPoints()
61 for point, enabled in self.items():
62 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
63 return retval
64
65 def eq(self, rhs):
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self.keys()) != set(rhs.keys()):
68 raise ValueError("incompatible point set")
69 for point, enabled in self.items():
70 yield enabled.eq(rhs[point])
71
72 def as_mask(self, width):
73 """Create a bit-mask from `self`.
74
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
77
78 :param width: the bit width of the resulting mask
79 """
80 bits = []
81 for i in range(width):
82 if i in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106
107 class FullAdder(Elaboratable):
108 """Full Adder.
109
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
115
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
120 """
121
122 def __init__(self, width):
123 """Create a ``FullAdder``.
124
125 :param width: the bit width of the input and output
126 """
127 self.in0 = Signal(width)
128 self.in1 = Signal(width)
129 self.in2 = Signal(width)
130 self.sum = Signal(width)
131 self.carry = Signal(width)
132
133 def elaborate(self, platform):
134 """Elaborate this module."""
135 m = Module()
136 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
137 m.d.comb += self.carry.eq((self.in0 & self.in1)
138 | (self.in1 & self.in2)
139 | (self.in2 & self.in0))
140 return m
141
142
143 class MaskedFullAdder(FullAdder):
144 """Masked Full Adder.
145
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
152
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
156 """
157
158 def __init__(self, width):
159 """Create a ``MaskedFullAdder``.
160
161 :param width: the bit width of the input and output
162 """
163 FullAdder.__init__(self, width)
164 self.mask = Signal(width)
165 self.mcarry = Signal(width)
166
167 def elaborate(self, platform):
168 """Elaborate this module."""
169 m = FullAdder.elaborate(self, platform)
170 m.d.comb += self.mcarry.eq((self.carry << 1) & self.mask)
171 return m
172
173
174 class PartitionedAdder(Elaboratable):
175 """Partitioned Adder.
176
177 Performs the final add. The partition points are included in the
178 actual add (in one of the operands only), which causes a carry over
179 to the next bit. Then the final output *removes* the extra bits from
180 the result.
181
182 partition: .... P... P... P... P... (32 bits)
183 a : .... .... .... .... .... (32 bits)
184 b : .... .... .... .... .... (32 bits)
185 exp-a : ....P....P....P....P.... (32+4 bits)
186 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
187 exp-o : ....xN...xN...xN...xN... (32+4 bits)
188 o : .... N... N... N... N... (32 bits)
189
190 :attribute width: the bit width of the input and output. Read-only.
191 :attribute a: the first input to the adder
192 :attribute b: the second input to the adder
193 :attribute output: the sum output
194 :attribute partition_points: the input partition points. Modification not
195 supported, except for by ``Signal.eq``.
196 """
197
198 def __init__(self, width, partition_points):
199 """Create a ``PartitionedAdder``.
200
201 :param width: the bit width of the input and output
202 :param partition_points: the input partition points
203 """
204 self.width = width
205 self.a = Signal(width)
206 self.b = Signal(width)
207 self.output = Signal(width)
208 self.partition_points = PartitionPoints(partition_points)
209 if not self.partition_points.fits_in_width(width):
210 raise ValueError("partition_points doesn't fit in width")
211 expanded_width = 0
212 for i in range(self.width):
213 if i in self.partition_points:
214 expanded_width += 1
215 expanded_width += 1
216 self._expanded_width = expanded_width
217 # XXX these have to remain here due to some horrible nmigen
218 # simulation bugs involving sync. it is *not* necessary to
219 # have them here, they should (under normal circumstances)
220 # be moved into elaborate, as they are entirely local
221 self._expanded_a = Signal(expanded_width) # includes extra part-points
222 self._expanded_b = Signal(expanded_width) # likewise.
223 self._expanded_o = Signal(expanded_width) # likewise.
224
225 def elaborate(self, platform):
226 """Elaborate this module."""
227 m = Module()
228 expanded_index = 0
229 # store bits in a list, use Cat later. graphviz is much cleaner
230 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
231
232 # partition points are "breaks" (extra zeros or 1s) in what would
233 # otherwise be a massive long add. when the "break" points are 0,
234 # whatever is in it (in the output) is discarded. however when
235 # there is a "1", it causes a roll-over carry to the *next* bit.
236 # we still ignore the "break" bit in the [intermediate] output,
237 # however by that time we've got the effect that we wanted: the
238 # carry has been carried *over* the break point.
239
240 for i in range(self.width):
241 if i in self.partition_points:
242 # add extra bit set to 0 + 0 for enabled partition points
243 # and 1 + 0 for disabled partition points
244 ea.append(self._expanded_a[expanded_index])
245 al.append(~self.partition_points[i]) # add extra bit in a
246 eb.append(self._expanded_b[expanded_index])
247 bl.append(C(0)) # yes, add a zero
248 expanded_index += 1 # skip the extra point. NOT in the output
249 ea.append(self._expanded_a[expanded_index])
250 eb.append(self._expanded_b[expanded_index])
251 eo.append(self._expanded_o[expanded_index])
252 al.append(self.a[i])
253 bl.append(self.b[i])
254 ol.append(self.output[i])
255 expanded_index += 1
256
257 # combine above using Cat
258 m.d.comb += Cat(*ea).eq(Cat(*al))
259 m.d.comb += Cat(*eb).eq(Cat(*bl))
260 m.d.comb += Cat(*ol).eq(Cat(*eo))
261
262 # use only one addition to take advantage of look-ahead carry and
263 # special hardware on FPGAs
264 m.d.comb += self._expanded_o.eq(
265 self._expanded_a + self._expanded_b)
266 return m
267
268
269 FULL_ADDER_INPUT_COUNT = 3
270
271
272 class AddReduce(Elaboratable):
273 """Add list of numbers together.
274
275 :attribute inputs: input ``Signal``s to be summed. Modification not
276 supported, except for by ``Signal.eq``.
277 :attribute register_levels: List of nesting levels that should have
278 pipeline registers.
279 :attribute output: output sum.
280 :attribute partition_points: the input partition points. Modification not
281 supported, except for by ``Signal.eq``.
282 """
283
284 def __init__(self, inputs, output_width, register_levels, partition_points):
285 """Create an ``AddReduce``.
286
287 :param inputs: input ``Signal``s to be summed.
288 :param output_width: bit-width of ``output``.
289 :param register_levels: List of nesting levels that should have
290 pipeline registers.
291 :param partition_points: the input partition points.
292 """
293 self.inputs = list(inputs)
294 self._resized_inputs = [
295 Signal(output_width, name=f"resized_inputs[{i}]")
296 for i in range(len(self.inputs))]
297 self.register_levels = list(register_levels)
298 self.output = Signal(output_width)
299 self.partition_points = PartitionPoints(partition_points)
300 if not self.partition_points.fits_in_width(output_width):
301 raise ValueError("partition_points doesn't fit in output_width")
302 self._reg_partition_points = self.partition_points.like()
303 max_level = AddReduce.get_max_level(len(self.inputs))
304 for level in self.register_levels:
305 if level > max_level:
306 raise ValueError(
307 "not enough adder levels for specified register levels")
308
309 @staticmethod
310 def get_max_level(input_count):
311 """Get the maximum level.
312
313 All ``register_levels`` must be less than or equal to the maximum
314 level.
315 """
316 retval = 0
317 while True:
318 groups = AddReduce.full_adder_groups(input_count)
319 if len(groups) == 0:
320 return retval
321 input_count %= FULL_ADDER_INPUT_COUNT
322 input_count += 2 * len(groups)
323 retval += 1
324
325 def next_register_levels(self):
326 """``Iterable`` of ``register_levels`` for next recursive level."""
327 for level in self.register_levels:
328 if level > 0:
329 yield level - 1
330
331 @staticmethod
332 def full_adder_groups(input_count):
333 """Get ``inputs`` indices for which a full adder should be built."""
334 return range(0,
335 input_count - FULL_ADDER_INPUT_COUNT + 1,
336 FULL_ADDER_INPUT_COUNT)
337
338 def elaborate(self, platform):
339 """Elaborate this module."""
340 m = Module()
341
342 # resize inputs to correct bit-width and optionally add in
343 # pipeline registers
344 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
345 for i in range(len(self.inputs))]
346 if 0 in self.register_levels:
347 m.d.sync += resized_input_assignments
348 m.d.sync += self._reg_partition_points.eq(self.partition_points)
349 else:
350 m.d.comb += resized_input_assignments
351 m.d.comb += self._reg_partition_points.eq(self.partition_points)
352
353 groups = AddReduce.full_adder_groups(len(self.inputs))
354 # if there are no full adders to create, then we handle the base cases
355 # and return, otherwise we go on to the recursive case
356 if len(groups) == 0:
357 if len(self.inputs) == 0:
358 # use 0 as the default output value
359 m.d.comb += self.output.eq(0)
360 elif len(self.inputs) == 1:
361 # handle single input
362 m.d.comb += self.output.eq(self._resized_inputs[0])
363 else:
364 # base case for adding 2 or more inputs, which get recursively
365 # reduced to 2 inputs
366 assert len(self.inputs) == 2
367 adder = PartitionedAdder(len(self.output),
368 self._reg_partition_points)
369 m.submodules.final_adder = adder
370 m.d.comb += adder.a.eq(self._resized_inputs[0])
371 m.d.comb += adder.b.eq(self._resized_inputs[1])
372 m.d.comb += self.output.eq(adder.output)
373 return m
374 # go on to handle recursive case
375 intermediate_terms = []
376
377 def add_intermediate_term(value):
378 intermediate_term = Signal(
379 len(self.output),
380 name=f"intermediate_terms[{len(intermediate_terms)}]")
381 intermediate_terms.append(intermediate_term)
382 m.d.comb += intermediate_term.eq(value)
383
384 # store mask in intermediary (simplifies graph)
385 part_mask = Signal(len(self.output), reset_less=True)
386 mask = self._reg_partition_points.as_mask(len(self.output))
387 m.d.comb += part_mask.eq(mask)
388
389 # create full adders for this recursive level.
390 # this shrinks N terms to 2 * (N // 3) plus the remainder
391 for i in groups:
392 adder_i = MaskedFullAdder(len(self.output))
393 setattr(m.submodules, f"adder_{i}", adder_i)
394 m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
395 m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
396 m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
397 m.d.comb += adder_i.mask.eq(part_mask)
398 # add both the sum and the masked-carry to the next level.
399 # 3 inputs have now been reduced to 2...
400 add_intermediate_term(adder_i.sum)
401 add_intermediate_term(adder_i.mcarry)
402 # handle the remaining inputs.
403 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
404 add_intermediate_term(self._resized_inputs[-1])
405 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
406 # Just pass the terms to the next layer, since we wouldn't gain
407 # anything by using a half adder since there would still be 2 terms
408 # and just passing the terms to the next layer saves gates.
409 add_intermediate_term(self._resized_inputs[-2])
410 add_intermediate_term(self._resized_inputs[-1])
411 else:
412 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
413 # recursive invocation of ``AddReduce``
414 next_level = AddReduce(intermediate_terms,
415 len(self.output),
416 self.next_register_levels(),
417 self._reg_partition_points)
418 m.submodules.next_level = next_level
419 m.d.comb += self.output.eq(next_level.output)
420 return m
421
422
423 OP_MUL_LOW = 0
424 OP_MUL_SIGNED_HIGH = 1
425 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
426 OP_MUL_UNSIGNED_HIGH = 3
427
428
429 def get_term(value, shift=0, enabled=None):
430 if enabled is not None:
431 value = Mux(enabled, value, 0)
432 if shift > 0:
433 value = Cat(Repl(C(0, 1), shift), value)
434 else:
435 assert shift == 0
436 return value
437
438
439 class ProductTerm(Elaboratable):
440 """ this class creates a single product term (a[..]*b[..]).
441 it has a design flaw in that is the *output* that is selected,
442 where the multiplication(s) are combinatorially generated
443 all the time.
444 """
445
446 def __init__(self, width, twidth, pbwid, a_index, b_index):
447 self.a_index = a_index
448 self.b_index = b_index
449 shift = 8 * (self.a_index + self.b_index)
450 self.pwidth = width
451 self.twidth = twidth
452 self.width = width*2
453 self.shift = shift
454
455 self.ti = Signal(self.width, reset_less=True)
456 self.term = Signal(twidth, reset_less=True)
457 self.a = Signal(twidth//2, reset_less=True)
458 self.b = Signal(twidth//2, reset_less=True)
459 self.pb_en = Signal(pbwid, reset_less=True)
460
461 self.tl = tl = []
462 min_index = min(self.a_index, self.b_index)
463 max_index = max(self.a_index, self.b_index)
464 for i in range(min_index, max_index):
465 tl.append(self.pb_en[i])
466 name = "te_%d_%d" % (self.a_index, self.b_index)
467 if len(tl) > 0:
468 term_enabled = Signal(name=name, reset_less=True)
469 else:
470 term_enabled = None
471 self.enabled = term_enabled
472 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
473
474 def elaborate(self, platform):
475
476 m = Module()
477 if self.enabled is not None:
478 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
479
480 bsa = Signal(self.width, reset_less=True)
481 bsb = Signal(self.width, reset_less=True)
482 a_index, b_index = self.a_index, self.b_index
483 pwidth = self.pwidth
484 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
485 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
486 m.d.comb += self.ti.eq(bsa * bsb)
487 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
488 """
489 #TODO: sort out width issues, get inputs a/b switched on/off.
490 #data going into Muxes is 1/2 the required width
491
492 pwidth = self.pwidth
493 width = self.width
494 bsa = Signal(self.twidth//2, reset_less=True)
495 bsb = Signal(self.twidth//2, reset_less=True)
496 asel = Signal(width, reset_less=True)
497 bsel = Signal(width, reset_less=True)
498 a_index, b_index = self.a_index, self.b_index
499 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
500 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
501 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
502 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
503 m.d.comb += self.ti.eq(bsa * bsb)
504 m.d.comb += self.term.eq(self.ti)
505 """
506
507 return m
508
509
510 class ProductTerms(Elaboratable):
511 """ creates a bank of product terms. also performs the actual bit-selection
512 this class is to be wrapped with a for-loop on the "a" operand.
513 it creates a second-level for-loop on the "b" operand.
514 """
515 def __init__(self, width, twidth, pbwid, a_index, blen):
516 self.a_index = a_index
517 self.blen = blen
518 self.pwidth = width
519 self.twidth = twidth
520 self.pbwid = pbwid
521 self.a = Signal(twidth//2, reset_less=True)
522 self.b = Signal(twidth//2, reset_less=True)
523 self.pb_en = Signal(pbwid, reset_less=True)
524 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
525 for i in range(blen)]
526
527 def elaborate(self, platform):
528
529 m = Module()
530
531 for b_index in range(self.blen):
532 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
533 self.a_index, b_index)
534 setattr(m.submodules, "term_%d" % b_index, t)
535
536 m.d.comb += t.a.eq(self.a)
537 m.d.comb += t.b.eq(self.b)
538 m.d.comb += t.pb_en.eq(self.pb_en)
539
540 m.d.comb += self.terms[b_index].eq(t.term)
541
542 return m
543
544 class LSBNegTerm(Elaboratable):
545
546 def __init__(self, bit_width):
547 self.bit_width = bit_width
548 self.part = Signal(reset_less=True)
549 self.signed = Signal(reset_less=True)
550 self.op = Signal(bit_width, reset_less=True)
551 self.msb = Signal(reset_less=True)
552 self.nt = Signal(bit_width*2, reset_less=True)
553 self.nl = Signal(bit_width*2, reset_less=True)
554
555 def elaborate(self, platform):
556 m = Module()
557 comb = m.d.comb
558 bit_wid = self.bit_width
559 ext = Repl(0, bit_wid) # extend output to HI part
560
561 # determine sign of each incoming number *in this partition*
562 enabled = Signal(reset_less=True)
563 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
564
565 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
566 # negation operation is split into a bitwise not and a +1.
567 # likewise for 16, 32, and 64-bit values.
568
569 # width-extended 1s complement if a is signed, otherwise zero
570 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
571
572 # add 1 if signed, otherwise add zero
573 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
574
575 return m
576
577
578 class Part(Elaboratable):
579 """ a key class which, depending on the partitioning, will determine
580 what action to take when parts of the output are signed or unsigned.
581
582 this requires 2 pieces of data *per operand, per partition*:
583 whether the MSB is HI/LO (per partition!), and whether a signed
584 or unsigned operation has been *requested*.
585
586 once that is determined, signed is basically carried out
587 by splitting 2's complement into 1's complement plus one.
588 1's complement is just a bit-inversion.
589
590 the extra terms - as separate terms - are then thrown at the
591 AddReduce alongside the multiplication part-results.
592 """
593 def __init__(self, width, n_parts, n_levels, pbwid):
594
595 # inputs
596 self.a = Signal(64)
597 self.b = Signal(64)
598 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
599 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
600 self.pbs = Signal(pbwid, reset_less=True)
601
602 # outputs
603 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
604 self.delayed_parts = [
605 [Signal(name=f"delayed_part_{delay}_{i}")
606 for i in range(n_parts)]
607 for delay in range(n_levels)]
608 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
609 self.dplast = [Signal(name=f"dplast_{i}")
610 for i in range(n_parts)]
611
612 self.not_a_term = Signal(width)
613 self.neg_lsb_a_term = Signal(width)
614 self.not_b_term = Signal(width)
615 self.neg_lsb_b_term = Signal(width)
616
617 def elaborate(self, platform):
618 m = Module()
619
620 pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
621 # negated-temporary copy of partition bits
622 npbs = Signal.like(pbs, reset_less=True)
623 m.d.comb += npbs.eq(~pbs)
624 byte_count = 8 // len(parts)
625 for i in range(len(parts)):
626 pbl = []
627 pbl.append(npbs[i * byte_count - 1])
628 for j in range(i * byte_count, (i + 1) * byte_count - 1):
629 pbl.append(pbs[j])
630 pbl.append(npbs[(i + 1) * byte_count - 1])
631 value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
632 m.d.comb += value.eq(Cat(*pbl))
633 m.d.comb += parts[i].eq(~(value).bool())
634 m.d.comb += delayed_parts[0][i].eq(parts[i])
635 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
636 for j in range(len(delayed_parts)-1)]
637 m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
638
639 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
640 self.not_a_term, self.neg_lsb_a_term, \
641 self.not_b_term, self.neg_lsb_b_term
642
643 byte_width = 8 // len(parts) # byte width
644 bit_wid = 8 * byte_width # bit width
645 nat, nbt, nla, nlb = [], [], [], []
646 for i in range(len(parts)):
647 # work out bit-inverted and +1 term for a.
648 pa = LSBNegTerm(bit_wid)
649 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
650 m.d.comb += pa.part.eq(parts[i])
651 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
652 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
653 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
654 nat.append(pa.nt)
655 nla.append(pa.nl)
656
657 # work out bit-inverted and +1 term for b
658 pb = LSBNegTerm(bit_wid)
659 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
660 m.d.comb += pb.part.eq(parts[i])
661 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
662 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
663 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
664 nbt.append(pb.nt)
665 nlb.append(pb.nl)
666
667 # concatenate together and return all 4 results.
668 m.d.comb += [not_a_term.eq(Cat(*nat)),
669 not_b_term.eq(Cat(*nbt)),
670 neg_lsb_a_term.eq(Cat(*nla)),
671 neg_lsb_b_term.eq(Cat(*nlb)),
672 ]
673
674 return m
675
676
677 class IntermediateOut(Elaboratable):
678 """ selects the HI/LO part of the multiplication, for a given bit-width
679 the output is also reconstructed in its SIMD (partition) lanes.
680 """
681 def __init__(self, width, out_wid, n_parts):
682 self.width = width
683 self.n_parts = n_parts
684 self.delayed_part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
685 for i in range(8)]
686 self.intermed = Signal(out_wid, reset_less=True)
687 self.output = Signal(out_wid//2, reset_less=True)
688
689 def elaborate(self, platform):
690 m = Module()
691
692 ol = []
693 w = self.width
694 sel = w // 8
695 for i in range(self.n_parts):
696 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
697 m.d.comb += op.eq(
698 Mux(self.delayed_part_ops[sel * i] == OP_MUL_LOW,
699 self.intermed.bit_select(i * w*2, w),
700 self.intermed.bit_select(i * w*2 + w, w)))
701 ol.append(op)
702 m.d.comb += self.output.eq(Cat(*ol))
703
704 return m
705
706
707 class FinalOut(Elaboratable):
708 """ selects the final output based on the partitioning.
709
710 each byte is selectable independently, i.e. it is possible
711 that some partitions requested 8-bit computation whilst others
712 requested 16 or 32 bit.
713 """
714 def __init__(self, out_wid):
715 # inputs
716 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
717 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
718 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
719
720 self.i8 = Signal(out_wid, reset_less=True)
721 self.i16 = Signal(out_wid, reset_less=True)
722 self.i32 = Signal(out_wid, reset_less=True)
723 self.i64 = Signal(out_wid, reset_less=True)
724
725 # output
726 self.out = Signal(out_wid, reset_less=True)
727
728 def elaborate(self, platform):
729 m = Module()
730 ol = []
731 for i in range(8):
732 # select one of the outputs: d8 selects i8, d16 selects i16
733 # d32 selects i32, and the default is i64.
734 # d8 and d16 are ORed together in the first Mux
735 # then the 2nd selects either i8 or i16.
736 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
737 op = Signal(8, reset_less=True, name="op_%d" % i)
738 m.d.comb += op.eq(
739 Mux(self.d8[i] | self.d16[i // 2],
740 Mux(self.d8[i], self.i8.bit_select(i * 8, 8),
741 self.i16.bit_select(i * 8, 8)),
742 Mux(self.d32[i // 4], self.i32.bit_select(i * 8, 8),
743 self.i64.bit_select(i * 8, 8))))
744 ol.append(op)
745 m.d.comb += self.out.eq(Cat(*ol))
746 return m
747
748
749 class OrMod(Elaboratable):
750 """ ORs four values together in a hierarchical tree
751 """
752 def __init__(self, wid):
753 self.wid = wid
754 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
755 for i in range(4)]
756 self.orout = Signal(wid, reset_less=True)
757
758 def elaborate(self, platform):
759 m = Module()
760 or1 = Signal(self.wid, reset_less=True)
761 or2 = Signal(self.wid, reset_less=True)
762 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
763 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
764 m.d.comb += self.orout.eq(or1 | or2)
765
766 return m
767
768
769 class Signs(Elaboratable):
770 """ determines whether a or b are signed numbers
771 based on the required operation type (OP_MUL_*)
772 """
773
774 def __init__(self):
775 self.part_ops = Signal(2, reset_less=True)
776 self.a_signed = Signal(reset_less=True)
777 self.b_signed = Signal(reset_less=True)
778
779 def elaborate(self, platform):
780
781 m = Module()
782
783 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
784 bsig = (self.part_ops == OP_MUL_LOW) \
785 | (self.part_ops == OP_MUL_SIGNED_HIGH)
786 m.d.comb += self.a_signed.eq(asig)
787 m.d.comb += self.b_signed.eq(bsig)
788
789 return m
790
791
792 class Mul8_16_32_64(Elaboratable):
793 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
794
795 Supports partitioning into any combination of 8, 16, 32, and 64-bit
796 partitions on naturally-aligned boundaries. Supports the operation being
797 set for each partition independently.
798
799 :attribute part_pts: the input partition points. Has a partition point at
800 multiples of 8 in 0 < i < 64. Each partition point's associated
801 ``Value`` is a ``Signal``. Modification not supported, except for by
802 ``Signal.eq``.
803 :attribute part_ops: the operation for each byte. The operation for a
804 particular partition is selected by assigning the selected operation
805 code to each byte in the partition. The allowed operation codes are:
806
807 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
808 RISC-V's `mul` instruction.
809 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
810 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
811 instruction.
812 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
813 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
814 `mulhsu` instruction.
815 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
816 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
817 instruction.
818 """
819
820 def __init__(self, register_levels=()):
821 """ register_levels: specifies the points in the cascade at which
822 flip-flops are to be inserted.
823 """
824
825 # parameter(s)
826 self.register_levels = list(register_levels)
827
828 # inputs
829 self.part_pts = PartitionPoints()
830 for i in range(8, 64, 8):
831 self.part_pts[i] = Signal(name=f"part_pts_{i}")
832 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
833 self.a = Signal(64)
834 self.b = Signal(64)
835
836 # intermediates (needed for unit tests)
837 self._intermediate_output = Signal(128)
838
839 # output
840 self.output = Signal(64)
841
842 def _part_byte(self, index):
843 if index == -1 or index == 7:
844 return C(True, 1)
845 assert index >= 0 and index < 8
846 return self.part_pts[index * 8 + 8]
847
848 def elaborate(self, platform):
849 m = Module()
850
851 # collect part-bytes
852 pbs = Signal(8, reset_less=True)
853 tl = []
854 for i in range(8):
855 pb = Signal(name="pb%d" % i, reset_less=True)
856 m.d.comb += pb.eq(self._part_byte(i))
857 tl.append(pb)
858 m.d.comb += pbs.eq(Cat(*tl))
859
860 # local variables
861 signs = []
862 for i in range(8):
863 s = Signs()
864 signs.append(s)
865 setattr(m.submodules, "signs%d" % i, s)
866 m.d.comb += s.part_ops.eq(self.part_ops[i])
867
868 delayed_part_ops = [
869 [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
870 for i in range(8)]
871 for delay in range(1 + len(self.register_levels))]
872 for i in range(len(self.part_ops)):
873 m.d.comb += delayed_part_ops[0][i].eq(self.part_ops[i])
874 m.d.sync += [delayed_part_ops[j + 1][i].eq(delayed_part_ops[j][i])
875 for j in range(len(self.register_levels))]
876
877 n_levels = len(self.register_levels)+1
878 m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
879 m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
880 m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
881 m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
882 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
883 for mod in [part_8, part_16, part_32, part_64]:
884 m.d.comb += mod.a.eq(self.a)
885 m.d.comb += mod.b.eq(self.b)
886 for i in range(len(signs)):
887 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
888 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
889 m.d.comb += mod.pbs.eq(pbs)
890 nat_l.append(mod.not_a_term)
891 nbt_l.append(mod.not_b_term)
892 nla_l.append(mod.neg_lsb_a_term)
893 nlb_l.append(mod.neg_lsb_b_term)
894
895 terms = []
896
897 for a_index in range(8):
898 t = ProductTerms(8, 128, 8, a_index, 8)
899 setattr(m.submodules, "terms_%d" % a_index, t)
900
901 m.d.comb += t.a.eq(self.a)
902 m.d.comb += t.b.eq(self.b)
903 m.d.comb += t.pb_en.eq(pbs)
904
905 for term in t.terms:
906 terms.append(term)
907
908 # it's fine to bitwise-or data together since they are never enabled
909 # at the same time
910 m.submodules.nat_or = nat_or = OrMod(128)
911 m.submodules.nbt_or = nbt_or = OrMod(128)
912 m.submodules.nla_or = nla_or = OrMod(128)
913 m.submodules.nlb_or = nlb_or = OrMod(128)
914 for l, mod in [(nat_l, nat_or),
915 (nbt_l, nbt_or),
916 (nla_l, nla_or),
917 (nlb_l, nlb_or)]:
918 for i in range(len(l)):
919 m.d.comb += mod.orin[i].eq(l[i])
920 terms.append(mod.orout)
921
922 expanded_part_pts = PartitionPoints()
923 for i, v in self.part_pts.items():
924 signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
925 expanded_part_pts[i * 2] = signal
926 m.d.comb += signal.eq(v)
927
928 add_reduce = AddReduce(terms,
929 128,
930 self.register_levels,
931 expanded_part_pts)
932 m.submodules.add_reduce = add_reduce
933 m.d.comb += self._intermediate_output.eq(add_reduce.output)
934 # create _output_64
935 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
936 m.d.comb += io64.intermed.eq(self._intermediate_output)
937 for i in range(8):
938 m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
939
940 # create _output_32
941 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
942 m.d.comb += io32.intermed.eq(self._intermediate_output)
943 for i in range(8):
944 m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
945
946 # create _output_16
947 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
948 m.d.comb += io16.intermed.eq(self._intermediate_output)
949 for i in range(8):
950 m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
951
952 # create _output_8
953 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
954 m.d.comb += io8.intermed.eq(self._intermediate_output)
955 for i in range(8):
956 m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
957
958 # final output
959 m.submodules.finalout = finalout = FinalOut(64)
960 for i in range(len(part_8.delayed_parts[-1])):
961 m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
962 for i in range(len(part_16.delayed_parts[-1])):
963 m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
964 for i in range(len(part_32.delayed_parts[-1])):
965 m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
966 m.d.comb += finalout.i8.eq(io8.output)
967 m.d.comb += finalout.i16.eq(io16.output)
968 m.d.comb += finalout.i32.eq(io32.output)
969 m.d.comb += finalout.i64.eq(io64.output)
970 m.d.comb += self.output.eq(finalout.out)
971
972 return m
973
974
975 if __name__ == "__main__":
976 m = Mul8_16_32_64()
977 main(m, ports=[m.a,
978 m.b,
979 m._intermediate_output,
980 m.output,
981 *m.part_ops,
982 *m.part_pts.values()])