1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
11 from ieee754
.pipeline
import PipelineSpec
12 from nmutil
.pipemodbase
import PipeModBase
15 class PartitionPoints(dict):
16 """Partition points and corresponding ``Value``s.
18 The points at where an ALU is partitioned along with ``Value``s that
19 specify if the corresponding partition points are enabled.
21 For example: ``{1: True, 5: True, 10: True}`` with
22 ``width == 16`` specifies that the ALU is split into 4 sections:
25 * bits 5 <= ``i`` < 10
26 * bits 10 <= ``i`` < 16
28 If the partition_points were instead ``{1: True, 5: a, 10: True}``
29 where ``a`` is a 1-bit ``Signal``:
30 * If ``a`` is asserted:
33 * bits 5 <= ``i`` < 10
34 * bits 10 <= ``i`` < 16
37 * bits 1 <= ``i`` < 10
38 * bits 10 <= ``i`` < 16
41 def __init__(self
, partition_points
=None):
42 """Create a new ``PartitionPoints``.
44 :param partition_points: the input partition points to values mapping.
47 if partition_points
is not None:
48 for point
, enabled
in partition_points
.items():
49 if not isinstance(point
, int):
50 raise TypeError("point must be a non-negative integer")
52 raise ValueError("point must be a non-negative integer")
53 self
[point
] = Value
.wrap(enabled
)
55 def like(self
, name
=None, src_loc_at
=0, mul
=1):
56 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
58 :param name: the base name for the new ``Signal``s.
59 :param mul: a multiplication factor on the indices
62 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
63 retval
= PartitionPoints()
64 for point
, enabled
in self
.items():
66 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
70 """Assign ``PartitionPoints`` using ``Signal.eq``."""
71 if set(self
.keys()) != set(rhs
.keys()):
72 raise ValueError("incompatible point set")
73 for point
, enabled
in self
.items():
74 yield enabled
.eq(rhs
[point
])
76 def as_mask(self
, width
, mul
=1):
77 """Create a bit-mask from `self`.
79 Each bit in the returned mask is clear only if the partition point at
80 the same bit-index is enabled.
82 :param width: the bit width of the resulting mask
83 :param mul: a "multiplier" which in-place expands the partition points
84 typically set to "2" when used for multipliers
87 for i
in range(width
):
89 if i
.is_integer() and int(i
) in self
:
95 def get_max_partition_count(self
, width
):
96 """Get the maximum number of partitions.
98 Gets the number of partitions when all partition points are enabled.
101 for point
in self
.keys():
106 def fits_in_width(self
, width
):
107 """Check if all partition points are smaller than `width`."""
108 for point
in self
.keys():
113 def part_byte(self
, index
, mfactor
=1): # mfactor used for "expanding"
114 if index
== -1 or index
== 7:
116 assert index
>= 0 and index
< 8
117 return self
[(index
* 8 + 8)*mfactor
]
120 class FullAdder(Elaboratable
):
123 :attribute in0: the first input
124 :attribute in1: the second input
125 :attribute in2: the third input
126 :attribute sum: the sum output
127 :attribute carry: the carry output
129 Rather than do individual full adders (and have an array of them,
130 which would be very slow to simulate), this module can specify the
131 bit width of the inputs and outputs: in effect it performs multiple
132 Full 3-2 Add operations "in parallel".
135 def __init__(self
, width
):
136 """Create a ``FullAdder``.
138 :param width: the bit width of the input and output
140 self
.in0
= Signal(width
, reset_less
=True)
141 self
.in1
= Signal(width
, reset_less
=True)
142 self
.in2
= Signal(width
, reset_less
=True)
143 self
.sum = Signal(width
, reset_less
=True)
144 self
.carry
= Signal(width
, reset_less
=True)
146 def elaborate(self
, platform
):
147 """Elaborate this module."""
149 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
150 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
151 |
(self
.in1
& self
.in2
)
152 |
(self
.in2
& self
.in0
))
156 class MaskedFullAdder(Elaboratable
):
157 """Masked Full Adder.
159 :attribute mask: the carry partition mask
160 :attribute in0: the first input
161 :attribute in1: the second input
162 :attribute in2: the third input
163 :attribute sum: the sum output
164 :attribute mcarry: the masked carry output
166 FullAdders are always used with a "mask" on the output. To keep
167 the graphviz "clean", this class performs the masking here rather
168 than inside a large for-loop.
170 See the following discussion as to why this is no longer derived
171 from FullAdder. Each carry is shifted here *before* being ANDed
172 with the mask, so that an AOI cell may be used (which is more
174 https://en.wikipedia.org/wiki/AND-OR-Invert
175 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
178 def __init__(self
, width
):
179 """Create a ``MaskedFullAdder``.
181 :param width: the bit width of the input and output
184 self
.mask
= Signal(width
, reset_less
=True)
185 self
.mcarry
= Signal(width
, reset_less
=True)
186 self
.in0
= Signal(width
, reset_less
=True)
187 self
.in1
= Signal(width
, reset_less
=True)
188 self
.in2
= Signal(width
, reset_less
=True)
189 self
.sum = Signal(width
, reset_less
=True)
191 def elaborate(self
, platform
):
192 """Elaborate this module."""
194 s1
= Signal(self
.width
, reset_less
=True)
195 s2
= Signal(self
.width
, reset_less
=True)
196 s3
= Signal(self
.width
, reset_less
=True)
197 c1
= Signal(self
.width
, reset_less
=True)
198 c2
= Signal(self
.width
, reset_less
=True)
199 c3
= Signal(self
.width
, reset_less
=True)
200 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
201 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
202 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
203 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
204 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
205 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
206 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
207 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
211 class PartitionedAdder(Elaboratable
):
212 """Partitioned Adder.
214 Performs the final add. The partition points are included in the
215 actual add (in one of the operands only), which causes a carry over
216 to the next bit. Then the final output *removes* the extra bits from
219 partition: .... P... P... P... P... (32 bits)
220 a : .... .... .... .... .... (32 bits)
221 b : .... .... .... .... .... (32 bits)
222 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
223 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
224 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
225 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
227 :attribute width: the bit width of the input and output. Read-only.
228 :attribute a: the first input to the adder
229 :attribute b: the second input to the adder
230 :attribute output: the sum output
231 :attribute partition_points: the input partition points. Modification not
232 supported, except for by ``Signal.eq``.
235 def __init__(self
, width
, partition_points
, partition_step
=1):
236 """Create a ``PartitionedAdder``.
238 :param width: the bit width of the input and output
239 :param partition_points: the input partition points
240 :param partition_step: a multiplier (typically double) step
241 which in-place "expands" the partition points
244 self
.pmul
= partition_step
245 self
.a
= Signal(width
, reset_less
=True)
246 self
.b
= Signal(width
, reset_less
=True)
247 self
.output
= Signal(width
, reset_less
=True)
248 self
.partition_points
= PartitionPoints(partition_points
)
249 if not self
.partition_points
.fits_in_width(width
):
250 raise ValueError("partition_points doesn't fit in width")
252 for i
in range(self
.width
):
253 if i
in self
.partition_points
:
256 self
._expanded
_width
= expanded_width
258 def elaborate(self
, platform
):
259 """Elaborate this module."""
261 expanded_a
= Signal(self
._expanded
_width
, reset_less
=True)
262 expanded_b
= Signal(self
._expanded
_width
, reset_less
=True)
263 expanded_o
= Signal(self
._expanded
_width
, reset_less
=True)
266 # store bits in a list, use Cat later. graphviz is much cleaner
267 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
269 # partition points are "breaks" (extra zeros or 1s) in what would
270 # otherwise be a massive long add. when the "break" points are 0,
271 # whatever is in it (in the output) is discarded. however when
272 # there is a "1", it causes a roll-over carry to the *next* bit.
273 # we still ignore the "break" bit in the [intermediate] output,
274 # however by that time we've got the effect that we wanted: the
275 # carry has been carried *over* the break point.
277 for i
in range(self
.width
):
278 pi
= i
/self
.pmul
# double the range of the partition point test
279 if pi
.is_integer() and pi
in self
.partition_points
:
280 # add extra bit set to 0 + 0 for enabled partition points
281 # and 1 + 0 for disabled partition points
282 ea
.append(expanded_a
[expanded_index
])
283 al
.append(~self
.partition_points
[pi
]) # add extra bit in a
284 eb
.append(expanded_b
[expanded_index
])
285 bl
.append(C(0)) # yes, add a zero
286 expanded_index
+= 1 # skip the extra point. NOT in the output
287 ea
.append(expanded_a
[expanded_index
])
288 eb
.append(expanded_b
[expanded_index
])
289 eo
.append(expanded_o
[expanded_index
])
292 ol
.append(self
.output
[i
])
295 # combine above using Cat
296 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
297 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
298 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
300 # use only one addition to take advantage of look-ahead carry and
301 # special hardware on FPGAs
302 m
.d
.comb
+= expanded_o
.eq(expanded_a
+ expanded_b
)
306 FULL_ADDER_INPUT_COUNT
= 3
310 def __init__(self
, part_pts
, n_inputs
, output_width
, n_parts
):
311 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
312 for i
in range(n_parts
)]
313 self
.terms
= [Signal(output_width
, name
=f
"inputs_{i}",
315 for i
in range(n_inputs
)]
316 self
.part_pts
= part_pts
.like()
318 def eq_from(self
, part_pts
, inputs
, part_ops
):
319 return [self
.part_pts
.eq(part_pts
)] + \
320 [self
.terms
[i
].eq(inputs
[i
])
321 for i
in range(len(self
.terms
))] + \
322 [self
.part_ops
[i
].eq(part_ops
[i
])
323 for i
in range(len(self
.part_ops
))]
326 return self
.eq_from(rhs
.part_pts
, rhs
.terms
, rhs
.part_ops
)
329 class FinalReduceData
:
331 def __init__(self
, part_pts
, output_width
, n_parts
):
332 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
333 for i
in range(n_parts
)]
334 self
.output
= Signal(output_width
, reset_less
=True)
335 self
.part_pts
= part_pts
.like()
337 def eq_from(self
, part_pts
, output
, part_ops
):
338 return [self
.part_pts
.eq(part_pts
)] + \
339 [self
.output
.eq(output
)] + \
340 [self
.part_ops
[i
].eq(part_ops
[i
])
341 for i
in range(len(self
.part_ops
))]
344 return self
.eq_from(rhs
.part_pts
, rhs
.output
, rhs
.part_ops
)
347 class FinalAdd(Elaboratable
):
348 """ Final stage of add reduce
351 def __init__(self
, lidx
, n_inputs
, output_width
, n_parts
, partition_points
,
354 self
.partition_step
= partition_step
355 self
.output_width
= output_width
356 self
.n_inputs
= n_inputs
357 self
.n_parts
= n_parts
358 self
.partition_points
= PartitionPoints(partition_points
)
359 if not self
.partition_points
.fits_in_width(output_width
):
360 raise ValueError("partition_points doesn't fit in output_width")
362 self
.i
= self
.ispec()
363 self
.o
= self
.ospec()
366 return AddReduceData(self
.partition_points
, self
.n_inputs
,
367 self
.output_width
, self
.n_parts
)
370 return FinalReduceData(self
.partition_points
,
371 self
.output_width
, self
.n_parts
)
373 def setup(self
, m
, i
):
374 m
.submodules
.finaladd
= self
375 m
.d
.comb
+= self
.i
.eq(i
)
377 def process(self
, i
):
380 def elaborate(self
, platform
):
381 """Elaborate this module."""
384 output_width
= self
.output_width
385 output
= Signal(output_width
, reset_less
=True)
386 if self
.n_inputs
== 0:
387 # use 0 as the default output value
388 m
.d
.comb
+= output
.eq(0)
389 elif self
.n_inputs
== 1:
390 # handle single input
391 m
.d
.comb
+= output
.eq(self
.i
.terms
[0])
393 # base case for adding 2 inputs
394 assert self
.n_inputs
== 2
395 adder
= PartitionedAdder(output_width
,
396 self
.i
.part_pts
, self
.partition_step
)
397 m
.submodules
.final_adder
= adder
398 m
.d
.comb
+= adder
.a
.eq(self
.i
.terms
[0])
399 m
.d
.comb
+= adder
.b
.eq(self
.i
.terms
[1])
400 m
.d
.comb
+= output
.eq(adder
.output
)
403 m
.d
.comb
+= self
.o
.eq_from(self
.i
.part_pts
, output
,
409 class AddReduceSingle(Elaboratable
):
410 """Add list of numbers together.
412 :attribute inputs: input ``Signal``s to be summed. Modification not
413 supported, except for by ``Signal.eq``.
414 :attribute register_levels: List of nesting levels that should have
416 :attribute output: output sum.
417 :attribute partition_points: the input partition points. Modification not
418 supported, except for by ``Signal.eq``.
421 def __init__(self
, lidx
, n_inputs
, output_width
, n_parts
, partition_points
,
423 """Create an ``AddReduce``.
425 :param inputs: input ``Signal``s to be summed.
426 :param output_width: bit-width of ``output``.
427 :param partition_points: the input partition points.
430 self
.partition_step
= partition_step
431 self
.n_inputs
= n_inputs
432 self
.n_parts
= n_parts
433 self
.output_width
= output_width
434 self
.partition_points
= PartitionPoints(partition_points
)
435 if not self
.partition_points
.fits_in_width(output_width
):
436 raise ValueError("partition_points doesn't fit in output_width")
438 self
.groups
= AddReduceSingle
.full_adder_groups(n_inputs
)
439 self
.n_terms
= AddReduceSingle
.calc_n_inputs(n_inputs
, self
.groups
)
441 self
.i
= self
.ispec()
442 self
.o
= self
.ospec()
445 return AddReduceData(self
.partition_points
, self
.n_inputs
,
446 self
.output_width
, self
.n_parts
)
449 return AddReduceData(self
.partition_points
, self
.n_terms
,
450 self
.output_width
, self
.n_parts
)
452 def setup(self
, m
, i
):
453 setattr(m
.submodules
, "addreduce_%d" % self
.lidx
, self
)
454 m
.d
.comb
+= self
.i
.eq(i
)
456 def process(self
, i
):
460 def calc_n_inputs(n_inputs
, groups
):
461 retval
= len(groups
)*2
462 if n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
464 elif n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
467 assert n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
471 def get_max_level(input_count
):
472 """Get the maximum level.
474 All ``register_levels`` must be less than or equal to the maximum
479 groups
= AddReduceSingle
.full_adder_groups(input_count
)
482 input_count
%= FULL_ADDER_INPUT_COUNT
483 input_count
+= 2 * len(groups
)
487 def full_adder_groups(input_count
):
488 """Get ``inputs`` indices for which a full adder should be built."""
490 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
491 FULL_ADDER_INPUT_COUNT
)
493 def create_next_terms(self
):
494 """ create next intermediate terms, for linking up in elaborate, below
499 # create full adders for this recursive level.
500 # this shrinks N terms to 2 * (N // 3) plus the remainder
501 for i
in self
.groups
:
502 adder_i
= MaskedFullAdder(self
.output_width
)
503 adders
.append((i
, adder_i
))
504 # add both the sum and the masked-carry to the next level.
505 # 3 inputs have now been reduced to 2...
506 terms
.append(adder_i
.sum)
507 terms
.append(adder_i
.mcarry
)
508 # handle the remaining inputs.
509 if self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
510 terms
.append(self
.i
.terms
[-1])
511 elif self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
512 # Just pass the terms to the next layer, since we wouldn't gain
513 # anything by using a half adder since there would still be 2 terms
514 # and just passing the terms to the next layer saves gates.
515 terms
.append(self
.i
.terms
[-2])
516 terms
.append(self
.i
.terms
[-1])
518 assert self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
522 def elaborate(self
, platform
):
523 """Elaborate this module."""
526 terms
, adders
= self
.create_next_terms()
528 # copy the intermediate terms to the output
529 for i
, value
in enumerate(terms
):
530 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
532 # copy reg part points and part ops to output
533 m
.d
.comb
+= self
.o
.part_pts
.eq(self
.i
.part_pts
)
534 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
535 for i
in range(len(self
.i
.part_ops
))]
537 # set up the partition mask (for the adders)
538 part_mask
= Signal(self
.output_width
, reset_less
=True)
540 # get partition points as a mask
541 mask
= self
.i
.part_pts
.as_mask(self
.output_width
,
542 mul
=self
.partition_step
)
543 m
.d
.comb
+= part_mask
.eq(mask
)
545 # add and link the intermediate term modules
546 for i
, (iidx
, adder_i
) in enumerate(adders
):
547 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
549 m
.d
.comb
+= adder_i
.in0
.eq(self
.i
.terms
[iidx
])
550 m
.d
.comb
+= adder_i
.in1
.eq(self
.i
.terms
[iidx
+ 1])
551 m
.d
.comb
+= adder_i
.in2
.eq(self
.i
.terms
[iidx
+ 2])
552 m
.d
.comb
+= adder_i
.mask
.eq(part_mask
)
557 class AddReduceInternal
:
558 """Recursively Add list of numbers together.
560 :attribute inputs: input ``Signal``s to be summed. Modification not
561 supported, except for by ``Signal.eq``.
562 :attribute register_levels: List of nesting levels that should have
564 :attribute output: output sum.
565 :attribute partition_points: the input partition points. Modification not
566 supported, except for by ``Signal.eq``.
569 def __init__(self
, i
, output_width
, partition_step
=1):
570 """Create an ``AddReduce``.
572 :param inputs: input ``Signal``s to be summed.
573 :param output_width: bit-width of ``output``.
574 :param partition_points: the input partition points.
577 self
.inputs
= i
.terms
578 self
.part_ops
= i
.part_ops
579 self
.output_width
= output_width
580 self
.partition_points
= i
.part_pts
581 self
.partition_step
= partition_step
585 def create_levels(self
):
586 """creates reduction levels"""
589 partition_points
= self
.partition_points
590 part_ops
= self
.part_ops
591 n_parts
= len(part_ops
)
595 groups
= AddReduceSingle
.full_adder_groups(len(inputs
))
599 next_level
= AddReduceSingle(lidx
, ilen
, self
.output_width
, n_parts
,
602 mods
.append(next_level
)
603 partition_points
= next_level
.i
.part_pts
604 inputs
= next_level
.o
.terms
606 part_ops
= next_level
.i
.part_ops
609 next_level
= FinalAdd(lidx
, ilen
, self
.output_width
, n_parts
,
610 partition_points
, self
.partition_step
)
611 mods
.append(next_level
)
616 class AddReduce(AddReduceInternal
, Elaboratable
):
617 """Recursively Add list of numbers together.
619 :attribute inputs: input ``Signal``s to be summed. Modification not
620 supported, except for by ``Signal.eq``.
621 :attribute register_levels: List of nesting levels that should have
623 :attribute output: output sum.
624 :attribute partition_points: the input partition points. Modification not
625 supported, except for by ``Signal.eq``.
628 def __init__(self
, inputs
, output_width
, register_levels
, part_pts
,
629 part_ops
, partition_step
=1):
630 """Create an ``AddReduce``.
632 :param inputs: input ``Signal``s to be summed.
633 :param output_width: bit-width of ``output``.
634 :param register_levels: List of nesting levels that should have
636 :param partition_points: the input partition points.
638 self
._inputs
= inputs
639 self
._part
_pts
= part_pts
640 self
._part
_ops
= part_ops
641 n_parts
= len(part_ops
)
642 self
.i
= AddReduceData(part_pts
, len(inputs
),
643 output_width
, n_parts
)
644 AddReduceInternal
.__init
__(self
, self
.i
, output_width
, partition_step
)
645 self
.o
= FinalReduceData(part_pts
, output_width
, n_parts
)
646 self
.register_levels
= register_levels
649 def get_max_level(input_count
):
650 return AddReduceSingle
.get_max_level(input_count
)
653 def next_register_levels(register_levels
):
654 """``Iterable`` of ``register_levels`` for next recursive level."""
655 for level
in register_levels
:
659 def elaborate(self
, platform
):
660 """Elaborate this module."""
663 m
.d
.comb
+= self
.i
.eq_from(self
._part
_pts
, self
._inputs
, self
._part
_ops
)
665 for i
, next_level
in enumerate(self
.levels
):
666 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
669 for idx
in range(len(self
.levels
)):
670 mcur
= self
.levels
[idx
]
671 if idx
in self
.register_levels
:
672 m
.d
.sync
+= mcur
.i
.eq(i
)
674 m
.d
.comb
+= mcur
.i
.eq(i
)
675 i
= mcur
.o
# for next loop
677 # output comes from last module
678 m
.d
.comb
+= self
.o
.eq(i
)
684 OP_MUL_SIGNED_HIGH
= 1
685 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
686 OP_MUL_UNSIGNED_HIGH
= 3
689 def get_term(value
, shift
=0, enabled
=None):
690 if enabled
is not None:
691 value
= Mux(enabled
, value
, 0)
693 value
= Cat(Repl(C(0, 1), shift
), value
)
699 class ProductTerm(Elaboratable
):
700 """ this class creates a single product term (a[..]*b[..]).
701 it has a design flaw in that is the *output* that is selected,
702 where the multiplication(s) are combinatorially generated
706 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
707 self
.a_index
= a_index
708 self
.b_index
= b_index
709 shift
= 8 * (self
.a_index
+ self
.b_index
)
715 self
.ti
= Signal(self
.width
, reset_less
=True)
716 self
.term
= Signal(twidth
, reset_less
=True)
717 self
.a
= Signal(twidth
//2, reset_less
=True)
718 self
.b
= Signal(twidth
//2, reset_less
=True)
719 self
.pb_en
= Signal(pbwid
, reset_less
=True)
722 min_index
= min(self
.a_index
, self
.b_index
)
723 max_index
= max(self
.a_index
, self
.b_index
)
724 for i
in range(min_index
, max_index
):
725 tl
.append(self
.pb_en
[i
])
726 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
728 term_enabled
= Signal(name
=name
, reset_less
=True)
731 self
.enabled
= term_enabled
732 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
734 def elaborate(self
, platform
):
737 if self
.enabled
is not None:
738 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
740 bsa
= Signal(self
.width
, reset_less
=True)
741 bsb
= Signal(self
.width
, reset_less
=True)
742 a_index
, b_index
= self
.a_index
, self
.b_index
744 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
745 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
746 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
747 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
749 #TODO: sort out width issues, get inputs a/b switched on/off.
750 #data going into Muxes is 1/2 the required width
754 bsa = Signal(self.twidth//2, reset_less=True)
755 bsb = Signal(self.twidth//2, reset_less=True)
756 asel = Signal(width, reset_less=True)
757 bsel = Signal(width, reset_less=True)
758 a_index, b_index = self.a_index, self.b_index
759 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
760 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
761 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
762 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
763 m.d.comb += self.ti.eq(bsa * bsb)
764 m.d.comb += self.term.eq(self.ti)
770 class ProductTerms(Elaboratable
):
771 """ creates a bank of product terms. also performs the actual bit-selection
772 this class is to be wrapped with a for-loop on the "a" operand.
773 it creates a second-level for-loop on the "b" operand.
775 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
776 self
.a_index
= a_index
781 self
.a
= Signal(twidth
//2, reset_less
=True)
782 self
.b
= Signal(twidth
//2, reset_less
=True)
783 self
.pb_en
= Signal(pbwid
, reset_less
=True)
784 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
785 for i
in range(blen
)]
787 def elaborate(self
, platform
):
791 for b_index
in range(self
.blen
):
792 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
793 self
.a_index
, b_index
)
794 setattr(m
.submodules
, "term_%d" % b_index
, t
)
796 m
.d
.comb
+= t
.a
.eq(self
.a
)
797 m
.d
.comb
+= t
.b
.eq(self
.b
)
798 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
800 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
805 class LSBNegTerm(Elaboratable
):
807 def __init__(self
, bit_width
):
808 self
.bit_width
= bit_width
809 self
.part
= Signal(reset_less
=True)
810 self
.signed
= Signal(reset_less
=True)
811 self
.op
= Signal(bit_width
, reset_less
=True)
812 self
.msb
= Signal(reset_less
=True)
813 self
.nt
= Signal(bit_width
*2, reset_less
=True)
814 self
.nl
= Signal(bit_width
*2, reset_less
=True)
816 def elaborate(self
, platform
):
819 bit_wid
= self
.bit_width
820 ext
= Repl(0, bit_wid
) # extend output to HI part
822 # determine sign of each incoming number *in this partition*
823 enabled
= Signal(reset_less
=True)
824 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
826 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
827 # negation operation is split into a bitwise not and a +1.
828 # likewise for 16, 32, and 64-bit values.
830 # width-extended 1s complement if a is signed, otherwise zero
831 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
833 # add 1 if signed, otherwise add zero
834 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
839 class Parts(Elaboratable
):
841 def __init__(self
, pbwid
, part_pts
, n_parts
):
844 self
.part_pts
= PartitionPoints
.like(part_pts
)
846 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
847 for i
in range(n_parts
)]
849 def elaborate(self
, platform
):
852 part_pts
, parts
= self
.part_pts
, self
.parts
853 # collect part-bytes (double factor because the input is extended)
854 pbs
= Signal(self
.pbwid
, reset_less
=True)
856 for i
in range(self
.pbwid
):
857 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
858 m
.d
.comb
+= pb
.eq(part_pts
.part_byte(i
))
860 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
862 # negated-temporary copy of partition bits
863 npbs
= Signal
.like(pbs
, reset_less
=True)
864 m
.d
.comb
+= npbs
.eq(~pbs
)
865 byte_count
= 8 // len(parts
)
866 for i
in range(len(parts
)):
868 pbl
.append(npbs
[i
* byte_count
- 1])
869 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
871 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
872 value
= Signal(len(pbl
), name
="value_%d" % i
, reset_less
=True)
873 m
.d
.comb
+= value
.eq(Cat(*pbl
))
874 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
879 class Part(Elaboratable
):
880 """ a key class which, depending on the partitioning, will determine
881 what action to take when parts of the output are signed or unsigned.
883 this requires 2 pieces of data *per operand, per partition*:
884 whether the MSB is HI/LO (per partition!), and whether a signed
885 or unsigned operation has been *requested*.
887 once that is determined, signed is basically carried out
888 by splitting 2's complement into 1's complement plus one.
889 1's complement is just a bit-inversion.
891 the extra terms - as separate terms - are then thrown at the
892 AddReduce alongside the multiplication part-results.
894 def __init__(self
, part_pts
, width
, n_parts
, pbwid
):
897 self
.part_pts
= part_pts
900 self
.a
= Signal(64, reset_less
=True)
901 self
.b
= Signal(64, reset_less
=True)
902 self
.a_signed
= [Signal(name
=f
"a_signed_{i}", reset_less
=True)
904 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}", reset_less
=True)
906 self
.pbs
= Signal(pbwid
, reset_less
=True)
909 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
910 for i
in range(n_parts
)]
912 self
.not_a_term
= Signal(width
, reset_less
=True)
913 self
.neg_lsb_a_term
= Signal(width
, reset_less
=True)
914 self
.not_b_term
= Signal(width
, reset_less
=True)
915 self
.neg_lsb_b_term
= Signal(width
, reset_less
=True)
917 def elaborate(self
, platform
):
920 pbs
, parts
= self
.pbs
, self
.parts
921 part_pts
= self
.part_pts
922 m
.submodules
.p
= p
= Parts(self
.pbwid
, part_pts
, len(parts
))
923 m
.d
.comb
+= p
.part_pts
.eq(part_pts
)
926 byte_count
= 8 // len(parts
)
928 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= (
929 self
.not_a_term
, self
.neg_lsb_a_term
,
930 self
.not_b_term
, self
.neg_lsb_b_term
)
932 byte_width
= 8 // len(parts
) # byte width
933 bit_wid
= 8 * byte_width
# bit width
934 nat
, nbt
, nla
, nlb
= [], [], [], []
935 for i
in range(len(parts
)):
936 # work out bit-inverted and +1 term for a.
937 pa
= LSBNegTerm(bit_wid
)
938 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
939 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
940 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
941 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
942 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
946 # work out bit-inverted and +1 term for b
947 pb
= LSBNegTerm(bit_wid
)
948 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
949 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
950 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
951 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
952 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
956 # concatenate together and return all 4 results.
957 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
958 not_b_term
.eq(Cat(*nbt
)),
959 neg_lsb_a_term
.eq(Cat(*nla
)),
960 neg_lsb_b_term
.eq(Cat(*nlb
)),
966 class IntermediateOut(Elaboratable
):
967 """ selects the HI/LO part of the multiplication, for a given bit-width
968 the output is also reconstructed in its SIMD (partition) lanes.
970 def __init__(self
, width
, out_wid
, n_parts
):
972 self
.n_parts
= n_parts
973 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
975 self
.intermed
= Signal(out_wid
, reset_less
=True)
976 self
.output
= Signal(out_wid
//2, reset_less
=True)
978 def elaborate(self
, platform
):
984 for i
in range(self
.n_parts
):
985 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
987 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
988 self
.intermed
.bit_select(i
* w
*2, w
),
989 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
991 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
996 class FinalOut(Elaboratable
):
997 """ selects the final output based on the partitioning.
999 each byte is selectable independently, i.e. it is possible
1000 that some partitions requested 8-bit computation whilst others
1001 requested 16 or 32 bit.
1003 def __init__(self
, output_width
, n_parts
, part_pts
):
1004 self
.part_pts
= part_pts
1005 self
.output_width
= output_width
1006 self
.n_parts
= n_parts
1007 self
.out_wid
= output_width
//2
1009 self
.i
= self
.ispec()
1010 self
.o
= self
.ospec()
1013 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
1018 def setup(self
, m
, i
):
1019 m
.submodules
.finalout
= self
1020 m
.d
.comb
+= self
.i
.eq(i
)
1022 def process(self
, i
):
1025 def elaborate(self
, platform
):
1028 part_pts
= self
.part_pts
1029 m
.submodules
.p_8
= p_8
= Parts(8, part_pts
, 8)
1030 m
.submodules
.p_16
= p_16
= Parts(8, part_pts
, 4)
1031 m
.submodules
.p_32
= p_32
= Parts(8, part_pts
, 2)
1032 m
.submodules
.p_64
= p_64
= Parts(8, part_pts
, 1)
1034 out_part_pts
= self
.i
.part_pts
1037 d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
1038 d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
1039 d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
1041 i8
= Signal(self
.out_wid
, reset_less
=True)
1042 i16
= Signal(self
.out_wid
, reset_less
=True)
1043 i32
= Signal(self
.out_wid
, reset_less
=True)
1044 i64
= Signal(self
.out_wid
, reset_less
=True)
1046 m
.d
.comb
+= p_8
.part_pts
.eq(out_part_pts
)
1047 m
.d
.comb
+= p_16
.part_pts
.eq(out_part_pts
)
1048 m
.d
.comb
+= p_32
.part_pts
.eq(out_part_pts
)
1049 m
.d
.comb
+= p_64
.part_pts
.eq(out_part_pts
)
1051 for i
in range(len(p_8
.parts
)):
1052 m
.d
.comb
+= d8
[i
].eq(p_8
.parts
[i
])
1053 for i
in range(len(p_16
.parts
)):
1054 m
.d
.comb
+= d16
[i
].eq(p_16
.parts
[i
])
1055 for i
in range(len(p_32
.parts
)):
1056 m
.d
.comb
+= d32
[i
].eq(p_32
.parts
[i
])
1057 m
.d
.comb
+= i8
.eq(self
.i
.outputs
[0])
1058 m
.d
.comb
+= i16
.eq(self
.i
.outputs
[1])
1059 m
.d
.comb
+= i32
.eq(self
.i
.outputs
[2])
1060 m
.d
.comb
+= i64
.eq(self
.i
.outputs
[3])
1064 # select one of the outputs: d8 selects i8, d16 selects i16
1065 # d32 selects i32, and the default is i64.
1066 # d8 and d16 are ORed together in the first Mux
1067 # then the 2nd selects either i8 or i16.
1068 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
1069 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
1071 Mux(d8
[i
] | d16
[i
// 2],
1072 Mux(d8
[i
], i8
.bit_select(i
* 8, 8),
1073 i16
.bit_select(i
* 8, 8)),
1074 Mux(d32
[i
// 4], i32
.bit_select(i
* 8, 8),
1075 i64
.bit_select(i
* 8, 8))))
1079 m
.d
.comb
+= self
.o
.output
.eq(Cat(*ol
))
1080 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.intermediate_output
)
1085 class OrMod(Elaboratable
):
1086 """ ORs four values together in a hierarchical tree
1088 def __init__(self
, wid
):
1090 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
1092 self
.orout
= Signal(wid
, reset_less
=True)
1094 def elaborate(self
, platform
):
1096 or1
= Signal(self
.wid
, reset_less
=True)
1097 or2
= Signal(self
.wid
, reset_less
=True)
1098 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
1099 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
1100 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
1105 class Signs(Elaboratable
):
1106 """ determines whether a or b are signed numbers
1107 based on the required operation type (OP_MUL_*)
1111 self
.part_ops
= Signal(2, reset_less
=True)
1112 self
.a_signed
= Signal(reset_less
=True)
1113 self
.b_signed
= Signal(reset_less
=True)
1115 def elaborate(self
, platform
):
1119 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
1120 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
1121 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
1122 m
.d
.comb
+= self
.a_signed
.eq(asig
)
1123 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
1128 class IntermediateData
:
1130 def __init__(self
, part_pts
, output_width
, n_parts
):
1131 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
1132 for i
in range(n_parts
)]
1133 self
.part_pts
= part_pts
.like()
1134 self
.outputs
= [Signal(output_width
, name
="io%d" % i
, reset_less
=True)
1136 # intermediates (needed for unit tests)
1137 self
.intermediate_output
= Signal(output_width
)
1139 def eq_from(self
, part_pts
, outputs
, intermediate_output
,
1141 return [self
.part_pts
.eq(part_pts
)] + \
1142 [self
.intermediate_output
.eq(intermediate_output
)] + \
1143 [self
.outputs
[i
].eq(outputs
[i
])
1144 for i
in range(4)] + \
1145 [self
.part_ops
[i
].eq(part_ops
[i
])
1146 for i
in range(len(self
.part_ops
))]
1149 return self
.eq_from(rhs
.part_pts
, rhs
.outputs
,
1150 rhs
.intermediate_output
, rhs
.part_ops
)
1158 self
.part_pts
= PartitionPoints()
1159 for i
in range(8, 64, 8):
1160 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
1161 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
1163 def eq_from(self
, part_pts
, a
, b
, part_ops
):
1164 return [self
.part_pts
.eq(part_pts
)] + \
1165 [self
.a
.eq(a
), self
.b
.eq(b
)] + \
1166 [self
.part_ops
[i
].eq(part_ops
[i
])
1167 for i
in range(len(self
.part_ops
))]
1170 return self
.eq_from(rhs
.part_pts
, rhs
.a
, rhs
.b
, rhs
.part_ops
)
1176 self
.intermediate_output
= Signal(128) # needed for unit tests
1177 self
.output
= Signal(64)
1180 return [self
.intermediate_output
.eq(rhs
.intermediate_output
),
1181 self
.output
.eq(rhs
.output
)]
1184 class AllTerms(PipeModBase
):
1185 """Set of terms to be added together
1188 def __init__(self
, pspec
):
1189 """Create an ``AllTerms``.
1191 self
.n_inputs
= pspec
.n_inputs
1192 self
.n_parts
= pspec
.n_parts
1193 self
.output_width
= pspec
.width
1194 super().__init
__(pspec
, "allterms")
1200 return AddReduceData(self
.i
.part_pts
, self
.n_inputs
,
1201 self
.output_width
, self
.n_parts
)
1203 def elaborate(self
, platform
):
1206 eps
= self
.i
.part_pts
1208 # collect part-bytes
1209 pbs
= Signal(8, reset_less
=True)
1212 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
1213 m
.d
.comb
+= pb
.eq(eps
.part_byte(i
))
1215 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
1222 setattr(m
.submodules
, "signs%d" % i
, s
)
1223 m
.d
.comb
+= s
.part_ops
.eq(self
.i
.part_ops
[i
])
1225 m
.submodules
.part_8
= part_8
= Part(eps
, 128, 8, 8)
1226 m
.submodules
.part_16
= part_16
= Part(eps
, 128, 4, 8)
1227 m
.submodules
.part_32
= part_32
= Part(eps
, 128, 2, 8)
1228 m
.submodules
.part_64
= part_64
= Part(eps
, 128, 1, 8)
1229 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
1230 for mod
in [part_8
, part_16
, part_32
, part_64
]:
1231 m
.d
.comb
+= mod
.a
.eq(self
.i
.a
)
1232 m
.d
.comb
+= mod
.b
.eq(self
.i
.b
)
1233 for i
in range(len(signs
)):
1234 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
1235 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
1236 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
1237 nat_l
.append(mod
.not_a_term
)
1238 nbt_l
.append(mod
.not_b_term
)
1239 nla_l
.append(mod
.neg_lsb_a_term
)
1240 nlb_l
.append(mod
.neg_lsb_b_term
)
1244 for a_index
in range(8):
1245 t
= ProductTerms(8, 128, 8, a_index
, 8)
1246 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1248 m
.d
.comb
+= t
.a
.eq(self
.i
.a
)
1249 m
.d
.comb
+= t
.b
.eq(self
.i
.b
)
1250 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1252 for term
in t
.terms
:
1255 # it's fine to bitwise-or data together since they are never enabled
1257 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1258 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1259 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1260 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1261 for l
, mod
in [(nat_l
, nat_or
),
1265 for i
in range(len(l
)):
1266 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1267 terms
.append(mod
.orout
)
1269 # copy the intermediate terms to the output
1270 for i
, value
in enumerate(terms
):
1271 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
1273 # copy reg part points and part ops to output
1274 m
.d
.comb
+= self
.o
.part_pts
.eq(eps
)
1275 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
1276 for i
in range(len(self
.i
.part_ops
))]
1281 class Intermediates(Elaboratable
):
1282 """ Intermediate output modules
1285 def __init__(self
, output_width
, n_parts
, part_pts
):
1286 self
.part_pts
= part_pts
1287 self
.output_width
= output_width
1288 self
.n_parts
= n_parts
1290 self
.i
= self
.ispec()
1291 self
.o
= self
.ospec()
1294 return FinalReduceData(self
.part_pts
, self
.output_width
, self
.n_parts
)
1297 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
1299 def setup(self
, m
, i
):
1300 m
.submodules
.intermediates
= self
1301 m
.d
.comb
+= self
.i
.eq(i
)
1303 def process(self
, i
):
1306 def elaborate(self
, platform
):
1309 out_part_ops
= self
.i
.part_ops
1310 out_part_pts
= self
.i
.part_pts
1313 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1314 m
.d
.comb
+= io64
.intermed
.eq(self
.i
.output
)
1316 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
1317 m
.d
.comb
+= self
.o
.outputs
[3].eq(io64
.output
)
1320 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1321 m
.d
.comb
+= io32
.intermed
.eq(self
.i
.output
)
1323 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1324 m
.d
.comb
+= self
.o
.outputs
[2].eq(io32
.output
)
1327 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1328 m
.d
.comb
+= io16
.intermed
.eq(self
.i
.output
)
1330 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1331 m
.d
.comb
+= self
.o
.outputs
[1].eq(io16
.output
)
1334 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1335 m
.d
.comb
+= io8
.intermed
.eq(self
.i
.output
)
1337 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1338 m
.d
.comb
+= self
.o
.outputs
[0].eq(io8
.output
)
1341 m
.d
.comb
+= self
.o
.part_ops
[i
].eq(out_part_ops
[i
])
1342 m
.d
.comb
+= self
.o
.part_pts
.eq(out_part_pts
)
1343 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.output
)
1348 class Mul8_16_32_64(Elaboratable
):
1349 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1351 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1352 partitions on naturally-aligned boundaries. Supports the operation being
1353 set for each partition independently.
1355 :attribute part_pts: the input partition points. Has a partition point at
1356 multiples of 8 in 0 < i < 64. Each partition point's associated
1357 ``Value`` is a ``Signal``. Modification not supported, except for by
1359 :attribute part_ops: the operation for each byte. The operation for a
1360 particular partition is selected by assigning the selected operation
1361 code to each byte in the partition. The allowed operation codes are:
1363 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1364 RISC-V's `mul` instruction.
1365 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1366 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1368 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1369 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1370 `mulhsu` instruction.
1371 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1372 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1376 def __init__(self
, register_levels
=()):
1377 """ register_levels: specifies the points in the cascade at which
1378 flip-flops are to be inserted.
1381 self
.id_wid
= 0 # num_bits(num_rows)
1383 self
.pspec
= PipelineSpec(128, self
.id_wid
, self
.op_wid
, n_ops
=3)
1384 self
.pspec
.n_inputs
= 64 + 4
1385 self
.pspec
.n_parts
= 8
1388 self
.register_levels
= list(register_levels
)
1390 self
.i
= self
.ispec()
1391 self
.o
= self
.ospec()
1394 self
.part_pts
= self
.i
.part_pts
1395 self
.part_ops
= self
.i
.part_ops
1400 self
.intermediate_output
= self
.o
.intermediate_output
1401 self
.output
= self
.o
.output
1409 def elaborate(self
, platform
):
1412 part_pts
= self
.part_pts
1414 n_parts
= self
.pspec
.n_parts
1415 n_inputs
= self
.pspec
.n_inputs
1416 output_width
= self
.pspec
.width
1417 t
= AllTerms(self
.pspec
)
1422 at
= AddReduceInternal(t
.process(self
.i
), 128, partition_step
=2)
1425 for idx
in range(len(at
.levels
)):
1426 mcur
= at
.levels
[idx
]
1429 if idx
in self
.register_levels
:
1430 m
.d
.sync
+= o
.eq(mcur
.process(i
))
1432 m
.d
.comb
+= o
.eq(mcur
.process(i
))
1433 i
= o
# for next loop
1435 interm
= Intermediates(128, 8, part_pts
)
1437 o
= interm
.process(interm
.i
)
1440 finalout
= FinalOut(128, 8, part_pts
)
1441 finalout
.setup(m
, o
)
1442 m
.d
.comb
+= self
.o
.eq(finalout
.process(o
))
1447 if __name__
== "__main__":
1451 m
.intermediate_output
,
1454 *m
.part_pts
.values()])