pass in partition step parameter
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0, mul=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
58 """
59 if name is None:
60 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
61 retval = PartitionPoints()
62 for point, enabled in self.items():
63 point *= mul
64 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
65 return retval
66
67 def eq(self, rhs):
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self.keys()) != set(rhs.keys()):
70 raise ValueError("incompatible point set")
71 for point, enabled in self.items():
72 yield enabled.eq(rhs[point])
73
74 def as_mask(self, width, mul=1):
75 """Create a bit-mask from `self`.
76
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
79
80 :param width: the bit width of the resulting mask
81 :param mul: a "multiplier" which in-place expands the partition points
82 typically set to "2" when used for multipliers
83 """
84 bits = []
85 for i in range(width):
86 i /= mul
87 if i.is_integer() and int(i) in self:
88 bits.append(~self[i])
89 else:
90 bits.append(True)
91 return Cat(*bits)
92
93 def get_max_partition_count(self, width):
94 """Get the maximum number of partitions.
95
96 Gets the number of partitions when all partition points are enabled.
97 """
98 retval = 1
99 for point in self.keys():
100 if point < width:
101 retval += 1
102 return retval
103
104 def fits_in_width(self, width):
105 """Check if all partition points are smaller than `width`."""
106 for point in self.keys():
107 if point >= width:
108 return False
109 return True
110
111 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
112 if index == -1 or index == 7:
113 return C(True, 1)
114 assert index >= 0 and index < 8
115 return self[(index * 8 + 8)*mfactor]
116
117
118 class FullAdder(Elaboratable):
119 """Full Adder.
120
121 :attribute in0: the first input
122 :attribute in1: the second input
123 :attribute in2: the third input
124 :attribute sum: the sum output
125 :attribute carry: the carry output
126
127 Rather than do individual full adders (and have an array of them,
128 which would be very slow to simulate), this module can specify the
129 bit width of the inputs and outputs: in effect it performs multiple
130 Full 3-2 Add operations "in parallel".
131 """
132
133 def __init__(self, width):
134 """Create a ``FullAdder``.
135
136 :param width: the bit width of the input and output
137 """
138 self.in0 = Signal(width, reset_less=True)
139 self.in1 = Signal(width, reset_less=True)
140 self.in2 = Signal(width, reset_less=True)
141 self.sum = Signal(width, reset_less=True)
142 self.carry = Signal(width, reset_less=True)
143
144 def elaborate(self, platform):
145 """Elaborate this module."""
146 m = Module()
147 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
148 m.d.comb += self.carry.eq((self.in0 & self.in1)
149 | (self.in1 & self.in2)
150 | (self.in2 & self.in0))
151 return m
152
153
154 class MaskedFullAdder(Elaboratable):
155 """Masked Full Adder.
156
157 :attribute mask: the carry partition mask
158 :attribute in0: the first input
159 :attribute in1: the second input
160 :attribute in2: the third input
161 :attribute sum: the sum output
162 :attribute mcarry: the masked carry output
163
164 FullAdders are always used with a "mask" on the output. To keep
165 the graphviz "clean", this class performs the masking here rather
166 than inside a large for-loop.
167
168 See the following discussion as to why this is no longer derived
169 from FullAdder. Each carry is shifted here *before* being ANDed
170 with the mask, so that an AOI cell may be used (which is more
171 gate-efficient)
172 https://en.wikipedia.org/wiki/AND-OR-Invert
173 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
174 """
175
176 def __init__(self, width):
177 """Create a ``MaskedFullAdder``.
178
179 :param width: the bit width of the input and output
180 """
181 self.width = width
182 self.mask = Signal(width, reset_less=True)
183 self.mcarry = Signal(width, reset_less=True)
184 self.in0 = Signal(width, reset_less=True)
185 self.in1 = Signal(width, reset_less=True)
186 self.in2 = Signal(width, reset_less=True)
187 self.sum = Signal(width, reset_less=True)
188
189 def elaborate(self, platform):
190 """Elaborate this module."""
191 m = Module()
192 s1 = Signal(self.width, reset_less=True)
193 s2 = Signal(self.width, reset_less=True)
194 s3 = Signal(self.width, reset_less=True)
195 c1 = Signal(self.width, reset_less=True)
196 c2 = Signal(self.width, reset_less=True)
197 c3 = Signal(self.width, reset_less=True)
198 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
199 m.d.comb += s1.eq(Cat(0, self.in0))
200 m.d.comb += s2.eq(Cat(0, self.in1))
201 m.d.comb += s3.eq(Cat(0, self.in2))
202 m.d.comb += c1.eq(s1 & s2 & self.mask)
203 m.d.comb += c2.eq(s2 & s3 & self.mask)
204 m.d.comb += c3.eq(s3 & s1 & self.mask)
205 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
206 return m
207
208
209 class PartitionedAdder(Elaboratable):
210 """Partitioned Adder.
211
212 Performs the final add. The partition points are included in the
213 actual add (in one of the operands only), which causes a carry over
214 to the next bit. Then the final output *removes* the extra bits from
215 the result.
216
217 partition: .... P... P... P... P... (32 bits)
218 a : .... .... .... .... .... (32 bits)
219 b : .... .... .... .... .... (32 bits)
220 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
221 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
222 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
223 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
224
225 :attribute width: the bit width of the input and output. Read-only.
226 :attribute a: the first input to the adder
227 :attribute b: the second input to the adder
228 :attribute output: the sum output
229 :attribute partition_points: the input partition points. Modification not
230 supported, except for by ``Signal.eq``.
231 """
232
233 def __init__(self, width, partition_points, partition_step=1):
234 """Create a ``PartitionedAdder``.
235
236 :param width: the bit width of the input and output
237 :param partition_points: the input partition points
238 :param partition_step: a multiplier (typically double) step
239 which in-place "expands" the partition points
240 """
241 self.width = width
242 self.pmul = partition_step
243 self.a = Signal(width, reset_less=True)
244 self.b = Signal(width, reset_less=True)
245 self.output = Signal(width, reset_less=True)
246 self.partition_points = PartitionPoints(partition_points)
247 if not self.partition_points.fits_in_width(width):
248 raise ValueError("partition_points doesn't fit in width")
249 expanded_width = 0
250 for i in range(self.width):
251 if i in self.partition_points:
252 expanded_width += 1
253 expanded_width += 1
254 self._expanded_width = expanded_width
255
256 def elaborate(self, platform):
257 """Elaborate this module."""
258 m = Module()
259 expanded_a = Signal(self._expanded_width, reset_less=True)
260 expanded_b = Signal(self._expanded_width, reset_less=True)
261 expanded_o = Signal(self._expanded_width, reset_less=True)
262
263 expanded_index = 0
264 # store bits in a list, use Cat later. graphviz is much cleaner
265 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
266
267 # partition points are "breaks" (extra zeros or 1s) in what would
268 # otherwise be a massive long add. when the "break" points are 0,
269 # whatever is in it (in the output) is discarded. however when
270 # there is a "1", it causes a roll-over carry to the *next* bit.
271 # we still ignore the "break" bit in the [intermediate] output,
272 # however by that time we've got the effect that we wanted: the
273 # carry has been carried *over* the break point.
274
275 for i in range(self.width):
276 pi = i/self.pmul # double the range of the partition point test
277 if pi.is_integer() and pi in self.partition_points:
278 # add extra bit set to 0 + 0 for enabled partition points
279 # and 1 + 0 for disabled partition points
280 ea.append(expanded_a[expanded_index])
281 al.append(~self.partition_points[pi]) # add extra bit in a
282 eb.append(expanded_b[expanded_index])
283 bl.append(C(0)) # yes, add a zero
284 expanded_index += 1 # skip the extra point. NOT in the output
285 ea.append(expanded_a[expanded_index])
286 eb.append(expanded_b[expanded_index])
287 eo.append(expanded_o[expanded_index])
288 al.append(self.a[i])
289 bl.append(self.b[i])
290 ol.append(self.output[i])
291 expanded_index += 1
292
293 # combine above using Cat
294 m.d.comb += Cat(*ea).eq(Cat(*al))
295 m.d.comb += Cat(*eb).eq(Cat(*bl))
296 m.d.comb += Cat(*ol).eq(Cat(*eo))
297
298 # use only one addition to take advantage of look-ahead carry and
299 # special hardware on FPGAs
300 m.d.comb += expanded_o.eq(expanded_a + expanded_b)
301 return m
302
303
304 FULL_ADDER_INPUT_COUNT = 3
305
306 class AddReduceData:
307
308 def __init__(self, part_pts, n_inputs, output_width, n_parts):
309 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
310 for i in range(n_parts)]
311 self.terms = [Signal(output_width, name=f"inputs_{i}",
312 reset_less=True)
313 for i in range(n_inputs)]
314 self.part_pts = part_pts.like()
315
316 def eq_from(self, part_pts, inputs, part_ops):
317 return [self.part_pts.eq(part_pts)] + \
318 [self.terms[i].eq(inputs[i])
319 for i in range(len(self.terms))] + \
320 [self.part_ops[i].eq(part_ops[i])
321 for i in range(len(self.part_ops))]
322
323 def eq(self, rhs):
324 return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
325
326
327 class FinalReduceData:
328
329 def __init__(self, part_pts, output_width, n_parts):
330 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
331 for i in range(n_parts)]
332 self.output = Signal(output_width, reset_less=True)
333 self.part_pts = part_pts.like()
334
335 def eq_from(self, part_pts, output, part_ops):
336 return [self.part_pts.eq(part_pts)] + \
337 [self.output.eq(output)] + \
338 [self.part_ops[i].eq(part_ops[i])
339 for i in range(len(self.part_ops))]
340
341 def eq(self, rhs):
342 return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
343
344
345 class FinalAdd(Elaboratable):
346 """ Final stage of add reduce
347 """
348
349 def __init__(self, n_inputs, output_width, n_parts, partition_points,
350 partition_step=1):
351 self.partition_step = partition_step
352 self.output_width = output_width
353 self.n_inputs = n_inputs
354 self.n_parts = n_parts
355 self.partition_points = PartitionPoints(partition_points)
356 if not self.partition_points.fits_in_width(output_width):
357 raise ValueError("partition_points doesn't fit in output_width")
358
359 self.i = self.ispec()
360 self.o = self.ospec()
361
362 def ispec(self):
363 return AddReduceData(self.partition_points, self.n_inputs,
364 self.output_width, self.n_parts)
365
366 def ospec(self):
367 return FinalReduceData(self.partition_points,
368 self.output_width, self.n_parts)
369
370 def elaborate(self, platform):
371 """Elaborate this module."""
372 m = Module()
373
374 output_width = self.output_width
375 output = Signal(output_width, reset_less=True)
376 if self.n_inputs == 0:
377 # use 0 as the default output value
378 m.d.comb += output.eq(0)
379 elif self.n_inputs == 1:
380 # handle single input
381 m.d.comb += output.eq(self.i.terms[0])
382 else:
383 # base case for adding 2 inputs
384 assert self.n_inputs == 2
385 adder = PartitionedAdder(output_width,
386 self.i.part_pts, self.partition_step)
387 m.submodules.final_adder = adder
388 m.d.comb += adder.a.eq(self.i.terms[0])
389 m.d.comb += adder.b.eq(self.i.terms[1])
390 m.d.comb += output.eq(adder.output)
391
392 # create output
393 m.d.comb += self.o.eq_from(self.i.part_pts, output,
394 self.i.part_ops)
395
396 return m
397
398
399 class AddReduceSingle(Elaboratable):
400 """Add list of numbers together.
401
402 :attribute inputs: input ``Signal``s to be summed. Modification not
403 supported, except for by ``Signal.eq``.
404 :attribute register_levels: List of nesting levels that should have
405 pipeline registers.
406 :attribute output: output sum.
407 :attribute partition_points: the input partition points. Modification not
408 supported, except for by ``Signal.eq``.
409 """
410
411 def __init__(self, n_inputs, output_width, n_parts, partition_points):
412 """Create an ``AddReduce``.
413
414 :param inputs: input ``Signal``s to be summed.
415 :param output_width: bit-width of ``output``.
416 :param partition_points: the input partition points.
417 """
418 self.n_inputs = n_inputs
419 self.n_parts = n_parts
420 self.output_width = output_width
421 self.partition_points = PartitionPoints(partition_points)
422 if not self.partition_points.fits_in_width(output_width):
423 raise ValueError("partition_points doesn't fit in output_width")
424
425 self.groups = AddReduceSingle.full_adder_groups(n_inputs)
426 self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
427
428 self.i = self.ispec()
429 self.o = self.ospec()
430
431 def ispec(self):
432 return AddReduceData(self.partition_points, self.n_inputs,
433 self.output_width, self.n_parts)
434
435 def ospec(self):
436 return AddReduceData(self.partition_points, self.n_terms,
437 self.output_width, self.n_parts)
438
439 @staticmethod
440 def calc_n_inputs(n_inputs, groups):
441 retval = len(groups)*2
442 if n_inputs % FULL_ADDER_INPUT_COUNT == 1:
443 retval += 1
444 elif n_inputs % FULL_ADDER_INPUT_COUNT == 2:
445 retval += 2
446 else:
447 assert n_inputs % FULL_ADDER_INPUT_COUNT == 0
448 return retval
449
450 @staticmethod
451 def get_max_level(input_count):
452 """Get the maximum level.
453
454 All ``register_levels`` must be less than or equal to the maximum
455 level.
456 """
457 retval = 0
458 while True:
459 groups = AddReduceSingle.full_adder_groups(input_count)
460 if len(groups) == 0:
461 return retval
462 input_count %= FULL_ADDER_INPUT_COUNT
463 input_count += 2 * len(groups)
464 retval += 1
465
466 @staticmethod
467 def full_adder_groups(input_count):
468 """Get ``inputs`` indices for which a full adder should be built."""
469 return range(0,
470 input_count - FULL_ADDER_INPUT_COUNT + 1,
471 FULL_ADDER_INPUT_COUNT)
472
473 def create_next_terms(self):
474 """ create next intermediate terms, for linking up in elaborate, below
475 """
476 terms = []
477 adders = []
478
479 # create full adders for this recursive level.
480 # this shrinks N terms to 2 * (N // 3) plus the remainder
481 for i in self.groups:
482 adder_i = MaskedFullAdder(self.output_width)
483 adders.append((i, adder_i))
484 # add both the sum and the masked-carry to the next level.
485 # 3 inputs have now been reduced to 2...
486 terms.append(adder_i.sum)
487 terms.append(adder_i.mcarry)
488 # handle the remaining inputs.
489 if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
490 terms.append(self.i.terms[-1])
491 elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
492 # Just pass the terms to the next layer, since we wouldn't gain
493 # anything by using a half adder since there would still be 2 terms
494 # and just passing the terms to the next layer saves gates.
495 terms.append(self.i.terms[-2])
496 terms.append(self.i.terms[-1])
497 else:
498 assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
499
500 return terms, adders
501
502 def elaborate(self, platform):
503 """Elaborate this module."""
504 m = Module()
505
506 terms, adders = self.create_next_terms()
507
508 # copy the intermediate terms to the output
509 for i, value in enumerate(terms):
510 m.d.comb += self.o.terms[i].eq(value)
511
512 # copy reg part points and part ops to output
513 m.d.comb += self.o.part_pts.eq(self.i.part_pts)
514 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
515 for i in range(len(self.i.part_ops))]
516
517 # set up the partition mask (for the adders)
518 part_mask = Signal(self.output_width, reset_less=True)
519
520 # get partition points as a mask
521 mask = self.i.part_pts.as_mask(self.output_width, mul=2)
522 m.d.comb += part_mask.eq(mask)
523
524 # add and link the intermediate term modules
525 for i, (iidx, adder_i) in enumerate(adders):
526 setattr(m.submodules, f"adder_{i}", adder_i)
527
528 m.d.comb += adder_i.in0.eq(self.i.terms[iidx])
529 m.d.comb += adder_i.in1.eq(self.i.terms[iidx + 1])
530 m.d.comb += adder_i.in2.eq(self.i.terms[iidx + 2])
531 m.d.comb += adder_i.mask.eq(part_mask)
532
533 return m
534
535
536 class AddReduceInternal:
537 """Recursively Add list of numbers together.
538
539 :attribute inputs: input ``Signal``s to be summed. Modification not
540 supported, except for by ``Signal.eq``.
541 :attribute register_levels: List of nesting levels that should have
542 pipeline registers.
543 :attribute output: output sum.
544 :attribute partition_points: the input partition points. Modification not
545 supported, except for by ``Signal.eq``.
546 """
547
548 def __init__(self, i, output_width, partition_step=1):
549 """Create an ``AddReduce``.
550
551 :param inputs: input ``Signal``s to be summed.
552 :param output_width: bit-width of ``output``.
553 :param partition_points: the input partition points.
554 """
555 self.i = i
556 self.inputs = i.terms
557 self.part_ops = i.part_ops
558 self.output_width = output_width
559 self.partition_points = i.part_pts
560 self.partition_step = partition_step
561
562 self.create_levels()
563
564 def create_levels(self):
565 """creates reduction levels"""
566
567 mods = []
568 partition_points = self.partition_points
569 part_ops = self.part_ops
570 n_parts = len(part_ops)
571 inputs = self.inputs
572 ilen = len(inputs)
573 while True:
574 groups = AddReduceSingle.full_adder_groups(len(inputs))
575 if len(groups) == 0:
576 break
577 next_level = AddReduceSingle(ilen, self.output_width, n_parts,
578 partition_points)
579 mods.append(next_level)
580 partition_points = next_level.i.part_pts
581 inputs = next_level.o.terms
582 ilen = len(inputs)
583 part_ops = next_level.i.part_ops
584
585 next_level = FinalAdd(ilen, self.output_width, n_parts,
586 partition_points, self.partition_step)
587 mods.append(next_level)
588
589 self.levels = mods
590
591
592 class AddReduce(AddReduceInternal, Elaboratable):
593 """Recursively Add list of numbers together.
594
595 :attribute inputs: input ``Signal``s to be summed. Modification not
596 supported, except for by ``Signal.eq``.
597 :attribute register_levels: List of nesting levels that should have
598 pipeline registers.
599 :attribute output: output sum.
600 :attribute partition_points: the input partition points. Modification not
601 supported, except for by ``Signal.eq``.
602 """
603
604 def __init__(self, inputs, output_width, register_levels, part_pts,
605 part_ops, partition_step=1):
606 """Create an ``AddReduce``.
607
608 :param inputs: input ``Signal``s to be summed.
609 :param output_width: bit-width of ``output``.
610 :param register_levels: List of nesting levels that should have
611 pipeline registers.
612 :param partition_points: the input partition points.
613 """
614 self._inputs = inputs
615 self._part_pts = part_pts
616 self._part_ops = part_ops
617 n_parts = len(part_ops)
618 self.i = AddReduceData(part_pts, len(inputs),
619 output_width, n_parts)
620 AddReduceInternal.__init__(self, self.i, output_width, partition_step)
621 self.o = FinalReduceData(part_pts, output_width, n_parts)
622 self.register_levels = register_levels
623
624 @staticmethod
625 def get_max_level(input_count):
626 return AddReduceSingle.get_max_level(input_count)
627
628 @staticmethod
629 def next_register_levels(register_levels):
630 """``Iterable`` of ``register_levels`` for next recursive level."""
631 for level in register_levels:
632 if level > 0:
633 yield level - 1
634
635 def elaborate(self, platform):
636 """Elaborate this module."""
637 m = Module()
638
639 m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops)
640
641 for i, next_level in enumerate(self.levels):
642 setattr(m.submodules, "next_level%d" % i, next_level)
643
644 i = self.i
645 for idx in range(len(self.levels)):
646 mcur = self.levels[idx]
647 if idx in self.register_levels:
648 m.d.sync += mcur.i.eq(i)
649 else:
650 m.d.comb += mcur.i.eq(i)
651 i = mcur.o # for next loop
652
653 # output comes from last module
654 m.d.comb += self.o.eq(i)
655
656 return m
657
658
659 OP_MUL_LOW = 0
660 OP_MUL_SIGNED_HIGH = 1
661 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
662 OP_MUL_UNSIGNED_HIGH = 3
663
664
665 def get_term(value, shift=0, enabled=None):
666 if enabled is not None:
667 value = Mux(enabled, value, 0)
668 if shift > 0:
669 value = Cat(Repl(C(0, 1), shift), value)
670 else:
671 assert shift == 0
672 return value
673
674
675 class ProductTerm(Elaboratable):
676 """ this class creates a single product term (a[..]*b[..]).
677 it has a design flaw in that is the *output* that is selected,
678 where the multiplication(s) are combinatorially generated
679 all the time.
680 """
681
682 def __init__(self, width, twidth, pbwid, a_index, b_index):
683 self.a_index = a_index
684 self.b_index = b_index
685 shift = 8 * (self.a_index + self.b_index)
686 self.pwidth = width
687 self.twidth = twidth
688 self.width = width*2
689 self.shift = shift
690
691 self.ti = Signal(self.width, reset_less=True)
692 self.term = Signal(twidth, reset_less=True)
693 self.a = Signal(twidth//2, reset_less=True)
694 self.b = Signal(twidth//2, reset_less=True)
695 self.pb_en = Signal(pbwid, reset_less=True)
696
697 self.tl = tl = []
698 min_index = min(self.a_index, self.b_index)
699 max_index = max(self.a_index, self.b_index)
700 for i in range(min_index, max_index):
701 tl.append(self.pb_en[i])
702 name = "te_%d_%d" % (self.a_index, self.b_index)
703 if len(tl) > 0:
704 term_enabled = Signal(name=name, reset_less=True)
705 else:
706 term_enabled = None
707 self.enabled = term_enabled
708 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
709
710 def elaborate(self, platform):
711
712 m = Module()
713 if self.enabled is not None:
714 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
715
716 bsa = Signal(self.width, reset_less=True)
717 bsb = Signal(self.width, reset_less=True)
718 a_index, b_index = self.a_index, self.b_index
719 pwidth = self.pwidth
720 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
721 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
722 m.d.comb += self.ti.eq(bsa * bsb)
723 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
724 """
725 #TODO: sort out width issues, get inputs a/b switched on/off.
726 #data going into Muxes is 1/2 the required width
727
728 pwidth = self.pwidth
729 width = self.width
730 bsa = Signal(self.twidth//2, reset_less=True)
731 bsb = Signal(self.twidth//2, reset_less=True)
732 asel = Signal(width, reset_less=True)
733 bsel = Signal(width, reset_less=True)
734 a_index, b_index = self.a_index, self.b_index
735 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
736 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
737 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
738 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
739 m.d.comb += self.ti.eq(bsa * bsb)
740 m.d.comb += self.term.eq(self.ti)
741 """
742
743 return m
744
745
746 class ProductTerms(Elaboratable):
747 """ creates a bank of product terms. also performs the actual bit-selection
748 this class is to be wrapped with a for-loop on the "a" operand.
749 it creates a second-level for-loop on the "b" operand.
750 """
751 def __init__(self, width, twidth, pbwid, a_index, blen):
752 self.a_index = a_index
753 self.blen = blen
754 self.pwidth = width
755 self.twidth = twidth
756 self.pbwid = pbwid
757 self.a = Signal(twidth//2, reset_less=True)
758 self.b = Signal(twidth//2, reset_less=True)
759 self.pb_en = Signal(pbwid, reset_less=True)
760 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
761 for i in range(blen)]
762
763 def elaborate(self, platform):
764
765 m = Module()
766
767 for b_index in range(self.blen):
768 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
769 self.a_index, b_index)
770 setattr(m.submodules, "term_%d" % b_index, t)
771
772 m.d.comb += t.a.eq(self.a)
773 m.d.comb += t.b.eq(self.b)
774 m.d.comb += t.pb_en.eq(self.pb_en)
775
776 m.d.comb += self.terms[b_index].eq(t.term)
777
778 return m
779
780
781 class LSBNegTerm(Elaboratable):
782
783 def __init__(self, bit_width):
784 self.bit_width = bit_width
785 self.part = Signal(reset_less=True)
786 self.signed = Signal(reset_less=True)
787 self.op = Signal(bit_width, reset_less=True)
788 self.msb = Signal(reset_less=True)
789 self.nt = Signal(bit_width*2, reset_less=True)
790 self.nl = Signal(bit_width*2, reset_less=True)
791
792 def elaborate(self, platform):
793 m = Module()
794 comb = m.d.comb
795 bit_wid = self.bit_width
796 ext = Repl(0, bit_wid) # extend output to HI part
797
798 # determine sign of each incoming number *in this partition*
799 enabled = Signal(reset_less=True)
800 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
801
802 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
803 # negation operation is split into a bitwise not and a +1.
804 # likewise for 16, 32, and 64-bit values.
805
806 # width-extended 1s complement if a is signed, otherwise zero
807 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
808
809 # add 1 if signed, otherwise add zero
810 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
811
812 return m
813
814
815 class Parts(Elaboratable):
816
817 def __init__(self, pbwid, part_pts, n_parts):
818 self.pbwid = pbwid
819 # inputs
820 self.part_pts = PartitionPoints.like(part_pts)
821 # outputs
822 self.parts = [Signal(name=f"part_{i}", reset_less=True)
823 for i in range(n_parts)]
824
825 def elaborate(self, platform):
826 m = Module()
827
828 part_pts, parts = self.part_pts, self.parts
829 # collect part-bytes (double factor because the input is extended)
830 pbs = Signal(self.pbwid, reset_less=True)
831 tl = []
832 for i in range(self.pbwid):
833 pb = Signal(name="pb%d" % i, reset_less=True)
834 m.d.comb += pb.eq(part_pts.part_byte(i))
835 tl.append(pb)
836 m.d.comb += pbs.eq(Cat(*tl))
837
838 # negated-temporary copy of partition bits
839 npbs = Signal.like(pbs, reset_less=True)
840 m.d.comb += npbs.eq(~pbs)
841 byte_count = 8 // len(parts)
842 for i in range(len(parts)):
843 pbl = []
844 pbl.append(npbs[i * byte_count - 1])
845 for j in range(i * byte_count, (i + 1) * byte_count - 1):
846 pbl.append(pbs[j])
847 pbl.append(npbs[(i + 1) * byte_count - 1])
848 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
849 m.d.comb += value.eq(Cat(*pbl))
850 m.d.comb += parts[i].eq(~(value).bool())
851
852 return m
853
854
855 class Part(Elaboratable):
856 """ a key class which, depending on the partitioning, will determine
857 what action to take when parts of the output are signed or unsigned.
858
859 this requires 2 pieces of data *per operand, per partition*:
860 whether the MSB is HI/LO (per partition!), and whether a signed
861 or unsigned operation has been *requested*.
862
863 once that is determined, signed is basically carried out
864 by splitting 2's complement into 1's complement plus one.
865 1's complement is just a bit-inversion.
866
867 the extra terms - as separate terms - are then thrown at the
868 AddReduce alongside the multiplication part-results.
869 """
870 def __init__(self, part_pts, width, n_parts, n_levels, pbwid):
871
872 self.pbwid = pbwid
873 self.part_pts = part_pts
874
875 # inputs
876 self.a = Signal(64, reset_less=True)
877 self.b = Signal(64, reset_less=True)
878 self.a_signed = [Signal(name=f"a_signed_{i}", reset_less=True)
879 for i in range(8)]
880 self.b_signed = [Signal(name=f"_b_signed_{i}", reset_less=True)
881 for i in range(8)]
882 self.pbs = Signal(pbwid, reset_less=True)
883
884 # outputs
885 self.parts = [Signal(name=f"part_{i}", reset_less=True)
886 for i in range(n_parts)]
887
888 self.not_a_term = Signal(width, reset_less=True)
889 self.neg_lsb_a_term = Signal(width, reset_less=True)
890 self.not_b_term = Signal(width, reset_less=True)
891 self.neg_lsb_b_term = Signal(width, reset_less=True)
892
893 def elaborate(self, platform):
894 m = Module()
895
896 pbs, parts = self.pbs, self.parts
897 part_pts = self.part_pts
898 m.submodules.p = p = Parts(self.pbwid, part_pts, len(parts))
899 m.d.comb += p.part_pts.eq(part_pts)
900 parts = p.parts
901
902 byte_count = 8 // len(parts)
903
904 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
905 self.not_a_term, self.neg_lsb_a_term,
906 self.not_b_term, self.neg_lsb_b_term)
907
908 byte_width = 8 // len(parts) # byte width
909 bit_wid = 8 * byte_width # bit width
910 nat, nbt, nla, nlb = [], [], [], []
911 for i in range(len(parts)):
912 # work out bit-inverted and +1 term for a.
913 pa = LSBNegTerm(bit_wid)
914 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
915 m.d.comb += pa.part.eq(parts[i])
916 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
917 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
918 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
919 nat.append(pa.nt)
920 nla.append(pa.nl)
921
922 # work out bit-inverted and +1 term for b
923 pb = LSBNegTerm(bit_wid)
924 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
925 m.d.comb += pb.part.eq(parts[i])
926 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
927 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
928 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
929 nbt.append(pb.nt)
930 nlb.append(pb.nl)
931
932 # concatenate together and return all 4 results.
933 m.d.comb += [not_a_term.eq(Cat(*nat)),
934 not_b_term.eq(Cat(*nbt)),
935 neg_lsb_a_term.eq(Cat(*nla)),
936 neg_lsb_b_term.eq(Cat(*nlb)),
937 ]
938
939 return m
940
941
942 class IntermediateOut(Elaboratable):
943 """ selects the HI/LO part of the multiplication, for a given bit-width
944 the output is also reconstructed in its SIMD (partition) lanes.
945 """
946 def __init__(self, width, out_wid, n_parts):
947 self.width = width
948 self.n_parts = n_parts
949 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
950 for i in range(8)]
951 self.intermed = Signal(out_wid, reset_less=True)
952 self.output = Signal(out_wid//2, reset_less=True)
953
954 def elaborate(self, platform):
955 m = Module()
956
957 ol = []
958 w = self.width
959 sel = w // 8
960 for i in range(self.n_parts):
961 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
962 m.d.comb += op.eq(
963 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
964 self.intermed.bit_select(i * w*2, w),
965 self.intermed.bit_select(i * w*2 + w, w)))
966 ol.append(op)
967 m.d.comb += self.output.eq(Cat(*ol))
968
969 return m
970
971
972 class FinalOut(Elaboratable):
973 """ selects the final output based on the partitioning.
974
975 each byte is selectable independently, i.e. it is possible
976 that some partitions requested 8-bit computation whilst others
977 requested 16 or 32 bit.
978 """
979 def __init__(self, output_width, n_parts, part_pts):
980 self.part_pts = part_pts
981 self.output_width = output_width
982 self.n_parts = n_parts
983 self.out_wid = output_width//2
984
985 self.i = self.ispec()
986 self.o = self.ospec()
987
988 def ispec(self):
989 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
990
991 def ospec(self):
992 return OutputData()
993
994 def elaborate(self, platform):
995 m = Module()
996
997 part_pts = self.part_pts
998 m.submodules.p_8 = p_8 = Parts(8, part_pts, 8)
999 m.submodules.p_16 = p_16 = Parts(8, part_pts, 4)
1000 m.submodules.p_32 = p_32 = Parts(8, part_pts, 2)
1001 m.submodules.p_64 = p_64 = Parts(8, part_pts, 1)
1002
1003 out_part_pts = self.i.part_pts
1004
1005 # temporaries
1006 d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
1007 d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
1008 d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
1009
1010 i8 = Signal(self.out_wid, reset_less=True)
1011 i16 = Signal(self.out_wid, reset_less=True)
1012 i32 = Signal(self.out_wid, reset_less=True)
1013 i64 = Signal(self.out_wid, reset_less=True)
1014
1015 m.d.comb += p_8.part_pts.eq(out_part_pts)
1016 m.d.comb += p_16.part_pts.eq(out_part_pts)
1017 m.d.comb += p_32.part_pts.eq(out_part_pts)
1018 m.d.comb += p_64.part_pts.eq(out_part_pts)
1019
1020 for i in range(len(p_8.parts)):
1021 m.d.comb += d8[i].eq(p_8.parts[i])
1022 for i in range(len(p_16.parts)):
1023 m.d.comb += d16[i].eq(p_16.parts[i])
1024 for i in range(len(p_32.parts)):
1025 m.d.comb += d32[i].eq(p_32.parts[i])
1026 m.d.comb += i8.eq(self.i.outputs[0])
1027 m.d.comb += i16.eq(self.i.outputs[1])
1028 m.d.comb += i32.eq(self.i.outputs[2])
1029 m.d.comb += i64.eq(self.i.outputs[3])
1030
1031 ol = []
1032 for i in range(8):
1033 # select one of the outputs: d8 selects i8, d16 selects i16
1034 # d32 selects i32, and the default is i64.
1035 # d8 and d16 are ORed together in the first Mux
1036 # then the 2nd selects either i8 or i16.
1037 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
1038 op = Signal(8, reset_less=True, name="op_%d" % i)
1039 m.d.comb += op.eq(
1040 Mux(d8[i] | d16[i // 2],
1041 Mux(d8[i], i8.bit_select(i * 8, 8),
1042 i16.bit_select(i * 8, 8)),
1043 Mux(d32[i // 4], i32.bit_select(i * 8, 8),
1044 i64.bit_select(i * 8, 8))))
1045 ol.append(op)
1046
1047 # create outputs
1048 m.d.comb += self.o.output.eq(Cat(*ol))
1049 m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
1050
1051 return m
1052
1053
1054 class OrMod(Elaboratable):
1055 """ ORs four values together in a hierarchical tree
1056 """
1057 def __init__(self, wid):
1058 self.wid = wid
1059 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
1060 for i in range(4)]
1061 self.orout = Signal(wid, reset_less=True)
1062
1063 def elaborate(self, platform):
1064 m = Module()
1065 or1 = Signal(self.wid, reset_less=True)
1066 or2 = Signal(self.wid, reset_less=True)
1067 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
1068 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
1069 m.d.comb += self.orout.eq(or1 | or2)
1070
1071 return m
1072
1073
1074 class Signs(Elaboratable):
1075 """ determines whether a or b are signed numbers
1076 based on the required operation type (OP_MUL_*)
1077 """
1078
1079 def __init__(self):
1080 self.part_ops = Signal(2, reset_less=True)
1081 self.a_signed = Signal(reset_less=True)
1082 self.b_signed = Signal(reset_less=True)
1083
1084 def elaborate(self, platform):
1085
1086 m = Module()
1087
1088 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
1089 bsig = (self.part_ops == OP_MUL_LOW) \
1090 | (self.part_ops == OP_MUL_SIGNED_HIGH)
1091 m.d.comb += self.a_signed.eq(asig)
1092 m.d.comb += self.b_signed.eq(bsig)
1093
1094 return m
1095
1096
1097 class IntermediateData:
1098
1099 def __init__(self, part_pts, output_width, n_parts):
1100 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
1101 for i in range(n_parts)]
1102 self.part_pts = part_pts.like()
1103 self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
1104 for i in range(4)]
1105 # intermediates (needed for unit tests)
1106 self.intermediate_output = Signal(output_width)
1107
1108 def eq_from(self, part_pts, outputs, intermediate_output,
1109 part_ops):
1110 return [self.part_pts.eq(part_pts)] + \
1111 [self.intermediate_output.eq(intermediate_output)] + \
1112 [self.outputs[i].eq(outputs[i])
1113 for i in range(4)] + \
1114 [self.part_ops[i].eq(part_ops[i])
1115 for i in range(len(self.part_ops))]
1116
1117 def eq(self, rhs):
1118 return self.eq_from(rhs.part_pts, rhs.outputs,
1119 rhs.intermediate_output, rhs.part_ops)
1120
1121
1122 class InputData:
1123
1124 def __init__(self):
1125 self.a = Signal(64)
1126 self.b = Signal(64)
1127 self.part_pts = PartitionPoints()
1128 for i in range(8, 64, 8):
1129 self.part_pts[i] = Signal(name=f"part_pts_{i}")
1130 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
1131
1132 def eq_from(self, part_pts, a, b, part_ops):
1133 return [self.part_pts.eq(part_pts)] + \
1134 [self.a.eq(a), self.b.eq(b)] + \
1135 [self.part_ops[i].eq(part_ops[i])
1136 for i in range(len(self.part_ops))]
1137
1138 def eq(self, rhs):
1139 return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
1140
1141
1142 class OutputData:
1143
1144 def __init__(self):
1145 self.intermediate_output = Signal(128) # needed for unit tests
1146 self.output = Signal(64)
1147
1148 def eq(self, rhs):
1149 return [self.intermediate_output.eq(rhs.intermediate_output),
1150 self.output.eq(rhs.output)]
1151
1152
1153 class AllTerms(Elaboratable):
1154 """Set of terms to be added together
1155 """
1156
1157 def __init__(self, n_inputs, output_width, n_parts, register_levels):
1158 """Create an ``AddReduce``.
1159
1160 :param inputs: input ``Signal``s to be summed.
1161 :param output_width: bit-width of ``output``.
1162 :param register_levels: List of nesting levels that should have
1163 pipeline registers.
1164 :param partition_points: the input partition points.
1165 """
1166 self.register_levels = register_levels
1167 self.n_inputs = n_inputs
1168 self.n_parts = n_parts
1169 self.output_width = output_width
1170
1171 self.i = self.ispec()
1172 self.o = self.ospec()
1173
1174 def ispec(self):
1175 return InputData()
1176
1177 def ospec(self):
1178 return AddReduceData(self.i.part_pts, self.n_inputs,
1179 self.output_width, self.n_parts)
1180
1181 def elaborate(self, platform):
1182 m = Module()
1183
1184 eps = self.i.part_pts
1185
1186 # collect part-bytes
1187 pbs = Signal(8, reset_less=True)
1188 tl = []
1189 for i in range(8):
1190 pb = Signal(name="pb%d" % i, reset_less=True)
1191 m.d.comb += pb.eq(eps.part_byte(i))
1192 tl.append(pb)
1193 m.d.comb += pbs.eq(Cat(*tl))
1194
1195 # local variables
1196 signs = []
1197 for i in range(8):
1198 s = Signs()
1199 signs.append(s)
1200 setattr(m.submodules, "signs%d" % i, s)
1201 m.d.comb += s.part_ops.eq(self.i.part_ops[i])
1202
1203 n_levels = len(self.register_levels)+1
1204 m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
1205 m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
1206 m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
1207 m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
1208 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1209 for mod in [part_8, part_16, part_32, part_64]:
1210 m.d.comb += mod.a.eq(self.i.a)
1211 m.d.comb += mod.b.eq(self.i.b)
1212 for i in range(len(signs)):
1213 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1214 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1215 m.d.comb += mod.pbs.eq(pbs)
1216 nat_l.append(mod.not_a_term)
1217 nbt_l.append(mod.not_b_term)
1218 nla_l.append(mod.neg_lsb_a_term)
1219 nlb_l.append(mod.neg_lsb_b_term)
1220
1221 terms = []
1222
1223 for a_index in range(8):
1224 t = ProductTerms(8, 128, 8, a_index, 8)
1225 setattr(m.submodules, "terms_%d" % a_index, t)
1226
1227 m.d.comb += t.a.eq(self.i.a)
1228 m.d.comb += t.b.eq(self.i.b)
1229 m.d.comb += t.pb_en.eq(pbs)
1230
1231 for term in t.terms:
1232 terms.append(term)
1233
1234 # it's fine to bitwise-or data together since they are never enabled
1235 # at the same time
1236 m.submodules.nat_or = nat_or = OrMod(128)
1237 m.submodules.nbt_or = nbt_or = OrMod(128)
1238 m.submodules.nla_or = nla_or = OrMod(128)
1239 m.submodules.nlb_or = nlb_or = OrMod(128)
1240 for l, mod in [(nat_l, nat_or),
1241 (nbt_l, nbt_or),
1242 (nla_l, nla_or),
1243 (nlb_l, nlb_or)]:
1244 for i in range(len(l)):
1245 m.d.comb += mod.orin[i].eq(l[i])
1246 terms.append(mod.orout)
1247
1248 # copy the intermediate terms to the output
1249 for i, value in enumerate(terms):
1250 m.d.comb += self.o.terms[i].eq(value)
1251
1252 # copy reg part points and part ops to output
1253 m.d.comb += self.o.part_pts.eq(eps)
1254 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
1255 for i in range(len(self.i.part_ops))]
1256
1257 return m
1258
1259
1260 class Intermediates(Elaboratable):
1261 """ Intermediate output modules
1262 """
1263
1264 def __init__(self, output_width, n_parts, part_pts):
1265 self.part_pts = part_pts
1266 self.output_width = output_width
1267 self.n_parts = n_parts
1268
1269 self.i = self.ispec()
1270 self.o = self.ospec()
1271
1272 def ispec(self):
1273 return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
1274
1275 def ospec(self):
1276 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
1277
1278 def elaborate(self, platform):
1279 m = Module()
1280
1281 out_part_ops = self.i.part_ops
1282 out_part_pts = self.i.part_pts
1283
1284 # create _output_64
1285 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1286 m.d.comb += io64.intermed.eq(self.i.output)
1287 for i in range(8):
1288 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1289 m.d.comb += self.o.outputs[3].eq(io64.output)
1290
1291 # create _output_32
1292 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1293 m.d.comb += io32.intermed.eq(self.i.output)
1294 for i in range(8):
1295 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1296 m.d.comb += self.o.outputs[2].eq(io32.output)
1297
1298 # create _output_16
1299 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1300 m.d.comb += io16.intermed.eq(self.i.output)
1301 for i in range(8):
1302 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1303 m.d.comb += self.o.outputs[1].eq(io16.output)
1304
1305 # create _output_8
1306 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1307 m.d.comb += io8.intermed.eq(self.i.output)
1308 for i in range(8):
1309 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1310 m.d.comb += self.o.outputs[0].eq(io8.output)
1311
1312 for i in range(8):
1313 m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
1314 m.d.comb += self.o.part_pts.eq(out_part_pts)
1315 m.d.comb += self.o.intermediate_output.eq(self.i.output)
1316
1317 return m
1318
1319
1320 class Mul8_16_32_64(Elaboratable):
1321 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1322
1323 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1324 partitions on naturally-aligned boundaries. Supports the operation being
1325 set for each partition independently.
1326
1327 :attribute part_pts: the input partition points. Has a partition point at
1328 multiples of 8 in 0 < i < 64. Each partition point's associated
1329 ``Value`` is a ``Signal``. Modification not supported, except for by
1330 ``Signal.eq``.
1331 :attribute part_ops: the operation for each byte. The operation for a
1332 particular partition is selected by assigning the selected operation
1333 code to each byte in the partition. The allowed operation codes are:
1334
1335 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1336 RISC-V's `mul` instruction.
1337 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1338 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1339 instruction.
1340 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1341 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1342 `mulhsu` instruction.
1343 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1344 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1345 instruction.
1346 """
1347
1348 def __init__(self, register_levels=()):
1349 """ register_levels: specifies the points in the cascade at which
1350 flip-flops are to be inserted.
1351 """
1352
1353 # parameter(s)
1354 self.register_levels = list(register_levels)
1355
1356 self.i = self.ispec()
1357 self.o = self.ospec()
1358
1359 # inputs
1360 self.part_pts = self.i.part_pts
1361 self.part_ops = self.i.part_ops
1362 self.a = self.i.a
1363 self.b = self.i.b
1364
1365 # output
1366 self.intermediate_output = self.o.intermediate_output
1367 self.output = self.o.output
1368
1369 def ispec(self):
1370 return InputData()
1371
1372 def ospec(self):
1373 return OutputData()
1374
1375 def elaborate(self, platform):
1376 m = Module()
1377
1378 part_pts = self.part_pts
1379
1380 n_inputs = 64 + 4
1381 n_parts = 8
1382 t = AllTerms(n_inputs, 128, n_parts, self.register_levels)
1383 m.submodules.allterms = t
1384 m.d.comb += t.i.eq(self.i)
1385
1386 terms = t.o.terms
1387
1388 add_reduce = AddReduce(terms,
1389 128,
1390 self.register_levels,
1391 t.o.part_pts,
1392 t.o.part_ops,
1393 partition_step=2)
1394
1395 m.submodules.add_reduce = add_reduce
1396
1397 interm = Intermediates(128, 8, part_pts)
1398 m.submodules.intermediates = interm
1399 m.d.comb += interm.i.eq(add_reduce.o)
1400
1401 # final output
1402 m.submodules.finalout = finalout = FinalOut(128, 8, part_pts)
1403 m.d.comb += finalout.i.eq(interm.o)
1404 m.d.comb += self.o.eq(finalout.o)
1405
1406 return m
1407
1408
1409 if __name__ == "__main__":
1410 m = Mul8_16_32_64()
1411 main(m, ports=[m.a,
1412 m.b,
1413 m.intermediate_output,
1414 m.output,
1415 *m.part_ops,
1416 *m.part_pts.values()])