1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
59 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
60 retval
= PartitionPoints()
61 for point
, enabled
in self
.items():
62 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self
.keys()) != set(rhs
.keys()):
68 raise ValueError("incompatible point set")
69 for point
, enabled
in self
.items():
70 yield enabled
.eq(rhs
[point
])
72 def as_mask(self
, width
):
73 """Create a bit-mask from `self`.
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
78 :param width: the bit width of the resulting mask
81 for i
in range(width
):
88 def get_max_partition_count(self
, width
):
89 """Get the maximum number of partitions.
91 Gets the number of partitions when all partition points are enabled.
94 for point
in self
.keys():
99 def fits_in_width(self
, width
):
100 """Check if all partition points are smaller than `width`."""
101 for point
in self
.keys():
107 class FullAdder(Elaboratable
):
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
122 def __init__(self
, width
):
123 """Create a ``FullAdder``.
125 :param width: the bit width of the input and output
127 self
.in0
= Signal(width
)
128 self
.in1
= Signal(width
)
129 self
.in2
= Signal(width
)
130 self
.sum = Signal(width
)
131 self
.carry
= Signal(width
)
133 def elaborate(self
, platform
):
134 """Elaborate this module."""
136 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
137 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
138 |
(self
.in1
& self
.in2
)
139 |
(self
.in2
& self
.in0
))
143 class MaskedFullAdder(Elaboratable
):
144 """Masked Full Adder.
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
157 See the following discussion as to why this is no longer derived
158 from FullAdder. Each carry is shifted here *before* being ANDed
159 with the mask, so that an AOI cell may be used (which is more
161 https://en.wikipedia.org/wiki/AND-OR-Invert
162 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
165 def __init__(self
, width
):
166 """Create a ``MaskedFullAdder``.
168 :param width: the bit width of the input and output
171 self
.mask
= Signal(width
, reset_less
=True)
172 self
.mcarry
= Signal(width
, reset_less
=True)
173 self
.in0
= Signal(width
, reset_less
=True)
174 self
.in1
= Signal(width
, reset_less
=True)
175 self
.in2
= Signal(width
, reset_less
=True)
176 self
.sum = Signal(width
, reset_less
=True)
178 def elaborate(self
, platform
):
179 """Elaborate this module."""
181 s1
= Signal(self
.width
, reset_less
=True)
182 s2
= Signal(self
.width
, reset_less
=True)
183 s3
= Signal(self
.width
, reset_less
=True)
184 c1
= Signal(self
.width
, reset_less
=True)
185 c2
= Signal(self
.width
, reset_less
=True)
186 c3
= Signal(self
.width
, reset_less
=True)
187 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
188 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
189 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
190 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
191 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
192 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
193 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
194 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
198 class PartitionedAdder(Elaboratable
):
199 """Partitioned Adder.
201 Performs the final add. The partition points are included in the
202 actual add (in one of the operands only), which causes a carry over
203 to the next bit. Then the final output *removes* the extra bits from
206 partition: .... P... P... P... P... (32 bits)
207 a : .... .... .... .... .... (32 bits)
208 b : .... .... .... .... .... (32 bits)
209 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
210 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
211 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
212 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
214 :attribute width: the bit width of the input and output. Read-only.
215 :attribute a: the first input to the adder
216 :attribute b: the second input to the adder
217 :attribute output: the sum output
218 :attribute partition_points: the input partition points. Modification not
219 supported, except for by ``Signal.eq``.
222 def __init__(self
, width
, partition_points
):
223 """Create a ``PartitionedAdder``.
225 :param width: the bit width of the input and output
226 :param partition_points: the input partition points
229 self
.a
= Signal(width
)
230 self
.b
= Signal(width
)
231 self
.output
= Signal(width
)
232 self
.partition_points
= PartitionPoints(partition_points
)
233 if not self
.partition_points
.fits_in_width(width
):
234 raise ValueError("partition_points doesn't fit in width")
236 for i
in range(self
.width
):
237 if i
in self
.partition_points
:
240 self
._expanded
_width
= expanded_width
241 # XXX these have to remain here due to some horrible nmigen
242 # simulation bugs involving sync. it is *not* necessary to
243 # have them here, they should (under normal circumstances)
244 # be moved into elaborate, as they are entirely local
245 self
._expanded
_a
= Signal(expanded_width
) # includes extra part-points
246 self
._expanded
_b
= Signal(expanded_width
) # likewise.
247 self
._expanded
_o
= Signal(expanded_width
) # likewise.
249 def elaborate(self
, platform
):
250 """Elaborate this module."""
253 # store bits in a list, use Cat later. graphviz is much cleaner
254 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
256 # partition points are "breaks" (extra zeros or 1s) in what would
257 # otherwise be a massive long add. when the "break" points are 0,
258 # whatever is in it (in the output) is discarded. however when
259 # there is a "1", it causes a roll-over carry to the *next* bit.
260 # we still ignore the "break" bit in the [intermediate] output,
261 # however by that time we've got the effect that we wanted: the
262 # carry has been carried *over* the break point.
264 for i
in range(self
.width
):
265 if i
in self
.partition_points
:
266 # add extra bit set to 0 + 0 for enabled partition points
267 # and 1 + 0 for disabled partition points
268 ea
.append(self
._expanded
_a
[expanded_index
])
269 al
.append(~self
.partition_points
[i
]) # add extra bit in a
270 eb
.append(self
._expanded
_b
[expanded_index
])
271 bl
.append(C(0)) # yes, add a zero
272 expanded_index
+= 1 # skip the extra point. NOT in the output
273 ea
.append(self
._expanded
_a
[expanded_index
])
274 eb
.append(self
._expanded
_b
[expanded_index
])
275 eo
.append(self
._expanded
_o
[expanded_index
])
278 ol
.append(self
.output
[i
])
281 # combine above using Cat
282 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
283 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
284 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
286 # use only one addition to take advantage of look-ahead carry and
287 # special hardware on FPGAs
288 m
.d
.comb
+= self
._expanded
_o
.eq(
289 self
._expanded
_a
+ self
._expanded
_b
)
293 FULL_ADDER_INPUT_COUNT
= 3
296 class AddReduceSingle(Elaboratable
):
297 """Add list of numbers together.
299 :attribute inputs: input ``Signal``s to be summed. Modification not
300 supported, except for by ``Signal.eq``.
301 :attribute register_levels: List of nesting levels that should have
303 :attribute output: output sum.
304 :attribute partition_points: the input partition points. Modification not
305 supported, except for by ``Signal.eq``.
308 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
310 """Create an ``AddReduce``.
312 :param inputs: input ``Signal``s to be summed.
313 :param output_width: bit-width of ``output``.
314 :param register_levels: List of nesting levels that should have
316 :param partition_points: the input partition points.
318 self
.part_ops
= part_ops
319 self
._part
_ops
= [Signal(2, name
=f
"part_ops_{i}")
320 for i
in range(len(part_ops
))]
321 self
.inputs
= list(inputs
)
322 self
._resized
_inputs
= [
323 Signal(output_width
, name
=f
"resized_inputs[{i}]")
324 for i
in range(len(self
.inputs
))]
325 self
.register_levels
= list(register_levels
)
326 self
.output
= Signal(output_width
)
327 self
.partition_points
= PartitionPoints(partition_points
)
328 if not self
.partition_points
.fits_in_width(output_width
):
329 raise ValueError("partition_points doesn't fit in output_width")
330 self
._reg
_partition
_points
= self
.partition_points
.like()
332 max_level
= AddReduceSingle
.get_max_level(len(self
.inputs
))
333 for level
in self
.register_levels
:
334 if level
> max_level
:
336 "not enough adder levels for specified register levels")
338 self
.groups
= AddReduceSingle
.full_adder_groups(len(self
.inputs
))
339 self
._intermediate
_terms
= []
340 if len(self
.groups
) != 0:
341 self
.create_next_terms()
344 def get_max_level(input_count
):
345 """Get the maximum level.
347 All ``register_levels`` must be less than or equal to the maximum
352 groups
= AddReduceSingle
.full_adder_groups(input_count
)
355 input_count
%= FULL_ADDER_INPUT_COUNT
356 input_count
+= 2 * len(groups
)
360 def full_adder_groups(input_count
):
361 """Get ``inputs`` indices for which a full adder should be built."""
363 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
364 FULL_ADDER_INPUT_COUNT
)
366 def elaborate(self
, platform
):
367 """Elaborate this module."""
370 # resize inputs to correct bit-width and optionally add in
372 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
373 for i
in range(len(self
.inputs
))]
374 copy_part_ops
= [self
._part
_ops
[i
].eq(self
.part_ops
[i
])
375 for i
in range(len(self
.part_ops
))]
376 if 0 in self
.register_levels
:
377 m
.d
.sync
+= copy_part_ops
378 m
.d
.sync
+= resized_input_assignments
379 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
381 m
.d
.comb
+= copy_part_ops
382 m
.d
.comb
+= resized_input_assignments
383 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
385 for (value
, term
) in self
._intermediate
_terms
:
386 m
.d
.comb
+= term
.eq(value
)
388 # if there are no full adders to create, then we handle the base cases
389 # and return, otherwise we go on to the recursive case
390 if len(self
.groups
) == 0:
391 if len(self
.inputs
) == 0:
392 # use 0 as the default output value
393 m
.d
.comb
+= self
.output
.eq(0)
394 elif len(self
.inputs
) == 1:
395 # handle single input
396 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
398 # base case for adding 2 inputs
399 assert len(self
.inputs
) == 2
400 adder
= PartitionedAdder(len(self
.output
),
401 self
._reg
_partition
_points
)
402 m
.submodules
.final_adder
= adder
403 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
404 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
405 m
.d
.comb
+= self
.output
.eq(adder
.output
)
408 mask
= self
._reg
_partition
_points
.as_mask(len(self
.output
))
409 m
.d
.comb
+= self
.part_mask
.eq(mask
)
411 # add and link the intermediate term modules
412 for i
, (iidx
, adder_i
) in enumerate(self
.adders
):
413 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
415 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[iidx
])
416 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[iidx
+ 1])
417 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[iidx
+ 2])
418 m
.d
.comb
+= adder_i
.mask
.eq(self
.part_mask
)
422 def create_next_terms(self
):
424 # go on to prepare recursive case
425 intermediate_terms
= []
426 _intermediate_terms
= []
428 def add_intermediate_term(value
):
429 intermediate_term
= Signal(
431 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
432 _intermediate_terms
.append((value
, intermediate_term
))
433 intermediate_terms
.append(intermediate_term
)
435 # store mask in intermediary (simplifies graph)
436 self
.part_mask
= Signal(len(self
.output
), reset_less
=True)
438 # create full adders for this recursive level.
439 # this shrinks N terms to 2 * (N // 3) plus the remainder
441 for i
in self
.groups
:
442 adder_i
= MaskedFullAdder(len(self
.output
))
443 self
.adders
.append((i
, adder_i
))
444 # add both the sum and the masked-carry to the next level.
445 # 3 inputs have now been reduced to 2...
446 add_intermediate_term(adder_i
.sum)
447 add_intermediate_term(adder_i
.mcarry
)
448 # handle the remaining inputs.
449 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
450 add_intermediate_term(self
._resized
_inputs
[-1])
451 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
452 # Just pass the terms to the next layer, since we wouldn't gain
453 # anything by using a half adder since there would still be 2 terms
454 # and just passing the terms to the next layer saves gates.
455 add_intermediate_term(self
._resized
_inputs
[-2])
456 add_intermediate_term(self
._resized
_inputs
[-1])
458 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
460 self
.intermediate_terms
= intermediate_terms
461 self
._intermediate
_terms
= _intermediate_terms
464 class AddReduce(Elaboratable
):
465 """Recursively Add list of numbers together.
467 :attribute inputs: input ``Signal``s to be summed. Modification not
468 supported, except for by ``Signal.eq``.
469 :attribute register_levels: List of nesting levels that should have
471 :attribute output: output sum.
472 :attribute partition_points: the input partition points. Modification not
473 supported, except for by ``Signal.eq``.
476 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
478 """Create an ``AddReduce``.
480 :param inputs: input ``Signal``s to be summed.
481 :param output_width: bit-width of ``output``.
482 :param register_levels: List of nesting levels that should have
484 :param partition_points: the input partition points.
487 self
.part_ops
= part_ops
488 self
.output
= Signal(output_width
)
489 self
.output_width
= output_width
490 self
.register_levels
= register_levels
491 self
.partition_points
= partition_points
496 def next_register_levels(register_levels
):
497 """``Iterable`` of ``register_levels`` for next recursive level."""
498 for level
in register_levels
:
502 def create_levels(self
):
503 """creates reduction levels"""
506 next_levels
= self
.register_levels
507 partition_points
= self
.partition_points
510 next_level
= AddReduceSingle(inputs
, self
.output_width
, next_levels
,
511 partition_points
, self
.part_ops
)
512 mods
.append(next_level
)
513 if len(next_level
.groups
) == 0:
515 next_levels
= list(AddReduce
.next_register_levels(next_levels
))
516 partition_points
= next_level
._reg
_partition
_points
517 inputs
= next_level
.intermediate_terms
521 def elaborate(self
, platform
):
522 """Elaborate this module."""
525 for i
, next_level
in enumerate(self
.levels
):
526 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
528 # output comes from last module
529 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
535 OP_MUL_SIGNED_HIGH
= 1
536 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
537 OP_MUL_UNSIGNED_HIGH
= 3
540 def get_term(value
, shift
=0, enabled
=None):
541 if enabled
is not None:
542 value
= Mux(enabled
, value
, 0)
544 value
= Cat(Repl(C(0, 1), shift
), value
)
550 class ProductTerm(Elaboratable
):
551 """ this class creates a single product term (a[..]*b[..]).
552 it has a design flaw in that is the *output* that is selected,
553 where the multiplication(s) are combinatorially generated
557 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
558 self
.a_index
= a_index
559 self
.b_index
= b_index
560 shift
= 8 * (self
.a_index
+ self
.b_index
)
566 self
.ti
= Signal(self
.width
, reset_less
=True)
567 self
.term
= Signal(twidth
, reset_less
=True)
568 self
.a
= Signal(twidth
//2, reset_less
=True)
569 self
.b
= Signal(twidth
//2, reset_less
=True)
570 self
.pb_en
= Signal(pbwid
, reset_less
=True)
573 min_index
= min(self
.a_index
, self
.b_index
)
574 max_index
= max(self
.a_index
, self
.b_index
)
575 for i
in range(min_index
, max_index
):
576 tl
.append(self
.pb_en
[i
])
577 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
579 term_enabled
= Signal(name
=name
, reset_less
=True)
582 self
.enabled
= term_enabled
583 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
585 def elaborate(self
, platform
):
588 if self
.enabled
is not None:
589 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
591 bsa
= Signal(self
.width
, reset_less
=True)
592 bsb
= Signal(self
.width
, reset_less
=True)
593 a_index
, b_index
= self
.a_index
, self
.b_index
595 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
596 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
597 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
598 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
600 #TODO: sort out width issues, get inputs a/b switched on/off.
601 #data going into Muxes is 1/2 the required width
605 bsa = Signal(self.twidth//2, reset_less=True)
606 bsb = Signal(self.twidth//2, reset_less=True)
607 asel = Signal(width, reset_less=True)
608 bsel = Signal(width, reset_less=True)
609 a_index, b_index = self.a_index, self.b_index
610 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
611 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
612 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
613 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
614 m.d.comb += self.ti.eq(bsa * bsb)
615 m.d.comb += self.term.eq(self.ti)
621 class ProductTerms(Elaboratable
):
622 """ creates a bank of product terms. also performs the actual bit-selection
623 this class is to be wrapped with a for-loop on the "a" operand.
624 it creates a second-level for-loop on the "b" operand.
626 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
627 self
.a_index
= a_index
632 self
.a
= Signal(twidth
//2, reset_less
=True)
633 self
.b
= Signal(twidth
//2, reset_less
=True)
634 self
.pb_en
= Signal(pbwid
, reset_less
=True)
635 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
636 for i
in range(blen
)]
638 def elaborate(self
, platform
):
642 for b_index
in range(self
.blen
):
643 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
644 self
.a_index
, b_index
)
645 setattr(m
.submodules
, "term_%d" % b_index
, t
)
647 m
.d
.comb
+= t
.a
.eq(self
.a
)
648 m
.d
.comb
+= t
.b
.eq(self
.b
)
649 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
651 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
656 class LSBNegTerm(Elaboratable
):
658 def __init__(self
, bit_width
):
659 self
.bit_width
= bit_width
660 self
.part
= Signal(reset_less
=True)
661 self
.signed
= Signal(reset_less
=True)
662 self
.op
= Signal(bit_width
, reset_less
=True)
663 self
.msb
= Signal(reset_less
=True)
664 self
.nt
= Signal(bit_width
*2, reset_less
=True)
665 self
.nl
= Signal(bit_width
*2, reset_less
=True)
667 def elaborate(self
, platform
):
670 bit_wid
= self
.bit_width
671 ext
= Repl(0, bit_wid
) # extend output to HI part
673 # determine sign of each incoming number *in this partition*
674 enabled
= Signal(reset_less
=True)
675 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
677 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
678 # negation operation is split into a bitwise not and a +1.
679 # likewise for 16, 32, and 64-bit values.
681 # width-extended 1s complement if a is signed, otherwise zero
682 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
684 # add 1 if signed, otherwise add zero
685 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
690 class Part(Elaboratable
):
691 """ a key class which, depending on the partitioning, will determine
692 what action to take when parts of the output are signed or unsigned.
694 this requires 2 pieces of data *per operand, per partition*:
695 whether the MSB is HI/LO (per partition!), and whether a signed
696 or unsigned operation has been *requested*.
698 once that is determined, signed is basically carried out
699 by splitting 2's complement into 1's complement plus one.
700 1's complement is just a bit-inversion.
702 the extra terms - as separate terms - are then thrown at the
703 AddReduce alongside the multiplication part-results.
705 def __init__(self
, width
, n_parts
, n_levels
, pbwid
):
710 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
711 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
712 self
.pbs
= Signal(pbwid
, reset_less
=True)
715 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
716 self
.delayed_parts
= [
717 [Signal(name
=f
"delayed_part_{delay}_{i}")
718 for i
in range(n_parts
)]
719 for delay
in range(n_levels
)]
720 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
721 self
.dplast
= [Signal(name
=f
"dplast_{i}")
722 for i
in range(n_parts
)]
724 self
.not_a_term
= Signal(width
)
725 self
.neg_lsb_a_term
= Signal(width
)
726 self
.not_b_term
= Signal(width
)
727 self
.neg_lsb_b_term
= Signal(width
)
729 def elaborate(self
, platform
):
732 pbs
, parts
, delayed_parts
= self
.pbs
, self
.parts
, self
.delayed_parts
733 # negated-temporary copy of partition bits
734 npbs
= Signal
.like(pbs
, reset_less
=True)
735 m
.d
.comb
+= npbs
.eq(~pbs
)
736 byte_count
= 8 // len(parts
)
737 for i
in range(len(parts
)):
739 pbl
.append(npbs
[i
* byte_count
- 1])
740 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
742 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
743 value
= Signal(len(pbl
), name
="value_%di" % i
, reset_less
=True)
744 m
.d
.comb
+= value
.eq(Cat(*pbl
))
745 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
746 m
.d
.comb
+= delayed_parts
[0][i
].eq(parts
[i
])
747 m
.d
.sync
+= [delayed_parts
[j
+ 1][i
].eq(delayed_parts
[j
][i
])
748 for j
in range(len(delayed_parts
)-1)]
749 m
.d
.comb
+= self
.dplast
[i
].eq(delayed_parts
[-1][i
])
751 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= \
752 self
.not_a_term
, self
.neg_lsb_a_term
, \
753 self
.not_b_term
, self
.neg_lsb_b_term
755 byte_width
= 8 // len(parts
) # byte width
756 bit_wid
= 8 * byte_width
# bit width
757 nat
, nbt
, nla
, nlb
= [], [], [], []
758 for i
in range(len(parts
)):
759 # work out bit-inverted and +1 term for a.
760 pa
= LSBNegTerm(bit_wid
)
761 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
762 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
763 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
764 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
765 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
769 # work out bit-inverted and +1 term for b
770 pb
= LSBNegTerm(bit_wid
)
771 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
772 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
773 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
774 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
775 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
779 # concatenate together and return all 4 results.
780 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
781 not_b_term
.eq(Cat(*nbt
)),
782 neg_lsb_a_term
.eq(Cat(*nla
)),
783 neg_lsb_b_term
.eq(Cat(*nlb
)),
789 class IntermediateOut(Elaboratable
):
790 """ selects the HI/LO part of the multiplication, for a given bit-width
791 the output is also reconstructed in its SIMD (partition) lanes.
793 def __init__(self
, width
, out_wid
, n_parts
):
795 self
.n_parts
= n_parts
796 self
.delayed_part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
798 self
.intermed
= Signal(out_wid
, reset_less
=True)
799 self
.output
= Signal(out_wid
//2, reset_less
=True)
801 def elaborate(self
, platform
):
807 for i
in range(self
.n_parts
):
808 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
810 Mux(self
.delayed_part_ops
[sel
* i
] == OP_MUL_LOW
,
811 self
.intermed
.bit_select(i
* w
*2, w
),
812 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
814 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
819 class FinalOut(Elaboratable
):
820 """ selects the final output based on the partitioning.
822 each byte is selectable independently, i.e. it is possible
823 that some partitions requested 8-bit computation whilst others
824 requested 16 or 32 bit.
826 def __init__(self
, out_wid
):
828 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
829 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
830 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
832 self
.i8
= Signal(out_wid
, reset_less
=True)
833 self
.i16
= Signal(out_wid
, reset_less
=True)
834 self
.i32
= Signal(out_wid
, reset_less
=True)
835 self
.i64
= Signal(out_wid
, reset_less
=True)
838 self
.out
= Signal(out_wid
, reset_less
=True)
840 def elaborate(self
, platform
):
844 # select one of the outputs: d8 selects i8, d16 selects i16
845 # d32 selects i32, and the default is i64.
846 # d8 and d16 are ORed together in the first Mux
847 # then the 2nd selects either i8 or i16.
848 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
849 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
851 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
852 Mux(self
.d8
[i
], self
.i8
.bit_select(i
* 8, 8),
853 self
.i16
.bit_select(i
* 8, 8)),
854 Mux(self
.d32
[i
// 4], self
.i32
.bit_select(i
* 8, 8),
855 self
.i64
.bit_select(i
* 8, 8))))
857 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
861 class OrMod(Elaboratable
):
862 """ ORs four values together in a hierarchical tree
864 def __init__(self
, wid
):
866 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
868 self
.orout
= Signal(wid
, reset_less
=True)
870 def elaborate(self
, platform
):
872 or1
= Signal(self
.wid
, reset_less
=True)
873 or2
= Signal(self
.wid
, reset_less
=True)
874 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
875 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
876 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
881 class Signs(Elaboratable
):
882 """ determines whether a or b are signed numbers
883 based on the required operation type (OP_MUL_*)
887 self
.part_ops
= Signal(2, reset_less
=True)
888 self
.a_signed
= Signal(reset_less
=True)
889 self
.b_signed
= Signal(reset_less
=True)
891 def elaborate(self
, platform
):
895 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
896 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
897 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
898 m
.d
.comb
+= self
.a_signed
.eq(asig
)
899 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
904 class Mul8_16_32_64(Elaboratable
):
905 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
907 Supports partitioning into any combination of 8, 16, 32, and 64-bit
908 partitions on naturally-aligned boundaries. Supports the operation being
909 set for each partition independently.
911 :attribute part_pts: the input partition points. Has a partition point at
912 multiples of 8 in 0 < i < 64. Each partition point's associated
913 ``Value`` is a ``Signal``. Modification not supported, except for by
915 :attribute part_ops: the operation for each byte. The operation for a
916 particular partition is selected by assigning the selected operation
917 code to each byte in the partition. The allowed operation codes are:
919 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
920 RISC-V's `mul` instruction.
921 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
922 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
924 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
925 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
926 `mulhsu` instruction.
927 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
928 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
932 def __init__(self
, register_levels
=()):
933 """ register_levels: specifies the points in the cascade at which
934 flip-flops are to be inserted.
938 self
.register_levels
= list(register_levels
)
941 self
.part_pts
= PartitionPoints()
942 for i
in range(8, 64, 8):
943 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
944 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
948 # intermediates (needed for unit tests)
949 self
._intermediate
_output
= Signal(128)
952 self
.output
= Signal(64)
954 def _part_byte(self
, index
):
955 if index
== -1 or index
== 7:
957 assert index
>= 0 and index
< 8
958 return self
.part_pts
[index
* 8 + 8]
960 def elaborate(self
, platform
):
964 pbs
= Signal(8, reset_less
=True)
967 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
968 m
.d
.comb
+= pb
.eq(self
._part
_byte
(i
))
970 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
977 setattr(m
.submodules
, "signs%d" % i
, s
)
978 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
981 [Signal(2, name
=f
"_delayed_part_ops_{delay}_{i}")
983 for delay
in range(1 + len(self
.register_levels
))]
984 for i
in range(len(self
.part_ops
)):
985 m
.d
.comb
+= delayed_part_ops
[0][i
].eq(self
.part_ops
[i
])
986 m
.d
.sync
+= [delayed_part_ops
[j
+ 1][i
].eq(delayed_part_ops
[j
][i
])
987 for j
in range(len(self
.register_levels
))]
989 n_levels
= len(self
.register_levels
)+1
990 m
.submodules
.part_8
= part_8
= Part(128, 8, n_levels
, 8)
991 m
.submodules
.part_16
= part_16
= Part(128, 4, n_levels
, 8)
992 m
.submodules
.part_32
= part_32
= Part(128, 2, n_levels
, 8)
993 m
.submodules
.part_64
= part_64
= Part(128, 1, n_levels
, 8)
994 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
995 for mod
in [part_8
, part_16
, part_32
, part_64
]:
996 m
.d
.comb
+= mod
.a
.eq(self
.a
)
997 m
.d
.comb
+= mod
.b
.eq(self
.b
)
998 for i
in range(len(signs
)):
999 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
1000 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
1001 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
1002 nat_l
.append(mod
.not_a_term
)
1003 nbt_l
.append(mod
.not_b_term
)
1004 nla_l
.append(mod
.neg_lsb_a_term
)
1005 nlb_l
.append(mod
.neg_lsb_b_term
)
1009 for a_index
in range(8):
1010 t
= ProductTerms(8, 128, 8, a_index
, 8)
1011 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1013 m
.d
.comb
+= t
.a
.eq(self
.a
)
1014 m
.d
.comb
+= t
.b
.eq(self
.b
)
1015 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1017 for term
in t
.terms
:
1020 # it's fine to bitwise-or data together since they are never enabled
1022 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1023 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1024 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1025 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1026 for l
, mod
in [(nat_l
, nat_or
),
1030 for i
in range(len(l
)):
1031 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1032 terms
.append(mod
.orout
)
1034 expanded_part_pts
= PartitionPoints()
1035 for i
, v
in self
.part_pts
.items():
1036 signal
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
1037 expanded_part_pts
[i
* 2] = signal
1038 m
.d
.comb
+= signal
.eq(v
)
1040 add_reduce
= AddReduce(terms
,
1042 self
.register_levels
,
1046 m
.submodules
.add_reduce
= add_reduce
1047 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
1049 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1050 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
1052 m
.d
.comb
+= io64
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1055 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1056 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
1058 m
.d
.comb
+= io32
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1061 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1062 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
1064 m
.d
.comb
+= io16
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1067 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1068 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
1070 m
.d
.comb
+= io8
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1073 m
.submodules
.finalout
= finalout
= FinalOut(64)
1074 for i
in range(len(part_8
.delayed_parts
[-1])):
1075 m
.d
.comb
+= finalout
.d8
[i
].eq(part_8
.dplast
[i
])
1076 for i
in range(len(part_16
.delayed_parts
[-1])):
1077 m
.d
.comb
+= finalout
.d16
[i
].eq(part_16
.dplast
[i
])
1078 for i
in range(len(part_32
.delayed_parts
[-1])):
1079 m
.d
.comb
+= finalout
.d32
[i
].eq(part_32
.dplast
[i
])
1080 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
1081 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
1082 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
1083 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
1084 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
1089 if __name__
== "__main__":
1093 m
._intermediate
_output
,
1096 *m
.part_pts
.values()])