move input assignments (chain) out of AddReduceSingle
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0, mul=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
58 """
59 if name is None:
60 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
61 retval = PartitionPoints()
62 for point, enabled in self.items():
63 point *= mul
64 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
65 return retval
66
67 def eq(self, rhs):
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self.keys()) != set(rhs.keys()):
70 raise ValueError("incompatible point set")
71 for point, enabled in self.items():
72 yield enabled.eq(rhs[point])
73
74 def as_mask(self, width):
75 """Create a bit-mask from `self`.
76
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
79
80 :param width: the bit width of the resulting mask
81 """
82 bits = []
83 for i in range(width):
84 if i in self:
85 bits.append(~self[i])
86 else:
87 bits.append(True)
88 return Cat(*bits)
89
90 def get_max_partition_count(self, width):
91 """Get the maximum number of partitions.
92
93 Gets the number of partitions when all partition points are enabled.
94 """
95 retval = 1
96 for point in self.keys():
97 if point < width:
98 retval += 1
99 return retval
100
101 def fits_in_width(self, width):
102 """Check if all partition points are smaller than `width`."""
103 for point in self.keys():
104 if point >= width:
105 return False
106 return True
107
108 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
109 if index == -1 or index == 7:
110 return C(True, 1)
111 assert index >= 0 and index < 8
112 return self[(index * 8 + 8)*mfactor]
113
114
115 class FullAdder(Elaboratable):
116 """Full Adder.
117
118 :attribute in0: the first input
119 :attribute in1: the second input
120 :attribute in2: the third input
121 :attribute sum: the sum output
122 :attribute carry: the carry output
123
124 Rather than do individual full adders (and have an array of them,
125 which would be very slow to simulate), this module can specify the
126 bit width of the inputs and outputs: in effect it performs multiple
127 Full 3-2 Add operations "in parallel".
128 """
129
130 def __init__(self, width):
131 """Create a ``FullAdder``.
132
133 :param width: the bit width of the input and output
134 """
135 self.in0 = Signal(width)
136 self.in1 = Signal(width)
137 self.in2 = Signal(width)
138 self.sum = Signal(width)
139 self.carry = Signal(width)
140
141 def elaborate(self, platform):
142 """Elaborate this module."""
143 m = Module()
144 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
145 m.d.comb += self.carry.eq((self.in0 & self.in1)
146 | (self.in1 & self.in2)
147 | (self.in2 & self.in0))
148 return m
149
150
151 class MaskedFullAdder(Elaboratable):
152 """Masked Full Adder.
153
154 :attribute mask: the carry partition mask
155 :attribute in0: the first input
156 :attribute in1: the second input
157 :attribute in2: the third input
158 :attribute sum: the sum output
159 :attribute mcarry: the masked carry output
160
161 FullAdders are always used with a "mask" on the output. To keep
162 the graphviz "clean", this class performs the masking here rather
163 than inside a large for-loop.
164
165 See the following discussion as to why this is no longer derived
166 from FullAdder. Each carry is shifted here *before* being ANDed
167 with the mask, so that an AOI cell may be used (which is more
168 gate-efficient)
169 https://en.wikipedia.org/wiki/AND-OR-Invert
170 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
171 """
172
173 def __init__(self, width):
174 """Create a ``MaskedFullAdder``.
175
176 :param width: the bit width of the input and output
177 """
178 self.width = width
179 self.mask = Signal(width, reset_less=True)
180 self.mcarry = Signal(width, reset_less=True)
181 self.in0 = Signal(width, reset_less=True)
182 self.in1 = Signal(width, reset_less=True)
183 self.in2 = Signal(width, reset_less=True)
184 self.sum = Signal(width, reset_less=True)
185
186 def elaborate(self, platform):
187 """Elaborate this module."""
188 m = Module()
189 s1 = Signal(self.width, reset_less=True)
190 s2 = Signal(self.width, reset_less=True)
191 s3 = Signal(self.width, reset_less=True)
192 c1 = Signal(self.width, reset_less=True)
193 c2 = Signal(self.width, reset_less=True)
194 c3 = Signal(self.width, reset_less=True)
195 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
196 m.d.comb += s1.eq(Cat(0, self.in0))
197 m.d.comb += s2.eq(Cat(0, self.in1))
198 m.d.comb += s3.eq(Cat(0, self.in2))
199 m.d.comb += c1.eq(s1 & s2 & self.mask)
200 m.d.comb += c2.eq(s2 & s3 & self.mask)
201 m.d.comb += c3.eq(s3 & s1 & self.mask)
202 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
203 return m
204
205
206 class PartitionedAdder(Elaboratable):
207 """Partitioned Adder.
208
209 Performs the final add. The partition points are included in the
210 actual add (in one of the operands only), which causes a carry over
211 to the next bit. Then the final output *removes* the extra bits from
212 the result.
213
214 partition: .... P... P... P... P... (32 bits)
215 a : .... .... .... .... .... (32 bits)
216 b : .... .... .... .... .... (32 bits)
217 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
218 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
219 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
220 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
221
222 :attribute width: the bit width of the input and output. Read-only.
223 :attribute a: the first input to the adder
224 :attribute b: the second input to the adder
225 :attribute output: the sum output
226 :attribute partition_points: the input partition points. Modification not
227 supported, except for by ``Signal.eq``.
228 """
229
230 def __init__(self, width, partition_points):
231 """Create a ``PartitionedAdder``.
232
233 :param width: the bit width of the input and output
234 :param partition_points: the input partition points
235 """
236 self.width = width
237 self.a = Signal(width)
238 self.b = Signal(width)
239 self.output = Signal(width)
240 self.partition_points = PartitionPoints(partition_points)
241 if not self.partition_points.fits_in_width(width):
242 raise ValueError("partition_points doesn't fit in width")
243 expanded_width = 0
244 for i in range(self.width):
245 if i in self.partition_points:
246 expanded_width += 1
247 expanded_width += 1
248 self._expanded_width = expanded_width
249 # XXX these have to remain here due to some horrible nmigen
250 # simulation bugs involving sync. it is *not* necessary to
251 # have them here, they should (under normal circumstances)
252 # be moved into elaborate, as they are entirely local
253 self._expanded_a = Signal(expanded_width) # includes extra part-points
254 self._expanded_b = Signal(expanded_width) # likewise.
255 self._expanded_o = Signal(expanded_width) # likewise.
256
257 def elaborate(self, platform):
258 """Elaborate this module."""
259 m = Module()
260 expanded_index = 0
261 # store bits in a list, use Cat later. graphviz is much cleaner
262 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
263
264 # partition points are "breaks" (extra zeros or 1s) in what would
265 # otherwise be a massive long add. when the "break" points are 0,
266 # whatever is in it (in the output) is discarded. however when
267 # there is a "1", it causes a roll-over carry to the *next* bit.
268 # we still ignore the "break" bit in the [intermediate] output,
269 # however by that time we've got the effect that we wanted: the
270 # carry has been carried *over* the break point.
271
272 for i in range(self.width):
273 if i in self.partition_points:
274 # add extra bit set to 0 + 0 for enabled partition points
275 # and 1 + 0 for disabled partition points
276 ea.append(self._expanded_a[expanded_index])
277 al.append(~self.partition_points[i]) # add extra bit in a
278 eb.append(self._expanded_b[expanded_index])
279 bl.append(C(0)) # yes, add a zero
280 expanded_index += 1 # skip the extra point. NOT in the output
281 ea.append(self._expanded_a[expanded_index])
282 eb.append(self._expanded_b[expanded_index])
283 eo.append(self._expanded_o[expanded_index])
284 al.append(self.a[i])
285 bl.append(self.b[i])
286 ol.append(self.output[i])
287 expanded_index += 1
288
289 # combine above using Cat
290 m.d.comb += Cat(*ea).eq(Cat(*al))
291 m.d.comb += Cat(*eb).eq(Cat(*bl))
292 m.d.comb += Cat(*ol).eq(Cat(*eo))
293
294 # use only one addition to take advantage of look-ahead carry and
295 # special hardware on FPGAs
296 m.d.comb += self._expanded_o.eq(
297 self._expanded_a + self._expanded_b)
298 return m
299
300
301 FULL_ADDER_INPUT_COUNT = 3
302
303 class AddReduceData:
304
305 def __init__(self, ppoints, output_width, n_parts):
306 self.part_ops = [Signal(2, name=f"part_ops_{i}")
307 for i in range(n_parts)]
308 self.inputs = [Signal(output_width, name=f"inputs[{i}]")
309 for i in range(len(self.inputs))]
310 self.reg_partition_points = partition_points.like()
311
312 def eq(self, rhs):
313 return [self.reg_partition_points.eq(rhs.reg_partition_points)] + \
314 [self.inputs[i].eq(rhs.inputs[i])
315 for i in range(len(self.inputs))] + \
316 [self.part_ops[i].eq(rhs.part_ops[i])
317 for i in range(len(self.part_ops))]
318
319
320 class FinalAdd(Elaboratable):
321 """ Final stage of add reduce
322 """
323
324 def __init__(self, inputs, output_width, register_levels, partition_points,
325 part_ops):
326 self.part_ops = part_ops
327 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
328 for i in range(len(part_ops))]
329 self.inputs = list(inputs)
330 self._resized_inputs = [
331 Signal(output_width, name=f"resized_inputs[{i}]")
332 for i in range(len(self.inputs))]
333 self.register_levels = list(register_levels)
334 self.output = Signal(output_width)
335 self.partition_points = PartitionPoints(partition_points)
336 if not self.partition_points.fits_in_width(output_width):
337 raise ValueError("partition_points doesn't fit in output_width")
338 self._reg_partition_points = self.partition_points.like()
339
340 def elaborate(self, platform):
341 """Elaborate this module."""
342 m = Module()
343
344 if len(self.inputs) == 0:
345 # use 0 as the default output value
346 m.d.comb += self.output.eq(0)
347 elif len(self.inputs) == 1:
348 # handle single input
349 m.d.comb += self.output.eq(self._resized_inputs[0])
350 else:
351 # base case for adding 2 inputs
352 assert len(self.inputs) == 2
353 adder = PartitionedAdder(len(self.output),
354 self._reg_partition_points)
355 m.submodules.final_adder = adder
356 m.d.comb += adder.a.eq(self._resized_inputs[0])
357 m.d.comb += adder.b.eq(self._resized_inputs[1])
358 m.d.comb += self.output.eq(adder.output)
359 return m
360
361
362 class AddReduceSingle(Elaboratable):
363 """Add list of numbers together.
364
365 :attribute inputs: input ``Signal``s to be summed. Modification not
366 supported, except for by ``Signal.eq``.
367 :attribute register_levels: List of nesting levels that should have
368 pipeline registers.
369 :attribute output: output sum.
370 :attribute partition_points: the input partition points. Modification not
371 supported, except for by ``Signal.eq``.
372 """
373
374 def __init__(self, inputs, output_width, register_levels, partition_points,
375 part_ops):
376 """Create an ``AddReduce``.
377
378 :param inputs: input ``Signal``s to be summed.
379 :param output_width: bit-width of ``output``.
380 :param register_levels: List of nesting levels that should have
381 pipeline registers.
382 :param partition_points: the input partition points.
383 """
384 self.output_width = output_width
385 self.part_ops = part_ops
386 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
387 for i in range(len(part_ops))]
388 self.inputs = list(inputs)
389 self._resized_inputs = [
390 Signal(output_width, name=f"resized_inputs[{i}]")
391 for i in range(len(self.inputs))]
392 self.register_levels = list(register_levels)
393 self.partition_points = PartitionPoints(partition_points)
394 if not self.partition_points.fits_in_width(output_width):
395 raise ValueError("partition_points doesn't fit in output_width")
396 self._reg_partition_points = self.partition_points.like()
397
398 max_level = AddReduceSingle.get_max_level(len(self.inputs))
399 for level in self.register_levels:
400 if level > max_level:
401 raise ValueError(
402 "not enough adder levels for specified register levels")
403
404 # this is annoying. we have to create the modules (and terms)
405 # because we need to know what they are (in order to set up the
406 # interconnects back in AddReduce), but cannot do the m.d.comb +=
407 # etc because this is not in elaboratable.
408 self.groups = AddReduceSingle.full_adder_groups(len(self.inputs))
409 self._intermediate_terms = []
410 if len(self.groups) != 0:
411 self.create_next_terms()
412
413 @staticmethod
414 def get_max_level(input_count):
415 """Get the maximum level.
416
417 All ``register_levels`` must be less than or equal to the maximum
418 level.
419 """
420 retval = 0
421 while True:
422 groups = AddReduceSingle.full_adder_groups(input_count)
423 if len(groups) == 0:
424 return retval
425 input_count %= FULL_ADDER_INPUT_COUNT
426 input_count += 2 * len(groups)
427 retval += 1
428
429 @staticmethod
430 def full_adder_groups(input_count):
431 """Get ``inputs`` indices for which a full adder should be built."""
432 return range(0,
433 input_count - FULL_ADDER_INPUT_COUNT + 1,
434 FULL_ADDER_INPUT_COUNT)
435
436 def elaborate(self, platform):
437 """Elaborate this module."""
438 m = Module()
439
440 for (value, term) in self._intermediate_terms:
441 m.d.comb += term.eq(value)
442
443 mask = self._reg_partition_points.as_mask(self.output_width)
444 m.d.comb += self.part_mask.eq(mask)
445
446 # add and link the intermediate term modules
447 for i, (iidx, adder_i) in enumerate(self.adders):
448 setattr(m.submodules, f"adder_{i}", adder_i)
449
450 m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
451 m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
452 m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
453 m.d.comb += adder_i.mask.eq(self.part_mask)
454
455 return m
456
457 def create_next_terms(self):
458
459 # go on to prepare recursive case
460 intermediate_terms = []
461 _intermediate_terms = []
462
463 def add_intermediate_term(value):
464 intermediate_term = Signal(
465 self.output_width,
466 name=f"intermediate_terms[{len(intermediate_terms)}]")
467 _intermediate_terms.append((value, intermediate_term))
468 intermediate_terms.append(intermediate_term)
469
470 # store mask in intermediary (simplifies graph)
471 self.part_mask = Signal(self.output_width, reset_less=True)
472
473 # create full adders for this recursive level.
474 # this shrinks N terms to 2 * (N // 3) plus the remainder
475 self.adders = []
476 for i in self.groups:
477 adder_i = MaskedFullAdder(self.output_width)
478 self.adders.append((i, adder_i))
479 # add both the sum and the masked-carry to the next level.
480 # 3 inputs have now been reduced to 2...
481 add_intermediate_term(adder_i.sum)
482 add_intermediate_term(adder_i.mcarry)
483 # handle the remaining inputs.
484 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
485 add_intermediate_term(self._resized_inputs[-1])
486 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
487 # Just pass the terms to the next layer, since we wouldn't gain
488 # anything by using a half adder since there would still be 2 terms
489 # and just passing the terms to the next layer saves gates.
490 add_intermediate_term(self._resized_inputs[-2])
491 add_intermediate_term(self._resized_inputs[-1])
492 else:
493 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
494
495 self.intermediate_terms = intermediate_terms
496 self._intermediate_terms = _intermediate_terms
497
498
499 class AddReduce(Elaboratable):
500 """Recursively Add list of numbers together.
501
502 :attribute inputs: input ``Signal``s to be summed. Modification not
503 supported, except for by ``Signal.eq``.
504 :attribute register_levels: List of nesting levels that should have
505 pipeline registers.
506 :attribute output: output sum.
507 :attribute partition_points: the input partition points. Modification not
508 supported, except for by ``Signal.eq``.
509 """
510
511 def __init__(self, inputs, output_width, register_levels, partition_points,
512 part_ops):
513 """Create an ``AddReduce``.
514
515 :param inputs: input ``Signal``s to be summed.
516 :param output_width: bit-width of ``output``.
517 :param register_levels: List of nesting levels that should have
518 pipeline registers.
519 :param partition_points: the input partition points.
520 """
521 self.inputs = inputs
522 self.part_ops = part_ops
523 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
524 for i in range(len(part_ops))]
525 self.output = Signal(output_width)
526 self.output_width = output_width
527 self.register_levels = register_levels
528 self.partition_points = partition_points
529
530 self.create_levels()
531
532 @staticmethod
533 def get_max_level(input_count):
534 return AddReduceSingle.get_max_level(input_count)
535
536 @staticmethod
537 def next_register_levels(register_levels):
538 """``Iterable`` of ``register_levels`` for next recursive level."""
539 for level in register_levels:
540 if level > 0:
541 yield level - 1
542
543 def create_levels(self):
544 """creates reduction levels"""
545
546 mods = []
547 next_levels = self.register_levels
548 partition_points = self.partition_points
549 inputs = self.inputs
550 part_ops = self.part_ops
551 while True:
552 next_level = AddReduceSingle(inputs, self.output_width, next_levels,
553 partition_points, part_ops)
554 mods.append(next_level)
555 next_levels = list(AddReduce.next_register_levels(next_levels))
556 partition_points = next_level._reg_partition_points
557 inputs = next_level.intermediate_terms
558 part_ops = next_level.out_part_ops
559 groups = AddReduceSingle.full_adder_groups(len(inputs))
560 if len(groups) == 0:
561 break
562
563 next_level = FinalAdd(inputs, self.output_width, next_levels,
564 partition_points, part_ops)
565 mods.append(next_level)
566
567 self.levels = mods
568
569 def elaborate(self, platform):
570 """Elaborate this module."""
571 m = Module()
572
573 for i, next_level in enumerate(self.levels):
574 setattr(m.submodules, "next_level%d" % i, next_level)
575
576 for i in range(len(self.levels)):
577 mcur = self.levels[i]
578 #mnext = self.levels[i+1]
579 inassign = [mcur._resized_inputs[i].eq(mcur.inputs[i])
580 for i in range(len(mcur.inputs))]
581 copy_part_ops = [mcur.out_part_ops[i].eq(mcur.part_ops[i])
582 for i in range(len(mcur.part_ops))]
583 if 0 in mcur.register_levels:
584 m.d.sync += copy_part_ops
585 m.d.sync += inassign
586 m.d.sync += mcur._reg_partition_points.eq(mcur.partition_points)
587 else:
588 m.d.comb += copy_part_ops
589 m.d.comb += inassign
590 m.d.comb += mcur._reg_partition_points.eq(mcur.partition_points)
591
592 # output comes from last module
593 m.d.comb += self.output.eq(next_level.output)
594 copy_part_ops = [self.out_part_ops[i].eq(next_level.out_part_ops[i])
595 for i in range(len(self.part_ops))]
596 m.d.comb += copy_part_ops
597
598 return m
599
600
601 OP_MUL_LOW = 0
602 OP_MUL_SIGNED_HIGH = 1
603 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
604 OP_MUL_UNSIGNED_HIGH = 3
605
606
607 def get_term(value, shift=0, enabled=None):
608 if enabled is not None:
609 value = Mux(enabled, value, 0)
610 if shift > 0:
611 value = Cat(Repl(C(0, 1), shift), value)
612 else:
613 assert shift == 0
614 return value
615
616
617 class ProductTerm(Elaboratable):
618 """ this class creates a single product term (a[..]*b[..]).
619 it has a design flaw in that is the *output* that is selected,
620 where the multiplication(s) are combinatorially generated
621 all the time.
622 """
623
624 def __init__(self, width, twidth, pbwid, a_index, b_index):
625 self.a_index = a_index
626 self.b_index = b_index
627 shift = 8 * (self.a_index + self.b_index)
628 self.pwidth = width
629 self.twidth = twidth
630 self.width = width*2
631 self.shift = shift
632
633 self.ti = Signal(self.width, reset_less=True)
634 self.term = Signal(twidth, reset_less=True)
635 self.a = Signal(twidth//2, reset_less=True)
636 self.b = Signal(twidth//2, reset_less=True)
637 self.pb_en = Signal(pbwid, reset_less=True)
638
639 self.tl = tl = []
640 min_index = min(self.a_index, self.b_index)
641 max_index = max(self.a_index, self.b_index)
642 for i in range(min_index, max_index):
643 tl.append(self.pb_en[i])
644 name = "te_%d_%d" % (self.a_index, self.b_index)
645 if len(tl) > 0:
646 term_enabled = Signal(name=name, reset_less=True)
647 else:
648 term_enabled = None
649 self.enabled = term_enabled
650 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
651
652 def elaborate(self, platform):
653
654 m = Module()
655 if self.enabled is not None:
656 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
657
658 bsa = Signal(self.width, reset_less=True)
659 bsb = Signal(self.width, reset_less=True)
660 a_index, b_index = self.a_index, self.b_index
661 pwidth = self.pwidth
662 m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
663 m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
664 m.d.comb += self.ti.eq(bsa * bsb)
665 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
666 """
667 #TODO: sort out width issues, get inputs a/b switched on/off.
668 #data going into Muxes is 1/2 the required width
669
670 pwidth = self.pwidth
671 width = self.width
672 bsa = Signal(self.twidth//2, reset_less=True)
673 bsb = Signal(self.twidth//2, reset_less=True)
674 asel = Signal(width, reset_less=True)
675 bsel = Signal(width, reset_less=True)
676 a_index, b_index = self.a_index, self.b_index
677 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
678 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
679 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
680 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
681 m.d.comb += self.ti.eq(bsa * bsb)
682 m.d.comb += self.term.eq(self.ti)
683 """
684
685 return m
686
687
688 class ProductTerms(Elaboratable):
689 """ creates a bank of product terms. also performs the actual bit-selection
690 this class is to be wrapped with a for-loop on the "a" operand.
691 it creates a second-level for-loop on the "b" operand.
692 """
693 def __init__(self, width, twidth, pbwid, a_index, blen):
694 self.a_index = a_index
695 self.blen = blen
696 self.pwidth = width
697 self.twidth = twidth
698 self.pbwid = pbwid
699 self.a = Signal(twidth//2, reset_less=True)
700 self.b = Signal(twidth//2, reset_less=True)
701 self.pb_en = Signal(pbwid, reset_less=True)
702 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
703 for i in range(blen)]
704
705 def elaborate(self, platform):
706
707 m = Module()
708
709 for b_index in range(self.blen):
710 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
711 self.a_index, b_index)
712 setattr(m.submodules, "term_%d" % b_index, t)
713
714 m.d.comb += t.a.eq(self.a)
715 m.d.comb += t.b.eq(self.b)
716 m.d.comb += t.pb_en.eq(self.pb_en)
717
718 m.d.comb += self.terms[b_index].eq(t.term)
719
720 return m
721
722
723 class LSBNegTerm(Elaboratable):
724
725 def __init__(self, bit_width):
726 self.bit_width = bit_width
727 self.part = Signal(reset_less=True)
728 self.signed = Signal(reset_less=True)
729 self.op = Signal(bit_width, reset_less=True)
730 self.msb = Signal(reset_less=True)
731 self.nt = Signal(bit_width*2, reset_less=True)
732 self.nl = Signal(bit_width*2, reset_less=True)
733
734 def elaborate(self, platform):
735 m = Module()
736 comb = m.d.comb
737 bit_wid = self.bit_width
738 ext = Repl(0, bit_wid) # extend output to HI part
739
740 # determine sign of each incoming number *in this partition*
741 enabled = Signal(reset_less=True)
742 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
743
744 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
745 # negation operation is split into a bitwise not and a +1.
746 # likewise for 16, 32, and 64-bit values.
747
748 # width-extended 1s complement if a is signed, otherwise zero
749 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
750
751 # add 1 if signed, otherwise add zero
752 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
753
754 return m
755
756
757 class Parts(Elaboratable):
758
759 def __init__(self, pbwid, epps, n_parts):
760 self.pbwid = pbwid
761 # inputs
762 self.epps = PartitionPoints.like(epps, name="epps") # expanded points
763 # outputs
764 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
765
766 def elaborate(self, platform):
767 m = Module()
768
769 epps, parts = self.epps, self.parts
770 # collect part-bytes (double factor because the input is extended)
771 pbs = Signal(self.pbwid, reset_less=True)
772 tl = []
773 for i in range(self.pbwid):
774 pb = Signal(name="pb%d" % i, reset_less=True)
775 m.d.comb += pb.eq(epps.part_byte(i, mfactor=2)) # double
776 tl.append(pb)
777 m.d.comb += pbs.eq(Cat(*tl))
778
779 # negated-temporary copy of partition bits
780 npbs = Signal.like(pbs, reset_less=True)
781 m.d.comb += npbs.eq(~pbs)
782 byte_count = 8 // len(parts)
783 for i in range(len(parts)):
784 pbl = []
785 pbl.append(npbs[i * byte_count - 1])
786 for j in range(i * byte_count, (i + 1) * byte_count - 1):
787 pbl.append(pbs[j])
788 pbl.append(npbs[(i + 1) * byte_count - 1])
789 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
790 m.d.comb += value.eq(Cat(*pbl))
791 m.d.comb += parts[i].eq(~(value).bool())
792
793 return m
794
795
796 class Part(Elaboratable):
797 """ a key class which, depending on the partitioning, will determine
798 what action to take when parts of the output are signed or unsigned.
799
800 this requires 2 pieces of data *per operand, per partition*:
801 whether the MSB is HI/LO (per partition!), and whether a signed
802 or unsigned operation has been *requested*.
803
804 once that is determined, signed is basically carried out
805 by splitting 2's complement into 1's complement plus one.
806 1's complement is just a bit-inversion.
807
808 the extra terms - as separate terms - are then thrown at the
809 AddReduce alongside the multiplication part-results.
810 """
811 def __init__(self, epps, width, n_parts, n_levels, pbwid):
812
813 self.pbwid = pbwid
814 self.epps = epps
815
816 # inputs
817 self.a = Signal(64)
818 self.b = Signal(64)
819 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
820 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
821 self.pbs = Signal(pbwid, reset_less=True)
822
823 # outputs
824 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
825
826 self.not_a_term = Signal(width)
827 self.neg_lsb_a_term = Signal(width)
828 self.not_b_term = Signal(width)
829 self.neg_lsb_b_term = Signal(width)
830
831 def elaborate(self, platform):
832 m = Module()
833
834 pbs, parts = self.pbs, self.parts
835 epps = self.epps
836 m.submodules.p = p = Parts(self.pbwid, epps, len(parts))
837 m.d.comb += p.epps.eq(epps)
838 parts = p.parts
839
840 byte_count = 8 // len(parts)
841
842 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
843 self.not_a_term, self.neg_lsb_a_term,
844 self.not_b_term, self.neg_lsb_b_term)
845
846 byte_width = 8 // len(parts) # byte width
847 bit_wid = 8 * byte_width # bit width
848 nat, nbt, nla, nlb = [], [], [], []
849 for i in range(len(parts)):
850 # work out bit-inverted and +1 term for a.
851 pa = LSBNegTerm(bit_wid)
852 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
853 m.d.comb += pa.part.eq(parts[i])
854 m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
855 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
856 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
857 nat.append(pa.nt)
858 nla.append(pa.nl)
859
860 # work out bit-inverted and +1 term for b
861 pb = LSBNegTerm(bit_wid)
862 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
863 m.d.comb += pb.part.eq(parts[i])
864 m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
865 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
866 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
867 nbt.append(pb.nt)
868 nlb.append(pb.nl)
869
870 # concatenate together and return all 4 results.
871 m.d.comb += [not_a_term.eq(Cat(*nat)),
872 not_b_term.eq(Cat(*nbt)),
873 neg_lsb_a_term.eq(Cat(*nla)),
874 neg_lsb_b_term.eq(Cat(*nlb)),
875 ]
876
877 return m
878
879
880 class IntermediateOut(Elaboratable):
881 """ selects the HI/LO part of the multiplication, for a given bit-width
882 the output is also reconstructed in its SIMD (partition) lanes.
883 """
884 def __init__(self, width, out_wid, n_parts):
885 self.width = width
886 self.n_parts = n_parts
887 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
888 for i in range(8)]
889 self.intermed = Signal(out_wid, reset_less=True)
890 self.output = Signal(out_wid//2, reset_less=True)
891
892 def elaborate(self, platform):
893 m = Module()
894
895 ol = []
896 w = self.width
897 sel = w // 8
898 for i in range(self.n_parts):
899 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
900 m.d.comb += op.eq(
901 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
902 self.intermed.part(i * w*2, w),
903 self.intermed.part(i * w*2 + w, w)))
904 ol.append(op)
905 m.d.comb += self.output.eq(Cat(*ol))
906
907 return m
908
909
910 class FinalOut(Elaboratable):
911 """ selects the final output based on the partitioning.
912
913 each byte is selectable independently, i.e. it is possible
914 that some partitions requested 8-bit computation whilst others
915 requested 16 or 32 bit.
916 """
917 def __init__(self, out_wid):
918 # inputs
919 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
920 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
921 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
922
923 self.i8 = Signal(out_wid, reset_less=True)
924 self.i16 = Signal(out_wid, reset_less=True)
925 self.i32 = Signal(out_wid, reset_less=True)
926 self.i64 = Signal(out_wid, reset_less=True)
927
928 # output
929 self.out = Signal(out_wid, reset_less=True)
930
931 def elaborate(self, platform):
932 m = Module()
933 ol = []
934 for i in range(8):
935 # select one of the outputs: d8 selects i8, d16 selects i16
936 # d32 selects i32, and the default is i64.
937 # d8 and d16 are ORed together in the first Mux
938 # then the 2nd selects either i8 or i16.
939 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
940 op = Signal(8, reset_less=True, name="op_%d" % i)
941 m.d.comb += op.eq(
942 Mux(self.d8[i] | self.d16[i // 2],
943 Mux(self.d8[i], self.i8.part(i * 8, 8),
944 self.i16.part(i * 8, 8)),
945 Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
946 self.i64.part(i * 8, 8))))
947 ol.append(op)
948 m.d.comb += self.out.eq(Cat(*ol))
949 return m
950
951
952 class OrMod(Elaboratable):
953 """ ORs four values together in a hierarchical tree
954 """
955 def __init__(self, wid):
956 self.wid = wid
957 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
958 for i in range(4)]
959 self.orout = Signal(wid, reset_less=True)
960
961 def elaborate(self, platform):
962 m = Module()
963 or1 = Signal(self.wid, reset_less=True)
964 or2 = Signal(self.wid, reset_less=True)
965 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
966 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
967 m.d.comb += self.orout.eq(or1 | or2)
968
969 return m
970
971
972 class Signs(Elaboratable):
973 """ determines whether a or b are signed numbers
974 based on the required operation type (OP_MUL_*)
975 """
976
977 def __init__(self):
978 self.part_ops = Signal(2, reset_less=True)
979 self.a_signed = Signal(reset_less=True)
980 self.b_signed = Signal(reset_less=True)
981
982 def elaborate(self, platform):
983
984 m = Module()
985
986 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
987 bsig = (self.part_ops == OP_MUL_LOW) \
988 | (self.part_ops == OP_MUL_SIGNED_HIGH)
989 m.d.comb += self.a_signed.eq(asig)
990 m.d.comb += self.b_signed.eq(bsig)
991
992 return m
993
994
995 class Mul8_16_32_64(Elaboratable):
996 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
997
998 Supports partitioning into any combination of 8, 16, 32, and 64-bit
999 partitions on naturally-aligned boundaries. Supports the operation being
1000 set for each partition independently.
1001
1002 :attribute part_pts: the input partition points. Has a partition point at
1003 multiples of 8 in 0 < i < 64. Each partition point's associated
1004 ``Value`` is a ``Signal``. Modification not supported, except for by
1005 ``Signal.eq``.
1006 :attribute part_ops: the operation for each byte. The operation for a
1007 particular partition is selected by assigning the selected operation
1008 code to each byte in the partition. The allowed operation codes are:
1009
1010 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1011 RISC-V's `mul` instruction.
1012 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1013 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1014 instruction.
1015 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1016 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1017 `mulhsu` instruction.
1018 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1019 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1020 instruction.
1021 """
1022
1023 def __init__(self, register_levels=()):
1024 """ register_levels: specifies the points in the cascade at which
1025 flip-flops are to be inserted.
1026 """
1027
1028 # parameter(s)
1029 self.register_levels = list(register_levels)
1030
1031 # inputs
1032 self.part_pts = PartitionPoints()
1033 for i in range(8, 64, 8):
1034 self.part_pts[i] = Signal(name=f"part_pts_{i}")
1035 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
1036 self.a = Signal(64)
1037 self.b = Signal(64)
1038
1039 # intermediates (needed for unit tests)
1040 self._intermediate_output = Signal(128)
1041
1042 # output
1043 self.output = Signal(64)
1044
1045 def elaborate(self, platform):
1046 m = Module()
1047
1048 # collect part-bytes
1049 pbs = Signal(8, reset_less=True)
1050 tl = []
1051 for i in range(8):
1052 pb = Signal(name="pb%d" % i, reset_less=True)
1053 m.d.comb += pb.eq(self.part_pts.part_byte(i))
1054 tl.append(pb)
1055 m.d.comb += pbs.eq(Cat(*tl))
1056
1057 # create (doubled) PartitionPoints (output is double input width)
1058 expanded_part_pts = eps = PartitionPoints()
1059 for i, v in self.part_pts.items():
1060 ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
1061 expanded_part_pts[i * 2] = ep
1062 m.d.comb += ep.eq(v)
1063
1064 # local variables
1065 signs = []
1066 for i in range(8):
1067 s = Signs()
1068 signs.append(s)
1069 setattr(m.submodules, "signs%d" % i, s)
1070 m.d.comb += s.part_ops.eq(self.part_ops[i])
1071
1072 n_levels = len(self.register_levels)+1
1073 m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
1074 m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
1075 m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
1076 m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
1077 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1078 for mod in [part_8, part_16, part_32, part_64]:
1079 m.d.comb += mod.a.eq(self.a)
1080 m.d.comb += mod.b.eq(self.b)
1081 for i in range(len(signs)):
1082 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1083 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1084 m.d.comb += mod.pbs.eq(pbs)
1085 nat_l.append(mod.not_a_term)
1086 nbt_l.append(mod.not_b_term)
1087 nla_l.append(mod.neg_lsb_a_term)
1088 nlb_l.append(mod.neg_lsb_b_term)
1089
1090 terms = []
1091
1092 for a_index in range(8):
1093 t = ProductTerms(8, 128, 8, a_index, 8)
1094 setattr(m.submodules, "terms_%d" % a_index, t)
1095
1096 m.d.comb += t.a.eq(self.a)
1097 m.d.comb += t.b.eq(self.b)
1098 m.d.comb += t.pb_en.eq(pbs)
1099
1100 for term in t.terms:
1101 terms.append(term)
1102
1103 # it's fine to bitwise-or data together since they are never enabled
1104 # at the same time
1105 m.submodules.nat_or = nat_or = OrMod(128)
1106 m.submodules.nbt_or = nbt_or = OrMod(128)
1107 m.submodules.nla_or = nla_or = OrMod(128)
1108 m.submodules.nlb_or = nlb_or = OrMod(128)
1109 for l, mod in [(nat_l, nat_or),
1110 (nbt_l, nbt_or),
1111 (nla_l, nla_or),
1112 (nlb_l, nlb_or)]:
1113 for i in range(len(l)):
1114 m.d.comb += mod.orin[i].eq(l[i])
1115 terms.append(mod.orout)
1116
1117 add_reduce = AddReduce(terms,
1118 128,
1119 self.register_levels,
1120 expanded_part_pts,
1121 self.part_ops)
1122
1123 out_part_ops = add_reduce.levels[-1].out_part_ops
1124 out_part_pts = add_reduce.levels[-1]._reg_partition_points
1125
1126 m.submodules.add_reduce = add_reduce
1127 m.d.comb += self._intermediate_output.eq(add_reduce.output)
1128 # create _output_64
1129 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1130 m.d.comb += io64.intermed.eq(self._intermediate_output)
1131 for i in range(8):
1132 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1133
1134 # create _output_32
1135 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1136 m.d.comb += io32.intermed.eq(self._intermediate_output)
1137 for i in range(8):
1138 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1139
1140 # create _output_16
1141 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1142 m.d.comb += io16.intermed.eq(self._intermediate_output)
1143 for i in range(8):
1144 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1145
1146 # create _output_8
1147 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1148 m.d.comb += io8.intermed.eq(self._intermediate_output)
1149 for i in range(8):
1150 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1151
1152 m.submodules.p_8 = p_8 = Parts(8, eps, len(part_8.parts))
1153 m.submodules.p_16 = p_16 = Parts(8, eps, len(part_16.parts))
1154 m.submodules.p_32 = p_32 = Parts(8, eps, len(part_32.parts))
1155 m.submodules.p_64 = p_64 = Parts(8, eps, len(part_64.parts))
1156
1157 m.d.comb += p_8.epps.eq(out_part_pts)
1158 m.d.comb += p_16.epps.eq(out_part_pts)
1159 m.d.comb += p_32.epps.eq(out_part_pts)
1160 m.d.comb += p_64.epps.eq(out_part_pts)
1161
1162 # final output
1163 m.submodules.finalout = finalout = FinalOut(64)
1164 for i in range(len(part_8.parts)):
1165 m.d.comb += finalout.d8[i].eq(p_8.parts[i])
1166 for i in range(len(part_16.parts)):
1167 m.d.comb += finalout.d16[i].eq(p_16.parts[i])
1168 for i in range(len(part_32.parts)):
1169 m.d.comb += finalout.d32[i].eq(p_32.parts[i])
1170 m.d.comb += finalout.i8.eq(io8.output)
1171 m.d.comb += finalout.i16.eq(io16.output)
1172 m.d.comb += finalout.i32.eq(io32.output)
1173 m.d.comb += finalout.i64.eq(io64.output)
1174 m.d.comb += self.output.eq(finalout.out)
1175
1176 return m
1177
1178
1179 if __name__ == "__main__":
1180 m = Mul8_16_32_64()
1181 main(m, ports=[m.a,
1182 m.b,
1183 m._intermediate_output,
1184 m.output,
1185 *m.part_ops,
1186 *m.part_pts.values()])