move expanded_part_pts further up
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 """
58 if name is None:
59 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
60 retval = PartitionPoints()
61 for point, enabled in self.items():
62 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
63 return retval
64
65 def eq(self, rhs):
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self.keys()) != set(rhs.keys()):
68 raise ValueError("incompatible point set")
69 for point, enabled in self.items():
70 yield enabled.eq(rhs[point])
71
72 def as_mask(self, width):
73 """Create a bit-mask from `self`.
74
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
77
78 :param width: the bit width of the resulting mask
79 """
80 bits = []
81 for i in range(width):
82 if i in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
107 if index == -1 or index == 7:
108 return C(True, 1)
109 assert index >= 0 and index < 8
110 return self[(index * 8 + 8)*mfactor]
111
112
113 class FullAdder(Elaboratable):
114 """Full Adder.
115
116 :attribute in0: the first input
117 :attribute in1: the second input
118 :attribute in2: the third input
119 :attribute sum: the sum output
120 :attribute carry: the carry output
121
122 Rather than do individual full adders (and have an array of them,
123 which would be very slow to simulate), this module can specify the
124 bit width of the inputs and outputs: in effect it performs multiple
125 Full 3-2 Add operations "in parallel".
126 """
127
128 def __init__(self, width):
129 """Create a ``FullAdder``.
130
131 :param width: the bit width of the input and output
132 """
133 self.in0 = Signal(width)
134 self.in1 = Signal(width)
135 self.in2 = Signal(width)
136 self.sum = Signal(width)
137 self.carry = Signal(width)
138
139 def elaborate(self, platform):
140 """Elaborate this module."""
141 m = Module()
142 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
143 m.d.comb += self.carry.eq((self.in0 & self.in1)
144 | (self.in1 & self.in2)
145 | (self.in2 & self.in0))
146 return m
147
148
149 class MaskedFullAdder(Elaboratable):
150 """Masked Full Adder.
151
152 :attribute mask: the carry partition mask
153 :attribute in0: the first input
154 :attribute in1: the second input
155 :attribute in2: the third input
156 :attribute sum: the sum output
157 :attribute mcarry: the masked carry output
158
159 FullAdders are always used with a "mask" on the output. To keep
160 the graphviz "clean", this class performs the masking here rather
161 than inside a large for-loop.
162
163 See the following discussion as to why this is no longer derived
164 from FullAdder. Each carry is shifted here *before* being ANDed
165 with the mask, so that an AOI cell may be used (which is more
166 gate-efficient)
167 https://en.wikipedia.org/wiki/AND-OR-Invert
168 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
169 """
170
171 def __init__(self, width):
172 """Create a ``MaskedFullAdder``.
173
174 :param width: the bit width of the input and output
175 """
176 self.width = width
177 self.mask = Signal(width, reset_less=True)
178 self.mcarry = Signal(width, reset_less=True)
179 self.in0 = Signal(width, reset_less=True)
180 self.in1 = Signal(width, reset_less=True)
181 self.in2 = Signal(width, reset_less=True)
182 self.sum = Signal(width, reset_less=True)
183
184 def elaborate(self, platform):
185 """Elaborate this module."""
186 m = Module()
187 s1 = Signal(self.width, reset_less=True)
188 s2 = Signal(self.width, reset_less=True)
189 s3 = Signal(self.width, reset_less=True)
190 c1 = Signal(self.width, reset_less=True)
191 c2 = Signal(self.width, reset_less=True)
192 c3 = Signal(self.width, reset_less=True)
193 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
194 m.d.comb += s1.eq(Cat(0, self.in0))
195 m.d.comb += s2.eq(Cat(0, self.in1))
196 m.d.comb += s3.eq(Cat(0, self.in2))
197 m.d.comb += c1.eq(s1 & s2 & self.mask)
198 m.d.comb += c2.eq(s2 & s3 & self.mask)
199 m.d.comb += c3.eq(s3 & s1 & self.mask)
200 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
201 return m
202
203
204 class PartitionedAdder(Elaboratable):
205 """Partitioned Adder.
206
207 Performs the final add. The partition points are included in the
208 actual add (in one of the operands only), which causes a carry over
209 to the next bit. Then the final output *removes* the extra bits from
210 the result.
211
212 partition: .... P... P... P... P... (32 bits)
213 a : .... .... .... .... .... (32 bits)
214 b : .... .... .... .... .... (32 bits)
215 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
216 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
217 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
218 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
219
220 :attribute width: the bit width of the input and output. Read-only.
221 :attribute a: the first input to the adder
222 :attribute b: the second input to the adder
223 :attribute output: the sum output
224 :attribute partition_points: the input partition points. Modification not
225 supported, except for by ``Signal.eq``.
226 """
227
228 def __init__(self, width, partition_points):
229 """Create a ``PartitionedAdder``.
230
231 :param width: the bit width of the input and output
232 :param partition_points: the input partition points
233 """
234 self.width = width
235 self.a = Signal(width)
236 self.b = Signal(width)
237 self.output = Signal(width)
238 self.partition_points = PartitionPoints(partition_points)
239 if not self.partition_points.fits_in_width(width):
240 raise ValueError("partition_points doesn't fit in width")
241 expanded_width = 0
242 for i in range(self.width):
243 if i in self.partition_points:
244 expanded_width += 1
245 expanded_width += 1
246 self._expanded_width = expanded_width
247 # XXX these have to remain here due to some horrible nmigen
248 # simulation bugs involving sync. it is *not* necessary to
249 # have them here, they should (under normal circumstances)
250 # be moved into elaborate, as they are entirely local
251 self._expanded_a = Signal(expanded_width) # includes extra part-points
252 self._expanded_b = Signal(expanded_width) # likewise.
253 self._expanded_o = Signal(expanded_width) # likewise.
254
255 def elaborate(self, platform):
256 """Elaborate this module."""
257 m = Module()
258 expanded_index = 0
259 # store bits in a list, use Cat later. graphviz is much cleaner
260 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
261
262 # partition points are "breaks" (extra zeros or 1s) in what would
263 # otherwise be a massive long add. when the "break" points are 0,
264 # whatever is in it (in the output) is discarded. however when
265 # there is a "1", it causes a roll-over carry to the *next* bit.
266 # we still ignore the "break" bit in the [intermediate] output,
267 # however by that time we've got the effect that we wanted: the
268 # carry has been carried *over* the break point.
269
270 for i in range(self.width):
271 if i in self.partition_points:
272 # add extra bit set to 0 + 0 for enabled partition points
273 # and 1 + 0 for disabled partition points
274 ea.append(self._expanded_a[expanded_index])
275 al.append(~self.partition_points[i]) # add extra bit in a
276 eb.append(self._expanded_b[expanded_index])
277 bl.append(C(0)) # yes, add a zero
278 expanded_index += 1 # skip the extra point. NOT in the output
279 ea.append(self._expanded_a[expanded_index])
280 eb.append(self._expanded_b[expanded_index])
281 eo.append(self._expanded_o[expanded_index])
282 al.append(self.a[i])
283 bl.append(self.b[i])
284 ol.append(self.output[i])
285 expanded_index += 1
286
287 # combine above using Cat
288 m.d.comb += Cat(*ea).eq(Cat(*al))
289 m.d.comb += Cat(*eb).eq(Cat(*bl))
290 m.d.comb += Cat(*ol).eq(Cat(*eo))
291
292 # use only one addition to take advantage of look-ahead carry and
293 # special hardware on FPGAs
294 m.d.comb += self._expanded_o.eq(
295 self._expanded_a + self._expanded_b)
296 return m
297
298
299 FULL_ADDER_INPUT_COUNT = 3
300
301
302 class AddReduceSingle(Elaboratable):
303 """Add list of numbers together.
304
305 :attribute inputs: input ``Signal``s to be summed. Modification not
306 supported, except for by ``Signal.eq``.
307 :attribute register_levels: List of nesting levels that should have
308 pipeline registers.
309 :attribute output: output sum.
310 :attribute partition_points: the input partition points. Modification not
311 supported, except for by ``Signal.eq``.
312 """
313
314 def __init__(self, inputs, output_width, register_levels, partition_points,
315 part_ops):
316 """Create an ``AddReduce``.
317
318 :param inputs: input ``Signal``s to be summed.
319 :param output_width: bit-width of ``output``.
320 :param register_levels: List of nesting levels that should have
321 pipeline registers.
322 :param partition_points: the input partition points.
323 """
324 self.part_ops = part_ops
325 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
326 for i in range(len(part_ops))]
327 self.inputs = list(inputs)
328 self._resized_inputs = [
329 Signal(output_width, name=f"resized_inputs[{i}]")
330 for i in range(len(self.inputs))]
331 self.register_levels = list(register_levels)
332 self.output = Signal(output_width)
333 self.partition_points = PartitionPoints(partition_points)
334 if not self.partition_points.fits_in_width(output_width):
335 raise ValueError("partition_points doesn't fit in output_width")
336 self._reg_partition_points = self.partition_points.like()
337
338 max_level = AddReduceSingle.get_max_level(len(self.inputs))
339 for level in self.register_levels:
340 if level > max_level:
341 raise ValueError(
342 "not enough adder levels for specified register levels")
343
344 # this is annoying. we have to create the modules (and terms)
345 # because we need to know what they are (in order to set up the
346 # interconnects back in AddReduce), but cannot do the m.d.comb +=
347 # etc because this is not in elaboratable.
348 self.groups = AddReduceSingle.full_adder_groups(len(self.inputs))
349 self._intermediate_terms = []
350 if len(self.groups) != 0:
351 self.create_next_terms()
352
353 @staticmethod
354 def get_max_level(input_count):
355 """Get the maximum level.
356
357 All ``register_levels`` must be less than or equal to the maximum
358 level.
359 """
360 retval = 0
361 while True:
362 groups = AddReduceSingle.full_adder_groups(input_count)
363 if len(groups) == 0:
364 return retval
365 input_count %= FULL_ADDER_INPUT_COUNT
366 input_count += 2 * len(groups)
367 retval += 1
368
369 @staticmethod
370 def full_adder_groups(input_count):
371 """Get ``inputs`` indices for which a full adder should be built."""
372 return range(0,
373 input_count - FULL_ADDER_INPUT_COUNT + 1,
374 FULL_ADDER_INPUT_COUNT)
375
376 def elaborate(self, platform):
377 """Elaborate this module."""
378 m = Module()
379
380 # resize inputs to correct bit-width and optionally add in
381 # pipeline registers
382 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
383 for i in range(len(self.inputs))]
384 copy_part_ops = [self.out_part_ops[i].eq(self.part_ops[i])
385 for i in range(len(self.part_ops))]
386 if 0 in self.register_levels:
387 m.d.sync += copy_part_ops
388 m.d.sync += resized_input_assignments
389 m.d.sync += self._reg_partition_points.eq(self.partition_points)
390 else:
391 m.d.comb += copy_part_ops
392 m.d.comb += resized_input_assignments
393 m.d.comb += self._reg_partition_points.eq(self.partition_points)
394
395 for (value, term) in self._intermediate_terms:
396 m.d.comb += term.eq(value)
397
398 # if there are no full adders to create, then we handle the base cases
399 # and return, otherwise we go on to the recursive case
400 if len(self.groups) == 0:
401 if len(self.inputs) == 0:
402 # use 0 as the default output value
403 m.d.comb += self.output.eq(0)
404 elif len(self.inputs) == 1:
405 # handle single input
406 m.d.comb += self.output.eq(self._resized_inputs[0])
407 else:
408 # base case for adding 2 inputs
409 assert len(self.inputs) == 2
410 adder = PartitionedAdder(len(self.output),
411 self._reg_partition_points)
412 m.submodules.final_adder = adder
413 m.d.comb += adder.a.eq(self._resized_inputs[0])
414 m.d.comb += adder.b.eq(self._resized_inputs[1])
415 m.d.comb += self.output.eq(adder.output)
416 return m
417
418 mask = self._reg_partition_points.as_mask(len(self.output))
419 m.d.comb += self.part_mask.eq(mask)
420
421 # add and link the intermediate term modules
422 for i, (iidx, adder_i) in enumerate(self.adders):
423 setattr(m.submodules, f"adder_{i}", adder_i)
424
425 m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
426 m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
427 m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
428 m.d.comb += adder_i.mask.eq(self.part_mask)
429
430 return m
431
432 def create_next_terms(self):
433
434 # go on to prepare recursive case
435 intermediate_terms = []
436 _intermediate_terms = []
437
438 def add_intermediate_term(value):
439 intermediate_term = Signal(
440 len(self.output),
441 name=f"intermediate_terms[{len(intermediate_terms)}]")
442 _intermediate_terms.append((value, intermediate_term))
443 intermediate_terms.append(intermediate_term)
444
445 # store mask in intermediary (simplifies graph)
446 self.part_mask = Signal(len(self.output), reset_less=True)
447
448 # create full adders for this recursive level.
449 # this shrinks N terms to 2 * (N // 3) plus the remainder
450 self.adders = []
451 for i in self.groups:
452 adder_i = MaskedFullAdder(len(self.output))
453 self.adders.append((i, adder_i))
454 # add both the sum and the masked-carry to the next level.
455 # 3 inputs have now been reduced to 2...
456 add_intermediate_term(adder_i.sum)
457 add_intermediate_term(adder_i.mcarry)
458 # handle the remaining inputs.
459 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
460 add_intermediate_term(self._resized_inputs[-1])
461 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
462 # Just pass the terms to the next layer, since we wouldn't gain
463 # anything by using a half adder since there would still be 2 terms
464 # and just passing the terms to the next layer saves gates.
465 add_intermediate_term(self._resized_inputs[-2])
466 add_intermediate_term(self._resized_inputs[-1])
467 else:
468 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
469
470 self.intermediate_terms = intermediate_terms
471 self._intermediate_terms = _intermediate_terms
472
473
474 class AddReduce(Elaboratable):
475 """Recursively Add list of numbers together.
476
477 :attribute inputs: input ``Signal``s to be summed. Modification not
478 supported, except for by ``Signal.eq``.
479 :attribute register_levels: List of nesting levels that should have
480 pipeline registers.
481 :attribute output: output sum.
482 :attribute partition_points: the input partition points. Modification not
483 supported, except for by ``Signal.eq``.
484 """
485
486 def __init__(self, inputs, output_width, register_levels, partition_points,
487 part_ops):
488 """Create an ``AddReduce``.
489
490 :param inputs: input ``Signal``s to be summed.
491 :param output_width: bit-width of ``output``.
492 :param register_levels: List of nesting levels that should have
493 pipeline registers.
494 :param partition_points: the input partition points.
495 """
496 self.inputs = inputs
497 self.part_ops = part_ops
498 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
499 for i in range(len(part_ops))]
500 self.output = Signal(output_width)
501 self.output_width = output_width
502 self.register_levels = register_levels
503 self.partition_points = partition_points
504
505 self.create_levels()
506
507 @staticmethod
508 def get_max_level(input_count):
509 return AddReduceSingle.get_max_level(input_count)
510
511 @staticmethod
512 def next_register_levels(register_levels):
513 """``Iterable`` of ``register_levels`` for next recursive level."""
514 for level in register_levels:
515 if level > 0:
516 yield level - 1
517
518 def create_levels(self):
519 """creates reduction levels"""
520
521 mods = []
522 next_levels = self.register_levels
523 partition_points = self.partition_points
524 inputs = self.inputs
525 part_ops = self.part_ops
526 while True:
527 next_level = AddReduceSingle(inputs, self.output_width, next_levels,
528 partition_points, part_ops)
529 mods.append(next_level)
530 if len(next_level.groups) == 0:
531 break
532 next_levels = list(AddReduce.next_register_levels(next_levels))
533 partition_points = next_level._reg_partition_points
534 inputs = next_level.intermediate_terms
535 part_ops = next_level.out_part_ops
536
537 self.levels = mods
538
539 def elaborate(self, platform):
540 """Elaborate this module."""
541 m = Module()
542
543 for i, next_level in enumerate(self.levels):
544 setattr(m.submodules, "next_level%d" % i, next_level)
545
546 # output comes from last module
547 m.d.comb += self.output.eq(next_level.output)
548 copy_part_ops = [self.out_part_ops[i].eq(next_level.out_part_ops[i])
549 for i in range(len(self.part_ops))]
550 m.d.comb += copy_part_ops
551
552 return m
553
554
555 OP_MUL_LOW = 0
556 OP_MUL_SIGNED_HIGH = 1
557 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
558 OP_MUL_UNSIGNED_HIGH = 3
559
560
561 def get_term(value, shift=0, enabled=None):
562 if enabled is not None:
563 value = Mux(enabled, value, 0)
564 if shift > 0:
565 value = Cat(Repl(C(0, 1), shift), value)
566 else:
567 assert shift == 0
568 return value
569
570
571 class ProductTerm(Elaboratable):
572 """ this class creates a single product term (a[..]*b[..]).
573 it has a design flaw in that is the *output* that is selected,
574 where the multiplication(s) are combinatorially generated
575 all the time.
576 """
577
578 def __init__(self, width, twidth, pbwid, a_index, b_index):
579 self.a_index = a_index
580 self.b_index = b_index
581 shift = 8 * (self.a_index + self.b_index)
582 self.pwidth = width
583 self.twidth = twidth
584 self.width = width*2
585 self.shift = shift
586
587 self.ti = Signal(self.width, reset_less=True)
588 self.term = Signal(twidth, reset_less=True)
589 self.a = Signal(twidth//2, reset_less=True)
590 self.b = Signal(twidth//2, reset_less=True)
591 self.pb_en = Signal(pbwid, reset_less=True)
592
593 self.tl = tl = []
594 min_index = min(self.a_index, self.b_index)
595 max_index = max(self.a_index, self.b_index)
596 for i in range(min_index, max_index):
597 tl.append(self.pb_en[i])
598 name = "te_%d_%d" % (self.a_index, self.b_index)
599 if len(tl) > 0:
600 term_enabled = Signal(name=name, reset_less=True)
601 else:
602 term_enabled = None
603 self.enabled = term_enabled
604 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
605
606 def elaborate(self, platform):
607
608 m = Module()
609 if self.enabled is not None:
610 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
611
612 bsa = Signal(self.width, reset_less=True)
613 bsb = Signal(self.width, reset_less=True)
614 a_index, b_index = self.a_index, self.b_index
615 pwidth = self.pwidth
616 m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
617 m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
618 m.d.comb += self.ti.eq(bsa * bsb)
619 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
620 """
621 #TODO: sort out width issues, get inputs a/b switched on/off.
622 #data going into Muxes is 1/2 the required width
623
624 pwidth = self.pwidth
625 width = self.width
626 bsa = Signal(self.twidth//2, reset_less=True)
627 bsb = Signal(self.twidth//2, reset_less=True)
628 asel = Signal(width, reset_less=True)
629 bsel = Signal(width, reset_less=True)
630 a_index, b_index = self.a_index, self.b_index
631 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
632 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
633 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
634 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
635 m.d.comb += self.ti.eq(bsa * bsb)
636 m.d.comb += self.term.eq(self.ti)
637 """
638
639 return m
640
641
642 class ProductTerms(Elaboratable):
643 """ creates a bank of product terms. also performs the actual bit-selection
644 this class is to be wrapped with a for-loop on the "a" operand.
645 it creates a second-level for-loop on the "b" operand.
646 """
647 def __init__(self, width, twidth, pbwid, a_index, blen):
648 self.a_index = a_index
649 self.blen = blen
650 self.pwidth = width
651 self.twidth = twidth
652 self.pbwid = pbwid
653 self.a = Signal(twidth//2, reset_less=True)
654 self.b = Signal(twidth//2, reset_less=True)
655 self.pb_en = Signal(pbwid, reset_less=True)
656 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
657 for i in range(blen)]
658
659 def elaborate(self, platform):
660
661 m = Module()
662
663 for b_index in range(self.blen):
664 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
665 self.a_index, b_index)
666 setattr(m.submodules, "term_%d" % b_index, t)
667
668 m.d.comb += t.a.eq(self.a)
669 m.d.comb += t.b.eq(self.b)
670 m.d.comb += t.pb_en.eq(self.pb_en)
671
672 m.d.comb += self.terms[b_index].eq(t.term)
673
674 return m
675
676
677 class LSBNegTerm(Elaboratable):
678
679 def __init__(self, bit_width):
680 self.bit_width = bit_width
681 self.part = Signal(reset_less=True)
682 self.signed = Signal(reset_less=True)
683 self.op = Signal(bit_width, reset_less=True)
684 self.msb = Signal(reset_less=True)
685 self.nt = Signal(bit_width*2, reset_less=True)
686 self.nl = Signal(bit_width*2, reset_less=True)
687
688 def elaborate(self, platform):
689 m = Module()
690 comb = m.d.comb
691 bit_wid = self.bit_width
692 ext = Repl(0, bit_wid) # extend output to HI part
693
694 # determine sign of each incoming number *in this partition*
695 enabled = Signal(reset_less=True)
696 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
697
698 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
699 # negation operation is split into a bitwise not and a +1.
700 # likewise for 16, 32, and 64-bit values.
701
702 # width-extended 1s complement if a is signed, otherwise zero
703 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
704
705 # add 1 if signed, otherwise add zero
706 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
707
708 return m
709
710
711 class Part(Elaboratable):
712 """ a key class which, depending on the partitioning, will determine
713 what action to take when parts of the output are signed or unsigned.
714
715 this requires 2 pieces of data *per operand, per partition*:
716 whether the MSB is HI/LO (per partition!), and whether a signed
717 or unsigned operation has been *requested*.
718
719 once that is determined, signed is basically carried out
720 by splitting 2's complement into 1's complement plus one.
721 1's complement is just a bit-inversion.
722
723 the extra terms - as separate terms - are then thrown at the
724 AddReduce alongside the multiplication part-results.
725 """
726 def __init__(self, width, n_parts, n_levels, pbwid):
727
728 # inputs
729 self.a = Signal(64)
730 self.b = Signal(64)
731 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
732 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
733 self.pbs = Signal(pbwid, reset_less=True)
734
735 # outputs
736 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
737 self.delayed_parts = [
738 [Signal(name=f"delayed_part_{delay}_{i}")
739 for i in range(n_parts)]
740 for delay in range(n_levels)]
741 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
742 self.dplast = [Signal(name=f"dplast_{i}")
743 for i in range(n_parts)]
744
745 self.not_a_term = Signal(width)
746 self.neg_lsb_a_term = Signal(width)
747 self.not_b_term = Signal(width)
748 self.neg_lsb_b_term = Signal(width)
749
750 def elaborate(self, platform):
751 m = Module()
752
753 pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
754 # negated-temporary copy of partition bits
755 npbs = Signal.like(pbs, reset_less=True)
756 m.d.comb += npbs.eq(~pbs)
757 byte_count = 8 // len(parts)
758 for i in range(len(parts)):
759 pbl = []
760 pbl.append(npbs[i * byte_count - 1])
761 for j in range(i * byte_count, (i + 1) * byte_count - 1):
762 pbl.append(pbs[j])
763 pbl.append(npbs[(i + 1) * byte_count - 1])
764 value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
765 m.d.comb += value.eq(Cat(*pbl))
766 m.d.comb += parts[i].eq(~(value).bool())
767 m.d.comb += delayed_parts[0][i].eq(parts[i])
768 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
769 for j in range(len(delayed_parts)-1)]
770 m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
771
772 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
773 self.not_a_term, self.neg_lsb_a_term, \
774 self.not_b_term, self.neg_lsb_b_term
775
776 byte_width = 8 // len(parts) # byte width
777 bit_wid = 8 * byte_width # bit width
778 nat, nbt, nla, nlb = [], [], [], []
779 for i in range(len(parts)):
780 # work out bit-inverted and +1 term for a.
781 pa = LSBNegTerm(bit_wid)
782 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
783 m.d.comb += pa.part.eq(parts[i])
784 m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
785 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
786 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
787 nat.append(pa.nt)
788 nla.append(pa.nl)
789
790 # work out bit-inverted and +1 term for b
791 pb = LSBNegTerm(bit_wid)
792 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
793 m.d.comb += pb.part.eq(parts[i])
794 m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
795 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
796 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
797 nbt.append(pb.nt)
798 nlb.append(pb.nl)
799
800 # concatenate together and return all 4 results.
801 m.d.comb += [not_a_term.eq(Cat(*nat)),
802 not_b_term.eq(Cat(*nbt)),
803 neg_lsb_a_term.eq(Cat(*nla)),
804 neg_lsb_b_term.eq(Cat(*nlb)),
805 ]
806
807 return m
808
809
810 class IntermediateOut(Elaboratable):
811 """ selects the HI/LO part of the multiplication, for a given bit-width
812 the output is also reconstructed in its SIMD (partition) lanes.
813 """
814 def __init__(self, width, out_wid, n_parts):
815 self.width = width
816 self.n_parts = n_parts
817 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
818 for i in range(8)]
819 self.intermed = Signal(out_wid, reset_less=True)
820 self.output = Signal(out_wid//2, reset_less=True)
821
822 def elaborate(self, platform):
823 m = Module()
824
825 ol = []
826 w = self.width
827 sel = w // 8
828 for i in range(self.n_parts):
829 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
830 m.d.comb += op.eq(
831 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
832 self.intermed.part(i * w*2, w),
833 self.intermed.part(i * w*2 + w, w)))
834 ol.append(op)
835 m.d.comb += self.output.eq(Cat(*ol))
836
837 return m
838
839
840 class FinalOut(Elaboratable):
841 """ selects the final output based on the partitioning.
842
843 each byte is selectable independently, i.e. it is possible
844 that some partitions requested 8-bit computation whilst others
845 requested 16 or 32 bit.
846 """
847 def __init__(self, out_wid):
848 # inputs
849 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
850 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
851 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
852
853 self.i8 = Signal(out_wid, reset_less=True)
854 self.i16 = Signal(out_wid, reset_less=True)
855 self.i32 = Signal(out_wid, reset_less=True)
856 self.i64 = Signal(out_wid, reset_less=True)
857
858 # output
859 self.out = Signal(out_wid, reset_less=True)
860
861 def elaborate(self, platform):
862 m = Module()
863 ol = []
864 for i in range(8):
865 # select one of the outputs: d8 selects i8, d16 selects i16
866 # d32 selects i32, and the default is i64.
867 # d8 and d16 are ORed together in the first Mux
868 # then the 2nd selects either i8 or i16.
869 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
870 op = Signal(8, reset_less=True, name="op_%d" % i)
871 m.d.comb += op.eq(
872 Mux(self.d8[i] | self.d16[i // 2],
873 Mux(self.d8[i], self.i8.part(i * 8, 8),
874 self.i16.part(i * 8, 8)),
875 Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
876 self.i64.part(i * 8, 8))))
877 ol.append(op)
878 m.d.comb += self.out.eq(Cat(*ol))
879 return m
880
881
882 class OrMod(Elaboratable):
883 """ ORs four values together in a hierarchical tree
884 """
885 def __init__(self, wid):
886 self.wid = wid
887 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
888 for i in range(4)]
889 self.orout = Signal(wid, reset_less=True)
890
891 def elaborate(self, platform):
892 m = Module()
893 or1 = Signal(self.wid, reset_less=True)
894 or2 = Signal(self.wid, reset_less=True)
895 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
896 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
897 m.d.comb += self.orout.eq(or1 | or2)
898
899 return m
900
901
902 class Signs(Elaboratable):
903 """ determines whether a or b are signed numbers
904 based on the required operation type (OP_MUL_*)
905 """
906
907 def __init__(self):
908 self.part_ops = Signal(2, reset_less=True)
909 self.a_signed = Signal(reset_less=True)
910 self.b_signed = Signal(reset_less=True)
911
912 def elaborate(self, platform):
913
914 m = Module()
915
916 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
917 bsig = (self.part_ops == OP_MUL_LOW) \
918 | (self.part_ops == OP_MUL_SIGNED_HIGH)
919 m.d.comb += self.a_signed.eq(asig)
920 m.d.comb += self.b_signed.eq(bsig)
921
922 return m
923
924
925 class Mul8_16_32_64(Elaboratable):
926 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
927
928 Supports partitioning into any combination of 8, 16, 32, and 64-bit
929 partitions on naturally-aligned boundaries. Supports the operation being
930 set for each partition independently.
931
932 :attribute part_pts: the input partition points. Has a partition point at
933 multiples of 8 in 0 < i < 64. Each partition point's associated
934 ``Value`` is a ``Signal``. Modification not supported, except for by
935 ``Signal.eq``.
936 :attribute part_ops: the operation for each byte. The operation for a
937 particular partition is selected by assigning the selected operation
938 code to each byte in the partition. The allowed operation codes are:
939
940 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
941 RISC-V's `mul` instruction.
942 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
943 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
944 instruction.
945 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
946 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
947 `mulhsu` instruction.
948 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
949 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
950 instruction.
951 """
952
953 def __init__(self, register_levels=()):
954 """ register_levels: specifies the points in the cascade at which
955 flip-flops are to be inserted.
956 """
957
958 # parameter(s)
959 self.register_levels = list(register_levels)
960
961 # inputs
962 self.part_pts = PartitionPoints()
963 for i in range(8, 64, 8):
964 self.part_pts[i] = Signal(name=f"part_pts_{i}")
965 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
966 self.a = Signal(64)
967 self.b = Signal(64)
968
969 # intermediates (needed for unit tests)
970 self._intermediate_output = Signal(128)
971
972 # output
973 self.output = Signal(64)
974
975 def elaborate(self, platform):
976 m = Module()
977
978 # collect part-bytes
979 pbs = Signal(8, reset_less=True)
980 tl = []
981 for i in range(8):
982 pb = Signal(name="pb%d" % i, reset_less=True)
983 m.d.comb += pb.eq(self.part_pts.part_byte(i))
984 tl.append(pb)
985 m.d.comb += pbs.eq(Cat(*tl))
986
987 # create (doubled) PartitionPoints (output is double input width)
988 expanded_part_pts = PartitionPoints()
989 for i, v in self.part_pts.items():
990 ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
991 expanded_part_pts[i * 2] = ep
992 m.d.comb += ep.eq(v)
993
994 # local variables
995 signs = []
996 for i in range(8):
997 s = Signs()
998 signs.append(s)
999 setattr(m.submodules, "signs%d" % i, s)
1000 m.d.comb += s.part_ops.eq(self.part_ops[i])
1001
1002 n_levels = len(self.register_levels)+1
1003 m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
1004 m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
1005 m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
1006 m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
1007 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1008 for mod in [part_8, part_16, part_32, part_64]:
1009 m.d.comb += mod.a.eq(self.a)
1010 m.d.comb += mod.b.eq(self.b)
1011 for i in range(len(signs)):
1012 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1013 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1014 m.d.comb += mod.pbs.eq(pbs)
1015 nat_l.append(mod.not_a_term)
1016 nbt_l.append(mod.not_b_term)
1017 nla_l.append(mod.neg_lsb_a_term)
1018 nlb_l.append(mod.neg_lsb_b_term)
1019
1020 terms = []
1021
1022 for a_index in range(8):
1023 t = ProductTerms(8, 128, 8, a_index, 8)
1024 setattr(m.submodules, "terms_%d" % a_index, t)
1025
1026 m.d.comb += t.a.eq(self.a)
1027 m.d.comb += t.b.eq(self.b)
1028 m.d.comb += t.pb_en.eq(pbs)
1029
1030 for term in t.terms:
1031 terms.append(term)
1032
1033 # it's fine to bitwise-or data together since they are never enabled
1034 # at the same time
1035 m.submodules.nat_or = nat_or = OrMod(128)
1036 m.submodules.nbt_or = nbt_or = OrMod(128)
1037 m.submodules.nla_or = nla_or = OrMod(128)
1038 m.submodules.nlb_or = nlb_or = OrMod(128)
1039 for l, mod in [(nat_l, nat_or),
1040 (nbt_l, nbt_or),
1041 (nla_l, nla_or),
1042 (nlb_l, nlb_or)]:
1043 for i in range(len(l)):
1044 m.d.comb += mod.orin[i].eq(l[i])
1045 terms.append(mod.orout)
1046
1047 add_reduce = AddReduce(terms,
1048 128,
1049 self.register_levels,
1050 expanded_part_pts,
1051 self.part_ops)
1052
1053 out_part_ops = add_reduce.levels[-1].out_part_ops
1054
1055 m.submodules.add_reduce = add_reduce
1056 m.d.comb += self._intermediate_output.eq(add_reduce.output)
1057 # create _output_64
1058 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1059 m.d.comb += io64.intermed.eq(self._intermediate_output)
1060 for i in range(8):
1061 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1062
1063 # create _output_32
1064 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1065 m.d.comb += io32.intermed.eq(self._intermediate_output)
1066 for i in range(8):
1067 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1068
1069 # create _output_16
1070 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1071 m.d.comb += io16.intermed.eq(self._intermediate_output)
1072 for i in range(8):
1073 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1074
1075 # create _output_8
1076 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1077 m.d.comb += io8.intermed.eq(self._intermediate_output)
1078 for i in range(8):
1079 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1080
1081 # final output
1082 m.submodules.finalout = finalout = FinalOut(64)
1083 for i in range(len(part_8.delayed_parts[-1])):
1084 m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
1085 for i in range(len(part_16.delayed_parts[-1])):
1086 m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
1087 for i in range(len(part_32.delayed_parts[-1])):
1088 m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
1089 m.d.comb += finalout.i8.eq(io8.output)
1090 m.d.comb += finalout.i16.eq(io16.output)
1091 m.d.comb += finalout.i32.eq(io32.output)
1092 m.d.comb += finalout.i64.eq(io64.output)
1093 m.d.comb += self.output.eq(finalout.out)
1094
1095 return m
1096
1097
1098 if __name__ == "__main__":
1099 m = Mul8_16_32_64()
1100 main(m, ports=[m.a,
1101 m.b,
1102 m._intermediate_output,
1103 m.output,
1104 *m.part_ops,
1105 *m.part_pts.values()])