1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0, mul
=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
60 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
61 retval
= PartitionPoints()
62 for point
, enabled
in self
.items():
64 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self
.keys()) != set(rhs
.keys()):
70 raise ValueError("incompatible point set")
71 for point
, enabled
in self
.items():
72 yield enabled
.eq(rhs
[point
])
74 def as_mask(self
, width
):
75 """Create a bit-mask from `self`.
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
80 :param width: the bit width of the resulting mask
83 for i
in range(width
):
90 def get_max_partition_count(self
, width
):
91 """Get the maximum number of partitions.
93 Gets the number of partitions when all partition points are enabled.
96 for point
in self
.keys():
101 def fits_in_width(self
, width
):
102 """Check if all partition points are smaller than `width`."""
103 for point
in self
.keys():
108 def part_byte(self
, index
, mfactor
=1): # mfactor used for "expanding"
109 if index
== -1 or index
== 7:
111 assert index
>= 0 and index
< 8
112 return self
[(index
* 8 + 8)*mfactor
]
115 class FullAdder(Elaboratable
):
118 :attribute in0: the first input
119 :attribute in1: the second input
120 :attribute in2: the third input
121 :attribute sum: the sum output
122 :attribute carry: the carry output
124 Rather than do individual full adders (and have an array of them,
125 which would be very slow to simulate), this module can specify the
126 bit width of the inputs and outputs: in effect it performs multiple
127 Full 3-2 Add operations "in parallel".
130 def __init__(self
, width
):
131 """Create a ``FullAdder``.
133 :param width: the bit width of the input and output
135 self
.in0
= Signal(width
)
136 self
.in1
= Signal(width
)
137 self
.in2
= Signal(width
)
138 self
.sum = Signal(width
)
139 self
.carry
= Signal(width
)
141 def elaborate(self
, platform
):
142 """Elaborate this module."""
144 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
145 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
146 |
(self
.in1
& self
.in2
)
147 |
(self
.in2
& self
.in0
))
151 class MaskedFullAdder(Elaboratable
):
152 """Masked Full Adder.
154 :attribute mask: the carry partition mask
155 :attribute in0: the first input
156 :attribute in1: the second input
157 :attribute in2: the third input
158 :attribute sum: the sum output
159 :attribute mcarry: the masked carry output
161 FullAdders are always used with a "mask" on the output. To keep
162 the graphviz "clean", this class performs the masking here rather
163 than inside a large for-loop.
165 See the following discussion as to why this is no longer derived
166 from FullAdder. Each carry is shifted here *before* being ANDed
167 with the mask, so that an AOI cell may be used (which is more
169 https://en.wikipedia.org/wiki/AND-OR-Invert
170 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
173 def __init__(self
, width
):
174 """Create a ``MaskedFullAdder``.
176 :param width: the bit width of the input and output
179 self
.mask
= Signal(width
, reset_less
=True)
180 self
.mcarry
= Signal(width
, reset_less
=True)
181 self
.in0
= Signal(width
, reset_less
=True)
182 self
.in1
= Signal(width
, reset_less
=True)
183 self
.in2
= Signal(width
, reset_less
=True)
184 self
.sum = Signal(width
, reset_less
=True)
186 def elaborate(self
, platform
):
187 """Elaborate this module."""
189 s1
= Signal(self
.width
, reset_less
=True)
190 s2
= Signal(self
.width
, reset_less
=True)
191 s3
= Signal(self
.width
, reset_less
=True)
192 c1
= Signal(self
.width
, reset_less
=True)
193 c2
= Signal(self
.width
, reset_less
=True)
194 c3
= Signal(self
.width
, reset_less
=True)
195 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
196 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
197 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
198 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
199 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
200 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
201 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
202 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
206 class PartitionedAdder(Elaboratable
):
207 """Partitioned Adder.
209 Performs the final add. The partition points are included in the
210 actual add (in one of the operands only), which causes a carry over
211 to the next bit. Then the final output *removes* the extra bits from
214 partition: .... P... P... P... P... (32 bits)
215 a : .... .... .... .... .... (32 bits)
216 b : .... .... .... .... .... (32 bits)
217 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
218 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
219 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
220 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
222 :attribute width: the bit width of the input and output. Read-only.
223 :attribute a: the first input to the adder
224 :attribute b: the second input to the adder
225 :attribute output: the sum output
226 :attribute partition_points: the input partition points. Modification not
227 supported, except for by ``Signal.eq``.
230 def __init__(self
, width
, partition_points
):
231 """Create a ``PartitionedAdder``.
233 :param width: the bit width of the input and output
234 :param partition_points: the input partition points
237 self
.a
= Signal(width
)
238 self
.b
= Signal(width
)
239 self
.output
= Signal(width
)
240 self
.partition_points
= PartitionPoints(partition_points
)
241 if not self
.partition_points
.fits_in_width(width
):
242 raise ValueError("partition_points doesn't fit in width")
244 for i
in range(self
.width
):
245 if i
in self
.partition_points
:
248 self
._expanded
_width
= expanded_width
249 # XXX these have to remain here due to some horrible nmigen
250 # simulation bugs involving sync. it is *not* necessary to
251 # have them here, they should (under normal circumstances)
252 # be moved into elaborate, as they are entirely local
253 self
._expanded
_a
= Signal(expanded_width
) # includes extra part-points
254 self
._expanded
_b
= Signal(expanded_width
) # likewise.
255 self
._expanded
_o
= Signal(expanded_width
) # likewise.
257 def elaborate(self
, platform
):
258 """Elaborate this module."""
261 # store bits in a list, use Cat later. graphviz is much cleaner
262 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
264 # partition points are "breaks" (extra zeros or 1s) in what would
265 # otherwise be a massive long add. when the "break" points are 0,
266 # whatever is in it (in the output) is discarded. however when
267 # there is a "1", it causes a roll-over carry to the *next* bit.
268 # we still ignore the "break" bit in the [intermediate] output,
269 # however by that time we've got the effect that we wanted: the
270 # carry has been carried *over* the break point.
272 for i
in range(self
.width
):
273 if i
in self
.partition_points
:
274 # add extra bit set to 0 + 0 for enabled partition points
275 # and 1 + 0 for disabled partition points
276 ea
.append(self
._expanded
_a
[expanded_index
])
277 al
.append(~self
.partition_points
[i
]) # add extra bit in a
278 eb
.append(self
._expanded
_b
[expanded_index
])
279 bl
.append(C(0)) # yes, add a zero
280 expanded_index
+= 1 # skip the extra point. NOT in the output
281 ea
.append(self
._expanded
_a
[expanded_index
])
282 eb
.append(self
._expanded
_b
[expanded_index
])
283 eo
.append(self
._expanded
_o
[expanded_index
])
286 ol
.append(self
.output
[i
])
289 # combine above using Cat
290 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
291 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
292 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
294 # use only one addition to take advantage of look-ahead carry and
295 # special hardware on FPGAs
296 m
.d
.comb
+= self
._expanded
_o
.eq(
297 self
._expanded
_a
+ self
._expanded
_b
)
301 FULL_ADDER_INPUT_COUNT
= 3
305 def __init__(self
, ppoints
, output_width
, n_parts
):
306 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}")
307 for i
in range(n_parts
)]
308 self
.inputs
= [Signal(output_width
, name
=f
"inputs[{i}]")
309 for i
in range(len(self
.inputs
))]
310 self
.reg_partition_points
= partition_points
.like()
313 return [self
.reg_partition_points
.eq(rhs
.reg_partition_points
)] + \
314 [self
.inputs
[i
].eq(rhs
.inputs
[i
])
315 for i
in range(len(self
.inputs
))] + \
316 [self
.part_ops
[i
].eq(rhs
.part_ops
[i
])
317 for i
in range(len(self
.part_ops
))]
320 class AddReduceSingle(Elaboratable
):
321 """Add list of numbers together.
323 :attribute inputs: input ``Signal``s to be summed. Modification not
324 supported, except for by ``Signal.eq``.
325 :attribute register_levels: List of nesting levels that should have
327 :attribute output: output sum.
328 :attribute partition_points: the input partition points. Modification not
329 supported, except for by ``Signal.eq``.
332 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
334 """Create an ``AddReduce``.
336 :param inputs: input ``Signal``s to be summed.
337 :param output_width: bit-width of ``output``.
338 :param register_levels: List of nesting levels that should have
340 :param partition_points: the input partition points.
342 self
.part_ops
= part_ops
343 self
.out_part_ops
= [Signal(2, name
=f
"out_part_ops_{i}")
344 for i
in range(len(part_ops
))]
345 self
.inputs
= list(inputs
)
346 self
._resized
_inputs
= [
347 Signal(output_width
, name
=f
"resized_inputs[{i}]")
348 for i
in range(len(self
.inputs
))]
349 self
.register_levels
= list(register_levels
)
350 self
.output
= Signal(output_width
)
351 self
.partition_points
= PartitionPoints(partition_points
)
352 if not self
.partition_points
.fits_in_width(output_width
):
353 raise ValueError("partition_points doesn't fit in output_width")
354 self
._reg
_partition
_points
= self
.partition_points
.like()
356 max_level
= AddReduceSingle
.get_max_level(len(self
.inputs
))
357 for level
in self
.register_levels
:
358 if level
> max_level
:
360 "not enough adder levels for specified register levels")
362 # this is annoying. we have to create the modules (and terms)
363 # because we need to know what they are (in order to set up the
364 # interconnects back in AddReduce), but cannot do the m.d.comb +=
365 # etc because this is not in elaboratable.
366 self
.groups
= AddReduceSingle
.full_adder_groups(len(self
.inputs
))
367 self
._intermediate
_terms
= []
368 if len(self
.groups
) != 0:
369 self
.create_next_terms()
372 def get_max_level(input_count
):
373 """Get the maximum level.
375 All ``register_levels`` must be less than or equal to the maximum
380 groups
= AddReduceSingle
.full_adder_groups(input_count
)
383 input_count
%= FULL_ADDER_INPUT_COUNT
384 input_count
+= 2 * len(groups
)
388 def full_adder_groups(input_count
):
389 """Get ``inputs`` indices for which a full adder should be built."""
391 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
392 FULL_ADDER_INPUT_COUNT
)
394 def elaborate(self
, platform
):
395 """Elaborate this module."""
398 # resize inputs to correct bit-width and optionally add in
400 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
401 for i
in range(len(self
.inputs
))]
402 copy_part_ops
= [self
.out_part_ops
[i
].eq(self
.part_ops
[i
])
403 for i
in range(len(self
.part_ops
))]
404 if 0 in self
.register_levels
:
405 m
.d
.sync
+= copy_part_ops
406 m
.d
.sync
+= resized_input_assignments
407 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
409 m
.d
.comb
+= copy_part_ops
410 m
.d
.comb
+= resized_input_assignments
411 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
413 for (value
, term
) in self
._intermediate
_terms
:
414 m
.d
.comb
+= term
.eq(value
)
416 # if there are no full adders to create, then we handle the base cases
417 # and return, otherwise we go on to the recursive case
418 if len(self
.groups
) == 0:
419 if len(self
.inputs
) == 0:
420 # use 0 as the default output value
421 m
.d
.comb
+= self
.output
.eq(0)
422 elif len(self
.inputs
) == 1:
423 # handle single input
424 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
426 # base case for adding 2 inputs
427 assert len(self
.inputs
) == 2
428 adder
= PartitionedAdder(len(self
.output
),
429 self
._reg
_partition
_points
)
430 m
.submodules
.final_adder
= adder
431 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
432 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
433 m
.d
.comb
+= self
.output
.eq(adder
.output
)
436 mask
= self
._reg
_partition
_points
.as_mask(len(self
.output
))
437 m
.d
.comb
+= self
.part_mask
.eq(mask
)
439 # add and link the intermediate term modules
440 for i
, (iidx
, adder_i
) in enumerate(self
.adders
):
441 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
443 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[iidx
])
444 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[iidx
+ 1])
445 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[iidx
+ 2])
446 m
.d
.comb
+= adder_i
.mask
.eq(self
.part_mask
)
450 def create_next_terms(self
):
452 # go on to prepare recursive case
453 intermediate_terms
= []
454 _intermediate_terms
= []
456 def add_intermediate_term(value
):
457 intermediate_term
= Signal(
459 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
460 _intermediate_terms
.append((value
, intermediate_term
))
461 intermediate_terms
.append(intermediate_term
)
463 # store mask in intermediary (simplifies graph)
464 self
.part_mask
= Signal(len(self
.output
), reset_less
=True)
466 # create full adders for this recursive level.
467 # this shrinks N terms to 2 * (N // 3) plus the remainder
469 for i
in self
.groups
:
470 adder_i
= MaskedFullAdder(len(self
.output
))
471 self
.adders
.append((i
, adder_i
))
472 # add both the sum and the masked-carry to the next level.
473 # 3 inputs have now been reduced to 2...
474 add_intermediate_term(adder_i
.sum)
475 add_intermediate_term(adder_i
.mcarry
)
476 # handle the remaining inputs.
477 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
478 add_intermediate_term(self
._resized
_inputs
[-1])
479 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
480 # Just pass the terms to the next layer, since we wouldn't gain
481 # anything by using a half adder since there would still be 2 terms
482 # and just passing the terms to the next layer saves gates.
483 add_intermediate_term(self
._resized
_inputs
[-2])
484 add_intermediate_term(self
._resized
_inputs
[-1])
486 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
488 self
.intermediate_terms
= intermediate_terms
489 self
._intermediate
_terms
= _intermediate_terms
492 class AddReduce(Elaboratable
):
493 """Recursively Add list of numbers together.
495 :attribute inputs: input ``Signal``s to be summed. Modification not
496 supported, except for by ``Signal.eq``.
497 :attribute register_levels: List of nesting levels that should have
499 :attribute output: output sum.
500 :attribute partition_points: the input partition points. Modification not
501 supported, except for by ``Signal.eq``.
504 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
506 """Create an ``AddReduce``.
508 :param inputs: input ``Signal``s to be summed.
509 :param output_width: bit-width of ``output``.
510 :param register_levels: List of nesting levels that should have
512 :param partition_points: the input partition points.
515 self
.part_ops
= part_ops
516 self
.out_part_ops
= [Signal(2, name
=f
"out_part_ops_{i}")
517 for i
in range(len(part_ops
))]
518 self
.output
= Signal(output_width
)
519 self
.output_width
= output_width
520 self
.register_levels
= register_levels
521 self
.partition_points
= partition_points
526 def get_max_level(input_count
):
527 return AddReduceSingle
.get_max_level(input_count
)
530 def next_register_levels(register_levels
):
531 """``Iterable`` of ``register_levels`` for next recursive level."""
532 for level
in register_levels
:
536 def create_levels(self
):
537 """creates reduction levels"""
540 next_levels
= self
.register_levels
541 partition_points
= self
.partition_points
543 part_ops
= self
.part_ops
545 next_level
= AddReduceSingle(inputs
, self
.output_width
, next_levels
,
546 partition_points
, part_ops
)
547 mods
.append(next_level
)
548 if len(next_level
.groups
) == 0:
550 next_levels
= list(AddReduce
.next_register_levels(next_levels
))
551 partition_points
= next_level
._reg
_partition
_points
552 inputs
= next_level
.intermediate_terms
553 part_ops
= next_level
.out_part_ops
557 def elaborate(self
, platform
):
558 """Elaborate this module."""
561 for i
, next_level
in enumerate(self
.levels
):
562 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
564 # output comes from last module
565 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
566 copy_part_ops
= [self
.out_part_ops
[i
].eq(next_level
.out_part_ops
[i
])
567 for i
in range(len(self
.part_ops
))]
568 m
.d
.comb
+= copy_part_ops
574 OP_MUL_SIGNED_HIGH
= 1
575 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
576 OP_MUL_UNSIGNED_HIGH
= 3
579 def get_term(value
, shift
=0, enabled
=None):
580 if enabled
is not None:
581 value
= Mux(enabled
, value
, 0)
583 value
= Cat(Repl(C(0, 1), shift
), value
)
589 class ProductTerm(Elaboratable
):
590 """ this class creates a single product term (a[..]*b[..]).
591 it has a design flaw in that is the *output* that is selected,
592 where the multiplication(s) are combinatorially generated
596 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
597 self
.a_index
= a_index
598 self
.b_index
= b_index
599 shift
= 8 * (self
.a_index
+ self
.b_index
)
605 self
.ti
= Signal(self
.width
, reset_less
=True)
606 self
.term
= Signal(twidth
, reset_less
=True)
607 self
.a
= Signal(twidth
//2, reset_less
=True)
608 self
.b
= Signal(twidth
//2, reset_less
=True)
609 self
.pb_en
= Signal(pbwid
, reset_less
=True)
612 min_index
= min(self
.a_index
, self
.b_index
)
613 max_index
= max(self
.a_index
, self
.b_index
)
614 for i
in range(min_index
, max_index
):
615 tl
.append(self
.pb_en
[i
])
616 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
618 term_enabled
= Signal(name
=name
, reset_less
=True)
621 self
.enabled
= term_enabled
622 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
624 def elaborate(self
, platform
):
627 if self
.enabled
is not None:
628 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
630 bsa
= Signal(self
.width
, reset_less
=True)
631 bsb
= Signal(self
.width
, reset_less
=True)
632 a_index
, b_index
= self
.a_index
, self
.b_index
634 m
.d
.comb
+= bsa
.eq(self
.a
.part(a_index
* pwidth
, pwidth
))
635 m
.d
.comb
+= bsb
.eq(self
.b
.part(b_index
* pwidth
, pwidth
))
636 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
637 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
639 #TODO: sort out width issues, get inputs a/b switched on/off.
640 #data going into Muxes is 1/2 the required width
644 bsa = Signal(self.twidth//2, reset_less=True)
645 bsb = Signal(self.twidth//2, reset_less=True)
646 asel = Signal(width, reset_less=True)
647 bsel = Signal(width, reset_less=True)
648 a_index, b_index = self.a_index, self.b_index
649 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
650 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
651 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
652 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
653 m.d.comb += self.ti.eq(bsa * bsb)
654 m.d.comb += self.term.eq(self.ti)
660 class ProductTerms(Elaboratable
):
661 """ creates a bank of product terms. also performs the actual bit-selection
662 this class is to be wrapped with a for-loop on the "a" operand.
663 it creates a second-level for-loop on the "b" operand.
665 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
666 self
.a_index
= a_index
671 self
.a
= Signal(twidth
//2, reset_less
=True)
672 self
.b
= Signal(twidth
//2, reset_less
=True)
673 self
.pb_en
= Signal(pbwid
, reset_less
=True)
674 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
675 for i
in range(blen
)]
677 def elaborate(self
, platform
):
681 for b_index
in range(self
.blen
):
682 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
683 self
.a_index
, b_index
)
684 setattr(m
.submodules
, "term_%d" % b_index
, t
)
686 m
.d
.comb
+= t
.a
.eq(self
.a
)
687 m
.d
.comb
+= t
.b
.eq(self
.b
)
688 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
690 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
695 class LSBNegTerm(Elaboratable
):
697 def __init__(self
, bit_width
):
698 self
.bit_width
= bit_width
699 self
.part
= Signal(reset_less
=True)
700 self
.signed
= Signal(reset_less
=True)
701 self
.op
= Signal(bit_width
, reset_less
=True)
702 self
.msb
= Signal(reset_less
=True)
703 self
.nt
= Signal(bit_width
*2, reset_less
=True)
704 self
.nl
= Signal(bit_width
*2, reset_less
=True)
706 def elaborate(self
, platform
):
709 bit_wid
= self
.bit_width
710 ext
= Repl(0, bit_wid
) # extend output to HI part
712 # determine sign of each incoming number *in this partition*
713 enabled
= Signal(reset_less
=True)
714 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
716 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
717 # negation operation is split into a bitwise not and a +1.
718 # likewise for 16, 32, and 64-bit values.
720 # width-extended 1s complement if a is signed, otherwise zero
721 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
723 # add 1 if signed, otherwise add zero
724 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
729 class Parts(Elaboratable
):
731 def __init__(self
, pbwid
, epps
, n_parts
):
734 self
.epps
= PartitionPoints
.like(epps
, name
="epps") # expanded points
736 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
738 def elaborate(self
, platform
):
741 epps
, parts
= self
.epps
, self
.parts
742 # collect part-bytes (double factor because the input is extended)
743 pbs
= Signal(self
.pbwid
, reset_less
=True)
745 for i
in range(self
.pbwid
):
746 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
747 m
.d
.comb
+= pb
.eq(epps
.part_byte(i
, mfactor
=2)) # double
749 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
751 # negated-temporary copy of partition bits
752 npbs
= Signal
.like(pbs
, reset_less
=True)
753 m
.d
.comb
+= npbs
.eq(~pbs
)
754 byte_count
= 8 // len(parts
)
755 for i
in range(len(parts
)):
757 pbl
.append(npbs
[i
* byte_count
- 1])
758 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
760 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
761 value
= Signal(len(pbl
), name
="value_%d" % i
, reset_less
=True)
762 m
.d
.comb
+= value
.eq(Cat(*pbl
))
763 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
768 class Part(Elaboratable
):
769 """ a key class which, depending on the partitioning, will determine
770 what action to take when parts of the output are signed or unsigned.
772 this requires 2 pieces of data *per operand, per partition*:
773 whether the MSB is HI/LO (per partition!), and whether a signed
774 or unsigned operation has been *requested*.
776 once that is determined, signed is basically carried out
777 by splitting 2's complement into 1's complement plus one.
778 1's complement is just a bit-inversion.
780 the extra terms - as separate terms - are then thrown at the
781 AddReduce alongside the multiplication part-results.
783 def __init__(self
, epps
, width
, n_parts
, n_levels
, pbwid
):
791 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
792 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
793 self
.pbs
= Signal(pbwid
, reset_less
=True)
796 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
798 self
.not_a_term
= Signal(width
)
799 self
.neg_lsb_a_term
= Signal(width
)
800 self
.not_b_term
= Signal(width
)
801 self
.neg_lsb_b_term
= Signal(width
)
803 def elaborate(self
, platform
):
806 pbs
, parts
= self
.pbs
, self
.parts
808 m
.submodules
.p
= p
= Parts(self
.pbwid
, epps
, len(parts
))
809 m
.d
.comb
+= p
.epps
.eq(epps
)
812 byte_count
= 8 // len(parts
)
814 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= (
815 self
.not_a_term
, self
.neg_lsb_a_term
,
816 self
.not_b_term
, self
.neg_lsb_b_term
)
818 byte_width
= 8 // len(parts
) # byte width
819 bit_wid
= 8 * byte_width
# bit width
820 nat
, nbt
, nla
, nlb
= [], [], [], []
821 for i
in range(len(parts
)):
822 # work out bit-inverted and +1 term for a.
823 pa
= LSBNegTerm(bit_wid
)
824 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
825 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
826 m
.d
.comb
+= pa
.op
.eq(self
.a
.part(bit_wid
* i
, bit_wid
))
827 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
828 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
832 # work out bit-inverted and +1 term for b
833 pb
= LSBNegTerm(bit_wid
)
834 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
835 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
836 m
.d
.comb
+= pb
.op
.eq(self
.b
.part(bit_wid
* i
, bit_wid
))
837 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
838 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
842 # concatenate together and return all 4 results.
843 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
844 not_b_term
.eq(Cat(*nbt
)),
845 neg_lsb_a_term
.eq(Cat(*nla
)),
846 neg_lsb_b_term
.eq(Cat(*nlb
)),
852 class IntermediateOut(Elaboratable
):
853 """ selects the HI/LO part of the multiplication, for a given bit-width
854 the output is also reconstructed in its SIMD (partition) lanes.
856 def __init__(self
, width
, out_wid
, n_parts
):
858 self
.n_parts
= n_parts
859 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
861 self
.intermed
= Signal(out_wid
, reset_less
=True)
862 self
.output
= Signal(out_wid
//2, reset_less
=True)
864 def elaborate(self
, platform
):
870 for i
in range(self
.n_parts
):
871 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
873 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
874 self
.intermed
.part(i
* w
*2, w
),
875 self
.intermed
.part(i
* w
*2 + w
, w
)))
877 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
882 class FinalOut(Elaboratable
):
883 """ selects the final output based on the partitioning.
885 each byte is selectable independently, i.e. it is possible
886 that some partitions requested 8-bit computation whilst others
887 requested 16 or 32 bit.
889 def __init__(self
, out_wid
):
891 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
892 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
893 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
895 self
.i8
= Signal(out_wid
, reset_less
=True)
896 self
.i16
= Signal(out_wid
, reset_less
=True)
897 self
.i32
= Signal(out_wid
, reset_less
=True)
898 self
.i64
= Signal(out_wid
, reset_less
=True)
901 self
.out
= Signal(out_wid
, reset_less
=True)
903 def elaborate(self
, platform
):
907 # select one of the outputs: d8 selects i8, d16 selects i16
908 # d32 selects i32, and the default is i64.
909 # d8 and d16 are ORed together in the first Mux
910 # then the 2nd selects either i8 or i16.
911 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
912 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
914 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
915 Mux(self
.d8
[i
], self
.i8
.part(i
* 8, 8),
916 self
.i16
.part(i
* 8, 8)),
917 Mux(self
.d32
[i
// 4], self
.i32
.part(i
* 8, 8),
918 self
.i64
.part(i
* 8, 8))))
920 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
924 class OrMod(Elaboratable
):
925 """ ORs four values together in a hierarchical tree
927 def __init__(self
, wid
):
929 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
931 self
.orout
= Signal(wid
, reset_less
=True)
933 def elaborate(self
, platform
):
935 or1
= Signal(self
.wid
, reset_less
=True)
936 or2
= Signal(self
.wid
, reset_less
=True)
937 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
938 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
939 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
944 class Signs(Elaboratable
):
945 """ determines whether a or b are signed numbers
946 based on the required operation type (OP_MUL_*)
950 self
.part_ops
= Signal(2, reset_less
=True)
951 self
.a_signed
= Signal(reset_less
=True)
952 self
.b_signed
= Signal(reset_less
=True)
954 def elaborate(self
, platform
):
958 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
959 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
960 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
961 m
.d
.comb
+= self
.a_signed
.eq(asig
)
962 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
967 class Mul8_16_32_64(Elaboratable
):
968 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
970 Supports partitioning into any combination of 8, 16, 32, and 64-bit
971 partitions on naturally-aligned boundaries. Supports the operation being
972 set for each partition independently.
974 :attribute part_pts: the input partition points. Has a partition point at
975 multiples of 8 in 0 < i < 64. Each partition point's associated
976 ``Value`` is a ``Signal``. Modification not supported, except for by
978 :attribute part_ops: the operation for each byte. The operation for a
979 particular partition is selected by assigning the selected operation
980 code to each byte in the partition. The allowed operation codes are:
982 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
983 RISC-V's `mul` instruction.
984 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
985 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
987 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
988 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
989 `mulhsu` instruction.
990 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
991 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
995 def __init__(self
, register_levels
=()):
996 """ register_levels: specifies the points in the cascade at which
997 flip-flops are to be inserted.
1001 self
.register_levels
= list(register_levels
)
1004 self
.part_pts
= PartitionPoints()
1005 for i
in range(8, 64, 8):
1006 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
1007 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
1011 # intermediates (needed for unit tests)
1012 self
._intermediate
_output
= Signal(128)
1015 self
.output
= Signal(64)
1017 def elaborate(self
, platform
):
1020 # collect part-bytes
1021 pbs
= Signal(8, reset_less
=True)
1024 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
1025 m
.d
.comb
+= pb
.eq(self
.part_pts
.part_byte(i
))
1027 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
1029 # create (doubled) PartitionPoints (output is double input width)
1030 expanded_part_pts
= eps
= PartitionPoints()
1031 for i
, v
in self
.part_pts
.items():
1032 ep
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
1033 expanded_part_pts
[i
* 2] = ep
1034 m
.d
.comb
+= ep
.eq(v
)
1041 setattr(m
.submodules
, "signs%d" % i
, s
)
1042 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
1044 n_levels
= len(self
.register_levels
)+1
1045 m
.submodules
.part_8
= part_8
= Part(eps
, 128, 8, n_levels
, 8)
1046 m
.submodules
.part_16
= part_16
= Part(eps
, 128, 4, n_levels
, 8)
1047 m
.submodules
.part_32
= part_32
= Part(eps
, 128, 2, n_levels
, 8)
1048 m
.submodules
.part_64
= part_64
= Part(eps
, 128, 1, n_levels
, 8)
1049 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
1050 for mod
in [part_8
, part_16
, part_32
, part_64
]:
1051 m
.d
.comb
+= mod
.a
.eq(self
.a
)
1052 m
.d
.comb
+= mod
.b
.eq(self
.b
)
1053 for i
in range(len(signs
)):
1054 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
1055 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
1056 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
1057 nat_l
.append(mod
.not_a_term
)
1058 nbt_l
.append(mod
.not_b_term
)
1059 nla_l
.append(mod
.neg_lsb_a_term
)
1060 nlb_l
.append(mod
.neg_lsb_b_term
)
1064 for a_index
in range(8):
1065 t
= ProductTerms(8, 128, 8, a_index
, 8)
1066 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1068 m
.d
.comb
+= t
.a
.eq(self
.a
)
1069 m
.d
.comb
+= t
.b
.eq(self
.b
)
1070 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1072 for term
in t
.terms
:
1075 # it's fine to bitwise-or data together since they are never enabled
1077 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1078 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1079 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1080 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1081 for l
, mod
in [(nat_l
, nat_or
),
1085 for i
in range(len(l
)):
1086 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1087 terms
.append(mod
.orout
)
1089 add_reduce
= AddReduce(terms
,
1091 self
.register_levels
,
1095 out_part_ops
= add_reduce
.levels
[-1].out_part_ops
1096 out_part_pts
= add_reduce
.levels
[-1]._reg
_partition
_points
1098 m
.submodules
.add_reduce
= add_reduce
1099 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
1101 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1102 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
1104 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
1107 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1108 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
1110 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1113 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1114 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
1116 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1119 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1120 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
1122 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1124 m
.submodules
.p_8
= p_8
= Parts(8, eps
, len(part_8
.parts
))
1125 m
.submodules
.p_16
= p_16
= Parts(8, eps
, len(part_16
.parts
))
1126 m
.submodules
.p_32
= p_32
= Parts(8, eps
, len(part_32
.parts
))
1127 m
.submodules
.p_64
= p_64
= Parts(8, eps
, len(part_64
.parts
))
1129 m
.d
.comb
+= p_8
.epps
.eq(out_part_pts
)
1130 m
.d
.comb
+= p_16
.epps
.eq(out_part_pts
)
1131 m
.d
.comb
+= p_32
.epps
.eq(out_part_pts
)
1132 m
.d
.comb
+= p_64
.epps
.eq(out_part_pts
)
1135 m
.submodules
.finalout
= finalout
= FinalOut(64)
1136 for i
in range(len(part_8
.parts
)):
1137 m
.d
.comb
+= finalout
.d8
[i
].eq(p_8
.parts
[i
])
1138 for i
in range(len(part_16
.parts
)):
1139 m
.d
.comb
+= finalout
.d16
[i
].eq(p_16
.parts
[i
])
1140 for i
in range(len(part_32
.parts
)):
1141 m
.d
.comb
+= finalout
.d32
[i
].eq(p_32
.parts
[i
])
1142 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
1143 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
1144 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
1145 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
1146 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
1151 if __name__
== "__main__":
1155 m
._intermediate
_output
,
1158 *m
.part_pts
.values()])