move final adder to separate module
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0, mul=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
58 """
59 if name is None:
60 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
61 retval = PartitionPoints()
62 for point, enabled in self.items():
63 point *= mul
64 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
65 return retval
66
67 def eq(self, rhs):
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self.keys()) != set(rhs.keys()):
70 raise ValueError("incompatible point set")
71 for point, enabled in self.items():
72 yield enabled.eq(rhs[point])
73
74 def as_mask(self, width):
75 """Create a bit-mask from `self`.
76
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
79
80 :param width: the bit width of the resulting mask
81 """
82 bits = []
83 for i in range(width):
84 if i in self:
85 bits.append(~self[i])
86 else:
87 bits.append(True)
88 return Cat(*bits)
89
90 def get_max_partition_count(self, width):
91 """Get the maximum number of partitions.
92
93 Gets the number of partitions when all partition points are enabled.
94 """
95 retval = 1
96 for point in self.keys():
97 if point < width:
98 retval += 1
99 return retval
100
101 def fits_in_width(self, width):
102 """Check if all partition points are smaller than `width`."""
103 for point in self.keys():
104 if point >= width:
105 return False
106 return True
107
108 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
109 if index == -1 or index == 7:
110 return C(True, 1)
111 assert index >= 0 and index < 8
112 return self[(index * 8 + 8)*mfactor]
113
114
115 class FullAdder(Elaboratable):
116 """Full Adder.
117
118 :attribute in0: the first input
119 :attribute in1: the second input
120 :attribute in2: the third input
121 :attribute sum: the sum output
122 :attribute carry: the carry output
123
124 Rather than do individual full adders (and have an array of them,
125 which would be very slow to simulate), this module can specify the
126 bit width of the inputs and outputs: in effect it performs multiple
127 Full 3-2 Add operations "in parallel".
128 """
129
130 def __init__(self, width):
131 """Create a ``FullAdder``.
132
133 :param width: the bit width of the input and output
134 """
135 self.in0 = Signal(width)
136 self.in1 = Signal(width)
137 self.in2 = Signal(width)
138 self.sum = Signal(width)
139 self.carry = Signal(width)
140
141 def elaborate(self, platform):
142 """Elaborate this module."""
143 m = Module()
144 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
145 m.d.comb += self.carry.eq((self.in0 & self.in1)
146 | (self.in1 & self.in2)
147 | (self.in2 & self.in0))
148 return m
149
150
151 class MaskedFullAdder(Elaboratable):
152 """Masked Full Adder.
153
154 :attribute mask: the carry partition mask
155 :attribute in0: the first input
156 :attribute in1: the second input
157 :attribute in2: the third input
158 :attribute sum: the sum output
159 :attribute mcarry: the masked carry output
160
161 FullAdders are always used with a "mask" on the output. To keep
162 the graphviz "clean", this class performs the masking here rather
163 than inside a large for-loop.
164
165 See the following discussion as to why this is no longer derived
166 from FullAdder. Each carry is shifted here *before* being ANDed
167 with the mask, so that an AOI cell may be used (which is more
168 gate-efficient)
169 https://en.wikipedia.org/wiki/AND-OR-Invert
170 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
171 """
172
173 def __init__(self, width):
174 """Create a ``MaskedFullAdder``.
175
176 :param width: the bit width of the input and output
177 """
178 self.width = width
179 self.mask = Signal(width, reset_less=True)
180 self.mcarry = Signal(width, reset_less=True)
181 self.in0 = Signal(width, reset_less=True)
182 self.in1 = Signal(width, reset_less=True)
183 self.in2 = Signal(width, reset_less=True)
184 self.sum = Signal(width, reset_less=True)
185
186 def elaborate(self, platform):
187 """Elaborate this module."""
188 m = Module()
189 s1 = Signal(self.width, reset_less=True)
190 s2 = Signal(self.width, reset_less=True)
191 s3 = Signal(self.width, reset_less=True)
192 c1 = Signal(self.width, reset_less=True)
193 c2 = Signal(self.width, reset_less=True)
194 c3 = Signal(self.width, reset_less=True)
195 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
196 m.d.comb += s1.eq(Cat(0, self.in0))
197 m.d.comb += s2.eq(Cat(0, self.in1))
198 m.d.comb += s3.eq(Cat(0, self.in2))
199 m.d.comb += c1.eq(s1 & s2 & self.mask)
200 m.d.comb += c2.eq(s2 & s3 & self.mask)
201 m.d.comb += c3.eq(s3 & s1 & self.mask)
202 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
203 return m
204
205
206 class PartitionedAdder(Elaboratable):
207 """Partitioned Adder.
208
209 Performs the final add. The partition points are included in the
210 actual add (in one of the operands only), which causes a carry over
211 to the next bit. Then the final output *removes* the extra bits from
212 the result.
213
214 partition: .... P... P... P... P... (32 bits)
215 a : .... .... .... .... .... (32 bits)
216 b : .... .... .... .... .... (32 bits)
217 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
218 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
219 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
220 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
221
222 :attribute width: the bit width of the input and output. Read-only.
223 :attribute a: the first input to the adder
224 :attribute b: the second input to the adder
225 :attribute output: the sum output
226 :attribute partition_points: the input partition points. Modification not
227 supported, except for by ``Signal.eq``.
228 """
229
230 def __init__(self, width, partition_points):
231 """Create a ``PartitionedAdder``.
232
233 :param width: the bit width of the input and output
234 :param partition_points: the input partition points
235 """
236 self.width = width
237 self.a = Signal(width)
238 self.b = Signal(width)
239 self.output = Signal(width)
240 self.partition_points = PartitionPoints(partition_points)
241 if not self.partition_points.fits_in_width(width):
242 raise ValueError("partition_points doesn't fit in width")
243 expanded_width = 0
244 for i in range(self.width):
245 if i in self.partition_points:
246 expanded_width += 1
247 expanded_width += 1
248 self._expanded_width = expanded_width
249 # XXX these have to remain here due to some horrible nmigen
250 # simulation bugs involving sync. it is *not* necessary to
251 # have them here, they should (under normal circumstances)
252 # be moved into elaborate, as they are entirely local
253 self._expanded_a = Signal(expanded_width) # includes extra part-points
254 self._expanded_b = Signal(expanded_width) # likewise.
255 self._expanded_o = Signal(expanded_width) # likewise.
256
257 def elaborate(self, platform):
258 """Elaborate this module."""
259 m = Module()
260 expanded_index = 0
261 # store bits in a list, use Cat later. graphviz is much cleaner
262 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
263
264 # partition points are "breaks" (extra zeros or 1s) in what would
265 # otherwise be a massive long add. when the "break" points are 0,
266 # whatever is in it (in the output) is discarded. however when
267 # there is a "1", it causes a roll-over carry to the *next* bit.
268 # we still ignore the "break" bit in the [intermediate] output,
269 # however by that time we've got the effect that we wanted: the
270 # carry has been carried *over* the break point.
271
272 for i in range(self.width):
273 if i in self.partition_points:
274 # add extra bit set to 0 + 0 for enabled partition points
275 # and 1 + 0 for disabled partition points
276 ea.append(self._expanded_a[expanded_index])
277 al.append(~self.partition_points[i]) # add extra bit in a
278 eb.append(self._expanded_b[expanded_index])
279 bl.append(C(0)) # yes, add a zero
280 expanded_index += 1 # skip the extra point. NOT in the output
281 ea.append(self._expanded_a[expanded_index])
282 eb.append(self._expanded_b[expanded_index])
283 eo.append(self._expanded_o[expanded_index])
284 al.append(self.a[i])
285 bl.append(self.b[i])
286 ol.append(self.output[i])
287 expanded_index += 1
288
289 # combine above using Cat
290 m.d.comb += Cat(*ea).eq(Cat(*al))
291 m.d.comb += Cat(*eb).eq(Cat(*bl))
292 m.d.comb += Cat(*ol).eq(Cat(*eo))
293
294 # use only one addition to take advantage of look-ahead carry and
295 # special hardware on FPGAs
296 m.d.comb += self._expanded_o.eq(
297 self._expanded_a + self._expanded_b)
298 return m
299
300
301 FULL_ADDER_INPUT_COUNT = 3
302
303 class AddReduceData:
304
305 def __init__(self, ppoints, output_width, n_parts):
306 self.part_ops = [Signal(2, name=f"part_ops_{i}")
307 for i in range(n_parts)]
308 self.inputs = [Signal(output_width, name=f"inputs[{i}]")
309 for i in range(len(self.inputs))]
310 self.reg_partition_points = partition_points.like()
311
312 def eq(self, rhs):
313 return [self.reg_partition_points.eq(rhs.reg_partition_points)] + \
314 [self.inputs[i].eq(rhs.inputs[i])
315 for i in range(len(self.inputs))] + \
316 [self.part_ops[i].eq(rhs.part_ops[i])
317 for i in range(len(self.part_ops))]
318
319
320 class FinalAdd(Elaboratable):
321 """ Final stage of add reduce
322 """
323
324 def __init__(self, inputs, output_width, register_levels, partition_points,
325 part_ops):
326 self.part_ops = part_ops
327 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
328 for i in range(len(part_ops))]
329 self.inputs = list(inputs)
330 self._resized_inputs = [
331 Signal(output_width, name=f"resized_inputs[{i}]")
332 for i in range(len(self.inputs))]
333 self.register_levels = list(register_levels)
334 self.output = Signal(output_width)
335 self.partition_points = PartitionPoints(partition_points)
336 if not self.partition_points.fits_in_width(output_width):
337 raise ValueError("partition_points doesn't fit in output_width")
338 self._reg_partition_points = self.partition_points.like()
339
340 def elaborate(self, platform):
341 """Elaborate this module."""
342 m = Module()
343
344 # resize inputs to correct bit-width and optionally add in
345 # pipeline registers
346 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
347 for i in range(len(self.inputs))]
348 copy_part_ops = [self.out_part_ops[i].eq(self.part_ops[i])
349 for i in range(len(self.part_ops))]
350 if 0 in self.register_levels:
351 m.d.sync += copy_part_ops
352 m.d.sync += resized_input_assignments
353 m.d.sync += self._reg_partition_points.eq(self.partition_points)
354 else:
355 m.d.comb += copy_part_ops
356 m.d.comb += resized_input_assignments
357 m.d.comb += self._reg_partition_points.eq(self.partition_points)
358
359 if len(self.inputs) == 0:
360 # use 0 as the default output value
361 m.d.comb += self.output.eq(0)
362 elif len(self.inputs) == 1:
363 # handle single input
364 m.d.comb += self.output.eq(self._resized_inputs[0])
365 else:
366 # base case for adding 2 inputs
367 assert len(self.inputs) == 2
368 adder = PartitionedAdder(len(self.output),
369 self._reg_partition_points)
370 m.submodules.final_adder = adder
371 m.d.comb += adder.a.eq(self._resized_inputs[0])
372 m.d.comb += adder.b.eq(self._resized_inputs[1])
373 m.d.comb += self.output.eq(adder.output)
374 return m
375
376
377 class AddReduceSingle(Elaboratable):
378 """Add list of numbers together.
379
380 :attribute inputs: input ``Signal``s to be summed. Modification not
381 supported, except for by ``Signal.eq``.
382 :attribute register_levels: List of nesting levels that should have
383 pipeline registers.
384 :attribute output: output sum.
385 :attribute partition_points: the input partition points. Modification not
386 supported, except for by ``Signal.eq``.
387 """
388
389 def __init__(self, inputs, output_width, register_levels, partition_points,
390 part_ops):
391 """Create an ``AddReduce``.
392
393 :param inputs: input ``Signal``s to be summed.
394 :param output_width: bit-width of ``output``.
395 :param register_levels: List of nesting levels that should have
396 pipeline registers.
397 :param partition_points: the input partition points.
398 """
399 self.output_width = output_width
400 self.part_ops = part_ops
401 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
402 for i in range(len(part_ops))]
403 self.inputs = list(inputs)
404 self._resized_inputs = [
405 Signal(output_width, name=f"resized_inputs[{i}]")
406 for i in range(len(self.inputs))]
407 self.register_levels = list(register_levels)
408 self.partition_points = PartitionPoints(partition_points)
409 if not self.partition_points.fits_in_width(output_width):
410 raise ValueError("partition_points doesn't fit in output_width")
411 self._reg_partition_points = self.partition_points.like()
412
413 max_level = AddReduceSingle.get_max_level(len(self.inputs))
414 for level in self.register_levels:
415 if level > max_level:
416 raise ValueError(
417 "not enough adder levels for specified register levels")
418
419 # this is annoying. we have to create the modules (and terms)
420 # because we need to know what they are (in order to set up the
421 # interconnects back in AddReduce), but cannot do the m.d.comb +=
422 # etc because this is not in elaboratable.
423 self.groups = AddReduceSingle.full_adder_groups(len(self.inputs))
424 self._intermediate_terms = []
425 if len(self.groups) != 0:
426 self.create_next_terms()
427
428 @staticmethod
429 def get_max_level(input_count):
430 """Get the maximum level.
431
432 All ``register_levels`` must be less than or equal to the maximum
433 level.
434 """
435 retval = 0
436 while True:
437 groups = AddReduceSingle.full_adder_groups(input_count)
438 if len(groups) == 0:
439 return retval
440 input_count %= FULL_ADDER_INPUT_COUNT
441 input_count += 2 * len(groups)
442 retval += 1
443
444 @staticmethod
445 def full_adder_groups(input_count):
446 """Get ``inputs`` indices for which a full adder should be built."""
447 return range(0,
448 input_count - FULL_ADDER_INPUT_COUNT + 1,
449 FULL_ADDER_INPUT_COUNT)
450
451 def elaborate(self, platform):
452 """Elaborate this module."""
453 m = Module()
454
455 # resize inputs to correct bit-width and optionally add in
456 # pipeline registers
457 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
458 for i in range(len(self.inputs))]
459 copy_part_ops = [self.out_part_ops[i].eq(self.part_ops[i])
460 for i in range(len(self.part_ops))]
461 if 0 in self.register_levels:
462 m.d.sync += copy_part_ops
463 m.d.sync += resized_input_assignments
464 m.d.sync += self._reg_partition_points.eq(self.partition_points)
465 else:
466 m.d.comb += copy_part_ops
467 m.d.comb += resized_input_assignments
468 m.d.comb += self._reg_partition_points.eq(self.partition_points)
469
470 for (value, term) in self._intermediate_terms:
471 m.d.comb += term.eq(value)
472
473 mask = self._reg_partition_points.as_mask(self.output_width)
474 m.d.comb += self.part_mask.eq(mask)
475
476 # add and link the intermediate term modules
477 for i, (iidx, adder_i) in enumerate(self.adders):
478 setattr(m.submodules, f"adder_{i}", adder_i)
479
480 m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
481 m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
482 m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
483 m.d.comb += adder_i.mask.eq(self.part_mask)
484
485 return m
486
487 def create_next_terms(self):
488
489 # go on to prepare recursive case
490 intermediate_terms = []
491 _intermediate_terms = []
492
493 def add_intermediate_term(value):
494 intermediate_term = Signal(
495 self.output_width,
496 name=f"intermediate_terms[{len(intermediate_terms)}]")
497 _intermediate_terms.append((value, intermediate_term))
498 intermediate_terms.append(intermediate_term)
499
500 # store mask in intermediary (simplifies graph)
501 self.part_mask = Signal(self.output_width, reset_less=True)
502
503 # create full adders for this recursive level.
504 # this shrinks N terms to 2 * (N // 3) plus the remainder
505 self.adders = []
506 for i in self.groups:
507 adder_i = MaskedFullAdder(self.output_width)
508 self.adders.append((i, adder_i))
509 # add both the sum and the masked-carry to the next level.
510 # 3 inputs have now been reduced to 2...
511 add_intermediate_term(adder_i.sum)
512 add_intermediate_term(adder_i.mcarry)
513 # handle the remaining inputs.
514 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
515 add_intermediate_term(self._resized_inputs[-1])
516 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
517 # Just pass the terms to the next layer, since we wouldn't gain
518 # anything by using a half adder since there would still be 2 terms
519 # and just passing the terms to the next layer saves gates.
520 add_intermediate_term(self._resized_inputs[-2])
521 add_intermediate_term(self._resized_inputs[-1])
522 else:
523 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
524
525 self.intermediate_terms = intermediate_terms
526 self._intermediate_terms = _intermediate_terms
527
528
529 class AddReduce(Elaboratable):
530 """Recursively Add list of numbers together.
531
532 :attribute inputs: input ``Signal``s to be summed. Modification not
533 supported, except for by ``Signal.eq``.
534 :attribute register_levels: List of nesting levels that should have
535 pipeline registers.
536 :attribute output: output sum.
537 :attribute partition_points: the input partition points. Modification not
538 supported, except for by ``Signal.eq``.
539 """
540
541 def __init__(self, inputs, output_width, register_levels, partition_points,
542 part_ops):
543 """Create an ``AddReduce``.
544
545 :param inputs: input ``Signal``s to be summed.
546 :param output_width: bit-width of ``output``.
547 :param register_levels: List of nesting levels that should have
548 pipeline registers.
549 :param partition_points: the input partition points.
550 """
551 self.inputs = inputs
552 self.part_ops = part_ops
553 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
554 for i in range(len(part_ops))]
555 self.output = Signal(output_width)
556 self.output_width = output_width
557 self.register_levels = register_levels
558 self.partition_points = partition_points
559
560 self.create_levels()
561
562 @staticmethod
563 def get_max_level(input_count):
564 return AddReduceSingle.get_max_level(input_count)
565
566 @staticmethod
567 def next_register_levels(register_levels):
568 """``Iterable`` of ``register_levels`` for next recursive level."""
569 for level in register_levels:
570 if level > 0:
571 yield level - 1
572
573 def create_levels(self):
574 """creates reduction levels"""
575
576 mods = []
577 next_levels = self.register_levels
578 partition_points = self.partition_points
579 inputs = self.inputs
580 part_ops = self.part_ops
581 while True:
582 next_level = AddReduceSingle(inputs, self.output_width, next_levels,
583 partition_points, part_ops)
584 mods.append(next_level)
585 next_levels = list(AddReduce.next_register_levels(next_levels))
586 partition_points = next_level._reg_partition_points
587 inputs = next_level.intermediate_terms
588 part_ops = next_level.out_part_ops
589 groups = AddReduceSingle.full_adder_groups(len(inputs))
590 if len(groups) == 0:
591 break
592
593 next_level = FinalAdd(inputs, self.output_width, next_levels,
594 partition_points, part_ops)
595 mods.append(next_level)
596
597 self.levels = mods
598
599 def elaborate(self, platform):
600 """Elaborate this module."""
601 m = Module()
602
603 for i, next_level in enumerate(self.levels):
604 setattr(m.submodules, "next_level%d" % i, next_level)
605
606 # output comes from last module
607 m.d.comb += self.output.eq(next_level.output)
608 copy_part_ops = [self.out_part_ops[i].eq(next_level.out_part_ops[i])
609 for i in range(len(self.part_ops))]
610 m.d.comb += copy_part_ops
611
612 return m
613
614
615 OP_MUL_LOW = 0
616 OP_MUL_SIGNED_HIGH = 1
617 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
618 OP_MUL_UNSIGNED_HIGH = 3
619
620
621 def get_term(value, shift=0, enabled=None):
622 if enabled is not None:
623 value = Mux(enabled, value, 0)
624 if shift > 0:
625 value = Cat(Repl(C(0, 1), shift), value)
626 else:
627 assert shift == 0
628 return value
629
630
631 class ProductTerm(Elaboratable):
632 """ this class creates a single product term (a[..]*b[..]).
633 it has a design flaw in that is the *output* that is selected,
634 where the multiplication(s) are combinatorially generated
635 all the time.
636 """
637
638 def __init__(self, width, twidth, pbwid, a_index, b_index):
639 self.a_index = a_index
640 self.b_index = b_index
641 shift = 8 * (self.a_index + self.b_index)
642 self.pwidth = width
643 self.twidth = twidth
644 self.width = width*2
645 self.shift = shift
646
647 self.ti = Signal(self.width, reset_less=True)
648 self.term = Signal(twidth, reset_less=True)
649 self.a = Signal(twidth//2, reset_less=True)
650 self.b = Signal(twidth//2, reset_less=True)
651 self.pb_en = Signal(pbwid, reset_less=True)
652
653 self.tl = tl = []
654 min_index = min(self.a_index, self.b_index)
655 max_index = max(self.a_index, self.b_index)
656 for i in range(min_index, max_index):
657 tl.append(self.pb_en[i])
658 name = "te_%d_%d" % (self.a_index, self.b_index)
659 if len(tl) > 0:
660 term_enabled = Signal(name=name, reset_less=True)
661 else:
662 term_enabled = None
663 self.enabled = term_enabled
664 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
665
666 def elaborate(self, platform):
667
668 m = Module()
669 if self.enabled is not None:
670 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
671
672 bsa = Signal(self.width, reset_less=True)
673 bsb = Signal(self.width, reset_less=True)
674 a_index, b_index = self.a_index, self.b_index
675 pwidth = self.pwidth
676 m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
677 m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
678 m.d.comb += self.ti.eq(bsa * bsb)
679 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
680 """
681 #TODO: sort out width issues, get inputs a/b switched on/off.
682 #data going into Muxes is 1/2 the required width
683
684 pwidth = self.pwidth
685 width = self.width
686 bsa = Signal(self.twidth//2, reset_less=True)
687 bsb = Signal(self.twidth//2, reset_less=True)
688 asel = Signal(width, reset_less=True)
689 bsel = Signal(width, reset_less=True)
690 a_index, b_index = self.a_index, self.b_index
691 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
692 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
693 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
694 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
695 m.d.comb += self.ti.eq(bsa * bsb)
696 m.d.comb += self.term.eq(self.ti)
697 """
698
699 return m
700
701
702 class ProductTerms(Elaboratable):
703 """ creates a bank of product terms. also performs the actual bit-selection
704 this class is to be wrapped with a for-loop on the "a" operand.
705 it creates a second-level for-loop on the "b" operand.
706 """
707 def __init__(self, width, twidth, pbwid, a_index, blen):
708 self.a_index = a_index
709 self.blen = blen
710 self.pwidth = width
711 self.twidth = twidth
712 self.pbwid = pbwid
713 self.a = Signal(twidth//2, reset_less=True)
714 self.b = Signal(twidth//2, reset_less=True)
715 self.pb_en = Signal(pbwid, reset_less=True)
716 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
717 for i in range(blen)]
718
719 def elaborate(self, platform):
720
721 m = Module()
722
723 for b_index in range(self.blen):
724 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
725 self.a_index, b_index)
726 setattr(m.submodules, "term_%d" % b_index, t)
727
728 m.d.comb += t.a.eq(self.a)
729 m.d.comb += t.b.eq(self.b)
730 m.d.comb += t.pb_en.eq(self.pb_en)
731
732 m.d.comb += self.terms[b_index].eq(t.term)
733
734 return m
735
736
737 class LSBNegTerm(Elaboratable):
738
739 def __init__(self, bit_width):
740 self.bit_width = bit_width
741 self.part = Signal(reset_less=True)
742 self.signed = Signal(reset_less=True)
743 self.op = Signal(bit_width, reset_less=True)
744 self.msb = Signal(reset_less=True)
745 self.nt = Signal(bit_width*2, reset_less=True)
746 self.nl = Signal(bit_width*2, reset_less=True)
747
748 def elaborate(self, platform):
749 m = Module()
750 comb = m.d.comb
751 bit_wid = self.bit_width
752 ext = Repl(0, bit_wid) # extend output to HI part
753
754 # determine sign of each incoming number *in this partition*
755 enabled = Signal(reset_less=True)
756 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
757
758 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
759 # negation operation is split into a bitwise not and a +1.
760 # likewise for 16, 32, and 64-bit values.
761
762 # width-extended 1s complement if a is signed, otherwise zero
763 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
764
765 # add 1 if signed, otherwise add zero
766 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
767
768 return m
769
770
771 class Parts(Elaboratable):
772
773 def __init__(self, pbwid, epps, n_parts):
774 self.pbwid = pbwid
775 # inputs
776 self.epps = PartitionPoints.like(epps, name="epps") # expanded points
777 # outputs
778 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
779
780 def elaborate(self, platform):
781 m = Module()
782
783 epps, parts = self.epps, self.parts
784 # collect part-bytes (double factor because the input is extended)
785 pbs = Signal(self.pbwid, reset_less=True)
786 tl = []
787 for i in range(self.pbwid):
788 pb = Signal(name="pb%d" % i, reset_less=True)
789 m.d.comb += pb.eq(epps.part_byte(i, mfactor=2)) # double
790 tl.append(pb)
791 m.d.comb += pbs.eq(Cat(*tl))
792
793 # negated-temporary copy of partition bits
794 npbs = Signal.like(pbs, reset_less=True)
795 m.d.comb += npbs.eq(~pbs)
796 byte_count = 8 // len(parts)
797 for i in range(len(parts)):
798 pbl = []
799 pbl.append(npbs[i * byte_count - 1])
800 for j in range(i * byte_count, (i + 1) * byte_count - 1):
801 pbl.append(pbs[j])
802 pbl.append(npbs[(i + 1) * byte_count - 1])
803 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
804 m.d.comb += value.eq(Cat(*pbl))
805 m.d.comb += parts[i].eq(~(value).bool())
806
807 return m
808
809
810 class Part(Elaboratable):
811 """ a key class which, depending on the partitioning, will determine
812 what action to take when parts of the output are signed or unsigned.
813
814 this requires 2 pieces of data *per operand, per partition*:
815 whether the MSB is HI/LO (per partition!), and whether a signed
816 or unsigned operation has been *requested*.
817
818 once that is determined, signed is basically carried out
819 by splitting 2's complement into 1's complement plus one.
820 1's complement is just a bit-inversion.
821
822 the extra terms - as separate terms - are then thrown at the
823 AddReduce alongside the multiplication part-results.
824 """
825 def __init__(self, epps, width, n_parts, n_levels, pbwid):
826
827 self.pbwid = pbwid
828 self.epps = epps
829
830 # inputs
831 self.a = Signal(64)
832 self.b = Signal(64)
833 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
834 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
835 self.pbs = Signal(pbwid, reset_less=True)
836
837 # outputs
838 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
839
840 self.not_a_term = Signal(width)
841 self.neg_lsb_a_term = Signal(width)
842 self.not_b_term = Signal(width)
843 self.neg_lsb_b_term = Signal(width)
844
845 def elaborate(self, platform):
846 m = Module()
847
848 pbs, parts = self.pbs, self.parts
849 epps = self.epps
850 m.submodules.p = p = Parts(self.pbwid, epps, len(parts))
851 m.d.comb += p.epps.eq(epps)
852 parts = p.parts
853
854 byte_count = 8 // len(parts)
855
856 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
857 self.not_a_term, self.neg_lsb_a_term,
858 self.not_b_term, self.neg_lsb_b_term)
859
860 byte_width = 8 // len(parts) # byte width
861 bit_wid = 8 * byte_width # bit width
862 nat, nbt, nla, nlb = [], [], [], []
863 for i in range(len(parts)):
864 # work out bit-inverted and +1 term for a.
865 pa = LSBNegTerm(bit_wid)
866 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
867 m.d.comb += pa.part.eq(parts[i])
868 m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
869 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
870 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
871 nat.append(pa.nt)
872 nla.append(pa.nl)
873
874 # work out bit-inverted and +1 term for b
875 pb = LSBNegTerm(bit_wid)
876 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
877 m.d.comb += pb.part.eq(parts[i])
878 m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
879 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
880 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
881 nbt.append(pb.nt)
882 nlb.append(pb.nl)
883
884 # concatenate together and return all 4 results.
885 m.d.comb += [not_a_term.eq(Cat(*nat)),
886 not_b_term.eq(Cat(*nbt)),
887 neg_lsb_a_term.eq(Cat(*nla)),
888 neg_lsb_b_term.eq(Cat(*nlb)),
889 ]
890
891 return m
892
893
894 class IntermediateOut(Elaboratable):
895 """ selects the HI/LO part of the multiplication, for a given bit-width
896 the output is also reconstructed in its SIMD (partition) lanes.
897 """
898 def __init__(self, width, out_wid, n_parts):
899 self.width = width
900 self.n_parts = n_parts
901 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
902 for i in range(8)]
903 self.intermed = Signal(out_wid, reset_less=True)
904 self.output = Signal(out_wid//2, reset_less=True)
905
906 def elaborate(self, platform):
907 m = Module()
908
909 ol = []
910 w = self.width
911 sel = w // 8
912 for i in range(self.n_parts):
913 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
914 m.d.comb += op.eq(
915 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
916 self.intermed.part(i * w*2, w),
917 self.intermed.part(i * w*2 + w, w)))
918 ol.append(op)
919 m.d.comb += self.output.eq(Cat(*ol))
920
921 return m
922
923
924 class FinalOut(Elaboratable):
925 """ selects the final output based on the partitioning.
926
927 each byte is selectable independently, i.e. it is possible
928 that some partitions requested 8-bit computation whilst others
929 requested 16 or 32 bit.
930 """
931 def __init__(self, out_wid):
932 # inputs
933 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
934 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
935 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
936
937 self.i8 = Signal(out_wid, reset_less=True)
938 self.i16 = Signal(out_wid, reset_less=True)
939 self.i32 = Signal(out_wid, reset_less=True)
940 self.i64 = Signal(out_wid, reset_less=True)
941
942 # output
943 self.out = Signal(out_wid, reset_less=True)
944
945 def elaborate(self, platform):
946 m = Module()
947 ol = []
948 for i in range(8):
949 # select one of the outputs: d8 selects i8, d16 selects i16
950 # d32 selects i32, and the default is i64.
951 # d8 and d16 are ORed together in the first Mux
952 # then the 2nd selects either i8 or i16.
953 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
954 op = Signal(8, reset_less=True, name="op_%d" % i)
955 m.d.comb += op.eq(
956 Mux(self.d8[i] | self.d16[i // 2],
957 Mux(self.d8[i], self.i8.part(i * 8, 8),
958 self.i16.part(i * 8, 8)),
959 Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
960 self.i64.part(i * 8, 8))))
961 ol.append(op)
962 m.d.comb += self.out.eq(Cat(*ol))
963 return m
964
965
966 class OrMod(Elaboratable):
967 """ ORs four values together in a hierarchical tree
968 """
969 def __init__(self, wid):
970 self.wid = wid
971 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
972 for i in range(4)]
973 self.orout = Signal(wid, reset_less=True)
974
975 def elaborate(self, platform):
976 m = Module()
977 or1 = Signal(self.wid, reset_less=True)
978 or2 = Signal(self.wid, reset_less=True)
979 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
980 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
981 m.d.comb += self.orout.eq(or1 | or2)
982
983 return m
984
985
986 class Signs(Elaboratable):
987 """ determines whether a or b are signed numbers
988 based on the required operation type (OP_MUL_*)
989 """
990
991 def __init__(self):
992 self.part_ops = Signal(2, reset_less=True)
993 self.a_signed = Signal(reset_less=True)
994 self.b_signed = Signal(reset_less=True)
995
996 def elaborate(self, platform):
997
998 m = Module()
999
1000 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
1001 bsig = (self.part_ops == OP_MUL_LOW) \
1002 | (self.part_ops == OP_MUL_SIGNED_HIGH)
1003 m.d.comb += self.a_signed.eq(asig)
1004 m.d.comb += self.b_signed.eq(bsig)
1005
1006 return m
1007
1008
1009 class Mul8_16_32_64(Elaboratable):
1010 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1011
1012 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1013 partitions on naturally-aligned boundaries. Supports the operation being
1014 set for each partition independently.
1015
1016 :attribute part_pts: the input partition points. Has a partition point at
1017 multiples of 8 in 0 < i < 64. Each partition point's associated
1018 ``Value`` is a ``Signal``. Modification not supported, except for by
1019 ``Signal.eq``.
1020 :attribute part_ops: the operation for each byte. The operation for a
1021 particular partition is selected by assigning the selected operation
1022 code to each byte in the partition. The allowed operation codes are:
1023
1024 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1025 RISC-V's `mul` instruction.
1026 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1027 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1028 instruction.
1029 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1030 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1031 `mulhsu` instruction.
1032 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1033 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1034 instruction.
1035 """
1036
1037 def __init__(self, register_levels=()):
1038 """ register_levels: specifies the points in the cascade at which
1039 flip-flops are to be inserted.
1040 """
1041
1042 # parameter(s)
1043 self.register_levels = list(register_levels)
1044
1045 # inputs
1046 self.part_pts = PartitionPoints()
1047 for i in range(8, 64, 8):
1048 self.part_pts[i] = Signal(name=f"part_pts_{i}")
1049 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
1050 self.a = Signal(64)
1051 self.b = Signal(64)
1052
1053 # intermediates (needed for unit tests)
1054 self._intermediate_output = Signal(128)
1055
1056 # output
1057 self.output = Signal(64)
1058
1059 def elaborate(self, platform):
1060 m = Module()
1061
1062 # collect part-bytes
1063 pbs = Signal(8, reset_less=True)
1064 tl = []
1065 for i in range(8):
1066 pb = Signal(name="pb%d" % i, reset_less=True)
1067 m.d.comb += pb.eq(self.part_pts.part_byte(i))
1068 tl.append(pb)
1069 m.d.comb += pbs.eq(Cat(*tl))
1070
1071 # create (doubled) PartitionPoints (output is double input width)
1072 expanded_part_pts = eps = PartitionPoints()
1073 for i, v in self.part_pts.items():
1074 ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
1075 expanded_part_pts[i * 2] = ep
1076 m.d.comb += ep.eq(v)
1077
1078 # local variables
1079 signs = []
1080 for i in range(8):
1081 s = Signs()
1082 signs.append(s)
1083 setattr(m.submodules, "signs%d" % i, s)
1084 m.d.comb += s.part_ops.eq(self.part_ops[i])
1085
1086 n_levels = len(self.register_levels)+1
1087 m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
1088 m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
1089 m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
1090 m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
1091 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1092 for mod in [part_8, part_16, part_32, part_64]:
1093 m.d.comb += mod.a.eq(self.a)
1094 m.d.comb += mod.b.eq(self.b)
1095 for i in range(len(signs)):
1096 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1097 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1098 m.d.comb += mod.pbs.eq(pbs)
1099 nat_l.append(mod.not_a_term)
1100 nbt_l.append(mod.not_b_term)
1101 nla_l.append(mod.neg_lsb_a_term)
1102 nlb_l.append(mod.neg_lsb_b_term)
1103
1104 terms = []
1105
1106 for a_index in range(8):
1107 t = ProductTerms(8, 128, 8, a_index, 8)
1108 setattr(m.submodules, "terms_%d" % a_index, t)
1109
1110 m.d.comb += t.a.eq(self.a)
1111 m.d.comb += t.b.eq(self.b)
1112 m.d.comb += t.pb_en.eq(pbs)
1113
1114 for term in t.terms:
1115 terms.append(term)
1116
1117 # it's fine to bitwise-or data together since they are never enabled
1118 # at the same time
1119 m.submodules.nat_or = nat_or = OrMod(128)
1120 m.submodules.nbt_or = nbt_or = OrMod(128)
1121 m.submodules.nla_or = nla_or = OrMod(128)
1122 m.submodules.nlb_or = nlb_or = OrMod(128)
1123 for l, mod in [(nat_l, nat_or),
1124 (nbt_l, nbt_or),
1125 (nla_l, nla_or),
1126 (nlb_l, nlb_or)]:
1127 for i in range(len(l)):
1128 m.d.comb += mod.orin[i].eq(l[i])
1129 terms.append(mod.orout)
1130
1131 add_reduce = AddReduce(terms,
1132 128,
1133 self.register_levels,
1134 expanded_part_pts,
1135 self.part_ops)
1136
1137 out_part_ops = add_reduce.levels[-1].out_part_ops
1138 out_part_pts = add_reduce.levels[-1]._reg_partition_points
1139
1140 m.submodules.add_reduce = add_reduce
1141 m.d.comb += self._intermediate_output.eq(add_reduce.output)
1142 # create _output_64
1143 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1144 m.d.comb += io64.intermed.eq(self._intermediate_output)
1145 for i in range(8):
1146 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1147
1148 # create _output_32
1149 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1150 m.d.comb += io32.intermed.eq(self._intermediate_output)
1151 for i in range(8):
1152 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1153
1154 # create _output_16
1155 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1156 m.d.comb += io16.intermed.eq(self._intermediate_output)
1157 for i in range(8):
1158 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1159
1160 # create _output_8
1161 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1162 m.d.comb += io8.intermed.eq(self._intermediate_output)
1163 for i in range(8):
1164 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1165
1166 m.submodules.p_8 = p_8 = Parts(8, eps, len(part_8.parts))
1167 m.submodules.p_16 = p_16 = Parts(8, eps, len(part_16.parts))
1168 m.submodules.p_32 = p_32 = Parts(8, eps, len(part_32.parts))
1169 m.submodules.p_64 = p_64 = Parts(8, eps, len(part_64.parts))
1170
1171 m.d.comb += p_8.epps.eq(out_part_pts)
1172 m.d.comb += p_16.epps.eq(out_part_pts)
1173 m.d.comb += p_32.epps.eq(out_part_pts)
1174 m.d.comb += p_64.epps.eq(out_part_pts)
1175
1176 # final output
1177 m.submodules.finalout = finalout = FinalOut(64)
1178 for i in range(len(part_8.parts)):
1179 m.d.comb += finalout.d8[i].eq(p_8.parts[i])
1180 for i in range(len(part_16.parts)):
1181 m.d.comb += finalout.d16[i].eq(p_16.parts[i])
1182 for i in range(len(part_32.parts)):
1183 m.d.comb += finalout.d32[i].eq(p_32.parts[i])
1184 m.d.comb += finalout.i8.eq(io8.output)
1185 m.d.comb += finalout.i16.eq(io16.output)
1186 m.d.comb += finalout.i32.eq(io32.output)
1187 m.d.comb += finalout.i64.eq(io64.output)
1188 m.d.comb += self.output.eq(finalout.out)
1189
1190 return m
1191
1192
1193 if __name__ == "__main__":
1194 m = Mul8_16_32_64()
1195 main(m, ports=[m.a,
1196 m.b,
1197 m._intermediate_output,
1198 m.output,
1199 *m.part_ops,
1200 *m.part_pts.values()])