1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
59 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
60 retval
= PartitionPoints()
61 for point
, enabled
in self
.items():
62 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self
.keys()) != set(rhs
.keys()):
68 raise ValueError("incompatible point set")
69 for point
, enabled
in self
.items():
70 yield enabled
.eq(rhs
[point
])
72 def as_mask(self
, width
):
73 """Create a bit-mask from `self`.
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
78 :param width: the bit width of the resulting mask
81 for i
in range(width
):
88 def get_max_partition_count(self
, width
):
89 """Get the maximum number of partitions.
91 Gets the number of partitions when all partition points are enabled.
94 for point
in self
.keys():
99 def fits_in_width(self
, width
):
100 """Check if all partition points are smaller than `width`."""
101 for point
in self
.keys():
107 class FullAdder(Elaboratable
):
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
122 def __init__(self
, width
):
123 """Create a ``FullAdder``.
125 :param width: the bit width of the input and output
127 self
.in0
= Signal(width
)
128 self
.in1
= Signal(width
)
129 self
.in2
= Signal(width
)
130 self
.sum = Signal(width
)
131 self
.carry
= Signal(width
)
133 def elaborate(self
, platform
):
134 """Elaborate this module."""
136 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
137 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
138 |
(self
.in1
& self
.in2
)
139 |
(self
.in2
& self
.in0
))
143 class MaskedFullAdder(Elaboratable
):
144 """Masked Full Adder.
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
157 See the following discussion as to why this is no longer derived
158 from FullAdder. Each carry is shifted here *before* being ANDed
159 with the mask, so that an AOI cell may be used (which is more
161 https://en.wikipedia.org/wiki/AND-OR-Invert
162 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
165 def __init__(self
, width
):
166 """Create a ``MaskedFullAdder``.
168 :param width: the bit width of the input and output
171 self
.mask
= Signal(width
, reset_less
=True)
172 self
.mcarry
= Signal(width
, reset_less
=True)
173 self
.in0
= Signal(width
, reset_less
=True)
174 self
.in1
= Signal(width
, reset_less
=True)
175 self
.in2
= Signal(width
, reset_less
=True)
176 self
.sum = Signal(width
, reset_less
=True)
178 def elaborate(self
, platform
):
179 """Elaborate this module."""
181 s1
= Signal(self
.width
, reset_less
=True)
182 s2
= Signal(self
.width
, reset_less
=True)
183 s3
= Signal(self
.width
, reset_less
=True)
184 c1
= Signal(self
.width
, reset_less
=True)
185 c2
= Signal(self
.width
, reset_less
=True)
186 c3
= Signal(self
.width
, reset_less
=True)
187 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
188 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
189 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
190 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
191 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
192 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
193 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
194 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
198 class PartitionedAdder(Elaboratable
):
199 """Partitioned Adder.
201 Performs the final add. The partition points are included in the
202 actual add (in one of the operands only), which causes a carry over
203 to the next bit. Then the final output *removes* the extra bits from
206 partition: .... P... P... P... P... (32 bits)
207 a : .... .... .... .... .... (32 bits)
208 b : .... .... .... .... .... (32 bits)
209 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
210 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
211 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
212 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
214 :attribute width: the bit width of the input and output. Read-only.
215 :attribute a: the first input to the adder
216 :attribute b: the second input to the adder
217 :attribute output: the sum output
218 :attribute partition_points: the input partition points. Modification not
219 supported, except for by ``Signal.eq``.
222 def __init__(self
, width
, partition_points
):
223 """Create a ``PartitionedAdder``.
225 :param width: the bit width of the input and output
226 :param partition_points: the input partition points
229 self
.a
= Signal(width
)
230 self
.b
= Signal(width
)
231 self
.output
= Signal(width
)
232 self
.partition_points
= PartitionPoints(partition_points
)
233 if not self
.partition_points
.fits_in_width(width
):
234 raise ValueError("partition_points doesn't fit in width")
236 for i
in range(self
.width
):
237 if i
in self
.partition_points
:
240 self
._expanded
_width
= expanded_width
241 # XXX these have to remain here due to some horrible nmigen
242 # simulation bugs involving sync. it is *not* necessary to
243 # have them here, they should (under normal circumstances)
244 # be moved into elaborate, as they are entirely local
245 self
._expanded
_a
= Signal(expanded_width
) # includes extra part-points
246 self
._expanded
_b
= Signal(expanded_width
) # likewise.
247 self
._expanded
_o
= Signal(expanded_width
) # likewise.
249 def elaborate(self
, platform
):
250 """Elaborate this module."""
253 # store bits in a list, use Cat later. graphviz is much cleaner
254 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
256 # partition points are "breaks" (extra zeros or 1s) in what would
257 # otherwise be a massive long add. when the "break" points are 0,
258 # whatever is in it (in the output) is discarded. however when
259 # there is a "1", it causes a roll-over carry to the *next* bit.
260 # we still ignore the "break" bit in the [intermediate] output,
261 # however by that time we've got the effect that we wanted: the
262 # carry has been carried *over* the break point.
264 for i
in range(self
.width
):
265 if i
in self
.partition_points
:
266 # add extra bit set to 0 + 0 for enabled partition points
267 # and 1 + 0 for disabled partition points
268 ea
.append(self
._expanded
_a
[expanded_index
])
269 al
.append(~self
.partition_points
[i
]) # add extra bit in a
270 eb
.append(self
._expanded
_b
[expanded_index
])
271 bl
.append(C(0)) # yes, add a zero
272 expanded_index
+= 1 # skip the extra point. NOT in the output
273 ea
.append(self
._expanded
_a
[expanded_index
])
274 eb
.append(self
._expanded
_b
[expanded_index
])
275 eo
.append(self
._expanded
_o
[expanded_index
])
278 ol
.append(self
.output
[i
])
281 # combine above using Cat
282 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
283 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
284 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
286 # use only one addition to take advantage of look-ahead carry and
287 # special hardware on FPGAs
288 m
.d
.comb
+= self
._expanded
_o
.eq(
289 self
._expanded
_a
+ self
._expanded
_b
)
293 FULL_ADDER_INPUT_COUNT
= 3
296 class AddReduceSingle(Elaboratable
):
297 """Add list of numbers together.
299 :attribute inputs: input ``Signal``s to be summed. Modification not
300 supported, except for by ``Signal.eq``.
301 :attribute register_levels: List of nesting levels that should have
303 :attribute output: output sum.
304 :attribute partition_points: the input partition points. Modification not
305 supported, except for by ``Signal.eq``.
308 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
):
309 """Create an ``AddReduce``.
311 :param inputs: input ``Signal``s to be summed.
312 :param output_width: bit-width of ``output``.
313 :param register_levels: List of nesting levels that should have
315 :param partition_points: the input partition points.
317 self
.inputs
= list(inputs
)
318 self
._resized
_inputs
= [
319 Signal(output_width
, name
=f
"resized_inputs[{i}]")
320 for i
in range(len(self
.inputs
))]
321 self
.register_levels
= list(register_levels
)
322 self
.output
= Signal(output_width
)
323 self
.partition_points
= PartitionPoints(partition_points
)
324 if not self
.partition_points
.fits_in_width(output_width
):
325 raise ValueError("partition_points doesn't fit in output_width")
326 self
._reg
_partition
_points
= self
.partition_points
.like()
328 max_level
= AddReduceSingle
.get_max_level(len(self
.inputs
))
329 for level
in self
.register_levels
:
330 if level
> max_level
:
332 "not enough adder levels for specified register levels")
334 self
.groups
= AddReduceSingle
.full_adder_groups(len(self
.inputs
))
335 self
._intermediate
_terms
= []
336 if len(self
.groups
) != 0:
337 self
.create_next_terms()
340 def get_max_level(input_count
):
341 """Get the maximum level.
343 All ``register_levels`` must be less than or equal to the maximum
348 groups
= AddReduceSingle
.full_adder_groups(input_count
)
351 input_count
%= FULL_ADDER_INPUT_COUNT
352 input_count
+= 2 * len(groups
)
356 def full_adder_groups(input_count
):
357 """Get ``inputs`` indices for which a full adder should be built."""
359 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
360 FULL_ADDER_INPUT_COUNT
)
362 def elaborate(self
, platform
):
363 """Elaborate this module."""
366 # resize inputs to correct bit-width and optionally add in
368 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
369 for i
in range(len(self
.inputs
))]
370 if 0 in self
.register_levels
:
371 m
.d
.sync
+= resized_input_assignments
372 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
374 m
.d
.comb
+= resized_input_assignments
375 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
377 for (value
, term
) in self
._intermediate
_terms
:
378 m
.d
.comb
+= term
.eq(value
)
380 # if there are no full adders to create, then we handle the base cases
381 # and return, otherwise we go on to the recursive case
382 if len(self
.groups
) == 0:
383 if len(self
.inputs
) == 0:
384 # use 0 as the default output value
385 m
.d
.comb
+= self
.output
.eq(0)
386 elif len(self
.inputs
) == 1:
387 # handle single input
388 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
390 # base case for adding 2 inputs
391 assert len(self
.inputs
) == 2
392 adder
= PartitionedAdder(len(self
.output
),
393 self
._reg
_partition
_points
)
394 m
.submodules
.final_adder
= adder
395 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
396 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
397 m
.d
.comb
+= self
.output
.eq(adder
.output
)
400 mask
= self
._reg
_partition
_points
.as_mask(len(self
.output
))
401 m
.d
.comb
+= self
.part_mask
.eq(mask
)
403 # add and link the intermediate term modules
404 for i
, (iidx
, adder_i
) in enumerate(self
.adders
):
405 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
407 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[iidx
])
408 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[iidx
+ 1])
409 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[iidx
+ 2])
410 m
.d
.comb
+= adder_i
.mask
.eq(self
.part_mask
)
414 def create_next_terms(self
):
416 # go on to prepare recursive case
417 intermediate_terms
= []
418 _intermediate_terms
= []
420 def add_intermediate_term(value
):
421 intermediate_term
= Signal(
423 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
424 _intermediate_terms
.append((value
, intermediate_term
))
425 intermediate_terms
.append(intermediate_term
)
427 # store mask in intermediary (simplifies graph)
428 self
.part_mask
= Signal(len(self
.output
), reset_less
=True)
430 # create full adders for this recursive level.
431 # this shrinks N terms to 2 * (N // 3) plus the remainder
433 for i
in self
.groups
:
434 adder_i
= MaskedFullAdder(len(self
.output
))
435 self
.adders
.append((i
, adder_i
))
436 # add both the sum and the masked-carry to the next level.
437 # 3 inputs have now been reduced to 2...
438 add_intermediate_term(adder_i
.sum)
439 add_intermediate_term(adder_i
.mcarry
)
440 # handle the remaining inputs.
441 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
442 add_intermediate_term(self
._resized
_inputs
[-1])
443 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
444 # Just pass the terms to the next layer, since we wouldn't gain
445 # anything by using a half adder since there would still be 2 terms
446 # and just passing the terms to the next layer saves gates.
447 add_intermediate_term(self
._resized
_inputs
[-2])
448 add_intermediate_term(self
._resized
_inputs
[-1])
450 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
452 self
.intermediate_terms
= intermediate_terms
453 self
._intermediate
_terms
= _intermediate_terms
456 class AddReduce(Elaboratable
):
457 """Recursively Add list of numbers together.
459 :attribute inputs: input ``Signal``s to be summed. Modification not
460 supported, except for by ``Signal.eq``.
461 :attribute register_levels: List of nesting levels that should have
463 :attribute output: output sum.
464 :attribute partition_points: the input partition points. Modification not
465 supported, except for by ``Signal.eq``.
468 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
):
469 """Create an ``AddReduce``.
471 :param inputs: input ``Signal``s to be summed.
472 :param output_width: bit-width of ``output``.
473 :param register_levels: List of nesting levels that should have
475 :param partition_points: the input partition points.
478 self
.output
= Signal(output_width
)
479 self
.output_width
= output_width
480 self
.register_levels
= register_levels
481 self
.partition_points
= partition_points
486 def next_register_levels(register_levels
):
487 """``Iterable`` of ``register_levels`` for next recursive level."""
488 for level
in register_levels
:
492 def create_levels(self
):
493 """creates reduction levels"""
496 next_levels
= self
.register_levels
497 partition_points
= self
.partition_points
500 next_level
= AddReduceSingle(inputs
, self
.output_width
, next_levels
,
502 mods
.append(next_level
)
503 if len(next_level
.groups
) == 0:
505 next_levels
= list(AddReduce
.next_register_levels(next_levels
))
506 partition_points
= next_level
._reg
_partition
_points
507 inputs
= next_level
.intermediate_terms
511 def elaborate(self
, platform
):
512 """Elaborate this module."""
515 for i
, next_level
in enumerate(self
.levels
):
516 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
518 # output comes from last module
519 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
525 OP_MUL_SIGNED_HIGH
= 1
526 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
527 OP_MUL_UNSIGNED_HIGH
= 3
530 def get_term(value
, shift
=0, enabled
=None):
531 if enabled
is not None:
532 value
= Mux(enabled
, value
, 0)
534 value
= Cat(Repl(C(0, 1), shift
), value
)
540 class ProductTerm(Elaboratable
):
541 """ this class creates a single product term (a[..]*b[..]).
542 it has a design flaw in that is the *output* that is selected,
543 where the multiplication(s) are combinatorially generated
547 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
548 self
.a_index
= a_index
549 self
.b_index
= b_index
550 shift
= 8 * (self
.a_index
+ self
.b_index
)
556 self
.ti
= Signal(self
.width
, reset_less
=True)
557 self
.term
= Signal(twidth
, reset_less
=True)
558 self
.a
= Signal(twidth
//2, reset_less
=True)
559 self
.b
= Signal(twidth
//2, reset_less
=True)
560 self
.pb_en
= Signal(pbwid
, reset_less
=True)
563 min_index
= min(self
.a_index
, self
.b_index
)
564 max_index
= max(self
.a_index
, self
.b_index
)
565 for i
in range(min_index
, max_index
):
566 tl
.append(self
.pb_en
[i
])
567 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
569 term_enabled
= Signal(name
=name
, reset_less
=True)
572 self
.enabled
= term_enabled
573 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
575 def elaborate(self
, platform
):
578 if self
.enabled
is not None:
579 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
581 bsa
= Signal(self
.width
, reset_less
=True)
582 bsb
= Signal(self
.width
, reset_less
=True)
583 a_index
, b_index
= self
.a_index
, self
.b_index
585 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
586 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
587 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
588 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
590 #TODO: sort out width issues, get inputs a/b switched on/off.
591 #data going into Muxes is 1/2 the required width
595 bsa = Signal(self.twidth//2, reset_less=True)
596 bsb = Signal(self.twidth//2, reset_less=True)
597 asel = Signal(width, reset_less=True)
598 bsel = Signal(width, reset_less=True)
599 a_index, b_index = self.a_index, self.b_index
600 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
601 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
602 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
603 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
604 m.d.comb += self.ti.eq(bsa * bsb)
605 m.d.comb += self.term.eq(self.ti)
611 class ProductTerms(Elaboratable
):
612 """ creates a bank of product terms. also performs the actual bit-selection
613 this class is to be wrapped with a for-loop on the "a" operand.
614 it creates a second-level for-loop on the "b" operand.
616 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
617 self
.a_index
= a_index
622 self
.a
= Signal(twidth
//2, reset_less
=True)
623 self
.b
= Signal(twidth
//2, reset_less
=True)
624 self
.pb_en
= Signal(pbwid
, reset_less
=True)
625 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
626 for i
in range(blen
)]
628 def elaborate(self
, platform
):
632 for b_index
in range(self
.blen
):
633 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
634 self
.a_index
, b_index
)
635 setattr(m
.submodules
, "term_%d" % b_index
, t
)
637 m
.d
.comb
+= t
.a
.eq(self
.a
)
638 m
.d
.comb
+= t
.b
.eq(self
.b
)
639 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
641 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
646 class LSBNegTerm(Elaboratable
):
648 def __init__(self
, bit_width
):
649 self
.bit_width
= bit_width
650 self
.part
= Signal(reset_less
=True)
651 self
.signed
= Signal(reset_less
=True)
652 self
.op
= Signal(bit_width
, reset_less
=True)
653 self
.msb
= Signal(reset_less
=True)
654 self
.nt
= Signal(bit_width
*2, reset_less
=True)
655 self
.nl
= Signal(bit_width
*2, reset_less
=True)
657 def elaborate(self
, platform
):
660 bit_wid
= self
.bit_width
661 ext
= Repl(0, bit_wid
) # extend output to HI part
663 # determine sign of each incoming number *in this partition*
664 enabled
= Signal(reset_less
=True)
665 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
667 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
668 # negation operation is split into a bitwise not and a +1.
669 # likewise for 16, 32, and 64-bit values.
671 # width-extended 1s complement if a is signed, otherwise zero
672 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
674 # add 1 if signed, otherwise add zero
675 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
680 class Part(Elaboratable
):
681 """ a key class which, depending on the partitioning, will determine
682 what action to take when parts of the output are signed or unsigned.
684 this requires 2 pieces of data *per operand, per partition*:
685 whether the MSB is HI/LO (per partition!), and whether a signed
686 or unsigned operation has been *requested*.
688 once that is determined, signed is basically carried out
689 by splitting 2's complement into 1's complement plus one.
690 1's complement is just a bit-inversion.
692 the extra terms - as separate terms - are then thrown at the
693 AddReduce alongside the multiplication part-results.
695 def __init__(self
, width
, n_parts
, n_levels
, pbwid
):
700 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
701 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
702 self
.pbs
= Signal(pbwid
, reset_less
=True)
705 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
706 self
.delayed_parts
= [
707 [Signal(name
=f
"delayed_part_{delay}_{i}")
708 for i
in range(n_parts
)]
709 for delay
in range(n_levels
)]
710 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
711 self
.dplast
= [Signal(name
=f
"dplast_{i}")
712 for i
in range(n_parts
)]
714 self
.not_a_term
= Signal(width
)
715 self
.neg_lsb_a_term
= Signal(width
)
716 self
.not_b_term
= Signal(width
)
717 self
.neg_lsb_b_term
= Signal(width
)
719 def elaborate(self
, platform
):
722 pbs
, parts
, delayed_parts
= self
.pbs
, self
.parts
, self
.delayed_parts
723 # negated-temporary copy of partition bits
724 npbs
= Signal
.like(pbs
, reset_less
=True)
725 m
.d
.comb
+= npbs
.eq(~pbs
)
726 byte_count
= 8 // len(parts
)
727 for i
in range(len(parts
)):
729 pbl
.append(npbs
[i
* byte_count
- 1])
730 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
732 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
733 value
= Signal(len(pbl
), name
="value_%di" % i
, reset_less
=True)
734 m
.d
.comb
+= value
.eq(Cat(*pbl
))
735 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
736 m
.d
.comb
+= delayed_parts
[0][i
].eq(parts
[i
])
737 m
.d
.sync
+= [delayed_parts
[j
+ 1][i
].eq(delayed_parts
[j
][i
])
738 for j
in range(len(delayed_parts
)-1)]
739 m
.d
.comb
+= self
.dplast
[i
].eq(delayed_parts
[-1][i
])
741 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= \
742 self
.not_a_term
, self
.neg_lsb_a_term
, \
743 self
.not_b_term
, self
.neg_lsb_b_term
745 byte_width
= 8 // len(parts
) # byte width
746 bit_wid
= 8 * byte_width
# bit width
747 nat
, nbt
, nla
, nlb
= [], [], [], []
748 for i
in range(len(parts
)):
749 # work out bit-inverted and +1 term for a.
750 pa
= LSBNegTerm(bit_wid
)
751 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
752 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
753 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
754 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
755 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
759 # work out bit-inverted and +1 term for b
760 pb
= LSBNegTerm(bit_wid
)
761 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
762 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
763 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
764 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
765 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
769 # concatenate together and return all 4 results.
770 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
771 not_b_term
.eq(Cat(*nbt
)),
772 neg_lsb_a_term
.eq(Cat(*nla
)),
773 neg_lsb_b_term
.eq(Cat(*nlb
)),
779 class IntermediateOut(Elaboratable
):
780 """ selects the HI/LO part of the multiplication, for a given bit-width
781 the output is also reconstructed in its SIMD (partition) lanes.
783 def __init__(self
, width
, out_wid
, n_parts
):
785 self
.n_parts
= n_parts
786 self
.delayed_part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
788 self
.intermed
= Signal(out_wid
, reset_less
=True)
789 self
.output
= Signal(out_wid
//2, reset_less
=True)
791 def elaborate(self
, platform
):
797 for i
in range(self
.n_parts
):
798 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
800 Mux(self
.delayed_part_ops
[sel
* i
] == OP_MUL_LOW
,
801 self
.intermed
.bit_select(i
* w
*2, w
),
802 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
804 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
809 class FinalOut(Elaboratable
):
810 """ selects the final output based on the partitioning.
812 each byte is selectable independently, i.e. it is possible
813 that some partitions requested 8-bit computation whilst others
814 requested 16 or 32 bit.
816 def __init__(self
, out_wid
):
818 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
819 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
820 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
822 self
.i8
= Signal(out_wid
, reset_less
=True)
823 self
.i16
= Signal(out_wid
, reset_less
=True)
824 self
.i32
= Signal(out_wid
, reset_less
=True)
825 self
.i64
= Signal(out_wid
, reset_less
=True)
828 self
.out
= Signal(out_wid
, reset_less
=True)
830 def elaborate(self
, platform
):
834 # select one of the outputs: d8 selects i8, d16 selects i16
835 # d32 selects i32, and the default is i64.
836 # d8 and d16 are ORed together in the first Mux
837 # then the 2nd selects either i8 or i16.
838 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
839 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
841 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
842 Mux(self
.d8
[i
], self
.i8
.bit_select(i
* 8, 8),
843 self
.i16
.bit_select(i
* 8, 8)),
844 Mux(self
.d32
[i
// 4], self
.i32
.bit_select(i
* 8, 8),
845 self
.i64
.bit_select(i
* 8, 8))))
847 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
851 class OrMod(Elaboratable
):
852 """ ORs four values together in a hierarchical tree
854 def __init__(self
, wid
):
856 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
858 self
.orout
= Signal(wid
, reset_less
=True)
860 def elaborate(self
, platform
):
862 or1
= Signal(self
.wid
, reset_less
=True)
863 or2
= Signal(self
.wid
, reset_less
=True)
864 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
865 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
866 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
871 class Signs(Elaboratable
):
872 """ determines whether a or b are signed numbers
873 based on the required operation type (OP_MUL_*)
877 self
.part_ops
= Signal(2, reset_less
=True)
878 self
.a_signed
= Signal(reset_less
=True)
879 self
.b_signed
= Signal(reset_less
=True)
881 def elaborate(self
, platform
):
885 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
886 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
887 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
888 m
.d
.comb
+= self
.a_signed
.eq(asig
)
889 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
894 class Mul8_16_32_64(Elaboratable
):
895 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
897 Supports partitioning into any combination of 8, 16, 32, and 64-bit
898 partitions on naturally-aligned boundaries. Supports the operation being
899 set for each partition independently.
901 :attribute part_pts: the input partition points. Has a partition point at
902 multiples of 8 in 0 < i < 64. Each partition point's associated
903 ``Value`` is a ``Signal``. Modification not supported, except for by
905 :attribute part_ops: the operation for each byte. The operation for a
906 particular partition is selected by assigning the selected operation
907 code to each byte in the partition. The allowed operation codes are:
909 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
910 RISC-V's `mul` instruction.
911 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
912 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
914 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
915 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
916 `mulhsu` instruction.
917 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
918 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
922 def __init__(self
, register_levels
=()):
923 """ register_levels: specifies the points in the cascade at which
924 flip-flops are to be inserted.
928 self
.register_levels
= list(register_levels
)
931 self
.part_pts
= PartitionPoints()
932 for i
in range(8, 64, 8):
933 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
934 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
938 # intermediates (needed for unit tests)
939 self
._intermediate
_output
= Signal(128)
942 self
.output
= Signal(64)
944 def _part_byte(self
, index
):
945 if index
== -1 or index
== 7:
947 assert index
>= 0 and index
< 8
948 return self
.part_pts
[index
* 8 + 8]
950 def elaborate(self
, platform
):
954 pbs
= Signal(8, reset_less
=True)
957 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
958 m
.d
.comb
+= pb
.eq(self
._part
_byte
(i
))
960 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
967 setattr(m
.submodules
, "signs%d" % i
, s
)
968 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
971 [Signal(2, name
=f
"_delayed_part_ops_{delay}_{i}")
973 for delay
in range(1 + len(self
.register_levels
))]
974 for i
in range(len(self
.part_ops
)):
975 m
.d
.comb
+= delayed_part_ops
[0][i
].eq(self
.part_ops
[i
])
976 m
.d
.sync
+= [delayed_part_ops
[j
+ 1][i
].eq(delayed_part_ops
[j
][i
])
977 for j
in range(len(self
.register_levels
))]
979 n_levels
= len(self
.register_levels
)+1
980 m
.submodules
.part_8
= part_8
= Part(128, 8, n_levels
, 8)
981 m
.submodules
.part_16
= part_16
= Part(128, 4, n_levels
, 8)
982 m
.submodules
.part_32
= part_32
= Part(128, 2, n_levels
, 8)
983 m
.submodules
.part_64
= part_64
= Part(128, 1, n_levels
, 8)
984 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
985 for mod
in [part_8
, part_16
, part_32
, part_64
]:
986 m
.d
.comb
+= mod
.a
.eq(self
.a
)
987 m
.d
.comb
+= mod
.b
.eq(self
.b
)
988 for i
in range(len(signs
)):
989 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
990 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
991 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
992 nat_l
.append(mod
.not_a_term
)
993 nbt_l
.append(mod
.not_b_term
)
994 nla_l
.append(mod
.neg_lsb_a_term
)
995 nlb_l
.append(mod
.neg_lsb_b_term
)
999 for a_index
in range(8):
1000 t
= ProductTerms(8, 128, 8, a_index
, 8)
1001 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1003 m
.d
.comb
+= t
.a
.eq(self
.a
)
1004 m
.d
.comb
+= t
.b
.eq(self
.b
)
1005 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1007 for term
in t
.terms
:
1010 # it's fine to bitwise-or data together since they are never enabled
1012 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1013 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1014 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1015 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1016 for l
, mod
in [(nat_l
, nat_or
),
1020 for i
in range(len(l
)):
1021 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1022 terms
.append(mod
.orout
)
1024 expanded_part_pts
= PartitionPoints()
1025 for i
, v
in self
.part_pts
.items():
1026 signal
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
1027 expanded_part_pts
[i
* 2] = signal
1028 m
.d
.comb
+= signal
.eq(v
)
1030 add_reduce
= AddReduce(terms
,
1032 self
.register_levels
,
1034 m
.submodules
.add_reduce
= add_reduce
1035 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
1037 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1038 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
1040 m
.d
.comb
+= io64
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1043 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1044 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
1046 m
.d
.comb
+= io32
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1049 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1050 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
1052 m
.d
.comb
+= io16
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1055 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1056 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
1058 m
.d
.comb
+= io8
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1061 m
.submodules
.finalout
= finalout
= FinalOut(64)
1062 for i
in range(len(part_8
.delayed_parts
[-1])):
1063 m
.d
.comb
+= finalout
.d8
[i
].eq(part_8
.dplast
[i
])
1064 for i
in range(len(part_16
.delayed_parts
[-1])):
1065 m
.d
.comb
+= finalout
.d16
[i
].eq(part_16
.dplast
[i
])
1066 for i
in range(len(part_32
.delayed_parts
[-1])):
1067 m
.d
.comb
+= finalout
.d32
[i
].eq(part_32
.dplast
[i
])
1068 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
1069 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
1070 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
1071 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
1072 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
1077 if __name__
== "__main__":
1081 m
._intermediate
_output
,
1084 *m
.part_pts
.values()])