1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
59 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
60 retval
= PartitionPoints()
61 for point
, enabled
in self
.items():
62 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self
.keys()) != set(rhs
.keys()):
68 raise ValueError("incompatible point set")
69 for point
, enabled
in self
.items():
70 yield enabled
.eq(rhs
[point
])
72 def as_mask(self
, width
):
73 """Create a bit-mask from `self`.
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
78 :param width: the bit width of the resulting mask
81 for i
in range(width
):
88 def get_max_partition_count(self
, width
):
89 """Get the maximum number of partitions.
91 Gets the number of partitions when all partition points are enabled.
94 for point
in self
.keys():
99 def fits_in_width(self
, width
):
100 """Check if all partition points are smaller than `width`."""
101 for point
in self
.keys():
106 def part_byte(self
, index
, mfactor
=1): # mfactor used for "expanding"
107 if index
== -1 or index
== 7:
109 assert index
>= 0 and index
< 8
110 return self
[(index
* 8 + 8)*mfactor
]
113 class FullAdder(Elaboratable
):
116 :attribute in0: the first input
117 :attribute in1: the second input
118 :attribute in2: the third input
119 :attribute sum: the sum output
120 :attribute carry: the carry output
122 Rather than do individual full adders (and have an array of them,
123 which would be very slow to simulate), this module can specify the
124 bit width of the inputs and outputs: in effect it performs multiple
125 Full 3-2 Add operations "in parallel".
128 def __init__(self
, width
):
129 """Create a ``FullAdder``.
131 :param width: the bit width of the input and output
133 self
.in0
= Signal(width
)
134 self
.in1
= Signal(width
)
135 self
.in2
= Signal(width
)
136 self
.sum = Signal(width
)
137 self
.carry
= Signal(width
)
139 def elaborate(self
, platform
):
140 """Elaborate this module."""
142 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
143 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
144 |
(self
.in1
& self
.in2
)
145 |
(self
.in2
& self
.in0
))
149 class MaskedFullAdder(Elaboratable
):
150 """Masked Full Adder.
152 :attribute mask: the carry partition mask
153 :attribute in0: the first input
154 :attribute in1: the second input
155 :attribute in2: the third input
156 :attribute sum: the sum output
157 :attribute mcarry: the masked carry output
159 FullAdders are always used with a "mask" on the output. To keep
160 the graphviz "clean", this class performs the masking here rather
161 than inside a large for-loop.
163 See the following discussion as to why this is no longer derived
164 from FullAdder. Each carry is shifted here *before* being ANDed
165 with the mask, so that an AOI cell may be used (which is more
167 https://en.wikipedia.org/wiki/AND-OR-Invert
168 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
171 def __init__(self
, width
):
172 """Create a ``MaskedFullAdder``.
174 :param width: the bit width of the input and output
177 self
.mask
= Signal(width
, reset_less
=True)
178 self
.mcarry
= Signal(width
, reset_less
=True)
179 self
.in0
= Signal(width
, reset_less
=True)
180 self
.in1
= Signal(width
, reset_less
=True)
181 self
.in2
= Signal(width
, reset_less
=True)
182 self
.sum = Signal(width
, reset_less
=True)
184 def elaborate(self
, platform
):
185 """Elaborate this module."""
187 s1
= Signal(self
.width
, reset_less
=True)
188 s2
= Signal(self
.width
, reset_less
=True)
189 s3
= Signal(self
.width
, reset_less
=True)
190 c1
= Signal(self
.width
, reset_less
=True)
191 c2
= Signal(self
.width
, reset_less
=True)
192 c3
= Signal(self
.width
, reset_less
=True)
193 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
194 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
195 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
196 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
197 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
198 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
199 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
200 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
204 class PartitionedAdder(Elaboratable
):
205 """Partitioned Adder.
207 Performs the final add. The partition points are included in the
208 actual add (in one of the operands only), which causes a carry over
209 to the next bit. Then the final output *removes* the extra bits from
212 partition: .... P... P... P... P... (32 bits)
213 a : .... .... .... .... .... (32 bits)
214 b : .... .... .... .... .... (32 bits)
215 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
216 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
217 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
218 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
220 :attribute width: the bit width of the input and output. Read-only.
221 :attribute a: the first input to the adder
222 :attribute b: the second input to the adder
223 :attribute output: the sum output
224 :attribute partition_points: the input partition points. Modification not
225 supported, except for by ``Signal.eq``.
228 def __init__(self
, width
, partition_points
):
229 """Create a ``PartitionedAdder``.
231 :param width: the bit width of the input and output
232 :param partition_points: the input partition points
235 self
.a
= Signal(width
)
236 self
.b
= Signal(width
)
237 self
.output
= Signal(width
)
238 self
.partition_points
= PartitionPoints(partition_points
)
239 if not self
.partition_points
.fits_in_width(width
):
240 raise ValueError("partition_points doesn't fit in width")
242 for i
in range(self
.width
):
243 if i
in self
.partition_points
:
246 self
._expanded
_width
= expanded_width
247 # XXX these have to remain here due to some horrible nmigen
248 # simulation bugs involving sync. it is *not* necessary to
249 # have them here, they should (under normal circumstances)
250 # be moved into elaborate, as they are entirely local
251 self
._expanded
_a
= Signal(expanded_width
) # includes extra part-points
252 self
._expanded
_b
= Signal(expanded_width
) # likewise.
253 self
._expanded
_o
= Signal(expanded_width
) # likewise.
255 def elaborate(self
, platform
):
256 """Elaborate this module."""
259 # store bits in a list, use Cat later. graphviz is much cleaner
260 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
262 # partition points are "breaks" (extra zeros or 1s) in what would
263 # otherwise be a massive long add. when the "break" points are 0,
264 # whatever is in it (in the output) is discarded. however when
265 # there is a "1", it causes a roll-over carry to the *next* bit.
266 # we still ignore the "break" bit in the [intermediate] output,
267 # however by that time we've got the effect that we wanted: the
268 # carry has been carried *over* the break point.
270 for i
in range(self
.width
):
271 if i
in self
.partition_points
:
272 # add extra bit set to 0 + 0 for enabled partition points
273 # and 1 + 0 for disabled partition points
274 ea
.append(self
._expanded
_a
[expanded_index
])
275 al
.append(~self
.partition_points
[i
]) # add extra bit in a
276 eb
.append(self
._expanded
_b
[expanded_index
])
277 bl
.append(C(0)) # yes, add a zero
278 expanded_index
+= 1 # skip the extra point. NOT in the output
279 ea
.append(self
._expanded
_a
[expanded_index
])
280 eb
.append(self
._expanded
_b
[expanded_index
])
281 eo
.append(self
._expanded
_o
[expanded_index
])
284 ol
.append(self
.output
[i
])
287 # combine above using Cat
288 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
289 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
290 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
292 # use only one addition to take advantage of look-ahead carry and
293 # special hardware on FPGAs
294 m
.d
.comb
+= self
._expanded
_o
.eq(
295 self
._expanded
_a
+ self
._expanded
_b
)
299 FULL_ADDER_INPUT_COUNT
= 3
302 class AddReduceSingle(Elaboratable
):
303 """Add list of numbers together.
305 :attribute inputs: input ``Signal``s to be summed. Modification not
306 supported, except for by ``Signal.eq``.
307 :attribute register_levels: List of nesting levels that should have
309 :attribute output: output sum.
310 :attribute partition_points: the input partition points. Modification not
311 supported, except for by ``Signal.eq``.
314 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
316 """Create an ``AddReduce``.
318 :param inputs: input ``Signal``s to be summed.
319 :param output_width: bit-width of ``output``.
320 :param register_levels: List of nesting levels that should have
322 :param partition_points: the input partition points.
324 self
.part_ops
= part_ops
325 self
.out_part_ops
= [Signal(2, name
=f
"out_part_ops_{i}")
326 for i
in range(len(part_ops
))]
327 self
.inputs
= list(inputs
)
328 self
._resized
_inputs
= [
329 Signal(output_width
, name
=f
"resized_inputs[{i}]")
330 for i
in range(len(self
.inputs
))]
331 self
.register_levels
= list(register_levels
)
332 self
.output
= Signal(output_width
)
333 self
.partition_points
= PartitionPoints(partition_points
)
334 if not self
.partition_points
.fits_in_width(output_width
):
335 raise ValueError("partition_points doesn't fit in output_width")
336 self
._reg
_partition
_points
= self
.partition_points
.like()
338 max_level
= AddReduceSingle
.get_max_level(len(self
.inputs
))
339 for level
in self
.register_levels
:
340 if level
> max_level
:
342 "not enough adder levels for specified register levels")
344 # this is annoying. we have to create the modules (and terms)
345 # because we need to know what they are (in order to set up the
346 # interconnects back in AddReduce), but cannot do the m.d.comb +=
347 # etc because this is not in elaboratable.
348 self
.groups
= AddReduceSingle
.full_adder_groups(len(self
.inputs
))
349 self
._intermediate
_terms
= []
350 if len(self
.groups
) != 0:
351 self
.create_next_terms()
354 def get_max_level(input_count
):
355 """Get the maximum level.
357 All ``register_levels`` must be less than or equal to the maximum
362 groups
= AddReduceSingle
.full_adder_groups(input_count
)
365 input_count
%= FULL_ADDER_INPUT_COUNT
366 input_count
+= 2 * len(groups
)
370 def full_adder_groups(input_count
):
371 """Get ``inputs`` indices for which a full adder should be built."""
373 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
374 FULL_ADDER_INPUT_COUNT
)
376 def elaborate(self
, platform
):
377 """Elaborate this module."""
380 # resize inputs to correct bit-width and optionally add in
382 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
383 for i
in range(len(self
.inputs
))]
384 copy_part_ops
= [self
.out_part_ops
[i
].eq(self
.part_ops
[i
])
385 for i
in range(len(self
.part_ops
))]
386 if 0 in self
.register_levels
:
387 m
.d
.sync
+= copy_part_ops
388 m
.d
.sync
+= resized_input_assignments
389 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
391 m
.d
.comb
+= copy_part_ops
392 m
.d
.comb
+= resized_input_assignments
393 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
395 for (value
, term
) in self
._intermediate
_terms
:
396 m
.d
.comb
+= term
.eq(value
)
398 # if there are no full adders to create, then we handle the base cases
399 # and return, otherwise we go on to the recursive case
400 if len(self
.groups
) == 0:
401 if len(self
.inputs
) == 0:
402 # use 0 as the default output value
403 m
.d
.comb
+= self
.output
.eq(0)
404 elif len(self
.inputs
) == 1:
405 # handle single input
406 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
408 # base case for adding 2 inputs
409 assert len(self
.inputs
) == 2
410 adder
= PartitionedAdder(len(self
.output
),
411 self
._reg
_partition
_points
)
412 m
.submodules
.final_adder
= adder
413 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
414 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
415 m
.d
.comb
+= self
.output
.eq(adder
.output
)
418 mask
= self
._reg
_partition
_points
.as_mask(len(self
.output
))
419 m
.d
.comb
+= self
.part_mask
.eq(mask
)
421 # add and link the intermediate term modules
422 for i
, (iidx
, adder_i
) in enumerate(self
.adders
):
423 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
425 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[iidx
])
426 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[iidx
+ 1])
427 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[iidx
+ 2])
428 m
.d
.comb
+= adder_i
.mask
.eq(self
.part_mask
)
432 def create_next_terms(self
):
434 # go on to prepare recursive case
435 intermediate_terms
= []
436 _intermediate_terms
= []
438 def add_intermediate_term(value
):
439 intermediate_term
= Signal(
441 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
442 _intermediate_terms
.append((value
, intermediate_term
))
443 intermediate_terms
.append(intermediate_term
)
445 # store mask in intermediary (simplifies graph)
446 self
.part_mask
= Signal(len(self
.output
), reset_less
=True)
448 # create full adders for this recursive level.
449 # this shrinks N terms to 2 * (N // 3) plus the remainder
451 for i
in self
.groups
:
452 adder_i
= MaskedFullAdder(len(self
.output
))
453 self
.adders
.append((i
, adder_i
))
454 # add both the sum and the masked-carry to the next level.
455 # 3 inputs have now been reduced to 2...
456 add_intermediate_term(adder_i
.sum)
457 add_intermediate_term(adder_i
.mcarry
)
458 # handle the remaining inputs.
459 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
460 add_intermediate_term(self
._resized
_inputs
[-1])
461 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
462 # Just pass the terms to the next layer, since we wouldn't gain
463 # anything by using a half adder since there would still be 2 terms
464 # and just passing the terms to the next layer saves gates.
465 add_intermediate_term(self
._resized
_inputs
[-2])
466 add_intermediate_term(self
._resized
_inputs
[-1])
468 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
470 self
.intermediate_terms
= intermediate_terms
471 self
._intermediate
_terms
= _intermediate_terms
474 class AddReduce(Elaboratable
):
475 """Recursively Add list of numbers together.
477 :attribute inputs: input ``Signal``s to be summed. Modification not
478 supported, except for by ``Signal.eq``.
479 :attribute register_levels: List of nesting levels that should have
481 :attribute output: output sum.
482 :attribute partition_points: the input partition points. Modification not
483 supported, except for by ``Signal.eq``.
486 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
488 """Create an ``AddReduce``.
490 :param inputs: input ``Signal``s to be summed.
491 :param output_width: bit-width of ``output``.
492 :param register_levels: List of nesting levels that should have
494 :param partition_points: the input partition points.
497 self
.part_ops
= part_ops
498 self
.out_part_ops
= [Signal(2, name
=f
"out_part_ops_{i}")
499 for i
in range(len(part_ops
))]
500 self
.output
= Signal(output_width
)
501 self
.output_width
= output_width
502 self
.register_levels
= register_levels
503 self
.partition_points
= partition_points
508 def get_max_level(input_count
):
509 return AddReduceSingle
.get_max_level(input_count
)
512 def next_register_levels(register_levels
):
513 """``Iterable`` of ``register_levels`` for next recursive level."""
514 for level
in register_levels
:
518 def create_levels(self
):
519 """creates reduction levels"""
522 next_levels
= self
.register_levels
523 partition_points
= self
.partition_points
525 part_ops
= self
.part_ops
527 next_level
= AddReduceSingle(inputs
, self
.output_width
, next_levels
,
528 partition_points
, part_ops
)
529 mods
.append(next_level
)
530 if len(next_level
.groups
) == 0:
532 next_levels
= list(AddReduce
.next_register_levels(next_levels
))
533 partition_points
= next_level
._reg
_partition
_points
534 inputs
= next_level
.intermediate_terms
535 part_ops
= next_level
.out_part_ops
539 def elaborate(self
, platform
):
540 """Elaborate this module."""
543 for i
, next_level
in enumerate(self
.levels
):
544 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
546 # output comes from last module
547 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
548 copy_part_ops
= [self
.out_part_ops
[i
].eq(next_level
.out_part_ops
[i
])
549 for i
in range(len(self
.part_ops
))]
550 m
.d
.comb
+= copy_part_ops
556 OP_MUL_SIGNED_HIGH
= 1
557 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
558 OP_MUL_UNSIGNED_HIGH
= 3
561 def get_term(value
, shift
=0, enabled
=None):
562 if enabled
is not None:
563 value
= Mux(enabled
, value
, 0)
565 value
= Cat(Repl(C(0, 1), shift
), value
)
571 class ProductTerm(Elaboratable
):
572 """ this class creates a single product term (a[..]*b[..]).
573 it has a design flaw in that is the *output* that is selected,
574 where the multiplication(s) are combinatorially generated
578 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
579 self
.a_index
= a_index
580 self
.b_index
= b_index
581 shift
= 8 * (self
.a_index
+ self
.b_index
)
587 self
.ti
= Signal(self
.width
, reset_less
=True)
588 self
.term
= Signal(twidth
, reset_less
=True)
589 self
.a
= Signal(twidth
//2, reset_less
=True)
590 self
.b
= Signal(twidth
//2, reset_less
=True)
591 self
.pb_en
= Signal(pbwid
, reset_less
=True)
594 min_index
= min(self
.a_index
, self
.b_index
)
595 max_index
= max(self
.a_index
, self
.b_index
)
596 for i
in range(min_index
, max_index
):
597 tl
.append(self
.pb_en
[i
])
598 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
600 term_enabled
= Signal(name
=name
, reset_less
=True)
603 self
.enabled
= term_enabled
604 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
606 def elaborate(self
, platform
):
609 if self
.enabled
is not None:
610 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
612 bsa
= Signal(self
.width
, reset_less
=True)
613 bsb
= Signal(self
.width
, reset_less
=True)
614 a_index
, b_index
= self
.a_index
, self
.b_index
616 m
.d
.comb
+= bsa
.eq(self
.a
.part(a_index
* pwidth
, pwidth
))
617 m
.d
.comb
+= bsb
.eq(self
.b
.part(b_index
* pwidth
, pwidth
))
618 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
619 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
621 #TODO: sort out width issues, get inputs a/b switched on/off.
622 #data going into Muxes is 1/2 the required width
626 bsa = Signal(self.twidth//2, reset_less=True)
627 bsb = Signal(self.twidth//2, reset_less=True)
628 asel = Signal(width, reset_less=True)
629 bsel = Signal(width, reset_less=True)
630 a_index, b_index = self.a_index, self.b_index
631 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
632 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
633 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
634 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
635 m.d.comb += self.ti.eq(bsa * bsb)
636 m.d.comb += self.term.eq(self.ti)
642 class ProductTerms(Elaboratable
):
643 """ creates a bank of product terms. also performs the actual bit-selection
644 this class is to be wrapped with a for-loop on the "a" operand.
645 it creates a second-level for-loop on the "b" operand.
647 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
648 self
.a_index
= a_index
653 self
.a
= Signal(twidth
//2, reset_less
=True)
654 self
.b
= Signal(twidth
//2, reset_less
=True)
655 self
.pb_en
= Signal(pbwid
, reset_less
=True)
656 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
657 for i
in range(blen
)]
659 def elaborate(self
, platform
):
663 for b_index
in range(self
.blen
):
664 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
665 self
.a_index
, b_index
)
666 setattr(m
.submodules
, "term_%d" % b_index
, t
)
668 m
.d
.comb
+= t
.a
.eq(self
.a
)
669 m
.d
.comb
+= t
.b
.eq(self
.b
)
670 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
672 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
677 class LSBNegTerm(Elaboratable
):
679 def __init__(self
, bit_width
):
680 self
.bit_width
= bit_width
681 self
.part
= Signal(reset_less
=True)
682 self
.signed
= Signal(reset_less
=True)
683 self
.op
= Signal(bit_width
, reset_less
=True)
684 self
.msb
= Signal(reset_less
=True)
685 self
.nt
= Signal(bit_width
*2, reset_less
=True)
686 self
.nl
= Signal(bit_width
*2, reset_less
=True)
688 def elaborate(self
, platform
):
691 bit_wid
= self
.bit_width
692 ext
= Repl(0, bit_wid
) # extend output to HI part
694 # determine sign of each incoming number *in this partition*
695 enabled
= Signal(reset_less
=True)
696 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
698 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
699 # negation operation is split into a bitwise not and a +1.
700 # likewise for 16, 32, and 64-bit values.
702 # width-extended 1s complement if a is signed, otherwise zero
703 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
705 # add 1 if signed, otherwise add zero
706 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
711 class Parts(Elaboratable
):
713 def __init__(self
, pbwid
, epps
, n_parts
):
716 self
.epps
= PartitionPoints
.like(epps
, name
="epps") # expanded points
718 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
720 def elaborate(self
, platform
):
723 epps
, parts
= self
.epps
, self
.parts
724 # collect part-bytes (double factor because the input is extended)
725 pbs
= Signal(self
.pbwid
, reset_less
=True)
727 for i
in range(self
.pbwid
):
728 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
729 m
.d
.comb
+= pb
.eq(epps
.part_byte(i
, mfactor
=2)) # double
731 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
733 # negated-temporary copy of partition bits
734 npbs
= Signal
.like(pbs
, reset_less
=True)
735 m
.d
.comb
+= npbs
.eq(~pbs
)
736 byte_count
= 8 // len(parts
)
737 for i
in range(len(parts
)):
739 pbl
.append(npbs
[i
* byte_count
- 1])
740 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
742 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
743 value
= Signal(len(pbl
), name
="value_%d" % i
, reset_less
=True)
744 m
.d
.comb
+= value
.eq(Cat(*pbl
))
745 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
750 class Part(Elaboratable
):
751 """ a key class which, depending on the partitioning, will determine
752 what action to take when parts of the output are signed or unsigned.
754 this requires 2 pieces of data *per operand, per partition*:
755 whether the MSB is HI/LO (per partition!), and whether a signed
756 or unsigned operation has been *requested*.
758 once that is determined, signed is basically carried out
759 by splitting 2's complement into 1's complement plus one.
760 1's complement is just a bit-inversion.
762 the extra terms - as separate terms - are then thrown at the
763 AddReduce alongside the multiplication part-results.
765 def __init__(self
, width
, n_parts
, n_levels
, pbwid
):
770 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
771 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
772 self
.pbs
= Signal(pbwid
, reset_less
=True)
775 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
776 self
.delayed_parts
= [
777 [Signal(name
=f
"delayed_part_{delay}_{i}")
778 for i
in range(n_parts
)]
779 for delay
in range(n_levels
)]
780 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
781 self
.dplast
= [Signal(name
=f
"dplast_{i}")
782 for i
in range(n_parts
)]
784 self
.not_a_term
= Signal(width
)
785 self
.neg_lsb_a_term
= Signal(width
)
786 self
.not_b_term
= Signal(width
)
787 self
.neg_lsb_b_term
= Signal(width
)
789 def elaborate(self
, platform
):
792 pbs
, parts
, delayed_parts
= self
.pbs
, self
.parts
, self
.delayed_parts
793 # negated-temporary copy of partition bits
794 npbs
= Signal
.like(pbs
, reset_less
=True)
795 m
.d
.comb
+= npbs
.eq(~pbs
)
796 byte_count
= 8 // len(parts
)
797 for i
in range(len(parts
)):
799 pbl
.append(npbs
[i
* byte_count
- 1])
800 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
802 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
803 value
= Signal(len(pbl
), name
="value_%di" % i
, reset_less
=True)
804 m
.d
.comb
+= value
.eq(Cat(*pbl
))
805 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
806 m
.d
.comb
+= delayed_parts
[0][i
].eq(parts
[i
])
807 m
.d
.sync
+= [delayed_parts
[j
+ 1][i
].eq(delayed_parts
[j
][i
])
808 for j
in range(len(delayed_parts
)-1)]
809 m
.d
.comb
+= self
.dplast
[i
].eq(delayed_parts
[-1][i
])
811 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= \
812 self
.not_a_term
, self
.neg_lsb_a_term
, \
813 self
.not_b_term
, self
.neg_lsb_b_term
815 byte_width
= 8 // len(parts
) # byte width
816 bit_wid
= 8 * byte_width
# bit width
817 nat
, nbt
, nla
, nlb
= [], [], [], []
818 for i
in range(len(parts
)):
819 # work out bit-inverted and +1 term for a.
820 pa
= LSBNegTerm(bit_wid
)
821 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
822 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
823 m
.d
.comb
+= pa
.op
.eq(self
.a
.part(bit_wid
* i
, bit_wid
))
824 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
825 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
829 # work out bit-inverted and +1 term for b
830 pb
= LSBNegTerm(bit_wid
)
831 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
832 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
833 m
.d
.comb
+= pb
.op
.eq(self
.b
.part(bit_wid
* i
, bit_wid
))
834 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
835 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
839 # concatenate together and return all 4 results.
840 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
841 not_b_term
.eq(Cat(*nbt
)),
842 neg_lsb_a_term
.eq(Cat(*nla
)),
843 neg_lsb_b_term
.eq(Cat(*nlb
)),
849 class IntermediateOut(Elaboratable
):
850 """ selects the HI/LO part of the multiplication, for a given bit-width
851 the output is also reconstructed in its SIMD (partition) lanes.
853 def __init__(self
, width
, out_wid
, n_parts
):
855 self
.n_parts
= n_parts
856 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
858 self
.intermed
= Signal(out_wid
, reset_less
=True)
859 self
.output
= Signal(out_wid
//2, reset_less
=True)
861 def elaborate(self
, platform
):
867 for i
in range(self
.n_parts
):
868 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
870 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
871 self
.intermed
.part(i
* w
*2, w
),
872 self
.intermed
.part(i
* w
*2 + w
, w
)))
874 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
879 class FinalOut(Elaboratable
):
880 """ selects the final output based on the partitioning.
882 each byte is selectable independently, i.e. it is possible
883 that some partitions requested 8-bit computation whilst others
884 requested 16 or 32 bit.
886 def __init__(self
, out_wid
):
888 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
889 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
890 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
892 self
.i8
= Signal(out_wid
, reset_less
=True)
893 self
.i16
= Signal(out_wid
, reset_less
=True)
894 self
.i32
= Signal(out_wid
, reset_less
=True)
895 self
.i64
= Signal(out_wid
, reset_less
=True)
898 self
.out
= Signal(out_wid
, reset_less
=True)
900 def elaborate(self
, platform
):
904 # select one of the outputs: d8 selects i8, d16 selects i16
905 # d32 selects i32, and the default is i64.
906 # d8 and d16 are ORed together in the first Mux
907 # then the 2nd selects either i8 or i16.
908 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
909 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
911 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
912 Mux(self
.d8
[i
], self
.i8
.part(i
* 8, 8),
913 self
.i16
.part(i
* 8, 8)),
914 Mux(self
.d32
[i
// 4], self
.i32
.part(i
* 8, 8),
915 self
.i64
.part(i
* 8, 8))))
917 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
921 class OrMod(Elaboratable
):
922 """ ORs four values together in a hierarchical tree
924 def __init__(self
, wid
):
926 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
928 self
.orout
= Signal(wid
, reset_less
=True)
930 def elaborate(self
, platform
):
932 or1
= Signal(self
.wid
, reset_less
=True)
933 or2
= Signal(self
.wid
, reset_less
=True)
934 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
935 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
936 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
941 class Signs(Elaboratable
):
942 """ determines whether a or b are signed numbers
943 based on the required operation type (OP_MUL_*)
947 self
.part_ops
= Signal(2, reset_less
=True)
948 self
.a_signed
= Signal(reset_less
=True)
949 self
.b_signed
= Signal(reset_less
=True)
951 def elaborate(self
, platform
):
955 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
956 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
957 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
958 m
.d
.comb
+= self
.a_signed
.eq(asig
)
959 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
964 class Mul8_16_32_64(Elaboratable
):
965 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
967 Supports partitioning into any combination of 8, 16, 32, and 64-bit
968 partitions on naturally-aligned boundaries. Supports the operation being
969 set for each partition independently.
971 :attribute part_pts: the input partition points. Has a partition point at
972 multiples of 8 in 0 < i < 64. Each partition point's associated
973 ``Value`` is a ``Signal``. Modification not supported, except for by
975 :attribute part_ops: the operation for each byte. The operation for a
976 particular partition is selected by assigning the selected operation
977 code to each byte in the partition. The allowed operation codes are:
979 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
980 RISC-V's `mul` instruction.
981 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
982 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
984 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
985 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
986 `mulhsu` instruction.
987 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
988 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
992 def __init__(self
, register_levels
=()):
993 """ register_levels: specifies the points in the cascade at which
994 flip-flops are to be inserted.
998 self
.register_levels
= list(register_levels
)
1001 self
.part_pts
= PartitionPoints()
1002 for i
in range(8, 64, 8):
1003 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
1004 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
1008 # intermediates (needed for unit tests)
1009 self
._intermediate
_output
= Signal(128)
1012 self
.output
= Signal(64)
1014 def elaborate(self
, platform
):
1017 # collect part-bytes
1018 pbs
= Signal(8, reset_less
=True)
1021 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
1022 m
.d
.comb
+= pb
.eq(self
.part_pts
.part_byte(i
))
1024 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
1026 # create (doubled) PartitionPoints (output is double input width)
1027 expanded_part_pts
= PartitionPoints()
1028 for i
, v
in self
.part_pts
.items():
1029 ep
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
1030 expanded_part_pts
[i
* 2] = ep
1031 m
.d
.comb
+= ep
.eq(v
)
1038 setattr(m
.submodules
, "signs%d" % i
, s
)
1039 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
1041 n_levels
= len(self
.register_levels
)+1
1042 m
.submodules
.part_8
= part_8
= Part(128, 8, n_levels
, 8)
1043 m
.submodules
.part_16
= part_16
= Part(128, 4, n_levels
, 8)
1044 m
.submodules
.part_32
= part_32
= Part(128, 2, n_levels
, 8)
1045 m
.submodules
.part_64
= part_64
= Part(128, 1, n_levels
, 8)
1046 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
1047 for mod
in [part_8
, part_16
, part_32
, part_64
]:
1048 m
.d
.comb
+= mod
.a
.eq(self
.a
)
1049 m
.d
.comb
+= mod
.b
.eq(self
.b
)
1050 for i
in range(len(signs
)):
1051 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
1052 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
1053 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
1054 nat_l
.append(mod
.not_a_term
)
1055 nbt_l
.append(mod
.not_b_term
)
1056 nla_l
.append(mod
.neg_lsb_a_term
)
1057 nlb_l
.append(mod
.neg_lsb_b_term
)
1061 for a_index
in range(8):
1062 t
= ProductTerms(8, 128, 8, a_index
, 8)
1063 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1065 m
.d
.comb
+= t
.a
.eq(self
.a
)
1066 m
.d
.comb
+= t
.b
.eq(self
.b
)
1067 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1069 for term
in t
.terms
:
1072 # it's fine to bitwise-or data together since they are never enabled
1074 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1075 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1076 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1077 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1078 for l
, mod
in [(nat_l
, nat_or
),
1082 for i
in range(len(l
)):
1083 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1084 terms
.append(mod
.orout
)
1086 add_reduce
= AddReduce(terms
,
1088 self
.register_levels
,
1092 out_part_ops
= add_reduce
.levels
[-1].out_part_ops
1094 m
.submodules
.add_reduce
= add_reduce
1095 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
1097 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1098 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
1100 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
1103 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1104 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
1106 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1109 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1110 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
1112 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1115 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1116 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
1118 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1121 m
.submodules
.finalout
= finalout
= FinalOut(64)
1122 for i
in range(len(part_8
.delayed_parts
[-1])):
1123 m
.d
.comb
+= finalout
.d8
[i
].eq(part_8
.dplast
[i
])
1124 for i
in range(len(part_16
.delayed_parts
[-1])):
1125 m
.d
.comb
+= finalout
.d16
[i
].eq(part_16
.dplast
[i
])
1126 for i
in range(len(part_32
.delayed_parts
[-1])):
1127 m
.d
.comb
+= finalout
.d32
[i
].eq(part_32
.dplast
[i
])
1128 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
1129 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
1130 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
1131 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
1132 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
1137 if __name__
== "__main__":
1141 m
._intermediate
_output
,
1144 *m
.part_pts
.values()])