e86f655b53b6ffe5d06086a36ca62f2523a7850f
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 """
58 if name is None:
59 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
60 retval = PartitionPoints()
61 for point, enabled in self.items():
62 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
63 return retval
64
65 def eq(self, rhs):
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self.keys()) != set(rhs.keys()):
68 raise ValueError("incompatible point set")
69 for point, enabled in self.items():
70 yield enabled.eq(rhs[point])
71
72 def as_mask(self, width):
73 """Create a bit-mask from `self`.
74
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
77
78 :param width: the bit width of the resulting mask
79 """
80 bits = []
81 for i in range(width):
82 if i in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106
107 class FullAdder(Elaboratable):
108 """Full Adder.
109
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
115
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
120 """
121
122 def __init__(self, width):
123 """Create a ``FullAdder``.
124
125 :param width: the bit width of the input and output
126 """
127 self.in0 = Signal(width)
128 self.in1 = Signal(width)
129 self.in2 = Signal(width)
130 self.sum = Signal(width)
131 self.carry = Signal(width)
132
133 def elaborate(self, platform):
134 """Elaborate this module."""
135 m = Module()
136 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
137 m.d.comb += self.carry.eq((self.in0 & self.in1)
138 | (self.in1 & self.in2)
139 | (self.in2 & self.in0))
140 return m
141
142
143 class MaskedFullAdder(FullAdder):
144 """Masked Full Adder.
145
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
152
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
156 """
157
158 def __init__(self, width):
159 """Create a ``MaskedFullAdder``.
160
161 :param width: the bit width of the input and output
162 """
163 FullAdder.__init__(self, width)
164 self.mask = Signal(width)
165 self.mcarry = Signal(width)
166
167 def elaborate(self, platform):
168 """Elaborate this module."""
169 m = FullAdder.elaborate(self, platform)
170 m.d.comb += self.mcarry.eq((self.carry << 1) & self.mask)
171 return m
172
173
174 class PartitionedAdder(Elaboratable):
175 """Partitioned Adder.
176
177 Performs the final add. The partition points are included in the
178 actual add (in one of the operands only), which causes a carry over
179 to the next bit. Then the final output *removes* the extra bits from
180 the result.
181
182 partition: .... P... P... P... P... (32 bits)
183 a : .... .... .... .... .... (32 bits)
184 b : .... .... .... .... .... (32 bits)
185 exp-a : ....P....P....P....P.... (32+4 bits)
186 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
187 exp-o : ....xN...xN...xN...xN... (32+4 bits)
188 o : .... N... N... N... N... (32 bits)
189
190 :attribute width: the bit width of the input and output. Read-only.
191 :attribute a: the first input to the adder
192 :attribute b: the second input to the adder
193 :attribute output: the sum output
194 :attribute partition_points: the input partition points. Modification not
195 supported, except for by ``Signal.eq``.
196 """
197
198 def __init__(self, width, partition_points):
199 """Create a ``PartitionedAdder``.
200
201 :param width: the bit width of the input and output
202 :param partition_points: the input partition points
203 """
204 self.width = width
205 self.a = Signal(width)
206 self.b = Signal(width)
207 self.output = Signal(width)
208 self.partition_points = PartitionPoints(partition_points)
209 if not self.partition_points.fits_in_width(width):
210 raise ValueError("partition_points doesn't fit in width")
211 expanded_width = 0
212 for i in range(self.width):
213 if i in self.partition_points:
214 expanded_width += 1
215 expanded_width += 1
216 self._expanded_width = expanded_width
217 # XXX these have to remain here due to some horrible nmigen
218 # simulation bugs involving sync. it is *not* necessary to
219 # have them here, they should (under normal circumstances)
220 # be moved into elaborate, as they are entirely local
221 self._expanded_a = Signal(expanded_width) # includes extra part-points
222 self._expanded_b = Signal(expanded_width) # likewise.
223 self._expanded_o = Signal(expanded_width) # likewise.
224
225 def elaborate(self, platform):
226 """Elaborate this module."""
227 m = Module()
228 expanded_index = 0
229 # store bits in a list, use Cat later. graphviz is much cleaner
230 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
231
232 # partition points are "breaks" (extra zeros or 1s) in what would
233 # otherwise be a massive long add. when the "break" points are 0,
234 # whatever is in it (in the output) is discarded. however when
235 # there is a "1", it causes a roll-over carry to the *next* bit.
236 # we still ignore the "break" bit in the [intermediate] output,
237 # however by that time we've got the effect that we wanted: the
238 # carry has been carried *over* the break point.
239
240 for i in range(self.width):
241 if i in self.partition_points:
242 # add extra bit set to 0 + 0 for enabled partition points
243 # and 1 + 0 for disabled partition points
244 ea.append(self._expanded_a[expanded_index])
245 al.append(~self.partition_points[i]) # add extra bit in a
246 eb.append(self._expanded_b[expanded_index])
247 bl.append(C(0)) # yes, add a zero
248 expanded_index += 1 # skip the extra point. NOT in the output
249 ea.append(self._expanded_a[expanded_index])
250 eb.append(self._expanded_b[expanded_index])
251 eo.append(self._expanded_o[expanded_index])
252 al.append(self.a[i])
253 bl.append(self.b[i])
254 ol.append(self.output[i])
255 expanded_index += 1
256
257 # combine above using Cat
258 m.d.comb += Cat(*ea).eq(Cat(*al))
259 m.d.comb += Cat(*eb).eq(Cat(*bl))
260 m.d.comb += Cat(*ol).eq(Cat(*eo))
261
262 # use only one addition to take advantage of look-ahead carry and
263 # special hardware on FPGAs
264 m.d.comb += self._expanded_o.eq(
265 self._expanded_a + self._expanded_b)
266 return m
267
268
269 FULL_ADDER_INPUT_COUNT = 3
270
271
272 class AddReduceSingle(Elaboratable):
273 """Add list of numbers together.
274
275 :attribute inputs: input ``Signal``s to be summed. Modification not
276 supported, except for by ``Signal.eq``.
277 :attribute register_levels: List of nesting levels that should have
278 pipeline registers.
279 :attribute output: output sum.
280 :attribute partition_points: the input partition points. Modification not
281 supported, except for by ``Signal.eq``.
282 """
283
284 def __init__(self, inputs, output_width, register_levels, partition_points):
285 """Create an ``AddReduce``.
286
287 :param inputs: input ``Signal``s to be summed.
288 :param output_width: bit-width of ``output``.
289 :param register_levels: List of nesting levels that should have
290 pipeline registers.
291 :param partition_points: the input partition points.
292 """
293 self.inputs = list(inputs)
294 self._resized_inputs = [
295 Signal(output_width, name=f"resized_inputs[{i}]")
296 for i in range(len(self.inputs))]
297 self.register_levels = list(register_levels)
298 self.output = Signal(output_width)
299 self.partition_points = PartitionPoints(partition_points)
300 if not self.partition_points.fits_in_width(output_width):
301 raise ValueError("partition_points doesn't fit in output_width")
302 self._reg_partition_points = self.partition_points.like()
303
304 max_level = AddReduce.get_max_level(len(self.inputs))
305 for level in self.register_levels:
306 if level > max_level:
307 raise ValueError(
308 "not enough adder levels for specified register levels")
309
310 @staticmethod
311 def get_max_level(input_count):
312 """Get the maximum level.
313
314 All ``register_levels`` must be less than or equal to the maximum
315 level.
316 """
317 retval = 0
318 while True:
319 groups = AddReduce.full_adder_groups(input_count)
320 if len(groups) == 0:
321 return retval
322 input_count %= FULL_ADDER_INPUT_COUNT
323 input_count += 2 * len(groups)
324 retval += 1
325 @staticmethod
326 def full_adder_groups(input_count):
327 """Get ``inputs`` indices for which a full adder should be built."""
328 return range(0,
329 input_count - FULL_ADDER_INPUT_COUNT + 1,
330 FULL_ADDER_INPUT_COUNT)
331
332 def _elaborate(self, platform):
333 """Elaborate this module."""
334 m = Module()
335
336 # resize inputs to correct bit-width and optionally add in
337 # pipeline registers
338 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
339 for i in range(len(self.inputs))]
340 if 0 in self.register_levels:
341 m.d.sync += resized_input_assignments
342 m.d.sync += self._reg_partition_points.eq(self.partition_points)
343 else:
344 m.d.comb += resized_input_assignments
345 m.d.comb += self._reg_partition_points.eq(self.partition_points)
346
347 groups = AddReduceSingle.full_adder_groups(len(self.inputs))
348 # if there are no full adders to create, then we handle the base cases
349 # and return, otherwise we go on to the recursive case
350 if len(groups) == 0:
351 if len(self.inputs) == 0:
352 # use 0 as the default output value
353 m.d.comb += self.output.eq(0)
354 elif len(self.inputs) == 1:
355 # handle single input
356 m.d.comb += self.output.eq(self._resized_inputs[0])
357 else:
358 # base case for adding 2 or more inputs, which get recursively
359 # reduced to 2 inputs
360 assert len(self.inputs) == 2
361 adder = PartitionedAdder(len(self.output),
362 self._reg_partition_points)
363 m.submodules.final_adder = adder
364 m.d.comb += adder.a.eq(self._resized_inputs[0])
365 m.d.comb += adder.b.eq(self._resized_inputs[1])
366 m.d.comb += self.output.eq(adder.output)
367 return None, m
368
369 # go on to prepare recursive case
370 intermediate_terms = []
371
372 def add_intermediate_term(value):
373 intermediate_term = Signal(
374 len(self.output),
375 name=f"intermediate_terms[{len(intermediate_terms)}]")
376 intermediate_terms.append(intermediate_term)
377 m.d.comb += intermediate_term.eq(value)
378
379 # store mask in intermediary (simplifies graph)
380 part_mask = Signal(len(self.output), reset_less=True)
381 mask = self._reg_partition_points.as_mask(len(self.output))
382 m.d.comb += part_mask.eq(mask)
383
384 # create full adders for this recursive level.
385 # this shrinks N terms to 2 * (N // 3) plus the remainder
386 for i in groups:
387 adder_i = MaskedFullAdder(len(self.output))
388 setattr(m.submodules, f"adder_{i}", adder_i)
389 m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
390 m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
391 m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
392 m.d.comb += adder_i.mask.eq(part_mask)
393 # add both the sum and the masked-carry to the next level.
394 # 3 inputs have now been reduced to 2...
395 add_intermediate_term(adder_i.sum)
396 add_intermediate_term(adder_i.mcarry)
397 # handle the remaining inputs.
398 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
399 add_intermediate_term(self._resized_inputs[-1])
400 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
401 # Just pass the terms to the next layer, since we wouldn't gain
402 # anything by using a half adder since there would still be 2 terms
403 # and just passing the terms to the next layer saves gates.
404 add_intermediate_term(self._resized_inputs[-2])
405 add_intermediate_term(self._resized_inputs[-1])
406 else:
407 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
408
409 return intermediate_terms, m
410
411
412 class AddReduce(AddReduceSingle):
413 """Recursively Add list of numbers together.
414
415 :attribute inputs: input ``Signal``s to be summed. Modification not
416 supported, except for by ``Signal.eq``.
417 :attribute register_levels: List of nesting levels that should have
418 pipeline registers.
419 :attribute output: output sum.
420 :attribute partition_points: the input partition points. Modification not
421 supported, except for by ``Signal.eq``.
422 """
423
424 def __init__(self, inputs, output_width, register_levels, partition_points):
425 """Create an ``AddReduce``.
426
427 :param inputs: input ``Signal``s to be summed.
428 :param output_width: bit-width of ``output``.
429 :param register_levels: List of nesting levels that should have
430 pipeline registers.
431 :param partition_points: the input partition points.
432 """
433 AddReduceSingle.__init__(self, inputs, output_width, register_levels,
434 partition_points)
435
436 def next_register_levels(self):
437 """``Iterable`` of ``register_levels`` for next recursive level."""
438 for level in self.register_levels:
439 if level > 0:
440 yield level - 1
441
442 def elaborate(self, platform):
443 """Elaborate this module."""
444 intermediate_terms, m = AddReduceSingle._elaborate(self, platform)
445 if intermediate_terms is None:
446 return m
447
448 # recursive invocation of ``AddReduce``
449 next_level = AddReduce(intermediate_terms,
450 len(self.output),
451 self.next_register_levels(),
452 self._reg_partition_points)
453 m.submodules.next_level = next_level
454 m.d.comb += self.output.eq(next_level.output)
455 return m
456
457
458 OP_MUL_LOW = 0
459 OP_MUL_SIGNED_HIGH = 1
460 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
461 OP_MUL_UNSIGNED_HIGH = 3
462
463
464 def get_term(value, shift=0, enabled=None):
465 if enabled is not None:
466 value = Mux(enabled, value, 0)
467 if shift > 0:
468 value = Cat(Repl(C(0, 1), shift), value)
469 else:
470 assert shift == 0
471 return value
472
473
474 class ProductTerm(Elaboratable):
475 """ this class creates a single product term (a[..]*b[..]).
476 it has a design flaw in that is the *output* that is selected,
477 where the multiplication(s) are combinatorially generated
478 all the time.
479 """
480
481 def __init__(self, width, twidth, pbwid, a_index, b_index):
482 self.a_index = a_index
483 self.b_index = b_index
484 shift = 8 * (self.a_index + self.b_index)
485 self.pwidth = width
486 self.twidth = twidth
487 self.width = width*2
488 self.shift = shift
489
490 self.ti = Signal(self.width, reset_less=True)
491 self.term = Signal(twidth, reset_less=True)
492 self.a = Signal(twidth//2, reset_less=True)
493 self.b = Signal(twidth//2, reset_less=True)
494 self.pb_en = Signal(pbwid, reset_less=True)
495
496 self.tl = tl = []
497 min_index = min(self.a_index, self.b_index)
498 max_index = max(self.a_index, self.b_index)
499 for i in range(min_index, max_index):
500 tl.append(self.pb_en[i])
501 name = "te_%d_%d" % (self.a_index, self.b_index)
502 if len(tl) > 0:
503 term_enabled = Signal(name=name, reset_less=True)
504 else:
505 term_enabled = None
506 self.enabled = term_enabled
507 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
508
509 def elaborate(self, platform):
510
511 m = Module()
512 if self.enabled is not None:
513 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
514
515 bsa = Signal(self.width, reset_less=True)
516 bsb = Signal(self.width, reset_less=True)
517 a_index, b_index = self.a_index, self.b_index
518 pwidth = self.pwidth
519 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
520 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
521 m.d.comb += self.ti.eq(bsa * bsb)
522 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
523 """
524 #TODO: sort out width issues, get inputs a/b switched on/off.
525 #data going into Muxes is 1/2 the required width
526
527 pwidth = self.pwidth
528 width = self.width
529 bsa = Signal(self.twidth//2, reset_less=True)
530 bsb = Signal(self.twidth//2, reset_less=True)
531 asel = Signal(width, reset_less=True)
532 bsel = Signal(width, reset_less=True)
533 a_index, b_index = self.a_index, self.b_index
534 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
535 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
536 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
537 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
538 m.d.comb += self.ti.eq(bsa * bsb)
539 m.d.comb += self.term.eq(self.ti)
540 """
541
542 return m
543
544
545 class ProductTerms(Elaboratable):
546 """ creates a bank of product terms. also performs the actual bit-selection
547 this class is to be wrapped with a for-loop on the "a" operand.
548 it creates a second-level for-loop on the "b" operand.
549 """
550 def __init__(self, width, twidth, pbwid, a_index, blen):
551 self.a_index = a_index
552 self.blen = blen
553 self.pwidth = width
554 self.twidth = twidth
555 self.pbwid = pbwid
556 self.a = Signal(twidth//2, reset_less=True)
557 self.b = Signal(twidth//2, reset_less=True)
558 self.pb_en = Signal(pbwid, reset_less=True)
559 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
560 for i in range(blen)]
561
562 def elaborate(self, platform):
563
564 m = Module()
565
566 for b_index in range(self.blen):
567 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
568 self.a_index, b_index)
569 setattr(m.submodules, "term_%d" % b_index, t)
570
571 m.d.comb += t.a.eq(self.a)
572 m.d.comb += t.b.eq(self.b)
573 m.d.comb += t.pb_en.eq(self.pb_en)
574
575 m.d.comb += self.terms[b_index].eq(t.term)
576
577 return m
578
579 class LSBNegTerm(Elaboratable):
580
581 def __init__(self, bit_width):
582 self.bit_width = bit_width
583 self.part = Signal(reset_less=True)
584 self.signed = Signal(reset_less=True)
585 self.op = Signal(bit_width, reset_less=True)
586 self.msb = Signal(reset_less=True)
587 self.nt = Signal(bit_width*2, reset_less=True)
588 self.nl = Signal(bit_width*2, reset_less=True)
589
590 def elaborate(self, platform):
591 m = Module()
592 comb = m.d.comb
593 bit_wid = self.bit_width
594 ext = Repl(0, bit_wid) # extend output to HI part
595
596 # determine sign of each incoming number *in this partition*
597 enabled = Signal(reset_less=True)
598 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
599
600 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
601 # negation operation is split into a bitwise not and a +1.
602 # likewise for 16, 32, and 64-bit values.
603
604 # width-extended 1s complement if a is signed, otherwise zero
605 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
606
607 # add 1 if signed, otherwise add zero
608 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
609
610 return m
611
612
613 class Part(Elaboratable):
614 """ a key class which, depending on the partitioning, will determine
615 what action to take when parts of the output are signed or unsigned.
616
617 this requires 2 pieces of data *per operand, per partition*:
618 whether the MSB is HI/LO (per partition!), and whether a signed
619 or unsigned operation has been *requested*.
620
621 once that is determined, signed is basically carried out
622 by splitting 2's complement into 1's complement plus one.
623 1's complement is just a bit-inversion.
624
625 the extra terms - as separate terms - are then thrown at the
626 AddReduce alongside the multiplication part-results.
627 """
628 def __init__(self, width, n_parts, n_levels, pbwid):
629
630 # inputs
631 self.a = Signal(64)
632 self.b = Signal(64)
633 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
634 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
635 self.pbs = Signal(pbwid, reset_less=True)
636
637 # outputs
638 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
639 self.delayed_parts = [
640 [Signal(name=f"delayed_part_{delay}_{i}")
641 for i in range(n_parts)]
642 for delay in range(n_levels)]
643 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
644 self.dplast = [Signal(name=f"dplast_{i}")
645 for i in range(n_parts)]
646
647 self.not_a_term = Signal(width)
648 self.neg_lsb_a_term = Signal(width)
649 self.not_b_term = Signal(width)
650 self.neg_lsb_b_term = Signal(width)
651
652 def elaborate(self, platform):
653 m = Module()
654
655 pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
656 # negated-temporary copy of partition bits
657 npbs = Signal.like(pbs, reset_less=True)
658 m.d.comb += npbs.eq(~pbs)
659 byte_count = 8 // len(parts)
660 for i in range(len(parts)):
661 pbl = []
662 pbl.append(npbs[i * byte_count - 1])
663 for j in range(i * byte_count, (i + 1) * byte_count - 1):
664 pbl.append(pbs[j])
665 pbl.append(npbs[(i + 1) * byte_count - 1])
666 value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
667 m.d.comb += value.eq(Cat(*pbl))
668 m.d.comb += parts[i].eq(~(value).bool())
669 m.d.comb += delayed_parts[0][i].eq(parts[i])
670 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
671 for j in range(len(delayed_parts)-1)]
672 m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
673
674 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
675 self.not_a_term, self.neg_lsb_a_term, \
676 self.not_b_term, self.neg_lsb_b_term
677
678 byte_width = 8 // len(parts) # byte width
679 bit_wid = 8 * byte_width # bit width
680 nat, nbt, nla, nlb = [], [], [], []
681 for i in range(len(parts)):
682 # work out bit-inverted and +1 term for a.
683 pa = LSBNegTerm(bit_wid)
684 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
685 m.d.comb += pa.part.eq(parts[i])
686 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
687 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
688 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
689 nat.append(pa.nt)
690 nla.append(pa.nl)
691
692 # work out bit-inverted and +1 term for b
693 pb = LSBNegTerm(bit_wid)
694 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
695 m.d.comb += pb.part.eq(parts[i])
696 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
697 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
698 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
699 nbt.append(pb.nt)
700 nlb.append(pb.nl)
701
702 # concatenate together and return all 4 results.
703 m.d.comb += [not_a_term.eq(Cat(*nat)),
704 not_b_term.eq(Cat(*nbt)),
705 neg_lsb_a_term.eq(Cat(*nla)),
706 neg_lsb_b_term.eq(Cat(*nlb)),
707 ]
708
709 return m
710
711
712 class IntermediateOut(Elaboratable):
713 """ selects the HI/LO part of the multiplication, for a given bit-width
714 the output is also reconstructed in its SIMD (partition) lanes.
715 """
716 def __init__(self, width, out_wid, n_parts):
717 self.width = width
718 self.n_parts = n_parts
719 self.delayed_part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
720 for i in range(8)]
721 self.intermed = Signal(out_wid, reset_less=True)
722 self.output = Signal(out_wid//2, reset_less=True)
723
724 def elaborate(self, platform):
725 m = Module()
726
727 ol = []
728 w = self.width
729 sel = w // 8
730 for i in range(self.n_parts):
731 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
732 m.d.comb += op.eq(
733 Mux(self.delayed_part_ops[sel * i] == OP_MUL_LOW,
734 self.intermed.bit_select(i * w*2, w),
735 self.intermed.bit_select(i * w*2 + w, w)))
736 ol.append(op)
737 m.d.comb += self.output.eq(Cat(*ol))
738
739 return m
740
741
742 class FinalOut(Elaboratable):
743 """ selects the final output based on the partitioning.
744
745 each byte is selectable independently, i.e. it is possible
746 that some partitions requested 8-bit computation whilst others
747 requested 16 or 32 bit.
748 """
749 def __init__(self, out_wid):
750 # inputs
751 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
752 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
753 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
754
755 self.i8 = Signal(out_wid, reset_less=True)
756 self.i16 = Signal(out_wid, reset_less=True)
757 self.i32 = Signal(out_wid, reset_less=True)
758 self.i64 = Signal(out_wid, reset_less=True)
759
760 # output
761 self.out = Signal(out_wid, reset_less=True)
762
763 def elaborate(self, platform):
764 m = Module()
765 ol = []
766 for i in range(8):
767 # select one of the outputs: d8 selects i8, d16 selects i16
768 # d32 selects i32, and the default is i64.
769 # d8 and d16 are ORed together in the first Mux
770 # then the 2nd selects either i8 or i16.
771 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
772 op = Signal(8, reset_less=True, name="op_%d" % i)
773 m.d.comb += op.eq(
774 Mux(self.d8[i] | self.d16[i // 2],
775 Mux(self.d8[i], self.i8.bit_select(i * 8, 8),
776 self.i16.bit_select(i * 8, 8)),
777 Mux(self.d32[i // 4], self.i32.bit_select(i * 8, 8),
778 self.i64.bit_select(i * 8, 8))))
779 ol.append(op)
780 m.d.comb += self.out.eq(Cat(*ol))
781 return m
782
783
784 class OrMod(Elaboratable):
785 """ ORs four values together in a hierarchical tree
786 """
787 def __init__(self, wid):
788 self.wid = wid
789 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
790 for i in range(4)]
791 self.orout = Signal(wid, reset_less=True)
792
793 def elaborate(self, platform):
794 m = Module()
795 or1 = Signal(self.wid, reset_less=True)
796 or2 = Signal(self.wid, reset_less=True)
797 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
798 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
799 m.d.comb += self.orout.eq(or1 | or2)
800
801 return m
802
803
804 class Signs(Elaboratable):
805 """ determines whether a or b are signed numbers
806 based on the required operation type (OP_MUL_*)
807 """
808
809 def __init__(self):
810 self.part_ops = Signal(2, reset_less=True)
811 self.a_signed = Signal(reset_less=True)
812 self.b_signed = Signal(reset_less=True)
813
814 def elaborate(self, platform):
815
816 m = Module()
817
818 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
819 bsig = (self.part_ops == OP_MUL_LOW) \
820 | (self.part_ops == OP_MUL_SIGNED_HIGH)
821 m.d.comb += self.a_signed.eq(asig)
822 m.d.comb += self.b_signed.eq(bsig)
823
824 return m
825
826
827 class Mul8_16_32_64(Elaboratable):
828 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
829
830 Supports partitioning into any combination of 8, 16, 32, and 64-bit
831 partitions on naturally-aligned boundaries. Supports the operation being
832 set for each partition independently.
833
834 :attribute part_pts: the input partition points. Has a partition point at
835 multiples of 8 in 0 < i < 64. Each partition point's associated
836 ``Value`` is a ``Signal``. Modification not supported, except for by
837 ``Signal.eq``.
838 :attribute part_ops: the operation for each byte. The operation for a
839 particular partition is selected by assigning the selected operation
840 code to each byte in the partition. The allowed operation codes are:
841
842 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
843 RISC-V's `mul` instruction.
844 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
845 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
846 instruction.
847 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
848 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
849 `mulhsu` instruction.
850 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
851 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
852 instruction.
853 """
854
855 def __init__(self, register_levels=()):
856 """ register_levels: specifies the points in the cascade at which
857 flip-flops are to be inserted.
858 """
859
860 # parameter(s)
861 self.register_levels = list(register_levels)
862
863 # inputs
864 self.part_pts = PartitionPoints()
865 for i in range(8, 64, 8):
866 self.part_pts[i] = Signal(name=f"part_pts_{i}")
867 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
868 self.a = Signal(64)
869 self.b = Signal(64)
870
871 # intermediates (needed for unit tests)
872 self._intermediate_output = Signal(128)
873
874 # output
875 self.output = Signal(64)
876
877 def _part_byte(self, index):
878 if index == -1 or index == 7:
879 return C(True, 1)
880 assert index >= 0 and index < 8
881 return self.part_pts[index * 8 + 8]
882
883 def elaborate(self, platform):
884 m = Module()
885
886 # collect part-bytes
887 pbs = Signal(8, reset_less=True)
888 tl = []
889 for i in range(8):
890 pb = Signal(name="pb%d" % i, reset_less=True)
891 m.d.comb += pb.eq(self._part_byte(i))
892 tl.append(pb)
893 m.d.comb += pbs.eq(Cat(*tl))
894
895 # local variables
896 signs = []
897 for i in range(8):
898 s = Signs()
899 signs.append(s)
900 setattr(m.submodules, "signs%d" % i, s)
901 m.d.comb += s.part_ops.eq(self.part_ops[i])
902
903 delayed_part_ops = [
904 [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
905 for i in range(8)]
906 for delay in range(1 + len(self.register_levels))]
907 for i in range(len(self.part_ops)):
908 m.d.comb += delayed_part_ops[0][i].eq(self.part_ops[i])
909 m.d.sync += [delayed_part_ops[j + 1][i].eq(delayed_part_ops[j][i])
910 for j in range(len(self.register_levels))]
911
912 n_levels = len(self.register_levels)+1
913 m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
914 m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
915 m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
916 m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
917 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
918 for mod in [part_8, part_16, part_32, part_64]:
919 m.d.comb += mod.a.eq(self.a)
920 m.d.comb += mod.b.eq(self.b)
921 for i in range(len(signs)):
922 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
923 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
924 m.d.comb += mod.pbs.eq(pbs)
925 nat_l.append(mod.not_a_term)
926 nbt_l.append(mod.not_b_term)
927 nla_l.append(mod.neg_lsb_a_term)
928 nlb_l.append(mod.neg_lsb_b_term)
929
930 terms = []
931
932 for a_index in range(8):
933 t = ProductTerms(8, 128, 8, a_index, 8)
934 setattr(m.submodules, "terms_%d" % a_index, t)
935
936 m.d.comb += t.a.eq(self.a)
937 m.d.comb += t.b.eq(self.b)
938 m.d.comb += t.pb_en.eq(pbs)
939
940 for term in t.terms:
941 terms.append(term)
942
943 # it's fine to bitwise-or data together since they are never enabled
944 # at the same time
945 m.submodules.nat_or = nat_or = OrMod(128)
946 m.submodules.nbt_or = nbt_or = OrMod(128)
947 m.submodules.nla_or = nla_or = OrMod(128)
948 m.submodules.nlb_or = nlb_or = OrMod(128)
949 for l, mod in [(nat_l, nat_or),
950 (nbt_l, nbt_or),
951 (nla_l, nla_or),
952 (nlb_l, nlb_or)]:
953 for i in range(len(l)):
954 m.d.comb += mod.orin[i].eq(l[i])
955 terms.append(mod.orout)
956
957 expanded_part_pts = PartitionPoints()
958 for i, v in self.part_pts.items():
959 signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
960 expanded_part_pts[i * 2] = signal
961 m.d.comb += signal.eq(v)
962
963 add_reduce = AddReduce(terms,
964 128,
965 self.register_levels,
966 expanded_part_pts)
967 m.submodules.add_reduce = add_reduce
968 m.d.comb += self._intermediate_output.eq(add_reduce.output)
969 # create _output_64
970 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
971 m.d.comb += io64.intermed.eq(self._intermediate_output)
972 for i in range(8):
973 m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
974
975 # create _output_32
976 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
977 m.d.comb += io32.intermed.eq(self._intermediate_output)
978 for i in range(8):
979 m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
980
981 # create _output_16
982 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
983 m.d.comb += io16.intermed.eq(self._intermediate_output)
984 for i in range(8):
985 m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
986
987 # create _output_8
988 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
989 m.d.comb += io8.intermed.eq(self._intermediate_output)
990 for i in range(8):
991 m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
992
993 # final output
994 m.submodules.finalout = finalout = FinalOut(64)
995 for i in range(len(part_8.delayed_parts[-1])):
996 m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
997 for i in range(len(part_16.delayed_parts[-1])):
998 m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
999 for i in range(len(part_32.delayed_parts[-1])):
1000 m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
1001 m.d.comb += finalout.i8.eq(io8.output)
1002 m.d.comb += finalout.i16.eq(io16.output)
1003 m.d.comb += finalout.i32.eq(io32.output)
1004 m.d.comb += finalout.i64.eq(io64.output)
1005 m.d.comb += self.output.eq(finalout.out)
1006
1007 return m
1008
1009
1010 if __name__ == "__main__":
1011 m = Mul8_16_32_64()
1012 main(m, ports=[m.a,
1013 m.b,
1014 m._intermediate_output,
1015 m.output,
1016 *m.part_ops,
1017 *m.part_pts.values()])