2cf37e84f1f3d6f322d4af48ede1e65e38f2d977
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0, mul=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
58 """
59 if name is None:
60 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
61 retval = PartitionPoints()
62 for point, enabled in self.items():
63 point *= mul
64 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
65 return retval
66
67 def eq(self, rhs):
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self.keys()) != set(rhs.keys()):
70 raise ValueError("incompatible point set")
71 for point, enabled in self.items():
72 yield enabled.eq(rhs[point])
73
74 def as_mask(self, width):
75 """Create a bit-mask from `self`.
76
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
79
80 :param width: the bit width of the resulting mask
81 """
82 bits = []
83 for i in range(width):
84 if i in self:
85 bits.append(~self[i])
86 else:
87 bits.append(True)
88 return Cat(*bits)
89
90 def get_max_partition_count(self, width):
91 """Get the maximum number of partitions.
92
93 Gets the number of partitions when all partition points are enabled.
94 """
95 retval = 1
96 for point in self.keys():
97 if point < width:
98 retval += 1
99 return retval
100
101 def fits_in_width(self, width):
102 """Check if all partition points are smaller than `width`."""
103 for point in self.keys():
104 if point >= width:
105 return False
106 return True
107
108 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
109 if index == -1 or index == 7:
110 return C(True, 1)
111 assert index >= 0 and index < 8
112 return self[(index * 8 + 8)*mfactor]
113
114
115 class FullAdder(Elaboratable):
116 """Full Adder.
117
118 :attribute in0: the first input
119 :attribute in1: the second input
120 :attribute in2: the third input
121 :attribute sum: the sum output
122 :attribute carry: the carry output
123
124 Rather than do individual full adders (and have an array of them,
125 which would be very slow to simulate), this module can specify the
126 bit width of the inputs and outputs: in effect it performs multiple
127 Full 3-2 Add operations "in parallel".
128 """
129
130 def __init__(self, width):
131 """Create a ``FullAdder``.
132
133 :param width: the bit width of the input and output
134 """
135 self.in0 = Signal(width)
136 self.in1 = Signal(width)
137 self.in2 = Signal(width)
138 self.sum = Signal(width)
139 self.carry = Signal(width)
140
141 def elaborate(self, platform):
142 """Elaborate this module."""
143 m = Module()
144 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
145 m.d.comb += self.carry.eq((self.in0 & self.in1)
146 | (self.in1 & self.in2)
147 | (self.in2 & self.in0))
148 return m
149
150
151 class MaskedFullAdder(Elaboratable):
152 """Masked Full Adder.
153
154 :attribute mask: the carry partition mask
155 :attribute in0: the first input
156 :attribute in1: the second input
157 :attribute in2: the third input
158 :attribute sum: the sum output
159 :attribute mcarry: the masked carry output
160
161 FullAdders are always used with a "mask" on the output. To keep
162 the graphviz "clean", this class performs the masking here rather
163 than inside a large for-loop.
164
165 See the following discussion as to why this is no longer derived
166 from FullAdder. Each carry is shifted here *before* being ANDed
167 with the mask, so that an AOI cell may be used (which is more
168 gate-efficient)
169 https://en.wikipedia.org/wiki/AND-OR-Invert
170 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
171 """
172
173 def __init__(self, width):
174 """Create a ``MaskedFullAdder``.
175
176 :param width: the bit width of the input and output
177 """
178 self.width = width
179 self.mask = Signal(width, reset_less=True)
180 self.mcarry = Signal(width, reset_less=True)
181 self.in0 = Signal(width, reset_less=True)
182 self.in1 = Signal(width, reset_less=True)
183 self.in2 = Signal(width, reset_less=True)
184 self.sum = Signal(width, reset_less=True)
185
186 def elaborate(self, platform):
187 """Elaborate this module."""
188 m = Module()
189 s1 = Signal(self.width, reset_less=True)
190 s2 = Signal(self.width, reset_less=True)
191 s3 = Signal(self.width, reset_less=True)
192 c1 = Signal(self.width, reset_less=True)
193 c2 = Signal(self.width, reset_less=True)
194 c3 = Signal(self.width, reset_less=True)
195 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
196 m.d.comb += s1.eq(Cat(0, self.in0))
197 m.d.comb += s2.eq(Cat(0, self.in1))
198 m.d.comb += s3.eq(Cat(0, self.in2))
199 m.d.comb += c1.eq(s1 & s2 & self.mask)
200 m.d.comb += c2.eq(s2 & s3 & self.mask)
201 m.d.comb += c3.eq(s3 & s1 & self.mask)
202 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
203 return m
204
205
206 class PartitionedAdder(Elaboratable):
207 """Partitioned Adder.
208
209 Performs the final add. The partition points are included in the
210 actual add (in one of the operands only), which causes a carry over
211 to the next bit. Then the final output *removes* the extra bits from
212 the result.
213
214 partition: .... P... P... P... P... (32 bits)
215 a : .... .... .... .... .... (32 bits)
216 b : .... .... .... .... .... (32 bits)
217 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
218 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
219 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
220 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
221
222 :attribute width: the bit width of the input and output. Read-only.
223 :attribute a: the first input to the adder
224 :attribute b: the second input to the adder
225 :attribute output: the sum output
226 :attribute partition_points: the input partition points. Modification not
227 supported, except for by ``Signal.eq``.
228 """
229
230 def __init__(self, width, partition_points):
231 """Create a ``PartitionedAdder``.
232
233 :param width: the bit width of the input and output
234 :param partition_points: the input partition points
235 """
236 self.width = width
237 self.a = Signal(width)
238 self.b = Signal(width)
239 self.output = Signal(width)
240 self.partition_points = PartitionPoints(partition_points)
241 if not self.partition_points.fits_in_width(width):
242 raise ValueError("partition_points doesn't fit in width")
243 expanded_width = 0
244 for i in range(self.width):
245 if i in self.partition_points:
246 expanded_width += 1
247 expanded_width += 1
248 self._expanded_width = expanded_width
249 # XXX these have to remain here due to some horrible nmigen
250 # simulation bugs involving sync. it is *not* necessary to
251 # have them here, they should (under normal circumstances)
252 # be moved into elaborate, as they are entirely local
253 self._expanded_a = Signal(expanded_width) # includes extra part-points
254 self._expanded_b = Signal(expanded_width) # likewise.
255 self._expanded_o = Signal(expanded_width) # likewise.
256
257 def elaborate(self, platform):
258 """Elaborate this module."""
259 m = Module()
260 expanded_index = 0
261 # store bits in a list, use Cat later. graphviz is much cleaner
262 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
263
264 # partition points are "breaks" (extra zeros or 1s) in what would
265 # otherwise be a massive long add. when the "break" points are 0,
266 # whatever is in it (in the output) is discarded. however when
267 # there is a "1", it causes a roll-over carry to the *next* bit.
268 # we still ignore the "break" bit in the [intermediate] output,
269 # however by that time we've got the effect that we wanted: the
270 # carry has been carried *over* the break point.
271
272 for i in range(self.width):
273 if i in self.partition_points:
274 # add extra bit set to 0 + 0 for enabled partition points
275 # and 1 + 0 for disabled partition points
276 ea.append(self._expanded_a[expanded_index])
277 al.append(~self.partition_points[i]) # add extra bit in a
278 eb.append(self._expanded_b[expanded_index])
279 bl.append(C(0)) # yes, add a zero
280 expanded_index += 1 # skip the extra point. NOT in the output
281 ea.append(self._expanded_a[expanded_index])
282 eb.append(self._expanded_b[expanded_index])
283 eo.append(self._expanded_o[expanded_index])
284 al.append(self.a[i])
285 bl.append(self.b[i])
286 ol.append(self.output[i])
287 expanded_index += 1
288
289 # combine above using Cat
290 m.d.comb += Cat(*ea).eq(Cat(*al))
291 m.d.comb += Cat(*eb).eq(Cat(*bl))
292 m.d.comb += Cat(*ol).eq(Cat(*eo))
293
294 # use only one addition to take advantage of look-ahead carry and
295 # special hardware on FPGAs
296 m.d.comb += self._expanded_o.eq(
297 self._expanded_a + self._expanded_b)
298 return m
299
300
301 FULL_ADDER_INPUT_COUNT = 3
302
303 class AddReduceData:
304
305 def __init__(self, ppoints, n_inputs, output_width, n_parts):
306 self.part_ops = [Signal(2, name=f"part_ops_{i}")
307 for i in range(n_parts)]
308 self.inputs = [Signal(output_width, name=f"inputs[{i}]")
309 for i in range(n_inputs)]
310 self.reg_partition_points = ppoints.like()
311
312 def eq(self, rhs):
313 return [self.reg_partition_points.eq(rhs.reg_partition_points)] + \
314 [self.inputs[i].eq(rhs.inputs[i])
315 for i in range(len(self.inputs))] + \
316 [self.part_ops[i].eq(rhs.part_ops[i])
317 for i in range(len(self.part_ops))]
318
319
320 class FinalAdd(Elaboratable):
321 """ Final stage of add reduce
322 """
323
324 def __init__(self, n_inputs, output_width, n_parts, register_levels,
325 partition_points):
326 self.i = AddReduceData(partition_points, n_inputs,
327 output_width, n_parts)
328 self.n_inputs = n_inputs
329 self.n_parts = n_parts
330 self.out_part_ops = self.i.part_ops
331 self._resized_inputs = self.i.inputs
332 self.register_levels = list(register_levels)
333 self.output = Signal(output_width)
334 self.partition_points = PartitionPoints(partition_points)
335 if not self.partition_points.fits_in_width(output_width):
336 raise ValueError("partition_points doesn't fit in output_width")
337 self._reg_partition_points = self.i.reg_partition_points
338 self.intermediate_terms = []
339
340 def elaborate(self, platform):
341 """Elaborate this module."""
342 m = Module()
343
344 if self.n_inputs == 0:
345 # use 0 as the default output value
346 m.d.comb += self.output.eq(0)
347 elif self.n_inputs == 1:
348 # handle single input
349 m.d.comb += self.output.eq(self._resized_inputs[0])
350 else:
351 # base case for adding 2 inputs
352 assert self.n_inputs == 2
353 adder = PartitionedAdder(len(self.output),
354 self._reg_partition_points)
355 m.submodules.final_adder = adder
356 m.d.comb += adder.a.eq(self._resized_inputs[0])
357 m.d.comb += adder.b.eq(self._resized_inputs[1])
358 m.d.comb += self.output.eq(adder.output)
359 return m
360
361
362 class AddReduceSingle(Elaboratable):
363 """Add list of numbers together.
364
365 :attribute inputs: input ``Signal``s to be summed. Modification not
366 supported, except for by ``Signal.eq``.
367 :attribute register_levels: List of nesting levels that should have
368 pipeline registers.
369 :attribute output: output sum.
370 :attribute partition_points: the input partition points. Modification not
371 supported, except for by ``Signal.eq``.
372 """
373
374 def __init__(self, n_inputs, output_width, n_parts, register_levels,
375 partition_points):
376 """Create an ``AddReduce``.
377
378 :param inputs: input ``Signal``s to be summed.
379 :param output_width: bit-width of ``output``.
380 :param register_levels: List of nesting levels that should have
381 pipeline registers.
382 :param partition_points: the input partition points.
383 """
384 self.n_inputs = n_inputs
385 self.n_parts = n_parts
386 self.output_width = output_width
387 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
388 for i in range(n_parts)]
389 self._resized_inputs = [
390 Signal(output_width, name=f"resized_inputs[{i}]")
391 for i in range(n_inputs)]
392 self.register_levels = list(register_levels)
393 self.partition_points = PartitionPoints(partition_points)
394 if not self.partition_points.fits_in_width(output_width):
395 raise ValueError("partition_points doesn't fit in output_width")
396 self._reg_partition_points = self.partition_points.like()
397
398 max_level = AddReduceSingle.get_max_level(n_inputs)
399 for level in self.register_levels:
400 if level > max_level:
401 raise ValueError(
402 "not enough adder levels for specified register levels")
403
404 # this is annoying. we have to create the modules (and terms)
405 # because we need to know what they are (in order to set up the
406 # interconnects back in AddReduce), but cannot do the m.d.comb +=
407 # etc because this is not in elaboratable.
408 self.groups = AddReduceSingle.full_adder_groups(n_inputs)
409 self._intermediate_terms = []
410 if len(self.groups) != 0:
411 self.create_next_terms()
412
413 @staticmethod
414 def get_max_level(input_count):
415 """Get the maximum level.
416
417 All ``register_levels`` must be less than or equal to the maximum
418 level.
419 """
420 retval = 0
421 while True:
422 groups = AddReduceSingle.full_adder_groups(input_count)
423 if len(groups) == 0:
424 return retval
425 input_count %= FULL_ADDER_INPUT_COUNT
426 input_count += 2 * len(groups)
427 retval += 1
428
429 @staticmethod
430 def full_adder_groups(input_count):
431 """Get ``inputs`` indices for which a full adder should be built."""
432 return range(0,
433 input_count - FULL_ADDER_INPUT_COUNT + 1,
434 FULL_ADDER_INPUT_COUNT)
435
436 def elaborate(self, platform):
437 """Elaborate this module."""
438 m = Module()
439
440 for (value, term) in self._intermediate_terms:
441 m.d.comb += term.eq(value)
442
443 mask = self._reg_partition_points.as_mask(self.output_width)
444 m.d.comb += self.part_mask.eq(mask)
445
446 # add and link the intermediate term modules
447 for i, (iidx, adder_i) in enumerate(self.adders):
448 setattr(m.submodules, f"adder_{i}", adder_i)
449
450 m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
451 m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
452 m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
453 m.d.comb += adder_i.mask.eq(self.part_mask)
454
455 return m
456
457 def create_next_terms(self):
458
459 # go on to prepare recursive case
460 intermediate_terms = []
461 _intermediate_terms = []
462
463 def add_intermediate_term(value):
464 intermediate_term = Signal(
465 self.output_width,
466 name=f"intermediate_terms[{len(intermediate_terms)}]")
467 _intermediate_terms.append((value, intermediate_term))
468 intermediate_terms.append(intermediate_term)
469
470 # store mask in intermediary (simplifies graph)
471 self.part_mask = Signal(self.output_width, reset_less=True)
472
473 # create full adders for this recursive level.
474 # this shrinks N terms to 2 * (N // 3) plus the remainder
475 self.adders = []
476 for i in self.groups:
477 adder_i = MaskedFullAdder(self.output_width)
478 self.adders.append((i, adder_i))
479 # add both the sum and the masked-carry to the next level.
480 # 3 inputs have now been reduced to 2...
481 add_intermediate_term(adder_i.sum)
482 add_intermediate_term(adder_i.mcarry)
483 # handle the remaining inputs.
484 if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
485 add_intermediate_term(self._resized_inputs[-1])
486 elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
487 # Just pass the terms to the next layer, since we wouldn't gain
488 # anything by using a half adder since there would still be 2 terms
489 # and just passing the terms to the next layer saves gates.
490 add_intermediate_term(self._resized_inputs[-2])
491 add_intermediate_term(self._resized_inputs[-1])
492 else:
493 assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
494
495 self.intermediate_terms = intermediate_terms
496 self._intermediate_terms = _intermediate_terms
497
498
499 class AddReduce(Elaboratable):
500 """Recursively Add list of numbers together.
501
502 :attribute inputs: input ``Signal``s to be summed. Modification not
503 supported, except for by ``Signal.eq``.
504 :attribute register_levels: List of nesting levels that should have
505 pipeline registers.
506 :attribute output: output sum.
507 :attribute partition_points: the input partition points. Modification not
508 supported, except for by ``Signal.eq``.
509 """
510
511 def __init__(self, inputs, output_width, register_levels, partition_points,
512 part_ops):
513 """Create an ``AddReduce``.
514
515 :param inputs: input ``Signal``s to be summed.
516 :param output_width: bit-width of ``output``.
517 :param register_levels: List of nesting levels that should have
518 pipeline registers.
519 :param partition_points: the input partition points.
520 """
521 self.inputs = inputs
522 self.part_ops = part_ops
523 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
524 for i in range(len(part_ops))]
525 self.output = Signal(output_width)
526 self.output_width = output_width
527 self.register_levels = register_levels
528 self.partition_points = partition_points
529
530 self.create_levels()
531
532 @staticmethod
533 def get_max_level(input_count):
534 return AddReduceSingle.get_max_level(input_count)
535
536 @staticmethod
537 def next_register_levels(register_levels):
538 """``Iterable`` of ``register_levels`` for next recursive level."""
539 for level in register_levels:
540 if level > 0:
541 yield level - 1
542
543 def create_levels(self):
544 """creates reduction levels"""
545
546 mods = []
547 next_levels = self.register_levels
548 partition_points = self.partition_points
549 inputs = self.inputs
550 part_ops = self.part_ops
551 n_parts = len(part_ops)
552 while True:
553 ilen = len(inputs)
554 next_level = AddReduceSingle(ilen, self.output_width, n_parts,
555 next_levels, partition_points)
556 mods.append(next_level)
557 next_levels = list(AddReduce.next_register_levels(next_levels))
558 partition_points = next_level._reg_partition_points
559 inputs = next_level.intermediate_terms
560 ilen = len(inputs)
561 part_ops = next_level.out_part_ops
562 groups = AddReduceSingle.full_adder_groups(len(inputs))
563 if len(groups) == 0:
564 break
565
566 if ilen != 0:
567 next_level = FinalAdd(ilen, self.output_width, n_parts,
568 next_levels, partition_points)
569 mods.append(next_level)
570
571 self.levels = mods
572
573 def elaborate(self, platform):
574 """Elaborate this module."""
575 m = Module()
576
577 for i, next_level in enumerate(self.levels):
578 setattr(m.submodules, "next_level%d" % i, next_level)
579
580 partition_points = self.partition_points
581 inputs = self.inputs
582 part_ops = self.part_ops
583 for i in range(len(self.levels)):
584 mcur = self.levels[i]
585 inassign = [mcur._resized_inputs[i].eq(inputs[i])
586 for i in range(len(inputs))]
587 copy_part_ops = [mcur.out_part_ops[i].eq(part_ops[i])
588 for i in range(len(part_ops))]
589 if 0 in mcur.register_levels:
590 m.d.sync += copy_part_ops
591 m.d.sync += inassign
592 m.d.sync += mcur._reg_partition_points.eq(partition_points)
593 else:
594 m.d.comb += copy_part_ops
595 m.d.comb += inassign
596 m.d.comb += mcur._reg_partition_points.eq(partition_points)
597 partition_points = mcur._reg_partition_points
598 inputs = mcur.intermediate_terms
599 part_ops = mcur.out_part_ops
600
601 # output comes from last module
602 m.d.comb += self.output.eq(next_level.output)
603 copy_part_ops = [self.out_part_ops[i].eq(next_level.out_part_ops[i])
604 for i in range(len(self.part_ops))]
605 m.d.comb += copy_part_ops
606
607 return m
608
609
610 OP_MUL_LOW = 0
611 OP_MUL_SIGNED_HIGH = 1
612 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
613 OP_MUL_UNSIGNED_HIGH = 3
614
615
616 def get_term(value, shift=0, enabled=None):
617 if enabled is not None:
618 value = Mux(enabled, value, 0)
619 if shift > 0:
620 value = Cat(Repl(C(0, 1), shift), value)
621 else:
622 assert shift == 0
623 return value
624
625
626 class ProductTerm(Elaboratable):
627 """ this class creates a single product term (a[..]*b[..]).
628 it has a design flaw in that is the *output* that is selected,
629 where the multiplication(s) are combinatorially generated
630 all the time.
631 """
632
633 def __init__(self, width, twidth, pbwid, a_index, b_index):
634 self.a_index = a_index
635 self.b_index = b_index
636 shift = 8 * (self.a_index + self.b_index)
637 self.pwidth = width
638 self.twidth = twidth
639 self.width = width*2
640 self.shift = shift
641
642 self.ti = Signal(self.width, reset_less=True)
643 self.term = Signal(twidth, reset_less=True)
644 self.a = Signal(twidth//2, reset_less=True)
645 self.b = Signal(twidth//2, reset_less=True)
646 self.pb_en = Signal(pbwid, reset_less=True)
647
648 self.tl = tl = []
649 min_index = min(self.a_index, self.b_index)
650 max_index = max(self.a_index, self.b_index)
651 for i in range(min_index, max_index):
652 tl.append(self.pb_en[i])
653 name = "te_%d_%d" % (self.a_index, self.b_index)
654 if len(tl) > 0:
655 term_enabled = Signal(name=name, reset_less=True)
656 else:
657 term_enabled = None
658 self.enabled = term_enabled
659 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
660
661 def elaborate(self, platform):
662
663 m = Module()
664 if self.enabled is not None:
665 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
666
667 bsa = Signal(self.width, reset_less=True)
668 bsb = Signal(self.width, reset_less=True)
669 a_index, b_index = self.a_index, self.b_index
670 pwidth = self.pwidth
671 m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
672 m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
673 m.d.comb += self.ti.eq(bsa * bsb)
674 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
675 """
676 #TODO: sort out width issues, get inputs a/b switched on/off.
677 #data going into Muxes is 1/2 the required width
678
679 pwidth = self.pwidth
680 width = self.width
681 bsa = Signal(self.twidth//2, reset_less=True)
682 bsb = Signal(self.twidth//2, reset_less=True)
683 asel = Signal(width, reset_less=True)
684 bsel = Signal(width, reset_less=True)
685 a_index, b_index = self.a_index, self.b_index
686 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
687 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
688 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
689 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
690 m.d.comb += self.ti.eq(bsa * bsb)
691 m.d.comb += self.term.eq(self.ti)
692 """
693
694 return m
695
696
697 class ProductTerms(Elaboratable):
698 """ creates a bank of product terms. also performs the actual bit-selection
699 this class is to be wrapped with a for-loop on the "a" operand.
700 it creates a second-level for-loop on the "b" operand.
701 """
702 def __init__(self, width, twidth, pbwid, a_index, blen):
703 self.a_index = a_index
704 self.blen = blen
705 self.pwidth = width
706 self.twidth = twidth
707 self.pbwid = pbwid
708 self.a = Signal(twidth//2, reset_less=True)
709 self.b = Signal(twidth//2, reset_less=True)
710 self.pb_en = Signal(pbwid, reset_less=True)
711 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
712 for i in range(blen)]
713
714 def elaborate(self, platform):
715
716 m = Module()
717
718 for b_index in range(self.blen):
719 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
720 self.a_index, b_index)
721 setattr(m.submodules, "term_%d" % b_index, t)
722
723 m.d.comb += t.a.eq(self.a)
724 m.d.comb += t.b.eq(self.b)
725 m.d.comb += t.pb_en.eq(self.pb_en)
726
727 m.d.comb += self.terms[b_index].eq(t.term)
728
729 return m
730
731
732 class LSBNegTerm(Elaboratable):
733
734 def __init__(self, bit_width):
735 self.bit_width = bit_width
736 self.part = Signal(reset_less=True)
737 self.signed = Signal(reset_less=True)
738 self.op = Signal(bit_width, reset_less=True)
739 self.msb = Signal(reset_less=True)
740 self.nt = Signal(bit_width*2, reset_less=True)
741 self.nl = Signal(bit_width*2, reset_less=True)
742
743 def elaborate(self, platform):
744 m = Module()
745 comb = m.d.comb
746 bit_wid = self.bit_width
747 ext = Repl(0, bit_wid) # extend output to HI part
748
749 # determine sign of each incoming number *in this partition*
750 enabled = Signal(reset_less=True)
751 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
752
753 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
754 # negation operation is split into a bitwise not and a +1.
755 # likewise for 16, 32, and 64-bit values.
756
757 # width-extended 1s complement if a is signed, otherwise zero
758 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
759
760 # add 1 if signed, otherwise add zero
761 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
762
763 return m
764
765
766 class Parts(Elaboratable):
767
768 def __init__(self, pbwid, epps, n_parts):
769 self.pbwid = pbwid
770 # inputs
771 self.epps = PartitionPoints.like(epps, name="epps") # expanded points
772 # outputs
773 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
774
775 def elaborate(self, platform):
776 m = Module()
777
778 epps, parts = self.epps, self.parts
779 # collect part-bytes (double factor because the input is extended)
780 pbs = Signal(self.pbwid, reset_less=True)
781 tl = []
782 for i in range(self.pbwid):
783 pb = Signal(name="pb%d" % i, reset_less=True)
784 m.d.comb += pb.eq(epps.part_byte(i, mfactor=2)) # double
785 tl.append(pb)
786 m.d.comb += pbs.eq(Cat(*tl))
787
788 # negated-temporary copy of partition bits
789 npbs = Signal.like(pbs, reset_less=True)
790 m.d.comb += npbs.eq(~pbs)
791 byte_count = 8 // len(parts)
792 for i in range(len(parts)):
793 pbl = []
794 pbl.append(npbs[i * byte_count - 1])
795 for j in range(i * byte_count, (i + 1) * byte_count - 1):
796 pbl.append(pbs[j])
797 pbl.append(npbs[(i + 1) * byte_count - 1])
798 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
799 m.d.comb += value.eq(Cat(*pbl))
800 m.d.comb += parts[i].eq(~(value).bool())
801
802 return m
803
804
805 class Part(Elaboratable):
806 """ a key class which, depending on the partitioning, will determine
807 what action to take when parts of the output are signed or unsigned.
808
809 this requires 2 pieces of data *per operand, per partition*:
810 whether the MSB is HI/LO (per partition!), and whether a signed
811 or unsigned operation has been *requested*.
812
813 once that is determined, signed is basically carried out
814 by splitting 2's complement into 1's complement plus one.
815 1's complement is just a bit-inversion.
816
817 the extra terms - as separate terms - are then thrown at the
818 AddReduce alongside the multiplication part-results.
819 """
820 def __init__(self, epps, width, n_parts, n_levels, pbwid):
821
822 self.pbwid = pbwid
823 self.epps = epps
824
825 # inputs
826 self.a = Signal(64)
827 self.b = Signal(64)
828 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
829 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
830 self.pbs = Signal(pbwid, reset_less=True)
831
832 # outputs
833 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
834
835 self.not_a_term = Signal(width)
836 self.neg_lsb_a_term = Signal(width)
837 self.not_b_term = Signal(width)
838 self.neg_lsb_b_term = Signal(width)
839
840 def elaborate(self, platform):
841 m = Module()
842
843 pbs, parts = self.pbs, self.parts
844 epps = self.epps
845 m.submodules.p = p = Parts(self.pbwid, epps, len(parts))
846 m.d.comb += p.epps.eq(epps)
847 parts = p.parts
848
849 byte_count = 8 // len(parts)
850
851 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
852 self.not_a_term, self.neg_lsb_a_term,
853 self.not_b_term, self.neg_lsb_b_term)
854
855 byte_width = 8 // len(parts) # byte width
856 bit_wid = 8 * byte_width # bit width
857 nat, nbt, nla, nlb = [], [], [], []
858 for i in range(len(parts)):
859 # work out bit-inverted and +1 term for a.
860 pa = LSBNegTerm(bit_wid)
861 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
862 m.d.comb += pa.part.eq(parts[i])
863 m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
864 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
865 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
866 nat.append(pa.nt)
867 nla.append(pa.nl)
868
869 # work out bit-inverted and +1 term for b
870 pb = LSBNegTerm(bit_wid)
871 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
872 m.d.comb += pb.part.eq(parts[i])
873 m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
874 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
875 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
876 nbt.append(pb.nt)
877 nlb.append(pb.nl)
878
879 # concatenate together and return all 4 results.
880 m.d.comb += [not_a_term.eq(Cat(*nat)),
881 not_b_term.eq(Cat(*nbt)),
882 neg_lsb_a_term.eq(Cat(*nla)),
883 neg_lsb_b_term.eq(Cat(*nlb)),
884 ]
885
886 return m
887
888
889 class IntermediateOut(Elaboratable):
890 """ selects the HI/LO part of the multiplication, for a given bit-width
891 the output is also reconstructed in its SIMD (partition) lanes.
892 """
893 def __init__(self, width, out_wid, n_parts):
894 self.width = width
895 self.n_parts = n_parts
896 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
897 for i in range(8)]
898 self.intermed = Signal(out_wid, reset_less=True)
899 self.output = Signal(out_wid//2, reset_less=True)
900
901 def elaborate(self, platform):
902 m = Module()
903
904 ol = []
905 w = self.width
906 sel = w // 8
907 for i in range(self.n_parts):
908 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
909 m.d.comb += op.eq(
910 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
911 self.intermed.part(i * w*2, w),
912 self.intermed.part(i * w*2 + w, w)))
913 ol.append(op)
914 m.d.comb += self.output.eq(Cat(*ol))
915
916 return m
917
918
919 class FinalOut(Elaboratable):
920 """ selects the final output based on the partitioning.
921
922 each byte is selectable independently, i.e. it is possible
923 that some partitions requested 8-bit computation whilst others
924 requested 16 or 32 bit.
925 """
926 def __init__(self, out_wid):
927 # inputs
928 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
929 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
930 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
931
932 self.i8 = Signal(out_wid, reset_less=True)
933 self.i16 = Signal(out_wid, reset_less=True)
934 self.i32 = Signal(out_wid, reset_less=True)
935 self.i64 = Signal(out_wid, reset_less=True)
936
937 # output
938 self.out = Signal(out_wid, reset_less=True)
939
940 def elaborate(self, platform):
941 m = Module()
942 ol = []
943 for i in range(8):
944 # select one of the outputs: d8 selects i8, d16 selects i16
945 # d32 selects i32, and the default is i64.
946 # d8 and d16 are ORed together in the first Mux
947 # then the 2nd selects either i8 or i16.
948 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
949 op = Signal(8, reset_less=True, name="op_%d" % i)
950 m.d.comb += op.eq(
951 Mux(self.d8[i] | self.d16[i // 2],
952 Mux(self.d8[i], self.i8.part(i * 8, 8),
953 self.i16.part(i * 8, 8)),
954 Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
955 self.i64.part(i * 8, 8))))
956 ol.append(op)
957 m.d.comb += self.out.eq(Cat(*ol))
958 return m
959
960
961 class OrMod(Elaboratable):
962 """ ORs four values together in a hierarchical tree
963 """
964 def __init__(self, wid):
965 self.wid = wid
966 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
967 for i in range(4)]
968 self.orout = Signal(wid, reset_less=True)
969
970 def elaborate(self, platform):
971 m = Module()
972 or1 = Signal(self.wid, reset_less=True)
973 or2 = Signal(self.wid, reset_less=True)
974 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
975 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
976 m.d.comb += self.orout.eq(or1 | or2)
977
978 return m
979
980
981 class Signs(Elaboratable):
982 """ determines whether a or b are signed numbers
983 based on the required operation type (OP_MUL_*)
984 """
985
986 def __init__(self):
987 self.part_ops = Signal(2, reset_less=True)
988 self.a_signed = Signal(reset_less=True)
989 self.b_signed = Signal(reset_less=True)
990
991 def elaborate(self, platform):
992
993 m = Module()
994
995 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
996 bsig = (self.part_ops == OP_MUL_LOW) \
997 | (self.part_ops == OP_MUL_SIGNED_HIGH)
998 m.d.comb += self.a_signed.eq(asig)
999 m.d.comb += self.b_signed.eq(bsig)
1000
1001 return m
1002
1003
1004 class Mul8_16_32_64(Elaboratable):
1005 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1006
1007 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1008 partitions on naturally-aligned boundaries. Supports the operation being
1009 set for each partition independently.
1010
1011 :attribute part_pts: the input partition points. Has a partition point at
1012 multiples of 8 in 0 < i < 64. Each partition point's associated
1013 ``Value`` is a ``Signal``. Modification not supported, except for by
1014 ``Signal.eq``.
1015 :attribute part_ops: the operation for each byte. The operation for a
1016 particular partition is selected by assigning the selected operation
1017 code to each byte in the partition. The allowed operation codes are:
1018
1019 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1020 RISC-V's `mul` instruction.
1021 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1022 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1023 instruction.
1024 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1025 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1026 `mulhsu` instruction.
1027 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1028 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1029 instruction.
1030 """
1031
1032 def __init__(self, register_levels=()):
1033 """ register_levels: specifies the points in the cascade at which
1034 flip-flops are to be inserted.
1035 """
1036
1037 # parameter(s)
1038 self.register_levels = list(register_levels)
1039
1040 # inputs
1041 self.part_pts = PartitionPoints()
1042 for i in range(8, 64, 8):
1043 self.part_pts[i] = Signal(name=f"part_pts_{i}")
1044 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
1045 self.a = Signal(64)
1046 self.b = Signal(64)
1047
1048 # intermediates (needed for unit tests)
1049 self._intermediate_output = Signal(128)
1050
1051 # output
1052 self.output = Signal(64)
1053
1054 def elaborate(self, platform):
1055 m = Module()
1056
1057 # collect part-bytes
1058 pbs = Signal(8, reset_less=True)
1059 tl = []
1060 for i in range(8):
1061 pb = Signal(name="pb%d" % i, reset_less=True)
1062 m.d.comb += pb.eq(self.part_pts.part_byte(i))
1063 tl.append(pb)
1064 m.d.comb += pbs.eq(Cat(*tl))
1065
1066 # create (doubled) PartitionPoints (output is double input width)
1067 expanded_part_pts = eps = PartitionPoints()
1068 for i, v in self.part_pts.items():
1069 ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
1070 expanded_part_pts[i * 2] = ep
1071 m.d.comb += ep.eq(v)
1072
1073 # local variables
1074 signs = []
1075 for i in range(8):
1076 s = Signs()
1077 signs.append(s)
1078 setattr(m.submodules, "signs%d" % i, s)
1079 m.d.comb += s.part_ops.eq(self.part_ops[i])
1080
1081 n_levels = len(self.register_levels)+1
1082 m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
1083 m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
1084 m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
1085 m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
1086 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1087 for mod in [part_8, part_16, part_32, part_64]:
1088 m.d.comb += mod.a.eq(self.a)
1089 m.d.comb += mod.b.eq(self.b)
1090 for i in range(len(signs)):
1091 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1092 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1093 m.d.comb += mod.pbs.eq(pbs)
1094 nat_l.append(mod.not_a_term)
1095 nbt_l.append(mod.not_b_term)
1096 nla_l.append(mod.neg_lsb_a_term)
1097 nlb_l.append(mod.neg_lsb_b_term)
1098
1099 terms = []
1100
1101 for a_index in range(8):
1102 t = ProductTerms(8, 128, 8, a_index, 8)
1103 setattr(m.submodules, "terms_%d" % a_index, t)
1104
1105 m.d.comb += t.a.eq(self.a)
1106 m.d.comb += t.b.eq(self.b)
1107 m.d.comb += t.pb_en.eq(pbs)
1108
1109 for term in t.terms:
1110 terms.append(term)
1111
1112 # it's fine to bitwise-or data together since they are never enabled
1113 # at the same time
1114 m.submodules.nat_or = nat_or = OrMod(128)
1115 m.submodules.nbt_or = nbt_or = OrMod(128)
1116 m.submodules.nla_or = nla_or = OrMod(128)
1117 m.submodules.nlb_or = nlb_or = OrMod(128)
1118 for l, mod in [(nat_l, nat_or),
1119 (nbt_l, nbt_or),
1120 (nla_l, nla_or),
1121 (nlb_l, nlb_or)]:
1122 for i in range(len(l)):
1123 m.d.comb += mod.orin[i].eq(l[i])
1124 terms.append(mod.orout)
1125
1126 add_reduce = AddReduce(terms,
1127 128,
1128 self.register_levels,
1129 expanded_part_pts,
1130 self.part_ops)
1131
1132 out_part_ops = add_reduce.levels[-1].out_part_ops
1133 out_part_pts = add_reduce.levels[-1]._reg_partition_points
1134
1135 m.submodules.add_reduce = add_reduce
1136 m.d.comb += self._intermediate_output.eq(add_reduce.output)
1137 # create _output_64
1138 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1139 m.d.comb += io64.intermed.eq(self._intermediate_output)
1140 for i in range(8):
1141 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1142
1143 # create _output_32
1144 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1145 m.d.comb += io32.intermed.eq(self._intermediate_output)
1146 for i in range(8):
1147 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1148
1149 # create _output_16
1150 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1151 m.d.comb += io16.intermed.eq(self._intermediate_output)
1152 for i in range(8):
1153 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1154
1155 # create _output_8
1156 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1157 m.d.comb += io8.intermed.eq(self._intermediate_output)
1158 for i in range(8):
1159 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1160
1161 m.submodules.p_8 = p_8 = Parts(8, eps, len(part_8.parts))
1162 m.submodules.p_16 = p_16 = Parts(8, eps, len(part_16.parts))
1163 m.submodules.p_32 = p_32 = Parts(8, eps, len(part_32.parts))
1164 m.submodules.p_64 = p_64 = Parts(8, eps, len(part_64.parts))
1165
1166 m.d.comb += p_8.epps.eq(out_part_pts)
1167 m.d.comb += p_16.epps.eq(out_part_pts)
1168 m.d.comb += p_32.epps.eq(out_part_pts)
1169 m.d.comb += p_64.epps.eq(out_part_pts)
1170
1171 # final output
1172 m.submodules.finalout = finalout = FinalOut(64)
1173 for i in range(len(part_8.parts)):
1174 m.d.comb += finalout.d8[i].eq(p_8.parts[i])
1175 for i in range(len(part_16.parts)):
1176 m.d.comb += finalout.d16[i].eq(p_16.parts[i])
1177 for i in range(len(part_32.parts)):
1178 m.d.comb += finalout.d32[i].eq(p_32.parts[i])
1179 m.d.comb += finalout.i8.eq(io8.output)
1180 m.d.comb += finalout.i16.eq(io16.output)
1181 m.d.comb += finalout.i32.eq(io32.output)
1182 m.d.comb += finalout.i64.eq(io64.output)
1183 m.d.comb += self.output.eq(finalout.out)
1184
1185 return m
1186
1187
1188 if __name__ == "__main__":
1189 m = Mul8_16_32_64()
1190 main(m, ports=[m.a,
1191 m.b,
1192 m._intermediate_output,
1193 m.output,
1194 *m.part_ops,
1195 *m.part_pts.values()])