removing recursion from AddReduce
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 """
58 if name is None:
59 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
60 retval = PartitionPoints()
61 for point, enabled in self.items():
62 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
63 return retval
64
65 def eq(self, rhs):
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self.keys()) != set(rhs.keys()):
68 raise ValueError("incompatible point set")
69 for point, enabled in self.items():
70 yield enabled.eq(rhs[point])
71
72 def as_mask(self, width):
73 """Create a bit-mask from `self`.
74
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
77
78 :param width: the bit width of the resulting mask
79 """
80 bits = []
81 for i in range(width):
82 if i in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106
107 class FullAdder(Elaboratable):
108 """Full Adder.
109
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
115
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
120 """
121
122 def __init__(self, width):
123 """Create a ``FullAdder``.
124
125 :param width: the bit width of the input and output
126 """
127 self.in0 = Signal(width)
128 self.in1 = Signal(width)
129 self.in2 = Signal(width)
130 self.sum = Signal(width)
131 self.carry = Signal(width)
132
133 def elaborate(self, platform):
134 """Elaborate this module."""
135 m = Module()
136 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
137 m.d.comb += self.carry.eq((self.in0 & self.in1)
138 | (self.in1 & self.in2)
139 | (self.in2 & self.in0))
140 return m
141
142
143 class MaskedFullAdder(Elaboratable):
144 """Masked Full Adder.
145
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
152
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
156
157 See the following discussion as to why this is no longer derived
158 from FullAdder. Each carry is shifted here *before* being ANDed
159 with the mask, so that an AOI cell may be used (which is more
160 gate-efficient)
161 https://en.wikipedia.org/wiki/AND-OR-Invert
162 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
163 """
164
165 def __init__(self, width):
166 """Create a ``MaskedFullAdder``.
167
168 :param width: the bit width of the input and output
169 """
170 self.width = width
171 self.mask = Signal(width, reset_less=True)
172 self.mcarry = Signal(width, reset_less=True)
173 self.in0 = Signal(width, reset_less=True)
174 self.in1 = Signal(width, reset_less=True)
175 self.in2 = Signal(width, reset_less=True)
176 self.sum = Signal(width, reset_less=True)
177
178 def elaborate(self, platform):
179 """Elaborate this module."""
180 m = Module()
181 s1 = Signal(self.width, reset_less=True)
182 s2 = Signal(self.width, reset_less=True)
183 s3 = Signal(self.width, reset_less=True)
184 c1 = Signal(self.width, reset_less=True)
185 c2 = Signal(self.width, reset_less=True)
186 c3 = Signal(self.width, reset_less=True)
187 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
188 m.d.comb += s1.eq(Cat(0, self.in0))
189 m.d.comb += s2.eq(Cat(0, self.in1))
190 m.d.comb += s3.eq(Cat(0, self.in2))
191 m.d.comb += c1.eq(s1 & s2 & self.mask)
192 m.d.comb += c2.eq(s2 & s3 & self.mask)
193 m.d.comb += c3.eq(s3 & s1 & self.mask)
194 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
195 return m
196
197
198 class PartitionedAdder(Elaboratable):
199 """Partitioned Adder.
200
201 Performs the final add. The partition points are included in the
202 actual add (in one of the operands only), which causes a carry over
203 to the next bit. Then the final output *removes* the extra bits from
204 the result.
205
206 partition: .... P... P... P... P... (32 bits)
207 a : .... .... .... .... .... (32 bits)
208 b : .... .... .... .... .... (32 bits)
209 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
210 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
211 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
212 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
213
214 :attribute width: the bit width of the input and output. Read-only.
215 :attribute a: the first input to the adder
216 :attribute b: the second input to the adder
217 :attribute output: the sum output
218 :attribute partition_points: the input partition points. Modification not
219 supported, except for by ``Signal.eq``.
220 """
221
222 def __init__(self, width, partition_points):
223 """Create a ``PartitionedAdder``.
224
225 :param width: the bit width of the input and output
226 :param partition_points: the input partition points
227 """
228 self.width = width
229 self.a = Signal(width)
230 self.b = Signal(width)
231 self.output = Signal(width)
232 self.partition_points = PartitionPoints(partition_points)
233 if not self.partition_points.fits_in_width(width):
234 raise ValueError("partition_points doesn't fit in width")
235 expanded_width = 0
236 for i in range(self.width):
237 if i in self.partition_points:
238 expanded_width += 1
239 expanded_width += 1
240 self._expanded_width = expanded_width
241 # XXX these have to remain here due to some horrible nmigen
242 # simulation bugs involving sync. it is *not* necessary to
243 # have them here, they should (under normal circumstances)
244 # be moved into elaborate, as they are entirely local
245 self._expanded_a = Signal(expanded_width) # includes extra part-points
246 self._expanded_b = Signal(expanded_width) # likewise.
247 self._expanded_o = Signal(expanded_width) # likewise.
248
249 def elaborate(self, platform):
250 """Elaborate this module."""
251 m = Module()
252 expanded_index = 0
253 # store bits in a list, use Cat later. graphviz is much cleaner
254 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
255
256 # partition points are "breaks" (extra zeros or 1s) in what would
257 # otherwise be a massive long add. when the "break" points are 0,
258 # whatever is in it (in the output) is discarded. however when
259 # there is a "1", it causes a roll-over carry to the *next* bit.
260 # we still ignore the "break" bit in the [intermediate] output,
261 # however by that time we've got the effect that we wanted: the
262 # carry has been carried *over* the break point.
263
264 for i in range(self.width):
265 if i in self.partition_points:
266 # add extra bit set to 0 + 0 for enabled partition points
267 # and 1 + 0 for disabled partition points
268 ea.append(self._expanded_a[expanded_index])
269 al.append(~self.partition_points[i]) # add extra bit in a
270 eb.append(self._expanded_b[expanded_index])
271 bl.append(C(0)) # yes, add a zero
272 expanded_index += 1 # skip the extra point. NOT in the output
273 ea.append(self._expanded_a[expanded_index])
274 eb.append(self._expanded_b[expanded_index])
275 eo.append(self._expanded_o[expanded_index])
276 al.append(self.a[i])
277 bl.append(self.b[i])
278 ol.append(self.output[i])
279 expanded_index += 1
280
281 # combine above using Cat
282 m.d.comb += Cat(*ea).eq(Cat(*al))
283 m.d.comb += Cat(*eb).eq(Cat(*bl))
284 m.d.comb += Cat(*ol).eq(Cat(*eo))
285
286 # use only one addition to take advantage of look-ahead carry and
287 # special hardware on FPGAs
288 m.d.comb += self._expanded_o.eq(
289 self._expanded_a + self._expanded_b)
290 return m
291
292
293 FULL_ADDER_INPUT_COUNT = 3
294
295
296 class AddReduceSingle(Elaboratable):
297 """Add list of numbers together.
298
299 :attribute inputs: input ``Signal``s to be summed. Modification not
300 supported, except for by ``Signal.eq``.
301 :attribute register_levels: List of nesting levels that should have
302 pipeline registers.
303 :attribute output: output sum.
304 :attribute partition_points: the input partition points. Modification not
305 supported, except for by ``Signal.eq``.
306 """
307
308 def __init__(self, inputs, output_width, register_levels, partition_points):
309 """Create an ``AddReduce``.
310
311 :param inputs: input ``Signal``s to be summed.
312 :param output_width: bit-width of ``output``.
313 :param register_levels: List of nesting levels that should have
314 pipeline registers.
315 :param partition_points: the input partition points.
316 """
317 self.inputs = list(inputs)
318 self._resized_inputs = [
319 Signal(output_width, name=f"resized_inputs[{i}]")
320 for i in range(len(self.inputs))]
321 self.register_levels = list(register_levels)
322 self.output = Signal(output_width)
323 self.partition_points = PartitionPoints(partition_points)
324 if not self.partition_points.fits_in_width(output_width):
325 raise ValueError("partition_points doesn't fit in output_width")
326 self._reg_partition_points = self.partition_points.like()
327
328 max_level = AddReduceSingle.get_max_level(len(self.inputs))
329 for level in self.register_levels:
330 if level > max_level:
331 raise ValueError(
332 "not enough adder levels for specified register levels")
333
334 self.groups = AddReduceSingle.full_adder_groups(len(self.inputs))
335 self._intermediate_terms = []
336 if len(self.groups) != 0:
337 self.create_next_terms()
338
339 @staticmethod
340 def get_max_level(input_count):
341 """Get the maximum level.
342
343 All ``register_levels`` must be less than or equal to the maximum
344 level.
345 """
346 retval = 0
347 while True:
348 groups = AddReduceSingle.full_adder_groups(input_count)
349 if len(groups) == 0:
350 return retval
351 input_count %= FULL_ADDER_INPUT_COUNT
352 input_count += 2 * len(groups)
353 retval += 1
354
355 @staticmethod
356 def full_adder_groups(input_count):
357 """Get ``inputs`` indices for which a full adder should be built."""
358 return range(0,
359 input_count - FULL_ADDER_INPUT_COUNT + 1,
360 FULL_ADDER_INPUT_COUNT)
361
362 def elaborate(self, platform):
363 """Elaborate this module."""
364 m = Module()
365
366 # resize inputs to correct bit-width and optionally add in
367 # pipeline registers
368 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
369 for i in range(len(self.inputs))]
370 if 0 in self.register_levels:
371 m.d.sync += resized_input_assignments
372 m.d.sync += self._reg_partition_points.eq(self.partition_points)
373 else:
374 m.d.comb += resized_input_assignments
375 m.d.comb += self._reg_partition_points.eq(self.partition_points)
376
377 for (value, term) in self._intermediate_terms:
378 m.d.comb += term.eq(value)
379
380 # if there are no full adders to create, then we handle the base cases
381 # and return, otherwise we go on to the recursive case
382 if len(self.groups) == 0:
383 if len(self.inputs) == 0:
384 # use 0 as the default output value
385 m.d.comb += self.output.eq(0)
386 elif len(self.inputs) == 1:
387 # handle single input
388 m.d.comb += self.output.eq(self._resized_inputs[0])
389 else:
390 # base case for adding 2 inputs
391 assert len(self.inputs) == 2
392 adder = PartitionedAdder(len(self.output),
393 self._reg_partition_points)
394 m.submodules.final_adder = adder
395 m.d.comb += adder.a.eq(self._resized_inputs[0])
396 m.d.comb += adder.b.eq(self._resized_inputs[1])
397 m.d.comb += self.output.eq(adder.output)
398 return m
399
400 mask = self._reg_partition_points.as_mask(len(self.output))
401 m.d.comb += self.part_mask.eq(mask)
402
403 # add and link the intermediate term modules
404 for i, (iidx, adder_i) in enumerate(self.adders):
405 setattr(m.submodules, f"adder_{i}", adder_i)
406
407 m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
408 m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
409 m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
410 m.d.comb += adder_i.mask.eq(self.part_mask)
411
412 return m
413
414 def create_next_terms(self):
415
416 # go on to prepare recursive case
417 intermediate_terms = []
418 _intermediate_terms = []
419
420 def add_intermediate_term(value):
421 intermediate_term = Signal(
422 len(self.output),
423 name=f"intermediate_terms[{len(intermediate_terms)}]")
424 _intermediate_terms.append((value, intermediate_term))
425 intermediate_terms.append(intermediate_term)
426
427 # store mask in intermediary (simplifies graph)
428 self.part_mask = Signal(len(self.output), reset_less=True)
429
430 # create full adders for this recursive level.
431 # this shrinks N terms to 2 * (N // 3) plus the remainder
432 self.adders = []
433 for i in self.groups:
434 adder_i = MaskedFullAdder(len(self.output))
435 self.adders.append((i, adder_i))
436 # add both the sum and the masked-carry to the next level.
437 # 3 inputs have now been reduced to 2...
438 add_intermediate_term(adder_i.sum)
439 add_intermediate_term(adder_i.mcarry)
440 # handle the remaining inputs.
441 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
442 add_intermediate_term(self._resized_inputs[-1])
443 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
444 # Just pass the terms to the next layer, since we wouldn't gain
445 # anything by using a half adder since there would still be 2 terms
446 # and just passing the terms to the next layer saves gates.
447 add_intermediate_term(self._resized_inputs[-2])
448 add_intermediate_term(self._resized_inputs[-1])
449 else:
450 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
451
452 self.intermediate_terms = intermediate_terms
453 self._intermediate_terms = _intermediate_terms
454
455
456 class AddReduce(Elaboratable):
457 """Recursively Add list of numbers together.
458
459 :attribute inputs: input ``Signal``s to be summed. Modification not
460 supported, except for by ``Signal.eq``.
461 :attribute register_levels: List of nesting levels that should have
462 pipeline registers.
463 :attribute output: output sum.
464 :attribute partition_points: the input partition points. Modification not
465 supported, except for by ``Signal.eq``.
466 """
467
468 def __init__(self, inputs, output_width, register_levels, partition_points):
469 """Create an ``AddReduce``.
470
471 :param inputs: input ``Signal``s to be summed.
472 :param output_width: bit-width of ``output``.
473 :param register_levels: List of nesting levels that should have
474 pipeline registers.
475 :param partition_points: the input partition points.
476 """
477 self.inputs = inputs
478 self.output = Signal(output_width)
479 self.output_width = output_width
480 self.register_levels = register_levels
481 self.partition_points = partition_points
482
483 self.create_levels()
484
485 @staticmethod
486 def next_register_levels(register_levels):
487 """``Iterable`` of ``register_levels`` for next recursive level."""
488 for level in register_levels:
489 if level > 0:
490 yield level - 1
491
492 def create_levels(self):
493 """creates reduction levels"""
494
495 mods = []
496 next_levels = self.register_levels
497 partition_points = self.partition_points
498 inputs = self.inputs
499 while True:
500 next_level = AddReduceSingle(inputs, self.output_width, next_levels,
501 partition_points)
502 mods.append(next_level)
503 if len(next_level.groups) == 0:
504 break
505 next_levels = list(AddReduce.next_register_levels(next_levels))
506 partition_points = next_level._reg_partition_points
507 inputs = next_level.intermediate_terms
508
509 self.levels = mods
510
511 def elaborate(self, platform):
512 """Elaborate this module."""
513 m = Module()
514
515 for i, next_level in enumerate(self.levels):
516 setattr(m.submodules, "next_level%d" % i, next_level)
517
518 # output comes from last module
519 m.d.comb += self.output.eq(next_level.output)
520
521 return m
522
523
524 OP_MUL_LOW = 0
525 OP_MUL_SIGNED_HIGH = 1
526 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
527 OP_MUL_UNSIGNED_HIGH = 3
528
529
530 def get_term(value, shift=0, enabled=None):
531 if enabled is not None:
532 value = Mux(enabled, value, 0)
533 if shift > 0:
534 value = Cat(Repl(C(0, 1), shift), value)
535 else:
536 assert shift == 0
537 return value
538
539
540 class ProductTerm(Elaboratable):
541 """ this class creates a single product term (a[..]*b[..]).
542 it has a design flaw in that is the *output* that is selected,
543 where the multiplication(s) are combinatorially generated
544 all the time.
545 """
546
547 def __init__(self, width, twidth, pbwid, a_index, b_index):
548 self.a_index = a_index
549 self.b_index = b_index
550 shift = 8 * (self.a_index + self.b_index)
551 self.pwidth = width
552 self.twidth = twidth
553 self.width = width*2
554 self.shift = shift
555
556 self.ti = Signal(self.width, reset_less=True)
557 self.term = Signal(twidth, reset_less=True)
558 self.a = Signal(twidth//2, reset_less=True)
559 self.b = Signal(twidth//2, reset_less=True)
560 self.pb_en = Signal(pbwid, reset_less=True)
561
562 self.tl = tl = []
563 min_index = min(self.a_index, self.b_index)
564 max_index = max(self.a_index, self.b_index)
565 for i in range(min_index, max_index):
566 tl.append(self.pb_en[i])
567 name = "te_%d_%d" % (self.a_index, self.b_index)
568 if len(tl) > 0:
569 term_enabled = Signal(name=name, reset_less=True)
570 else:
571 term_enabled = None
572 self.enabled = term_enabled
573 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
574
575 def elaborate(self, platform):
576
577 m = Module()
578 if self.enabled is not None:
579 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
580
581 bsa = Signal(self.width, reset_less=True)
582 bsb = Signal(self.width, reset_less=True)
583 a_index, b_index = self.a_index, self.b_index
584 pwidth = self.pwidth
585 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
586 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
587 m.d.comb += self.ti.eq(bsa * bsb)
588 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
589 """
590 #TODO: sort out width issues, get inputs a/b switched on/off.
591 #data going into Muxes is 1/2 the required width
592
593 pwidth = self.pwidth
594 width = self.width
595 bsa = Signal(self.twidth//2, reset_less=True)
596 bsb = Signal(self.twidth//2, reset_less=True)
597 asel = Signal(width, reset_less=True)
598 bsel = Signal(width, reset_less=True)
599 a_index, b_index = self.a_index, self.b_index
600 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
601 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
602 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
603 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
604 m.d.comb += self.ti.eq(bsa * bsb)
605 m.d.comb += self.term.eq(self.ti)
606 """
607
608 return m
609
610
611 class ProductTerms(Elaboratable):
612 """ creates a bank of product terms. also performs the actual bit-selection
613 this class is to be wrapped with a for-loop on the "a" operand.
614 it creates a second-level for-loop on the "b" operand.
615 """
616 def __init__(self, width, twidth, pbwid, a_index, blen):
617 self.a_index = a_index
618 self.blen = blen
619 self.pwidth = width
620 self.twidth = twidth
621 self.pbwid = pbwid
622 self.a = Signal(twidth//2, reset_less=True)
623 self.b = Signal(twidth//2, reset_less=True)
624 self.pb_en = Signal(pbwid, reset_less=True)
625 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
626 for i in range(blen)]
627
628 def elaborate(self, platform):
629
630 m = Module()
631
632 for b_index in range(self.blen):
633 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
634 self.a_index, b_index)
635 setattr(m.submodules, "term_%d" % b_index, t)
636
637 m.d.comb += t.a.eq(self.a)
638 m.d.comb += t.b.eq(self.b)
639 m.d.comb += t.pb_en.eq(self.pb_en)
640
641 m.d.comb += self.terms[b_index].eq(t.term)
642
643 return m
644
645 class LSBNegTerm(Elaboratable):
646
647 def __init__(self, bit_width):
648 self.bit_width = bit_width
649 self.part = Signal(reset_less=True)
650 self.signed = Signal(reset_less=True)
651 self.op = Signal(bit_width, reset_less=True)
652 self.msb = Signal(reset_less=True)
653 self.nt = Signal(bit_width*2, reset_less=True)
654 self.nl = Signal(bit_width*2, reset_less=True)
655
656 def elaborate(self, platform):
657 m = Module()
658 comb = m.d.comb
659 bit_wid = self.bit_width
660 ext = Repl(0, bit_wid) # extend output to HI part
661
662 # determine sign of each incoming number *in this partition*
663 enabled = Signal(reset_less=True)
664 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
665
666 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
667 # negation operation is split into a bitwise not and a +1.
668 # likewise for 16, 32, and 64-bit values.
669
670 # width-extended 1s complement if a is signed, otherwise zero
671 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
672
673 # add 1 if signed, otherwise add zero
674 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
675
676 return m
677
678
679 class Part(Elaboratable):
680 """ a key class which, depending on the partitioning, will determine
681 what action to take when parts of the output are signed or unsigned.
682
683 this requires 2 pieces of data *per operand, per partition*:
684 whether the MSB is HI/LO (per partition!), and whether a signed
685 or unsigned operation has been *requested*.
686
687 once that is determined, signed is basically carried out
688 by splitting 2's complement into 1's complement plus one.
689 1's complement is just a bit-inversion.
690
691 the extra terms - as separate terms - are then thrown at the
692 AddReduce alongside the multiplication part-results.
693 """
694 def __init__(self, width, n_parts, n_levels, pbwid):
695
696 # inputs
697 self.a = Signal(64)
698 self.b = Signal(64)
699 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
700 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
701 self.pbs = Signal(pbwid, reset_less=True)
702
703 # outputs
704 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
705 self.delayed_parts = [
706 [Signal(name=f"delayed_part_{delay}_{i}")
707 for i in range(n_parts)]
708 for delay in range(n_levels)]
709 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
710 self.dplast = [Signal(name=f"dplast_{i}")
711 for i in range(n_parts)]
712
713 self.not_a_term = Signal(width)
714 self.neg_lsb_a_term = Signal(width)
715 self.not_b_term = Signal(width)
716 self.neg_lsb_b_term = Signal(width)
717
718 def elaborate(self, platform):
719 m = Module()
720
721 pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
722 # negated-temporary copy of partition bits
723 npbs = Signal.like(pbs, reset_less=True)
724 m.d.comb += npbs.eq(~pbs)
725 byte_count = 8 // len(parts)
726 for i in range(len(parts)):
727 pbl = []
728 pbl.append(npbs[i * byte_count - 1])
729 for j in range(i * byte_count, (i + 1) * byte_count - 1):
730 pbl.append(pbs[j])
731 pbl.append(npbs[(i + 1) * byte_count - 1])
732 value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
733 m.d.comb += value.eq(Cat(*pbl))
734 m.d.comb += parts[i].eq(~(value).bool())
735 m.d.comb += delayed_parts[0][i].eq(parts[i])
736 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
737 for j in range(len(delayed_parts)-1)]
738 m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
739
740 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
741 self.not_a_term, self.neg_lsb_a_term, \
742 self.not_b_term, self.neg_lsb_b_term
743
744 byte_width = 8 // len(parts) # byte width
745 bit_wid = 8 * byte_width # bit width
746 nat, nbt, nla, nlb = [], [], [], []
747 for i in range(len(parts)):
748 # work out bit-inverted and +1 term for a.
749 pa = LSBNegTerm(bit_wid)
750 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
751 m.d.comb += pa.part.eq(parts[i])
752 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
753 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
754 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
755 nat.append(pa.nt)
756 nla.append(pa.nl)
757
758 # work out bit-inverted and +1 term for b
759 pb = LSBNegTerm(bit_wid)
760 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
761 m.d.comb += pb.part.eq(parts[i])
762 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
763 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
764 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
765 nbt.append(pb.nt)
766 nlb.append(pb.nl)
767
768 # concatenate together and return all 4 results.
769 m.d.comb += [not_a_term.eq(Cat(*nat)),
770 not_b_term.eq(Cat(*nbt)),
771 neg_lsb_a_term.eq(Cat(*nla)),
772 neg_lsb_b_term.eq(Cat(*nlb)),
773 ]
774
775 return m
776
777
778 class IntermediateOut(Elaboratable):
779 """ selects the HI/LO part of the multiplication, for a given bit-width
780 the output is also reconstructed in its SIMD (partition) lanes.
781 """
782 def __init__(self, width, out_wid, n_parts):
783 self.width = width
784 self.n_parts = n_parts
785 self.delayed_part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
786 for i in range(8)]
787 self.intermed = Signal(out_wid, reset_less=True)
788 self.output = Signal(out_wid//2, reset_less=True)
789
790 def elaborate(self, platform):
791 m = Module()
792
793 ol = []
794 w = self.width
795 sel = w // 8
796 for i in range(self.n_parts):
797 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
798 m.d.comb += op.eq(
799 Mux(self.delayed_part_ops[sel * i] == OP_MUL_LOW,
800 self.intermed.bit_select(i * w*2, w),
801 self.intermed.bit_select(i * w*2 + w, w)))
802 ol.append(op)
803 m.d.comb += self.output.eq(Cat(*ol))
804
805 return m
806
807
808 class FinalOut(Elaboratable):
809 """ selects the final output based on the partitioning.
810
811 each byte is selectable independently, i.e. it is possible
812 that some partitions requested 8-bit computation whilst others
813 requested 16 or 32 bit.
814 """
815 def __init__(self, out_wid):
816 # inputs
817 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
818 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
819 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
820
821 self.i8 = Signal(out_wid, reset_less=True)
822 self.i16 = Signal(out_wid, reset_less=True)
823 self.i32 = Signal(out_wid, reset_less=True)
824 self.i64 = Signal(out_wid, reset_less=True)
825
826 # output
827 self.out = Signal(out_wid, reset_less=True)
828
829 def elaborate(self, platform):
830 m = Module()
831 ol = []
832 for i in range(8):
833 # select one of the outputs: d8 selects i8, d16 selects i16
834 # d32 selects i32, and the default is i64.
835 # d8 and d16 are ORed together in the first Mux
836 # then the 2nd selects either i8 or i16.
837 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
838 op = Signal(8, reset_less=True, name="op_%d" % i)
839 m.d.comb += op.eq(
840 Mux(self.d8[i] | self.d16[i // 2],
841 Mux(self.d8[i], self.i8.bit_select(i * 8, 8),
842 self.i16.bit_select(i * 8, 8)),
843 Mux(self.d32[i // 4], self.i32.bit_select(i * 8, 8),
844 self.i64.bit_select(i * 8, 8))))
845 ol.append(op)
846 m.d.comb += self.out.eq(Cat(*ol))
847 return m
848
849
850 class OrMod(Elaboratable):
851 """ ORs four values together in a hierarchical tree
852 """
853 def __init__(self, wid):
854 self.wid = wid
855 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
856 for i in range(4)]
857 self.orout = Signal(wid, reset_less=True)
858
859 def elaborate(self, platform):
860 m = Module()
861 or1 = Signal(self.wid, reset_less=True)
862 or2 = Signal(self.wid, reset_less=True)
863 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
864 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
865 m.d.comb += self.orout.eq(or1 | or2)
866
867 return m
868
869
870 class Signs(Elaboratable):
871 """ determines whether a or b are signed numbers
872 based on the required operation type (OP_MUL_*)
873 """
874
875 def __init__(self):
876 self.part_ops = Signal(2, reset_less=True)
877 self.a_signed = Signal(reset_less=True)
878 self.b_signed = Signal(reset_less=True)
879
880 def elaborate(self, platform):
881
882 m = Module()
883
884 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
885 bsig = (self.part_ops == OP_MUL_LOW) \
886 | (self.part_ops == OP_MUL_SIGNED_HIGH)
887 m.d.comb += self.a_signed.eq(asig)
888 m.d.comb += self.b_signed.eq(bsig)
889
890 return m
891
892
893 class Mul8_16_32_64(Elaboratable):
894 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
895
896 Supports partitioning into any combination of 8, 16, 32, and 64-bit
897 partitions on naturally-aligned boundaries. Supports the operation being
898 set for each partition independently.
899
900 :attribute part_pts: the input partition points. Has a partition point at
901 multiples of 8 in 0 < i < 64. Each partition point's associated
902 ``Value`` is a ``Signal``. Modification not supported, except for by
903 ``Signal.eq``.
904 :attribute part_ops: the operation for each byte. The operation for a
905 particular partition is selected by assigning the selected operation
906 code to each byte in the partition. The allowed operation codes are:
907
908 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
909 RISC-V's `mul` instruction.
910 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
911 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
912 instruction.
913 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
914 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
915 `mulhsu` instruction.
916 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
917 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
918 instruction.
919 """
920
921 def __init__(self, register_levels=()):
922 """ register_levels: specifies the points in the cascade at which
923 flip-flops are to be inserted.
924 """
925
926 # parameter(s)
927 self.register_levels = list(register_levels)
928
929 # inputs
930 self.part_pts = PartitionPoints()
931 for i in range(8, 64, 8):
932 self.part_pts[i] = Signal(name=f"part_pts_{i}")
933 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
934 self.a = Signal(64)
935 self.b = Signal(64)
936
937 # intermediates (needed for unit tests)
938 self._intermediate_output = Signal(128)
939
940 # output
941 self.output = Signal(64)
942
943 def _part_byte(self, index):
944 if index == -1 or index == 7:
945 return C(True, 1)
946 assert index >= 0 and index < 8
947 return self.part_pts[index * 8 + 8]
948
949 def elaborate(self, platform):
950 m = Module()
951
952 # collect part-bytes
953 pbs = Signal(8, reset_less=True)
954 tl = []
955 for i in range(8):
956 pb = Signal(name="pb%d" % i, reset_less=True)
957 m.d.comb += pb.eq(self._part_byte(i))
958 tl.append(pb)
959 m.d.comb += pbs.eq(Cat(*tl))
960
961 # local variables
962 signs = []
963 for i in range(8):
964 s = Signs()
965 signs.append(s)
966 setattr(m.submodules, "signs%d" % i, s)
967 m.d.comb += s.part_ops.eq(self.part_ops[i])
968
969 delayed_part_ops = [
970 [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
971 for i in range(8)]
972 for delay in range(1 + len(self.register_levels))]
973 for i in range(len(self.part_ops)):
974 m.d.comb += delayed_part_ops[0][i].eq(self.part_ops[i])
975 m.d.sync += [delayed_part_ops[j + 1][i].eq(delayed_part_ops[j][i])
976 for j in range(len(self.register_levels))]
977
978 n_levels = len(self.register_levels)+1
979 m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
980 m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
981 m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
982 m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
983 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
984 for mod in [part_8, part_16, part_32, part_64]:
985 m.d.comb += mod.a.eq(self.a)
986 m.d.comb += mod.b.eq(self.b)
987 for i in range(len(signs)):
988 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
989 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
990 m.d.comb += mod.pbs.eq(pbs)
991 nat_l.append(mod.not_a_term)
992 nbt_l.append(mod.not_b_term)
993 nla_l.append(mod.neg_lsb_a_term)
994 nlb_l.append(mod.neg_lsb_b_term)
995
996 terms = []
997
998 for a_index in range(8):
999 t = ProductTerms(8, 128, 8, a_index, 8)
1000 setattr(m.submodules, "terms_%d" % a_index, t)
1001
1002 m.d.comb += t.a.eq(self.a)
1003 m.d.comb += t.b.eq(self.b)
1004 m.d.comb += t.pb_en.eq(pbs)
1005
1006 for term in t.terms:
1007 terms.append(term)
1008
1009 # it's fine to bitwise-or data together since they are never enabled
1010 # at the same time
1011 m.submodules.nat_or = nat_or = OrMod(128)
1012 m.submodules.nbt_or = nbt_or = OrMod(128)
1013 m.submodules.nla_or = nla_or = OrMod(128)
1014 m.submodules.nlb_or = nlb_or = OrMod(128)
1015 for l, mod in [(nat_l, nat_or),
1016 (nbt_l, nbt_or),
1017 (nla_l, nla_or),
1018 (nlb_l, nlb_or)]:
1019 for i in range(len(l)):
1020 m.d.comb += mod.orin[i].eq(l[i])
1021 terms.append(mod.orout)
1022
1023 expanded_part_pts = PartitionPoints()
1024 for i, v in self.part_pts.items():
1025 signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
1026 expanded_part_pts[i * 2] = signal
1027 m.d.comb += signal.eq(v)
1028
1029 add_reduce = AddReduce(terms,
1030 128,
1031 self.register_levels,
1032 expanded_part_pts)
1033 m.submodules.add_reduce = add_reduce
1034 m.d.comb += self._intermediate_output.eq(add_reduce.output)
1035 # create _output_64
1036 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1037 m.d.comb += io64.intermed.eq(self._intermediate_output)
1038 for i in range(8):
1039 m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
1040
1041 # create _output_32
1042 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1043 m.d.comb += io32.intermed.eq(self._intermediate_output)
1044 for i in range(8):
1045 m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
1046
1047 # create _output_16
1048 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1049 m.d.comb += io16.intermed.eq(self._intermediate_output)
1050 for i in range(8):
1051 m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
1052
1053 # create _output_8
1054 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1055 m.d.comb += io8.intermed.eq(self._intermediate_output)
1056 for i in range(8):
1057 m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
1058
1059 # final output
1060 m.submodules.finalout = finalout = FinalOut(64)
1061 for i in range(len(part_8.delayed_parts[-1])):
1062 m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
1063 for i in range(len(part_16.delayed_parts[-1])):
1064 m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
1065 for i in range(len(part_32.delayed_parts[-1])):
1066 m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
1067 m.d.comb += finalout.i8.eq(io8.output)
1068 m.d.comb += finalout.i16.eq(io16.output)
1069 m.d.comb += finalout.i32.eq(io32.output)
1070 m.d.comb += finalout.i64.eq(io64.output)
1071 m.d.comb += self.output.eq(finalout.out)
1072
1073 return m
1074
1075
1076 if __name__ == "__main__":
1077 m = Mul8_16_32_64()
1078 main(m, ports=[m.a,
1079 m.b,
1080 m._intermediate_output,
1081 m.output,
1082 *m.part_ops,
1083 *m.part_pts.values()])