rename delayed_part_ops to part_ops
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 """
58 if name is None:
59 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
60 retval = PartitionPoints()
61 for point, enabled in self.items():
62 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
63 return retval
64
65 def eq(self, rhs):
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self.keys()) != set(rhs.keys()):
68 raise ValueError("incompatible point set")
69 for point, enabled in self.items():
70 yield enabled.eq(rhs[point])
71
72 def as_mask(self, width):
73 """Create a bit-mask from `self`.
74
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
77
78 :param width: the bit width of the resulting mask
79 """
80 bits = []
81 for i in range(width):
82 if i in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106
107 class FullAdder(Elaboratable):
108 """Full Adder.
109
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
115
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
120 """
121
122 def __init__(self, width):
123 """Create a ``FullAdder``.
124
125 :param width: the bit width of the input and output
126 """
127 self.in0 = Signal(width)
128 self.in1 = Signal(width)
129 self.in2 = Signal(width)
130 self.sum = Signal(width)
131 self.carry = Signal(width)
132
133 def elaborate(self, platform):
134 """Elaborate this module."""
135 m = Module()
136 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
137 m.d.comb += self.carry.eq((self.in0 & self.in1)
138 | (self.in1 & self.in2)
139 | (self.in2 & self.in0))
140 return m
141
142
143 class MaskedFullAdder(Elaboratable):
144 """Masked Full Adder.
145
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
152
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
156
157 See the following discussion as to why this is no longer derived
158 from FullAdder. Each carry is shifted here *before* being ANDed
159 with the mask, so that an AOI cell may be used (which is more
160 gate-efficient)
161 https://en.wikipedia.org/wiki/AND-OR-Invert
162 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
163 """
164
165 def __init__(self, width):
166 """Create a ``MaskedFullAdder``.
167
168 :param width: the bit width of the input and output
169 """
170 self.width = width
171 self.mask = Signal(width, reset_less=True)
172 self.mcarry = Signal(width, reset_less=True)
173 self.in0 = Signal(width, reset_less=True)
174 self.in1 = Signal(width, reset_less=True)
175 self.in2 = Signal(width, reset_less=True)
176 self.sum = Signal(width, reset_less=True)
177
178 def elaborate(self, platform):
179 """Elaborate this module."""
180 m = Module()
181 s1 = Signal(self.width, reset_less=True)
182 s2 = Signal(self.width, reset_less=True)
183 s3 = Signal(self.width, reset_less=True)
184 c1 = Signal(self.width, reset_less=True)
185 c2 = Signal(self.width, reset_less=True)
186 c3 = Signal(self.width, reset_less=True)
187 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
188 m.d.comb += s1.eq(Cat(0, self.in0))
189 m.d.comb += s2.eq(Cat(0, self.in1))
190 m.d.comb += s3.eq(Cat(0, self.in2))
191 m.d.comb += c1.eq(s1 & s2 & self.mask)
192 m.d.comb += c2.eq(s2 & s3 & self.mask)
193 m.d.comb += c3.eq(s3 & s1 & self.mask)
194 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
195 return m
196
197
198 class PartitionedAdder(Elaboratable):
199 """Partitioned Adder.
200
201 Performs the final add. The partition points are included in the
202 actual add (in one of the operands only), which causes a carry over
203 to the next bit. Then the final output *removes* the extra bits from
204 the result.
205
206 partition: .... P... P... P... P... (32 bits)
207 a : .... .... .... .... .... (32 bits)
208 b : .... .... .... .... .... (32 bits)
209 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
210 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
211 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
212 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
213
214 :attribute width: the bit width of the input and output. Read-only.
215 :attribute a: the first input to the adder
216 :attribute b: the second input to the adder
217 :attribute output: the sum output
218 :attribute partition_points: the input partition points. Modification not
219 supported, except for by ``Signal.eq``.
220 """
221
222 def __init__(self, width, partition_points):
223 """Create a ``PartitionedAdder``.
224
225 :param width: the bit width of the input and output
226 :param partition_points: the input partition points
227 """
228 self.width = width
229 self.a = Signal(width)
230 self.b = Signal(width)
231 self.output = Signal(width)
232 self.partition_points = PartitionPoints(partition_points)
233 if not self.partition_points.fits_in_width(width):
234 raise ValueError("partition_points doesn't fit in width")
235 expanded_width = 0
236 for i in range(self.width):
237 if i in self.partition_points:
238 expanded_width += 1
239 expanded_width += 1
240 self._expanded_width = expanded_width
241 # XXX these have to remain here due to some horrible nmigen
242 # simulation bugs involving sync. it is *not* necessary to
243 # have them here, they should (under normal circumstances)
244 # be moved into elaborate, as they are entirely local
245 self._expanded_a = Signal(expanded_width) # includes extra part-points
246 self._expanded_b = Signal(expanded_width) # likewise.
247 self._expanded_o = Signal(expanded_width) # likewise.
248
249 def elaborate(self, platform):
250 """Elaborate this module."""
251 m = Module()
252 expanded_index = 0
253 # store bits in a list, use Cat later. graphviz is much cleaner
254 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
255
256 # partition points are "breaks" (extra zeros or 1s) in what would
257 # otherwise be a massive long add. when the "break" points are 0,
258 # whatever is in it (in the output) is discarded. however when
259 # there is a "1", it causes a roll-over carry to the *next* bit.
260 # we still ignore the "break" bit in the [intermediate] output,
261 # however by that time we've got the effect that we wanted: the
262 # carry has been carried *over* the break point.
263
264 for i in range(self.width):
265 if i in self.partition_points:
266 # add extra bit set to 0 + 0 for enabled partition points
267 # and 1 + 0 for disabled partition points
268 ea.append(self._expanded_a[expanded_index])
269 al.append(~self.partition_points[i]) # add extra bit in a
270 eb.append(self._expanded_b[expanded_index])
271 bl.append(C(0)) # yes, add a zero
272 expanded_index += 1 # skip the extra point. NOT in the output
273 ea.append(self._expanded_a[expanded_index])
274 eb.append(self._expanded_b[expanded_index])
275 eo.append(self._expanded_o[expanded_index])
276 al.append(self.a[i])
277 bl.append(self.b[i])
278 ol.append(self.output[i])
279 expanded_index += 1
280
281 # combine above using Cat
282 m.d.comb += Cat(*ea).eq(Cat(*al))
283 m.d.comb += Cat(*eb).eq(Cat(*bl))
284 m.d.comb += Cat(*ol).eq(Cat(*eo))
285
286 # use only one addition to take advantage of look-ahead carry and
287 # special hardware on FPGAs
288 m.d.comb += self._expanded_o.eq(
289 self._expanded_a + self._expanded_b)
290 return m
291
292
293 FULL_ADDER_INPUT_COUNT = 3
294
295
296 class AddReduceSingle(Elaboratable):
297 """Add list of numbers together.
298
299 :attribute inputs: input ``Signal``s to be summed. Modification not
300 supported, except for by ``Signal.eq``.
301 :attribute register_levels: List of nesting levels that should have
302 pipeline registers.
303 :attribute output: output sum.
304 :attribute partition_points: the input partition points. Modification not
305 supported, except for by ``Signal.eq``.
306 """
307
308 def __init__(self, inputs, output_width, register_levels, partition_points,
309 part_ops):
310 """Create an ``AddReduce``.
311
312 :param inputs: input ``Signal``s to be summed.
313 :param output_width: bit-width of ``output``.
314 :param register_levels: List of nesting levels that should have
315 pipeline registers.
316 :param partition_points: the input partition points.
317 """
318 self.part_ops = part_ops
319 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
320 for i in range(len(part_ops))]
321 self.inputs = list(inputs)
322 self._resized_inputs = [
323 Signal(output_width, name=f"resized_inputs[{i}]")
324 for i in range(len(self.inputs))]
325 self.register_levels = list(register_levels)
326 self.output = Signal(output_width)
327 self.partition_points = PartitionPoints(partition_points)
328 if not self.partition_points.fits_in_width(output_width):
329 raise ValueError("partition_points doesn't fit in output_width")
330 self._reg_partition_points = self.partition_points.like()
331
332 max_level = AddReduceSingle.get_max_level(len(self.inputs))
333 for level in self.register_levels:
334 if level > max_level:
335 raise ValueError(
336 "not enough adder levels for specified register levels")
337
338 # this is annoying. we have to create the modules (and terms)
339 # because we need to know what they are (in order to set up the
340 # interconnects back in AddReduce), but cannot do the m.d.comb +=
341 # etc because this is not in elaboratable.
342 self.groups = AddReduceSingle.full_adder_groups(len(self.inputs))
343 self._intermediate_terms = []
344 if len(self.groups) != 0:
345 self.create_next_terms()
346
347 @staticmethod
348 def get_max_level(input_count):
349 """Get the maximum level.
350
351 All ``register_levels`` must be less than or equal to the maximum
352 level.
353 """
354 retval = 0
355 while True:
356 groups = AddReduceSingle.full_adder_groups(input_count)
357 if len(groups) == 0:
358 return retval
359 input_count %= FULL_ADDER_INPUT_COUNT
360 input_count += 2 * len(groups)
361 retval += 1
362
363 @staticmethod
364 def full_adder_groups(input_count):
365 """Get ``inputs`` indices for which a full adder should be built."""
366 return range(0,
367 input_count - FULL_ADDER_INPUT_COUNT + 1,
368 FULL_ADDER_INPUT_COUNT)
369
370 def elaborate(self, platform):
371 """Elaborate this module."""
372 m = Module()
373
374 # resize inputs to correct bit-width and optionally add in
375 # pipeline registers
376 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
377 for i in range(len(self.inputs))]
378 copy_part_ops = [self.out_part_ops[i].eq(self.part_ops[i])
379 for i in range(len(self.part_ops))]
380 if 0 in self.register_levels:
381 m.d.sync += copy_part_ops
382 m.d.sync += resized_input_assignments
383 m.d.sync += self._reg_partition_points.eq(self.partition_points)
384 else:
385 m.d.comb += copy_part_ops
386 m.d.comb += resized_input_assignments
387 m.d.comb += self._reg_partition_points.eq(self.partition_points)
388
389 for (value, term) in self._intermediate_terms:
390 m.d.comb += term.eq(value)
391
392 # if there are no full adders to create, then we handle the base cases
393 # and return, otherwise we go on to the recursive case
394 if len(self.groups) == 0:
395 if len(self.inputs) == 0:
396 # use 0 as the default output value
397 m.d.comb += self.output.eq(0)
398 elif len(self.inputs) == 1:
399 # handle single input
400 m.d.comb += self.output.eq(self._resized_inputs[0])
401 else:
402 # base case for adding 2 inputs
403 assert len(self.inputs) == 2
404 adder = PartitionedAdder(len(self.output),
405 self._reg_partition_points)
406 m.submodules.final_adder = adder
407 m.d.comb += adder.a.eq(self._resized_inputs[0])
408 m.d.comb += adder.b.eq(self._resized_inputs[1])
409 m.d.comb += self.output.eq(adder.output)
410 return m
411
412 mask = self._reg_partition_points.as_mask(len(self.output))
413 m.d.comb += self.part_mask.eq(mask)
414
415 # add and link the intermediate term modules
416 for i, (iidx, adder_i) in enumerate(self.adders):
417 setattr(m.submodules, f"adder_{i}", adder_i)
418
419 m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
420 m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
421 m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
422 m.d.comb += adder_i.mask.eq(self.part_mask)
423
424 return m
425
426 def create_next_terms(self):
427
428 # go on to prepare recursive case
429 intermediate_terms = []
430 _intermediate_terms = []
431
432 def add_intermediate_term(value):
433 intermediate_term = Signal(
434 len(self.output),
435 name=f"intermediate_terms[{len(intermediate_terms)}]")
436 _intermediate_terms.append((value, intermediate_term))
437 intermediate_terms.append(intermediate_term)
438
439 # store mask in intermediary (simplifies graph)
440 self.part_mask = Signal(len(self.output), reset_less=True)
441
442 # create full adders for this recursive level.
443 # this shrinks N terms to 2 * (N // 3) plus the remainder
444 self.adders = []
445 for i in self.groups:
446 adder_i = MaskedFullAdder(len(self.output))
447 self.adders.append((i, adder_i))
448 # add both the sum and the masked-carry to the next level.
449 # 3 inputs have now been reduced to 2...
450 add_intermediate_term(adder_i.sum)
451 add_intermediate_term(adder_i.mcarry)
452 # handle the remaining inputs.
453 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
454 add_intermediate_term(self._resized_inputs[-1])
455 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
456 # Just pass the terms to the next layer, since we wouldn't gain
457 # anything by using a half adder since there would still be 2 terms
458 # and just passing the terms to the next layer saves gates.
459 add_intermediate_term(self._resized_inputs[-2])
460 add_intermediate_term(self._resized_inputs[-1])
461 else:
462 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
463
464 self.intermediate_terms = intermediate_terms
465 self._intermediate_terms = _intermediate_terms
466
467
468 class AddReduce(Elaboratable):
469 """Recursively Add list of numbers together.
470
471 :attribute inputs: input ``Signal``s to be summed. Modification not
472 supported, except for by ``Signal.eq``.
473 :attribute register_levels: List of nesting levels that should have
474 pipeline registers.
475 :attribute output: output sum.
476 :attribute partition_points: the input partition points. Modification not
477 supported, except for by ``Signal.eq``.
478 """
479
480 def __init__(self, inputs, output_width, register_levels, partition_points,
481 part_ops):
482 """Create an ``AddReduce``.
483
484 :param inputs: input ``Signal``s to be summed.
485 :param output_width: bit-width of ``output``.
486 :param register_levels: List of nesting levels that should have
487 pipeline registers.
488 :param partition_points: the input partition points.
489 """
490 self.inputs = inputs
491 self.part_ops = part_ops
492 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
493 for i in range(len(part_ops))]
494 self.output = Signal(output_width)
495 self.output_width = output_width
496 self.register_levels = register_levels
497 self.partition_points = partition_points
498
499 self.create_levels()
500
501 @staticmethod
502 def get_max_level(input_count):
503 return AddReduceSingle.get_max_level(input_count)
504
505 @staticmethod
506 def next_register_levels(register_levels):
507 """``Iterable`` of ``register_levels`` for next recursive level."""
508 for level in register_levels:
509 if level > 0:
510 yield level - 1
511
512 def create_levels(self):
513 """creates reduction levels"""
514
515 mods = []
516 next_levels = self.register_levels
517 partition_points = self.partition_points
518 inputs = self.inputs
519 part_ops = self.part_ops
520 while True:
521 next_level = AddReduceSingle(inputs, self.output_width, next_levels,
522 partition_points, part_ops)
523 mods.append(next_level)
524 if len(next_level.groups) == 0:
525 break
526 next_levels = list(AddReduce.next_register_levels(next_levels))
527 partition_points = next_level._reg_partition_points
528 inputs = next_level.intermediate_terms
529 part_ops = next_level.out_part_ops
530
531 self.levels = mods
532
533 def elaborate(self, platform):
534 """Elaborate this module."""
535 m = Module()
536
537 for i, next_level in enumerate(self.levels):
538 setattr(m.submodules, "next_level%d" % i, next_level)
539
540 # output comes from last module
541 m.d.comb += self.output.eq(next_level.output)
542 copy_part_ops = [self.out_part_ops[i].eq(next_level.out_part_ops[i])
543 for i in range(len(self.part_ops))]
544 m.d.comb += copy_part_ops
545
546 return m
547
548
549 OP_MUL_LOW = 0
550 OP_MUL_SIGNED_HIGH = 1
551 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
552 OP_MUL_UNSIGNED_HIGH = 3
553
554
555 def get_term(value, shift=0, enabled=None):
556 if enabled is not None:
557 value = Mux(enabled, value, 0)
558 if shift > 0:
559 value = Cat(Repl(C(0, 1), shift), value)
560 else:
561 assert shift == 0
562 return value
563
564
565 class ProductTerm(Elaboratable):
566 """ this class creates a single product term (a[..]*b[..]).
567 it has a design flaw in that is the *output* that is selected,
568 where the multiplication(s) are combinatorially generated
569 all the time.
570 """
571
572 def __init__(self, width, twidth, pbwid, a_index, b_index):
573 self.a_index = a_index
574 self.b_index = b_index
575 shift = 8 * (self.a_index + self.b_index)
576 self.pwidth = width
577 self.twidth = twidth
578 self.width = width*2
579 self.shift = shift
580
581 self.ti = Signal(self.width, reset_less=True)
582 self.term = Signal(twidth, reset_less=True)
583 self.a = Signal(twidth//2, reset_less=True)
584 self.b = Signal(twidth//2, reset_less=True)
585 self.pb_en = Signal(pbwid, reset_less=True)
586
587 self.tl = tl = []
588 min_index = min(self.a_index, self.b_index)
589 max_index = max(self.a_index, self.b_index)
590 for i in range(min_index, max_index):
591 tl.append(self.pb_en[i])
592 name = "te_%d_%d" % (self.a_index, self.b_index)
593 if len(tl) > 0:
594 term_enabled = Signal(name=name, reset_less=True)
595 else:
596 term_enabled = None
597 self.enabled = term_enabled
598 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
599
600 def elaborate(self, platform):
601
602 m = Module()
603 if self.enabled is not None:
604 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
605
606 bsa = Signal(self.width, reset_less=True)
607 bsb = Signal(self.width, reset_less=True)
608 a_index, b_index = self.a_index, self.b_index
609 pwidth = self.pwidth
610 m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
611 m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
612 m.d.comb += self.ti.eq(bsa * bsb)
613 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
614 """
615 #TODO: sort out width issues, get inputs a/b switched on/off.
616 #data going into Muxes is 1/2 the required width
617
618 pwidth = self.pwidth
619 width = self.width
620 bsa = Signal(self.twidth//2, reset_less=True)
621 bsb = Signal(self.twidth//2, reset_less=True)
622 asel = Signal(width, reset_less=True)
623 bsel = Signal(width, reset_less=True)
624 a_index, b_index = self.a_index, self.b_index
625 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
626 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
627 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
628 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
629 m.d.comb += self.ti.eq(bsa * bsb)
630 m.d.comb += self.term.eq(self.ti)
631 """
632
633 return m
634
635
636 class ProductTerms(Elaboratable):
637 """ creates a bank of product terms. also performs the actual bit-selection
638 this class is to be wrapped with a for-loop on the "a" operand.
639 it creates a second-level for-loop on the "b" operand.
640 """
641 def __init__(self, width, twidth, pbwid, a_index, blen):
642 self.a_index = a_index
643 self.blen = blen
644 self.pwidth = width
645 self.twidth = twidth
646 self.pbwid = pbwid
647 self.a = Signal(twidth//2, reset_less=True)
648 self.b = Signal(twidth//2, reset_less=True)
649 self.pb_en = Signal(pbwid, reset_less=True)
650 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
651 for i in range(blen)]
652
653 def elaborate(self, platform):
654
655 m = Module()
656
657 for b_index in range(self.blen):
658 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
659 self.a_index, b_index)
660 setattr(m.submodules, "term_%d" % b_index, t)
661
662 m.d.comb += t.a.eq(self.a)
663 m.d.comb += t.b.eq(self.b)
664 m.d.comb += t.pb_en.eq(self.pb_en)
665
666 m.d.comb += self.terms[b_index].eq(t.term)
667
668 return m
669
670
671 class LSBNegTerm(Elaboratable):
672
673 def __init__(self, bit_width):
674 self.bit_width = bit_width
675 self.part = Signal(reset_less=True)
676 self.signed = Signal(reset_less=True)
677 self.op = Signal(bit_width, reset_less=True)
678 self.msb = Signal(reset_less=True)
679 self.nt = Signal(bit_width*2, reset_less=True)
680 self.nl = Signal(bit_width*2, reset_less=True)
681
682 def elaborate(self, platform):
683 m = Module()
684 comb = m.d.comb
685 bit_wid = self.bit_width
686 ext = Repl(0, bit_wid) # extend output to HI part
687
688 # determine sign of each incoming number *in this partition*
689 enabled = Signal(reset_less=True)
690 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
691
692 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
693 # negation operation is split into a bitwise not and a +1.
694 # likewise for 16, 32, and 64-bit values.
695
696 # width-extended 1s complement if a is signed, otherwise zero
697 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
698
699 # add 1 if signed, otherwise add zero
700 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
701
702 return m
703
704
705 class Part(Elaboratable):
706 """ a key class which, depending on the partitioning, will determine
707 what action to take when parts of the output are signed or unsigned.
708
709 this requires 2 pieces of data *per operand, per partition*:
710 whether the MSB is HI/LO (per partition!), and whether a signed
711 or unsigned operation has been *requested*.
712
713 once that is determined, signed is basically carried out
714 by splitting 2's complement into 1's complement plus one.
715 1's complement is just a bit-inversion.
716
717 the extra terms - as separate terms - are then thrown at the
718 AddReduce alongside the multiplication part-results.
719 """
720 def __init__(self, width, n_parts, n_levels, pbwid):
721
722 # inputs
723 self.a = Signal(64)
724 self.b = Signal(64)
725 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
726 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
727 self.pbs = Signal(pbwid, reset_less=True)
728
729 # outputs
730 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
731 self.delayed_parts = [
732 [Signal(name=f"delayed_part_{delay}_{i}")
733 for i in range(n_parts)]
734 for delay in range(n_levels)]
735 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
736 self.dplast = [Signal(name=f"dplast_{i}")
737 for i in range(n_parts)]
738
739 self.not_a_term = Signal(width)
740 self.neg_lsb_a_term = Signal(width)
741 self.not_b_term = Signal(width)
742 self.neg_lsb_b_term = Signal(width)
743
744 def elaborate(self, platform):
745 m = Module()
746
747 pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
748 # negated-temporary copy of partition bits
749 npbs = Signal.like(pbs, reset_less=True)
750 m.d.comb += npbs.eq(~pbs)
751 byte_count = 8 // len(parts)
752 for i in range(len(parts)):
753 pbl = []
754 pbl.append(npbs[i * byte_count - 1])
755 for j in range(i * byte_count, (i + 1) * byte_count - 1):
756 pbl.append(pbs[j])
757 pbl.append(npbs[(i + 1) * byte_count - 1])
758 value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
759 m.d.comb += value.eq(Cat(*pbl))
760 m.d.comb += parts[i].eq(~(value).bool())
761 m.d.comb += delayed_parts[0][i].eq(parts[i])
762 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
763 for j in range(len(delayed_parts)-1)]
764 m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
765
766 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
767 self.not_a_term, self.neg_lsb_a_term, \
768 self.not_b_term, self.neg_lsb_b_term
769
770 byte_width = 8 // len(parts) # byte width
771 bit_wid = 8 * byte_width # bit width
772 nat, nbt, nla, nlb = [], [], [], []
773 for i in range(len(parts)):
774 # work out bit-inverted and +1 term for a.
775 pa = LSBNegTerm(bit_wid)
776 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
777 m.d.comb += pa.part.eq(parts[i])
778 m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
779 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
780 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
781 nat.append(pa.nt)
782 nla.append(pa.nl)
783
784 # work out bit-inverted and +1 term for b
785 pb = LSBNegTerm(bit_wid)
786 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
787 m.d.comb += pb.part.eq(parts[i])
788 m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
789 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
790 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
791 nbt.append(pb.nt)
792 nlb.append(pb.nl)
793
794 # concatenate together and return all 4 results.
795 m.d.comb += [not_a_term.eq(Cat(*nat)),
796 not_b_term.eq(Cat(*nbt)),
797 neg_lsb_a_term.eq(Cat(*nla)),
798 neg_lsb_b_term.eq(Cat(*nlb)),
799 ]
800
801 return m
802
803
804 class IntermediateOut(Elaboratable):
805 """ selects the HI/LO part of the multiplication, for a given bit-width
806 the output is also reconstructed in its SIMD (partition) lanes.
807 """
808 def __init__(self, width, out_wid, n_parts):
809 self.width = width
810 self.n_parts = n_parts
811 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
812 for i in range(8)]
813 self.intermed = Signal(out_wid, reset_less=True)
814 self.output = Signal(out_wid//2, reset_less=True)
815
816 def elaborate(self, platform):
817 m = Module()
818
819 ol = []
820 w = self.width
821 sel = w // 8
822 for i in range(self.n_parts):
823 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
824 m.d.comb += op.eq(
825 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
826 self.intermed.part(i * w*2, w),
827 self.intermed.part(i * w*2 + w, w)))
828 ol.append(op)
829 m.d.comb += self.output.eq(Cat(*ol))
830
831 return m
832
833
834 class FinalOut(Elaboratable):
835 """ selects the final output based on the partitioning.
836
837 each byte is selectable independently, i.e. it is possible
838 that some partitions requested 8-bit computation whilst others
839 requested 16 or 32 bit.
840 """
841 def __init__(self, out_wid):
842 # inputs
843 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
844 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
845 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
846
847 self.i8 = Signal(out_wid, reset_less=True)
848 self.i16 = Signal(out_wid, reset_less=True)
849 self.i32 = Signal(out_wid, reset_less=True)
850 self.i64 = Signal(out_wid, reset_less=True)
851
852 # output
853 self.out = Signal(out_wid, reset_less=True)
854
855 def elaborate(self, platform):
856 m = Module()
857 ol = []
858 for i in range(8):
859 # select one of the outputs: d8 selects i8, d16 selects i16
860 # d32 selects i32, and the default is i64.
861 # d8 and d16 are ORed together in the first Mux
862 # then the 2nd selects either i8 or i16.
863 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
864 op = Signal(8, reset_less=True, name="op_%d" % i)
865 m.d.comb += op.eq(
866 Mux(self.d8[i] | self.d16[i // 2],
867 Mux(self.d8[i], self.i8.part(i * 8, 8),
868 self.i16.part(i * 8, 8)),
869 Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
870 self.i64.part(i * 8, 8))))
871 ol.append(op)
872 m.d.comb += self.out.eq(Cat(*ol))
873 return m
874
875
876 class OrMod(Elaboratable):
877 """ ORs four values together in a hierarchical tree
878 """
879 def __init__(self, wid):
880 self.wid = wid
881 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
882 for i in range(4)]
883 self.orout = Signal(wid, reset_less=True)
884
885 def elaborate(self, platform):
886 m = Module()
887 or1 = Signal(self.wid, reset_less=True)
888 or2 = Signal(self.wid, reset_less=True)
889 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
890 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
891 m.d.comb += self.orout.eq(or1 | or2)
892
893 return m
894
895
896 class Signs(Elaboratable):
897 """ determines whether a or b are signed numbers
898 based on the required operation type (OP_MUL_*)
899 """
900
901 def __init__(self):
902 self.part_ops = Signal(2, reset_less=True)
903 self.a_signed = Signal(reset_less=True)
904 self.b_signed = Signal(reset_less=True)
905
906 def elaborate(self, platform):
907
908 m = Module()
909
910 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
911 bsig = (self.part_ops == OP_MUL_LOW) \
912 | (self.part_ops == OP_MUL_SIGNED_HIGH)
913 m.d.comb += self.a_signed.eq(asig)
914 m.d.comb += self.b_signed.eq(bsig)
915
916 return m
917
918
919 class Mul8_16_32_64(Elaboratable):
920 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
921
922 Supports partitioning into any combination of 8, 16, 32, and 64-bit
923 partitions on naturally-aligned boundaries. Supports the operation being
924 set for each partition independently.
925
926 :attribute part_pts: the input partition points. Has a partition point at
927 multiples of 8 in 0 < i < 64. Each partition point's associated
928 ``Value`` is a ``Signal``. Modification not supported, except for by
929 ``Signal.eq``.
930 :attribute part_ops: the operation for each byte. The operation for a
931 particular partition is selected by assigning the selected operation
932 code to each byte in the partition. The allowed operation codes are:
933
934 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
935 RISC-V's `mul` instruction.
936 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
937 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
938 instruction.
939 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
940 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
941 `mulhsu` instruction.
942 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
943 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
944 instruction.
945 """
946
947 def __init__(self, register_levels=()):
948 """ register_levels: specifies the points in the cascade at which
949 flip-flops are to be inserted.
950 """
951
952 # parameter(s)
953 self.register_levels = list(register_levels)
954
955 # inputs
956 self.part_pts = PartitionPoints()
957 for i in range(8, 64, 8):
958 self.part_pts[i] = Signal(name=f"part_pts_{i}")
959 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
960 self.a = Signal(64)
961 self.b = Signal(64)
962
963 # intermediates (needed for unit tests)
964 self._intermediate_output = Signal(128)
965
966 # output
967 self.output = Signal(64)
968
969 def _part_byte(self, index):
970 if index == -1 or index == 7:
971 return C(True, 1)
972 assert index >= 0 and index < 8
973 return self.part_pts[index * 8 + 8]
974
975 def elaborate(self, platform):
976 m = Module()
977
978 # collect part-bytes
979 pbs = Signal(8, reset_less=True)
980 tl = []
981 for i in range(8):
982 pb = Signal(name="pb%d" % i, reset_less=True)
983 m.d.comb += pb.eq(self._part_byte(i))
984 tl.append(pb)
985 m.d.comb += pbs.eq(Cat(*tl))
986
987 # local variables
988 signs = []
989 for i in range(8):
990 s = Signs()
991 signs.append(s)
992 setattr(m.submodules, "signs%d" % i, s)
993 m.d.comb += s.part_ops.eq(self.part_ops[i])
994
995 n_levels = len(self.register_levels)+1
996 m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
997 m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
998 m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
999 m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
1000 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1001 for mod in [part_8, part_16, part_32, part_64]:
1002 m.d.comb += mod.a.eq(self.a)
1003 m.d.comb += mod.b.eq(self.b)
1004 for i in range(len(signs)):
1005 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1006 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1007 m.d.comb += mod.pbs.eq(pbs)
1008 nat_l.append(mod.not_a_term)
1009 nbt_l.append(mod.not_b_term)
1010 nla_l.append(mod.neg_lsb_a_term)
1011 nlb_l.append(mod.neg_lsb_b_term)
1012
1013 terms = []
1014
1015 for a_index in range(8):
1016 t = ProductTerms(8, 128, 8, a_index, 8)
1017 setattr(m.submodules, "terms_%d" % a_index, t)
1018
1019 m.d.comb += t.a.eq(self.a)
1020 m.d.comb += t.b.eq(self.b)
1021 m.d.comb += t.pb_en.eq(pbs)
1022
1023 for term in t.terms:
1024 terms.append(term)
1025
1026 # it's fine to bitwise-or data together since they are never enabled
1027 # at the same time
1028 m.submodules.nat_or = nat_or = OrMod(128)
1029 m.submodules.nbt_or = nbt_or = OrMod(128)
1030 m.submodules.nla_or = nla_or = OrMod(128)
1031 m.submodules.nlb_or = nlb_or = OrMod(128)
1032 for l, mod in [(nat_l, nat_or),
1033 (nbt_l, nbt_or),
1034 (nla_l, nla_or),
1035 (nlb_l, nlb_or)]:
1036 for i in range(len(l)):
1037 m.d.comb += mod.orin[i].eq(l[i])
1038 terms.append(mod.orout)
1039
1040 expanded_part_pts = PartitionPoints()
1041 for i, v in self.part_pts.items():
1042 signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
1043 expanded_part_pts[i * 2] = signal
1044 m.d.comb += signal.eq(v)
1045
1046 add_reduce = AddReduce(terms,
1047 128,
1048 self.register_levels,
1049 expanded_part_pts,
1050 self.part_ops)
1051
1052 out_part_ops = add_reduce.levels[-1].out_part_ops
1053
1054 m.submodules.add_reduce = add_reduce
1055 m.d.comb += self._intermediate_output.eq(add_reduce.output)
1056 # create _output_64
1057 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1058 m.d.comb += io64.intermed.eq(self._intermediate_output)
1059 for i in range(8):
1060 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1061
1062 # create _output_32
1063 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1064 m.d.comb += io32.intermed.eq(self._intermediate_output)
1065 for i in range(8):
1066 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1067
1068 # create _output_16
1069 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1070 m.d.comb += io16.intermed.eq(self._intermediate_output)
1071 for i in range(8):
1072 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1073
1074 # create _output_8
1075 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1076 m.d.comb += io8.intermed.eq(self._intermediate_output)
1077 for i in range(8):
1078 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1079
1080 # final output
1081 m.submodules.finalout = finalout = FinalOut(64)
1082 for i in range(len(part_8.delayed_parts[-1])):
1083 m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
1084 for i in range(len(part_16.delayed_parts[-1])):
1085 m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
1086 for i in range(len(part_32.delayed_parts[-1])):
1087 m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
1088 m.d.comb += finalout.i8.eq(io8.output)
1089 m.d.comb += finalout.i16.eq(io16.output)
1090 m.d.comb += finalout.i32.eq(io32.output)
1091 m.d.comb += finalout.i64.eq(io64.output)
1092 m.d.comb += self.output.eq(finalout.out)
1093
1094 return m
1095
1096
1097 if __name__ == "__main__":
1098 m = Mul8_16_32_64()
1099 main(m, ports=[m.a,
1100 m.b,
1101 m._intermediate_output,
1102 m.output,
1103 *m.part_ops,
1104 *m.part_pts.values()])