445095b1af7eee292f395747a327ac8627bea7ad
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0, mul=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
58 """
59 if name is None:
60 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
61 retval = PartitionPoints()
62 for point, enabled in self.items():
63 point *= mul
64 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
65 return retval
66
67 def eq(self, rhs):
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self.keys()) != set(rhs.keys()):
70 raise ValueError("incompatible point set")
71 for point, enabled in self.items():
72 yield enabled.eq(rhs[point])
73
74 def as_mask(self, width):
75 """Create a bit-mask from `self`.
76
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
79
80 :param width: the bit width of the resulting mask
81 """
82 bits = []
83 for i in range(width):
84 if i in self:
85 bits.append(~self[i])
86 else:
87 bits.append(True)
88 return Cat(*bits)
89
90 def get_max_partition_count(self, width):
91 """Get the maximum number of partitions.
92
93 Gets the number of partitions when all partition points are enabled.
94 """
95 retval = 1
96 for point in self.keys():
97 if point < width:
98 retval += 1
99 return retval
100
101 def fits_in_width(self, width):
102 """Check if all partition points are smaller than `width`."""
103 for point in self.keys():
104 if point >= width:
105 return False
106 return True
107
108 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
109 if index == -1 or index == 7:
110 return C(True, 1)
111 assert index >= 0 and index < 8
112 return self[(index * 8 + 8)*mfactor]
113
114
115 class FullAdder(Elaboratable):
116 """Full Adder.
117
118 :attribute in0: the first input
119 :attribute in1: the second input
120 :attribute in2: the third input
121 :attribute sum: the sum output
122 :attribute carry: the carry output
123
124 Rather than do individual full adders (and have an array of them,
125 which would be very slow to simulate), this module can specify the
126 bit width of the inputs and outputs: in effect it performs multiple
127 Full 3-2 Add operations "in parallel".
128 """
129
130 def __init__(self, width):
131 """Create a ``FullAdder``.
132
133 :param width: the bit width of the input and output
134 """
135 self.in0 = Signal(width)
136 self.in1 = Signal(width)
137 self.in2 = Signal(width)
138 self.sum = Signal(width)
139 self.carry = Signal(width)
140
141 def elaborate(self, platform):
142 """Elaborate this module."""
143 m = Module()
144 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
145 m.d.comb += self.carry.eq((self.in0 & self.in1)
146 | (self.in1 & self.in2)
147 | (self.in2 & self.in0))
148 return m
149
150
151 class MaskedFullAdder(Elaboratable):
152 """Masked Full Adder.
153
154 :attribute mask: the carry partition mask
155 :attribute in0: the first input
156 :attribute in1: the second input
157 :attribute in2: the third input
158 :attribute sum: the sum output
159 :attribute mcarry: the masked carry output
160
161 FullAdders are always used with a "mask" on the output. To keep
162 the graphviz "clean", this class performs the masking here rather
163 than inside a large for-loop.
164
165 See the following discussion as to why this is no longer derived
166 from FullAdder. Each carry is shifted here *before* being ANDed
167 with the mask, so that an AOI cell may be used (which is more
168 gate-efficient)
169 https://en.wikipedia.org/wiki/AND-OR-Invert
170 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
171 """
172
173 def __init__(self, width):
174 """Create a ``MaskedFullAdder``.
175
176 :param width: the bit width of the input and output
177 """
178 self.width = width
179 self.mask = Signal(width, reset_less=True)
180 self.mcarry = Signal(width, reset_less=True)
181 self.in0 = Signal(width, reset_less=True)
182 self.in1 = Signal(width, reset_less=True)
183 self.in2 = Signal(width, reset_less=True)
184 self.sum = Signal(width, reset_less=True)
185
186 def elaborate(self, platform):
187 """Elaborate this module."""
188 m = Module()
189 s1 = Signal(self.width, reset_less=True)
190 s2 = Signal(self.width, reset_less=True)
191 s3 = Signal(self.width, reset_less=True)
192 c1 = Signal(self.width, reset_less=True)
193 c2 = Signal(self.width, reset_less=True)
194 c3 = Signal(self.width, reset_less=True)
195 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
196 m.d.comb += s1.eq(Cat(0, self.in0))
197 m.d.comb += s2.eq(Cat(0, self.in1))
198 m.d.comb += s3.eq(Cat(0, self.in2))
199 m.d.comb += c1.eq(s1 & s2 & self.mask)
200 m.d.comb += c2.eq(s2 & s3 & self.mask)
201 m.d.comb += c3.eq(s3 & s1 & self.mask)
202 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
203 return m
204
205
206 class PartitionedAdder(Elaboratable):
207 """Partitioned Adder.
208
209 Performs the final add. The partition points are included in the
210 actual add (in one of the operands only), which causes a carry over
211 to the next bit. Then the final output *removes* the extra bits from
212 the result.
213
214 partition: .... P... P... P... P... (32 bits)
215 a : .... .... .... .... .... (32 bits)
216 b : .... .... .... .... .... (32 bits)
217 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
218 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
219 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
220 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
221
222 :attribute width: the bit width of the input and output. Read-only.
223 :attribute a: the first input to the adder
224 :attribute b: the second input to the adder
225 :attribute output: the sum output
226 :attribute partition_points: the input partition points. Modification not
227 supported, except for by ``Signal.eq``.
228 """
229
230 def __init__(self, width, partition_points):
231 """Create a ``PartitionedAdder``.
232
233 :param width: the bit width of the input and output
234 :param partition_points: the input partition points
235 """
236 self.width = width
237 self.a = Signal(width)
238 self.b = Signal(width)
239 self.output = Signal(width)
240 self.partition_points = PartitionPoints(partition_points)
241 if not self.partition_points.fits_in_width(width):
242 raise ValueError("partition_points doesn't fit in width")
243 expanded_width = 0
244 for i in range(self.width):
245 if i in self.partition_points:
246 expanded_width += 1
247 expanded_width += 1
248 self._expanded_width = expanded_width
249 # XXX these have to remain here due to some horrible nmigen
250 # simulation bugs involving sync. it is *not* necessary to
251 # have them here, they should (under normal circumstances)
252 # be moved into elaborate, as they are entirely local
253 self._expanded_a = Signal(expanded_width) # includes extra part-points
254 self._expanded_b = Signal(expanded_width) # likewise.
255 self._expanded_o = Signal(expanded_width) # likewise.
256
257 def elaborate(self, platform):
258 """Elaborate this module."""
259 m = Module()
260 expanded_index = 0
261 # store bits in a list, use Cat later. graphviz is much cleaner
262 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
263
264 # partition points are "breaks" (extra zeros or 1s) in what would
265 # otherwise be a massive long add. when the "break" points are 0,
266 # whatever is in it (in the output) is discarded. however when
267 # there is a "1", it causes a roll-over carry to the *next* bit.
268 # we still ignore the "break" bit in the [intermediate] output,
269 # however by that time we've got the effect that we wanted: the
270 # carry has been carried *over* the break point.
271
272 for i in range(self.width):
273 if i in self.partition_points:
274 # add extra bit set to 0 + 0 for enabled partition points
275 # and 1 + 0 for disabled partition points
276 ea.append(self._expanded_a[expanded_index])
277 al.append(~self.partition_points[i]) # add extra bit in a
278 eb.append(self._expanded_b[expanded_index])
279 bl.append(C(0)) # yes, add a zero
280 expanded_index += 1 # skip the extra point. NOT in the output
281 ea.append(self._expanded_a[expanded_index])
282 eb.append(self._expanded_b[expanded_index])
283 eo.append(self._expanded_o[expanded_index])
284 al.append(self.a[i])
285 bl.append(self.b[i])
286 ol.append(self.output[i])
287 expanded_index += 1
288
289 # combine above using Cat
290 m.d.comb += Cat(*ea).eq(Cat(*al))
291 m.d.comb += Cat(*eb).eq(Cat(*bl))
292 m.d.comb += Cat(*ol).eq(Cat(*eo))
293
294 # use only one addition to take advantage of look-ahead carry and
295 # special hardware on FPGAs
296 m.d.comb += self._expanded_o.eq(
297 self._expanded_a + self._expanded_b)
298 return m
299
300
301 FULL_ADDER_INPUT_COUNT = 3
302
303 class AddReduceData:
304
305 def __init__(self, ppoints, n_inputs, output_width, n_parts):
306 self.part_ops = [Signal(2, name=f"part_ops_{i}")
307 for i in range(n_parts)]
308 self.inputs = [Signal(output_width, name=f"inputs[{i}]")
309 for i in range(n_inputs)]
310 self.reg_partition_points = ppoints.like()
311
312 def eq(self, rhs):
313 return [self.reg_partition_points.eq(rhs.reg_partition_points)] + \
314 [self.inputs[i].eq(rhs.inputs[i])
315 for i in range(len(self.inputs))] + \
316 [self.part_ops[i].eq(rhs.part_ops[i])
317 for i in range(len(self.part_ops))]
318
319
320 class FinalAdd(Elaboratable):
321 """ Final stage of add reduce
322 """
323
324 def __init__(self, n_inputs, output_width, n_parts, register_levels,
325 partition_points):
326 self.i = AddReduceData(partition_points, n_inputs,
327 output_width, n_parts)
328 self.n_inputs = n_inputs
329 self.n_parts = n_parts
330 self.out_part_ops = self.i.part_ops
331 self._resized_inputs = self.i.inputs
332 self.register_levels = list(register_levels)
333 self.output = Signal(output_width)
334 self.partition_points = PartitionPoints(partition_points)
335 if not self.partition_points.fits_in_width(output_width):
336 raise ValueError("partition_points doesn't fit in output_width")
337 self._reg_partition_points = self.i.reg_partition_points
338 self.intermediate_terms = []
339
340 def elaborate(self, platform):
341 """Elaborate this module."""
342 m = Module()
343
344 if self.n_inputs == 0:
345 # use 0 as the default output value
346 m.d.comb += self.output.eq(0)
347 elif self.n_inputs == 1:
348 # handle single input
349 m.d.comb += self.output.eq(self._resized_inputs[0])
350 else:
351 # base case for adding 2 inputs
352 assert self.n_inputs == 2
353 adder = PartitionedAdder(len(self.output),
354 self._reg_partition_points)
355 m.submodules.final_adder = adder
356 m.d.comb += adder.a.eq(self._resized_inputs[0])
357 m.d.comb += adder.b.eq(self._resized_inputs[1])
358 m.d.comb += self.output.eq(adder.output)
359 return m
360
361
362 class AddReduceSingle(Elaboratable):
363 """Add list of numbers together.
364
365 :attribute inputs: input ``Signal``s to be summed. Modification not
366 supported, except for by ``Signal.eq``.
367 :attribute register_levels: List of nesting levels that should have
368 pipeline registers.
369 :attribute output: output sum.
370 :attribute partition_points: the input partition points. Modification not
371 supported, except for by ``Signal.eq``.
372 """
373
374 def __init__(self, n_inputs, output_width, n_parts, register_levels,
375 partition_points):
376 """Create an ``AddReduce``.
377
378 :param inputs: input ``Signal``s to be summed.
379 :param output_width: bit-width of ``output``.
380 :param register_levels: List of nesting levels that should have
381 pipeline registers.
382 :param partition_points: the input partition points.
383 """
384 self.n_inputs = n_inputs
385 self.n_parts = n_parts
386 self.output_width = output_width
387 self.i = AddReduceData(partition_points, n_inputs,
388 output_width, n_parts)
389 self.out_part_ops = self.i.part_ops
390 self._resized_inputs = self.i.inputs
391 self.register_levels = list(register_levels)
392 self.partition_points = PartitionPoints(partition_points)
393 if not self.partition_points.fits_in_width(output_width):
394 raise ValueError("partition_points doesn't fit in output_width")
395 self._reg_partition_points = self.i.reg_partition_points
396
397 max_level = AddReduceSingle.get_max_level(n_inputs)
398 for level in self.register_levels:
399 if level > max_level:
400 raise ValueError(
401 "not enough adder levels for specified register levels")
402
403 # this is annoying. we have to create the modules (and terms)
404 # because we need to know what they are (in order to set up the
405 # interconnects back in AddReduce), but cannot do the m.d.comb +=
406 # etc because this is not in elaboratable.
407 self.groups = AddReduceSingle.full_adder_groups(n_inputs)
408 self._intermediate_terms = []
409 if len(self.groups) != 0:
410 self.create_next_terms()
411
412 @staticmethod
413 def get_max_level(input_count):
414 """Get the maximum level.
415
416 All ``register_levels`` must be less than or equal to the maximum
417 level.
418 """
419 retval = 0
420 while True:
421 groups = AddReduceSingle.full_adder_groups(input_count)
422 if len(groups) == 0:
423 return retval
424 input_count %= FULL_ADDER_INPUT_COUNT
425 input_count += 2 * len(groups)
426 retval += 1
427
428 @staticmethod
429 def full_adder_groups(input_count):
430 """Get ``inputs`` indices for which a full adder should be built."""
431 return range(0,
432 input_count - FULL_ADDER_INPUT_COUNT + 1,
433 FULL_ADDER_INPUT_COUNT)
434
435 def elaborate(self, platform):
436 """Elaborate this module."""
437 m = Module()
438
439 for (value, term) in self._intermediate_terms:
440 m.d.comb += term.eq(value)
441
442 mask = self._reg_partition_points.as_mask(self.output_width)
443 m.d.comb += self.part_mask.eq(mask)
444
445 # add and link the intermediate term modules
446 for i, (iidx, adder_i) in enumerate(self.adders):
447 setattr(m.submodules, f"adder_{i}", adder_i)
448
449 m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
450 m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
451 m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
452 m.d.comb += adder_i.mask.eq(self.part_mask)
453
454 return m
455
456 def create_next_terms(self):
457
458 # go on to prepare recursive case
459 intermediate_terms = []
460 _intermediate_terms = []
461
462 def add_intermediate_term(value):
463 intermediate_term = Signal(
464 self.output_width,
465 name=f"intermediate_terms[{len(intermediate_terms)}]")
466 _intermediate_terms.append((value, intermediate_term))
467 intermediate_terms.append(intermediate_term)
468
469 # store mask in intermediary (simplifies graph)
470 self.part_mask = Signal(self.output_width, reset_less=True)
471
472 # create full adders for this recursive level.
473 # this shrinks N terms to 2 * (N // 3) plus the remainder
474 self.adders = []
475 for i in self.groups:
476 adder_i = MaskedFullAdder(self.output_width)
477 self.adders.append((i, adder_i))
478 # add both the sum and the masked-carry to the next level.
479 # 3 inputs have now been reduced to 2...
480 add_intermediate_term(adder_i.sum)
481 add_intermediate_term(adder_i.mcarry)
482 # handle the remaining inputs.
483 if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
484 add_intermediate_term(self._resized_inputs[-1])
485 elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
486 # Just pass the terms to the next layer, since we wouldn't gain
487 # anything by using a half adder since there would still be 2 terms
488 # and just passing the terms to the next layer saves gates.
489 add_intermediate_term(self._resized_inputs[-2])
490 add_intermediate_term(self._resized_inputs[-1])
491 else:
492 assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
493
494 self.intermediate_terms = intermediate_terms
495 self._intermediate_terms = _intermediate_terms
496
497
498 class AddReduce(Elaboratable):
499 """Recursively Add list of numbers together.
500
501 :attribute inputs: input ``Signal``s to be summed. Modification not
502 supported, except for by ``Signal.eq``.
503 :attribute register_levels: List of nesting levels that should have
504 pipeline registers.
505 :attribute output: output sum.
506 :attribute partition_points: the input partition points. Modification not
507 supported, except for by ``Signal.eq``.
508 """
509
510 def __init__(self, inputs, output_width, register_levels, partition_points,
511 part_ops):
512 """Create an ``AddReduce``.
513
514 :param inputs: input ``Signal``s to be summed.
515 :param output_width: bit-width of ``output``.
516 :param register_levels: List of nesting levels that should have
517 pipeline registers.
518 :param partition_points: the input partition points.
519 """
520 self.inputs = inputs
521 self.part_ops = part_ops
522 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
523 for i in range(len(part_ops))]
524 self.output = Signal(output_width)
525 self.output_width = output_width
526 self.register_levels = register_levels
527 self.partition_points = partition_points
528
529 self.create_levels()
530
531 @staticmethod
532 def get_max_level(input_count):
533 return AddReduceSingle.get_max_level(input_count)
534
535 @staticmethod
536 def next_register_levels(register_levels):
537 """``Iterable`` of ``register_levels`` for next recursive level."""
538 for level in register_levels:
539 if level > 0:
540 yield level - 1
541
542 def create_levels(self):
543 """creates reduction levels"""
544
545 mods = []
546 next_levels = self.register_levels
547 partition_points = self.partition_points
548 inputs = self.inputs
549 part_ops = self.part_ops
550 n_parts = len(part_ops)
551 while True:
552 ilen = len(inputs)
553 next_level = AddReduceSingle(ilen, self.output_width, n_parts,
554 next_levels, partition_points)
555 mods.append(next_level)
556 next_levels = list(AddReduce.next_register_levels(next_levels))
557 partition_points = next_level._reg_partition_points
558 inputs = next_level.intermediate_terms
559 ilen = len(inputs)
560 part_ops = next_level.out_part_ops
561 groups = AddReduceSingle.full_adder_groups(len(inputs))
562 if len(groups) == 0:
563 break
564
565 if ilen != 0:
566 next_level = FinalAdd(ilen, self.output_width, n_parts,
567 next_levels, partition_points)
568 mods.append(next_level)
569
570 self.levels = mods
571
572 def elaborate(self, platform):
573 """Elaborate this module."""
574 m = Module()
575
576 for i, next_level in enumerate(self.levels):
577 setattr(m.submodules, "next_level%d" % i, next_level)
578
579 partition_points = self.partition_points
580 inputs = self.inputs
581 part_ops = self.part_ops
582 for i in range(len(self.levels)):
583 mcur = self.levels[i]
584 inassign = [mcur._resized_inputs[i].eq(inputs[i])
585 for i in range(len(inputs))]
586 copy_part_ops = [mcur.out_part_ops[i].eq(part_ops[i])
587 for i in range(len(part_ops))]
588 if 0 in mcur.register_levels:
589 m.d.sync += copy_part_ops
590 m.d.sync += inassign
591 m.d.sync += mcur._reg_partition_points.eq(partition_points)
592 else:
593 m.d.comb += copy_part_ops
594 m.d.comb += inassign
595 m.d.comb += mcur._reg_partition_points.eq(partition_points)
596 partition_points = mcur._reg_partition_points
597 inputs = mcur.intermediate_terms
598 part_ops = mcur.out_part_ops
599
600 # output comes from last module
601 m.d.comb += self.output.eq(next_level.output)
602 copy_part_ops = [self.out_part_ops[i].eq(next_level.out_part_ops[i])
603 for i in range(len(self.part_ops))]
604 m.d.comb += copy_part_ops
605
606 return m
607
608
609 OP_MUL_LOW = 0
610 OP_MUL_SIGNED_HIGH = 1
611 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
612 OP_MUL_UNSIGNED_HIGH = 3
613
614
615 def get_term(value, shift=0, enabled=None):
616 if enabled is not None:
617 value = Mux(enabled, value, 0)
618 if shift > 0:
619 value = Cat(Repl(C(0, 1), shift), value)
620 else:
621 assert shift == 0
622 return value
623
624
625 class ProductTerm(Elaboratable):
626 """ this class creates a single product term (a[..]*b[..]).
627 it has a design flaw in that is the *output* that is selected,
628 where the multiplication(s) are combinatorially generated
629 all the time.
630 """
631
632 def __init__(self, width, twidth, pbwid, a_index, b_index):
633 self.a_index = a_index
634 self.b_index = b_index
635 shift = 8 * (self.a_index + self.b_index)
636 self.pwidth = width
637 self.twidth = twidth
638 self.width = width*2
639 self.shift = shift
640
641 self.ti = Signal(self.width, reset_less=True)
642 self.term = Signal(twidth, reset_less=True)
643 self.a = Signal(twidth//2, reset_less=True)
644 self.b = Signal(twidth//2, reset_less=True)
645 self.pb_en = Signal(pbwid, reset_less=True)
646
647 self.tl = tl = []
648 min_index = min(self.a_index, self.b_index)
649 max_index = max(self.a_index, self.b_index)
650 for i in range(min_index, max_index):
651 tl.append(self.pb_en[i])
652 name = "te_%d_%d" % (self.a_index, self.b_index)
653 if len(tl) > 0:
654 term_enabled = Signal(name=name, reset_less=True)
655 else:
656 term_enabled = None
657 self.enabled = term_enabled
658 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
659
660 def elaborate(self, platform):
661
662 m = Module()
663 if self.enabled is not None:
664 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
665
666 bsa = Signal(self.width, reset_less=True)
667 bsb = Signal(self.width, reset_less=True)
668 a_index, b_index = self.a_index, self.b_index
669 pwidth = self.pwidth
670 m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
671 m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
672 m.d.comb += self.ti.eq(bsa * bsb)
673 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
674 """
675 #TODO: sort out width issues, get inputs a/b switched on/off.
676 #data going into Muxes is 1/2 the required width
677
678 pwidth = self.pwidth
679 width = self.width
680 bsa = Signal(self.twidth//2, reset_less=True)
681 bsb = Signal(self.twidth//2, reset_less=True)
682 asel = Signal(width, reset_less=True)
683 bsel = Signal(width, reset_less=True)
684 a_index, b_index = self.a_index, self.b_index
685 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
686 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
687 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
688 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
689 m.d.comb += self.ti.eq(bsa * bsb)
690 m.d.comb += self.term.eq(self.ti)
691 """
692
693 return m
694
695
696 class ProductTerms(Elaboratable):
697 """ creates a bank of product terms. also performs the actual bit-selection
698 this class is to be wrapped with a for-loop on the "a" operand.
699 it creates a second-level for-loop on the "b" operand.
700 """
701 def __init__(self, width, twidth, pbwid, a_index, blen):
702 self.a_index = a_index
703 self.blen = blen
704 self.pwidth = width
705 self.twidth = twidth
706 self.pbwid = pbwid
707 self.a = Signal(twidth//2, reset_less=True)
708 self.b = Signal(twidth//2, reset_less=True)
709 self.pb_en = Signal(pbwid, reset_less=True)
710 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
711 for i in range(blen)]
712
713 def elaborate(self, platform):
714
715 m = Module()
716
717 for b_index in range(self.blen):
718 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
719 self.a_index, b_index)
720 setattr(m.submodules, "term_%d" % b_index, t)
721
722 m.d.comb += t.a.eq(self.a)
723 m.d.comb += t.b.eq(self.b)
724 m.d.comb += t.pb_en.eq(self.pb_en)
725
726 m.d.comb += self.terms[b_index].eq(t.term)
727
728 return m
729
730
731 class LSBNegTerm(Elaboratable):
732
733 def __init__(self, bit_width):
734 self.bit_width = bit_width
735 self.part = Signal(reset_less=True)
736 self.signed = Signal(reset_less=True)
737 self.op = Signal(bit_width, reset_less=True)
738 self.msb = Signal(reset_less=True)
739 self.nt = Signal(bit_width*2, reset_less=True)
740 self.nl = Signal(bit_width*2, reset_less=True)
741
742 def elaborate(self, platform):
743 m = Module()
744 comb = m.d.comb
745 bit_wid = self.bit_width
746 ext = Repl(0, bit_wid) # extend output to HI part
747
748 # determine sign of each incoming number *in this partition*
749 enabled = Signal(reset_less=True)
750 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
751
752 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
753 # negation operation is split into a bitwise not and a +1.
754 # likewise for 16, 32, and 64-bit values.
755
756 # width-extended 1s complement if a is signed, otherwise zero
757 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
758
759 # add 1 if signed, otherwise add zero
760 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
761
762 return m
763
764
765 class Parts(Elaboratable):
766
767 def __init__(self, pbwid, epps, n_parts):
768 self.pbwid = pbwid
769 # inputs
770 self.epps = PartitionPoints.like(epps, name="epps") # expanded points
771 # outputs
772 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
773
774 def elaborate(self, platform):
775 m = Module()
776
777 epps, parts = self.epps, self.parts
778 # collect part-bytes (double factor because the input is extended)
779 pbs = Signal(self.pbwid, reset_less=True)
780 tl = []
781 for i in range(self.pbwid):
782 pb = Signal(name="pb%d" % i, reset_less=True)
783 m.d.comb += pb.eq(epps.part_byte(i, mfactor=2)) # double
784 tl.append(pb)
785 m.d.comb += pbs.eq(Cat(*tl))
786
787 # negated-temporary copy of partition bits
788 npbs = Signal.like(pbs, reset_less=True)
789 m.d.comb += npbs.eq(~pbs)
790 byte_count = 8 // len(parts)
791 for i in range(len(parts)):
792 pbl = []
793 pbl.append(npbs[i * byte_count - 1])
794 for j in range(i * byte_count, (i + 1) * byte_count - 1):
795 pbl.append(pbs[j])
796 pbl.append(npbs[(i + 1) * byte_count - 1])
797 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
798 m.d.comb += value.eq(Cat(*pbl))
799 m.d.comb += parts[i].eq(~(value).bool())
800
801 return m
802
803
804 class Part(Elaboratable):
805 """ a key class which, depending on the partitioning, will determine
806 what action to take when parts of the output are signed or unsigned.
807
808 this requires 2 pieces of data *per operand, per partition*:
809 whether the MSB is HI/LO (per partition!), and whether a signed
810 or unsigned operation has been *requested*.
811
812 once that is determined, signed is basically carried out
813 by splitting 2's complement into 1's complement plus one.
814 1's complement is just a bit-inversion.
815
816 the extra terms - as separate terms - are then thrown at the
817 AddReduce alongside the multiplication part-results.
818 """
819 def __init__(self, epps, width, n_parts, n_levels, pbwid):
820
821 self.pbwid = pbwid
822 self.epps = epps
823
824 # inputs
825 self.a = Signal(64)
826 self.b = Signal(64)
827 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
828 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
829 self.pbs = Signal(pbwid, reset_less=True)
830
831 # outputs
832 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
833
834 self.not_a_term = Signal(width)
835 self.neg_lsb_a_term = Signal(width)
836 self.not_b_term = Signal(width)
837 self.neg_lsb_b_term = Signal(width)
838
839 def elaborate(self, platform):
840 m = Module()
841
842 pbs, parts = self.pbs, self.parts
843 epps = self.epps
844 m.submodules.p = p = Parts(self.pbwid, epps, len(parts))
845 m.d.comb += p.epps.eq(epps)
846 parts = p.parts
847
848 byte_count = 8 // len(parts)
849
850 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
851 self.not_a_term, self.neg_lsb_a_term,
852 self.not_b_term, self.neg_lsb_b_term)
853
854 byte_width = 8 // len(parts) # byte width
855 bit_wid = 8 * byte_width # bit width
856 nat, nbt, nla, nlb = [], [], [], []
857 for i in range(len(parts)):
858 # work out bit-inverted and +1 term for a.
859 pa = LSBNegTerm(bit_wid)
860 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
861 m.d.comb += pa.part.eq(parts[i])
862 m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
863 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
864 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
865 nat.append(pa.nt)
866 nla.append(pa.nl)
867
868 # work out bit-inverted and +1 term for b
869 pb = LSBNegTerm(bit_wid)
870 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
871 m.d.comb += pb.part.eq(parts[i])
872 m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
873 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
874 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
875 nbt.append(pb.nt)
876 nlb.append(pb.nl)
877
878 # concatenate together and return all 4 results.
879 m.d.comb += [not_a_term.eq(Cat(*nat)),
880 not_b_term.eq(Cat(*nbt)),
881 neg_lsb_a_term.eq(Cat(*nla)),
882 neg_lsb_b_term.eq(Cat(*nlb)),
883 ]
884
885 return m
886
887
888 class IntermediateOut(Elaboratable):
889 """ selects the HI/LO part of the multiplication, for a given bit-width
890 the output is also reconstructed in its SIMD (partition) lanes.
891 """
892 def __init__(self, width, out_wid, n_parts):
893 self.width = width
894 self.n_parts = n_parts
895 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
896 for i in range(8)]
897 self.intermed = Signal(out_wid, reset_less=True)
898 self.output = Signal(out_wid//2, reset_less=True)
899
900 def elaborate(self, platform):
901 m = Module()
902
903 ol = []
904 w = self.width
905 sel = w // 8
906 for i in range(self.n_parts):
907 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
908 m.d.comb += op.eq(
909 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
910 self.intermed.part(i * w*2, w),
911 self.intermed.part(i * w*2 + w, w)))
912 ol.append(op)
913 m.d.comb += self.output.eq(Cat(*ol))
914
915 return m
916
917
918 class FinalOut(Elaboratable):
919 """ selects the final output based on the partitioning.
920
921 each byte is selectable independently, i.e. it is possible
922 that some partitions requested 8-bit computation whilst others
923 requested 16 or 32 bit.
924 """
925 def __init__(self, out_wid):
926 # inputs
927 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
928 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
929 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
930
931 self.i8 = Signal(out_wid, reset_less=True)
932 self.i16 = Signal(out_wid, reset_less=True)
933 self.i32 = Signal(out_wid, reset_less=True)
934 self.i64 = Signal(out_wid, reset_less=True)
935
936 # output
937 self.out = Signal(out_wid, reset_less=True)
938
939 def elaborate(self, platform):
940 m = Module()
941 ol = []
942 for i in range(8):
943 # select one of the outputs: d8 selects i8, d16 selects i16
944 # d32 selects i32, and the default is i64.
945 # d8 and d16 are ORed together in the first Mux
946 # then the 2nd selects either i8 or i16.
947 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
948 op = Signal(8, reset_less=True, name="op_%d" % i)
949 m.d.comb += op.eq(
950 Mux(self.d8[i] | self.d16[i // 2],
951 Mux(self.d8[i], self.i8.part(i * 8, 8),
952 self.i16.part(i * 8, 8)),
953 Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
954 self.i64.part(i * 8, 8))))
955 ol.append(op)
956 m.d.comb += self.out.eq(Cat(*ol))
957 return m
958
959
960 class OrMod(Elaboratable):
961 """ ORs four values together in a hierarchical tree
962 """
963 def __init__(self, wid):
964 self.wid = wid
965 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
966 for i in range(4)]
967 self.orout = Signal(wid, reset_less=True)
968
969 def elaborate(self, platform):
970 m = Module()
971 or1 = Signal(self.wid, reset_less=True)
972 or2 = Signal(self.wid, reset_less=True)
973 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
974 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
975 m.d.comb += self.orout.eq(or1 | or2)
976
977 return m
978
979
980 class Signs(Elaboratable):
981 """ determines whether a or b are signed numbers
982 based on the required operation type (OP_MUL_*)
983 """
984
985 def __init__(self):
986 self.part_ops = Signal(2, reset_less=True)
987 self.a_signed = Signal(reset_less=True)
988 self.b_signed = Signal(reset_less=True)
989
990 def elaborate(self, platform):
991
992 m = Module()
993
994 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
995 bsig = (self.part_ops == OP_MUL_LOW) \
996 | (self.part_ops == OP_MUL_SIGNED_HIGH)
997 m.d.comb += self.a_signed.eq(asig)
998 m.d.comb += self.b_signed.eq(bsig)
999
1000 return m
1001
1002
1003 class Mul8_16_32_64(Elaboratable):
1004 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1005
1006 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1007 partitions on naturally-aligned boundaries. Supports the operation being
1008 set for each partition independently.
1009
1010 :attribute part_pts: the input partition points. Has a partition point at
1011 multiples of 8 in 0 < i < 64. Each partition point's associated
1012 ``Value`` is a ``Signal``. Modification not supported, except for by
1013 ``Signal.eq``.
1014 :attribute part_ops: the operation for each byte. The operation for a
1015 particular partition is selected by assigning the selected operation
1016 code to each byte in the partition. The allowed operation codes are:
1017
1018 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1019 RISC-V's `mul` instruction.
1020 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1021 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1022 instruction.
1023 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1024 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1025 `mulhsu` instruction.
1026 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1027 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1028 instruction.
1029 """
1030
1031 def __init__(self, register_levels=()):
1032 """ register_levels: specifies the points in the cascade at which
1033 flip-flops are to be inserted.
1034 """
1035
1036 # parameter(s)
1037 self.register_levels = list(register_levels)
1038
1039 # inputs
1040 self.part_pts = PartitionPoints()
1041 for i in range(8, 64, 8):
1042 self.part_pts[i] = Signal(name=f"part_pts_{i}")
1043 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
1044 self.a = Signal(64)
1045 self.b = Signal(64)
1046
1047 # intermediates (needed for unit tests)
1048 self._intermediate_output = Signal(128)
1049
1050 # output
1051 self.output = Signal(64)
1052
1053 def elaborate(self, platform):
1054 m = Module()
1055
1056 # collect part-bytes
1057 pbs = Signal(8, reset_less=True)
1058 tl = []
1059 for i in range(8):
1060 pb = Signal(name="pb%d" % i, reset_less=True)
1061 m.d.comb += pb.eq(self.part_pts.part_byte(i))
1062 tl.append(pb)
1063 m.d.comb += pbs.eq(Cat(*tl))
1064
1065 # create (doubled) PartitionPoints (output is double input width)
1066 expanded_part_pts = eps = PartitionPoints()
1067 for i, v in self.part_pts.items():
1068 ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
1069 expanded_part_pts[i * 2] = ep
1070 m.d.comb += ep.eq(v)
1071
1072 # local variables
1073 signs = []
1074 for i in range(8):
1075 s = Signs()
1076 signs.append(s)
1077 setattr(m.submodules, "signs%d" % i, s)
1078 m.d.comb += s.part_ops.eq(self.part_ops[i])
1079
1080 n_levels = len(self.register_levels)+1
1081 m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
1082 m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
1083 m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
1084 m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
1085 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1086 for mod in [part_8, part_16, part_32, part_64]:
1087 m.d.comb += mod.a.eq(self.a)
1088 m.d.comb += mod.b.eq(self.b)
1089 for i in range(len(signs)):
1090 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1091 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1092 m.d.comb += mod.pbs.eq(pbs)
1093 nat_l.append(mod.not_a_term)
1094 nbt_l.append(mod.not_b_term)
1095 nla_l.append(mod.neg_lsb_a_term)
1096 nlb_l.append(mod.neg_lsb_b_term)
1097
1098 terms = []
1099
1100 for a_index in range(8):
1101 t = ProductTerms(8, 128, 8, a_index, 8)
1102 setattr(m.submodules, "terms_%d" % a_index, t)
1103
1104 m.d.comb += t.a.eq(self.a)
1105 m.d.comb += t.b.eq(self.b)
1106 m.d.comb += t.pb_en.eq(pbs)
1107
1108 for term in t.terms:
1109 terms.append(term)
1110
1111 # it's fine to bitwise-or data together since they are never enabled
1112 # at the same time
1113 m.submodules.nat_or = nat_or = OrMod(128)
1114 m.submodules.nbt_or = nbt_or = OrMod(128)
1115 m.submodules.nla_or = nla_or = OrMod(128)
1116 m.submodules.nlb_or = nlb_or = OrMod(128)
1117 for l, mod in [(nat_l, nat_or),
1118 (nbt_l, nbt_or),
1119 (nla_l, nla_or),
1120 (nlb_l, nlb_or)]:
1121 for i in range(len(l)):
1122 m.d.comb += mod.orin[i].eq(l[i])
1123 terms.append(mod.orout)
1124
1125 add_reduce = AddReduce(terms,
1126 128,
1127 self.register_levels,
1128 expanded_part_pts,
1129 self.part_ops)
1130
1131 out_part_ops = add_reduce.levels[-1].out_part_ops
1132 out_part_pts = add_reduce.levels[-1]._reg_partition_points
1133
1134 m.submodules.add_reduce = add_reduce
1135 m.d.comb += self._intermediate_output.eq(add_reduce.output)
1136 # create _output_64
1137 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1138 m.d.comb += io64.intermed.eq(self._intermediate_output)
1139 for i in range(8):
1140 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1141
1142 # create _output_32
1143 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1144 m.d.comb += io32.intermed.eq(self._intermediate_output)
1145 for i in range(8):
1146 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1147
1148 # create _output_16
1149 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1150 m.d.comb += io16.intermed.eq(self._intermediate_output)
1151 for i in range(8):
1152 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1153
1154 # create _output_8
1155 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1156 m.d.comb += io8.intermed.eq(self._intermediate_output)
1157 for i in range(8):
1158 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1159
1160 m.submodules.p_8 = p_8 = Parts(8, eps, len(part_8.parts))
1161 m.submodules.p_16 = p_16 = Parts(8, eps, len(part_16.parts))
1162 m.submodules.p_32 = p_32 = Parts(8, eps, len(part_32.parts))
1163 m.submodules.p_64 = p_64 = Parts(8, eps, len(part_64.parts))
1164
1165 m.d.comb += p_8.epps.eq(out_part_pts)
1166 m.d.comb += p_16.epps.eq(out_part_pts)
1167 m.d.comb += p_32.epps.eq(out_part_pts)
1168 m.d.comb += p_64.epps.eq(out_part_pts)
1169
1170 # final output
1171 m.submodules.finalout = finalout = FinalOut(64)
1172 for i in range(len(part_8.parts)):
1173 m.d.comb += finalout.d8[i].eq(p_8.parts[i])
1174 for i in range(len(part_16.parts)):
1175 m.d.comb += finalout.d16[i].eq(p_16.parts[i])
1176 for i in range(len(part_32.parts)):
1177 m.d.comb += finalout.d32[i].eq(p_32.parts[i])
1178 m.d.comb += finalout.i8.eq(io8.output)
1179 m.d.comb += finalout.i16.eq(io16.output)
1180 m.d.comb += finalout.i32.eq(io32.output)
1181 m.d.comb += finalout.i64.eq(io64.output)
1182 m.d.comb += self.output.eq(finalout.out)
1183
1184 return m
1185
1186
1187 if __name__ == "__main__":
1188 m = Mul8_16_32_64()
1189 main(m, ports=[m.a,
1190 m.b,
1191 m._intermediate_output,
1192 m.output,
1193 *m.part_ops,
1194 *m.part_pts.values()])