syntax error
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0, mul=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
58 """
59 if name is None:
60 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
61 retval = PartitionPoints()
62 for point, enabled in self.items():
63 point *= mul
64 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
65 return retval
66
67 def eq(self, rhs):
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self.keys()) != set(rhs.keys()):
70 raise ValueError("incompatible point set")
71 for point, enabled in self.items():
72 yield enabled.eq(rhs[point])
73
74 def as_mask(self, width):
75 """Create a bit-mask from `self`.
76
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
79
80 :param width: the bit width of the resulting mask
81 """
82 bits = []
83 for i in range(width):
84 if i in self:
85 bits.append(~self[i])
86 else:
87 bits.append(True)
88 return Cat(*bits)
89
90 def get_max_partition_count(self, width):
91 """Get the maximum number of partitions.
92
93 Gets the number of partitions when all partition points are enabled.
94 """
95 retval = 1
96 for point in self.keys():
97 if point < width:
98 retval += 1
99 return retval
100
101 def fits_in_width(self, width):
102 """Check if all partition points are smaller than `width`."""
103 for point in self.keys():
104 if point >= width:
105 return False
106 return True
107
108 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
109 if index == -1 or index == 7:
110 return C(True, 1)
111 assert index >= 0 and index < 8
112 return self[(index * 8 + 8)*mfactor]
113
114
115 class FullAdder(Elaboratable):
116 """Full Adder.
117
118 :attribute in0: the first input
119 :attribute in1: the second input
120 :attribute in2: the third input
121 :attribute sum: the sum output
122 :attribute carry: the carry output
123
124 Rather than do individual full adders (and have an array of them,
125 which would be very slow to simulate), this module can specify the
126 bit width of the inputs and outputs: in effect it performs multiple
127 Full 3-2 Add operations "in parallel".
128 """
129
130 def __init__(self, width):
131 """Create a ``FullAdder``.
132
133 :param width: the bit width of the input and output
134 """
135 self.in0 = Signal(width)
136 self.in1 = Signal(width)
137 self.in2 = Signal(width)
138 self.sum = Signal(width)
139 self.carry = Signal(width)
140
141 def elaborate(self, platform):
142 """Elaborate this module."""
143 m = Module()
144 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
145 m.d.comb += self.carry.eq((self.in0 & self.in1)
146 | (self.in1 & self.in2)
147 | (self.in2 & self.in0))
148 return m
149
150
151 class MaskedFullAdder(Elaboratable):
152 """Masked Full Adder.
153
154 :attribute mask: the carry partition mask
155 :attribute in0: the first input
156 :attribute in1: the second input
157 :attribute in2: the third input
158 :attribute sum: the sum output
159 :attribute mcarry: the masked carry output
160
161 FullAdders are always used with a "mask" on the output. To keep
162 the graphviz "clean", this class performs the masking here rather
163 than inside a large for-loop.
164
165 See the following discussion as to why this is no longer derived
166 from FullAdder. Each carry is shifted here *before* being ANDed
167 with the mask, so that an AOI cell may be used (which is more
168 gate-efficient)
169 https://en.wikipedia.org/wiki/AND-OR-Invert
170 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
171 """
172
173 def __init__(self, width):
174 """Create a ``MaskedFullAdder``.
175
176 :param width: the bit width of the input and output
177 """
178 self.width = width
179 self.mask = Signal(width, reset_less=True)
180 self.mcarry = Signal(width, reset_less=True)
181 self.in0 = Signal(width, reset_less=True)
182 self.in1 = Signal(width, reset_less=True)
183 self.in2 = Signal(width, reset_less=True)
184 self.sum = Signal(width, reset_less=True)
185
186 def elaborate(self, platform):
187 """Elaborate this module."""
188 m = Module()
189 s1 = Signal(self.width, reset_less=True)
190 s2 = Signal(self.width, reset_less=True)
191 s3 = Signal(self.width, reset_less=True)
192 c1 = Signal(self.width, reset_less=True)
193 c2 = Signal(self.width, reset_less=True)
194 c3 = Signal(self.width, reset_less=True)
195 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
196 m.d.comb += s1.eq(Cat(0, self.in0))
197 m.d.comb += s2.eq(Cat(0, self.in1))
198 m.d.comb += s3.eq(Cat(0, self.in2))
199 m.d.comb += c1.eq(s1 & s2 & self.mask)
200 m.d.comb += c2.eq(s2 & s3 & self.mask)
201 m.d.comb += c3.eq(s3 & s1 & self.mask)
202 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
203 return m
204
205
206 class PartitionedAdder(Elaboratable):
207 """Partitioned Adder.
208
209 Performs the final add. The partition points are included in the
210 actual add (in one of the operands only), which causes a carry over
211 to the next bit. Then the final output *removes* the extra bits from
212 the result.
213
214 partition: .... P... P... P... P... (32 bits)
215 a : .... .... .... .... .... (32 bits)
216 b : .... .... .... .... .... (32 bits)
217 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
218 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
219 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
220 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
221
222 :attribute width: the bit width of the input and output. Read-only.
223 :attribute a: the first input to the adder
224 :attribute b: the second input to the adder
225 :attribute output: the sum output
226 :attribute partition_points: the input partition points. Modification not
227 supported, except for by ``Signal.eq``.
228 """
229
230 def __init__(self, width, partition_points):
231 """Create a ``PartitionedAdder``.
232
233 :param width: the bit width of the input and output
234 :param partition_points: the input partition points
235 """
236 self.width = width
237 self.a = Signal(width)
238 self.b = Signal(width)
239 self.output = Signal(width)
240 self.partition_points = PartitionPoints(partition_points)
241 if not self.partition_points.fits_in_width(width):
242 raise ValueError("partition_points doesn't fit in width")
243 expanded_width = 0
244 for i in range(self.width):
245 if i in self.partition_points:
246 expanded_width += 1
247 expanded_width += 1
248 self._expanded_width = expanded_width
249 # XXX these have to remain here due to some horrible nmigen
250 # simulation bugs involving sync. it is *not* necessary to
251 # have them here, they should (under normal circumstances)
252 # be moved into elaborate, as they are entirely local
253 self._expanded_a = Signal(expanded_width) # includes extra part-points
254 self._expanded_b = Signal(expanded_width) # likewise.
255 self._expanded_o = Signal(expanded_width) # likewise.
256
257 def elaborate(self, platform):
258 """Elaborate this module."""
259 m = Module()
260 expanded_index = 0
261 # store bits in a list, use Cat later. graphviz is much cleaner
262 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
263
264 # partition points are "breaks" (extra zeros or 1s) in what would
265 # otherwise be a massive long add. when the "break" points are 0,
266 # whatever is in it (in the output) is discarded. however when
267 # there is a "1", it causes a roll-over carry to the *next* bit.
268 # we still ignore the "break" bit in the [intermediate] output,
269 # however by that time we've got the effect that we wanted: the
270 # carry has been carried *over* the break point.
271
272 for i in range(self.width):
273 if i in self.partition_points:
274 # add extra bit set to 0 + 0 for enabled partition points
275 # and 1 + 0 for disabled partition points
276 ea.append(self._expanded_a[expanded_index])
277 al.append(~self.partition_points[i]) # add extra bit in a
278 eb.append(self._expanded_b[expanded_index])
279 bl.append(C(0)) # yes, add a zero
280 expanded_index += 1 # skip the extra point. NOT in the output
281 ea.append(self._expanded_a[expanded_index])
282 eb.append(self._expanded_b[expanded_index])
283 eo.append(self._expanded_o[expanded_index])
284 al.append(self.a[i])
285 bl.append(self.b[i])
286 ol.append(self.output[i])
287 expanded_index += 1
288
289 # combine above using Cat
290 m.d.comb += Cat(*ea).eq(Cat(*al))
291 m.d.comb += Cat(*eb).eq(Cat(*bl))
292 m.d.comb += Cat(*ol).eq(Cat(*eo))
293
294 # use only one addition to take advantage of look-ahead carry and
295 # special hardware on FPGAs
296 m.d.comb += self._expanded_o.eq(
297 self._expanded_a + self._expanded_b)
298 return m
299
300
301 FULL_ADDER_INPUT_COUNT = 3
302
303 class AddReduceData:
304
305 def __init__(self, ppoints, output_width, n_parts):
306 self.part_ops = [Signal(2, name=f"part_ops_{i}")
307 for i in range(n_parts)]
308 self.inputs = [Signal(output_width, name=f"inputs[{i}]")
309 for i in range(len(self.inputs))]
310 self.reg_partition_points = partition_points.like()
311
312 def eq(self, rhs):
313 return [self.reg_partition_points.eq(rhs.reg_partition_points)] + \
314 [self.inputs[i].eq(rhs.inputs[i])
315 for i in range(len(self.inputs))] + \
316 [self.part_ops[i].eq(rhs.part_ops[i])
317 for i in range(len(self.part_ops))]
318
319
320 class AddReduceSingle(Elaboratable):
321 """Add list of numbers together.
322
323 :attribute inputs: input ``Signal``s to be summed. Modification not
324 supported, except for by ``Signal.eq``.
325 :attribute register_levels: List of nesting levels that should have
326 pipeline registers.
327 :attribute output: output sum.
328 :attribute partition_points: the input partition points. Modification not
329 supported, except for by ``Signal.eq``.
330 """
331
332 def __init__(self, inputs, output_width, register_levels, partition_points,
333 part_ops):
334 """Create an ``AddReduce``.
335
336 :param inputs: input ``Signal``s to be summed.
337 :param output_width: bit-width of ``output``.
338 :param register_levels: List of nesting levels that should have
339 pipeline registers.
340 :param partition_points: the input partition points.
341 """
342 self.part_ops = part_ops
343 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
344 for i in range(len(part_ops))]
345 self.inputs = list(inputs)
346 self._resized_inputs = [
347 Signal(output_width, name=f"resized_inputs[{i}]")
348 for i in range(len(self.inputs))]
349 self.register_levels = list(register_levels)
350 self.output = Signal(output_width)
351 self.partition_points = PartitionPoints(partition_points)
352 if not self.partition_points.fits_in_width(output_width):
353 raise ValueError("partition_points doesn't fit in output_width")
354 self._reg_partition_points = self.partition_points.like()
355
356 max_level = AddReduceSingle.get_max_level(len(self.inputs))
357 for level in self.register_levels:
358 if level > max_level:
359 raise ValueError(
360 "not enough adder levels for specified register levels")
361
362 # this is annoying. we have to create the modules (and terms)
363 # because we need to know what they are (in order to set up the
364 # interconnects back in AddReduce), but cannot do the m.d.comb +=
365 # etc because this is not in elaboratable.
366 self.groups = AddReduceSingle.full_adder_groups(len(self.inputs))
367 self._intermediate_terms = []
368 if len(self.groups) != 0:
369 self.create_next_terms()
370
371 @staticmethod
372 def get_max_level(input_count):
373 """Get the maximum level.
374
375 All ``register_levels`` must be less than or equal to the maximum
376 level.
377 """
378 retval = 0
379 while True:
380 groups = AddReduceSingle.full_adder_groups(input_count)
381 if len(groups) == 0:
382 return retval
383 input_count %= FULL_ADDER_INPUT_COUNT
384 input_count += 2 * len(groups)
385 retval += 1
386
387 @staticmethod
388 def full_adder_groups(input_count):
389 """Get ``inputs`` indices for which a full adder should be built."""
390 return range(0,
391 input_count - FULL_ADDER_INPUT_COUNT + 1,
392 FULL_ADDER_INPUT_COUNT)
393
394 def elaborate(self, platform):
395 """Elaborate this module."""
396 m = Module()
397
398 # resize inputs to correct bit-width and optionally add in
399 # pipeline registers
400 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
401 for i in range(len(self.inputs))]
402 copy_part_ops = [self.out_part_ops[i].eq(self.part_ops[i])
403 for i in range(len(self.part_ops))]
404 if 0 in self.register_levels:
405 m.d.sync += copy_part_ops
406 m.d.sync += resized_input_assignments
407 m.d.sync += self._reg_partition_points.eq(self.partition_points)
408 else:
409 m.d.comb += copy_part_ops
410 m.d.comb += resized_input_assignments
411 m.d.comb += self._reg_partition_points.eq(self.partition_points)
412
413 for (value, term) in self._intermediate_terms:
414 m.d.comb += term.eq(value)
415
416 # if there are no full adders to create, then we handle the base cases
417 # and return, otherwise we go on to the recursive case
418 if len(self.groups) == 0:
419 if len(self.inputs) == 0:
420 # use 0 as the default output value
421 m.d.comb += self.output.eq(0)
422 elif len(self.inputs) == 1:
423 # handle single input
424 m.d.comb += self.output.eq(self._resized_inputs[0])
425 else:
426 # base case for adding 2 inputs
427 assert len(self.inputs) == 2
428 adder = PartitionedAdder(len(self.output),
429 self._reg_partition_points)
430 m.submodules.final_adder = adder
431 m.d.comb += adder.a.eq(self._resized_inputs[0])
432 m.d.comb += adder.b.eq(self._resized_inputs[1])
433 m.d.comb += self.output.eq(adder.output)
434 return m
435
436 mask = self._reg_partition_points.as_mask(len(self.output))
437 m.d.comb += self.part_mask.eq(mask)
438
439 # add and link the intermediate term modules
440 for i, (iidx, adder_i) in enumerate(self.adders):
441 setattr(m.submodules, f"adder_{i}", adder_i)
442
443 m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
444 m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
445 m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
446 m.d.comb += adder_i.mask.eq(self.part_mask)
447
448 return m
449
450 def create_next_terms(self):
451
452 # go on to prepare recursive case
453 intermediate_terms = []
454 _intermediate_terms = []
455
456 def add_intermediate_term(value):
457 intermediate_term = Signal(
458 len(self.output),
459 name=f"intermediate_terms[{len(intermediate_terms)}]")
460 _intermediate_terms.append((value, intermediate_term))
461 intermediate_terms.append(intermediate_term)
462
463 # store mask in intermediary (simplifies graph)
464 self.part_mask = Signal(len(self.output), reset_less=True)
465
466 # create full adders for this recursive level.
467 # this shrinks N terms to 2 * (N // 3) plus the remainder
468 self.adders = []
469 for i in self.groups:
470 adder_i = MaskedFullAdder(len(self.output))
471 self.adders.append((i, adder_i))
472 # add both the sum and the masked-carry to the next level.
473 # 3 inputs have now been reduced to 2...
474 add_intermediate_term(adder_i.sum)
475 add_intermediate_term(adder_i.mcarry)
476 # handle the remaining inputs.
477 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
478 add_intermediate_term(self._resized_inputs[-1])
479 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
480 # Just pass the terms to the next layer, since we wouldn't gain
481 # anything by using a half adder since there would still be 2 terms
482 # and just passing the terms to the next layer saves gates.
483 add_intermediate_term(self._resized_inputs[-2])
484 add_intermediate_term(self._resized_inputs[-1])
485 else:
486 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
487
488 self.intermediate_terms = intermediate_terms
489 self._intermediate_terms = _intermediate_terms
490
491
492 class AddReduce(Elaboratable):
493 """Recursively Add list of numbers together.
494
495 :attribute inputs: input ``Signal``s to be summed. Modification not
496 supported, except for by ``Signal.eq``.
497 :attribute register_levels: List of nesting levels that should have
498 pipeline registers.
499 :attribute output: output sum.
500 :attribute partition_points: the input partition points. Modification not
501 supported, except for by ``Signal.eq``.
502 """
503
504 def __init__(self, inputs, output_width, register_levels, partition_points,
505 part_ops):
506 """Create an ``AddReduce``.
507
508 :param inputs: input ``Signal``s to be summed.
509 :param output_width: bit-width of ``output``.
510 :param register_levels: List of nesting levels that should have
511 pipeline registers.
512 :param partition_points: the input partition points.
513 """
514 self.inputs = inputs
515 self.part_ops = part_ops
516 self.out_part_ops = [Signal(2, name=f"out_part_ops_{i}")
517 for i in range(len(part_ops))]
518 self.output = Signal(output_width)
519 self.output_width = output_width
520 self.register_levels = register_levels
521 self.partition_points = partition_points
522
523 self.create_levels()
524
525 @staticmethod
526 def get_max_level(input_count):
527 return AddReduceSingle.get_max_level(input_count)
528
529 @staticmethod
530 def next_register_levels(register_levels):
531 """``Iterable`` of ``register_levels`` for next recursive level."""
532 for level in register_levels:
533 if level > 0:
534 yield level - 1
535
536 def create_levels(self):
537 """creates reduction levels"""
538
539 mods = []
540 next_levels = self.register_levels
541 partition_points = self.partition_points
542 inputs = self.inputs
543 part_ops = self.part_ops
544 while True:
545 next_level = AddReduceSingle(inputs, self.output_width, next_levels,
546 partition_points, part_ops)
547 mods.append(next_level)
548 if len(next_level.groups) == 0:
549 break
550 next_levels = list(AddReduce.next_register_levels(next_levels))
551 partition_points = next_level._reg_partition_points
552 inputs = next_level.intermediate_terms
553 part_ops = next_level.out_part_ops
554
555 self.levels = mods
556
557 def elaborate(self, platform):
558 """Elaborate this module."""
559 m = Module()
560
561 for i, next_level in enumerate(self.levels):
562 setattr(m.submodules, "next_level%d" % i, next_level)
563
564 # output comes from last module
565 m.d.comb += self.output.eq(next_level.output)
566 copy_part_ops = [self.out_part_ops[i].eq(next_level.out_part_ops[i])
567 for i in range(len(self.part_ops))]
568 m.d.comb += copy_part_ops
569
570 return m
571
572
573 OP_MUL_LOW = 0
574 OP_MUL_SIGNED_HIGH = 1
575 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
576 OP_MUL_UNSIGNED_HIGH = 3
577
578
579 def get_term(value, shift=0, enabled=None):
580 if enabled is not None:
581 value = Mux(enabled, value, 0)
582 if shift > 0:
583 value = Cat(Repl(C(0, 1), shift), value)
584 else:
585 assert shift == 0
586 return value
587
588
589 class ProductTerm(Elaboratable):
590 """ this class creates a single product term (a[..]*b[..]).
591 it has a design flaw in that is the *output* that is selected,
592 where the multiplication(s) are combinatorially generated
593 all the time.
594 """
595
596 def __init__(self, width, twidth, pbwid, a_index, b_index):
597 self.a_index = a_index
598 self.b_index = b_index
599 shift = 8 * (self.a_index + self.b_index)
600 self.pwidth = width
601 self.twidth = twidth
602 self.width = width*2
603 self.shift = shift
604
605 self.ti = Signal(self.width, reset_less=True)
606 self.term = Signal(twidth, reset_less=True)
607 self.a = Signal(twidth//2, reset_less=True)
608 self.b = Signal(twidth//2, reset_less=True)
609 self.pb_en = Signal(pbwid, reset_less=True)
610
611 self.tl = tl = []
612 min_index = min(self.a_index, self.b_index)
613 max_index = max(self.a_index, self.b_index)
614 for i in range(min_index, max_index):
615 tl.append(self.pb_en[i])
616 name = "te_%d_%d" % (self.a_index, self.b_index)
617 if len(tl) > 0:
618 term_enabled = Signal(name=name, reset_less=True)
619 else:
620 term_enabled = None
621 self.enabled = term_enabled
622 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
623
624 def elaborate(self, platform):
625
626 m = Module()
627 if self.enabled is not None:
628 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
629
630 bsa = Signal(self.width, reset_less=True)
631 bsb = Signal(self.width, reset_less=True)
632 a_index, b_index = self.a_index, self.b_index
633 pwidth = self.pwidth
634 m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
635 m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
636 m.d.comb += self.ti.eq(bsa * bsb)
637 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
638 """
639 #TODO: sort out width issues, get inputs a/b switched on/off.
640 #data going into Muxes is 1/2 the required width
641
642 pwidth = self.pwidth
643 width = self.width
644 bsa = Signal(self.twidth//2, reset_less=True)
645 bsb = Signal(self.twidth//2, reset_less=True)
646 asel = Signal(width, reset_less=True)
647 bsel = Signal(width, reset_less=True)
648 a_index, b_index = self.a_index, self.b_index
649 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
650 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
651 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
652 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
653 m.d.comb += self.ti.eq(bsa * bsb)
654 m.d.comb += self.term.eq(self.ti)
655 """
656
657 return m
658
659
660 class ProductTerms(Elaboratable):
661 """ creates a bank of product terms. also performs the actual bit-selection
662 this class is to be wrapped with a for-loop on the "a" operand.
663 it creates a second-level for-loop on the "b" operand.
664 """
665 def __init__(self, width, twidth, pbwid, a_index, blen):
666 self.a_index = a_index
667 self.blen = blen
668 self.pwidth = width
669 self.twidth = twidth
670 self.pbwid = pbwid
671 self.a = Signal(twidth//2, reset_less=True)
672 self.b = Signal(twidth//2, reset_less=True)
673 self.pb_en = Signal(pbwid, reset_less=True)
674 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
675 for i in range(blen)]
676
677 def elaborate(self, platform):
678
679 m = Module()
680
681 for b_index in range(self.blen):
682 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
683 self.a_index, b_index)
684 setattr(m.submodules, "term_%d" % b_index, t)
685
686 m.d.comb += t.a.eq(self.a)
687 m.d.comb += t.b.eq(self.b)
688 m.d.comb += t.pb_en.eq(self.pb_en)
689
690 m.d.comb += self.terms[b_index].eq(t.term)
691
692 return m
693
694
695 class LSBNegTerm(Elaboratable):
696
697 def __init__(self, bit_width):
698 self.bit_width = bit_width
699 self.part = Signal(reset_less=True)
700 self.signed = Signal(reset_less=True)
701 self.op = Signal(bit_width, reset_less=True)
702 self.msb = Signal(reset_less=True)
703 self.nt = Signal(bit_width*2, reset_less=True)
704 self.nl = Signal(bit_width*2, reset_less=True)
705
706 def elaborate(self, platform):
707 m = Module()
708 comb = m.d.comb
709 bit_wid = self.bit_width
710 ext = Repl(0, bit_wid) # extend output to HI part
711
712 # determine sign of each incoming number *in this partition*
713 enabled = Signal(reset_less=True)
714 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
715
716 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
717 # negation operation is split into a bitwise not and a +1.
718 # likewise for 16, 32, and 64-bit values.
719
720 # width-extended 1s complement if a is signed, otherwise zero
721 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
722
723 # add 1 if signed, otherwise add zero
724 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
725
726 return m
727
728
729 class Parts(Elaboratable):
730
731 def __init__(self, pbwid, epps, n_parts):
732 self.pbwid = pbwid
733 # inputs
734 self.epps = PartitionPoints.like(epps, name="epps") # expanded points
735 # outputs
736 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
737
738 def elaborate(self, platform):
739 m = Module()
740
741 epps, parts = self.epps, self.parts
742 # collect part-bytes (double factor because the input is extended)
743 pbs = Signal(self.pbwid, reset_less=True)
744 tl = []
745 for i in range(self.pbwid):
746 pb = Signal(name="pb%d" % i, reset_less=True)
747 m.d.comb += pb.eq(epps.part_byte(i, mfactor=2)) # double
748 tl.append(pb)
749 m.d.comb += pbs.eq(Cat(*tl))
750
751 # negated-temporary copy of partition bits
752 npbs = Signal.like(pbs, reset_less=True)
753 m.d.comb += npbs.eq(~pbs)
754 byte_count = 8 // len(parts)
755 for i in range(len(parts)):
756 pbl = []
757 pbl.append(npbs[i * byte_count - 1])
758 for j in range(i * byte_count, (i + 1) * byte_count - 1):
759 pbl.append(pbs[j])
760 pbl.append(npbs[(i + 1) * byte_count - 1])
761 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
762 m.d.comb += value.eq(Cat(*pbl))
763 m.d.comb += parts[i].eq(~(value).bool())
764
765 return m
766
767
768 class Part(Elaboratable):
769 """ a key class which, depending on the partitioning, will determine
770 what action to take when parts of the output are signed or unsigned.
771
772 this requires 2 pieces of data *per operand, per partition*:
773 whether the MSB is HI/LO (per partition!), and whether a signed
774 or unsigned operation has been *requested*.
775
776 once that is determined, signed is basically carried out
777 by splitting 2's complement into 1's complement plus one.
778 1's complement is just a bit-inversion.
779
780 the extra terms - as separate terms - are then thrown at the
781 AddReduce alongside the multiplication part-results.
782 """
783 def __init__(self, epps, width, n_parts, n_levels, pbwid):
784
785 self.pbwid = pbwid
786 self.epps = epps
787
788 # inputs
789 self.a = Signal(64)
790 self.b = Signal(64)
791 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
792 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
793 self.pbs = Signal(pbwid, reset_less=True)
794
795 # outputs
796 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
797
798 self.not_a_term = Signal(width)
799 self.neg_lsb_a_term = Signal(width)
800 self.not_b_term = Signal(width)
801 self.neg_lsb_b_term = Signal(width)
802
803 def elaborate(self, platform):
804 m = Module()
805
806 pbs, parts = self.pbs, self.parts
807 epps = self.epps
808 m.submodules.p = p = Parts(self.pbwid, epps, len(parts))
809 m.d.comb += p.epps.eq(epps)
810 parts = p.parts
811
812 byte_count = 8 // len(parts)
813
814 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
815 self.not_a_term, self.neg_lsb_a_term,
816 self.not_b_term, self.neg_lsb_b_term)
817
818 byte_width = 8 // len(parts) # byte width
819 bit_wid = 8 * byte_width # bit width
820 nat, nbt, nla, nlb = [], [], [], []
821 for i in range(len(parts)):
822 # work out bit-inverted and +1 term for a.
823 pa = LSBNegTerm(bit_wid)
824 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
825 m.d.comb += pa.part.eq(parts[i])
826 m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
827 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
828 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
829 nat.append(pa.nt)
830 nla.append(pa.nl)
831
832 # work out bit-inverted and +1 term for b
833 pb = LSBNegTerm(bit_wid)
834 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
835 m.d.comb += pb.part.eq(parts[i])
836 m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
837 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
838 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
839 nbt.append(pb.nt)
840 nlb.append(pb.nl)
841
842 # concatenate together and return all 4 results.
843 m.d.comb += [not_a_term.eq(Cat(*nat)),
844 not_b_term.eq(Cat(*nbt)),
845 neg_lsb_a_term.eq(Cat(*nla)),
846 neg_lsb_b_term.eq(Cat(*nlb)),
847 ]
848
849 return m
850
851
852 class IntermediateOut(Elaboratable):
853 """ selects the HI/LO part of the multiplication, for a given bit-width
854 the output is also reconstructed in its SIMD (partition) lanes.
855 """
856 def __init__(self, width, out_wid, n_parts):
857 self.width = width
858 self.n_parts = n_parts
859 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
860 for i in range(8)]
861 self.intermed = Signal(out_wid, reset_less=True)
862 self.output = Signal(out_wid//2, reset_less=True)
863
864 def elaborate(self, platform):
865 m = Module()
866
867 ol = []
868 w = self.width
869 sel = w // 8
870 for i in range(self.n_parts):
871 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
872 m.d.comb += op.eq(
873 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
874 self.intermed.part(i * w*2, w),
875 self.intermed.part(i * w*2 + w, w)))
876 ol.append(op)
877 m.d.comb += self.output.eq(Cat(*ol))
878
879 return m
880
881
882 class FinalOut(Elaboratable):
883 """ selects the final output based on the partitioning.
884
885 each byte is selectable independently, i.e. it is possible
886 that some partitions requested 8-bit computation whilst others
887 requested 16 or 32 bit.
888 """
889 def __init__(self, out_wid):
890 # inputs
891 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
892 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
893 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
894
895 self.i8 = Signal(out_wid, reset_less=True)
896 self.i16 = Signal(out_wid, reset_less=True)
897 self.i32 = Signal(out_wid, reset_less=True)
898 self.i64 = Signal(out_wid, reset_less=True)
899
900 # output
901 self.out = Signal(out_wid, reset_less=True)
902
903 def elaborate(self, platform):
904 m = Module()
905 ol = []
906 for i in range(8):
907 # select one of the outputs: d8 selects i8, d16 selects i16
908 # d32 selects i32, and the default is i64.
909 # d8 and d16 are ORed together in the first Mux
910 # then the 2nd selects either i8 or i16.
911 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
912 op = Signal(8, reset_less=True, name="op_%d" % i)
913 m.d.comb += op.eq(
914 Mux(self.d8[i] | self.d16[i // 2],
915 Mux(self.d8[i], self.i8.part(i * 8, 8),
916 self.i16.part(i * 8, 8)),
917 Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
918 self.i64.part(i * 8, 8))))
919 ol.append(op)
920 m.d.comb += self.out.eq(Cat(*ol))
921 return m
922
923
924 class OrMod(Elaboratable):
925 """ ORs four values together in a hierarchical tree
926 """
927 def __init__(self, wid):
928 self.wid = wid
929 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
930 for i in range(4)]
931 self.orout = Signal(wid, reset_less=True)
932
933 def elaborate(self, platform):
934 m = Module()
935 or1 = Signal(self.wid, reset_less=True)
936 or2 = Signal(self.wid, reset_less=True)
937 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
938 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
939 m.d.comb += self.orout.eq(or1 | or2)
940
941 return m
942
943
944 class Signs(Elaboratable):
945 """ determines whether a or b are signed numbers
946 based on the required operation type (OP_MUL_*)
947 """
948
949 def __init__(self):
950 self.part_ops = Signal(2, reset_less=True)
951 self.a_signed = Signal(reset_less=True)
952 self.b_signed = Signal(reset_less=True)
953
954 def elaborate(self, platform):
955
956 m = Module()
957
958 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
959 bsig = (self.part_ops == OP_MUL_LOW) \
960 | (self.part_ops == OP_MUL_SIGNED_HIGH)
961 m.d.comb += self.a_signed.eq(asig)
962 m.d.comb += self.b_signed.eq(bsig)
963
964 return m
965
966
967 class Mul8_16_32_64(Elaboratable):
968 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
969
970 Supports partitioning into any combination of 8, 16, 32, and 64-bit
971 partitions on naturally-aligned boundaries. Supports the operation being
972 set for each partition independently.
973
974 :attribute part_pts: the input partition points. Has a partition point at
975 multiples of 8 in 0 < i < 64. Each partition point's associated
976 ``Value`` is a ``Signal``. Modification not supported, except for by
977 ``Signal.eq``.
978 :attribute part_ops: the operation for each byte. The operation for a
979 particular partition is selected by assigning the selected operation
980 code to each byte in the partition. The allowed operation codes are:
981
982 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
983 RISC-V's `mul` instruction.
984 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
985 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
986 instruction.
987 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
988 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
989 `mulhsu` instruction.
990 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
991 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
992 instruction.
993 """
994
995 def __init__(self, register_levels=()):
996 """ register_levels: specifies the points in the cascade at which
997 flip-flops are to be inserted.
998 """
999
1000 # parameter(s)
1001 self.register_levels = list(register_levels)
1002
1003 # inputs
1004 self.part_pts = PartitionPoints()
1005 for i in range(8, 64, 8):
1006 self.part_pts[i] = Signal(name=f"part_pts_{i}")
1007 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
1008 self.a = Signal(64)
1009 self.b = Signal(64)
1010
1011 # intermediates (needed for unit tests)
1012 self._intermediate_output = Signal(128)
1013
1014 # output
1015 self.output = Signal(64)
1016
1017 def elaborate(self, platform):
1018 m = Module()
1019
1020 # collect part-bytes
1021 pbs = Signal(8, reset_less=True)
1022 tl = []
1023 for i in range(8):
1024 pb = Signal(name="pb%d" % i, reset_less=True)
1025 m.d.comb += pb.eq(self.part_pts.part_byte(i))
1026 tl.append(pb)
1027 m.d.comb += pbs.eq(Cat(*tl))
1028
1029 # create (doubled) PartitionPoints (output is double input width)
1030 expanded_part_pts = eps = PartitionPoints()
1031 for i, v in self.part_pts.items():
1032 ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
1033 expanded_part_pts[i * 2] = ep
1034 m.d.comb += ep.eq(v)
1035
1036 # local variables
1037 signs = []
1038 for i in range(8):
1039 s = Signs()
1040 signs.append(s)
1041 setattr(m.submodules, "signs%d" % i, s)
1042 m.d.comb += s.part_ops.eq(self.part_ops[i])
1043
1044 n_levels = len(self.register_levels)+1
1045 m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
1046 m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
1047 m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
1048 m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
1049 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1050 for mod in [part_8, part_16, part_32, part_64]:
1051 m.d.comb += mod.a.eq(self.a)
1052 m.d.comb += mod.b.eq(self.b)
1053 for i in range(len(signs)):
1054 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1055 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1056 m.d.comb += mod.pbs.eq(pbs)
1057 nat_l.append(mod.not_a_term)
1058 nbt_l.append(mod.not_b_term)
1059 nla_l.append(mod.neg_lsb_a_term)
1060 nlb_l.append(mod.neg_lsb_b_term)
1061
1062 terms = []
1063
1064 for a_index in range(8):
1065 t = ProductTerms(8, 128, 8, a_index, 8)
1066 setattr(m.submodules, "terms_%d" % a_index, t)
1067
1068 m.d.comb += t.a.eq(self.a)
1069 m.d.comb += t.b.eq(self.b)
1070 m.d.comb += t.pb_en.eq(pbs)
1071
1072 for term in t.terms:
1073 terms.append(term)
1074
1075 # it's fine to bitwise-or data together since they are never enabled
1076 # at the same time
1077 m.submodules.nat_or = nat_or = OrMod(128)
1078 m.submodules.nbt_or = nbt_or = OrMod(128)
1079 m.submodules.nla_or = nla_or = OrMod(128)
1080 m.submodules.nlb_or = nlb_or = OrMod(128)
1081 for l, mod in [(nat_l, nat_or),
1082 (nbt_l, nbt_or),
1083 (nla_l, nla_or),
1084 (nlb_l, nlb_or)]:
1085 for i in range(len(l)):
1086 m.d.comb += mod.orin[i].eq(l[i])
1087 terms.append(mod.orout)
1088
1089 add_reduce = AddReduce(terms,
1090 128,
1091 self.register_levels,
1092 expanded_part_pts,
1093 self.part_ops)
1094
1095 out_part_ops = add_reduce.levels[-1].out_part_ops
1096 out_part_pts = add_reduce.levels[-1]._reg_partition_points
1097
1098 m.submodules.add_reduce = add_reduce
1099 m.d.comb += self._intermediate_output.eq(add_reduce.output)
1100 # create _output_64
1101 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1102 m.d.comb += io64.intermed.eq(self._intermediate_output)
1103 for i in range(8):
1104 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1105
1106 # create _output_32
1107 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1108 m.d.comb += io32.intermed.eq(self._intermediate_output)
1109 for i in range(8):
1110 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1111
1112 # create _output_16
1113 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1114 m.d.comb += io16.intermed.eq(self._intermediate_output)
1115 for i in range(8):
1116 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1117
1118 # create _output_8
1119 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1120 m.d.comb += io8.intermed.eq(self._intermediate_output)
1121 for i in range(8):
1122 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1123
1124 m.submodules.p_8 = p_8 = Parts(8, eps, len(part_8.parts))
1125 m.submodules.p_16 = p_16 = Parts(8, eps, len(part_16.parts))
1126 m.submodules.p_32 = p_32 = Parts(8, eps, len(part_32.parts))
1127 m.submodules.p_64 = p_64 = Parts(8, eps, len(part_64.parts))
1128
1129 m.d.comb += p_8.epps.eq(out_part_pts)
1130 m.d.comb += p_16.epps.eq(out_part_pts)
1131 m.d.comb += p_32.epps.eq(out_part_pts)
1132 m.d.comb += p_64.epps.eq(out_part_pts)
1133
1134 # final output
1135 m.submodules.finalout = finalout = FinalOut(64)
1136 for i in range(len(part_8.parts)):
1137 m.d.comb += finalout.d8[i].eq(p_8.parts[i])
1138 for i in range(len(part_16.parts)):
1139 m.d.comb += finalout.d16[i].eq(p_16.parts[i])
1140 for i in range(len(part_32.parts)):
1141 m.d.comb += finalout.d32[i].eq(p_32.parts[i])
1142 m.d.comb += finalout.i8.eq(io8.output)
1143 m.d.comb += finalout.i16.eq(io16.output)
1144 m.d.comb += finalout.i32.eq(io32.output)
1145 m.d.comb += finalout.i64.eq(io64.output)
1146 m.d.comb += self.output.eq(finalout.out)
1147
1148 return m
1149
1150
1151 if __name__ == "__main__":
1152 m = Mul8_16_32_64()
1153 main(m, ports=[m.a,
1154 m.b,
1155 m._intermediate_output,
1156 m.output,
1157 *m.part_ops,
1158 *m.part_pts.values()])