pass in part_ops to AddReduce, so that it is syncd alongside the other data
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 """
58 if name is None:
59 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
60 retval = PartitionPoints()
61 for point, enabled in self.items():
62 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
63 return retval
64
65 def eq(self, rhs):
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self.keys()) != set(rhs.keys()):
68 raise ValueError("incompatible point set")
69 for point, enabled in self.items():
70 yield enabled.eq(rhs[point])
71
72 def as_mask(self, width):
73 """Create a bit-mask from `self`.
74
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
77
78 :param width: the bit width of the resulting mask
79 """
80 bits = []
81 for i in range(width):
82 if i in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106
107 class FullAdder(Elaboratable):
108 """Full Adder.
109
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
115
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
120 """
121
122 def __init__(self, width):
123 """Create a ``FullAdder``.
124
125 :param width: the bit width of the input and output
126 """
127 self.in0 = Signal(width)
128 self.in1 = Signal(width)
129 self.in2 = Signal(width)
130 self.sum = Signal(width)
131 self.carry = Signal(width)
132
133 def elaborate(self, platform):
134 """Elaborate this module."""
135 m = Module()
136 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
137 m.d.comb += self.carry.eq((self.in0 & self.in1)
138 | (self.in1 & self.in2)
139 | (self.in2 & self.in0))
140 return m
141
142
143 class MaskedFullAdder(Elaboratable):
144 """Masked Full Adder.
145
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
152
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
156
157 See the following discussion as to why this is no longer derived
158 from FullAdder. Each carry is shifted here *before* being ANDed
159 with the mask, so that an AOI cell may be used (which is more
160 gate-efficient)
161 https://en.wikipedia.org/wiki/AND-OR-Invert
162 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
163 """
164
165 def __init__(self, width):
166 """Create a ``MaskedFullAdder``.
167
168 :param width: the bit width of the input and output
169 """
170 self.width = width
171 self.mask = Signal(width, reset_less=True)
172 self.mcarry = Signal(width, reset_less=True)
173 self.in0 = Signal(width, reset_less=True)
174 self.in1 = Signal(width, reset_less=True)
175 self.in2 = Signal(width, reset_less=True)
176 self.sum = Signal(width, reset_less=True)
177
178 def elaborate(self, platform):
179 """Elaborate this module."""
180 m = Module()
181 s1 = Signal(self.width, reset_less=True)
182 s2 = Signal(self.width, reset_less=True)
183 s3 = Signal(self.width, reset_less=True)
184 c1 = Signal(self.width, reset_less=True)
185 c2 = Signal(self.width, reset_less=True)
186 c3 = Signal(self.width, reset_less=True)
187 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
188 m.d.comb += s1.eq(Cat(0, self.in0))
189 m.d.comb += s2.eq(Cat(0, self.in1))
190 m.d.comb += s3.eq(Cat(0, self.in2))
191 m.d.comb += c1.eq(s1 & s2 & self.mask)
192 m.d.comb += c2.eq(s2 & s3 & self.mask)
193 m.d.comb += c3.eq(s3 & s1 & self.mask)
194 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
195 return m
196
197
198 class PartitionedAdder(Elaboratable):
199 """Partitioned Adder.
200
201 Performs the final add. The partition points are included in the
202 actual add (in one of the operands only), which causes a carry over
203 to the next bit. Then the final output *removes* the extra bits from
204 the result.
205
206 partition: .... P... P... P... P... (32 bits)
207 a : .... .... .... .... .... (32 bits)
208 b : .... .... .... .... .... (32 bits)
209 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
210 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
211 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
212 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
213
214 :attribute width: the bit width of the input and output. Read-only.
215 :attribute a: the first input to the adder
216 :attribute b: the second input to the adder
217 :attribute output: the sum output
218 :attribute partition_points: the input partition points. Modification not
219 supported, except for by ``Signal.eq``.
220 """
221
222 def __init__(self, width, partition_points):
223 """Create a ``PartitionedAdder``.
224
225 :param width: the bit width of the input and output
226 :param partition_points: the input partition points
227 """
228 self.width = width
229 self.a = Signal(width)
230 self.b = Signal(width)
231 self.output = Signal(width)
232 self.partition_points = PartitionPoints(partition_points)
233 if not self.partition_points.fits_in_width(width):
234 raise ValueError("partition_points doesn't fit in width")
235 expanded_width = 0
236 for i in range(self.width):
237 if i in self.partition_points:
238 expanded_width += 1
239 expanded_width += 1
240 self._expanded_width = expanded_width
241 # XXX these have to remain here due to some horrible nmigen
242 # simulation bugs involving sync. it is *not* necessary to
243 # have them here, they should (under normal circumstances)
244 # be moved into elaborate, as they are entirely local
245 self._expanded_a = Signal(expanded_width) # includes extra part-points
246 self._expanded_b = Signal(expanded_width) # likewise.
247 self._expanded_o = Signal(expanded_width) # likewise.
248
249 def elaborate(self, platform):
250 """Elaborate this module."""
251 m = Module()
252 expanded_index = 0
253 # store bits in a list, use Cat later. graphviz is much cleaner
254 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
255
256 # partition points are "breaks" (extra zeros or 1s) in what would
257 # otherwise be a massive long add. when the "break" points are 0,
258 # whatever is in it (in the output) is discarded. however when
259 # there is a "1", it causes a roll-over carry to the *next* bit.
260 # we still ignore the "break" bit in the [intermediate] output,
261 # however by that time we've got the effect that we wanted: the
262 # carry has been carried *over* the break point.
263
264 for i in range(self.width):
265 if i in self.partition_points:
266 # add extra bit set to 0 + 0 for enabled partition points
267 # and 1 + 0 for disabled partition points
268 ea.append(self._expanded_a[expanded_index])
269 al.append(~self.partition_points[i]) # add extra bit in a
270 eb.append(self._expanded_b[expanded_index])
271 bl.append(C(0)) # yes, add a zero
272 expanded_index += 1 # skip the extra point. NOT in the output
273 ea.append(self._expanded_a[expanded_index])
274 eb.append(self._expanded_b[expanded_index])
275 eo.append(self._expanded_o[expanded_index])
276 al.append(self.a[i])
277 bl.append(self.b[i])
278 ol.append(self.output[i])
279 expanded_index += 1
280
281 # combine above using Cat
282 m.d.comb += Cat(*ea).eq(Cat(*al))
283 m.d.comb += Cat(*eb).eq(Cat(*bl))
284 m.d.comb += Cat(*ol).eq(Cat(*eo))
285
286 # use only one addition to take advantage of look-ahead carry and
287 # special hardware on FPGAs
288 m.d.comb += self._expanded_o.eq(
289 self._expanded_a + self._expanded_b)
290 return m
291
292
293 FULL_ADDER_INPUT_COUNT = 3
294
295
296 class AddReduceSingle(Elaboratable):
297 """Add list of numbers together.
298
299 :attribute inputs: input ``Signal``s to be summed. Modification not
300 supported, except for by ``Signal.eq``.
301 :attribute register_levels: List of nesting levels that should have
302 pipeline registers.
303 :attribute output: output sum.
304 :attribute partition_points: the input partition points. Modification not
305 supported, except for by ``Signal.eq``.
306 """
307
308 def __init__(self, inputs, output_width, register_levels, partition_points,
309 part_ops):
310 """Create an ``AddReduce``.
311
312 :param inputs: input ``Signal``s to be summed.
313 :param output_width: bit-width of ``output``.
314 :param register_levels: List of nesting levels that should have
315 pipeline registers.
316 :param partition_points: the input partition points.
317 """
318 self.part_ops = part_ops
319 self._part_ops = [Signal(2, name=f"part_ops_{i}")
320 for i in range(len(part_ops))]
321 self.inputs = list(inputs)
322 self._resized_inputs = [
323 Signal(output_width, name=f"resized_inputs[{i}]")
324 for i in range(len(self.inputs))]
325 self.register_levels = list(register_levels)
326 self.output = Signal(output_width)
327 self.partition_points = PartitionPoints(partition_points)
328 if not self.partition_points.fits_in_width(output_width):
329 raise ValueError("partition_points doesn't fit in output_width")
330 self._reg_partition_points = self.partition_points.like()
331
332 max_level = AddReduceSingle.get_max_level(len(self.inputs))
333 for level in self.register_levels:
334 if level > max_level:
335 raise ValueError(
336 "not enough adder levels for specified register levels")
337
338 self.groups = AddReduceSingle.full_adder_groups(len(self.inputs))
339 self._intermediate_terms = []
340 if len(self.groups) != 0:
341 self.create_next_terms()
342
343 @staticmethod
344 def get_max_level(input_count):
345 """Get the maximum level.
346
347 All ``register_levels`` must be less than or equal to the maximum
348 level.
349 """
350 retval = 0
351 while True:
352 groups = AddReduceSingle.full_adder_groups(input_count)
353 if len(groups) == 0:
354 return retval
355 input_count %= FULL_ADDER_INPUT_COUNT
356 input_count += 2 * len(groups)
357 retval += 1
358
359 @staticmethod
360 def full_adder_groups(input_count):
361 """Get ``inputs`` indices for which a full adder should be built."""
362 return range(0,
363 input_count - FULL_ADDER_INPUT_COUNT + 1,
364 FULL_ADDER_INPUT_COUNT)
365
366 def elaborate(self, platform):
367 """Elaborate this module."""
368 m = Module()
369
370 # resize inputs to correct bit-width and optionally add in
371 # pipeline registers
372 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
373 for i in range(len(self.inputs))]
374 copy_part_ops = [self._part_ops[i].eq(self.part_ops[i])
375 for i in range(len(self.part_ops))]
376 if 0 in self.register_levels:
377 m.d.sync += copy_part_ops
378 m.d.sync += resized_input_assignments
379 m.d.sync += self._reg_partition_points.eq(self.partition_points)
380 else:
381 m.d.comb += copy_part_ops
382 m.d.comb += resized_input_assignments
383 m.d.comb += self._reg_partition_points.eq(self.partition_points)
384
385 for (value, term) in self._intermediate_terms:
386 m.d.comb += term.eq(value)
387
388 # if there are no full adders to create, then we handle the base cases
389 # and return, otherwise we go on to the recursive case
390 if len(self.groups) == 0:
391 if len(self.inputs) == 0:
392 # use 0 as the default output value
393 m.d.comb += self.output.eq(0)
394 elif len(self.inputs) == 1:
395 # handle single input
396 m.d.comb += self.output.eq(self._resized_inputs[0])
397 else:
398 # base case for adding 2 inputs
399 assert len(self.inputs) == 2
400 adder = PartitionedAdder(len(self.output),
401 self._reg_partition_points)
402 m.submodules.final_adder = adder
403 m.d.comb += adder.a.eq(self._resized_inputs[0])
404 m.d.comb += adder.b.eq(self._resized_inputs[1])
405 m.d.comb += self.output.eq(adder.output)
406 return m
407
408 mask = self._reg_partition_points.as_mask(len(self.output))
409 m.d.comb += self.part_mask.eq(mask)
410
411 # add and link the intermediate term modules
412 for i, (iidx, adder_i) in enumerate(self.adders):
413 setattr(m.submodules, f"adder_{i}", adder_i)
414
415 m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
416 m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
417 m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
418 m.d.comb += adder_i.mask.eq(self.part_mask)
419
420 return m
421
422 def create_next_terms(self):
423
424 # go on to prepare recursive case
425 intermediate_terms = []
426 _intermediate_terms = []
427
428 def add_intermediate_term(value):
429 intermediate_term = Signal(
430 len(self.output),
431 name=f"intermediate_terms[{len(intermediate_terms)}]")
432 _intermediate_terms.append((value, intermediate_term))
433 intermediate_terms.append(intermediate_term)
434
435 # store mask in intermediary (simplifies graph)
436 self.part_mask = Signal(len(self.output), reset_less=True)
437
438 # create full adders for this recursive level.
439 # this shrinks N terms to 2 * (N // 3) plus the remainder
440 self.adders = []
441 for i in self.groups:
442 adder_i = MaskedFullAdder(len(self.output))
443 self.adders.append((i, adder_i))
444 # add both the sum and the masked-carry to the next level.
445 # 3 inputs have now been reduced to 2...
446 add_intermediate_term(adder_i.sum)
447 add_intermediate_term(adder_i.mcarry)
448 # handle the remaining inputs.
449 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
450 add_intermediate_term(self._resized_inputs[-1])
451 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
452 # Just pass the terms to the next layer, since we wouldn't gain
453 # anything by using a half adder since there would still be 2 terms
454 # and just passing the terms to the next layer saves gates.
455 add_intermediate_term(self._resized_inputs[-2])
456 add_intermediate_term(self._resized_inputs[-1])
457 else:
458 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
459
460 self.intermediate_terms = intermediate_terms
461 self._intermediate_terms = _intermediate_terms
462
463
464 class AddReduce(Elaboratable):
465 """Recursively Add list of numbers together.
466
467 :attribute inputs: input ``Signal``s to be summed. Modification not
468 supported, except for by ``Signal.eq``.
469 :attribute register_levels: List of nesting levels that should have
470 pipeline registers.
471 :attribute output: output sum.
472 :attribute partition_points: the input partition points. Modification not
473 supported, except for by ``Signal.eq``.
474 """
475
476 def __init__(self, inputs, output_width, register_levels, partition_points,
477 part_ops):
478 """Create an ``AddReduce``.
479
480 :param inputs: input ``Signal``s to be summed.
481 :param output_width: bit-width of ``output``.
482 :param register_levels: List of nesting levels that should have
483 pipeline registers.
484 :param partition_points: the input partition points.
485 """
486 self.inputs = inputs
487 self.part_ops = part_ops
488 self.output = Signal(output_width)
489 self.output_width = output_width
490 self.register_levels = register_levels
491 self.partition_points = partition_points
492
493 self.create_levels()
494
495 @staticmethod
496 def next_register_levels(register_levels):
497 """``Iterable`` of ``register_levels`` for next recursive level."""
498 for level in register_levels:
499 if level > 0:
500 yield level - 1
501
502 def create_levels(self):
503 """creates reduction levels"""
504
505 mods = []
506 next_levels = self.register_levels
507 partition_points = self.partition_points
508 inputs = self.inputs
509 while True:
510 next_level = AddReduceSingle(inputs, self.output_width, next_levels,
511 partition_points, self.part_ops)
512 mods.append(next_level)
513 if len(next_level.groups) == 0:
514 break
515 next_levels = list(AddReduce.next_register_levels(next_levels))
516 partition_points = next_level._reg_partition_points
517 inputs = next_level.intermediate_terms
518
519 self.levels = mods
520
521 def elaborate(self, platform):
522 """Elaborate this module."""
523 m = Module()
524
525 for i, next_level in enumerate(self.levels):
526 setattr(m.submodules, "next_level%d" % i, next_level)
527
528 # output comes from last module
529 m.d.comb += self.output.eq(next_level.output)
530
531 return m
532
533
534 OP_MUL_LOW = 0
535 OP_MUL_SIGNED_HIGH = 1
536 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
537 OP_MUL_UNSIGNED_HIGH = 3
538
539
540 def get_term(value, shift=0, enabled=None):
541 if enabled is not None:
542 value = Mux(enabled, value, 0)
543 if shift > 0:
544 value = Cat(Repl(C(0, 1), shift), value)
545 else:
546 assert shift == 0
547 return value
548
549
550 class ProductTerm(Elaboratable):
551 """ this class creates a single product term (a[..]*b[..]).
552 it has a design flaw in that is the *output* that is selected,
553 where the multiplication(s) are combinatorially generated
554 all the time.
555 """
556
557 def __init__(self, width, twidth, pbwid, a_index, b_index):
558 self.a_index = a_index
559 self.b_index = b_index
560 shift = 8 * (self.a_index + self.b_index)
561 self.pwidth = width
562 self.twidth = twidth
563 self.width = width*2
564 self.shift = shift
565
566 self.ti = Signal(self.width, reset_less=True)
567 self.term = Signal(twidth, reset_less=True)
568 self.a = Signal(twidth//2, reset_less=True)
569 self.b = Signal(twidth//2, reset_less=True)
570 self.pb_en = Signal(pbwid, reset_less=True)
571
572 self.tl = tl = []
573 min_index = min(self.a_index, self.b_index)
574 max_index = max(self.a_index, self.b_index)
575 for i in range(min_index, max_index):
576 tl.append(self.pb_en[i])
577 name = "te_%d_%d" % (self.a_index, self.b_index)
578 if len(tl) > 0:
579 term_enabled = Signal(name=name, reset_less=True)
580 else:
581 term_enabled = None
582 self.enabled = term_enabled
583 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
584
585 def elaborate(self, platform):
586
587 m = Module()
588 if self.enabled is not None:
589 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
590
591 bsa = Signal(self.width, reset_less=True)
592 bsb = Signal(self.width, reset_less=True)
593 a_index, b_index = self.a_index, self.b_index
594 pwidth = self.pwidth
595 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
596 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
597 m.d.comb += self.ti.eq(bsa * bsb)
598 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
599 """
600 #TODO: sort out width issues, get inputs a/b switched on/off.
601 #data going into Muxes is 1/2 the required width
602
603 pwidth = self.pwidth
604 width = self.width
605 bsa = Signal(self.twidth//2, reset_less=True)
606 bsb = Signal(self.twidth//2, reset_less=True)
607 asel = Signal(width, reset_less=True)
608 bsel = Signal(width, reset_less=True)
609 a_index, b_index = self.a_index, self.b_index
610 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
611 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
612 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
613 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
614 m.d.comb += self.ti.eq(bsa * bsb)
615 m.d.comb += self.term.eq(self.ti)
616 """
617
618 return m
619
620
621 class ProductTerms(Elaboratable):
622 """ creates a bank of product terms. also performs the actual bit-selection
623 this class is to be wrapped with a for-loop on the "a" operand.
624 it creates a second-level for-loop on the "b" operand.
625 """
626 def __init__(self, width, twidth, pbwid, a_index, blen):
627 self.a_index = a_index
628 self.blen = blen
629 self.pwidth = width
630 self.twidth = twidth
631 self.pbwid = pbwid
632 self.a = Signal(twidth//2, reset_less=True)
633 self.b = Signal(twidth//2, reset_less=True)
634 self.pb_en = Signal(pbwid, reset_less=True)
635 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
636 for i in range(blen)]
637
638 def elaborate(self, platform):
639
640 m = Module()
641
642 for b_index in range(self.blen):
643 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
644 self.a_index, b_index)
645 setattr(m.submodules, "term_%d" % b_index, t)
646
647 m.d.comb += t.a.eq(self.a)
648 m.d.comb += t.b.eq(self.b)
649 m.d.comb += t.pb_en.eq(self.pb_en)
650
651 m.d.comb += self.terms[b_index].eq(t.term)
652
653 return m
654
655
656 class LSBNegTerm(Elaboratable):
657
658 def __init__(self, bit_width):
659 self.bit_width = bit_width
660 self.part = Signal(reset_less=True)
661 self.signed = Signal(reset_less=True)
662 self.op = Signal(bit_width, reset_less=True)
663 self.msb = Signal(reset_less=True)
664 self.nt = Signal(bit_width*2, reset_less=True)
665 self.nl = Signal(bit_width*2, reset_less=True)
666
667 def elaborate(self, platform):
668 m = Module()
669 comb = m.d.comb
670 bit_wid = self.bit_width
671 ext = Repl(0, bit_wid) # extend output to HI part
672
673 # determine sign of each incoming number *in this partition*
674 enabled = Signal(reset_less=True)
675 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
676
677 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
678 # negation operation is split into a bitwise not and a +1.
679 # likewise for 16, 32, and 64-bit values.
680
681 # width-extended 1s complement if a is signed, otherwise zero
682 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
683
684 # add 1 if signed, otherwise add zero
685 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
686
687 return m
688
689
690 class Part(Elaboratable):
691 """ a key class which, depending on the partitioning, will determine
692 what action to take when parts of the output are signed or unsigned.
693
694 this requires 2 pieces of data *per operand, per partition*:
695 whether the MSB is HI/LO (per partition!), and whether a signed
696 or unsigned operation has been *requested*.
697
698 once that is determined, signed is basically carried out
699 by splitting 2's complement into 1's complement plus one.
700 1's complement is just a bit-inversion.
701
702 the extra terms - as separate terms - are then thrown at the
703 AddReduce alongside the multiplication part-results.
704 """
705 def __init__(self, width, n_parts, n_levels, pbwid):
706
707 # inputs
708 self.a = Signal(64)
709 self.b = Signal(64)
710 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
711 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
712 self.pbs = Signal(pbwid, reset_less=True)
713
714 # outputs
715 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
716 self.delayed_parts = [
717 [Signal(name=f"delayed_part_{delay}_{i}")
718 for i in range(n_parts)]
719 for delay in range(n_levels)]
720 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
721 self.dplast = [Signal(name=f"dplast_{i}")
722 for i in range(n_parts)]
723
724 self.not_a_term = Signal(width)
725 self.neg_lsb_a_term = Signal(width)
726 self.not_b_term = Signal(width)
727 self.neg_lsb_b_term = Signal(width)
728
729 def elaborate(self, platform):
730 m = Module()
731
732 pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
733 # negated-temporary copy of partition bits
734 npbs = Signal.like(pbs, reset_less=True)
735 m.d.comb += npbs.eq(~pbs)
736 byte_count = 8 // len(parts)
737 for i in range(len(parts)):
738 pbl = []
739 pbl.append(npbs[i * byte_count - 1])
740 for j in range(i * byte_count, (i + 1) * byte_count - 1):
741 pbl.append(pbs[j])
742 pbl.append(npbs[(i + 1) * byte_count - 1])
743 value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
744 m.d.comb += value.eq(Cat(*pbl))
745 m.d.comb += parts[i].eq(~(value).bool())
746 m.d.comb += delayed_parts[0][i].eq(parts[i])
747 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
748 for j in range(len(delayed_parts)-1)]
749 m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
750
751 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
752 self.not_a_term, self.neg_lsb_a_term, \
753 self.not_b_term, self.neg_lsb_b_term
754
755 byte_width = 8 // len(parts) # byte width
756 bit_wid = 8 * byte_width # bit width
757 nat, nbt, nla, nlb = [], [], [], []
758 for i in range(len(parts)):
759 # work out bit-inverted and +1 term for a.
760 pa = LSBNegTerm(bit_wid)
761 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
762 m.d.comb += pa.part.eq(parts[i])
763 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
764 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
765 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
766 nat.append(pa.nt)
767 nla.append(pa.nl)
768
769 # work out bit-inverted and +1 term for b
770 pb = LSBNegTerm(bit_wid)
771 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
772 m.d.comb += pb.part.eq(parts[i])
773 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
774 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
775 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
776 nbt.append(pb.nt)
777 nlb.append(pb.nl)
778
779 # concatenate together and return all 4 results.
780 m.d.comb += [not_a_term.eq(Cat(*nat)),
781 not_b_term.eq(Cat(*nbt)),
782 neg_lsb_a_term.eq(Cat(*nla)),
783 neg_lsb_b_term.eq(Cat(*nlb)),
784 ]
785
786 return m
787
788
789 class IntermediateOut(Elaboratable):
790 """ selects the HI/LO part of the multiplication, for a given bit-width
791 the output is also reconstructed in its SIMD (partition) lanes.
792 """
793 def __init__(self, width, out_wid, n_parts):
794 self.width = width
795 self.n_parts = n_parts
796 self.delayed_part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
797 for i in range(8)]
798 self.intermed = Signal(out_wid, reset_less=True)
799 self.output = Signal(out_wid//2, reset_less=True)
800
801 def elaborate(self, platform):
802 m = Module()
803
804 ol = []
805 w = self.width
806 sel = w // 8
807 for i in range(self.n_parts):
808 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
809 m.d.comb += op.eq(
810 Mux(self.delayed_part_ops[sel * i] == OP_MUL_LOW,
811 self.intermed.bit_select(i * w*2, w),
812 self.intermed.bit_select(i * w*2 + w, w)))
813 ol.append(op)
814 m.d.comb += self.output.eq(Cat(*ol))
815
816 return m
817
818
819 class FinalOut(Elaboratable):
820 """ selects the final output based on the partitioning.
821
822 each byte is selectable independently, i.e. it is possible
823 that some partitions requested 8-bit computation whilst others
824 requested 16 or 32 bit.
825 """
826 def __init__(self, out_wid):
827 # inputs
828 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
829 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
830 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
831
832 self.i8 = Signal(out_wid, reset_less=True)
833 self.i16 = Signal(out_wid, reset_less=True)
834 self.i32 = Signal(out_wid, reset_less=True)
835 self.i64 = Signal(out_wid, reset_less=True)
836
837 # output
838 self.out = Signal(out_wid, reset_less=True)
839
840 def elaborate(self, platform):
841 m = Module()
842 ol = []
843 for i in range(8):
844 # select one of the outputs: d8 selects i8, d16 selects i16
845 # d32 selects i32, and the default is i64.
846 # d8 and d16 are ORed together in the first Mux
847 # then the 2nd selects either i8 or i16.
848 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
849 op = Signal(8, reset_less=True, name="op_%d" % i)
850 m.d.comb += op.eq(
851 Mux(self.d8[i] | self.d16[i // 2],
852 Mux(self.d8[i], self.i8.bit_select(i * 8, 8),
853 self.i16.bit_select(i * 8, 8)),
854 Mux(self.d32[i // 4], self.i32.bit_select(i * 8, 8),
855 self.i64.bit_select(i * 8, 8))))
856 ol.append(op)
857 m.d.comb += self.out.eq(Cat(*ol))
858 return m
859
860
861 class OrMod(Elaboratable):
862 """ ORs four values together in a hierarchical tree
863 """
864 def __init__(self, wid):
865 self.wid = wid
866 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
867 for i in range(4)]
868 self.orout = Signal(wid, reset_less=True)
869
870 def elaborate(self, platform):
871 m = Module()
872 or1 = Signal(self.wid, reset_less=True)
873 or2 = Signal(self.wid, reset_less=True)
874 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
875 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
876 m.d.comb += self.orout.eq(or1 | or2)
877
878 return m
879
880
881 class Signs(Elaboratable):
882 """ determines whether a or b are signed numbers
883 based on the required operation type (OP_MUL_*)
884 """
885
886 def __init__(self):
887 self.part_ops = Signal(2, reset_less=True)
888 self.a_signed = Signal(reset_less=True)
889 self.b_signed = Signal(reset_less=True)
890
891 def elaborate(self, platform):
892
893 m = Module()
894
895 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
896 bsig = (self.part_ops == OP_MUL_LOW) \
897 | (self.part_ops == OP_MUL_SIGNED_HIGH)
898 m.d.comb += self.a_signed.eq(asig)
899 m.d.comb += self.b_signed.eq(bsig)
900
901 return m
902
903
904 class Mul8_16_32_64(Elaboratable):
905 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
906
907 Supports partitioning into any combination of 8, 16, 32, and 64-bit
908 partitions on naturally-aligned boundaries. Supports the operation being
909 set for each partition independently.
910
911 :attribute part_pts: the input partition points. Has a partition point at
912 multiples of 8 in 0 < i < 64. Each partition point's associated
913 ``Value`` is a ``Signal``. Modification not supported, except for by
914 ``Signal.eq``.
915 :attribute part_ops: the operation for each byte. The operation for a
916 particular partition is selected by assigning the selected operation
917 code to each byte in the partition. The allowed operation codes are:
918
919 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
920 RISC-V's `mul` instruction.
921 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
922 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
923 instruction.
924 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
925 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
926 `mulhsu` instruction.
927 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
928 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
929 instruction.
930 """
931
932 def __init__(self, register_levels=()):
933 """ register_levels: specifies the points in the cascade at which
934 flip-flops are to be inserted.
935 """
936
937 # parameter(s)
938 self.register_levels = list(register_levels)
939
940 # inputs
941 self.part_pts = PartitionPoints()
942 for i in range(8, 64, 8):
943 self.part_pts[i] = Signal(name=f"part_pts_{i}")
944 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
945 self.a = Signal(64)
946 self.b = Signal(64)
947
948 # intermediates (needed for unit tests)
949 self._intermediate_output = Signal(128)
950
951 # output
952 self.output = Signal(64)
953
954 def _part_byte(self, index):
955 if index == -1 or index == 7:
956 return C(True, 1)
957 assert index >= 0 and index < 8
958 return self.part_pts[index * 8 + 8]
959
960 def elaborate(self, platform):
961 m = Module()
962
963 # collect part-bytes
964 pbs = Signal(8, reset_less=True)
965 tl = []
966 for i in range(8):
967 pb = Signal(name="pb%d" % i, reset_less=True)
968 m.d.comb += pb.eq(self._part_byte(i))
969 tl.append(pb)
970 m.d.comb += pbs.eq(Cat(*tl))
971
972 # local variables
973 signs = []
974 for i in range(8):
975 s = Signs()
976 signs.append(s)
977 setattr(m.submodules, "signs%d" % i, s)
978 m.d.comb += s.part_ops.eq(self.part_ops[i])
979
980 delayed_part_ops = [
981 [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
982 for i in range(8)]
983 for delay in range(1 + len(self.register_levels))]
984 for i in range(len(self.part_ops)):
985 m.d.comb += delayed_part_ops[0][i].eq(self.part_ops[i])
986 m.d.sync += [delayed_part_ops[j + 1][i].eq(delayed_part_ops[j][i])
987 for j in range(len(self.register_levels))]
988
989 n_levels = len(self.register_levels)+1
990 m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
991 m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
992 m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
993 m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
994 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
995 for mod in [part_8, part_16, part_32, part_64]:
996 m.d.comb += mod.a.eq(self.a)
997 m.d.comb += mod.b.eq(self.b)
998 for i in range(len(signs)):
999 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1000 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1001 m.d.comb += mod.pbs.eq(pbs)
1002 nat_l.append(mod.not_a_term)
1003 nbt_l.append(mod.not_b_term)
1004 nla_l.append(mod.neg_lsb_a_term)
1005 nlb_l.append(mod.neg_lsb_b_term)
1006
1007 terms = []
1008
1009 for a_index in range(8):
1010 t = ProductTerms(8, 128, 8, a_index, 8)
1011 setattr(m.submodules, "terms_%d" % a_index, t)
1012
1013 m.d.comb += t.a.eq(self.a)
1014 m.d.comb += t.b.eq(self.b)
1015 m.d.comb += t.pb_en.eq(pbs)
1016
1017 for term in t.terms:
1018 terms.append(term)
1019
1020 # it's fine to bitwise-or data together since they are never enabled
1021 # at the same time
1022 m.submodules.nat_or = nat_or = OrMod(128)
1023 m.submodules.nbt_or = nbt_or = OrMod(128)
1024 m.submodules.nla_or = nla_or = OrMod(128)
1025 m.submodules.nlb_or = nlb_or = OrMod(128)
1026 for l, mod in [(nat_l, nat_or),
1027 (nbt_l, nbt_or),
1028 (nla_l, nla_or),
1029 (nlb_l, nlb_or)]:
1030 for i in range(len(l)):
1031 m.d.comb += mod.orin[i].eq(l[i])
1032 terms.append(mod.orout)
1033
1034 expanded_part_pts = PartitionPoints()
1035 for i, v in self.part_pts.items():
1036 signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
1037 expanded_part_pts[i * 2] = signal
1038 m.d.comb += signal.eq(v)
1039
1040 add_reduce = AddReduce(terms,
1041 128,
1042 self.register_levels,
1043 expanded_part_pts,
1044 self.part_ops)
1045
1046 m.submodules.add_reduce = add_reduce
1047 m.d.comb += self._intermediate_output.eq(add_reduce.output)
1048 # create _output_64
1049 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1050 m.d.comb += io64.intermed.eq(self._intermediate_output)
1051 for i in range(8):
1052 m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
1053
1054 # create _output_32
1055 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1056 m.d.comb += io32.intermed.eq(self._intermediate_output)
1057 for i in range(8):
1058 m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
1059
1060 # create _output_16
1061 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1062 m.d.comb += io16.intermed.eq(self._intermediate_output)
1063 for i in range(8):
1064 m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
1065
1066 # create _output_8
1067 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1068 m.d.comb += io8.intermed.eq(self._intermediate_output)
1069 for i in range(8):
1070 m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
1071
1072 # final output
1073 m.submodules.finalout = finalout = FinalOut(64)
1074 for i in range(len(part_8.delayed_parts[-1])):
1075 m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
1076 for i in range(len(part_16.delayed_parts[-1])):
1077 m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
1078 for i in range(len(part_32.delayed_parts[-1])):
1079 m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
1080 m.d.comb += finalout.i8.eq(io8.output)
1081 m.d.comb += finalout.i16.eq(io16.output)
1082 m.d.comb += finalout.i32.eq(io32.output)
1083 m.d.comb += finalout.i64.eq(io64.output)
1084 m.d.comb += self.output.eq(finalout.out)
1085
1086 return m
1087
1088
1089 if __name__ == "__main__":
1090 m = Mul8_16_32_64()
1091 main(m, ports=[m.a,
1092 m.b,
1093 m._intermediate_output,
1094 m.output,
1095 *m.part_ops,
1096 *m.part_pts.values()])