MaskedFullAdder performs ANDing in a group by pre-shifting the carry bits
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 """
58 if name is None:
59 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
60 retval = PartitionPoints()
61 for point, enabled in self.items():
62 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
63 return retval
64
65 def eq(self, rhs):
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self.keys()) != set(rhs.keys()):
68 raise ValueError("incompatible point set")
69 for point, enabled in self.items():
70 yield enabled.eq(rhs[point])
71
72 def as_mask(self, width):
73 """Create a bit-mask from `self`.
74
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
77
78 :param width: the bit width of the resulting mask
79 """
80 bits = []
81 for i in range(width):
82 if i in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106
107 class FullAdder(Elaboratable):
108 """Full Adder.
109
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
115
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
120 """
121
122 def __init__(self, width):
123 """Create a ``FullAdder``.
124
125 :param width: the bit width of the input and output
126 """
127 self.in0 = Signal(width)
128 self.in1 = Signal(width)
129 self.in2 = Signal(width)
130 self.sum = Signal(width)
131 self.carry = Signal(width)
132
133 def elaborate(self, platform):
134 """Elaborate this module."""
135 m = Module()
136 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
137 m.d.comb += self.carry.eq((self.in0 & self.in1)
138 | (self.in1 & self.in2)
139 | (self.in2 & self.in0))
140 return m
141
142
143 class MaskedFullAdder(Elaboratable):
144 """Masked Full Adder.
145
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
152
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
156
157 See the following discussion as to why this is no longer derived
158 from FullAdder. Each carry is shifted here *before* being ANDed
159 with the mask, so that an AOI cell may be used (which is more
160 gate-efficient)
161 https://en.wikipedia.org/wiki/AND-OR-Invert
162 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
163 """
164
165 def __init__(self, width):
166 """Create a ``MaskedFullAdder``.
167
168 :param width: the bit width of the input and output
169 """
170 self.width = width
171 self.mask = Signal(width, reset_less=True)
172 self.mcarry = Signal(width, reset_less=True)
173 self.in0 = Signal(width, reset_less=True)
174 self.in1 = Signal(width, reset_less=True)
175 self.in2 = Signal(width, reset_less=True)
176 self.sum = Signal(width, reset_less=True)
177
178 def elaborate(self, platform):
179 """Elaborate this module."""
180 m = Module()
181 s1 = Signal(self.width, reset_less=True)
182 s2 = Signal(self.width, reset_less=True)
183 s3 = Signal(self.width, reset_less=True)
184 c1 = Signal(self.width, reset_less=True)
185 c2 = Signal(self.width, reset_less=True)
186 c3 = Signal(self.width, reset_less=True)
187 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
188 m.d.comb += s1.eq(Cat(0, self.in0))
189 m.d.comb += s2.eq(Cat(0, self.in1))
190 m.d.comb += s3.eq(Cat(0, self.in2))
191 m.d.comb += c1.eq(s1 & s2 & self.mask)
192 m.d.comb += c2.eq(s2 & s3 & self.mask)
193 m.d.comb += c3.eq(s3 & s1 & self.mask)
194 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
195 return m
196
197
198 class PartitionedAdder(Elaboratable):
199 """Partitioned Adder.
200
201 Performs the final add. The partition points are included in the
202 actual add (in one of the operands only), which causes a carry over
203 to the next bit. Then the final output *removes* the extra bits from
204 the result.
205
206 partition: .... P... P... P... P... (32 bits)
207 a : .... .... .... .... .... (32 bits)
208 b : .... .... .... .... .... (32 bits)
209 exp-a : ....P....P....P....P.... (32+4 bits)
210 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
211 exp-o : ....xN...xN...xN...xN... (32+4 bits)
212 o : .... N... N... N... N... (32 bits)
213
214 :attribute width: the bit width of the input and output. Read-only.
215 :attribute a: the first input to the adder
216 :attribute b: the second input to the adder
217 :attribute output: the sum output
218 :attribute partition_points: the input partition points. Modification not
219 supported, except for by ``Signal.eq``.
220 """
221
222 def __init__(self, width, partition_points):
223 """Create a ``PartitionedAdder``.
224
225 :param width: the bit width of the input and output
226 :param partition_points: the input partition points
227 """
228 self.width = width
229 self.a = Signal(width)
230 self.b = Signal(width)
231 self.output = Signal(width)
232 self.partition_points = PartitionPoints(partition_points)
233 if not self.partition_points.fits_in_width(width):
234 raise ValueError("partition_points doesn't fit in width")
235 expanded_width = 0
236 for i in range(self.width):
237 if i in self.partition_points:
238 expanded_width += 1
239 expanded_width += 1
240 self._expanded_width = expanded_width
241 # XXX these have to remain here due to some horrible nmigen
242 # simulation bugs involving sync. it is *not* necessary to
243 # have them here, they should (under normal circumstances)
244 # be moved into elaborate, as they are entirely local
245 self._expanded_a = Signal(expanded_width) # includes extra part-points
246 self._expanded_b = Signal(expanded_width) # likewise.
247 self._expanded_o = Signal(expanded_width) # likewise.
248
249 def elaborate(self, platform):
250 """Elaborate this module."""
251 m = Module()
252 expanded_index = 0
253 # store bits in a list, use Cat later. graphviz is much cleaner
254 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
255
256 # partition points are "breaks" (extra zeros or 1s) in what would
257 # otherwise be a massive long add. when the "break" points are 0,
258 # whatever is in it (in the output) is discarded. however when
259 # there is a "1", it causes a roll-over carry to the *next* bit.
260 # we still ignore the "break" bit in the [intermediate] output,
261 # however by that time we've got the effect that we wanted: the
262 # carry has been carried *over* the break point.
263
264 for i in range(self.width):
265 if i in self.partition_points:
266 # add extra bit set to 0 + 0 for enabled partition points
267 # and 1 + 0 for disabled partition points
268 ea.append(self._expanded_a[expanded_index])
269 al.append(~self.partition_points[i]) # add extra bit in a
270 eb.append(self._expanded_b[expanded_index])
271 bl.append(C(0)) # yes, add a zero
272 expanded_index += 1 # skip the extra point. NOT in the output
273 ea.append(self._expanded_a[expanded_index])
274 eb.append(self._expanded_b[expanded_index])
275 eo.append(self._expanded_o[expanded_index])
276 al.append(self.a[i])
277 bl.append(self.b[i])
278 ol.append(self.output[i])
279 expanded_index += 1
280
281 # combine above using Cat
282 m.d.comb += Cat(*ea).eq(Cat(*al))
283 m.d.comb += Cat(*eb).eq(Cat(*bl))
284 m.d.comb += Cat(*ol).eq(Cat(*eo))
285
286 # use only one addition to take advantage of look-ahead carry and
287 # special hardware on FPGAs
288 m.d.comb += self._expanded_o.eq(
289 self._expanded_a + self._expanded_b)
290 return m
291
292
293 FULL_ADDER_INPUT_COUNT = 3
294
295
296 class AddReduceSingle(Elaboratable):
297 """Add list of numbers together.
298
299 :attribute inputs: input ``Signal``s to be summed. Modification not
300 supported, except for by ``Signal.eq``.
301 :attribute register_levels: List of nesting levels that should have
302 pipeline registers.
303 :attribute output: output sum.
304 :attribute partition_points: the input partition points. Modification not
305 supported, except for by ``Signal.eq``.
306 """
307
308 def __init__(self, inputs, output_width, register_levels, partition_points):
309 """Create an ``AddReduce``.
310
311 :param inputs: input ``Signal``s to be summed.
312 :param output_width: bit-width of ``output``.
313 :param register_levels: List of nesting levels that should have
314 pipeline registers.
315 :param partition_points: the input partition points.
316 """
317 self.inputs = list(inputs)
318 self._resized_inputs = [
319 Signal(output_width, name=f"resized_inputs[{i}]")
320 for i in range(len(self.inputs))]
321 self.register_levels = list(register_levels)
322 self.output = Signal(output_width)
323 self.partition_points = PartitionPoints(partition_points)
324 if not self.partition_points.fits_in_width(output_width):
325 raise ValueError("partition_points doesn't fit in output_width")
326 self._reg_partition_points = self.partition_points.like()
327
328 max_level = AddReduce.get_max_level(len(self.inputs))
329 for level in self.register_levels:
330 if level > max_level:
331 raise ValueError(
332 "not enough adder levels for specified register levels")
333
334 @staticmethod
335 def get_max_level(input_count):
336 """Get the maximum level.
337
338 All ``register_levels`` must be less than or equal to the maximum
339 level.
340 """
341 retval = 0
342 while True:
343 groups = AddReduce.full_adder_groups(input_count)
344 if len(groups) == 0:
345 return retval
346 input_count %= FULL_ADDER_INPUT_COUNT
347 input_count += 2 * len(groups)
348 retval += 1
349 @staticmethod
350 def full_adder_groups(input_count):
351 """Get ``inputs`` indices for which a full adder should be built."""
352 return range(0,
353 input_count - FULL_ADDER_INPUT_COUNT + 1,
354 FULL_ADDER_INPUT_COUNT)
355
356 def _elaborate(self, platform):
357 """Elaborate this module."""
358 m = Module()
359
360 # resize inputs to correct bit-width and optionally add in
361 # pipeline registers
362 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
363 for i in range(len(self.inputs))]
364 if 0 in self.register_levels:
365 m.d.sync += resized_input_assignments
366 m.d.sync += self._reg_partition_points.eq(self.partition_points)
367 else:
368 m.d.comb += resized_input_assignments
369 m.d.comb += self._reg_partition_points.eq(self.partition_points)
370
371 groups = AddReduceSingle.full_adder_groups(len(self.inputs))
372 # if there are no full adders to create, then we handle the base cases
373 # and return, otherwise we go on to the recursive case
374 if len(groups) == 0:
375 if len(self.inputs) == 0:
376 # use 0 as the default output value
377 m.d.comb += self.output.eq(0)
378 elif len(self.inputs) == 1:
379 # handle single input
380 m.d.comb += self.output.eq(self._resized_inputs[0])
381 else:
382 # base case for adding 2 or more inputs, which get recursively
383 # reduced to 2 inputs
384 assert len(self.inputs) == 2
385 adder = PartitionedAdder(len(self.output),
386 self._reg_partition_points)
387 m.submodules.final_adder = adder
388 m.d.comb += adder.a.eq(self._resized_inputs[0])
389 m.d.comb += adder.b.eq(self._resized_inputs[1])
390 m.d.comb += self.output.eq(adder.output)
391 return None, m
392
393 # go on to prepare recursive case
394 intermediate_terms = []
395
396 def add_intermediate_term(value):
397 intermediate_term = Signal(
398 len(self.output),
399 name=f"intermediate_terms[{len(intermediate_terms)}]")
400 intermediate_terms.append(intermediate_term)
401 m.d.comb += intermediate_term.eq(value)
402
403 # store mask in intermediary (simplifies graph)
404 part_mask = Signal(len(self.output), reset_less=True)
405 mask = self._reg_partition_points.as_mask(len(self.output))
406 m.d.comb += part_mask.eq(mask)
407
408 # create full adders for this recursive level.
409 # this shrinks N terms to 2 * (N // 3) plus the remainder
410 for i in groups:
411 adder_i = MaskedFullAdder(len(self.output))
412 setattr(m.submodules, f"adder_{i}", adder_i)
413 m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
414 m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
415 m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
416 m.d.comb += adder_i.mask.eq(part_mask)
417 # add both the sum and the masked-carry to the next level.
418 # 3 inputs have now been reduced to 2...
419 add_intermediate_term(adder_i.sum)
420 add_intermediate_term(adder_i.mcarry)
421 # handle the remaining inputs.
422 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
423 add_intermediate_term(self._resized_inputs[-1])
424 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
425 # Just pass the terms to the next layer, since we wouldn't gain
426 # anything by using a half adder since there would still be 2 terms
427 # and just passing the terms to the next layer saves gates.
428 add_intermediate_term(self._resized_inputs[-2])
429 add_intermediate_term(self._resized_inputs[-1])
430 else:
431 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
432
433 return intermediate_terms, m
434
435
436 class AddReduce(AddReduceSingle):
437 """Recursively Add list of numbers together.
438
439 :attribute inputs: input ``Signal``s to be summed. Modification not
440 supported, except for by ``Signal.eq``.
441 :attribute register_levels: List of nesting levels that should have
442 pipeline registers.
443 :attribute output: output sum.
444 :attribute partition_points: the input partition points. Modification not
445 supported, except for by ``Signal.eq``.
446 """
447
448 def __init__(self, inputs, output_width, register_levels, partition_points):
449 """Create an ``AddReduce``.
450
451 :param inputs: input ``Signal``s to be summed.
452 :param output_width: bit-width of ``output``.
453 :param register_levels: List of nesting levels that should have
454 pipeline registers.
455 :param partition_points: the input partition points.
456 """
457 AddReduceSingle.__init__(self, inputs, output_width, register_levels,
458 partition_points)
459
460 def next_register_levels(self):
461 """``Iterable`` of ``register_levels`` for next recursive level."""
462 for level in self.register_levels:
463 if level > 0:
464 yield level - 1
465
466 def elaborate(self, platform):
467 """Elaborate this module."""
468 intermediate_terms, m = AddReduceSingle._elaborate(self, platform)
469 if intermediate_terms is None:
470 return m
471
472 # recursive invocation of ``AddReduce``
473 next_level = AddReduce(intermediate_terms,
474 len(self.output),
475 self.next_register_levels(),
476 self._reg_partition_points)
477 m.submodules.next_level = next_level
478 m.d.comb += self.output.eq(next_level.output)
479 return m
480
481
482 OP_MUL_LOW = 0
483 OP_MUL_SIGNED_HIGH = 1
484 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
485 OP_MUL_UNSIGNED_HIGH = 3
486
487
488 def get_term(value, shift=0, enabled=None):
489 if enabled is not None:
490 value = Mux(enabled, value, 0)
491 if shift > 0:
492 value = Cat(Repl(C(0, 1), shift), value)
493 else:
494 assert shift == 0
495 return value
496
497
498 class ProductTerm(Elaboratable):
499 """ this class creates a single product term (a[..]*b[..]).
500 it has a design flaw in that is the *output* that is selected,
501 where the multiplication(s) are combinatorially generated
502 all the time.
503 """
504
505 def __init__(self, width, twidth, pbwid, a_index, b_index):
506 self.a_index = a_index
507 self.b_index = b_index
508 shift = 8 * (self.a_index + self.b_index)
509 self.pwidth = width
510 self.twidth = twidth
511 self.width = width*2
512 self.shift = shift
513
514 self.ti = Signal(self.width, reset_less=True)
515 self.term = Signal(twidth, reset_less=True)
516 self.a = Signal(twidth//2, reset_less=True)
517 self.b = Signal(twidth//2, reset_less=True)
518 self.pb_en = Signal(pbwid, reset_less=True)
519
520 self.tl = tl = []
521 min_index = min(self.a_index, self.b_index)
522 max_index = max(self.a_index, self.b_index)
523 for i in range(min_index, max_index):
524 tl.append(self.pb_en[i])
525 name = "te_%d_%d" % (self.a_index, self.b_index)
526 if len(tl) > 0:
527 term_enabled = Signal(name=name, reset_less=True)
528 else:
529 term_enabled = None
530 self.enabled = term_enabled
531 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
532
533 def elaborate(self, platform):
534
535 m = Module()
536 if self.enabled is not None:
537 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
538
539 bsa = Signal(self.width, reset_less=True)
540 bsb = Signal(self.width, reset_less=True)
541 a_index, b_index = self.a_index, self.b_index
542 pwidth = self.pwidth
543 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
544 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
545 m.d.comb += self.ti.eq(bsa * bsb)
546 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
547 """
548 #TODO: sort out width issues, get inputs a/b switched on/off.
549 #data going into Muxes is 1/2 the required width
550
551 pwidth = self.pwidth
552 width = self.width
553 bsa = Signal(self.twidth//2, reset_less=True)
554 bsb = Signal(self.twidth//2, reset_less=True)
555 asel = Signal(width, reset_less=True)
556 bsel = Signal(width, reset_less=True)
557 a_index, b_index = self.a_index, self.b_index
558 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
559 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
560 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
561 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
562 m.d.comb += self.ti.eq(bsa * bsb)
563 m.d.comb += self.term.eq(self.ti)
564 """
565
566 return m
567
568
569 class ProductTerms(Elaboratable):
570 """ creates a bank of product terms. also performs the actual bit-selection
571 this class is to be wrapped with a for-loop on the "a" operand.
572 it creates a second-level for-loop on the "b" operand.
573 """
574 def __init__(self, width, twidth, pbwid, a_index, blen):
575 self.a_index = a_index
576 self.blen = blen
577 self.pwidth = width
578 self.twidth = twidth
579 self.pbwid = pbwid
580 self.a = Signal(twidth//2, reset_less=True)
581 self.b = Signal(twidth//2, reset_less=True)
582 self.pb_en = Signal(pbwid, reset_less=True)
583 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
584 for i in range(blen)]
585
586 def elaborate(self, platform):
587
588 m = Module()
589
590 for b_index in range(self.blen):
591 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
592 self.a_index, b_index)
593 setattr(m.submodules, "term_%d" % b_index, t)
594
595 m.d.comb += t.a.eq(self.a)
596 m.d.comb += t.b.eq(self.b)
597 m.d.comb += t.pb_en.eq(self.pb_en)
598
599 m.d.comb += self.terms[b_index].eq(t.term)
600
601 return m
602
603 class LSBNegTerm(Elaboratable):
604
605 def __init__(self, bit_width):
606 self.bit_width = bit_width
607 self.part = Signal(reset_less=True)
608 self.signed = Signal(reset_less=True)
609 self.op = Signal(bit_width, reset_less=True)
610 self.msb = Signal(reset_less=True)
611 self.nt = Signal(bit_width*2, reset_less=True)
612 self.nl = Signal(bit_width*2, reset_less=True)
613
614 def elaborate(self, platform):
615 m = Module()
616 comb = m.d.comb
617 bit_wid = self.bit_width
618 ext = Repl(0, bit_wid) # extend output to HI part
619
620 # determine sign of each incoming number *in this partition*
621 enabled = Signal(reset_less=True)
622 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
623
624 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
625 # negation operation is split into a bitwise not and a +1.
626 # likewise for 16, 32, and 64-bit values.
627
628 # width-extended 1s complement if a is signed, otherwise zero
629 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
630
631 # add 1 if signed, otherwise add zero
632 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
633
634 return m
635
636
637 class Part(Elaboratable):
638 """ a key class which, depending on the partitioning, will determine
639 what action to take when parts of the output are signed or unsigned.
640
641 this requires 2 pieces of data *per operand, per partition*:
642 whether the MSB is HI/LO (per partition!), and whether a signed
643 or unsigned operation has been *requested*.
644
645 once that is determined, signed is basically carried out
646 by splitting 2's complement into 1's complement plus one.
647 1's complement is just a bit-inversion.
648
649 the extra terms - as separate terms - are then thrown at the
650 AddReduce alongside the multiplication part-results.
651 """
652 def __init__(self, width, n_parts, n_levels, pbwid):
653
654 # inputs
655 self.a = Signal(64)
656 self.b = Signal(64)
657 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
658 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
659 self.pbs = Signal(pbwid, reset_less=True)
660
661 # outputs
662 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
663 self.delayed_parts = [
664 [Signal(name=f"delayed_part_{delay}_{i}")
665 for i in range(n_parts)]
666 for delay in range(n_levels)]
667 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
668 self.dplast = [Signal(name=f"dplast_{i}")
669 for i in range(n_parts)]
670
671 self.not_a_term = Signal(width)
672 self.neg_lsb_a_term = Signal(width)
673 self.not_b_term = Signal(width)
674 self.neg_lsb_b_term = Signal(width)
675
676 def elaborate(self, platform):
677 m = Module()
678
679 pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
680 # negated-temporary copy of partition bits
681 npbs = Signal.like(pbs, reset_less=True)
682 m.d.comb += npbs.eq(~pbs)
683 byte_count = 8 // len(parts)
684 for i in range(len(parts)):
685 pbl = []
686 pbl.append(npbs[i * byte_count - 1])
687 for j in range(i * byte_count, (i + 1) * byte_count - 1):
688 pbl.append(pbs[j])
689 pbl.append(npbs[(i + 1) * byte_count - 1])
690 value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
691 m.d.comb += value.eq(Cat(*pbl))
692 m.d.comb += parts[i].eq(~(value).bool())
693 m.d.comb += delayed_parts[0][i].eq(parts[i])
694 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
695 for j in range(len(delayed_parts)-1)]
696 m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
697
698 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
699 self.not_a_term, self.neg_lsb_a_term, \
700 self.not_b_term, self.neg_lsb_b_term
701
702 byte_width = 8 // len(parts) # byte width
703 bit_wid = 8 * byte_width # bit width
704 nat, nbt, nla, nlb = [], [], [], []
705 for i in range(len(parts)):
706 # work out bit-inverted and +1 term for a.
707 pa = LSBNegTerm(bit_wid)
708 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
709 m.d.comb += pa.part.eq(parts[i])
710 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
711 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
712 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
713 nat.append(pa.nt)
714 nla.append(pa.nl)
715
716 # work out bit-inverted and +1 term for b
717 pb = LSBNegTerm(bit_wid)
718 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
719 m.d.comb += pb.part.eq(parts[i])
720 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
721 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
722 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
723 nbt.append(pb.nt)
724 nlb.append(pb.nl)
725
726 # concatenate together and return all 4 results.
727 m.d.comb += [not_a_term.eq(Cat(*nat)),
728 not_b_term.eq(Cat(*nbt)),
729 neg_lsb_a_term.eq(Cat(*nla)),
730 neg_lsb_b_term.eq(Cat(*nlb)),
731 ]
732
733 return m
734
735
736 class IntermediateOut(Elaboratable):
737 """ selects the HI/LO part of the multiplication, for a given bit-width
738 the output is also reconstructed in its SIMD (partition) lanes.
739 """
740 def __init__(self, width, out_wid, n_parts):
741 self.width = width
742 self.n_parts = n_parts
743 self.delayed_part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
744 for i in range(8)]
745 self.intermed = Signal(out_wid, reset_less=True)
746 self.output = Signal(out_wid//2, reset_less=True)
747
748 def elaborate(self, platform):
749 m = Module()
750
751 ol = []
752 w = self.width
753 sel = w // 8
754 for i in range(self.n_parts):
755 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
756 m.d.comb += op.eq(
757 Mux(self.delayed_part_ops[sel * i] == OP_MUL_LOW,
758 self.intermed.bit_select(i * w*2, w),
759 self.intermed.bit_select(i * w*2 + w, w)))
760 ol.append(op)
761 m.d.comb += self.output.eq(Cat(*ol))
762
763 return m
764
765
766 class FinalOut(Elaboratable):
767 """ selects the final output based on the partitioning.
768
769 each byte is selectable independently, i.e. it is possible
770 that some partitions requested 8-bit computation whilst others
771 requested 16 or 32 bit.
772 """
773 def __init__(self, out_wid):
774 # inputs
775 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
776 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
777 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
778
779 self.i8 = Signal(out_wid, reset_less=True)
780 self.i16 = Signal(out_wid, reset_less=True)
781 self.i32 = Signal(out_wid, reset_less=True)
782 self.i64 = Signal(out_wid, reset_less=True)
783
784 # output
785 self.out = Signal(out_wid, reset_less=True)
786
787 def elaborate(self, platform):
788 m = Module()
789 ol = []
790 for i in range(8):
791 # select one of the outputs: d8 selects i8, d16 selects i16
792 # d32 selects i32, and the default is i64.
793 # d8 and d16 are ORed together in the first Mux
794 # then the 2nd selects either i8 or i16.
795 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
796 op = Signal(8, reset_less=True, name="op_%d" % i)
797 m.d.comb += op.eq(
798 Mux(self.d8[i] | self.d16[i // 2],
799 Mux(self.d8[i], self.i8.bit_select(i * 8, 8),
800 self.i16.bit_select(i * 8, 8)),
801 Mux(self.d32[i // 4], self.i32.bit_select(i * 8, 8),
802 self.i64.bit_select(i * 8, 8))))
803 ol.append(op)
804 m.d.comb += self.out.eq(Cat(*ol))
805 return m
806
807
808 class OrMod(Elaboratable):
809 """ ORs four values together in a hierarchical tree
810 """
811 def __init__(self, wid):
812 self.wid = wid
813 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
814 for i in range(4)]
815 self.orout = Signal(wid, reset_less=True)
816
817 def elaborate(self, platform):
818 m = Module()
819 or1 = Signal(self.wid, reset_less=True)
820 or2 = Signal(self.wid, reset_less=True)
821 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
822 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
823 m.d.comb += self.orout.eq(or1 | or2)
824
825 return m
826
827
828 class Signs(Elaboratable):
829 """ determines whether a or b are signed numbers
830 based on the required operation type (OP_MUL_*)
831 """
832
833 def __init__(self):
834 self.part_ops = Signal(2, reset_less=True)
835 self.a_signed = Signal(reset_less=True)
836 self.b_signed = Signal(reset_less=True)
837
838 def elaborate(self, platform):
839
840 m = Module()
841
842 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
843 bsig = (self.part_ops == OP_MUL_LOW) \
844 | (self.part_ops == OP_MUL_SIGNED_HIGH)
845 m.d.comb += self.a_signed.eq(asig)
846 m.d.comb += self.b_signed.eq(bsig)
847
848 return m
849
850
851 class Mul8_16_32_64(Elaboratable):
852 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
853
854 Supports partitioning into any combination of 8, 16, 32, and 64-bit
855 partitions on naturally-aligned boundaries. Supports the operation being
856 set for each partition independently.
857
858 :attribute part_pts: the input partition points. Has a partition point at
859 multiples of 8 in 0 < i < 64. Each partition point's associated
860 ``Value`` is a ``Signal``. Modification not supported, except for by
861 ``Signal.eq``.
862 :attribute part_ops: the operation for each byte. The operation for a
863 particular partition is selected by assigning the selected operation
864 code to each byte in the partition. The allowed operation codes are:
865
866 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
867 RISC-V's `mul` instruction.
868 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
869 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
870 instruction.
871 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
872 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
873 `mulhsu` instruction.
874 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
875 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
876 instruction.
877 """
878
879 def __init__(self, register_levels=()):
880 """ register_levels: specifies the points in the cascade at which
881 flip-flops are to be inserted.
882 """
883
884 # parameter(s)
885 self.register_levels = list(register_levels)
886
887 # inputs
888 self.part_pts = PartitionPoints()
889 for i in range(8, 64, 8):
890 self.part_pts[i] = Signal(name=f"part_pts_{i}")
891 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
892 self.a = Signal(64)
893 self.b = Signal(64)
894
895 # intermediates (needed for unit tests)
896 self._intermediate_output = Signal(128)
897
898 # output
899 self.output = Signal(64)
900
901 def _part_byte(self, index):
902 if index == -1 or index == 7:
903 return C(True, 1)
904 assert index >= 0 and index < 8
905 return self.part_pts[index * 8 + 8]
906
907 def elaborate(self, platform):
908 m = Module()
909
910 # collect part-bytes
911 pbs = Signal(8, reset_less=True)
912 tl = []
913 for i in range(8):
914 pb = Signal(name="pb%d" % i, reset_less=True)
915 m.d.comb += pb.eq(self._part_byte(i))
916 tl.append(pb)
917 m.d.comb += pbs.eq(Cat(*tl))
918
919 # local variables
920 signs = []
921 for i in range(8):
922 s = Signs()
923 signs.append(s)
924 setattr(m.submodules, "signs%d" % i, s)
925 m.d.comb += s.part_ops.eq(self.part_ops[i])
926
927 delayed_part_ops = [
928 [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
929 for i in range(8)]
930 for delay in range(1 + len(self.register_levels))]
931 for i in range(len(self.part_ops)):
932 m.d.comb += delayed_part_ops[0][i].eq(self.part_ops[i])
933 m.d.sync += [delayed_part_ops[j + 1][i].eq(delayed_part_ops[j][i])
934 for j in range(len(self.register_levels))]
935
936 n_levels = len(self.register_levels)+1
937 m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
938 m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
939 m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
940 m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
941 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
942 for mod in [part_8, part_16, part_32, part_64]:
943 m.d.comb += mod.a.eq(self.a)
944 m.d.comb += mod.b.eq(self.b)
945 for i in range(len(signs)):
946 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
947 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
948 m.d.comb += mod.pbs.eq(pbs)
949 nat_l.append(mod.not_a_term)
950 nbt_l.append(mod.not_b_term)
951 nla_l.append(mod.neg_lsb_a_term)
952 nlb_l.append(mod.neg_lsb_b_term)
953
954 terms = []
955
956 for a_index in range(8):
957 t = ProductTerms(8, 128, 8, a_index, 8)
958 setattr(m.submodules, "terms_%d" % a_index, t)
959
960 m.d.comb += t.a.eq(self.a)
961 m.d.comb += t.b.eq(self.b)
962 m.d.comb += t.pb_en.eq(pbs)
963
964 for term in t.terms:
965 terms.append(term)
966
967 # it's fine to bitwise-or data together since they are never enabled
968 # at the same time
969 m.submodules.nat_or = nat_or = OrMod(128)
970 m.submodules.nbt_or = nbt_or = OrMod(128)
971 m.submodules.nla_or = nla_or = OrMod(128)
972 m.submodules.nlb_or = nlb_or = OrMod(128)
973 for l, mod in [(nat_l, nat_or),
974 (nbt_l, nbt_or),
975 (nla_l, nla_or),
976 (nlb_l, nlb_or)]:
977 for i in range(len(l)):
978 m.d.comb += mod.orin[i].eq(l[i])
979 terms.append(mod.orout)
980
981 expanded_part_pts = PartitionPoints()
982 for i, v in self.part_pts.items():
983 signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
984 expanded_part_pts[i * 2] = signal
985 m.d.comb += signal.eq(v)
986
987 add_reduce = AddReduce(terms,
988 128,
989 self.register_levels,
990 expanded_part_pts)
991 m.submodules.add_reduce = add_reduce
992 m.d.comb += self._intermediate_output.eq(add_reduce.output)
993 # create _output_64
994 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
995 m.d.comb += io64.intermed.eq(self._intermediate_output)
996 for i in range(8):
997 m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
998
999 # create _output_32
1000 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1001 m.d.comb += io32.intermed.eq(self._intermediate_output)
1002 for i in range(8):
1003 m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
1004
1005 # create _output_16
1006 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1007 m.d.comb += io16.intermed.eq(self._intermediate_output)
1008 for i in range(8):
1009 m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
1010
1011 # create _output_8
1012 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1013 m.d.comb += io8.intermed.eq(self._intermediate_output)
1014 for i in range(8):
1015 m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
1016
1017 # final output
1018 m.submodules.finalout = finalout = FinalOut(64)
1019 for i in range(len(part_8.delayed_parts[-1])):
1020 m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
1021 for i in range(len(part_16.delayed_parts[-1])):
1022 m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
1023 for i in range(len(part_32.delayed_parts[-1])):
1024 m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
1025 m.d.comb += finalout.i8.eq(io8.output)
1026 m.d.comb += finalout.i16.eq(io16.output)
1027 m.d.comb += finalout.i32.eq(io32.output)
1028 m.d.comb += finalout.i64.eq(io64.output)
1029 m.d.comb += self.output.eq(finalout.out)
1030
1031 return m
1032
1033
1034 if __name__ == "__main__":
1035 m = Mul8_16_32_64()
1036 main(m, ports=[m.a,
1037 m.b,
1038 m._intermediate_output,
1039 m.output,
1040 *m.part_ops,
1041 *m.part_pts.values()])