use part_ops not out_part_ops
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0, mul=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
58 """
59 if name is None:
60 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
61 retval = PartitionPoints()
62 for point, enabled in self.items():
63 point *= mul
64 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
65 return retval
66
67 def eq(self, rhs):
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self.keys()) != set(rhs.keys()):
70 raise ValueError("incompatible point set")
71 for point, enabled in self.items():
72 yield enabled.eq(rhs[point])
73
74 def as_mask(self, width):
75 """Create a bit-mask from `self`.
76
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
79
80 :param width: the bit width of the resulting mask
81 """
82 bits = []
83 for i in range(width):
84 if i in self:
85 bits.append(~self[i])
86 else:
87 bits.append(True)
88 return Cat(*bits)
89
90 def get_max_partition_count(self, width):
91 """Get the maximum number of partitions.
92
93 Gets the number of partitions when all partition points are enabled.
94 """
95 retval = 1
96 for point in self.keys():
97 if point < width:
98 retval += 1
99 return retval
100
101 def fits_in_width(self, width):
102 """Check if all partition points are smaller than `width`."""
103 for point in self.keys():
104 if point >= width:
105 return False
106 return True
107
108 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
109 if index == -1 or index == 7:
110 return C(True, 1)
111 assert index >= 0 and index < 8
112 return self[(index * 8 + 8)*mfactor]
113
114
115 class FullAdder(Elaboratable):
116 """Full Adder.
117
118 :attribute in0: the first input
119 :attribute in1: the second input
120 :attribute in2: the third input
121 :attribute sum: the sum output
122 :attribute carry: the carry output
123
124 Rather than do individual full adders (and have an array of them,
125 which would be very slow to simulate), this module can specify the
126 bit width of the inputs and outputs: in effect it performs multiple
127 Full 3-2 Add operations "in parallel".
128 """
129
130 def __init__(self, width):
131 """Create a ``FullAdder``.
132
133 :param width: the bit width of the input and output
134 """
135 self.in0 = Signal(width)
136 self.in1 = Signal(width)
137 self.in2 = Signal(width)
138 self.sum = Signal(width)
139 self.carry = Signal(width)
140
141 def elaborate(self, platform):
142 """Elaborate this module."""
143 m = Module()
144 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
145 m.d.comb += self.carry.eq((self.in0 & self.in1)
146 | (self.in1 & self.in2)
147 | (self.in2 & self.in0))
148 return m
149
150
151 class MaskedFullAdder(Elaboratable):
152 """Masked Full Adder.
153
154 :attribute mask: the carry partition mask
155 :attribute in0: the first input
156 :attribute in1: the second input
157 :attribute in2: the third input
158 :attribute sum: the sum output
159 :attribute mcarry: the masked carry output
160
161 FullAdders are always used with a "mask" on the output. To keep
162 the graphviz "clean", this class performs the masking here rather
163 than inside a large for-loop.
164
165 See the following discussion as to why this is no longer derived
166 from FullAdder. Each carry is shifted here *before* being ANDed
167 with the mask, so that an AOI cell may be used (which is more
168 gate-efficient)
169 https://en.wikipedia.org/wiki/AND-OR-Invert
170 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
171 """
172
173 def __init__(self, width):
174 """Create a ``MaskedFullAdder``.
175
176 :param width: the bit width of the input and output
177 """
178 self.width = width
179 self.mask = Signal(width, reset_less=True)
180 self.mcarry = Signal(width, reset_less=True)
181 self.in0 = Signal(width, reset_less=True)
182 self.in1 = Signal(width, reset_less=True)
183 self.in2 = Signal(width, reset_less=True)
184 self.sum = Signal(width, reset_less=True)
185
186 def elaborate(self, platform):
187 """Elaborate this module."""
188 m = Module()
189 s1 = Signal(self.width, reset_less=True)
190 s2 = Signal(self.width, reset_less=True)
191 s3 = Signal(self.width, reset_less=True)
192 c1 = Signal(self.width, reset_less=True)
193 c2 = Signal(self.width, reset_less=True)
194 c3 = Signal(self.width, reset_less=True)
195 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
196 m.d.comb += s1.eq(Cat(0, self.in0))
197 m.d.comb += s2.eq(Cat(0, self.in1))
198 m.d.comb += s3.eq(Cat(0, self.in2))
199 m.d.comb += c1.eq(s1 & s2 & self.mask)
200 m.d.comb += c2.eq(s2 & s3 & self.mask)
201 m.d.comb += c3.eq(s3 & s1 & self.mask)
202 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
203 return m
204
205
206 class PartitionedAdder(Elaboratable):
207 """Partitioned Adder.
208
209 Performs the final add. The partition points are included in the
210 actual add (in one of the operands only), which causes a carry over
211 to the next bit. Then the final output *removes* the extra bits from
212 the result.
213
214 partition: .... P... P... P... P... (32 bits)
215 a : .... .... .... .... .... (32 bits)
216 b : .... .... .... .... .... (32 bits)
217 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
218 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
219 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
220 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
221
222 :attribute width: the bit width of the input and output. Read-only.
223 :attribute a: the first input to the adder
224 :attribute b: the second input to the adder
225 :attribute output: the sum output
226 :attribute partition_points: the input partition points. Modification not
227 supported, except for by ``Signal.eq``.
228 """
229
230 def __init__(self, width, partition_points):
231 """Create a ``PartitionedAdder``.
232
233 :param width: the bit width of the input and output
234 :param partition_points: the input partition points
235 """
236 self.width = width
237 self.a = Signal(width)
238 self.b = Signal(width)
239 self.output = Signal(width)
240 self.partition_points = PartitionPoints(partition_points)
241 if not self.partition_points.fits_in_width(width):
242 raise ValueError("partition_points doesn't fit in width")
243 expanded_width = 0
244 for i in range(self.width):
245 if i in self.partition_points:
246 expanded_width += 1
247 expanded_width += 1
248 self._expanded_width = expanded_width
249 # XXX these have to remain here due to some horrible nmigen
250 # simulation bugs involving sync. it is *not* necessary to
251 # have them here, they should (under normal circumstances)
252 # be moved into elaborate, as they are entirely local
253 self._expanded_a = Signal(expanded_width) # includes extra part-points
254 self._expanded_b = Signal(expanded_width) # likewise.
255 self._expanded_o = Signal(expanded_width) # likewise.
256
257 def elaborate(self, platform):
258 """Elaborate this module."""
259 m = Module()
260 expanded_index = 0
261 # store bits in a list, use Cat later. graphviz is much cleaner
262 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
263
264 # partition points are "breaks" (extra zeros or 1s) in what would
265 # otherwise be a massive long add. when the "break" points are 0,
266 # whatever is in it (in the output) is discarded. however when
267 # there is a "1", it causes a roll-over carry to the *next* bit.
268 # we still ignore the "break" bit in the [intermediate] output,
269 # however by that time we've got the effect that we wanted: the
270 # carry has been carried *over* the break point.
271
272 for i in range(self.width):
273 if i in self.partition_points:
274 # add extra bit set to 0 + 0 for enabled partition points
275 # and 1 + 0 for disabled partition points
276 ea.append(self._expanded_a[expanded_index])
277 al.append(~self.partition_points[i]) # add extra bit in a
278 eb.append(self._expanded_b[expanded_index])
279 bl.append(C(0)) # yes, add a zero
280 expanded_index += 1 # skip the extra point. NOT in the output
281 ea.append(self._expanded_a[expanded_index])
282 eb.append(self._expanded_b[expanded_index])
283 eo.append(self._expanded_o[expanded_index])
284 al.append(self.a[i])
285 bl.append(self.b[i])
286 ol.append(self.output[i])
287 expanded_index += 1
288
289 # combine above using Cat
290 m.d.comb += Cat(*ea).eq(Cat(*al))
291 m.d.comb += Cat(*eb).eq(Cat(*bl))
292 m.d.comb += Cat(*ol).eq(Cat(*eo))
293
294 # use only one addition to take advantage of look-ahead carry and
295 # special hardware on FPGAs
296 m.d.comb += self._expanded_o.eq(
297 self._expanded_a + self._expanded_b)
298 return m
299
300
301 FULL_ADDER_INPUT_COUNT = 3
302
303 class AddReduceData:
304
305 def __init__(self, ppoints, n_inputs, output_width, n_parts):
306 self.part_ops = [Signal(2, name=f"part_ops_{i}")
307 for i in range(n_parts)]
308 self.inputs = [Signal(output_width, name=f"inputs[{i}]")
309 for i in range(n_inputs)]
310 self.reg_partition_points = ppoints.like()
311
312 def eq(self, rhs):
313 return [self.reg_partition_points.eq(rhs.reg_partition_points)] + \
314 [self.inputs[i].eq(rhs.inputs[i])
315 for i in range(len(self.inputs))] + \
316 [self.part_ops[i].eq(rhs.part_ops[i])
317 for i in range(len(self.part_ops))]
318
319
320 class FinalAdd(Elaboratable):
321 """ Final stage of add reduce
322 """
323
324 def __init__(self, n_inputs, output_width, n_parts, register_levels,
325 partition_points):
326 self.i = AddReduceData(partition_points, n_inputs,
327 output_width, n_parts)
328 self.n_inputs = n_inputs
329 self.n_parts = n_parts
330 self._resized_inputs = self.i.inputs
331 self.register_levels = list(register_levels)
332 self.output = Signal(output_width)
333 self.partition_points = PartitionPoints(partition_points)
334 if not self.partition_points.fits_in_width(output_width):
335 raise ValueError("partition_points doesn't fit in output_width")
336 self.intermediate_terms = []
337
338 def elaborate(self, platform):
339 """Elaborate this module."""
340 m = Module()
341
342 if self.n_inputs == 0:
343 # use 0 as the default output value
344 m.d.comb += self.output.eq(0)
345 elif self.n_inputs == 1:
346 # handle single input
347 m.d.comb += self.output.eq(self._resized_inputs[0])
348 else:
349 # base case for adding 2 inputs
350 assert self.n_inputs == 2
351 adder = PartitionedAdder(len(self.output),
352 self.i.reg_partition_points)
353 m.submodules.final_adder = adder
354 m.d.comb += adder.a.eq(self._resized_inputs[0])
355 m.d.comb += adder.b.eq(self._resized_inputs[1])
356 m.d.comb += self.output.eq(adder.output)
357 return m
358
359
360 class AddReduceSingle(Elaboratable):
361 """Add list of numbers together.
362
363 :attribute inputs: input ``Signal``s to be summed. Modification not
364 supported, except for by ``Signal.eq``.
365 :attribute register_levels: List of nesting levels that should have
366 pipeline registers.
367 :attribute output: output sum.
368 :attribute partition_points: the input partition points. Modification not
369 supported, except for by ``Signal.eq``.
370 """
371
372 def __init__(self, n_inputs, output_width, n_parts, register_levels,
373 partition_points):
374 """Create an ``AddReduce``.
375
376 :param inputs: input ``Signal``s to be summed.
377 :param output_width: bit-width of ``output``.
378 :param register_levels: List of nesting levels that should have
379 pipeline registers.
380 :param partition_points: the input partition points.
381 """
382 self.n_inputs = n_inputs
383 self.n_parts = n_parts
384 self.output_width = output_width
385 self.i = AddReduceData(partition_points, n_inputs,
386 output_width, n_parts)
387 self._resized_inputs = self.i.inputs
388 self.register_levels = list(register_levels)
389 self.partition_points = PartitionPoints(partition_points)
390 if not self.partition_points.fits_in_width(output_width):
391 raise ValueError("partition_points doesn't fit in output_width")
392
393 max_level = AddReduceSingle.get_max_level(n_inputs)
394 for level in self.register_levels:
395 if level > max_level:
396 raise ValueError(
397 "not enough adder levels for specified register levels")
398
399 # this is annoying. we have to create the modules (and terms)
400 # because we need to know what they are (in order to set up the
401 # interconnects back in AddReduce), but cannot do the m.d.comb +=
402 # etc because this is not in elaboratable.
403 self.groups = AddReduceSingle.full_adder_groups(n_inputs)
404 self._intermediate_terms = []
405 if len(self.groups) != 0:
406 self.create_next_terms()
407
408 @staticmethod
409 def get_max_level(input_count):
410 """Get the maximum level.
411
412 All ``register_levels`` must be less than or equal to the maximum
413 level.
414 """
415 retval = 0
416 while True:
417 groups = AddReduceSingle.full_adder_groups(input_count)
418 if len(groups) == 0:
419 return retval
420 input_count %= FULL_ADDER_INPUT_COUNT
421 input_count += 2 * len(groups)
422 retval += 1
423
424 @staticmethod
425 def full_adder_groups(input_count):
426 """Get ``inputs`` indices for which a full adder should be built."""
427 return range(0,
428 input_count - FULL_ADDER_INPUT_COUNT + 1,
429 FULL_ADDER_INPUT_COUNT)
430
431 def elaborate(self, platform):
432 """Elaborate this module."""
433 m = Module()
434
435 for (value, term) in self._intermediate_terms:
436 m.d.comb += term.eq(value)
437
438 mask = self.i.reg_partition_points.as_mask(self.output_width)
439 m.d.comb += self.part_mask.eq(mask)
440
441 # add and link the intermediate term modules
442 for i, (iidx, adder_i) in enumerate(self.adders):
443 setattr(m.submodules, f"adder_{i}", adder_i)
444
445 m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
446 m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
447 m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
448 m.d.comb += adder_i.mask.eq(self.part_mask)
449
450 return m
451
452 def create_next_terms(self):
453
454 # go on to prepare recursive case
455 intermediate_terms = []
456 _intermediate_terms = []
457
458 def add_intermediate_term(value):
459 intermediate_term = Signal(
460 self.output_width,
461 name=f"intermediate_terms[{len(intermediate_terms)}]")
462 _intermediate_terms.append((value, intermediate_term))
463 intermediate_terms.append(intermediate_term)
464
465 # store mask in intermediary (simplifies graph)
466 self.part_mask = Signal(self.output_width, reset_less=True)
467
468 # create full adders for this recursive level.
469 # this shrinks N terms to 2 * (N // 3) plus the remainder
470 self.adders = []
471 for i in self.groups:
472 adder_i = MaskedFullAdder(self.output_width)
473 self.adders.append((i, adder_i))
474 # add both the sum and the masked-carry to the next level.
475 # 3 inputs have now been reduced to 2...
476 add_intermediate_term(adder_i.sum)
477 add_intermediate_term(adder_i.mcarry)
478 # handle the remaining inputs.
479 if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
480 add_intermediate_term(self._resized_inputs[-1])
481 elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
482 # Just pass the terms to the next layer, since we wouldn't gain
483 # anything by using a half adder since there would still be 2 terms
484 # and just passing the terms to the next layer saves gates.
485 add_intermediate_term(self._resized_inputs[-2])
486 add_intermediate_term(self._resized_inputs[-1])
487 else:
488 assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
489
490 self.intermediate_terms = intermediate_terms
491 self._intermediate_terms = _intermediate_terms
492
493
494 class AddReduce(Elaboratable):
495 """Recursively Add list of numbers together.
496
497 :attribute inputs: input ``Signal``s to be summed. Modification not
498 supported, except for by ``Signal.eq``.
499 :attribute register_levels: List of nesting levels that should have
500 pipeline registers.
501 :attribute output: output sum.
502 :attribute partition_points: the input partition points. Modification not
503 supported, except for by ``Signal.eq``.
504 """
505
506 def __init__(self, inputs, output_width, register_levels, partition_points,
507 part_ops):
508 """Create an ``AddReduce``.
509
510 :param inputs: input ``Signal``s to be summed.
511 :param output_width: bit-width of ``output``.
512 :param register_levels: List of nesting levels that should have
513 pipeline registers.
514 :param partition_points: the input partition points.
515 """
516 self.inputs = inputs
517 self.part_ops = part_ops
518 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
519 for i in range(len(part_ops))]
520 self.output = Signal(output_width)
521 self.output_width = output_width
522 self.register_levels = register_levels
523 self.partition_points = partition_points
524
525 self.create_levels()
526
527 @staticmethod
528 def get_max_level(input_count):
529 return AddReduceSingle.get_max_level(input_count)
530
531 @staticmethod
532 def next_register_levels(register_levels):
533 """``Iterable`` of ``register_levels`` for next recursive level."""
534 for level in register_levels:
535 if level > 0:
536 yield level - 1
537
538 def create_levels(self):
539 """creates reduction levels"""
540
541 mods = []
542 next_levels = self.register_levels
543 partition_points = self.partition_points
544 inputs = self.inputs
545 part_ops = self.part_ops
546 n_parts = len(part_ops)
547 while True:
548 ilen = len(inputs)
549 next_level = AddReduceSingle(ilen, self.output_width, n_parts,
550 next_levels, partition_points)
551 mods.append(next_level)
552 next_levels = list(AddReduce.next_register_levels(next_levels))
553 partition_points = next_level.i.reg_partition_points
554 inputs = next_level.intermediate_terms
555 ilen = len(inputs)
556 part_ops = next_level.i.part_ops
557 groups = AddReduceSingle.full_adder_groups(len(inputs))
558 if len(groups) == 0:
559 break
560
561 if ilen != 0:
562 next_level = FinalAdd(ilen, self.output_width, n_parts,
563 next_levels, partition_points)
564 mods.append(next_level)
565
566 self.levels = mods
567
568 def elaborate(self, platform):
569 """Elaborate this module."""
570 m = Module()
571
572 for i, next_level in enumerate(self.levels):
573 setattr(m.submodules, "next_level%d" % i, next_level)
574
575 partition_points = self.partition_points
576 inputs = self.inputs
577 part_ops = self.part_ops
578 for i in range(len(self.levels)):
579 mcur = self.levels[i]
580 inassign = [mcur._resized_inputs[i].eq(inputs[i])
581 for i in range(len(inputs))]
582 copy_part_ops = [mcur.i.part_ops[i].eq(part_ops[i])
583 for i in range(len(part_ops))]
584 if 0 in mcur.register_levels:
585 m.d.sync += copy_part_ops
586 m.d.sync += inassign
587 m.d.sync += mcur.i.reg_partition_points.eq(partition_points)
588 else:
589 m.d.comb += copy_part_ops
590 m.d.comb += inassign
591 m.d.comb += mcur.i.reg_partition_points.eq(partition_points)
592 partition_points = mcur.i.reg_partition_points
593 inputs = mcur.intermediate_terms
594 part_ops = mcur.i.part_ops
595
596 # output comes from last module
597 m.d.comb += self.output.eq(next_level.output)
598 copy_part_ops = [self.out_part_ops[i].eq(next_level.i.part_ops[i])
599 for i in range(len(self.part_ops))]
600 m.d.comb += copy_part_ops
601
602 return m
603
604
605 OP_MUL_LOW = 0
606 OP_MUL_SIGNED_HIGH = 1
607 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
608 OP_MUL_UNSIGNED_HIGH = 3
609
610
611 def get_term(value, shift=0, enabled=None):
612 if enabled is not None:
613 value = Mux(enabled, value, 0)
614 if shift > 0:
615 value = Cat(Repl(C(0, 1), shift), value)
616 else:
617 assert shift == 0
618 return value
619
620
621 class ProductTerm(Elaboratable):
622 """ this class creates a single product term (a[..]*b[..]).
623 it has a design flaw in that is the *output* that is selected,
624 where the multiplication(s) are combinatorially generated
625 all the time.
626 """
627
628 def __init__(self, width, twidth, pbwid, a_index, b_index):
629 self.a_index = a_index
630 self.b_index = b_index
631 shift = 8 * (self.a_index + self.b_index)
632 self.pwidth = width
633 self.twidth = twidth
634 self.width = width*2
635 self.shift = shift
636
637 self.ti = Signal(self.width, reset_less=True)
638 self.term = Signal(twidth, reset_less=True)
639 self.a = Signal(twidth//2, reset_less=True)
640 self.b = Signal(twidth//2, reset_less=True)
641 self.pb_en = Signal(pbwid, reset_less=True)
642
643 self.tl = tl = []
644 min_index = min(self.a_index, self.b_index)
645 max_index = max(self.a_index, self.b_index)
646 for i in range(min_index, max_index):
647 tl.append(self.pb_en[i])
648 name = "te_%d_%d" % (self.a_index, self.b_index)
649 if len(tl) > 0:
650 term_enabled = Signal(name=name, reset_less=True)
651 else:
652 term_enabled = None
653 self.enabled = term_enabled
654 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
655
656 def elaborate(self, platform):
657
658 m = Module()
659 if self.enabled is not None:
660 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
661
662 bsa = Signal(self.width, reset_less=True)
663 bsb = Signal(self.width, reset_less=True)
664 a_index, b_index = self.a_index, self.b_index
665 pwidth = self.pwidth
666 m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
667 m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
668 m.d.comb += self.ti.eq(bsa * bsb)
669 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
670 """
671 #TODO: sort out width issues, get inputs a/b switched on/off.
672 #data going into Muxes is 1/2 the required width
673
674 pwidth = self.pwidth
675 width = self.width
676 bsa = Signal(self.twidth//2, reset_less=True)
677 bsb = Signal(self.twidth//2, reset_less=True)
678 asel = Signal(width, reset_less=True)
679 bsel = Signal(width, reset_less=True)
680 a_index, b_index = self.a_index, self.b_index
681 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
682 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
683 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
684 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
685 m.d.comb += self.ti.eq(bsa * bsb)
686 m.d.comb += self.term.eq(self.ti)
687 """
688
689 return m
690
691
692 class ProductTerms(Elaboratable):
693 """ creates a bank of product terms. also performs the actual bit-selection
694 this class is to be wrapped with a for-loop on the "a" operand.
695 it creates a second-level for-loop on the "b" operand.
696 """
697 def __init__(self, width, twidth, pbwid, a_index, blen):
698 self.a_index = a_index
699 self.blen = blen
700 self.pwidth = width
701 self.twidth = twidth
702 self.pbwid = pbwid
703 self.a = Signal(twidth//2, reset_less=True)
704 self.b = Signal(twidth//2, reset_less=True)
705 self.pb_en = Signal(pbwid, reset_less=True)
706 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
707 for i in range(blen)]
708
709 def elaborate(self, platform):
710
711 m = Module()
712
713 for b_index in range(self.blen):
714 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
715 self.a_index, b_index)
716 setattr(m.submodules, "term_%d" % b_index, t)
717
718 m.d.comb += t.a.eq(self.a)
719 m.d.comb += t.b.eq(self.b)
720 m.d.comb += t.pb_en.eq(self.pb_en)
721
722 m.d.comb += self.terms[b_index].eq(t.term)
723
724 return m
725
726
727 class LSBNegTerm(Elaboratable):
728
729 def __init__(self, bit_width):
730 self.bit_width = bit_width
731 self.part = Signal(reset_less=True)
732 self.signed = Signal(reset_less=True)
733 self.op = Signal(bit_width, reset_less=True)
734 self.msb = Signal(reset_less=True)
735 self.nt = Signal(bit_width*2, reset_less=True)
736 self.nl = Signal(bit_width*2, reset_less=True)
737
738 def elaborate(self, platform):
739 m = Module()
740 comb = m.d.comb
741 bit_wid = self.bit_width
742 ext = Repl(0, bit_wid) # extend output to HI part
743
744 # determine sign of each incoming number *in this partition*
745 enabled = Signal(reset_less=True)
746 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
747
748 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
749 # negation operation is split into a bitwise not and a +1.
750 # likewise for 16, 32, and 64-bit values.
751
752 # width-extended 1s complement if a is signed, otherwise zero
753 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
754
755 # add 1 if signed, otherwise add zero
756 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
757
758 return m
759
760
761 class Parts(Elaboratable):
762
763 def __init__(self, pbwid, epps, n_parts):
764 self.pbwid = pbwid
765 # inputs
766 self.epps = PartitionPoints.like(epps, name="epps") # expanded points
767 # outputs
768 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
769
770 def elaborate(self, platform):
771 m = Module()
772
773 epps, parts = self.epps, self.parts
774 # collect part-bytes (double factor because the input is extended)
775 pbs = Signal(self.pbwid, reset_less=True)
776 tl = []
777 for i in range(self.pbwid):
778 pb = Signal(name="pb%d" % i, reset_less=True)
779 m.d.comb += pb.eq(epps.part_byte(i, mfactor=2)) # double
780 tl.append(pb)
781 m.d.comb += pbs.eq(Cat(*tl))
782
783 # negated-temporary copy of partition bits
784 npbs = Signal.like(pbs, reset_less=True)
785 m.d.comb += npbs.eq(~pbs)
786 byte_count = 8 // len(parts)
787 for i in range(len(parts)):
788 pbl = []
789 pbl.append(npbs[i * byte_count - 1])
790 for j in range(i * byte_count, (i + 1) * byte_count - 1):
791 pbl.append(pbs[j])
792 pbl.append(npbs[(i + 1) * byte_count - 1])
793 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
794 m.d.comb += value.eq(Cat(*pbl))
795 m.d.comb += parts[i].eq(~(value).bool())
796
797 return m
798
799
800 class Part(Elaboratable):
801 """ a key class which, depending on the partitioning, will determine
802 what action to take when parts of the output are signed or unsigned.
803
804 this requires 2 pieces of data *per operand, per partition*:
805 whether the MSB is HI/LO (per partition!), and whether a signed
806 or unsigned operation has been *requested*.
807
808 once that is determined, signed is basically carried out
809 by splitting 2's complement into 1's complement plus one.
810 1's complement is just a bit-inversion.
811
812 the extra terms - as separate terms - are then thrown at the
813 AddReduce alongside the multiplication part-results.
814 """
815 def __init__(self, epps, width, n_parts, n_levels, pbwid):
816
817 self.pbwid = pbwid
818 self.epps = epps
819
820 # inputs
821 self.a = Signal(64)
822 self.b = Signal(64)
823 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
824 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
825 self.pbs = Signal(pbwid, reset_less=True)
826
827 # outputs
828 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
829
830 self.not_a_term = Signal(width)
831 self.neg_lsb_a_term = Signal(width)
832 self.not_b_term = Signal(width)
833 self.neg_lsb_b_term = Signal(width)
834
835 def elaborate(self, platform):
836 m = Module()
837
838 pbs, parts = self.pbs, self.parts
839 epps = self.epps
840 m.submodules.p = p = Parts(self.pbwid, epps, len(parts))
841 m.d.comb += p.epps.eq(epps)
842 parts = p.parts
843
844 byte_count = 8 // len(parts)
845
846 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
847 self.not_a_term, self.neg_lsb_a_term,
848 self.not_b_term, self.neg_lsb_b_term)
849
850 byte_width = 8 // len(parts) # byte width
851 bit_wid = 8 * byte_width # bit width
852 nat, nbt, nla, nlb = [], [], [], []
853 for i in range(len(parts)):
854 # work out bit-inverted and +1 term for a.
855 pa = LSBNegTerm(bit_wid)
856 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
857 m.d.comb += pa.part.eq(parts[i])
858 m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
859 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
860 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
861 nat.append(pa.nt)
862 nla.append(pa.nl)
863
864 # work out bit-inverted and +1 term for b
865 pb = LSBNegTerm(bit_wid)
866 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
867 m.d.comb += pb.part.eq(parts[i])
868 m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
869 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
870 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
871 nbt.append(pb.nt)
872 nlb.append(pb.nl)
873
874 # concatenate together and return all 4 results.
875 m.d.comb += [not_a_term.eq(Cat(*nat)),
876 not_b_term.eq(Cat(*nbt)),
877 neg_lsb_a_term.eq(Cat(*nla)),
878 neg_lsb_b_term.eq(Cat(*nlb)),
879 ]
880
881 return m
882
883
884 class IntermediateOut(Elaboratable):
885 """ selects the HI/LO part of the multiplication, for a given bit-width
886 the output is also reconstructed in its SIMD (partition) lanes.
887 """
888 def __init__(self, width, out_wid, n_parts):
889 self.width = width
890 self.n_parts = n_parts
891 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
892 for i in range(8)]
893 self.intermed = Signal(out_wid, reset_less=True)
894 self.output = Signal(out_wid//2, reset_less=True)
895
896 def elaborate(self, platform):
897 m = Module()
898
899 ol = []
900 w = self.width
901 sel = w // 8
902 for i in range(self.n_parts):
903 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
904 m.d.comb += op.eq(
905 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
906 self.intermed.part(i * w*2, w),
907 self.intermed.part(i * w*2 + w, w)))
908 ol.append(op)
909 m.d.comb += self.output.eq(Cat(*ol))
910
911 return m
912
913
914 class FinalOut(Elaboratable):
915 """ selects the final output based on the partitioning.
916
917 each byte is selectable independently, i.e. it is possible
918 that some partitions requested 8-bit computation whilst others
919 requested 16 or 32 bit.
920 """
921 def __init__(self, out_wid):
922 # inputs
923 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
924 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
925 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
926
927 self.i8 = Signal(out_wid, reset_less=True)
928 self.i16 = Signal(out_wid, reset_less=True)
929 self.i32 = Signal(out_wid, reset_less=True)
930 self.i64 = Signal(out_wid, reset_less=True)
931
932 # output
933 self.out = Signal(out_wid, reset_less=True)
934
935 def elaborate(self, platform):
936 m = Module()
937 ol = []
938 for i in range(8):
939 # select one of the outputs: d8 selects i8, d16 selects i16
940 # d32 selects i32, and the default is i64.
941 # d8 and d16 are ORed together in the first Mux
942 # then the 2nd selects either i8 or i16.
943 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
944 op = Signal(8, reset_less=True, name="op_%d" % i)
945 m.d.comb += op.eq(
946 Mux(self.d8[i] | self.d16[i // 2],
947 Mux(self.d8[i], self.i8.part(i * 8, 8),
948 self.i16.part(i * 8, 8)),
949 Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
950 self.i64.part(i * 8, 8))))
951 ol.append(op)
952 m.d.comb += self.out.eq(Cat(*ol))
953 return m
954
955
956 class OrMod(Elaboratable):
957 """ ORs four values together in a hierarchical tree
958 """
959 def __init__(self, wid):
960 self.wid = wid
961 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
962 for i in range(4)]
963 self.orout = Signal(wid, reset_less=True)
964
965 def elaborate(self, platform):
966 m = Module()
967 or1 = Signal(self.wid, reset_less=True)
968 or2 = Signal(self.wid, reset_less=True)
969 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
970 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
971 m.d.comb += self.orout.eq(or1 | or2)
972
973 return m
974
975
976 class Signs(Elaboratable):
977 """ determines whether a or b are signed numbers
978 based on the required operation type (OP_MUL_*)
979 """
980
981 def __init__(self):
982 self.part_ops = Signal(2, reset_less=True)
983 self.a_signed = Signal(reset_less=True)
984 self.b_signed = Signal(reset_less=True)
985
986 def elaborate(self, platform):
987
988 m = Module()
989
990 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
991 bsig = (self.part_ops == OP_MUL_LOW) \
992 | (self.part_ops == OP_MUL_SIGNED_HIGH)
993 m.d.comb += self.a_signed.eq(asig)
994 m.d.comb += self.b_signed.eq(bsig)
995
996 return m
997
998
999 class Mul8_16_32_64(Elaboratable):
1000 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1001
1002 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1003 partitions on naturally-aligned boundaries. Supports the operation being
1004 set for each partition independently.
1005
1006 :attribute part_pts: the input partition points. Has a partition point at
1007 multiples of 8 in 0 < i < 64. Each partition point's associated
1008 ``Value`` is a ``Signal``. Modification not supported, except for by
1009 ``Signal.eq``.
1010 :attribute part_ops: the operation for each byte. The operation for a
1011 particular partition is selected by assigning the selected operation
1012 code to each byte in the partition. The allowed operation codes are:
1013
1014 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1015 RISC-V's `mul` instruction.
1016 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1017 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1018 instruction.
1019 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1020 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1021 `mulhsu` instruction.
1022 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1023 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1024 instruction.
1025 """
1026
1027 def __init__(self, register_levels=()):
1028 """ register_levels: specifies the points in the cascade at which
1029 flip-flops are to be inserted.
1030 """
1031
1032 # parameter(s)
1033 self.register_levels = list(register_levels)
1034
1035 # inputs
1036 self.part_pts = PartitionPoints()
1037 for i in range(8, 64, 8):
1038 self.part_pts[i] = Signal(name=f"part_pts_{i}")
1039 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
1040 self.a = Signal(64)
1041 self.b = Signal(64)
1042
1043 # intermediates (needed for unit tests)
1044 self._intermediate_output = Signal(128)
1045
1046 # output
1047 self.output = Signal(64)
1048
1049 def elaborate(self, platform):
1050 m = Module()
1051
1052 # collect part-bytes
1053 pbs = Signal(8, reset_less=True)
1054 tl = []
1055 for i in range(8):
1056 pb = Signal(name="pb%d" % i, reset_less=True)
1057 m.d.comb += pb.eq(self.part_pts.part_byte(i))
1058 tl.append(pb)
1059 m.d.comb += pbs.eq(Cat(*tl))
1060
1061 # create (doubled) PartitionPoints (output is double input width)
1062 expanded_part_pts = eps = PartitionPoints()
1063 for i, v in self.part_pts.items():
1064 ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
1065 expanded_part_pts[i * 2] = ep
1066 m.d.comb += ep.eq(v)
1067
1068 # local variables
1069 signs = []
1070 for i in range(8):
1071 s = Signs()
1072 signs.append(s)
1073 setattr(m.submodules, "signs%d" % i, s)
1074 m.d.comb += s.part_ops.eq(self.part_ops[i])
1075
1076 n_levels = len(self.register_levels)+1
1077 m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
1078 m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
1079 m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
1080 m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
1081 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1082 for mod in [part_8, part_16, part_32, part_64]:
1083 m.d.comb += mod.a.eq(self.a)
1084 m.d.comb += mod.b.eq(self.b)
1085 for i in range(len(signs)):
1086 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1087 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1088 m.d.comb += mod.pbs.eq(pbs)
1089 nat_l.append(mod.not_a_term)
1090 nbt_l.append(mod.not_b_term)
1091 nla_l.append(mod.neg_lsb_a_term)
1092 nlb_l.append(mod.neg_lsb_b_term)
1093
1094 terms = []
1095
1096 for a_index in range(8):
1097 t = ProductTerms(8, 128, 8, a_index, 8)
1098 setattr(m.submodules, "terms_%d" % a_index, t)
1099
1100 m.d.comb += t.a.eq(self.a)
1101 m.d.comb += t.b.eq(self.b)
1102 m.d.comb += t.pb_en.eq(pbs)
1103
1104 for term in t.terms:
1105 terms.append(term)
1106
1107 # it's fine to bitwise-or data together since they are never enabled
1108 # at the same time
1109 m.submodules.nat_or = nat_or = OrMod(128)
1110 m.submodules.nbt_or = nbt_or = OrMod(128)
1111 m.submodules.nla_or = nla_or = OrMod(128)
1112 m.submodules.nlb_or = nlb_or = OrMod(128)
1113 for l, mod in [(nat_l, nat_or),
1114 (nbt_l, nbt_or),
1115 (nla_l, nla_or),
1116 (nlb_l, nlb_or)]:
1117 for i in range(len(l)):
1118 m.d.comb += mod.orin[i].eq(l[i])
1119 terms.append(mod.orout)
1120
1121 add_reduce = AddReduce(terms,
1122 128,
1123 self.register_levels,
1124 expanded_part_pts,
1125 self.part_ops)
1126
1127 out_part_ops = add_reduce.levels[-1].i.part_ops
1128 out_part_pts = add_reduce.levels[-1].i.reg_partition_points
1129
1130 m.submodules.add_reduce = add_reduce
1131 m.d.comb += self._intermediate_output.eq(add_reduce.output)
1132 # create _output_64
1133 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1134 m.d.comb += io64.intermed.eq(self._intermediate_output)
1135 for i in range(8):
1136 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1137
1138 # create _output_32
1139 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1140 m.d.comb += io32.intermed.eq(self._intermediate_output)
1141 for i in range(8):
1142 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1143
1144 # create _output_16
1145 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1146 m.d.comb += io16.intermed.eq(self._intermediate_output)
1147 for i in range(8):
1148 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1149
1150 # create _output_8
1151 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1152 m.d.comb += io8.intermed.eq(self._intermediate_output)
1153 for i in range(8):
1154 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1155
1156 m.submodules.p_8 = p_8 = Parts(8, eps, len(part_8.parts))
1157 m.submodules.p_16 = p_16 = Parts(8, eps, len(part_16.parts))
1158 m.submodules.p_32 = p_32 = Parts(8, eps, len(part_32.parts))
1159 m.submodules.p_64 = p_64 = Parts(8, eps, len(part_64.parts))
1160
1161 m.d.comb += p_8.epps.eq(out_part_pts)
1162 m.d.comb += p_16.epps.eq(out_part_pts)
1163 m.d.comb += p_32.epps.eq(out_part_pts)
1164 m.d.comb += p_64.epps.eq(out_part_pts)
1165
1166 # final output
1167 m.submodules.finalout = finalout = FinalOut(64)
1168 for i in range(len(part_8.parts)):
1169 m.d.comb += finalout.d8[i].eq(p_8.parts[i])
1170 for i in range(len(part_16.parts)):
1171 m.d.comb += finalout.d16[i].eq(p_16.parts[i])
1172 for i in range(len(part_32.parts)):
1173 m.d.comb += finalout.d32[i].eq(p_32.parts[i])
1174 m.d.comb += finalout.i8.eq(io8.output)
1175 m.d.comb += finalout.i16.eq(io16.output)
1176 m.d.comb += finalout.i32.eq(io32.output)
1177 m.d.comb += finalout.i64.eq(io64.output)
1178 m.d.comb += self.output.eq(finalout.out)
1179
1180 return m
1181
1182
1183 if __name__ == "__main__":
1184 m = Mul8_16_32_64()
1185 main(m, ports=[m.a,
1186 m.b,
1187 m._intermediate_output,
1188 m.output,
1189 *m.part_ops,
1190 *m.part_pts.values()])