1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
59 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
60 retval
= PartitionPoints()
61 for point
, enabled
in self
.items():
62 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self
.keys()) != set(rhs
.keys()):
68 raise ValueError("incompatible point set")
69 for point
, enabled
in self
.items():
70 yield enabled
.eq(rhs
[point
])
72 def as_mask(self
, width
):
73 """Create a bit-mask from `self`.
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
78 :param width: the bit width of the resulting mask
81 for i
in range(width
):
88 def get_max_partition_count(self
, width
):
89 """Get the maximum number of partitions.
91 Gets the number of partitions when all partition points are enabled.
94 for point
in self
.keys():
99 def fits_in_width(self
, width
):
100 """Check if all partition points are smaller than `width`."""
101 for point
in self
.keys():
106 def part_byte(self
, index
, mfactor
=1): # mfactor used for "expanding"
107 if index
== -1 or index
== 7:
109 assert index
>= 0 and index
< 8
110 return self
[(index
* 8 + 8)*mfactor
]
113 class FullAdder(Elaboratable
):
116 :attribute in0: the first input
117 :attribute in1: the second input
118 :attribute in2: the third input
119 :attribute sum: the sum output
120 :attribute carry: the carry output
122 Rather than do individual full adders (and have an array of them,
123 which would be very slow to simulate), this module can specify the
124 bit width of the inputs and outputs: in effect it performs multiple
125 Full 3-2 Add operations "in parallel".
128 def __init__(self
, width
):
129 """Create a ``FullAdder``.
131 :param width: the bit width of the input and output
133 self
.in0
= Signal(width
)
134 self
.in1
= Signal(width
)
135 self
.in2
= Signal(width
)
136 self
.sum = Signal(width
)
137 self
.carry
= Signal(width
)
139 def elaborate(self
, platform
):
140 """Elaborate this module."""
142 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
143 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
144 |
(self
.in1
& self
.in2
)
145 |
(self
.in2
& self
.in0
))
149 class MaskedFullAdder(Elaboratable
):
150 """Masked Full Adder.
152 :attribute mask: the carry partition mask
153 :attribute in0: the first input
154 :attribute in1: the second input
155 :attribute in2: the third input
156 :attribute sum: the sum output
157 :attribute mcarry: the masked carry output
159 FullAdders are always used with a "mask" on the output. To keep
160 the graphviz "clean", this class performs the masking here rather
161 than inside a large for-loop.
163 See the following discussion as to why this is no longer derived
164 from FullAdder. Each carry is shifted here *before* being ANDed
165 with the mask, so that an AOI cell may be used (which is more
167 https://en.wikipedia.org/wiki/AND-OR-Invert
168 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
171 def __init__(self
, width
):
172 """Create a ``MaskedFullAdder``.
174 :param width: the bit width of the input and output
177 self
.mask
= Signal(width
, reset_less
=True)
178 self
.mcarry
= Signal(width
, reset_less
=True)
179 self
.in0
= Signal(width
, reset_less
=True)
180 self
.in1
= Signal(width
, reset_less
=True)
181 self
.in2
= Signal(width
, reset_less
=True)
182 self
.sum = Signal(width
, reset_less
=True)
184 def elaborate(self
, platform
):
185 """Elaborate this module."""
187 s1
= Signal(self
.width
, reset_less
=True)
188 s2
= Signal(self
.width
, reset_less
=True)
189 s3
= Signal(self
.width
, reset_less
=True)
190 c1
= Signal(self
.width
, reset_less
=True)
191 c2
= Signal(self
.width
, reset_less
=True)
192 c3
= Signal(self
.width
, reset_less
=True)
193 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
194 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
195 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
196 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
197 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
198 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
199 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
200 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
204 class PartitionedAdder(Elaboratable
):
205 """Partitioned Adder.
207 Performs the final add. The partition points are included in the
208 actual add (in one of the operands only), which causes a carry over
209 to the next bit. Then the final output *removes* the extra bits from
212 partition: .... P... P... P... P... (32 bits)
213 a : .... .... .... .... .... (32 bits)
214 b : .... .... .... .... .... (32 bits)
215 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
216 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
217 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
218 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
220 :attribute width: the bit width of the input and output. Read-only.
221 :attribute a: the first input to the adder
222 :attribute b: the second input to the adder
223 :attribute output: the sum output
224 :attribute partition_points: the input partition points. Modification not
225 supported, except for by ``Signal.eq``.
228 def __init__(self
, width
, partition_points
):
229 """Create a ``PartitionedAdder``.
231 :param width: the bit width of the input and output
232 :param partition_points: the input partition points
235 self
.a
= Signal(width
)
236 self
.b
= Signal(width
)
237 self
.output
= Signal(width
)
238 self
.partition_points
= PartitionPoints(partition_points
)
239 if not self
.partition_points
.fits_in_width(width
):
240 raise ValueError("partition_points doesn't fit in width")
242 for i
in range(self
.width
):
243 if i
in self
.partition_points
:
246 self
._expanded
_width
= expanded_width
247 # XXX these have to remain here due to some horrible nmigen
248 # simulation bugs involving sync. it is *not* necessary to
249 # have them here, they should (under normal circumstances)
250 # be moved into elaborate, as they are entirely local
251 self
._expanded
_a
= Signal(expanded_width
) # includes extra part-points
252 self
._expanded
_b
= Signal(expanded_width
) # likewise.
253 self
._expanded
_o
= Signal(expanded_width
) # likewise.
255 def elaborate(self
, platform
):
256 """Elaborate this module."""
259 # store bits in a list, use Cat later. graphviz is much cleaner
260 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
262 # partition points are "breaks" (extra zeros or 1s) in what would
263 # otherwise be a massive long add. when the "break" points are 0,
264 # whatever is in it (in the output) is discarded. however when
265 # there is a "1", it causes a roll-over carry to the *next* bit.
266 # we still ignore the "break" bit in the [intermediate] output,
267 # however by that time we've got the effect that we wanted: the
268 # carry has been carried *over* the break point.
270 for i
in range(self
.width
):
271 if i
in self
.partition_points
:
272 # add extra bit set to 0 + 0 for enabled partition points
273 # and 1 + 0 for disabled partition points
274 ea
.append(self
._expanded
_a
[expanded_index
])
275 al
.append(~self
.partition_points
[i
]) # add extra bit in a
276 eb
.append(self
._expanded
_b
[expanded_index
])
277 bl
.append(C(0)) # yes, add a zero
278 expanded_index
+= 1 # skip the extra point. NOT in the output
279 ea
.append(self
._expanded
_a
[expanded_index
])
280 eb
.append(self
._expanded
_b
[expanded_index
])
281 eo
.append(self
._expanded
_o
[expanded_index
])
284 ol
.append(self
.output
[i
])
287 # combine above using Cat
288 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
289 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
290 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
292 # use only one addition to take advantage of look-ahead carry and
293 # special hardware on FPGAs
294 m
.d
.comb
+= self
._expanded
_o
.eq(
295 self
._expanded
_a
+ self
._expanded
_b
)
299 FULL_ADDER_INPUT_COUNT
= 3
302 class AddReduceSingle(Elaboratable
):
303 """Add list of numbers together.
305 :attribute inputs: input ``Signal``s to be summed. Modification not
306 supported, except for by ``Signal.eq``.
307 :attribute register_levels: List of nesting levels that should have
309 :attribute output: output sum.
310 :attribute partition_points: the input partition points. Modification not
311 supported, except for by ``Signal.eq``.
314 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
316 """Create an ``AddReduce``.
318 :param inputs: input ``Signal``s to be summed.
319 :param output_width: bit-width of ``output``.
320 :param register_levels: List of nesting levels that should have
322 :param partition_points: the input partition points.
324 self
.part_ops
= part_ops
325 self
.out_part_ops
= [Signal(2, name
=f
"out_part_ops_{i}")
326 for i
in range(len(part_ops
))]
327 self
.inputs
= list(inputs
)
328 self
._resized
_inputs
= [
329 Signal(output_width
, name
=f
"resized_inputs[{i}]")
330 for i
in range(len(self
.inputs
))]
331 self
.register_levels
= list(register_levels
)
332 self
.output
= Signal(output_width
)
333 self
.partition_points
= PartitionPoints(partition_points
)
334 if not self
.partition_points
.fits_in_width(output_width
):
335 raise ValueError("partition_points doesn't fit in output_width")
336 self
._reg
_partition
_points
= self
.partition_points
.like()
338 max_level
= AddReduceSingle
.get_max_level(len(self
.inputs
))
339 for level
in self
.register_levels
:
340 if level
> max_level
:
342 "not enough adder levels for specified register levels")
344 # this is annoying. we have to create the modules (and terms)
345 # because we need to know what they are (in order to set up the
346 # interconnects back in AddReduce), but cannot do the m.d.comb +=
347 # etc because this is not in elaboratable.
348 self
.groups
= AddReduceSingle
.full_adder_groups(len(self
.inputs
))
349 self
._intermediate
_terms
= []
350 if len(self
.groups
) != 0:
351 self
.create_next_terms()
354 def get_max_level(input_count
):
355 """Get the maximum level.
357 All ``register_levels`` must be less than or equal to the maximum
362 groups
= AddReduceSingle
.full_adder_groups(input_count
)
365 input_count
%= FULL_ADDER_INPUT_COUNT
366 input_count
+= 2 * len(groups
)
370 def full_adder_groups(input_count
):
371 """Get ``inputs`` indices for which a full adder should be built."""
373 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
374 FULL_ADDER_INPUT_COUNT
)
376 def elaborate(self
, platform
):
377 """Elaborate this module."""
380 # resize inputs to correct bit-width and optionally add in
382 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
383 for i
in range(len(self
.inputs
))]
384 copy_part_ops
= [self
.out_part_ops
[i
].eq(self
.part_ops
[i
])
385 for i
in range(len(self
.part_ops
))]
386 if 0 in self
.register_levels
:
387 m
.d
.sync
+= copy_part_ops
388 m
.d
.sync
+= resized_input_assignments
389 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
391 m
.d
.comb
+= copy_part_ops
392 m
.d
.comb
+= resized_input_assignments
393 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
395 for (value
, term
) in self
._intermediate
_terms
:
396 m
.d
.comb
+= term
.eq(value
)
398 # if there are no full adders to create, then we handle the base cases
399 # and return, otherwise we go on to the recursive case
400 if len(self
.groups
) == 0:
401 if len(self
.inputs
) == 0:
402 # use 0 as the default output value
403 m
.d
.comb
+= self
.output
.eq(0)
404 elif len(self
.inputs
) == 1:
405 # handle single input
406 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
408 # base case for adding 2 inputs
409 assert len(self
.inputs
) == 2
410 adder
= PartitionedAdder(len(self
.output
),
411 self
._reg
_partition
_points
)
412 m
.submodules
.final_adder
= adder
413 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
414 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
415 m
.d
.comb
+= self
.output
.eq(adder
.output
)
418 mask
= self
._reg
_partition
_points
.as_mask(len(self
.output
))
419 m
.d
.comb
+= self
.part_mask
.eq(mask
)
421 # add and link the intermediate term modules
422 for i
, (iidx
, adder_i
) in enumerate(self
.adders
):
423 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
425 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[iidx
])
426 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[iidx
+ 1])
427 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[iidx
+ 2])
428 m
.d
.comb
+= adder_i
.mask
.eq(self
.part_mask
)
432 def create_next_terms(self
):
434 # go on to prepare recursive case
435 intermediate_terms
= []
436 _intermediate_terms
= []
438 def add_intermediate_term(value
):
439 intermediate_term
= Signal(
441 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
442 _intermediate_terms
.append((value
, intermediate_term
))
443 intermediate_terms
.append(intermediate_term
)
445 # store mask in intermediary (simplifies graph)
446 self
.part_mask
= Signal(len(self
.output
), reset_less
=True)
448 # create full adders for this recursive level.
449 # this shrinks N terms to 2 * (N // 3) plus the remainder
451 for i
in self
.groups
:
452 adder_i
= MaskedFullAdder(len(self
.output
))
453 self
.adders
.append((i
, adder_i
))
454 # add both the sum and the masked-carry to the next level.
455 # 3 inputs have now been reduced to 2...
456 add_intermediate_term(adder_i
.sum)
457 add_intermediate_term(adder_i
.mcarry
)
458 # handle the remaining inputs.
459 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
460 add_intermediate_term(self
._resized
_inputs
[-1])
461 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
462 # Just pass the terms to the next layer, since we wouldn't gain
463 # anything by using a half adder since there would still be 2 terms
464 # and just passing the terms to the next layer saves gates.
465 add_intermediate_term(self
._resized
_inputs
[-2])
466 add_intermediate_term(self
._resized
_inputs
[-1])
468 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
470 self
.intermediate_terms
= intermediate_terms
471 self
._intermediate
_terms
= _intermediate_terms
474 class AddReduce(Elaboratable
):
475 """Recursively Add list of numbers together.
477 :attribute inputs: input ``Signal``s to be summed. Modification not
478 supported, except for by ``Signal.eq``.
479 :attribute register_levels: List of nesting levels that should have
481 :attribute output: output sum.
482 :attribute partition_points: the input partition points. Modification not
483 supported, except for by ``Signal.eq``.
486 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
488 """Create an ``AddReduce``.
490 :param inputs: input ``Signal``s to be summed.
491 :param output_width: bit-width of ``output``.
492 :param register_levels: List of nesting levels that should have
494 :param partition_points: the input partition points.
497 self
.part_ops
= part_ops
498 self
.out_part_ops
= [Signal(2, name
=f
"out_part_ops_{i}")
499 for i
in range(len(part_ops
))]
500 self
.output
= Signal(output_width
)
501 self
.output_width
= output_width
502 self
.register_levels
= register_levels
503 self
.partition_points
= partition_points
508 def get_max_level(input_count
):
509 return AddReduceSingle
.get_max_level(input_count
)
512 def next_register_levels(register_levels
):
513 """``Iterable`` of ``register_levels`` for next recursive level."""
514 for level
in register_levels
:
518 def create_levels(self
):
519 """creates reduction levels"""
522 next_levels
= self
.register_levels
523 partition_points
= self
.partition_points
525 part_ops
= self
.part_ops
527 next_level
= AddReduceSingle(inputs
, self
.output_width
, next_levels
,
528 partition_points
, part_ops
)
529 mods
.append(next_level
)
530 if len(next_level
.groups
) == 0:
532 next_levels
= list(AddReduce
.next_register_levels(next_levels
))
533 partition_points
= next_level
._reg
_partition
_points
534 inputs
= next_level
.intermediate_terms
535 part_ops
= next_level
.out_part_ops
539 def elaborate(self
, platform
):
540 """Elaborate this module."""
543 for i
, next_level
in enumerate(self
.levels
):
544 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
546 # output comes from last module
547 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
548 copy_part_ops
= [self
.out_part_ops
[i
].eq(next_level
.out_part_ops
[i
])
549 for i
in range(len(self
.part_ops
))]
550 m
.d
.comb
+= copy_part_ops
556 OP_MUL_SIGNED_HIGH
= 1
557 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
558 OP_MUL_UNSIGNED_HIGH
= 3
561 def get_term(value
, shift
=0, enabled
=None):
562 if enabled
is not None:
563 value
= Mux(enabled
, value
, 0)
565 value
= Cat(Repl(C(0, 1), shift
), value
)
571 class ProductTerm(Elaboratable
):
572 """ this class creates a single product term (a[..]*b[..]).
573 it has a design flaw in that is the *output* that is selected,
574 where the multiplication(s) are combinatorially generated
578 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
579 self
.a_index
= a_index
580 self
.b_index
= b_index
581 shift
= 8 * (self
.a_index
+ self
.b_index
)
587 self
.ti
= Signal(self
.width
, reset_less
=True)
588 self
.term
= Signal(twidth
, reset_less
=True)
589 self
.a
= Signal(twidth
//2, reset_less
=True)
590 self
.b
= Signal(twidth
//2, reset_less
=True)
591 self
.pb_en
= Signal(pbwid
, reset_less
=True)
594 min_index
= min(self
.a_index
, self
.b_index
)
595 max_index
= max(self
.a_index
, self
.b_index
)
596 for i
in range(min_index
, max_index
):
597 tl
.append(self
.pb_en
[i
])
598 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
600 term_enabled
= Signal(name
=name
, reset_less
=True)
603 self
.enabled
= term_enabled
604 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
606 def elaborate(self
, platform
):
609 if self
.enabled
is not None:
610 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
612 bsa
= Signal(self
.width
, reset_less
=True)
613 bsb
= Signal(self
.width
, reset_less
=True)
614 a_index
, b_index
= self
.a_index
, self
.b_index
616 m
.d
.comb
+= bsa
.eq(self
.a
.part(a_index
* pwidth
, pwidth
))
617 m
.d
.comb
+= bsb
.eq(self
.b
.part(b_index
* pwidth
, pwidth
))
618 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
619 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
621 #TODO: sort out width issues, get inputs a/b switched on/off.
622 #data going into Muxes is 1/2 the required width
626 bsa = Signal(self.twidth//2, reset_less=True)
627 bsb = Signal(self.twidth//2, reset_less=True)
628 asel = Signal(width, reset_less=True)
629 bsel = Signal(width, reset_less=True)
630 a_index, b_index = self.a_index, self.b_index
631 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
632 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
633 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
634 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
635 m.d.comb += self.ti.eq(bsa * bsb)
636 m.d.comb += self.term.eq(self.ti)
642 class ProductTerms(Elaboratable
):
643 """ creates a bank of product terms. also performs the actual bit-selection
644 this class is to be wrapped with a for-loop on the "a" operand.
645 it creates a second-level for-loop on the "b" operand.
647 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
648 self
.a_index
= a_index
653 self
.a
= Signal(twidth
//2, reset_less
=True)
654 self
.b
= Signal(twidth
//2, reset_less
=True)
655 self
.pb_en
= Signal(pbwid
, reset_less
=True)
656 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
657 for i
in range(blen
)]
659 def elaborate(self
, platform
):
663 for b_index
in range(self
.blen
):
664 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
665 self
.a_index
, b_index
)
666 setattr(m
.submodules
, "term_%d" % b_index
, t
)
668 m
.d
.comb
+= t
.a
.eq(self
.a
)
669 m
.d
.comb
+= t
.b
.eq(self
.b
)
670 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
672 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
677 class LSBNegTerm(Elaboratable
):
679 def __init__(self
, bit_width
):
680 self
.bit_width
= bit_width
681 self
.part
= Signal(reset_less
=True)
682 self
.signed
= Signal(reset_less
=True)
683 self
.op
= Signal(bit_width
, reset_less
=True)
684 self
.msb
= Signal(reset_less
=True)
685 self
.nt
= Signal(bit_width
*2, reset_less
=True)
686 self
.nl
= Signal(bit_width
*2, reset_less
=True)
688 def elaborate(self
, platform
):
691 bit_wid
= self
.bit_width
692 ext
= Repl(0, bit_wid
) # extend output to HI part
694 # determine sign of each incoming number *in this partition*
695 enabled
= Signal(reset_less
=True)
696 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
698 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
699 # negation operation is split into a bitwise not and a +1.
700 # likewise for 16, 32, and 64-bit values.
702 # width-extended 1s complement if a is signed, otherwise zero
703 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
705 # add 1 if signed, otherwise add zero
706 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
711 class Part(Elaboratable
):
712 """ a key class which, depending on the partitioning, will determine
713 what action to take when parts of the output are signed or unsigned.
715 this requires 2 pieces of data *per operand, per partition*:
716 whether the MSB is HI/LO (per partition!), and whether a signed
717 or unsigned operation has been *requested*.
719 once that is determined, signed is basically carried out
720 by splitting 2's complement into 1's complement plus one.
721 1's complement is just a bit-inversion.
723 the extra terms - as separate terms - are then thrown at the
724 AddReduce alongside the multiplication part-results.
726 def __init__(self
, width
, n_parts
, n_levels
, pbwid
):
731 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
732 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
733 self
.pbs
= Signal(pbwid
, reset_less
=True)
736 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
737 self
.delayed_parts
= [
738 [Signal(name
=f
"delayed_part_{delay}_{i}")
739 for i
in range(n_parts
)]
740 for delay
in range(n_levels
)]
741 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
742 self
.dplast
= [Signal(name
=f
"dplast_{i}")
743 for i
in range(n_parts
)]
745 self
.not_a_term
= Signal(width
)
746 self
.neg_lsb_a_term
= Signal(width
)
747 self
.not_b_term
= Signal(width
)
748 self
.neg_lsb_b_term
= Signal(width
)
750 def elaborate(self
, platform
):
753 pbs
, parts
, delayed_parts
= self
.pbs
, self
.parts
, self
.delayed_parts
754 # negated-temporary copy of partition bits
755 npbs
= Signal
.like(pbs
, reset_less
=True)
756 m
.d
.comb
+= npbs
.eq(~pbs
)
757 byte_count
= 8 // len(parts
)
758 for i
in range(len(parts
)):
760 pbl
.append(npbs
[i
* byte_count
- 1])
761 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
763 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
764 value
= Signal(len(pbl
), name
="value_%di" % i
, reset_less
=True)
765 m
.d
.comb
+= value
.eq(Cat(*pbl
))
766 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
767 m
.d
.comb
+= delayed_parts
[0][i
].eq(parts
[i
])
768 m
.d
.sync
+= [delayed_parts
[j
+ 1][i
].eq(delayed_parts
[j
][i
])
769 for j
in range(len(delayed_parts
)-1)]
770 m
.d
.comb
+= self
.dplast
[i
].eq(delayed_parts
[-1][i
])
772 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= \
773 self
.not_a_term
, self
.neg_lsb_a_term
, \
774 self
.not_b_term
, self
.neg_lsb_b_term
776 byte_width
= 8 // len(parts
) # byte width
777 bit_wid
= 8 * byte_width
# bit width
778 nat
, nbt
, nla
, nlb
= [], [], [], []
779 for i
in range(len(parts
)):
780 # work out bit-inverted and +1 term for a.
781 pa
= LSBNegTerm(bit_wid
)
782 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
783 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
784 m
.d
.comb
+= pa
.op
.eq(self
.a
.part(bit_wid
* i
, bit_wid
))
785 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
786 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
790 # work out bit-inverted and +1 term for b
791 pb
= LSBNegTerm(bit_wid
)
792 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
793 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
794 m
.d
.comb
+= pb
.op
.eq(self
.b
.part(bit_wid
* i
, bit_wid
))
795 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
796 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
800 # concatenate together and return all 4 results.
801 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
802 not_b_term
.eq(Cat(*nbt
)),
803 neg_lsb_a_term
.eq(Cat(*nla
)),
804 neg_lsb_b_term
.eq(Cat(*nlb
)),
810 class IntermediateOut(Elaboratable
):
811 """ selects the HI/LO part of the multiplication, for a given bit-width
812 the output is also reconstructed in its SIMD (partition) lanes.
814 def __init__(self
, width
, out_wid
, n_parts
):
816 self
.n_parts
= n_parts
817 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
819 self
.intermed
= Signal(out_wid
, reset_less
=True)
820 self
.output
= Signal(out_wid
//2, reset_less
=True)
822 def elaborate(self
, platform
):
828 for i
in range(self
.n_parts
):
829 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
831 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
832 self
.intermed
.part(i
* w
*2, w
),
833 self
.intermed
.part(i
* w
*2 + w
, w
)))
835 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
840 class FinalOut(Elaboratable
):
841 """ selects the final output based on the partitioning.
843 each byte is selectable independently, i.e. it is possible
844 that some partitions requested 8-bit computation whilst others
845 requested 16 or 32 bit.
847 def __init__(self
, out_wid
):
849 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
850 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
851 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
853 self
.i8
= Signal(out_wid
, reset_less
=True)
854 self
.i16
= Signal(out_wid
, reset_less
=True)
855 self
.i32
= Signal(out_wid
, reset_less
=True)
856 self
.i64
= Signal(out_wid
, reset_less
=True)
859 self
.out
= Signal(out_wid
, reset_less
=True)
861 def elaborate(self
, platform
):
865 # select one of the outputs: d8 selects i8, d16 selects i16
866 # d32 selects i32, and the default is i64.
867 # d8 and d16 are ORed together in the first Mux
868 # then the 2nd selects either i8 or i16.
869 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
870 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
872 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
873 Mux(self
.d8
[i
], self
.i8
.part(i
* 8, 8),
874 self
.i16
.part(i
* 8, 8)),
875 Mux(self
.d32
[i
// 4], self
.i32
.part(i
* 8, 8),
876 self
.i64
.part(i
* 8, 8))))
878 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
882 class OrMod(Elaboratable
):
883 """ ORs four values together in a hierarchical tree
885 def __init__(self
, wid
):
887 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
889 self
.orout
= Signal(wid
, reset_less
=True)
891 def elaborate(self
, platform
):
893 or1
= Signal(self
.wid
, reset_less
=True)
894 or2
= Signal(self
.wid
, reset_less
=True)
895 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
896 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
897 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
902 class Signs(Elaboratable
):
903 """ determines whether a or b are signed numbers
904 based on the required operation type (OP_MUL_*)
908 self
.part_ops
= Signal(2, reset_less
=True)
909 self
.a_signed
= Signal(reset_less
=True)
910 self
.b_signed
= Signal(reset_less
=True)
912 def elaborate(self
, platform
):
916 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
917 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
918 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
919 m
.d
.comb
+= self
.a_signed
.eq(asig
)
920 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
925 class Mul8_16_32_64(Elaboratable
):
926 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
928 Supports partitioning into any combination of 8, 16, 32, and 64-bit
929 partitions on naturally-aligned boundaries. Supports the operation being
930 set for each partition independently.
932 :attribute part_pts: the input partition points. Has a partition point at
933 multiples of 8 in 0 < i < 64. Each partition point's associated
934 ``Value`` is a ``Signal``. Modification not supported, except for by
936 :attribute part_ops: the operation for each byte. The operation for a
937 particular partition is selected by assigning the selected operation
938 code to each byte in the partition. The allowed operation codes are:
940 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
941 RISC-V's `mul` instruction.
942 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
943 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
945 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
946 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
947 `mulhsu` instruction.
948 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
949 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
953 def __init__(self
, register_levels
=()):
954 """ register_levels: specifies the points in the cascade at which
955 flip-flops are to be inserted.
959 self
.register_levels
= list(register_levels
)
962 self
.part_pts
= PartitionPoints()
963 for i
in range(8, 64, 8):
964 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
965 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
969 # intermediates (needed for unit tests)
970 self
._intermediate
_output
= Signal(128)
973 self
.output
= Signal(64)
975 def elaborate(self
, platform
):
979 pbs
= Signal(8, reset_less
=True)
982 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
983 m
.d
.comb
+= pb
.eq(self
.part_pts
.part_byte(i
))
985 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
992 setattr(m
.submodules
, "signs%d" % i
, s
)
993 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
995 n_levels
= len(self
.register_levels
)+1
996 m
.submodules
.part_8
= part_8
= Part(128, 8, n_levels
, 8)
997 m
.submodules
.part_16
= part_16
= Part(128, 4, n_levels
, 8)
998 m
.submodules
.part_32
= part_32
= Part(128, 2, n_levels
, 8)
999 m
.submodules
.part_64
= part_64
= Part(128, 1, n_levels
, 8)
1000 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
1001 for mod
in [part_8
, part_16
, part_32
, part_64
]:
1002 m
.d
.comb
+= mod
.a
.eq(self
.a
)
1003 m
.d
.comb
+= mod
.b
.eq(self
.b
)
1004 for i
in range(len(signs
)):
1005 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
1006 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
1007 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
1008 nat_l
.append(mod
.not_a_term
)
1009 nbt_l
.append(mod
.not_b_term
)
1010 nla_l
.append(mod
.neg_lsb_a_term
)
1011 nlb_l
.append(mod
.neg_lsb_b_term
)
1015 for a_index
in range(8):
1016 t
= ProductTerms(8, 128, 8, a_index
, 8)
1017 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1019 m
.d
.comb
+= t
.a
.eq(self
.a
)
1020 m
.d
.comb
+= t
.b
.eq(self
.b
)
1021 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1023 for term
in t
.terms
:
1026 # it's fine to bitwise-or data together since they are never enabled
1028 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1029 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1030 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1031 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1032 for l
, mod
in [(nat_l
, nat_or
),
1036 for i
in range(len(l
)):
1037 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1038 terms
.append(mod
.orout
)
1040 expanded_part_pts
= PartitionPoints()
1041 for i
, v
in self
.part_pts
.items():
1042 signal
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
1043 expanded_part_pts
[i
* 2] = signal
1044 m
.d
.comb
+= signal
.eq(v
)
1046 add_reduce
= AddReduce(terms
,
1048 self
.register_levels
,
1052 out_part_ops
= add_reduce
.levels
[-1].out_part_ops
1054 m
.submodules
.add_reduce
= add_reduce
1055 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
1057 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1058 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
1060 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
1063 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1064 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
1066 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1069 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1070 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
1072 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1075 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1076 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
1078 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1081 m
.submodules
.finalout
= finalout
= FinalOut(64)
1082 for i
in range(len(part_8
.delayed_parts
[-1])):
1083 m
.d
.comb
+= finalout
.d8
[i
].eq(part_8
.dplast
[i
])
1084 for i
in range(len(part_16
.delayed_parts
[-1])):
1085 m
.d
.comb
+= finalout
.d16
[i
].eq(part_16
.dplast
[i
])
1086 for i
in range(len(part_32
.delayed_parts
[-1])):
1087 m
.d
.comb
+= finalout
.d32
[i
].eq(part_32
.dplast
[i
])
1088 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
1089 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
1090 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
1091 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
1092 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
1097 if __name__
== "__main__":
1101 m
._intermediate
_output
,
1104 *m
.part_pts
.values()])