use PipelineSpec object in AllTerms
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11 from ieee754.pipeline import PipelineSpec
12 from nmutil.pipemodbase import PipeModBase
13
14
15 class PartitionPoints(dict):
16 """Partition points and corresponding ``Value``s.
17
18 The points at where an ALU is partitioned along with ``Value``s that
19 specify if the corresponding partition points are enabled.
20
21 For example: ``{1: True, 5: True, 10: True}`` with
22 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 0 <= ``i`` < 1
24 * bits 1 <= ``i`` < 5
25 * bits 5 <= ``i`` < 10
26 * bits 10 <= ``i`` < 16
27
28 If the partition_points were instead ``{1: True, 5: a, 10: True}``
29 where ``a`` is a 1-bit ``Signal``:
30 * If ``a`` is asserted:
31 * bits 0 <= ``i`` < 1
32 * bits 1 <= ``i`` < 5
33 * bits 5 <= ``i`` < 10
34 * bits 10 <= ``i`` < 16
35 * Otherwise
36 * bits 0 <= ``i`` < 1
37 * bits 1 <= ``i`` < 10
38 * bits 10 <= ``i`` < 16
39 """
40
41 def __init__(self, partition_points=None):
42 """Create a new ``PartitionPoints``.
43
44 :param partition_points: the input partition points to values mapping.
45 """
46 super().__init__()
47 if partition_points is not None:
48 for point, enabled in partition_points.items():
49 if not isinstance(point, int):
50 raise TypeError("point must be a non-negative integer")
51 if point < 0:
52 raise ValueError("point must be a non-negative integer")
53 self[point] = Value.wrap(enabled)
54
55 def like(self, name=None, src_loc_at=0, mul=1):
56 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
57
58 :param name: the base name for the new ``Signal``s.
59 :param mul: a multiplication factor on the indices
60 """
61 if name is None:
62 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
63 retval = PartitionPoints()
64 for point, enabled in self.items():
65 point *= mul
66 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
67 return retval
68
69 def eq(self, rhs):
70 """Assign ``PartitionPoints`` using ``Signal.eq``."""
71 if set(self.keys()) != set(rhs.keys()):
72 raise ValueError("incompatible point set")
73 for point, enabled in self.items():
74 yield enabled.eq(rhs[point])
75
76 def as_mask(self, width, mul=1):
77 """Create a bit-mask from `self`.
78
79 Each bit in the returned mask is clear only if the partition point at
80 the same bit-index is enabled.
81
82 :param width: the bit width of the resulting mask
83 :param mul: a "multiplier" which in-place expands the partition points
84 typically set to "2" when used for multipliers
85 """
86 bits = []
87 for i in range(width):
88 i /= mul
89 if i.is_integer() and int(i) in self:
90 bits.append(~self[i])
91 else:
92 bits.append(True)
93 return Cat(*bits)
94
95 def get_max_partition_count(self, width):
96 """Get the maximum number of partitions.
97
98 Gets the number of partitions when all partition points are enabled.
99 """
100 retval = 1
101 for point in self.keys():
102 if point < width:
103 retval += 1
104 return retval
105
106 def fits_in_width(self, width):
107 """Check if all partition points are smaller than `width`."""
108 for point in self.keys():
109 if point >= width:
110 return False
111 return True
112
113 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
114 if index == -1 or index == 7:
115 return C(True, 1)
116 assert index >= 0 and index < 8
117 return self[(index * 8 + 8)*mfactor]
118
119
120 class FullAdder(Elaboratable):
121 """Full Adder.
122
123 :attribute in0: the first input
124 :attribute in1: the second input
125 :attribute in2: the third input
126 :attribute sum: the sum output
127 :attribute carry: the carry output
128
129 Rather than do individual full adders (and have an array of them,
130 which would be very slow to simulate), this module can specify the
131 bit width of the inputs and outputs: in effect it performs multiple
132 Full 3-2 Add operations "in parallel".
133 """
134
135 def __init__(self, width):
136 """Create a ``FullAdder``.
137
138 :param width: the bit width of the input and output
139 """
140 self.in0 = Signal(width, reset_less=True)
141 self.in1 = Signal(width, reset_less=True)
142 self.in2 = Signal(width, reset_less=True)
143 self.sum = Signal(width, reset_less=True)
144 self.carry = Signal(width, reset_less=True)
145
146 def elaborate(self, platform):
147 """Elaborate this module."""
148 m = Module()
149 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
150 m.d.comb += self.carry.eq((self.in0 & self.in1)
151 | (self.in1 & self.in2)
152 | (self.in2 & self.in0))
153 return m
154
155
156 class MaskedFullAdder(Elaboratable):
157 """Masked Full Adder.
158
159 :attribute mask: the carry partition mask
160 :attribute in0: the first input
161 :attribute in1: the second input
162 :attribute in2: the third input
163 :attribute sum: the sum output
164 :attribute mcarry: the masked carry output
165
166 FullAdders are always used with a "mask" on the output. To keep
167 the graphviz "clean", this class performs the masking here rather
168 than inside a large for-loop.
169
170 See the following discussion as to why this is no longer derived
171 from FullAdder. Each carry is shifted here *before* being ANDed
172 with the mask, so that an AOI cell may be used (which is more
173 gate-efficient)
174 https://en.wikipedia.org/wiki/AND-OR-Invert
175 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
176 """
177
178 def __init__(self, width):
179 """Create a ``MaskedFullAdder``.
180
181 :param width: the bit width of the input and output
182 """
183 self.width = width
184 self.mask = Signal(width, reset_less=True)
185 self.mcarry = Signal(width, reset_less=True)
186 self.in0 = Signal(width, reset_less=True)
187 self.in1 = Signal(width, reset_less=True)
188 self.in2 = Signal(width, reset_less=True)
189 self.sum = Signal(width, reset_less=True)
190
191 def elaborate(self, platform):
192 """Elaborate this module."""
193 m = Module()
194 s1 = Signal(self.width, reset_less=True)
195 s2 = Signal(self.width, reset_less=True)
196 s3 = Signal(self.width, reset_less=True)
197 c1 = Signal(self.width, reset_less=True)
198 c2 = Signal(self.width, reset_less=True)
199 c3 = Signal(self.width, reset_less=True)
200 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
201 m.d.comb += s1.eq(Cat(0, self.in0))
202 m.d.comb += s2.eq(Cat(0, self.in1))
203 m.d.comb += s3.eq(Cat(0, self.in2))
204 m.d.comb += c1.eq(s1 & s2 & self.mask)
205 m.d.comb += c2.eq(s2 & s3 & self.mask)
206 m.d.comb += c3.eq(s3 & s1 & self.mask)
207 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
208 return m
209
210
211 class PartitionedAdder(Elaboratable):
212 """Partitioned Adder.
213
214 Performs the final add. The partition points are included in the
215 actual add (in one of the operands only), which causes a carry over
216 to the next bit. Then the final output *removes* the extra bits from
217 the result.
218
219 partition: .... P... P... P... P... (32 bits)
220 a : .... .... .... .... .... (32 bits)
221 b : .... .... .... .... .... (32 bits)
222 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
223 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
224 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
225 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
226
227 :attribute width: the bit width of the input and output. Read-only.
228 :attribute a: the first input to the adder
229 :attribute b: the second input to the adder
230 :attribute output: the sum output
231 :attribute partition_points: the input partition points. Modification not
232 supported, except for by ``Signal.eq``.
233 """
234
235 def __init__(self, width, partition_points, partition_step=1):
236 """Create a ``PartitionedAdder``.
237
238 :param width: the bit width of the input and output
239 :param partition_points: the input partition points
240 :param partition_step: a multiplier (typically double) step
241 which in-place "expands" the partition points
242 """
243 self.width = width
244 self.pmul = partition_step
245 self.a = Signal(width, reset_less=True)
246 self.b = Signal(width, reset_less=True)
247 self.output = Signal(width, reset_less=True)
248 self.partition_points = PartitionPoints(partition_points)
249 if not self.partition_points.fits_in_width(width):
250 raise ValueError("partition_points doesn't fit in width")
251 expanded_width = 0
252 for i in range(self.width):
253 if i in self.partition_points:
254 expanded_width += 1
255 expanded_width += 1
256 self._expanded_width = expanded_width
257
258 def elaborate(self, platform):
259 """Elaborate this module."""
260 m = Module()
261 expanded_a = Signal(self._expanded_width, reset_less=True)
262 expanded_b = Signal(self._expanded_width, reset_less=True)
263 expanded_o = Signal(self._expanded_width, reset_less=True)
264
265 expanded_index = 0
266 # store bits in a list, use Cat later. graphviz is much cleaner
267 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
268
269 # partition points are "breaks" (extra zeros or 1s) in what would
270 # otherwise be a massive long add. when the "break" points are 0,
271 # whatever is in it (in the output) is discarded. however when
272 # there is a "1", it causes a roll-over carry to the *next* bit.
273 # we still ignore the "break" bit in the [intermediate] output,
274 # however by that time we've got the effect that we wanted: the
275 # carry has been carried *over* the break point.
276
277 for i in range(self.width):
278 pi = i/self.pmul # double the range of the partition point test
279 if pi.is_integer() and pi in self.partition_points:
280 # add extra bit set to 0 + 0 for enabled partition points
281 # and 1 + 0 for disabled partition points
282 ea.append(expanded_a[expanded_index])
283 al.append(~self.partition_points[pi]) # add extra bit in a
284 eb.append(expanded_b[expanded_index])
285 bl.append(C(0)) # yes, add a zero
286 expanded_index += 1 # skip the extra point. NOT in the output
287 ea.append(expanded_a[expanded_index])
288 eb.append(expanded_b[expanded_index])
289 eo.append(expanded_o[expanded_index])
290 al.append(self.a[i])
291 bl.append(self.b[i])
292 ol.append(self.output[i])
293 expanded_index += 1
294
295 # combine above using Cat
296 m.d.comb += Cat(*ea).eq(Cat(*al))
297 m.d.comb += Cat(*eb).eq(Cat(*bl))
298 m.d.comb += Cat(*ol).eq(Cat(*eo))
299
300 # use only one addition to take advantage of look-ahead carry and
301 # special hardware on FPGAs
302 m.d.comb += expanded_o.eq(expanded_a + expanded_b)
303 return m
304
305
306 FULL_ADDER_INPUT_COUNT = 3
307
308 class AddReduceData:
309
310 def __init__(self, part_pts, n_inputs, output_width, n_parts):
311 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
312 for i in range(n_parts)]
313 self.terms = [Signal(output_width, name=f"inputs_{i}",
314 reset_less=True)
315 for i in range(n_inputs)]
316 self.part_pts = part_pts.like()
317
318 def eq_from(self, part_pts, inputs, part_ops):
319 return [self.part_pts.eq(part_pts)] + \
320 [self.terms[i].eq(inputs[i])
321 for i in range(len(self.terms))] + \
322 [self.part_ops[i].eq(part_ops[i])
323 for i in range(len(self.part_ops))]
324
325 def eq(self, rhs):
326 return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
327
328
329 class FinalReduceData:
330
331 def __init__(self, part_pts, output_width, n_parts):
332 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
333 for i in range(n_parts)]
334 self.output = Signal(output_width, reset_less=True)
335 self.part_pts = part_pts.like()
336
337 def eq_from(self, part_pts, output, part_ops):
338 return [self.part_pts.eq(part_pts)] + \
339 [self.output.eq(output)] + \
340 [self.part_ops[i].eq(part_ops[i])
341 for i in range(len(self.part_ops))]
342
343 def eq(self, rhs):
344 return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
345
346
347 class FinalAdd(Elaboratable):
348 """ Final stage of add reduce
349 """
350
351 def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points,
352 partition_step=1):
353 self.lidx = lidx
354 self.partition_step = partition_step
355 self.output_width = output_width
356 self.n_inputs = n_inputs
357 self.n_parts = n_parts
358 self.partition_points = PartitionPoints(partition_points)
359 if not self.partition_points.fits_in_width(output_width):
360 raise ValueError("partition_points doesn't fit in output_width")
361
362 self.i = self.ispec()
363 self.o = self.ospec()
364
365 def ispec(self):
366 return AddReduceData(self.partition_points, self.n_inputs,
367 self.output_width, self.n_parts)
368
369 def ospec(self):
370 return FinalReduceData(self.partition_points,
371 self.output_width, self.n_parts)
372
373 def setup(self, m, i):
374 m.submodules.finaladd = self
375 m.d.comb += self.i.eq(i)
376
377 def process(self, i):
378 return self.o
379
380 def elaborate(self, platform):
381 """Elaborate this module."""
382 m = Module()
383
384 output_width = self.output_width
385 output = Signal(output_width, reset_less=True)
386 if self.n_inputs == 0:
387 # use 0 as the default output value
388 m.d.comb += output.eq(0)
389 elif self.n_inputs == 1:
390 # handle single input
391 m.d.comb += output.eq(self.i.terms[0])
392 else:
393 # base case for adding 2 inputs
394 assert self.n_inputs == 2
395 adder = PartitionedAdder(output_width,
396 self.i.part_pts, self.partition_step)
397 m.submodules.final_adder = adder
398 m.d.comb += adder.a.eq(self.i.terms[0])
399 m.d.comb += adder.b.eq(self.i.terms[1])
400 m.d.comb += output.eq(adder.output)
401
402 # create output
403 m.d.comb += self.o.eq_from(self.i.part_pts, output,
404 self.i.part_ops)
405
406 return m
407
408
409 class AddReduceSingle(Elaboratable):
410 """Add list of numbers together.
411
412 :attribute inputs: input ``Signal``s to be summed. Modification not
413 supported, except for by ``Signal.eq``.
414 :attribute register_levels: List of nesting levels that should have
415 pipeline registers.
416 :attribute output: output sum.
417 :attribute partition_points: the input partition points. Modification not
418 supported, except for by ``Signal.eq``.
419 """
420
421 def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points,
422 partition_step=1):
423 """Create an ``AddReduce``.
424
425 :param inputs: input ``Signal``s to be summed.
426 :param output_width: bit-width of ``output``.
427 :param partition_points: the input partition points.
428 """
429 self.lidx = lidx
430 self.partition_step = partition_step
431 self.n_inputs = n_inputs
432 self.n_parts = n_parts
433 self.output_width = output_width
434 self.partition_points = PartitionPoints(partition_points)
435 if not self.partition_points.fits_in_width(output_width):
436 raise ValueError("partition_points doesn't fit in output_width")
437
438 self.groups = AddReduceSingle.full_adder_groups(n_inputs)
439 self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
440
441 self.i = self.ispec()
442 self.o = self.ospec()
443
444 def ispec(self):
445 return AddReduceData(self.partition_points, self.n_inputs,
446 self.output_width, self.n_parts)
447
448 def ospec(self):
449 return AddReduceData(self.partition_points, self.n_terms,
450 self.output_width, self.n_parts)
451
452 def setup(self, m, i):
453 setattr(m.submodules, "addreduce_%d" % self.lidx, self)
454 m.d.comb += self.i.eq(i)
455
456 def process(self, i):
457 return self.o
458
459 @staticmethod
460 def calc_n_inputs(n_inputs, groups):
461 retval = len(groups)*2
462 if n_inputs % FULL_ADDER_INPUT_COUNT == 1:
463 retval += 1
464 elif n_inputs % FULL_ADDER_INPUT_COUNT == 2:
465 retval += 2
466 else:
467 assert n_inputs % FULL_ADDER_INPUT_COUNT == 0
468 return retval
469
470 @staticmethod
471 def get_max_level(input_count):
472 """Get the maximum level.
473
474 All ``register_levels`` must be less than or equal to the maximum
475 level.
476 """
477 retval = 0
478 while True:
479 groups = AddReduceSingle.full_adder_groups(input_count)
480 if len(groups) == 0:
481 return retval
482 input_count %= FULL_ADDER_INPUT_COUNT
483 input_count += 2 * len(groups)
484 retval += 1
485
486 @staticmethod
487 def full_adder_groups(input_count):
488 """Get ``inputs`` indices for which a full adder should be built."""
489 return range(0,
490 input_count - FULL_ADDER_INPUT_COUNT + 1,
491 FULL_ADDER_INPUT_COUNT)
492
493 def create_next_terms(self):
494 """ create next intermediate terms, for linking up in elaborate, below
495 """
496 terms = []
497 adders = []
498
499 # create full adders for this recursive level.
500 # this shrinks N terms to 2 * (N // 3) plus the remainder
501 for i in self.groups:
502 adder_i = MaskedFullAdder(self.output_width)
503 adders.append((i, adder_i))
504 # add both the sum and the masked-carry to the next level.
505 # 3 inputs have now been reduced to 2...
506 terms.append(adder_i.sum)
507 terms.append(adder_i.mcarry)
508 # handle the remaining inputs.
509 if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
510 terms.append(self.i.terms[-1])
511 elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
512 # Just pass the terms to the next layer, since we wouldn't gain
513 # anything by using a half adder since there would still be 2 terms
514 # and just passing the terms to the next layer saves gates.
515 terms.append(self.i.terms[-2])
516 terms.append(self.i.terms[-1])
517 else:
518 assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
519
520 return terms, adders
521
522 def elaborate(self, platform):
523 """Elaborate this module."""
524 m = Module()
525
526 terms, adders = self.create_next_terms()
527
528 # copy the intermediate terms to the output
529 for i, value in enumerate(terms):
530 m.d.comb += self.o.terms[i].eq(value)
531
532 # copy reg part points and part ops to output
533 m.d.comb += self.o.part_pts.eq(self.i.part_pts)
534 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
535 for i in range(len(self.i.part_ops))]
536
537 # set up the partition mask (for the adders)
538 part_mask = Signal(self.output_width, reset_less=True)
539
540 # get partition points as a mask
541 mask = self.i.part_pts.as_mask(self.output_width,
542 mul=self.partition_step)
543 m.d.comb += part_mask.eq(mask)
544
545 # add and link the intermediate term modules
546 for i, (iidx, adder_i) in enumerate(adders):
547 setattr(m.submodules, f"adder_{i}", adder_i)
548
549 m.d.comb += adder_i.in0.eq(self.i.terms[iidx])
550 m.d.comb += adder_i.in1.eq(self.i.terms[iidx + 1])
551 m.d.comb += adder_i.in2.eq(self.i.terms[iidx + 2])
552 m.d.comb += adder_i.mask.eq(part_mask)
553
554 return m
555
556
557 class AddReduceInternal:
558 """Recursively Add list of numbers together.
559
560 :attribute inputs: input ``Signal``s to be summed. Modification not
561 supported, except for by ``Signal.eq``.
562 :attribute register_levels: List of nesting levels that should have
563 pipeline registers.
564 :attribute output: output sum.
565 :attribute partition_points: the input partition points. Modification not
566 supported, except for by ``Signal.eq``.
567 """
568
569 def __init__(self, i, output_width, partition_step=1):
570 """Create an ``AddReduce``.
571
572 :param inputs: input ``Signal``s to be summed.
573 :param output_width: bit-width of ``output``.
574 :param partition_points: the input partition points.
575 """
576 self.i = i
577 self.inputs = i.terms
578 self.part_ops = i.part_ops
579 self.output_width = output_width
580 self.partition_points = i.part_pts
581 self.partition_step = partition_step
582
583 self.create_levels()
584
585 def create_levels(self):
586 """creates reduction levels"""
587
588 mods = []
589 partition_points = self.partition_points
590 part_ops = self.part_ops
591 n_parts = len(part_ops)
592 inputs = self.inputs
593 ilen = len(inputs)
594 while True:
595 groups = AddReduceSingle.full_adder_groups(len(inputs))
596 if len(groups) == 0:
597 break
598 lidx = len(mods)
599 next_level = AddReduceSingle(lidx, ilen, self.output_width, n_parts,
600 partition_points,
601 self.partition_step)
602 mods.append(next_level)
603 partition_points = next_level.i.part_pts
604 inputs = next_level.o.terms
605 ilen = len(inputs)
606 part_ops = next_level.i.part_ops
607
608 lidx = len(mods)
609 next_level = FinalAdd(lidx, ilen, self.output_width, n_parts,
610 partition_points, self.partition_step)
611 mods.append(next_level)
612
613 self.levels = mods
614
615
616 class AddReduce(AddReduceInternal, Elaboratable):
617 """Recursively Add list of numbers together.
618
619 :attribute inputs: input ``Signal``s to be summed. Modification not
620 supported, except for by ``Signal.eq``.
621 :attribute register_levels: List of nesting levels that should have
622 pipeline registers.
623 :attribute output: output sum.
624 :attribute partition_points: the input partition points. Modification not
625 supported, except for by ``Signal.eq``.
626 """
627
628 def __init__(self, inputs, output_width, register_levels, part_pts,
629 part_ops, partition_step=1):
630 """Create an ``AddReduce``.
631
632 :param inputs: input ``Signal``s to be summed.
633 :param output_width: bit-width of ``output``.
634 :param register_levels: List of nesting levels that should have
635 pipeline registers.
636 :param partition_points: the input partition points.
637 """
638 self._inputs = inputs
639 self._part_pts = part_pts
640 self._part_ops = part_ops
641 n_parts = len(part_ops)
642 self.i = AddReduceData(part_pts, len(inputs),
643 output_width, n_parts)
644 AddReduceInternal.__init__(self, self.i, output_width, partition_step)
645 self.o = FinalReduceData(part_pts, output_width, n_parts)
646 self.register_levels = register_levels
647
648 @staticmethod
649 def get_max_level(input_count):
650 return AddReduceSingle.get_max_level(input_count)
651
652 @staticmethod
653 def next_register_levels(register_levels):
654 """``Iterable`` of ``register_levels`` for next recursive level."""
655 for level in register_levels:
656 if level > 0:
657 yield level - 1
658
659 def elaborate(self, platform):
660 """Elaborate this module."""
661 m = Module()
662
663 m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops)
664
665 for i, next_level in enumerate(self.levels):
666 setattr(m.submodules, "next_level%d" % i, next_level)
667
668 i = self.i
669 for idx in range(len(self.levels)):
670 mcur = self.levels[idx]
671 if idx in self.register_levels:
672 m.d.sync += mcur.i.eq(i)
673 else:
674 m.d.comb += mcur.i.eq(i)
675 i = mcur.o # for next loop
676
677 # output comes from last module
678 m.d.comb += self.o.eq(i)
679
680 return m
681
682
683 OP_MUL_LOW = 0
684 OP_MUL_SIGNED_HIGH = 1
685 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
686 OP_MUL_UNSIGNED_HIGH = 3
687
688
689 def get_term(value, shift=0, enabled=None):
690 if enabled is not None:
691 value = Mux(enabled, value, 0)
692 if shift > 0:
693 value = Cat(Repl(C(0, 1), shift), value)
694 else:
695 assert shift == 0
696 return value
697
698
699 class ProductTerm(Elaboratable):
700 """ this class creates a single product term (a[..]*b[..]).
701 it has a design flaw in that is the *output* that is selected,
702 where the multiplication(s) are combinatorially generated
703 all the time.
704 """
705
706 def __init__(self, width, twidth, pbwid, a_index, b_index):
707 self.a_index = a_index
708 self.b_index = b_index
709 shift = 8 * (self.a_index + self.b_index)
710 self.pwidth = width
711 self.twidth = twidth
712 self.width = width*2
713 self.shift = shift
714
715 self.ti = Signal(self.width, reset_less=True)
716 self.term = Signal(twidth, reset_less=True)
717 self.a = Signal(twidth//2, reset_less=True)
718 self.b = Signal(twidth//2, reset_less=True)
719 self.pb_en = Signal(pbwid, reset_less=True)
720
721 self.tl = tl = []
722 min_index = min(self.a_index, self.b_index)
723 max_index = max(self.a_index, self.b_index)
724 for i in range(min_index, max_index):
725 tl.append(self.pb_en[i])
726 name = "te_%d_%d" % (self.a_index, self.b_index)
727 if len(tl) > 0:
728 term_enabled = Signal(name=name, reset_less=True)
729 else:
730 term_enabled = None
731 self.enabled = term_enabled
732 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
733
734 def elaborate(self, platform):
735
736 m = Module()
737 if self.enabled is not None:
738 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
739
740 bsa = Signal(self.width, reset_less=True)
741 bsb = Signal(self.width, reset_less=True)
742 a_index, b_index = self.a_index, self.b_index
743 pwidth = self.pwidth
744 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
745 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
746 m.d.comb += self.ti.eq(bsa * bsb)
747 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
748 """
749 #TODO: sort out width issues, get inputs a/b switched on/off.
750 #data going into Muxes is 1/2 the required width
751
752 pwidth = self.pwidth
753 width = self.width
754 bsa = Signal(self.twidth//2, reset_less=True)
755 bsb = Signal(self.twidth//2, reset_less=True)
756 asel = Signal(width, reset_less=True)
757 bsel = Signal(width, reset_less=True)
758 a_index, b_index = self.a_index, self.b_index
759 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
760 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
761 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
762 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
763 m.d.comb += self.ti.eq(bsa * bsb)
764 m.d.comb += self.term.eq(self.ti)
765 """
766
767 return m
768
769
770 class ProductTerms(Elaboratable):
771 """ creates a bank of product terms. also performs the actual bit-selection
772 this class is to be wrapped with a for-loop on the "a" operand.
773 it creates a second-level for-loop on the "b" operand.
774 """
775 def __init__(self, width, twidth, pbwid, a_index, blen):
776 self.a_index = a_index
777 self.blen = blen
778 self.pwidth = width
779 self.twidth = twidth
780 self.pbwid = pbwid
781 self.a = Signal(twidth//2, reset_less=True)
782 self.b = Signal(twidth//2, reset_less=True)
783 self.pb_en = Signal(pbwid, reset_less=True)
784 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
785 for i in range(blen)]
786
787 def elaborate(self, platform):
788
789 m = Module()
790
791 for b_index in range(self.blen):
792 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
793 self.a_index, b_index)
794 setattr(m.submodules, "term_%d" % b_index, t)
795
796 m.d.comb += t.a.eq(self.a)
797 m.d.comb += t.b.eq(self.b)
798 m.d.comb += t.pb_en.eq(self.pb_en)
799
800 m.d.comb += self.terms[b_index].eq(t.term)
801
802 return m
803
804
805 class LSBNegTerm(Elaboratable):
806
807 def __init__(self, bit_width):
808 self.bit_width = bit_width
809 self.part = Signal(reset_less=True)
810 self.signed = Signal(reset_less=True)
811 self.op = Signal(bit_width, reset_less=True)
812 self.msb = Signal(reset_less=True)
813 self.nt = Signal(bit_width*2, reset_less=True)
814 self.nl = Signal(bit_width*2, reset_less=True)
815
816 def elaborate(self, platform):
817 m = Module()
818 comb = m.d.comb
819 bit_wid = self.bit_width
820 ext = Repl(0, bit_wid) # extend output to HI part
821
822 # determine sign of each incoming number *in this partition*
823 enabled = Signal(reset_less=True)
824 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
825
826 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
827 # negation operation is split into a bitwise not and a +1.
828 # likewise for 16, 32, and 64-bit values.
829
830 # width-extended 1s complement if a is signed, otherwise zero
831 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
832
833 # add 1 if signed, otherwise add zero
834 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
835
836 return m
837
838
839 class Parts(Elaboratable):
840
841 def __init__(self, pbwid, part_pts, n_parts):
842 self.pbwid = pbwid
843 # inputs
844 self.part_pts = PartitionPoints.like(part_pts)
845 # outputs
846 self.parts = [Signal(name=f"part_{i}", reset_less=True)
847 for i in range(n_parts)]
848
849 def elaborate(self, platform):
850 m = Module()
851
852 part_pts, parts = self.part_pts, self.parts
853 # collect part-bytes (double factor because the input is extended)
854 pbs = Signal(self.pbwid, reset_less=True)
855 tl = []
856 for i in range(self.pbwid):
857 pb = Signal(name="pb%d" % i, reset_less=True)
858 m.d.comb += pb.eq(part_pts.part_byte(i))
859 tl.append(pb)
860 m.d.comb += pbs.eq(Cat(*tl))
861
862 # negated-temporary copy of partition bits
863 npbs = Signal.like(pbs, reset_less=True)
864 m.d.comb += npbs.eq(~pbs)
865 byte_count = 8 // len(parts)
866 for i in range(len(parts)):
867 pbl = []
868 pbl.append(npbs[i * byte_count - 1])
869 for j in range(i * byte_count, (i + 1) * byte_count - 1):
870 pbl.append(pbs[j])
871 pbl.append(npbs[(i + 1) * byte_count - 1])
872 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
873 m.d.comb += value.eq(Cat(*pbl))
874 m.d.comb += parts[i].eq(~(value).bool())
875
876 return m
877
878
879 class Part(Elaboratable):
880 """ a key class which, depending on the partitioning, will determine
881 what action to take when parts of the output are signed or unsigned.
882
883 this requires 2 pieces of data *per operand, per partition*:
884 whether the MSB is HI/LO (per partition!), and whether a signed
885 or unsigned operation has been *requested*.
886
887 once that is determined, signed is basically carried out
888 by splitting 2's complement into 1's complement plus one.
889 1's complement is just a bit-inversion.
890
891 the extra terms - as separate terms - are then thrown at the
892 AddReduce alongside the multiplication part-results.
893 """
894 def __init__(self, part_pts, width, n_parts, pbwid):
895
896 self.pbwid = pbwid
897 self.part_pts = part_pts
898
899 # inputs
900 self.a = Signal(64, reset_less=True)
901 self.b = Signal(64, reset_less=True)
902 self.a_signed = [Signal(name=f"a_signed_{i}", reset_less=True)
903 for i in range(8)]
904 self.b_signed = [Signal(name=f"_b_signed_{i}", reset_less=True)
905 for i in range(8)]
906 self.pbs = Signal(pbwid, reset_less=True)
907
908 # outputs
909 self.parts = [Signal(name=f"part_{i}", reset_less=True)
910 for i in range(n_parts)]
911
912 self.not_a_term = Signal(width, reset_less=True)
913 self.neg_lsb_a_term = Signal(width, reset_less=True)
914 self.not_b_term = Signal(width, reset_less=True)
915 self.neg_lsb_b_term = Signal(width, reset_less=True)
916
917 def elaborate(self, platform):
918 m = Module()
919
920 pbs, parts = self.pbs, self.parts
921 part_pts = self.part_pts
922 m.submodules.p = p = Parts(self.pbwid, part_pts, len(parts))
923 m.d.comb += p.part_pts.eq(part_pts)
924 parts = p.parts
925
926 byte_count = 8 // len(parts)
927
928 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
929 self.not_a_term, self.neg_lsb_a_term,
930 self.not_b_term, self.neg_lsb_b_term)
931
932 byte_width = 8 // len(parts) # byte width
933 bit_wid = 8 * byte_width # bit width
934 nat, nbt, nla, nlb = [], [], [], []
935 for i in range(len(parts)):
936 # work out bit-inverted and +1 term for a.
937 pa = LSBNegTerm(bit_wid)
938 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
939 m.d.comb += pa.part.eq(parts[i])
940 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
941 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
942 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
943 nat.append(pa.nt)
944 nla.append(pa.nl)
945
946 # work out bit-inverted and +1 term for b
947 pb = LSBNegTerm(bit_wid)
948 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
949 m.d.comb += pb.part.eq(parts[i])
950 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
951 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
952 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
953 nbt.append(pb.nt)
954 nlb.append(pb.nl)
955
956 # concatenate together and return all 4 results.
957 m.d.comb += [not_a_term.eq(Cat(*nat)),
958 not_b_term.eq(Cat(*nbt)),
959 neg_lsb_a_term.eq(Cat(*nla)),
960 neg_lsb_b_term.eq(Cat(*nlb)),
961 ]
962
963 return m
964
965
966 class IntermediateOut(Elaboratable):
967 """ selects the HI/LO part of the multiplication, for a given bit-width
968 the output is also reconstructed in its SIMD (partition) lanes.
969 """
970 def __init__(self, width, out_wid, n_parts):
971 self.width = width
972 self.n_parts = n_parts
973 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
974 for i in range(8)]
975 self.intermed = Signal(out_wid, reset_less=True)
976 self.output = Signal(out_wid//2, reset_less=True)
977
978 def elaborate(self, platform):
979 m = Module()
980
981 ol = []
982 w = self.width
983 sel = w // 8
984 for i in range(self.n_parts):
985 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
986 m.d.comb += op.eq(
987 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
988 self.intermed.bit_select(i * w*2, w),
989 self.intermed.bit_select(i * w*2 + w, w)))
990 ol.append(op)
991 m.d.comb += self.output.eq(Cat(*ol))
992
993 return m
994
995
996 class FinalOut(Elaboratable):
997 """ selects the final output based on the partitioning.
998
999 each byte is selectable independently, i.e. it is possible
1000 that some partitions requested 8-bit computation whilst others
1001 requested 16 or 32 bit.
1002 """
1003 def __init__(self, output_width, n_parts, part_pts):
1004 self.part_pts = part_pts
1005 self.output_width = output_width
1006 self.n_parts = n_parts
1007 self.out_wid = output_width//2
1008
1009 self.i = self.ispec()
1010 self.o = self.ospec()
1011
1012 def ispec(self):
1013 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
1014
1015 def ospec(self):
1016 return OutputData()
1017
1018 def setup(self, m, i):
1019 m.submodules.finalout = self
1020 m.d.comb += self.i.eq(i)
1021
1022 def process(self, i):
1023 return self.o
1024
1025 def elaborate(self, platform):
1026 m = Module()
1027
1028 part_pts = self.part_pts
1029 m.submodules.p_8 = p_8 = Parts(8, part_pts, 8)
1030 m.submodules.p_16 = p_16 = Parts(8, part_pts, 4)
1031 m.submodules.p_32 = p_32 = Parts(8, part_pts, 2)
1032 m.submodules.p_64 = p_64 = Parts(8, part_pts, 1)
1033
1034 out_part_pts = self.i.part_pts
1035
1036 # temporaries
1037 d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
1038 d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
1039 d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
1040
1041 i8 = Signal(self.out_wid, reset_less=True)
1042 i16 = Signal(self.out_wid, reset_less=True)
1043 i32 = Signal(self.out_wid, reset_less=True)
1044 i64 = Signal(self.out_wid, reset_less=True)
1045
1046 m.d.comb += p_8.part_pts.eq(out_part_pts)
1047 m.d.comb += p_16.part_pts.eq(out_part_pts)
1048 m.d.comb += p_32.part_pts.eq(out_part_pts)
1049 m.d.comb += p_64.part_pts.eq(out_part_pts)
1050
1051 for i in range(len(p_8.parts)):
1052 m.d.comb += d8[i].eq(p_8.parts[i])
1053 for i in range(len(p_16.parts)):
1054 m.d.comb += d16[i].eq(p_16.parts[i])
1055 for i in range(len(p_32.parts)):
1056 m.d.comb += d32[i].eq(p_32.parts[i])
1057 m.d.comb += i8.eq(self.i.outputs[0])
1058 m.d.comb += i16.eq(self.i.outputs[1])
1059 m.d.comb += i32.eq(self.i.outputs[2])
1060 m.d.comb += i64.eq(self.i.outputs[3])
1061
1062 ol = []
1063 for i in range(8):
1064 # select one of the outputs: d8 selects i8, d16 selects i16
1065 # d32 selects i32, and the default is i64.
1066 # d8 and d16 are ORed together in the first Mux
1067 # then the 2nd selects either i8 or i16.
1068 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
1069 op = Signal(8, reset_less=True, name="op_%d" % i)
1070 m.d.comb += op.eq(
1071 Mux(d8[i] | d16[i // 2],
1072 Mux(d8[i], i8.bit_select(i * 8, 8),
1073 i16.bit_select(i * 8, 8)),
1074 Mux(d32[i // 4], i32.bit_select(i * 8, 8),
1075 i64.bit_select(i * 8, 8))))
1076 ol.append(op)
1077
1078 # create outputs
1079 m.d.comb += self.o.output.eq(Cat(*ol))
1080 m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
1081
1082 return m
1083
1084
1085 class OrMod(Elaboratable):
1086 """ ORs four values together in a hierarchical tree
1087 """
1088 def __init__(self, wid):
1089 self.wid = wid
1090 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
1091 for i in range(4)]
1092 self.orout = Signal(wid, reset_less=True)
1093
1094 def elaborate(self, platform):
1095 m = Module()
1096 or1 = Signal(self.wid, reset_less=True)
1097 or2 = Signal(self.wid, reset_less=True)
1098 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
1099 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
1100 m.d.comb += self.orout.eq(or1 | or2)
1101
1102 return m
1103
1104
1105 class Signs(Elaboratable):
1106 """ determines whether a or b are signed numbers
1107 based on the required operation type (OP_MUL_*)
1108 """
1109
1110 def __init__(self):
1111 self.part_ops = Signal(2, reset_less=True)
1112 self.a_signed = Signal(reset_less=True)
1113 self.b_signed = Signal(reset_less=True)
1114
1115 def elaborate(self, platform):
1116
1117 m = Module()
1118
1119 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
1120 bsig = (self.part_ops == OP_MUL_LOW) \
1121 | (self.part_ops == OP_MUL_SIGNED_HIGH)
1122 m.d.comb += self.a_signed.eq(asig)
1123 m.d.comb += self.b_signed.eq(bsig)
1124
1125 return m
1126
1127
1128 class IntermediateData:
1129
1130 def __init__(self, part_pts, output_width, n_parts):
1131 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
1132 for i in range(n_parts)]
1133 self.part_pts = part_pts.like()
1134 self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
1135 for i in range(4)]
1136 # intermediates (needed for unit tests)
1137 self.intermediate_output = Signal(output_width)
1138
1139 def eq_from(self, part_pts, outputs, intermediate_output,
1140 part_ops):
1141 return [self.part_pts.eq(part_pts)] + \
1142 [self.intermediate_output.eq(intermediate_output)] + \
1143 [self.outputs[i].eq(outputs[i])
1144 for i in range(4)] + \
1145 [self.part_ops[i].eq(part_ops[i])
1146 for i in range(len(self.part_ops))]
1147
1148 def eq(self, rhs):
1149 return self.eq_from(rhs.part_pts, rhs.outputs,
1150 rhs.intermediate_output, rhs.part_ops)
1151
1152
1153 class InputData:
1154
1155 def __init__(self):
1156 self.a = Signal(64)
1157 self.b = Signal(64)
1158 self.part_pts = PartitionPoints()
1159 for i in range(8, 64, 8):
1160 self.part_pts[i] = Signal(name=f"part_pts_{i}")
1161 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
1162
1163 def eq_from(self, part_pts, a, b, part_ops):
1164 return [self.part_pts.eq(part_pts)] + \
1165 [self.a.eq(a), self.b.eq(b)] + \
1166 [self.part_ops[i].eq(part_ops[i])
1167 for i in range(len(self.part_ops))]
1168
1169 def eq(self, rhs):
1170 return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
1171
1172
1173 class OutputData:
1174
1175 def __init__(self):
1176 self.intermediate_output = Signal(128) # needed for unit tests
1177 self.output = Signal(64)
1178
1179 def eq(self, rhs):
1180 return [self.intermediate_output.eq(rhs.intermediate_output),
1181 self.output.eq(rhs.output)]
1182
1183
1184 class AllTerms(PipeModBase):
1185 """Set of terms to be added together
1186 """
1187
1188 def __init__(self, pspec):
1189 """Create an ``AllTerms``.
1190 """
1191 self.n_inputs = pspec.n_inputs
1192 self.n_parts = pspec.n_parts
1193 self.output_width = pspec.width
1194 super().__init__(pspec, "allterms")
1195
1196 def ispec(self):
1197 return InputData()
1198
1199 def ospec(self):
1200 return AddReduceData(self.i.part_pts, self.n_inputs,
1201 self.output_width, self.n_parts)
1202
1203 def elaborate(self, platform):
1204 m = Module()
1205
1206 eps = self.i.part_pts
1207
1208 # collect part-bytes
1209 pbs = Signal(8, reset_less=True)
1210 tl = []
1211 for i in range(8):
1212 pb = Signal(name="pb%d" % i, reset_less=True)
1213 m.d.comb += pb.eq(eps.part_byte(i))
1214 tl.append(pb)
1215 m.d.comb += pbs.eq(Cat(*tl))
1216
1217 # local variables
1218 signs = []
1219 for i in range(8):
1220 s = Signs()
1221 signs.append(s)
1222 setattr(m.submodules, "signs%d" % i, s)
1223 m.d.comb += s.part_ops.eq(self.i.part_ops[i])
1224
1225 m.submodules.part_8 = part_8 = Part(eps, 128, 8, 8)
1226 m.submodules.part_16 = part_16 = Part(eps, 128, 4, 8)
1227 m.submodules.part_32 = part_32 = Part(eps, 128, 2, 8)
1228 m.submodules.part_64 = part_64 = Part(eps, 128, 1, 8)
1229 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1230 for mod in [part_8, part_16, part_32, part_64]:
1231 m.d.comb += mod.a.eq(self.i.a)
1232 m.d.comb += mod.b.eq(self.i.b)
1233 for i in range(len(signs)):
1234 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1235 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1236 m.d.comb += mod.pbs.eq(pbs)
1237 nat_l.append(mod.not_a_term)
1238 nbt_l.append(mod.not_b_term)
1239 nla_l.append(mod.neg_lsb_a_term)
1240 nlb_l.append(mod.neg_lsb_b_term)
1241
1242 terms = []
1243
1244 for a_index in range(8):
1245 t = ProductTerms(8, 128, 8, a_index, 8)
1246 setattr(m.submodules, "terms_%d" % a_index, t)
1247
1248 m.d.comb += t.a.eq(self.i.a)
1249 m.d.comb += t.b.eq(self.i.b)
1250 m.d.comb += t.pb_en.eq(pbs)
1251
1252 for term in t.terms:
1253 terms.append(term)
1254
1255 # it's fine to bitwise-or data together since they are never enabled
1256 # at the same time
1257 m.submodules.nat_or = nat_or = OrMod(128)
1258 m.submodules.nbt_or = nbt_or = OrMod(128)
1259 m.submodules.nla_or = nla_or = OrMod(128)
1260 m.submodules.nlb_or = nlb_or = OrMod(128)
1261 for l, mod in [(nat_l, nat_or),
1262 (nbt_l, nbt_or),
1263 (nla_l, nla_or),
1264 (nlb_l, nlb_or)]:
1265 for i in range(len(l)):
1266 m.d.comb += mod.orin[i].eq(l[i])
1267 terms.append(mod.orout)
1268
1269 # copy the intermediate terms to the output
1270 for i, value in enumerate(terms):
1271 m.d.comb += self.o.terms[i].eq(value)
1272
1273 # copy reg part points and part ops to output
1274 m.d.comb += self.o.part_pts.eq(eps)
1275 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
1276 for i in range(len(self.i.part_ops))]
1277
1278 return m
1279
1280
1281 class Intermediates(Elaboratable):
1282 """ Intermediate output modules
1283 """
1284
1285 def __init__(self, output_width, n_parts, part_pts):
1286 self.part_pts = part_pts
1287 self.output_width = output_width
1288 self.n_parts = n_parts
1289
1290 self.i = self.ispec()
1291 self.o = self.ospec()
1292
1293 def ispec(self):
1294 return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
1295
1296 def ospec(self):
1297 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
1298
1299 def setup(self, m, i):
1300 m.submodules.intermediates = self
1301 m.d.comb += self.i.eq(i)
1302
1303 def process(self, i):
1304 return self.o
1305
1306 def elaborate(self, platform):
1307 m = Module()
1308
1309 out_part_ops = self.i.part_ops
1310 out_part_pts = self.i.part_pts
1311
1312 # create _output_64
1313 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1314 m.d.comb += io64.intermed.eq(self.i.output)
1315 for i in range(8):
1316 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1317 m.d.comb += self.o.outputs[3].eq(io64.output)
1318
1319 # create _output_32
1320 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1321 m.d.comb += io32.intermed.eq(self.i.output)
1322 for i in range(8):
1323 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1324 m.d.comb += self.o.outputs[2].eq(io32.output)
1325
1326 # create _output_16
1327 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1328 m.d.comb += io16.intermed.eq(self.i.output)
1329 for i in range(8):
1330 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1331 m.d.comb += self.o.outputs[1].eq(io16.output)
1332
1333 # create _output_8
1334 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1335 m.d.comb += io8.intermed.eq(self.i.output)
1336 for i in range(8):
1337 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1338 m.d.comb += self.o.outputs[0].eq(io8.output)
1339
1340 for i in range(8):
1341 m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
1342 m.d.comb += self.o.part_pts.eq(out_part_pts)
1343 m.d.comb += self.o.intermediate_output.eq(self.i.output)
1344
1345 return m
1346
1347
1348 class Mul8_16_32_64(Elaboratable):
1349 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1350
1351 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1352 partitions on naturally-aligned boundaries. Supports the operation being
1353 set for each partition independently.
1354
1355 :attribute part_pts: the input partition points. Has a partition point at
1356 multiples of 8 in 0 < i < 64. Each partition point's associated
1357 ``Value`` is a ``Signal``. Modification not supported, except for by
1358 ``Signal.eq``.
1359 :attribute part_ops: the operation for each byte. The operation for a
1360 particular partition is selected by assigning the selected operation
1361 code to each byte in the partition. The allowed operation codes are:
1362
1363 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1364 RISC-V's `mul` instruction.
1365 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1366 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1367 instruction.
1368 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1369 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1370 `mulhsu` instruction.
1371 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1372 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1373 instruction.
1374 """
1375
1376 def __init__(self, register_levels=()):
1377 """ register_levels: specifies the points in the cascade at which
1378 flip-flops are to be inserted.
1379 """
1380
1381 self.id_wid = 0 # num_bits(num_rows)
1382 self.op_wid = 0
1383 self.pspec = PipelineSpec(128, self.id_wid, self.op_wid, n_ops=3)
1384 self.pspec.n_inputs = 64 + 4
1385 self.pspec.n_parts = 8
1386
1387 # parameter(s)
1388 self.register_levels = list(register_levels)
1389
1390 self.i = self.ispec()
1391 self.o = self.ospec()
1392
1393 # inputs
1394 self.part_pts = self.i.part_pts
1395 self.part_ops = self.i.part_ops
1396 self.a = self.i.a
1397 self.b = self.i.b
1398
1399 # output
1400 self.intermediate_output = self.o.intermediate_output
1401 self.output = self.o.output
1402
1403 def ispec(self):
1404 return InputData()
1405
1406 def ospec(self):
1407 return OutputData()
1408
1409 def elaborate(self, platform):
1410 m = Module()
1411
1412 part_pts = self.part_pts
1413
1414 n_parts = self.pspec.n_parts
1415 n_inputs = self.pspec.n_inputs
1416 output_width = self.pspec.width
1417 t = AllTerms(self.pspec)
1418 t.setup(m, self.i)
1419
1420 terms = t.o.terms
1421
1422 at = AddReduceInternal(t.process(self.i), 128, partition_step=2)
1423
1424 i = at.i
1425 for idx in range(len(at.levels)):
1426 mcur = at.levels[idx]
1427 mcur.setup(m, i)
1428 o = mcur.ospec()
1429 if idx in self.register_levels:
1430 m.d.sync += o.eq(mcur.process(i))
1431 else:
1432 m.d.comb += o.eq(mcur.process(i))
1433 i = o # for next loop
1434
1435 interm = Intermediates(128, 8, part_pts)
1436 interm.setup(m, i)
1437 o = interm.process(interm.i)
1438
1439 # final output
1440 finalout = FinalOut(128, 8, part_pts)
1441 finalout.setup(m, o)
1442 m.d.comb += self.o.eq(finalout.process(o))
1443
1444 return m
1445
1446
1447 if __name__ == "__main__":
1448 m = Mul8_16_32_64()
1449 main(m, ports=[m.a,
1450 m.b,
1451 m.intermediate_output,
1452 m.output,
1453 *m.part_ops,
1454 *m.part_pts.values()])