1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
59 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
60 retval
= PartitionPoints()
61 for point
, enabled
in self
.items():
62 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self
.keys()) != set(rhs
.keys()):
68 raise ValueError("incompatible point set")
69 for point
, enabled
in self
.items():
70 yield enabled
.eq(rhs
[point
])
72 def as_mask(self
, width
):
73 """Create a bit-mask from `self`.
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
78 :param width: the bit width of the resulting mask
81 for i
in range(width
):
88 def get_max_partition_count(self
, width
):
89 """Get the maximum number of partitions.
91 Gets the number of partitions when all partition points are enabled.
94 for point
in self
.keys():
99 def fits_in_width(self
, width
):
100 """Check if all partition points are smaller than `width`."""
101 for point
in self
.keys():
107 class FullAdder(Elaboratable
):
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
122 def __init__(self
, width
):
123 """Create a ``FullAdder``.
125 :param width: the bit width of the input and output
127 self
.in0
= Signal(width
)
128 self
.in1
= Signal(width
)
129 self
.in2
= Signal(width
)
130 self
.sum = Signal(width
)
131 self
.carry
= Signal(width
)
133 def elaborate(self
, platform
):
134 """Elaborate this module."""
136 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
137 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
138 |
(self
.in1
& self
.in2
)
139 |
(self
.in2
& self
.in0
))
143 class MaskedFullAdder(FullAdder
):
144 """Masked Full Adder.
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
158 def __init__(self
, width
):
159 """Create a ``MaskedFullAdder``.
161 :param width: the bit width of the input and output
163 FullAdder
.__init
__(self
, width
)
164 self
.mask
= Signal(width
)
165 self
.mcarry
= Signal(width
)
167 def elaborate(self
, platform
):
168 """Elaborate this module."""
169 m
= FullAdder
.elaborate(self
, platform
)
170 m
.d
.comb
+= self
.mcarry
.eq((self
.carry
<< 1) & self
.mask
)
174 class PartitionedAdder(Elaboratable
):
175 """Partitioned Adder.
177 Performs the final add. The partition points are included in the
178 actual add (in one of the operands only), which causes a carry over
179 to the next bit. Then the final output *removes* the extra bits from
182 :attribute width: the bit width of the input and output. Read-only.
183 :attribute a: the first input to the adder
184 :attribute b: the second input to the adder
185 :attribute output: the sum output
186 :attribute partition_points: the input partition points. Modification not
187 supported, except for by ``Signal.eq``.
190 def __init__(self
, width
, partition_points
):
191 """Create a ``PartitionedAdder``.
193 :param width: the bit width of the input and output
194 :param partition_points: the input partition points
197 self
.a
= Signal(width
)
198 self
.b
= Signal(width
)
199 self
.output
= Signal(width
)
200 self
.partition_points
= PartitionPoints(partition_points
)
201 if not self
.partition_points
.fits_in_width(width
):
202 raise ValueError("partition_points doesn't fit in width")
204 for i
in range(self
.width
):
205 if i
in self
.partition_points
:
208 self
._expanded
_width
= expanded_width
209 # XXX these have to remain here due to some horrible nmigen
210 # simulation bugs involving sync. it is *not* necessary to
211 # have them here, they should (under normal circumstances)
212 # be moved into elaborate, as they are entirely local
213 self
._expanded
_a
= Signal(expanded_width
)
214 self
._expanded
_b
= Signal(expanded_width
)
215 self
._expanded
_output
= Signal(expanded_width
)
217 def elaborate(self
, platform
):
218 """Elaborate this module."""
221 # store bits in a list, use Cat later. graphviz is much cleaner
222 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
224 # partition points are "breaks" (extra zeros or 1s) in what would
225 # otherwise be a massive long add. when the "break" points are 0,
226 # whatever is in it (in the output) is discarded. however when
227 # there is a "1", it causes a roll-over carry to the *next* bit.
228 # we still ignore the "break" bit in the [intermediate] output,
229 # however by that time we've got the effect that we wanted: the
230 # carry has been carried *over* the break point.
232 for i
in range(self
.width
):
233 if i
in self
.partition_points
:
234 # add extra bit set to 0 + 0 for enabled partition points
235 # and 1 + 0 for disabled partition points
236 ea
.append(self
._expanded
_a
[expanded_index
])
237 al
.append(~self
.partition_points
[i
]) # add extra bit in a
238 eb
.append(self
._expanded
_b
[expanded_index
])
239 bl
.append(C(0)) # do *not* add extra bit into b.
241 ea
.append(self
._expanded
_a
[expanded_index
])
243 eb
.append(self
._expanded
_b
[expanded_index
])
245 eo
.append(self
._expanded
_output
[expanded_index
])
246 ol
.append(self
.output
[i
])
249 # combine above using Cat
250 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
251 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
252 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
254 # use only one addition to take advantage of look-ahead carry and
255 # special hardware on FPGAs
256 m
.d
.comb
+= self
._expanded
_output
.eq(
257 self
._expanded
_a
+ self
._expanded
_b
)
261 FULL_ADDER_INPUT_COUNT
= 3
264 class AddReduce(Elaboratable
):
265 """Add list of numbers together.
267 :attribute inputs: input ``Signal``s to be summed. Modification not
268 supported, except for by ``Signal.eq``.
269 :attribute register_levels: List of nesting levels that should have
271 :attribute output: output sum.
272 :attribute partition_points: the input partition points. Modification not
273 supported, except for by ``Signal.eq``.
276 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
):
277 """Create an ``AddReduce``.
279 :param inputs: input ``Signal``s to be summed.
280 :param output_width: bit-width of ``output``.
281 :param register_levels: List of nesting levels that should have
283 :param partition_points: the input partition points.
285 self
.inputs
= list(inputs
)
286 self
._resized
_inputs
= [
287 Signal(output_width
, name
=f
"resized_inputs[{i}]")
288 for i
in range(len(self
.inputs
))]
289 self
.register_levels
= list(register_levels
)
290 self
.output
= Signal(output_width
)
291 self
.partition_points
= PartitionPoints(partition_points
)
292 if not self
.partition_points
.fits_in_width(output_width
):
293 raise ValueError("partition_points doesn't fit in output_width")
294 self
._reg
_partition
_points
= self
.partition_points
.like()
295 max_level
= AddReduce
.get_max_level(len(self
.inputs
))
296 for level
in self
.register_levels
:
297 if level
> max_level
:
299 "not enough adder levels for specified register levels")
302 def get_max_level(input_count
):
303 """Get the maximum level.
305 All ``register_levels`` must be less than or equal to the maximum
310 groups
= AddReduce
.full_adder_groups(input_count
)
313 input_count
%= FULL_ADDER_INPUT_COUNT
314 input_count
+= 2 * len(groups
)
317 def next_register_levels(self
):
318 """``Iterable`` of ``register_levels`` for next recursive level."""
319 for level
in self
.register_levels
:
324 def full_adder_groups(input_count
):
325 """Get ``inputs`` indices for which a full adder should be built."""
327 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
328 FULL_ADDER_INPUT_COUNT
)
330 def elaborate(self
, platform
):
331 """Elaborate this module."""
334 # resize inputs to correct bit-width and optionally add in
336 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
337 for i
in range(len(self
.inputs
))]
338 if 0 in self
.register_levels
:
339 m
.d
.sync
+= resized_input_assignments
340 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
342 m
.d
.comb
+= resized_input_assignments
343 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
345 groups
= AddReduce
.full_adder_groups(len(self
.inputs
))
346 # if there are no full adders to create, then we handle the base cases
347 # and return, otherwise we go on to the recursive case
349 if len(self
.inputs
) == 0:
350 # use 0 as the default output value
351 m
.d
.comb
+= self
.output
.eq(0)
352 elif len(self
.inputs
) == 1:
353 # handle single input
354 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
356 # base case for adding 2 or more inputs, which get recursively
357 # reduced to 2 inputs
358 assert len(self
.inputs
) == 2
359 adder
= PartitionedAdder(len(self
.output
),
360 self
._reg
_partition
_points
)
361 m
.submodules
.final_adder
= adder
362 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
363 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
364 m
.d
.comb
+= self
.output
.eq(adder
.output
)
366 # go on to handle recursive case
367 intermediate_terms
= []
369 def add_intermediate_term(value
):
370 intermediate_term
= Signal(
372 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
373 intermediate_terms
.append(intermediate_term
)
374 m
.d
.comb
+= intermediate_term
.eq(value
)
376 # store mask in intermediary (simplifies graph)
377 part_mask
= Signal(len(self
.output
), reset_less
=True)
378 mask
= self
._reg
_partition
_points
.as_mask(len(self
.output
))
379 m
.d
.comb
+= part_mask
.eq(mask
)
381 # create full adders for this recursive level.
382 # this shrinks N terms to 2 * (N // 3) plus the remainder
384 adder_i
= MaskedFullAdder(len(self
.output
))
385 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
386 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[i
])
387 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[i
+ 1])
388 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[i
+ 2])
389 m
.d
.comb
+= adder_i
.mask
.eq(part_mask
)
390 add_intermediate_term(adder_i
.sum)
391 # mask out carry bits to prevent carries between partitions
392 add_intermediate_term(adder_i
.mcarry
)
393 # handle the remaining inputs.
394 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
395 add_intermediate_term(self
._resized
_inputs
[-1])
396 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
397 # Just pass the terms to the next layer, since we wouldn't gain
398 # anything by using a half adder since there would still be 2 terms
399 # and just passing the terms to the next layer saves gates.
400 add_intermediate_term(self
._resized
_inputs
[-2])
401 add_intermediate_term(self
._resized
_inputs
[-1])
403 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
404 # recursive invocation of ``AddReduce``
405 next_level
= AddReduce(intermediate_terms
,
407 self
.next_register_levels(),
408 self
._reg
_partition
_points
)
409 m
.submodules
.next_level
= next_level
410 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
415 OP_MUL_SIGNED_HIGH
= 1
416 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
417 OP_MUL_UNSIGNED_HIGH
= 3
420 def get_term(value
, shift
=0, enabled
=None):
421 if enabled
is not None:
422 value
= Mux(enabled
, value
, 0)
424 value
= Cat(Repl(C(0, 1), shift
), value
)
430 class ProductTerm(Elaboratable
):
431 """ this class creates a single product term (a[..]*b[..]).
432 it has a design flaw in that is the *output* that is selected,
433 where the multiplication(s) are combinatorially generated
437 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
438 self
.a_index
= a_index
439 self
.b_index
= b_index
440 shift
= 8 * (self
.a_index
+ self
.b_index
)
446 self
.ti
= Signal(self
.width
, reset_less
=True)
447 self
.term
= Signal(twidth
, reset_less
=True)
448 self
.a
= Signal(twidth
//2, reset_less
=True)
449 self
.b
= Signal(twidth
//2, reset_less
=True)
450 self
.pb_en
= Signal(pbwid
, reset_less
=True)
453 min_index
= min(self
.a_index
, self
.b_index
)
454 max_index
= max(self
.a_index
, self
.b_index
)
455 for i
in range(min_index
, max_index
):
456 tl
.append(self
.pb_en
[i
])
457 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
459 term_enabled
= Signal(name
=name
, reset_less
=True)
462 self
.enabled
= term_enabled
463 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
465 def elaborate(self
, platform
):
468 if self
.enabled
is not None:
469 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
471 bsa
= Signal(self
.width
, reset_less
=True)
472 bsb
= Signal(self
.width
, reset_less
=True)
473 a_index
, b_index
= self
.a_index
, self
.b_index
475 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
476 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
477 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
478 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
480 #TODO: sort out width issues, get inputs a/b switched on/off.
481 #data going into Muxes is 1/2 the required width
485 bsa = Signal(self.twidth//2, reset_less=True)
486 bsb = Signal(self.twidth//2, reset_less=True)
487 asel = Signal(width, reset_less=True)
488 bsel = Signal(width, reset_less=True)
489 a_index, b_index = self.a_index, self.b_index
490 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
491 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
492 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
493 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
494 m.d.comb += self.ti.eq(bsa * bsb)
495 m.d.comb += self.term.eq(self.ti)
501 class ProductTerms(Elaboratable
):
502 """ creates a bank of product terms. also performs the actual bit-selection
503 this class is to be wrapped with a for-loop on the "a" operand.
504 it creates a second-level for-loop on the "b" operand.
506 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
507 self
.a_index
= a_index
512 self
.a
= Signal(twidth
//2, reset_less
=True)
513 self
.b
= Signal(twidth
//2, reset_less
=True)
514 self
.pb_en
= Signal(pbwid
, reset_less
=True)
515 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
516 for i
in range(blen
)]
518 def elaborate(self
, platform
):
522 for b_index
in range(self
.blen
):
523 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
524 self
.a_index
, b_index
)
525 setattr(m
.submodules
, "term_%d" % b_index
, t
)
527 m
.d
.comb
+= t
.a
.eq(self
.a
)
528 m
.d
.comb
+= t
.b
.eq(self
.b
)
529 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
531 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
535 class LSBNegTerm(Elaboratable
):
537 def __init__(self
, bit_width
):
538 self
.bit_width
= bit_width
539 self
.part
= Signal(reset_less
=True)
540 self
.signed
= Signal(reset_less
=True)
541 self
.op
= Signal(bit_width
, reset_less
=True)
542 self
.msb
= Signal(reset_less
=True)
543 self
.nt
= Signal(bit_width
*2, reset_less
=True)
544 self
.nl
= Signal(bit_width
*2, reset_less
=True)
546 def elaborate(self
, platform
):
549 bit_wid
= self
.bit_width
550 ext
= Repl(0, bit_wid
) # extend output to HI part
552 # determine sign of each incoming number *in this partition*
553 enabled
= Signal(reset_less
=True)
554 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
556 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
557 # negation operation is split into a bitwise not and a +1.
558 # likewise for 16, 32, and 64-bit values.
560 # width-extended 1s complement if a is signed, otherwise zero
561 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
563 # add 1 if signed, otherwise add zero
564 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
569 class Part(Elaboratable
):
570 """ a key class which, depending on the partitioning, will determine
571 what action to take when parts of the output are signed or unsigned.
573 this requires 2 pieces of data *per operand, per partition*:
574 whether the MSB is HI/LO (per partition!), and whether a signed
575 or unsigned operation has been *requested*.
577 once that is determined, signed is basically carried out
578 by splitting 2's complement into 1's complement plus one.
579 1's complement is just a bit-inversion.
581 the extra terms - as separate terms - are then thrown at the
582 AddReduce alongside the multiplication part-results.
584 def __init__(self
, width
, n_parts
, n_levels
, pbwid
):
589 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
590 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
591 self
.pbs
= Signal(pbwid
, reset_less
=True)
594 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
595 self
.delayed_parts
= [
596 [Signal(name
=f
"delayed_part_{delay}_{i}")
597 for i
in range(n_parts
)]
598 for delay
in range(n_levels
)]
599 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
600 self
.dplast
= [Signal(name
=f
"dplast_{i}")
601 for i
in range(n_parts
)]
603 self
.not_a_term
= Signal(width
)
604 self
.neg_lsb_a_term
= Signal(width
)
605 self
.not_b_term
= Signal(width
)
606 self
.neg_lsb_b_term
= Signal(width
)
608 def elaborate(self
, platform
):
611 pbs
, parts
, delayed_parts
= self
.pbs
, self
.parts
, self
.delayed_parts
612 # negated-temporary copy of partition bits
613 npbs
= Signal
.like(pbs
, reset_less
=True)
614 m
.d
.comb
+= npbs
.eq(~pbs
)
615 byte_count
= 8 // len(parts
)
616 for i
in range(len(parts
)):
618 pbl
.append(npbs
[i
* byte_count
- 1])
619 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
621 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
622 value
= Signal(len(pbl
), name
="value_%di" % i
, reset_less
=True)
623 m
.d
.comb
+= value
.eq(Cat(*pbl
))
624 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
625 m
.d
.comb
+= delayed_parts
[0][i
].eq(parts
[i
])
626 m
.d
.sync
+= [delayed_parts
[j
+ 1][i
].eq(delayed_parts
[j
][i
])
627 for j
in range(len(delayed_parts
)-1)]
628 m
.d
.comb
+= self
.dplast
[i
].eq(delayed_parts
[-1][i
])
630 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= \
631 self
.not_a_term
, self
.neg_lsb_a_term
, \
632 self
.not_b_term
, self
.neg_lsb_b_term
634 byte_width
= 8 // len(parts
) # byte width
635 bit_wid
= 8 * byte_width
# bit width
636 nat
, nbt
, nla
, nlb
= [], [], [], []
637 for i
in range(len(parts
)):
638 # work out bit-inverted and +1 term for a.
639 pa
= LSBNegTerm(bit_wid
)
640 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
641 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
642 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
643 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
644 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
648 # work out bit-inverted and +1 term for b
649 pb
= LSBNegTerm(bit_wid
)
650 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
651 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
652 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
653 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
654 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
658 # concatenate together and return all 4 results.
659 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
660 not_b_term
.eq(Cat(*nbt
)),
661 neg_lsb_a_term
.eq(Cat(*nla
)),
662 neg_lsb_b_term
.eq(Cat(*nlb
)),
668 class IntermediateOut(Elaboratable
):
669 """ selects the HI/LO part of the multiplication, for a given bit-width
670 the output is also reconstructed in its SIMD (partition) lanes.
672 def __init__(self
, width
, out_wid
, n_parts
):
674 self
.n_parts
= n_parts
675 self
.delayed_part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
677 self
.intermed
= Signal(out_wid
, reset_less
=True)
678 self
.output
= Signal(out_wid
//2, reset_less
=True)
680 def elaborate(self
, platform
):
686 for i
in range(self
.n_parts
):
687 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
689 Mux(self
.delayed_part_ops
[sel
* i
] == OP_MUL_LOW
,
690 self
.intermed
.bit_select(i
* w
*2, w
),
691 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
693 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
698 class FinalOut(Elaboratable
):
699 """ selects the final output based on the partitioning.
701 each byte is selectable independently, i.e. it is possible
702 that some partitions requested 8-bit computation whilst others
703 requested 16 or 32 bit.
705 def __init__(self
, out_wid
):
707 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
708 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
709 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
711 self
.i8
= Signal(out_wid
, reset_less
=True)
712 self
.i16
= Signal(out_wid
, reset_less
=True)
713 self
.i32
= Signal(out_wid
, reset_less
=True)
714 self
.i64
= Signal(out_wid
, reset_less
=True)
717 self
.out
= Signal(out_wid
, reset_less
=True)
719 def elaborate(self
, platform
):
723 # select one of the outputs: d8 selects i8, d16 selects i16
724 # d32 selects i32, and the default is i64.
725 # d8 and d16 are ORed together in the first Mux
726 # then the 2nd selects either i8 or i16.
727 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
728 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
730 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
731 Mux(self
.d8
[i
], self
.i8
.bit_select(i
* 8, 8),
732 self
.i16
.bit_select(i
* 8, 8)),
733 Mux(self
.d32
[i
// 4], self
.i32
.bit_select(i
* 8, 8),
734 self
.i64
.bit_select(i
* 8, 8))))
736 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
740 class OrMod(Elaboratable
):
741 """ ORs four values together in a hierarchical tree
743 def __init__(self
, wid
):
745 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
747 self
.orout
= Signal(wid
, reset_less
=True)
749 def elaborate(self
, platform
):
751 or1
= Signal(self
.wid
, reset_less
=True)
752 or2
= Signal(self
.wid
, reset_less
=True)
753 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
754 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
755 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
760 class Signs(Elaboratable
):
761 """ determines whether a or b are signed numbers
762 based on the required operation type (OP_MUL_*)
766 self
.part_ops
= Signal(2, reset_less
=True)
767 self
.a_signed
= Signal(reset_less
=True)
768 self
.b_signed
= Signal(reset_less
=True)
770 def elaborate(self
, platform
):
774 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
775 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
776 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
777 m
.d
.comb
+= self
.a_signed
.eq(asig
)
778 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
783 class Mul8_16_32_64(Elaboratable
):
784 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
786 Supports partitioning into any combination of 8, 16, 32, and 64-bit
787 partitions on naturally-aligned boundaries. Supports the operation being
788 set for each partition independently.
790 :attribute part_pts: the input partition points. Has a partition point at
791 multiples of 8 in 0 < i < 64. Each partition point's associated
792 ``Value`` is a ``Signal``. Modification not supported, except for by
794 :attribute part_ops: the operation for each byte. The operation for a
795 particular partition is selected by assigning the selected operation
796 code to each byte in the partition. The allowed operation codes are:
798 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
799 RISC-V's `mul` instruction.
800 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
801 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
803 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
804 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
805 `mulhsu` instruction.
806 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
807 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
811 def __init__(self
, register_levels
=()):
812 """ register_levels: specifies the points in the cascade at which
813 flip-flops are to be inserted.
817 self
.register_levels
= list(register_levels
)
820 self
.part_pts
= PartitionPoints()
821 for i
in range(8, 64, 8):
822 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
823 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
827 # intermediates (needed for unit tests)
828 self
._intermediate
_output
= Signal(128)
831 self
.output
= Signal(64)
833 def _part_byte(self
, index
):
834 if index
== -1 or index
== 7:
836 assert index
>= 0 and index
< 8
837 return self
.part_pts
[index
* 8 + 8]
839 def elaborate(self
, platform
):
843 pbs
= Signal(8, reset_less
=True)
846 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
847 m
.d
.comb
+= pb
.eq(self
._part
_byte
(i
))
849 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
856 setattr(m
.submodules
, "signs%d" % i
, s
)
857 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
860 [Signal(2, name
=f
"_delayed_part_ops_{delay}_{i}")
862 for delay
in range(1 + len(self
.register_levels
))]
863 for i
in range(len(self
.part_ops
)):
864 m
.d
.comb
+= delayed_part_ops
[0][i
].eq(self
.part_ops
[i
])
865 m
.d
.sync
+= [delayed_part_ops
[j
+ 1][i
].eq(delayed_part_ops
[j
][i
])
866 for j
in range(len(self
.register_levels
))]
868 n_levels
= len(self
.register_levels
)+1
869 m
.submodules
.part_8
= part_8
= Part(128, 8, n_levels
, 8)
870 m
.submodules
.part_16
= part_16
= Part(128, 4, n_levels
, 8)
871 m
.submodules
.part_32
= part_32
= Part(128, 2, n_levels
, 8)
872 m
.submodules
.part_64
= part_64
= Part(128, 1, n_levels
, 8)
873 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
874 for mod
in [part_8
, part_16
, part_32
, part_64
]:
875 m
.d
.comb
+= mod
.a
.eq(self
.a
)
876 m
.d
.comb
+= mod
.b
.eq(self
.b
)
877 for i
in range(len(signs
)):
878 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
879 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
880 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
881 nat_l
.append(mod
.not_a_term
)
882 nbt_l
.append(mod
.not_b_term
)
883 nla_l
.append(mod
.neg_lsb_a_term
)
884 nlb_l
.append(mod
.neg_lsb_b_term
)
888 for a_index
in range(8):
889 t
= ProductTerms(8, 128, 8, a_index
, 8)
890 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
892 m
.d
.comb
+= t
.a
.eq(self
.a
)
893 m
.d
.comb
+= t
.b
.eq(self
.b
)
894 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
899 # it's fine to bitwise-or data together since they are never enabled
901 m
.submodules
.nat_or
= nat_or
= OrMod(128)
902 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
903 m
.submodules
.nla_or
= nla_or
= OrMod(128)
904 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
905 for l
, mod
in [(nat_l
, nat_or
),
909 for i
in range(len(l
)):
910 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
911 terms
.append(mod
.orout
)
913 expanded_part_pts
= PartitionPoints()
914 for i
, v
in self
.part_pts
.items():
915 signal
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
916 expanded_part_pts
[i
* 2] = signal
917 m
.d
.comb
+= signal
.eq(v
)
919 add_reduce
= AddReduce(terms
,
921 self
.register_levels
,
923 m
.submodules
.add_reduce
= add_reduce
924 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
926 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
927 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
929 m
.d
.comb
+= io64
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
932 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
933 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
935 m
.d
.comb
+= io32
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
938 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
939 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
941 m
.d
.comb
+= io16
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
944 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
945 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
947 m
.d
.comb
+= io8
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
950 m
.submodules
.finalout
= finalout
= FinalOut(64)
951 for i
in range(len(part_8
.delayed_parts
[-1])):
952 m
.d
.comb
+= finalout
.d8
[i
].eq(part_8
.dplast
[i
])
953 for i
in range(len(part_16
.delayed_parts
[-1])):
954 m
.d
.comb
+= finalout
.d16
[i
].eq(part_16
.dplast
[i
])
955 for i
in range(len(part_32
.delayed_parts
[-1])):
956 m
.d
.comb
+= finalout
.d32
[i
].eq(part_32
.dplast
[i
])
957 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
958 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
959 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
960 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
961 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
966 if __name__
== "__main__":
970 m
._intermediate
_output
,
973 *m
.part_pts
.values()])