rename inputs_ to terms_
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11 from ieee754.pipeline import PipelineSpec
12 from nmutil.pipemodbase import PipeModBase
13
14
15 class PartitionPoints(dict):
16 """Partition points and corresponding ``Value``s.
17
18 The points at where an ALU is partitioned along with ``Value``s that
19 specify if the corresponding partition points are enabled.
20
21 For example: ``{1: True, 5: True, 10: True}`` with
22 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 0 <= ``i`` < 1
24 * bits 1 <= ``i`` < 5
25 * bits 5 <= ``i`` < 10
26 * bits 10 <= ``i`` < 16
27
28 If the partition_points were instead ``{1: True, 5: a, 10: True}``
29 where ``a`` is a 1-bit ``Signal``:
30 * If ``a`` is asserted:
31 * bits 0 <= ``i`` < 1
32 * bits 1 <= ``i`` < 5
33 * bits 5 <= ``i`` < 10
34 * bits 10 <= ``i`` < 16
35 * Otherwise
36 * bits 0 <= ``i`` < 1
37 * bits 1 <= ``i`` < 10
38 * bits 10 <= ``i`` < 16
39 """
40
41 def __init__(self, partition_points=None):
42 """Create a new ``PartitionPoints``.
43
44 :param partition_points: the input partition points to values mapping.
45 """
46 super().__init__()
47 if partition_points is not None:
48 for point, enabled in partition_points.items():
49 if not isinstance(point, int):
50 raise TypeError("point must be a non-negative integer")
51 if point < 0:
52 raise ValueError("point must be a non-negative integer")
53 self[point] = Value.wrap(enabled)
54
55 def like(self, name=None, src_loc_at=0, mul=1):
56 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
57
58 :param name: the base name for the new ``Signal``s.
59 :param mul: a multiplication factor on the indices
60 """
61 if name is None:
62 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
63 retval = PartitionPoints()
64 for point, enabled in self.items():
65 point *= mul
66 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
67 return retval
68
69 def eq(self, rhs):
70 """Assign ``PartitionPoints`` using ``Signal.eq``."""
71 if set(self.keys()) != set(rhs.keys()):
72 raise ValueError("incompatible point set")
73 for point, enabled in self.items():
74 yield enabled.eq(rhs[point])
75
76 def as_mask(self, width, mul=1):
77 """Create a bit-mask from `self`.
78
79 Each bit in the returned mask is clear only if the partition point at
80 the same bit-index is enabled.
81
82 :param width: the bit width of the resulting mask
83 :param mul: a "multiplier" which in-place expands the partition points
84 typically set to "2" when used for multipliers
85 """
86 bits = []
87 for i in range(width):
88 i /= mul
89 if i.is_integer() and int(i) in self:
90 bits.append(~self[i])
91 else:
92 bits.append(True)
93 return Cat(*bits)
94
95 def get_max_partition_count(self, width):
96 """Get the maximum number of partitions.
97
98 Gets the number of partitions when all partition points are enabled.
99 """
100 retval = 1
101 for point in self.keys():
102 if point < width:
103 retval += 1
104 return retval
105
106 def fits_in_width(self, width):
107 """Check if all partition points are smaller than `width`."""
108 for point in self.keys():
109 if point >= width:
110 return False
111 return True
112
113 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
114 if index == -1 or index == 7:
115 return C(True, 1)
116 assert index >= 0 and index < 8
117 return self[(index * 8 + 8)*mfactor]
118
119
120 class FullAdder(Elaboratable):
121 """Full Adder.
122
123 :attribute in0: the first input
124 :attribute in1: the second input
125 :attribute in2: the third input
126 :attribute sum: the sum output
127 :attribute carry: the carry output
128
129 Rather than do individual full adders (and have an array of them,
130 which would be very slow to simulate), this module can specify the
131 bit width of the inputs and outputs: in effect it performs multiple
132 Full 3-2 Add operations "in parallel".
133 """
134
135 def __init__(self, width):
136 """Create a ``FullAdder``.
137
138 :param width: the bit width of the input and output
139 """
140 self.in0 = Signal(width, reset_less=True)
141 self.in1 = Signal(width, reset_less=True)
142 self.in2 = Signal(width, reset_less=True)
143 self.sum = Signal(width, reset_less=True)
144 self.carry = Signal(width, reset_less=True)
145
146 def elaborate(self, platform):
147 """Elaborate this module."""
148 m = Module()
149 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
150 m.d.comb += self.carry.eq((self.in0 & self.in1)
151 | (self.in1 & self.in2)
152 | (self.in2 & self.in0))
153 return m
154
155
156 class MaskedFullAdder(Elaboratable):
157 """Masked Full Adder.
158
159 :attribute mask: the carry partition mask
160 :attribute in0: the first input
161 :attribute in1: the second input
162 :attribute in2: the third input
163 :attribute sum: the sum output
164 :attribute mcarry: the masked carry output
165
166 FullAdders are always used with a "mask" on the output. To keep
167 the graphviz "clean", this class performs the masking here rather
168 than inside a large for-loop.
169
170 See the following discussion as to why this is no longer derived
171 from FullAdder. Each carry is shifted here *before* being ANDed
172 with the mask, so that an AOI cell may be used (which is more
173 gate-efficient)
174 https://en.wikipedia.org/wiki/AND-OR-Invert
175 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
176 """
177
178 def __init__(self, width):
179 """Create a ``MaskedFullAdder``.
180
181 :param width: the bit width of the input and output
182 """
183 self.width = width
184 self.mask = Signal(width, reset_less=True)
185 self.mcarry = Signal(width, reset_less=True)
186 self.in0 = Signal(width, reset_less=True)
187 self.in1 = Signal(width, reset_less=True)
188 self.in2 = Signal(width, reset_less=True)
189 self.sum = Signal(width, reset_less=True)
190
191 def elaborate(self, platform):
192 """Elaborate this module."""
193 m = Module()
194 s1 = Signal(self.width, reset_less=True)
195 s2 = Signal(self.width, reset_less=True)
196 s3 = Signal(self.width, reset_less=True)
197 c1 = Signal(self.width, reset_less=True)
198 c2 = Signal(self.width, reset_less=True)
199 c3 = Signal(self.width, reset_less=True)
200 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
201 m.d.comb += s1.eq(Cat(0, self.in0))
202 m.d.comb += s2.eq(Cat(0, self.in1))
203 m.d.comb += s3.eq(Cat(0, self.in2))
204 m.d.comb += c1.eq(s1 & s2 & self.mask)
205 m.d.comb += c2.eq(s2 & s3 & self.mask)
206 m.d.comb += c3.eq(s3 & s1 & self.mask)
207 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
208 return m
209
210
211 class PartitionedAdder(Elaboratable):
212 """Partitioned Adder.
213
214 Performs the final add. The partition points are included in the
215 actual add (in one of the operands only), which causes a carry over
216 to the next bit. Then the final output *removes* the extra bits from
217 the result.
218
219 partition: .... P... P... P... P... (32 bits)
220 a : .... .... .... .... .... (32 bits)
221 b : .... .... .... .... .... (32 bits)
222 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
223 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
224 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
225 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
226
227 :attribute width: the bit width of the input and output. Read-only.
228 :attribute a: the first input to the adder
229 :attribute b: the second input to the adder
230 :attribute output: the sum output
231 :attribute partition_points: the input partition points. Modification not
232 supported, except for by ``Signal.eq``.
233 """
234
235 def __init__(self, width, partition_points, partition_step=1):
236 """Create a ``PartitionedAdder``.
237
238 :param width: the bit width of the input and output
239 :param partition_points: the input partition points
240 :param partition_step: a multiplier (typically double) step
241 which in-place "expands" the partition points
242 """
243 self.width = width
244 self.pmul = partition_step
245 self.a = Signal(width, reset_less=True)
246 self.b = Signal(width, reset_less=True)
247 self.output = Signal(width, reset_less=True)
248 self.partition_points = PartitionPoints(partition_points)
249 if not self.partition_points.fits_in_width(width):
250 raise ValueError("partition_points doesn't fit in width")
251 expanded_width = 0
252 for i in range(self.width):
253 if i in self.partition_points:
254 expanded_width += 1
255 expanded_width += 1
256 self._expanded_width = expanded_width
257
258 def elaborate(self, platform):
259 """Elaborate this module."""
260 m = Module()
261 expanded_a = Signal(self._expanded_width, reset_less=True)
262 expanded_b = Signal(self._expanded_width, reset_less=True)
263 expanded_o = Signal(self._expanded_width, reset_less=True)
264
265 expanded_index = 0
266 # store bits in a list, use Cat later. graphviz is much cleaner
267 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
268
269 # partition points are "breaks" (extra zeros or 1s) in what would
270 # otherwise be a massive long add. when the "break" points are 0,
271 # whatever is in it (in the output) is discarded. however when
272 # there is a "1", it causes a roll-over carry to the *next* bit.
273 # we still ignore the "break" bit in the [intermediate] output,
274 # however by that time we've got the effect that we wanted: the
275 # carry has been carried *over* the break point.
276
277 for i in range(self.width):
278 pi = i/self.pmul # double the range of the partition point test
279 if pi.is_integer() and pi in self.partition_points:
280 # add extra bit set to 0 + 0 for enabled partition points
281 # and 1 + 0 for disabled partition points
282 ea.append(expanded_a[expanded_index])
283 al.append(~self.partition_points[pi]) # add extra bit in a
284 eb.append(expanded_b[expanded_index])
285 bl.append(C(0)) # yes, add a zero
286 expanded_index += 1 # skip the extra point. NOT in the output
287 ea.append(expanded_a[expanded_index])
288 eb.append(expanded_b[expanded_index])
289 eo.append(expanded_o[expanded_index])
290 al.append(self.a[i])
291 bl.append(self.b[i])
292 ol.append(self.output[i])
293 expanded_index += 1
294
295 # combine above using Cat
296 m.d.comb += Cat(*ea).eq(Cat(*al))
297 m.d.comb += Cat(*eb).eq(Cat(*bl))
298 m.d.comb += Cat(*ol).eq(Cat(*eo))
299
300 # use only one addition to take advantage of look-ahead carry and
301 # special hardware on FPGAs
302 m.d.comb += expanded_o.eq(expanded_a + expanded_b)
303 return m
304
305
306 FULL_ADDER_INPUT_COUNT = 3
307
308 class AddReduceData:
309
310 def __init__(self, part_pts, n_inputs, output_width, n_parts):
311 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
312 for i in range(n_parts)]
313 self.terms = [Signal(output_width, name=f"terms_{i}",
314 reset_less=True)
315 for i in range(n_inputs)]
316 self.part_pts = part_pts.like()
317
318 def eq_from(self, part_pts, inputs, part_ops):
319 return [self.part_pts.eq(part_pts)] + \
320 [self.terms[i].eq(inputs[i])
321 for i in range(len(self.terms))] + \
322 [self.part_ops[i].eq(part_ops[i])
323 for i in range(len(self.part_ops))]
324
325 def eq(self, rhs):
326 return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
327
328
329 class FinalReduceData:
330
331 def __init__(self, part_pts, output_width, n_parts):
332 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
333 for i in range(n_parts)]
334 self.output = Signal(output_width, reset_less=True)
335 self.part_pts = part_pts.like()
336
337 def eq_from(self, part_pts, output, part_ops):
338 return [self.part_pts.eq(part_pts)] + \
339 [self.output.eq(output)] + \
340 [self.part_ops[i].eq(part_ops[i])
341 for i in range(len(self.part_ops))]
342
343 def eq(self, rhs):
344 return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
345
346
347 class FinalAdd(PipeModBase):
348 """ Final stage of add reduce
349 """
350
351 def __init__(self, pspec, lidx, n_inputs, partition_points,
352 partition_step=1):
353 self.lidx = lidx
354 self.partition_step = partition_step
355 self.output_width = pspec.width * 2
356 self.n_inputs = n_inputs
357 self.n_parts = pspec.n_parts
358 self.partition_points = PartitionPoints(partition_points)
359 if not self.partition_points.fits_in_width(self.output_width):
360 raise ValueError("partition_points doesn't fit in output_width")
361
362 super().__init__(pspec, "finaladd")
363
364 def ispec(self):
365 return AddReduceData(self.partition_points, self.n_inputs,
366 self.output_width, self.n_parts)
367
368 def ospec(self):
369 return FinalReduceData(self.partition_points,
370 self.output_width, self.n_parts)
371
372 def elaborate(self, platform):
373 """Elaborate this module."""
374 m = Module()
375
376 output_width = self.output_width
377 output = Signal(output_width, reset_less=True)
378 if self.n_inputs == 0:
379 # use 0 as the default output value
380 m.d.comb += output.eq(0)
381 elif self.n_inputs == 1:
382 # handle single input
383 m.d.comb += output.eq(self.i.terms[0])
384 else:
385 # base case for adding 2 inputs
386 assert self.n_inputs == 2
387 adder = PartitionedAdder(output_width,
388 self.i.part_pts, self.partition_step)
389 m.submodules.final_adder = adder
390 m.d.comb += adder.a.eq(self.i.terms[0])
391 m.d.comb += adder.b.eq(self.i.terms[1])
392 m.d.comb += output.eq(adder.output)
393
394 # create output
395 m.d.comb += self.o.eq_from(self.i.part_pts, output,
396 self.i.part_ops)
397
398 return m
399
400
401 class AddReduceSingle(PipeModBase):
402 """Add list of numbers together.
403
404 :attribute inputs: input ``Signal``s to be summed. Modification not
405 supported, except for by ``Signal.eq``.
406 :attribute register_levels: List of nesting levels that should have
407 pipeline registers.
408 :attribute output: output sum.
409 :attribute partition_points: the input partition points. Modification not
410 supported, except for by ``Signal.eq``.
411 """
412
413 def __init__(self, pspec, lidx, n_inputs, partition_points,
414 partition_step=1):
415 """Create an ``AddReduce``.
416
417 :param inputs: input ``Signal``s to be summed.
418 :param output_width: bit-width of ``output``.
419 :param partition_points: the input partition points.
420 """
421 self.lidx = lidx
422 self.partition_step = partition_step
423 self.n_inputs = n_inputs
424 self.n_parts = pspec.n_parts
425 self.output_width = pspec.width * 2
426 self.partition_points = PartitionPoints(partition_points)
427 if not self.partition_points.fits_in_width(self.output_width):
428 raise ValueError("partition_points doesn't fit in output_width")
429
430 self.groups = AddReduceSingle.full_adder_groups(n_inputs)
431 self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
432
433 super().__init__(pspec, "addreduce_%d" % lidx)
434
435 def ispec(self):
436 return AddReduceData(self.partition_points, self.n_inputs,
437 self.output_width, self.n_parts)
438
439 def ospec(self):
440 return AddReduceData(self.partition_points, self.n_terms,
441 self.output_width, self.n_parts)
442
443 @staticmethod
444 def calc_n_inputs(n_inputs, groups):
445 retval = len(groups)*2
446 if n_inputs % FULL_ADDER_INPUT_COUNT == 1:
447 retval += 1
448 elif n_inputs % FULL_ADDER_INPUT_COUNT == 2:
449 retval += 2
450 else:
451 assert n_inputs % FULL_ADDER_INPUT_COUNT == 0
452 return retval
453
454 @staticmethod
455 def get_max_level(input_count):
456 """Get the maximum level.
457
458 All ``register_levels`` must be less than or equal to the maximum
459 level.
460 """
461 retval = 0
462 while True:
463 groups = AddReduceSingle.full_adder_groups(input_count)
464 if len(groups) == 0:
465 return retval
466 input_count %= FULL_ADDER_INPUT_COUNT
467 input_count += 2 * len(groups)
468 retval += 1
469
470 @staticmethod
471 def full_adder_groups(input_count):
472 """Get ``inputs`` indices for which a full adder should be built."""
473 return range(0,
474 input_count - FULL_ADDER_INPUT_COUNT + 1,
475 FULL_ADDER_INPUT_COUNT)
476
477 def create_next_terms(self):
478 """ create next intermediate terms, for linking up in elaborate, below
479 """
480 terms = []
481 adders = []
482
483 # create full adders for this recursive level.
484 # this shrinks N terms to 2 * (N // 3) plus the remainder
485 for i in self.groups:
486 adder_i = MaskedFullAdder(self.output_width)
487 adders.append((i, adder_i))
488 # add both the sum and the masked-carry to the next level.
489 # 3 inputs have now been reduced to 2...
490 terms.append(adder_i.sum)
491 terms.append(adder_i.mcarry)
492 # handle the remaining inputs.
493 if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
494 terms.append(self.i.terms[-1])
495 elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
496 # Just pass the terms to the next layer, since we wouldn't gain
497 # anything by using a half adder since there would still be 2 terms
498 # and just passing the terms to the next layer saves gates.
499 terms.append(self.i.terms[-2])
500 terms.append(self.i.terms[-1])
501 else:
502 assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
503
504 return terms, adders
505
506 def elaborate(self, platform):
507 """Elaborate this module."""
508 m = Module()
509
510 terms, adders = self.create_next_terms()
511
512 # copy the intermediate terms to the output
513 for i, value in enumerate(terms):
514 m.d.comb += self.o.terms[i].eq(value)
515
516 # copy reg part points and part ops to output
517 m.d.comb += self.o.part_pts.eq(self.i.part_pts)
518 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
519 for i in range(len(self.i.part_ops))]
520
521 # set up the partition mask (for the adders)
522 part_mask = Signal(self.output_width, reset_less=True)
523
524 # get partition points as a mask
525 mask = self.i.part_pts.as_mask(self.output_width,
526 mul=self.partition_step)
527 m.d.comb += part_mask.eq(mask)
528
529 # add and link the intermediate term modules
530 for i, (iidx, adder_i) in enumerate(adders):
531 setattr(m.submodules, f"adder_{i}", adder_i)
532
533 m.d.comb += adder_i.in0.eq(self.i.terms[iidx])
534 m.d.comb += adder_i.in1.eq(self.i.terms[iidx + 1])
535 m.d.comb += adder_i.in2.eq(self.i.terms[iidx + 2])
536 m.d.comb += adder_i.mask.eq(part_mask)
537
538 return m
539
540
541 class AddReduceInternal:
542 """Iteratively Add list of numbers together.
543
544 :attribute inputs: input ``Signal``s to be summed. Modification not
545 supported, except for by ``Signal.eq``.
546 :attribute register_levels: List of nesting levels that should have
547 pipeline registers.
548 :attribute output: output sum.
549 :attribute partition_points: the input partition points. Modification not
550 supported, except for by ``Signal.eq``.
551 """
552
553 def __init__(self, pspec, n_inputs, part_pts, partition_step=1):
554 """Create an ``AddReduce``.
555
556 :param inputs: input ``Signal``s to be summed.
557 :param output_width: bit-width of ``output``.
558 :param partition_points: the input partition points.
559 """
560 self.pspec = pspec
561 self.n_inputs = n_inputs
562 self.output_width = pspec.width * 2
563 self.partition_points = part_pts
564 self.partition_step = partition_step
565
566 self.create_levels()
567
568 def create_levels(self):
569 """creates reduction levels"""
570
571 mods = []
572 partition_points = self.partition_points
573 ilen = self.n_inputs
574 while True:
575 groups = AddReduceSingle.full_adder_groups(ilen)
576 if len(groups) == 0:
577 break
578 lidx = len(mods)
579 next_level = AddReduceSingle(self.pspec, lidx, ilen,
580 partition_points,
581 self.partition_step)
582 mods.append(next_level)
583 partition_points = next_level.i.part_pts
584 ilen = len(next_level.o.terms)
585
586 lidx = len(mods)
587 next_level = FinalAdd(self.pspec, lidx, ilen,
588 partition_points, self.partition_step)
589 mods.append(next_level)
590
591 self.levels = mods
592
593
594 class AddReduce(AddReduceInternal, Elaboratable):
595 """Recursively Add list of numbers together.
596
597 :attribute inputs: input ``Signal``s to be summed. Modification not
598 supported, except for by ``Signal.eq``.
599 :attribute register_levels: List of nesting levels that should have
600 pipeline registers.
601 :attribute output: output sum.
602 :attribute partition_points: the input partition points. Modification not
603 supported, except for by ``Signal.eq``.
604 """
605
606 def __init__(self, inputs, output_width, register_levels, part_pts,
607 part_ops, partition_step=1):
608 """Create an ``AddReduce``.
609
610 :param inputs: input ``Signal``s to be summed.
611 :param output_width: bit-width of ``output``.
612 :param register_levels: List of nesting levels that should have
613 pipeline registers.
614 :param partition_points: the input partition points.
615 """
616 self._inputs = inputs
617 self._part_pts = part_pts
618 self._part_ops = part_ops
619 n_parts = len(part_ops)
620 self.i = AddReduceData(part_pts, len(inputs),
621 output_width, n_parts)
622 AddReduceInternal.__init__(self, pspec, n_inputs, part_pts,
623 partition_step)
624 self.o = FinalReduceData(part_pts, output_width, n_parts)
625 self.register_levels = register_levels
626
627 @staticmethod
628 def get_max_level(input_count):
629 return AddReduceSingle.get_max_level(input_count)
630
631 @staticmethod
632 def next_register_levels(register_levels):
633 """``Iterable`` of ``register_levels`` for next recursive level."""
634 for level in register_levels:
635 if level > 0:
636 yield level - 1
637
638 def elaborate(self, platform):
639 """Elaborate this module."""
640 m = Module()
641
642 m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops)
643
644 for i, next_level in enumerate(self.levels):
645 setattr(m.submodules, "next_level%d" % i, next_level)
646
647 i = self.i
648 for idx in range(len(self.levels)):
649 mcur = self.levels[idx]
650 if idx in self.register_levels:
651 m.d.sync += mcur.i.eq(i)
652 else:
653 m.d.comb += mcur.i.eq(i)
654 i = mcur.o # for next loop
655
656 # output comes from last module
657 m.d.comb += self.o.eq(i)
658
659 return m
660
661
662 OP_MUL_LOW = 0
663 OP_MUL_SIGNED_HIGH = 1
664 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
665 OP_MUL_UNSIGNED_HIGH = 3
666
667
668 def get_term(value, shift=0, enabled=None):
669 if enabled is not None:
670 value = Mux(enabled, value, 0)
671 if shift > 0:
672 value = Cat(Repl(C(0, 1), shift), value)
673 else:
674 assert shift == 0
675 return value
676
677
678 class ProductTerm(Elaboratable):
679 """ this class creates a single product term (a[..]*b[..]).
680 it has a design flaw in that is the *output* that is selected,
681 where the multiplication(s) are combinatorially generated
682 all the time.
683 """
684
685 def __init__(self, width, twidth, pbwid, a_index, b_index):
686 self.a_index = a_index
687 self.b_index = b_index
688 shift = 8 * (self.a_index + self.b_index)
689 self.pwidth = width
690 self.twidth = twidth
691 self.width = width*2
692 self.shift = shift
693
694 self.ti = Signal(self.width, reset_less=True)
695 self.term = Signal(twidth, reset_less=True)
696 self.a = Signal(twidth//2, reset_less=True)
697 self.b = Signal(twidth//2, reset_less=True)
698 self.pb_en = Signal(pbwid, reset_less=True)
699
700 self.tl = tl = []
701 min_index = min(self.a_index, self.b_index)
702 max_index = max(self.a_index, self.b_index)
703 for i in range(min_index, max_index):
704 tl.append(self.pb_en[i])
705 name = "te_%d_%d" % (self.a_index, self.b_index)
706 if len(tl) > 0:
707 term_enabled = Signal(name=name, reset_less=True)
708 else:
709 term_enabled = None
710 self.enabled = term_enabled
711 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
712
713 def elaborate(self, platform):
714
715 m = Module()
716 if self.enabled is not None:
717 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
718
719 bsa = Signal(self.width, reset_less=True)
720 bsb = Signal(self.width, reset_less=True)
721 a_index, b_index = self.a_index, self.b_index
722 pwidth = self.pwidth
723 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
724 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
725 m.d.comb += self.ti.eq(bsa * bsb)
726 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
727 """
728 #TODO: sort out width issues, get inputs a/b switched on/off.
729 #data going into Muxes is 1/2 the required width
730
731 pwidth = self.pwidth
732 width = self.width
733 bsa = Signal(self.twidth//2, reset_less=True)
734 bsb = Signal(self.twidth//2, reset_less=True)
735 asel = Signal(width, reset_less=True)
736 bsel = Signal(width, reset_less=True)
737 a_index, b_index = self.a_index, self.b_index
738 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
739 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
740 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
741 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
742 m.d.comb += self.ti.eq(bsa * bsb)
743 m.d.comb += self.term.eq(self.ti)
744 """
745
746 return m
747
748
749 class ProductTerms(Elaboratable):
750 """ creates a bank of product terms. also performs the actual bit-selection
751 this class is to be wrapped with a for-loop on the "a" operand.
752 it creates a second-level for-loop on the "b" operand.
753 """
754 def __init__(self, width, twidth, pbwid, a_index, blen):
755 self.a_index = a_index
756 self.blen = blen
757 self.pwidth = width
758 self.twidth = twidth
759 self.pbwid = pbwid
760 self.a = Signal(twidth//2, reset_less=True)
761 self.b = Signal(twidth//2, reset_less=True)
762 self.pb_en = Signal(pbwid, reset_less=True)
763 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
764 for i in range(blen)]
765
766 def elaborate(self, platform):
767
768 m = Module()
769
770 for b_index in range(self.blen):
771 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
772 self.a_index, b_index)
773 setattr(m.submodules, "term_%d" % b_index, t)
774
775 m.d.comb += t.a.eq(self.a)
776 m.d.comb += t.b.eq(self.b)
777 m.d.comb += t.pb_en.eq(self.pb_en)
778
779 m.d.comb += self.terms[b_index].eq(t.term)
780
781 return m
782
783
784 class LSBNegTerm(Elaboratable):
785
786 def __init__(self, bit_width):
787 self.bit_width = bit_width
788 self.part = Signal(reset_less=True)
789 self.signed = Signal(reset_less=True)
790 self.op = Signal(bit_width, reset_less=True)
791 self.msb = Signal(reset_less=True)
792 self.nt = Signal(bit_width*2, reset_less=True)
793 self.nl = Signal(bit_width*2, reset_less=True)
794
795 def elaborate(self, platform):
796 m = Module()
797 comb = m.d.comb
798 bit_wid = self.bit_width
799 ext = Repl(0, bit_wid) # extend output to HI part
800
801 # determine sign of each incoming number *in this partition*
802 enabled = Signal(reset_less=True)
803 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
804
805 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
806 # negation operation is split into a bitwise not and a +1.
807 # likewise for 16, 32, and 64-bit values.
808
809 # width-extended 1s complement if a is signed, otherwise zero
810 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
811
812 # add 1 if signed, otherwise add zero
813 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
814
815 return m
816
817
818 class Parts(Elaboratable):
819
820 def __init__(self, pbwid, part_pts, n_parts):
821 self.pbwid = pbwid
822 # inputs
823 self.part_pts = PartitionPoints.like(part_pts)
824 # outputs
825 self.parts = [Signal(name=f"part_{i}", reset_less=True)
826 for i in range(n_parts)]
827
828 def elaborate(self, platform):
829 m = Module()
830
831 part_pts, parts = self.part_pts, self.parts
832 # collect part-bytes (double factor because the input is extended)
833 pbs = Signal(self.pbwid, reset_less=True)
834 tl = []
835 for i in range(self.pbwid):
836 pb = Signal(name="pb%d" % i, reset_less=True)
837 m.d.comb += pb.eq(part_pts.part_byte(i))
838 tl.append(pb)
839 m.d.comb += pbs.eq(Cat(*tl))
840
841 # negated-temporary copy of partition bits
842 npbs = Signal.like(pbs, reset_less=True)
843 m.d.comb += npbs.eq(~pbs)
844 byte_count = 8 // len(parts)
845 for i in range(len(parts)):
846 pbl = []
847 pbl.append(npbs[i * byte_count - 1])
848 for j in range(i * byte_count, (i + 1) * byte_count - 1):
849 pbl.append(pbs[j])
850 pbl.append(npbs[(i + 1) * byte_count - 1])
851 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
852 m.d.comb += value.eq(Cat(*pbl))
853 m.d.comb += parts[i].eq(~(value).bool())
854
855 return m
856
857
858 class Part(Elaboratable):
859 """ a key class which, depending on the partitioning, will determine
860 what action to take when parts of the output are signed or unsigned.
861
862 this requires 2 pieces of data *per operand, per partition*:
863 whether the MSB is HI/LO (per partition!), and whether a signed
864 or unsigned operation has been *requested*.
865
866 once that is determined, signed is basically carried out
867 by splitting 2's complement into 1's complement plus one.
868 1's complement is just a bit-inversion.
869
870 the extra terms - as separate terms - are then thrown at the
871 AddReduce alongside the multiplication part-results.
872 """
873 def __init__(self, part_pts, width, n_parts, pbwid):
874
875 self.pbwid = pbwid
876 self.part_pts = part_pts
877
878 # inputs
879 self.a = Signal(64, reset_less=True)
880 self.b = Signal(64, reset_less=True)
881 self.a_signed = [Signal(name=f"a_signed_{i}", reset_less=True)
882 for i in range(8)]
883 self.b_signed = [Signal(name=f"_b_signed_{i}", reset_less=True)
884 for i in range(8)]
885 self.pbs = Signal(pbwid, reset_less=True)
886
887 # outputs
888 self.parts = [Signal(name=f"part_{i}", reset_less=True)
889 for i in range(n_parts)]
890
891 self.not_a_term = Signal(width, reset_less=True)
892 self.neg_lsb_a_term = Signal(width, reset_less=True)
893 self.not_b_term = Signal(width, reset_less=True)
894 self.neg_lsb_b_term = Signal(width, reset_less=True)
895
896 def elaborate(self, platform):
897 m = Module()
898
899 pbs, parts = self.pbs, self.parts
900 part_pts = self.part_pts
901 m.submodules.p = p = Parts(self.pbwid, part_pts, len(parts))
902 m.d.comb += p.part_pts.eq(part_pts)
903 parts = p.parts
904
905 byte_count = 8 // len(parts)
906
907 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
908 self.not_a_term, self.neg_lsb_a_term,
909 self.not_b_term, self.neg_lsb_b_term)
910
911 byte_width = 8 // len(parts) # byte width
912 bit_wid = 8 * byte_width # bit width
913 nat, nbt, nla, nlb = [], [], [], []
914 for i in range(len(parts)):
915 # work out bit-inverted and +1 term for a.
916 pa = LSBNegTerm(bit_wid)
917 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
918 m.d.comb += pa.part.eq(parts[i])
919 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
920 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
921 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
922 nat.append(pa.nt)
923 nla.append(pa.nl)
924
925 # work out bit-inverted and +1 term for b
926 pb = LSBNegTerm(bit_wid)
927 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
928 m.d.comb += pb.part.eq(parts[i])
929 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
930 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
931 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
932 nbt.append(pb.nt)
933 nlb.append(pb.nl)
934
935 # concatenate together and return all 4 results.
936 m.d.comb += [not_a_term.eq(Cat(*nat)),
937 not_b_term.eq(Cat(*nbt)),
938 neg_lsb_a_term.eq(Cat(*nla)),
939 neg_lsb_b_term.eq(Cat(*nlb)),
940 ]
941
942 return m
943
944
945 class IntermediateOut(Elaboratable):
946 """ selects the HI/LO part of the multiplication, for a given bit-width
947 the output is also reconstructed in its SIMD (partition) lanes.
948 """
949 def __init__(self, width, out_wid, n_parts):
950 self.width = width
951 self.n_parts = n_parts
952 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
953 for i in range(8)]
954 self.intermed = Signal(out_wid, reset_less=True)
955 self.output = Signal(out_wid//2, reset_less=True)
956
957 def elaborate(self, platform):
958 m = Module()
959
960 ol = []
961 w = self.width
962 sel = w // 8
963 for i in range(self.n_parts):
964 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
965 m.d.comb += op.eq(
966 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
967 self.intermed.bit_select(i * w*2, w),
968 self.intermed.bit_select(i * w*2 + w, w)))
969 ol.append(op)
970 m.d.comb += self.output.eq(Cat(*ol))
971
972 return m
973
974
975 class FinalOut(PipeModBase):
976 """ selects the final output based on the partitioning.
977
978 each byte is selectable independently, i.e. it is possible
979 that some partitions requested 8-bit computation whilst others
980 requested 16 or 32 bit.
981 """
982 def __init__(self, pspec, part_pts):
983
984 self.part_pts = part_pts
985 self.output_width = pspec.width * 2
986 self.n_parts = pspec.n_parts
987 self.out_wid = pspec.width
988
989 super().__init__(pspec, "finalout")
990
991 def ispec(self):
992 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
993
994 def ospec(self):
995 return OutputData()
996
997 def elaborate(self, platform):
998 m = Module()
999
1000 part_pts = self.part_pts
1001 m.submodules.p_8 = p_8 = Parts(8, part_pts, 8)
1002 m.submodules.p_16 = p_16 = Parts(8, part_pts, 4)
1003 m.submodules.p_32 = p_32 = Parts(8, part_pts, 2)
1004 m.submodules.p_64 = p_64 = Parts(8, part_pts, 1)
1005
1006 out_part_pts = self.i.part_pts
1007
1008 # temporaries
1009 d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
1010 d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
1011 d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
1012
1013 i8 = Signal(self.out_wid, reset_less=True)
1014 i16 = Signal(self.out_wid, reset_less=True)
1015 i32 = Signal(self.out_wid, reset_less=True)
1016 i64 = Signal(self.out_wid, reset_less=True)
1017
1018 m.d.comb += p_8.part_pts.eq(out_part_pts)
1019 m.d.comb += p_16.part_pts.eq(out_part_pts)
1020 m.d.comb += p_32.part_pts.eq(out_part_pts)
1021 m.d.comb += p_64.part_pts.eq(out_part_pts)
1022
1023 for i in range(len(p_8.parts)):
1024 m.d.comb += d8[i].eq(p_8.parts[i])
1025 for i in range(len(p_16.parts)):
1026 m.d.comb += d16[i].eq(p_16.parts[i])
1027 for i in range(len(p_32.parts)):
1028 m.d.comb += d32[i].eq(p_32.parts[i])
1029 m.d.comb += i8.eq(self.i.outputs[0])
1030 m.d.comb += i16.eq(self.i.outputs[1])
1031 m.d.comb += i32.eq(self.i.outputs[2])
1032 m.d.comb += i64.eq(self.i.outputs[3])
1033
1034 ol = []
1035 for i in range(8):
1036 # select one of the outputs: d8 selects i8, d16 selects i16
1037 # d32 selects i32, and the default is i64.
1038 # d8 and d16 are ORed together in the first Mux
1039 # then the 2nd selects either i8 or i16.
1040 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
1041 op = Signal(8, reset_less=True, name="op_%d" % i)
1042 m.d.comb += op.eq(
1043 Mux(d8[i] | d16[i // 2],
1044 Mux(d8[i], i8.bit_select(i * 8, 8),
1045 i16.bit_select(i * 8, 8)),
1046 Mux(d32[i // 4], i32.bit_select(i * 8, 8),
1047 i64.bit_select(i * 8, 8))))
1048 ol.append(op)
1049
1050 # create outputs
1051 m.d.comb += self.o.output.eq(Cat(*ol))
1052 m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
1053
1054 return m
1055
1056
1057 class OrMod(Elaboratable):
1058 """ ORs four values together in a hierarchical tree
1059 """
1060 def __init__(self, wid):
1061 self.wid = wid
1062 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
1063 for i in range(4)]
1064 self.orout = Signal(wid, reset_less=True)
1065
1066 def elaborate(self, platform):
1067 m = Module()
1068 or1 = Signal(self.wid, reset_less=True)
1069 or2 = Signal(self.wid, reset_less=True)
1070 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
1071 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
1072 m.d.comb += self.orout.eq(or1 | or2)
1073
1074 return m
1075
1076
1077 class Signs(Elaboratable):
1078 """ determines whether a or b are signed numbers
1079 based on the required operation type (OP_MUL_*)
1080 """
1081
1082 def __init__(self):
1083 self.part_ops = Signal(2, reset_less=True)
1084 self.a_signed = Signal(reset_less=True)
1085 self.b_signed = Signal(reset_less=True)
1086
1087 def elaborate(self, platform):
1088
1089 m = Module()
1090
1091 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
1092 bsig = (self.part_ops == OP_MUL_LOW) \
1093 | (self.part_ops == OP_MUL_SIGNED_HIGH)
1094 m.d.comb += self.a_signed.eq(asig)
1095 m.d.comb += self.b_signed.eq(bsig)
1096
1097 return m
1098
1099
1100 class IntermediateData:
1101
1102 def __init__(self, part_pts, output_width, n_parts):
1103 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
1104 for i in range(n_parts)]
1105 self.part_pts = part_pts.like()
1106 self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
1107 for i in range(4)]
1108 # intermediates (needed for unit tests)
1109 self.intermediate_output = Signal(output_width)
1110
1111 def eq_from(self, part_pts, outputs, intermediate_output,
1112 part_ops):
1113 return [self.part_pts.eq(part_pts)] + \
1114 [self.intermediate_output.eq(intermediate_output)] + \
1115 [self.outputs[i].eq(outputs[i])
1116 for i in range(4)] + \
1117 [self.part_ops[i].eq(part_ops[i])
1118 for i in range(len(self.part_ops))]
1119
1120 def eq(self, rhs):
1121 return self.eq_from(rhs.part_pts, rhs.outputs,
1122 rhs.intermediate_output, rhs.part_ops)
1123
1124
1125 class InputData:
1126
1127 def __init__(self):
1128 self.a = Signal(64)
1129 self.b = Signal(64)
1130 self.part_pts = PartitionPoints()
1131 for i in range(8, 64, 8):
1132 self.part_pts[i] = Signal(name=f"part_pts_{i}")
1133 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
1134
1135 def eq_from(self, part_pts, a, b, part_ops):
1136 return [self.part_pts.eq(part_pts)] + \
1137 [self.a.eq(a), self.b.eq(b)] + \
1138 [self.part_ops[i].eq(part_ops[i])
1139 for i in range(len(self.part_ops))]
1140
1141 def eq(self, rhs):
1142 return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
1143
1144
1145 class OutputData:
1146
1147 def __init__(self):
1148 self.intermediate_output = Signal(128) # needed for unit tests
1149 self.output = Signal(64)
1150
1151 def eq(self, rhs):
1152 return [self.intermediate_output.eq(rhs.intermediate_output),
1153 self.output.eq(rhs.output)]
1154
1155
1156 class AllTerms(PipeModBase):
1157 """Set of terms to be added together
1158 """
1159
1160 def __init__(self, pspec, n_inputs):
1161 """Create an ``AllTerms``.
1162 """
1163 self.n_inputs = n_inputs
1164 self.n_parts = pspec.n_parts
1165 self.output_width = pspec.width * 2
1166 super().__init__(pspec, "allterms")
1167
1168 def ispec(self):
1169 return InputData()
1170
1171 def ospec(self):
1172 return AddReduceData(self.i.part_pts, self.n_inputs,
1173 self.output_width, self.n_parts)
1174
1175 def elaborate(self, platform):
1176 m = Module()
1177
1178 eps = self.i.part_pts
1179
1180 # collect part-bytes
1181 pbs = Signal(8, reset_less=True)
1182 tl = []
1183 for i in range(8):
1184 pb = Signal(name="pb%d" % i, reset_less=True)
1185 m.d.comb += pb.eq(eps.part_byte(i))
1186 tl.append(pb)
1187 m.d.comb += pbs.eq(Cat(*tl))
1188
1189 # local variables
1190 signs = []
1191 for i in range(8):
1192 s = Signs()
1193 signs.append(s)
1194 setattr(m.submodules, "signs%d" % i, s)
1195 m.d.comb += s.part_ops.eq(self.i.part_ops[i])
1196
1197 m.submodules.part_8 = part_8 = Part(eps, 128, 8, 8)
1198 m.submodules.part_16 = part_16 = Part(eps, 128, 4, 8)
1199 m.submodules.part_32 = part_32 = Part(eps, 128, 2, 8)
1200 m.submodules.part_64 = part_64 = Part(eps, 128, 1, 8)
1201 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1202 for mod in [part_8, part_16, part_32, part_64]:
1203 m.d.comb += mod.a.eq(self.i.a)
1204 m.d.comb += mod.b.eq(self.i.b)
1205 for i in range(len(signs)):
1206 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1207 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1208 m.d.comb += mod.pbs.eq(pbs)
1209 nat_l.append(mod.not_a_term)
1210 nbt_l.append(mod.not_b_term)
1211 nla_l.append(mod.neg_lsb_a_term)
1212 nlb_l.append(mod.neg_lsb_b_term)
1213
1214 terms = []
1215
1216 for a_index in range(8):
1217 t = ProductTerms(8, 128, 8, a_index, 8)
1218 setattr(m.submodules, "terms_%d" % a_index, t)
1219
1220 m.d.comb += t.a.eq(self.i.a)
1221 m.d.comb += t.b.eq(self.i.b)
1222 m.d.comb += t.pb_en.eq(pbs)
1223
1224 for term in t.terms:
1225 terms.append(term)
1226
1227 # it's fine to bitwise-or data together since they are never enabled
1228 # at the same time
1229 m.submodules.nat_or = nat_or = OrMod(128)
1230 m.submodules.nbt_or = nbt_or = OrMod(128)
1231 m.submodules.nla_or = nla_or = OrMod(128)
1232 m.submodules.nlb_or = nlb_or = OrMod(128)
1233 for l, mod in [(nat_l, nat_or),
1234 (nbt_l, nbt_or),
1235 (nla_l, nla_or),
1236 (nlb_l, nlb_or)]:
1237 for i in range(len(l)):
1238 m.d.comb += mod.orin[i].eq(l[i])
1239 terms.append(mod.orout)
1240
1241 # copy the intermediate terms to the output
1242 for i, value in enumerate(terms):
1243 m.d.comb += self.o.terms[i].eq(value)
1244
1245 # copy reg part points and part ops to output
1246 m.d.comb += self.o.part_pts.eq(eps)
1247 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
1248 for i in range(len(self.i.part_ops))]
1249
1250 return m
1251
1252
1253 class Intermediates(PipeModBase):
1254 """ Intermediate output modules
1255 """
1256
1257 def __init__(self, pspec, part_pts):
1258 self.part_pts = part_pts
1259 self.output_width = pspec.width * 2
1260 self.n_parts = pspec.n_parts
1261
1262 super().__init__(pspec, "intermediates")
1263
1264 def ispec(self):
1265 return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
1266
1267 def ospec(self):
1268 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
1269
1270 def elaborate(self, platform):
1271 m = Module()
1272
1273 out_part_ops = self.i.part_ops
1274 out_part_pts = self.i.part_pts
1275
1276 # create _output_64
1277 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1278 m.d.comb += io64.intermed.eq(self.i.output)
1279 for i in range(8):
1280 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1281 m.d.comb += self.o.outputs[3].eq(io64.output)
1282
1283 # create _output_32
1284 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1285 m.d.comb += io32.intermed.eq(self.i.output)
1286 for i in range(8):
1287 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1288 m.d.comb += self.o.outputs[2].eq(io32.output)
1289
1290 # create _output_16
1291 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1292 m.d.comb += io16.intermed.eq(self.i.output)
1293 for i in range(8):
1294 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1295 m.d.comb += self.o.outputs[1].eq(io16.output)
1296
1297 # create _output_8
1298 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1299 m.d.comb += io8.intermed.eq(self.i.output)
1300 for i in range(8):
1301 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1302 m.d.comb += self.o.outputs[0].eq(io8.output)
1303
1304 for i in range(8):
1305 m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
1306 m.d.comb += self.o.part_pts.eq(out_part_pts)
1307 m.d.comb += self.o.intermediate_output.eq(self.i.output)
1308
1309 return m
1310
1311
1312 class Mul8_16_32_64(Elaboratable):
1313 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1314
1315 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1316 partitions on naturally-aligned boundaries. Supports the operation being
1317 set for each partition independently.
1318
1319 :attribute part_pts: the input partition points. Has a partition point at
1320 multiples of 8 in 0 < i < 64. Each partition point's associated
1321 ``Value`` is a ``Signal``. Modification not supported, except for by
1322 ``Signal.eq``.
1323 :attribute part_ops: the operation for each byte. The operation for a
1324 particular partition is selected by assigning the selected operation
1325 code to each byte in the partition. The allowed operation codes are:
1326
1327 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1328 RISC-V's `mul` instruction.
1329 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1330 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1331 instruction.
1332 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1333 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1334 `mulhsu` instruction.
1335 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1336 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1337 instruction.
1338 """
1339
1340 def __init__(self, register_levels=()):
1341 """ register_levels: specifies the points in the cascade at which
1342 flip-flops are to be inserted.
1343 """
1344
1345 self.id_wid = 0 # num_bits(num_rows)
1346 self.op_wid = 0
1347 self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
1348 self.pspec.n_parts = 8
1349
1350 # parameter(s)
1351 self.register_levels = list(register_levels)
1352
1353 self.i = self.ispec()
1354 self.o = self.ospec()
1355
1356 # inputs
1357 self.part_pts = self.i.part_pts
1358 self.part_ops = self.i.part_ops
1359 self.a = self.i.a
1360 self.b = self.i.b
1361
1362 # output
1363 self.intermediate_output = self.o.intermediate_output
1364 self.output = self.o.output
1365
1366 def ispec(self):
1367 return InputData()
1368
1369 def ospec(self):
1370 return OutputData()
1371
1372 def elaborate(self, platform):
1373 m = Module()
1374
1375 part_pts = self.part_pts
1376
1377 n_inputs = 64 + 4
1378 t = AllTerms(self.pspec, n_inputs)
1379 t.setup(m, self.i)
1380
1381 terms = t.o.terms
1382
1383 at = AddReduceInternal(self.pspec, n_inputs, part_pts, partition_step=2)
1384
1385 i = t.o
1386 for idx in range(len(at.levels)):
1387 mcur = at.levels[idx]
1388 mcur.setup(m, i)
1389 o = mcur.ospec()
1390 if idx in self.register_levels:
1391 m.d.sync += o.eq(mcur.process(i))
1392 else:
1393 m.d.comb += o.eq(mcur.process(i))
1394 i = o # for next loop
1395
1396 interm = Intermediates(self.pspec, part_pts)
1397 interm.setup(m, i)
1398 o = interm.process(interm.i)
1399
1400 # final output
1401 finalout = FinalOut(self.pspec, part_pts)
1402 finalout.setup(m, o)
1403 m.d.comb += self.o.eq(finalout.process(o))
1404
1405 return m
1406
1407
1408 if __name__ == "__main__":
1409 m = Mul8_16_32_64()
1410 main(m, ports=[m.a,
1411 m.b,
1412 m.intermediate_output,
1413 m.output,
1414 *m.part_ops,
1415 *m.part_pts.values()])