1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
59 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
60 retval
= PartitionPoints()
61 for point
, enabled
in self
.items():
62 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self
.keys()) != set(rhs
.keys()):
68 raise ValueError("incompatible point set")
69 for point
, enabled
in self
.items():
70 yield enabled
.eq(rhs
[point
])
72 def as_mask(self
, width
):
73 """Create a bit-mask from `self`.
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
78 :param width: the bit width of the resulting mask
81 for i
in range(width
):
88 def get_max_partition_count(self
, width
):
89 """Get the maximum number of partitions.
91 Gets the number of partitions when all partition points are enabled.
94 for point
in self
.keys():
99 def fits_in_width(self
, width
):
100 """Check if all partition points are smaller than `width`."""
101 for point
in self
.keys():
107 class FullAdder(Elaboratable
):
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
122 def __init__(self
, width
):
123 """Create a ``FullAdder``.
125 :param width: the bit width of the input and output
127 self
.in0
= Signal(width
)
128 self
.in1
= Signal(width
)
129 self
.in2
= Signal(width
)
130 self
.sum = Signal(width
)
131 self
.carry
= Signal(width
)
133 def elaborate(self
, platform
):
134 """Elaborate this module."""
136 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
137 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
138 |
(self
.in1
& self
.in2
)
139 |
(self
.in2
& self
.in0
))
143 class MaskedFullAdder(FullAdder
):
144 """Masked Full Adder.
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
158 def __init__(self
, width
):
159 """Create a ``MaskedFullAdder``.
161 :param width: the bit width of the input and output
163 FullAdder
.__init
__(self
, width
)
164 self
.mask
= Signal(width
)
165 self
.mcarry
= Signal(width
)
167 def elaborate(self
, platform
):
168 """Elaborate this module."""
169 m
= FullAdder
.elaborate(self
, platform
)
170 m
.d
.comb
+= self
.mcarry
.eq((self
.carry
<< 1) & self
.mask
)
174 class PartitionedAdder(Elaboratable
):
175 """Partitioned Adder.
177 Performs the final add. The partition points are included in the
178 actual add (in one of the operands only), which causes a carry over
179 to the next bit. Then the final output *removes* the extra bits from
182 partition: .... P... P... P... P... (32 bits)
183 a : .... .... .... .... .... (32 bits)
184 b : .... .... .... .... .... (32 bits)
185 exp-a : ....P....P....P....P.... (32+4 bits)
186 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
187 exp-o : ....xN...xN...xN...xN... (32+4 bits)
188 o : .... N... N... N... N... (32 bits)
190 :attribute width: the bit width of the input and output. Read-only.
191 :attribute a: the first input to the adder
192 :attribute b: the second input to the adder
193 :attribute output: the sum output
194 :attribute partition_points: the input partition points. Modification not
195 supported, except for by ``Signal.eq``.
198 def __init__(self
, width
, partition_points
):
199 """Create a ``PartitionedAdder``.
201 :param width: the bit width of the input and output
202 :param partition_points: the input partition points
205 self
.a
= Signal(width
)
206 self
.b
= Signal(width
)
207 self
.output
= Signal(width
)
208 self
.partition_points
= PartitionPoints(partition_points
)
209 if not self
.partition_points
.fits_in_width(width
):
210 raise ValueError("partition_points doesn't fit in width")
212 for i
in range(self
.width
):
213 if i
in self
.partition_points
:
216 self
._expanded
_width
= expanded_width
217 # XXX these have to remain here due to some horrible nmigen
218 # simulation bugs involving sync. it is *not* necessary to
219 # have them here, they should (under normal circumstances)
220 # be moved into elaborate, as they are entirely local
221 self
._expanded
_a
= Signal(expanded_width
) # includes extra part-points
222 self
._expanded
_b
= Signal(expanded_width
) # likewise.
223 self
._expanded
_o
= Signal(expanded_width
) # likewise.
225 def elaborate(self
, platform
):
226 """Elaborate this module."""
229 # store bits in a list, use Cat later. graphviz is much cleaner
230 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
232 # partition points are "breaks" (extra zeros or 1s) in what would
233 # otherwise be a massive long add. when the "break" points are 0,
234 # whatever is in it (in the output) is discarded. however when
235 # there is a "1", it causes a roll-over carry to the *next* bit.
236 # we still ignore the "break" bit in the [intermediate] output,
237 # however by that time we've got the effect that we wanted: the
238 # carry has been carried *over* the break point.
240 for i
in range(self
.width
):
241 if i
in self
.partition_points
:
242 # add extra bit set to 0 + 0 for enabled partition points
243 # and 1 + 0 for disabled partition points
244 ea
.append(self
._expanded
_a
[expanded_index
])
245 al
.append(~self
.partition_points
[i
]) # add extra bit in a
246 eb
.append(self
._expanded
_b
[expanded_index
])
247 bl
.append(C(0)) # yes, add a zero
248 expanded_index
+= 1 # skip the extra point. NOT in the output
249 ea
.append(self
._expanded
_a
[expanded_index
])
250 eb
.append(self
._expanded
_b
[expanded_index
])
251 eo
.append(self
._expanded
_o
[expanded_index
])
254 ol
.append(self
.output
[i
])
257 # combine above using Cat
258 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
259 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
260 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
262 # use only one addition to take advantage of look-ahead carry and
263 # special hardware on FPGAs
264 m
.d
.comb
+= self
._expanded
_o
.eq(
265 self
._expanded
_a
+ self
._expanded
_b
)
269 FULL_ADDER_INPUT_COUNT
= 3
272 class AddReduce(Elaboratable
):
273 """Add list of numbers together.
275 :attribute inputs: input ``Signal``s to be summed. Modification not
276 supported, except for by ``Signal.eq``.
277 :attribute register_levels: List of nesting levels that should have
279 :attribute output: output sum.
280 :attribute partition_points: the input partition points. Modification not
281 supported, except for by ``Signal.eq``.
284 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
):
285 """Create an ``AddReduce``.
287 :param inputs: input ``Signal``s to be summed.
288 :param output_width: bit-width of ``output``.
289 :param register_levels: List of nesting levels that should have
291 :param partition_points: the input partition points.
293 self
.inputs
= list(inputs
)
294 self
._resized
_inputs
= [
295 Signal(output_width
, name
=f
"resized_inputs[{i}]")
296 for i
in range(len(self
.inputs
))]
297 self
.register_levels
= list(register_levels
)
298 self
.output
= Signal(output_width
)
299 self
.partition_points
= PartitionPoints(partition_points
)
300 if not self
.partition_points
.fits_in_width(output_width
):
301 raise ValueError("partition_points doesn't fit in output_width")
302 self
._reg
_partition
_points
= self
.partition_points
.like()
303 max_level
= AddReduce
.get_max_level(len(self
.inputs
))
304 for level
in self
.register_levels
:
305 if level
> max_level
:
307 "not enough adder levels for specified register levels")
310 def get_max_level(input_count
):
311 """Get the maximum level.
313 All ``register_levels`` must be less than or equal to the maximum
318 groups
= AddReduce
.full_adder_groups(input_count
)
321 input_count
%= FULL_ADDER_INPUT_COUNT
322 input_count
+= 2 * len(groups
)
325 def next_register_levels(self
):
326 """``Iterable`` of ``register_levels`` for next recursive level."""
327 for level
in self
.register_levels
:
332 def full_adder_groups(input_count
):
333 """Get ``inputs`` indices for which a full adder should be built."""
335 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
336 FULL_ADDER_INPUT_COUNT
)
338 def elaborate(self
, platform
):
339 """Elaborate this module."""
342 # resize inputs to correct bit-width and optionally add in
344 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
345 for i
in range(len(self
.inputs
))]
346 if 0 in self
.register_levels
:
347 m
.d
.sync
+= resized_input_assignments
348 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
350 m
.d
.comb
+= resized_input_assignments
351 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
353 groups
= AddReduce
.full_adder_groups(len(self
.inputs
))
354 # if there are no full adders to create, then we handle the base cases
355 # and return, otherwise we go on to the recursive case
357 if len(self
.inputs
) == 0:
358 # use 0 as the default output value
359 m
.d
.comb
+= self
.output
.eq(0)
360 elif len(self
.inputs
) == 1:
361 # handle single input
362 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
364 # base case for adding 2 or more inputs, which get recursively
365 # reduced to 2 inputs
366 assert len(self
.inputs
) == 2
367 adder
= PartitionedAdder(len(self
.output
),
368 self
._reg
_partition
_points
)
369 m
.submodules
.final_adder
= adder
370 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
371 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
372 m
.d
.comb
+= self
.output
.eq(adder
.output
)
374 # go on to handle recursive case
375 intermediate_terms
= []
377 def add_intermediate_term(value
):
378 intermediate_term
= Signal(
380 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
381 intermediate_terms
.append(intermediate_term
)
382 m
.d
.comb
+= intermediate_term
.eq(value
)
384 # store mask in intermediary (simplifies graph)
385 part_mask
= Signal(len(self
.output
), reset_less
=True)
386 mask
= self
._reg
_partition
_points
.as_mask(len(self
.output
))
387 m
.d
.comb
+= part_mask
.eq(mask
)
389 # create full adders for this recursive level.
390 # this shrinks N terms to 2 * (N // 3) plus the remainder
392 adder_i
= MaskedFullAdder(len(self
.output
))
393 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
394 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[i
])
395 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[i
+ 1])
396 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[i
+ 2])
397 m
.d
.comb
+= adder_i
.mask
.eq(part_mask
)
398 # add both the sum and the masked-carry to the next level.
399 # 3 inputs have now been reduced to 2...
400 add_intermediate_term(adder_i
.sum)
401 add_intermediate_term(adder_i
.mcarry
)
402 # handle the remaining inputs.
403 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
404 add_intermediate_term(self
._resized
_inputs
[-1])
405 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
406 # Just pass the terms to the next layer, since we wouldn't gain
407 # anything by using a half adder since there would still be 2 terms
408 # and just passing the terms to the next layer saves gates.
409 add_intermediate_term(self
._resized
_inputs
[-2])
410 add_intermediate_term(self
._resized
_inputs
[-1])
412 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
413 # recursive invocation of ``AddReduce``
414 next_level
= AddReduce(intermediate_terms
,
416 self
.next_register_levels(),
417 self
._reg
_partition
_points
)
418 m
.submodules
.next_level
= next_level
419 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
424 OP_MUL_SIGNED_HIGH
= 1
425 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
426 OP_MUL_UNSIGNED_HIGH
= 3
429 def get_term(value
, shift
=0, enabled
=None):
430 if enabled
is not None:
431 value
= Mux(enabled
, value
, 0)
433 value
= Cat(Repl(C(0, 1), shift
), value
)
439 class ProductTerm(Elaboratable
):
440 """ this class creates a single product term (a[..]*b[..]).
441 it has a design flaw in that is the *output* that is selected,
442 where the multiplication(s) are combinatorially generated
446 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
447 self
.a_index
= a_index
448 self
.b_index
= b_index
449 shift
= 8 * (self
.a_index
+ self
.b_index
)
455 self
.ti
= Signal(self
.width
, reset_less
=True)
456 self
.term
= Signal(twidth
, reset_less
=True)
457 self
.a
= Signal(twidth
//2, reset_less
=True)
458 self
.b
= Signal(twidth
//2, reset_less
=True)
459 self
.pb_en
= Signal(pbwid
, reset_less
=True)
462 min_index
= min(self
.a_index
, self
.b_index
)
463 max_index
= max(self
.a_index
, self
.b_index
)
464 for i
in range(min_index
, max_index
):
465 tl
.append(self
.pb_en
[i
])
466 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
468 term_enabled
= Signal(name
=name
, reset_less
=True)
471 self
.enabled
= term_enabled
472 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
474 def elaborate(self
, platform
):
477 if self
.enabled
is not None:
478 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
480 bsa
= Signal(self
.width
, reset_less
=True)
481 bsb
= Signal(self
.width
, reset_less
=True)
482 a_index
, b_index
= self
.a_index
, self
.b_index
484 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
485 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
486 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
487 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
489 #TODO: sort out width issues, get inputs a/b switched on/off.
490 #data going into Muxes is 1/2 the required width
494 bsa = Signal(self.twidth//2, reset_less=True)
495 bsb = Signal(self.twidth//2, reset_less=True)
496 asel = Signal(width, reset_less=True)
497 bsel = Signal(width, reset_less=True)
498 a_index, b_index = self.a_index, self.b_index
499 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
500 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
501 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
502 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
503 m.d.comb += self.ti.eq(bsa * bsb)
504 m.d.comb += self.term.eq(self.ti)
510 class ProductTerms(Elaboratable
):
511 """ creates a bank of product terms. also performs the actual bit-selection
512 this class is to be wrapped with a for-loop on the "a" operand.
513 it creates a second-level for-loop on the "b" operand.
515 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
516 self
.a_index
= a_index
521 self
.a
= Signal(twidth
//2, reset_less
=True)
522 self
.b
= Signal(twidth
//2, reset_less
=True)
523 self
.pb_en
= Signal(pbwid
, reset_less
=True)
524 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
525 for i
in range(blen
)]
527 def elaborate(self
, platform
):
531 for b_index
in range(self
.blen
):
532 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
533 self
.a_index
, b_index
)
534 setattr(m
.submodules
, "term_%d" % b_index
, t
)
536 m
.d
.comb
+= t
.a
.eq(self
.a
)
537 m
.d
.comb
+= t
.b
.eq(self
.b
)
538 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
540 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
544 class LSBNegTerm(Elaboratable
):
546 def __init__(self
, bit_width
):
547 self
.bit_width
= bit_width
548 self
.part
= Signal(reset_less
=True)
549 self
.signed
= Signal(reset_less
=True)
550 self
.op
= Signal(bit_width
, reset_less
=True)
551 self
.msb
= Signal(reset_less
=True)
552 self
.nt
= Signal(bit_width
*2, reset_less
=True)
553 self
.nl
= Signal(bit_width
*2, reset_less
=True)
555 def elaborate(self
, platform
):
558 bit_wid
= self
.bit_width
559 ext
= Repl(0, bit_wid
) # extend output to HI part
561 # determine sign of each incoming number *in this partition*
562 enabled
= Signal(reset_less
=True)
563 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
565 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
566 # negation operation is split into a bitwise not and a +1.
567 # likewise for 16, 32, and 64-bit values.
569 # width-extended 1s complement if a is signed, otherwise zero
570 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
572 # add 1 if signed, otherwise add zero
573 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
578 class Part(Elaboratable
):
579 """ a key class which, depending on the partitioning, will determine
580 what action to take when parts of the output are signed or unsigned.
582 this requires 2 pieces of data *per operand, per partition*:
583 whether the MSB is HI/LO (per partition!), and whether a signed
584 or unsigned operation has been *requested*.
586 once that is determined, signed is basically carried out
587 by splitting 2's complement into 1's complement plus one.
588 1's complement is just a bit-inversion.
590 the extra terms - as separate terms - are then thrown at the
591 AddReduce alongside the multiplication part-results.
593 def __init__(self
, width
, n_parts
, n_levels
, pbwid
):
598 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
599 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
600 self
.pbs
= Signal(pbwid
, reset_less
=True)
603 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
604 self
.delayed_parts
= [
605 [Signal(name
=f
"delayed_part_{delay}_{i}")
606 for i
in range(n_parts
)]
607 for delay
in range(n_levels
)]
608 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
609 self
.dplast
= [Signal(name
=f
"dplast_{i}")
610 for i
in range(n_parts
)]
612 self
.not_a_term
= Signal(width
)
613 self
.neg_lsb_a_term
= Signal(width
)
614 self
.not_b_term
= Signal(width
)
615 self
.neg_lsb_b_term
= Signal(width
)
617 def elaborate(self
, platform
):
620 pbs
, parts
, delayed_parts
= self
.pbs
, self
.parts
, self
.delayed_parts
621 # negated-temporary copy of partition bits
622 npbs
= Signal
.like(pbs
, reset_less
=True)
623 m
.d
.comb
+= npbs
.eq(~pbs
)
624 byte_count
= 8 // len(parts
)
625 for i
in range(len(parts
)):
627 pbl
.append(npbs
[i
* byte_count
- 1])
628 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
630 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
631 value
= Signal(len(pbl
), name
="value_%di" % i
, reset_less
=True)
632 m
.d
.comb
+= value
.eq(Cat(*pbl
))
633 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
634 m
.d
.comb
+= delayed_parts
[0][i
].eq(parts
[i
])
635 m
.d
.sync
+= [delayed_parts
[j
+ 1][i
].eq(delayed_parts
[j
][i
])
636 for j
in range(len(delayed_parts
)-1)]
637 m
.d
.comb
+= self
.dplast
[i
].eq(delayed_parts
[-1][i
])
639 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= \
640 self
.not_a_term
, self
.neg_lsb_a_term
, \
641 self
.not_b_term
, self
.neg_lsb_b_term
643 byte_width
= 8 // len(parts
) # byte width
644 bit_wid
= 8 * byte_width
# bit width
645 nat
, nbt
, nla
, nlb
= [], [], [], []
646 for i
in range(len(parts
)):
647 # work out bit-inverted and +1 term for a.
648 pa
= LSBNegTerm(bit_wid
)
649 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
650 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
651 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
652 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
653 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
657 # work out bit-inverted and +1 term for b
658 pb
= LSBNegTerm(bit_wid
)
659 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
660 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
661 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
662 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
663 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
667 # concatenate together and return all 4 results.
668 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
669 not_b_term
.eq(Cat(*nbt
)),
670 neg_lsb_a_term
.eq(Cat(*nla
)),
671 neg_lsb_b_term
.eq(Cat(*nlb
)),
677 class IntermediateOut(Elaboratable
):
678 """ selects the HI/LO part of the multiplication, for a given bit-width
679 the output is also reconstructed in its SIMD (partition) lanes.
681 def __init__(self
, width
, out_wid
, n_parts
):
683 self
.n_parts
= n_parts
684 self
.delayed_part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
686 self
.intermed
= Signal(out_wid
, reset_less
=True)
687 self
.output
= Signal(out_wid
//2, reset_less
=True)
689 def elaborate(self
, platform
):
695 for i
in range(self
.n_parts
):
696 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
698 Mux(self
.delayed_part_ops
[sel
* i
] == OP_MUL_LOW
,
699 self
.intermed
.bit_select(i
* w
*2, w
),
700 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
702 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
707 class FinalOut(Elaboratable
):
708 """ selects the final output based on the partitioning.
710 each byte is selectable independently, i.e. it is possible
711 that some partitions requested 8-bit computation whilst others
712 requested 16 or 32 bit.
714 def __init__(self
, out_wid
):
716 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
717 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
718 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
720 self
.i8
= Signal(out_wid
, reset_less
=True)
721 self
.i16
= Signal(out_wid
, reset_less
=True)
722 self
.i32
= Signal(out_wid
, reset_less
=True)
723 self
.i64
= Signal(out_wid
, reset_less
=True)
726 self
.out
= Signal(out_wid
, reset_less
=True)
728 def elaborate(self
, platform
):
732 # select one of the outputs: d8 selects i8, d16 selects i16
733 # d32 selects i32, and the default is i64.
734 # d8 and d16 are ORed together in the first Mux
735 # then the 2nd selects either i8 or i16.
736 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
737 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
739 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
740 Mux(self
.d8
[i
], self
.i8
.bit_select(i
* 8, 8),
741 self
.i16
.bit_select(i
* 8, 8)),
742 Mux(self
.d32
[i
// 4], self
.i32
.bit_select(i
* 8, 8),
743 self
.i64
.bit_select(i
* 8, 8))))
745 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
749 class OrMod(Elaboratable
):
750 """ ORs four values together in a hierarchical tree
752 def __init__(self
, wid
):
754 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
756 self
.orout
= Signal(wid
, reset_less
=True)
758 def elaborate(self
, platform
):
760 or1
= Signal(self
.wid
, reset_less
=True)
761 or2
= Signal(self
.wid
, reset_less
=True)
762 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
763 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
764 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
769 class Signs(Elaboratable
):
770 """ determines whether a or b are signed numbers
771 based on the required operation type (OP_MUL_*)
775 self
.part_ops
= Signal(2, reset_less
=True)
776 self
.a_signed
= Signal(reset_less
=True)
777 self
.b_signed
= Signal(reset_less
=True)
779 def elaborate(self
, platform
):
783 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
784 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
785 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
786 m
.d
.comb
+= self
.a_signed
.eq(asig
)
787 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
792 class Mul8_16_32_64(Elaboratable
):
793 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
795 Supports partitioning into any combination of 8, 16, 32, and 64-bit
796 partitions on naturally-aligned boundaries. Supports the operation being
797 set for each partition independently.
799 :attribute part_pts: the input partition points. Has a partition point at
800 multiples of 8 in 0 < i < 64. Each partition point's associated
801 ``Value`` is a ``Signal``. Modification not supported, except for by
803 :attribute part_ops: the operation for each byte. The operation for a
804 particular partition is selected by assigning the selected operation
805 code to each byte in the partition. The allowed operation codes are:
807 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
808 RISC-V's `mul` instruction.
809 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
810 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
812 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
813 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
814 `mulhsu` instruction.
815 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
816 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
820 def __init__(self
, register_levels
=()):
821 """ register_levels: specifies the points in the cascade at which
822 flip-flops are to be inserted.
826 self
.register_levels
= list(register_levels
)
829 self
.part_pts
= PartitionPoints()
830 for i
in range(8, 64, 8):
831 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
832 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
836 # intermediates (needed for unit tests)
837 self
._intermediate
_output
= Signal(128)
840 self
.output
= Signal(64)
842 def _part_byte(self
, index
):
843 if index
== -1 or index
== 7:
845 assert index
>= 0 and index
< 8
846 return self
.part_pts
[index
* 8 + 8]
848 def elaborate(self
, platform
):
852 pbs
= Signal(8, reset_less
=True)
855 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
856 m
.d
.comb
+= pb
.eq(self
._part
_byte
(i
))
858 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
865 setattr(m
.submodules
, "signs%d" % i
, s
)
866 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
869 [Signal(2, name
=f
"_delayed_part_ops_{delay}_{i}")
871 for delay
in range(1 + len(self
.register_levels
))]
872 for i
in range(len(self
.part_ops
)):
873 m
.d
.comb
+= delayed_part_ops
[0][i
].eq(self
.part_ops
[i
])
874 m
.d
.sync
+= [delayed_part_ops
[j
+ 1][i
].eq(delayed_part_ops
[j
][i
])
875 for j
in range(len(self
.register_levels
))]
877 n_levels
= len(self
.register_levels
)+1
878 m
.submodules
.part_8
= part_8
= Part(128, 8, n_levels
, 8)
879 m
.submodules
.part_16
= part_16
= Part(128, 4, n_levels
, 8)
880 m
.submodules
.part_32
= part_32
= Part(128, 2, n_levels
, 8)
881 m
.submodules
.part_64
= part_64
= Part(128, 1, n_levels
, 8)
882 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
883 for mod
in [part_8
, part_16
, part_32
, part_64
]:
884 m
.d
.comb
+= mod
.a
.eq(self
.a
)
885 m
.d
.comb
+= mod
.b
.eq(self
.b
)
886 for i
in range(len(signs
)):
887 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
888 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
889 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
890 nat_l
.append(mod
.not_a_term
)
891 nbt_l
.append(mod
.not_b_term
)
892 nla_l
.append(mod
.neg_lsb_a_term
)
893 nlb_l
.append(mod
.neg_lsb_b_term
)
897 for a_index
in range(8):
898 t
= ProductTerms(8, 128, 8, a_index
, 8)
899 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
901 m
.d
.comb
+= t
.a
.eq(self
.a
)
902 m
.d
.comb
+= t
.b
.eq(self
.b
)
903 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
908 # it's fine to bitwise-or data together since they are never enabled
910 m
.submodules
.nat_or
= nat_or
= OrMod(128)
911 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
912 m
.submodules
.nla_or
= nla_or
= OrMod(128)
913 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
914 for l
, mod
in [(nat_l
, nat_or
),
918 for i
in range(len(l
)):
919 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
920 terms
.append(mod
.orout
)
922 expanded_part_pts
= PartitionPoints()
923 for i
, v
in self
.part_pts
.items():
924 signal
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
925 expanded_part_pts
[i
* 2] = signal
926 m
.d
.comb
+= signal
.eq(v
)
928 add_reduce
= AddReduce(terms
,
930 self
.register_levels
,
932 m
.submodules
.add_reduce
= add_reduce
933 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
935 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
936 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
938 m
.d
.comb
+= io64
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
941 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
942 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
944 m
.d
.comb
+= io32
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
947 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
948 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
950 m
.d
.comb
+= io16
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
953 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
954 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
956 m
.d
.comb
+= io8
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
959 m
.submodules
.finalout
= finalout
= FinalOut(64)
960 for i
in range(len(part_8
.delayed_parts
[-1])):
961 m
.d
.comb
+= finalout
.d8
[i
].eq(part_8
.dplast
[i
])
962 for i
in range(len(part_16
.delayed_parts
[-1])):
963 m
.d
.comb
+= finalout
.d16
[i
].eq(part_16
.dplast
[i
])
964 for i
in range(len(part_32
.delayed_parts
[-1])):
965 m
.d
.comb
+= finalout
.d32
[i
].eq(part_32
.dplast
[i
])
966 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
967 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
968 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
969 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
970 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
975 if __name__
== "__main__":
979 m
._intermediate
_output
,
982 *m
.part_pts
.values()])