1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0, mul
=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
60 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
61 retval
= PartitionPoints()
62 for point
, enabled
in self
.items():
64 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self
.keys()) != set(rhs
.keys()):
70 raise ValueError("incompatible point set")
71 for point
, enabled
in self
.items():
72 yield enabled
.eq(rhs
[point
])
74 def as_mask(self
, width
):
75 """Create a bit-mask from `self`.
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
80 :param width: the bit width of the resulting mask
83 for i
in range(width
):
90 def get_max_partition_count(self
, width
):
91 """Get the maximum number of partitions.
93 Gets the number of partitions when all partition points are enabled.
96 for point
in self
.keys():
101 def fits_in_width(self
, width
):
102 """Check if all partition points are smaller than `width`."""
103 for point
in self
.keys():
108 def part_byte(self
, index
, mfactor
=1): # mfactor used for "expanding"
109 if index
== -1 or index
== 7:
111 assert index
>= 0 and index
< 8
112 return self
[(index
* 8 + 8)*mfactor
]
115 class FullAdder(Elaboratable
):
118 :attribute in0: the first input
119 :attribute in1: the second input
120 :attribute in2: the third input
121 :attribute sum: the sum output
122 :attribute carry: the carry output
124 Rather than do individual full adders (and have an array of them,
125 which would be very slow to simulate), this module can specify the
126 bit width of the inputs and outputs: in effect it performs multiple
127 Full 3-2 Add operations "in parallel".
130 def __init__(self
, width
):
131 """Create a ``FullAdder``.
133 :param width: the bit width of the input and output
135 self
.in0
= Signal(width
)
136 self
.in1
= Signal(width
)
137 self
.in2
= Signal(width
)
138 self
.sum = Signal(width
)
139 self
.carry
= Signal(width
)
141 def elaborate(self
, platform
):
142 """Elaborate this module."""
144 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
145 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
146 |
(self
.in1
& self
.in2
)
147 |
(self
.in2
& self
.in0
))
151 class MaskedFullAdder(Elaboratable
):
152 """Masked Full Adder.
154 :attribute mask: the carry partition mask
155 :attribute in0: the first input
156 :attribute in1: the second input
157 :attribute in2: the third input
158 :attribute sum: the sum output
159 :attribute mcarry: the masked carry output
161 FullAdders are always used with a "mask" on the output. To keep
162 the graphviz "clean", this class performs the masking here rather
163 than inside a large for-loop.
165 See the following discussion as to why this is no longer derived
166 from FullAdder. Each carry is shifted here *before* being ANDed
167 with the mask, so that an AOI cell may be used (which is more
169 https://en.wikipedia.org/wiki/AND-OR-Invert
170 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
173 def __init__(self
, width
):
174 """Create a ``MaskedFullAdder``.
176 :param width: the bit width of the input and output
179 self
.mask
= Signal(width
, reset_less
=True)
180 self
.mcarry
= Signal(width
, reset_less
=True)
181 self
.in0
= Signal(width
, reset_less
=True)
182 self
.in1
= Signal(width
, reset_less
=True)
183 self
.in2
= Signal(width
, reset_less
=True)
184 self
.sum = Signal(width
, reset_less
=True)
186 def elaborate(self
, platform
):
187 """Elaborate this module."""
189 s1
= Signal(self
.width
, reset_less
=True)
190 s2
= Signal(self
.width
, reset_less
=True)
191 s3
= Signal(self
.width
, reset_less
=True)
192 c1
= Signal(self
.width
, reset_less
=True)
193 c2
= Signal(self
.width
, reset_less
=True)
194 c3
= Signal(self
.width
, reset_less
=True)
195 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
196 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
197 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
198 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
199 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
200 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
201 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
202 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
206 class PartitionedAdder(Elaboratable
):
207 """Partitioned Adder.
209 Performs the final add. The partition points are included in the
210 actual add (in one of the operands only), which causes a carry over
211 to the next bit. Then the final output *removes* the extra bits from
214 partition: .... P... P... P... P... (32 bits)
215 a : .... .... .... .... .... (32 bits)
216 b : .... .... .... .... .... (32 bits)
217 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
218 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
219 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
220 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
222 :attribute width: the bit width of the input and output. Read-only.
223 :attribute a: the first input to the adder
224 :attribute b: the second input to the adder
225 :attribute output: the sum output
226 :attribute partition_points: the input partition points. Modification not
227 supported, except for by ``Signal.eq``.
230 def __init__(self
, width
, partition_points
):
231 """Create a ``PartitionedAdder``.
233 :param width: the bit width of the input and output
234 :param partition_points: the input partition points
237 self
.a
= Signal(width
)
238 self
.b
= Signal(width
)
239 self
.output
= Signal(width
)
240 self
.partition_points
= PartitionPoints(partition_points
)
241 if not self
.partition_points
.fits_in_width(width
):
242 raise ValueError("partition_points doesn't fit in width")
244 for i
in range(self
.width
):
245 if i
in self
.partition_points
:
248 self
._expanded
_width
= expanded_width
249 # XXX these have to remain here due to some horrible nmigen
250 # simulation bugs involving sync. it is *not* necessary to
251 # have them here, they should (under normal circumstances)
252 # be moved into elaborate, as they are entirely local
253 self
._expanded
_a
= Signal(expanded_width
) # includes extra part-points
254 self
._expanded
_b
= Signal(expanded_width
) # likewise.
255 self
._expanded
_o
= Signal(expanded_width
) # likewise.
257 def elaborate(self
, platform
):
258 """Elaborate this module."""
261 # store bits in a list, use Cat later. graphviz is much cleaner
262 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
264 # partition points are "breaks" (extra zeros or 1s) in what would
265 # otherwise be a massive long add. when the "break" points are 0,
266 # whatever is in it (in the output) is discarded. however when
267 # there is a "1", it causes a roll-over carry to the *next* bit.
268 # we still ignore the "break" bit in the [intermediate] output,
269 # however by that time we've got the effect that we wanted: the
270 # carry has been carried *over* the break point.
272 for i
in range(self
.width
):
273 if i
in self
.partition_points
:
274 # add extra bit set to 0 + 0 for enabled partition points
275 # and 1 + 0 for disabled partition points
276 ea
.append(self
._expanded
_a
[expanded_index
])
277 al
.append(~self
.partition_points
[i
]) # add extra bit in a
278 eb
.append(self
._expanded
_b
[expanded_index
])
279 bl
.append(C(0)) # yes, add a zero
280 expanded_index
+= 1 # skip the extra point. NOT in the output
281 ea
.append(self
._expanded
_a
[expanded_index
])
282 eb
.append(self
._expanded
_b
[expanded_index
])
283 eo
.append(self
._expanded
_o
[expanded_index
])
286 ol
.append(self
.output
[i
])
289 # combine above using Cat
290 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
291 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
292 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
294 # use only one addition to take advantage of look-ahead carry and
295 # special hardware on FPGAs
296 m
.d
.comb
+= self
._expanded
_o
.eq(
297 self
._expanded
_a
+ self
._expanded
_b
)
301 FULL_ADDER_INPUT_COUNT
= 3
305 def __init__(self
, ppoints
, output_width
, n_parts
):
306 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}")
307 for i
in range(n_parts
)]
308 self
.inputs
= [Signal(output_width
, name
=f
"inputs[{i}]")
309 for i
in range(len(self
.inputs
))]
310 self
.reg_partition_points
= partition_points
.like()
313 return [self
.reg_partition_points
.eq(rhs
.reg_partition_points
)] + \
314 [self
.inputs
[i
].eq(rhs
.inputs
[i
])
315 for i
in range(len(self
.inputs
))] + \
316 [self
.part_ops
[i
].eq(rhs
.part_ops
[i
])
317 for i
in range(len(self
.part_ops
))]
320 class FinalAdd(Elaboratable
):
321 """ Final stage of add reduce
324 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
326 self
.part_ops
= part_ops
327 self
.out_part_ops
= [Signal(2, name
=f
"out_part_ops_{i}")
328 for i
in range(len(part_ops
))]
329 self
.inputs
= list(inputs
)
330 self
._resized
_inputs
= [
331 Signal(output_width
, name
=f
"resized_inputs[{i}]")
332 for i
in range(len(self
.inputs
))]
333 self
.register_levels
= list(register_levels
)
334 self
.output
= Signal(output_width
)
335 self
.partition_points
= PartitionPoints(partition_points
)
336 if not self
.partition_points
.fits_in_width(output_width
):
337 raise ValueError("partition_points doesn't fit in output_width")
338 self
._reg
_partition
_points
= self
.partition_points
.like()
340 def elaborate(self
, platform
):
341 """Elaborate this module."""
344 # resize inputs to correct bit-width and optionally add in
346 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
347 for i
in range(len(self
.inputs
))]
348 copy_part_ops
= [self
.out_part_ops
[i
].eq(self
.part_ops
[i
])
349 for i
in range(len(self
.part_ops
))]
350 if 0 in self
.register_levels
:
351 m
.d
.sync
+= copy_part_ops
352 m
.d
.sync
+= resized_input_assignments
353 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
355 m
.d
.comb
+= copy_part_ops
356 m
.d
.comb
+= resized_input_assignments
357 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
359 if len(self
.inputs
) == 0:
360 # use 0 as the default output value
361 m
.d
.comb
+= self
.output
.eq(0)
362 elif len(self
.inputs
) == 1:
363 # handle single input
364 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
366 # base case for adding 2 inputs
367 assert len(self
.inputs
) == 2
368 adder
= PartitionedAdder(len(self
.output
),
369 self
._reg
_partition
_points
)
370 m
.submodules
.final_adder
= adder
371 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
372 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
373 m
.d
.comb
+= self
.output
.eq(adder
.output
)
377 class AddReduceSingle(Elaboratable
):
378 """Add list of numbers together.
380 :attribute inputs: input ``Signal``s to be summed. Modification not
381 supported, except for by ``Signal.eq``.
382 :attribute register_levels: List of nesting levels that should have
384 :attribute output: output sum.
385 :attribute partition_points: the input partition points. Modification not
386 supported, except for by ``Signal.eq``.
389 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
391 """Create an ``AddReduce``.
393 :param inputs: input ``Signal``s to be summed.
394 :param output_width: bit-width of ``output``.
395 :param register_levels: List of nesting levels that should have
397 :param partition_points: the input partition points.
399 self
.output_width
= output_width
400 self
.part_ops
= part_ops
401 self
.out_part_ops
= [Signal(2, name
=f
"out_part_ops_{i}")
402 for i
in range(len(part_ops
))]
403 self
.inputs
= list(inputs
)
404 self
._resized
_inputs
= [
405 Signal(output_width
, name
=f
"resized_inputs[{i}]")
406 for i
in range(len(self
.inputs
))]
407 self
.register_levels
= list(register_levels
)
408 self
.partition_points
= PartitionPoints(partition_points
)
409 if not self
.partition_points
.fits_in_width(output_width
):
410 raise ValueError("partition_points doesn't fit in output_width")
411 self
._reg
_partition
_points
= self
.partition_points
.like()
413 max_level
= AddReduceSingle
.get_max_level(len(self
.inputs
))
414 for level
in self
.register_levels
:
415 if level
> max_level
:
417 "not enough adder levels for specified register levels")
419 # this is annoying. we have to create the modules (and terms)
420 # because we need to know what they are (in order to set up the
421 # interconnects back in AddReduce), but cannot do the m.d.comb +=
422 # etc because this is not in elaboratable.
423 self
.groups
= AddReduceSingle
.full_adder_groups(len(self
.inputs
))
424 self
._intermediate
_terms
= []
425 if len(self
.groups
) != 0:
426 self
.create_next_terms()
429 def get_max_level(input_count
):
430 """Get the maximum level.
432 All ``register_levels`` must be less than or equal to the maximum
437 groups
= AddReduceSingle
.full_adder_groups(input_count
)
440 input_count
%= FULL_ADDER_INPUT_COUNT
441 input_count
+= 2 * len(groups
)
445 def full_adder_groups(input_count
):
446 """Get ``inputs`` indices for which a full adder should be built."""
448 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
449 FULL_ADDER_INPUT_COUNT
)
451 def elaborate(self
, platform
):
452 """Elaborate this module."""
455 # resize inputs to correct bit-width and optionally add in
457 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
458 for i
in range(len(self
.inputs
))]
459 copy_part_ops
= [self
.out_part_ops
[i
].eq(self
.part_ops
[i
])
460 for i
in range(len(self
.part_ops
))]
461 if 0 in self
.register_levels
:
462 m
.d
.sync
+= copy_part_ops
463 m
.d
.sync
+= resized_input_assignments
464 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
466 m
.d
.comb
+= copy_part_ops
467 m
.d
.comb
+= resized_input_assignments
468 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
470 for (value
, term
) in self
._intermediate
_terms
:
471 m
.d
.comb
+= term
.eq(value
)
473 mask
= self
._reg
_partition
_points
.as_mask(self
.output_width
)
474 m
.d
.comb
+= self
.part_mask
.eq(mask
)
476 # add and link the intermediate term modules
477 for i
, (iidx
, adder_i
) in enumerate(self
.adders
):
478 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
480 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[iidx
])
481 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[iidx
+ 1])
482 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[iidx
+ 2])
483 m
.d
.comb
+= adder_i
.mask
.eq(self
.part_mask
)
487 def create_next_terms(self
):
489 # go on to prepare recursive case
490 intermediate_terms
= []
491 _intermediate_terms
= []
493 def add_intermediate_term(value
):
494 intermediate_term
= Signal(
496 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
497 _intermediate_terms
.append((value
, intermediate_term
))
498 intermediate_terms
.append(intermediate_term
)
500 # store mask in intermediary (simplifies graph)
501 self
.part_mask
= Signal(self
.output_width
, reset_less
=True)
503 # create full adders for this recursive level.
504 # this shrinks N terms to 2 * (N // 3) plus the remainder
506 for i
in self
.groups
:
507 adder_i
= MaskedFullAdder(self
.output_width
)
508 self
.adders
.append((i
, adder_i
))
509 # add both the sum and the masked-carry to the next level.
510 # 3 inputs have now been reduced to 2...
511 add_intermediate_term(adder_i
.sum)
512 add_intermediate_term(adder_i
.mcarry
)
513 # handle the remaining inputs.
514 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
515 add_intermediate_term(self
._resized
_inputs
[-1])
516 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
517 # Just pass the terms to the next layer, since we wouldn't gain
518 # anything by using a half adder since there would still be 2 terms
519 # and just passing the terms to the next layer saves gates.
520 add_intermediate_term(self
._resized
_inputs
[-2])
521 add_intermediate_term(self
._resized
_inputs
[-1])
523 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
525 self
.intermediate_terms
= intermediate_terms
526 self
._intermediate
_terms
= _intermediate_terms
529 class AddReduce(Elaboratable
):
530 """Recursively Add list of numbers together.
532 :attribute inputs: input ``Signal``s to be summed. Modification not
533 supported, except for by ``Signal.eq``.
534 :attribute register_levels: List of nesting levels that should have
536 :attribute output: output sum.
537 :attribute partition_points: the input partition points. Modification not
538 supported, except for by ``Signal.eq``.
541 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
543 """Create an ``AddReduce``.
545 :param inputs: input ``Signal``s to be summed.
546 :param output_width: bit-width of ``output``.
547 :param register_levels: List of nesting levels that should have
549 :param partition_points: the input partition points.
552 self
.part_ops
= part_ops
553 self
.out_part_ops
= [Signal(2, name
=f
"out_part_ops_{i}")
554 for i
in range(len(part_ops
))]
555 self
.output
= Signal(output_width
)
556 self
.output_width
= output_width
557 self
.register_levels
= register_levels
558 self
.partition_points
= partition_points
563 def get_max_level(input_count
):
564 return AddReduceSingle
.get_max_level(input_count
)
567 def next_register_levels(register_levels
):
568 """``Iterable`` of ``register_levels`` for next recursive level."""
569 for level
in register_levels
:
573 def create_levels(self
):
574 """creates reduction levels"""
577 next_levels
= self
.register_levels
578 partition_points
= self
.partition_points
580 part_ops
= self
.part_ops
582 next_level
= AddReduceSingle(inputs
, self
.output_width
, next_levels
,
583 partition_points
, part_ops
)
584 mods
.append(next_level
)
585 next_levels
= list(AddReduce
.next_register_levels(next_levels
))
586 partition_points
= next_level
._reg
_partition
_points
587 inputs
= next_level
.intermediate_terms
588 part_ops
= next_level
.out_part_ops
589 groups
= AddReduceSingle
.full_adder_groups(len(inputs
))
593 next_level
= FinalAdd(inputs
, self
.output_width
, next_levels
,
594 partition_points
, part_ops
)
595 mods
.append(next_level
)
599 def elaborate(self
, platform
):
600 """Elaborate this module."""
603 for i
, next_level
in enumerate(self
.levels
):
604 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
606 # output comes from last module
607 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
608 copy_part_ops
= [self
.out_part_ops
[i
].eq(next_level
.out_part_ops
[i
])
609 for i
in range(len(self
.part_ops
))]
610 m
.d
.comb
+= copy_part_ops
616 OP_MUL_SIGNED_HIGH
= 1
617 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
618 OP_MUL_UNSIGNED_HIGH
= 3
621 def get_term(value
, shift
=0, enabled
=None):
622 if enabled
is not None:
623 value
= Mux(enabled
, value
, 0)
625 value
= Cat(Repl(C(0, 1), shift
), value
)
631 class ProductTerm(Elaboratable
):
632 """ this class creates a single product term (a[..]*b[..]).
633 it has a design flaw in that is the *output* that is selected,
634 where the multiplication(s) are combinatorially generated
638 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
639 self
.a_index
= a_index
640 self
.b_index
= b_index
641 shift
= 8 * (self
.a_index
+ self
.b_index
)
647 self
.ti
= Signal(self
.width
, reset_less
=True)
648 self
.term
= Signal(twidth
, reset_less
=True)
649 self
.a
= Signal(twidth
//2, reset_less
=True)
650 self
.b
= Signal(twidth
//2, reset_less
=True)
651 self
.pb_en
= Signal(pbwid
, reset_less
=True)
654 min_index
= min(self
.a_index
, self
.b_index
)
655 max_index
= max(self
.a_index
, self
.b_index
)
656 for i
in range(min_index
, max_index
):
657 tl
.append(self
.pb_en
[i
])
658 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
660 term_enabled
= Signal(name
=name
, reset_less
=True)
663 self
.enabled
= term_enabled
664 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
666 def elaborate(self
, platform
):
669 if self
.enabled
is not None:
670 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
672 bsa
= Signal(self
.width
, reset_less
=True)
673 bsb
= Signal(self
.width
, reset_less
=True)
674 a_index
, b_index
= self
.a_index
, self
.b_index
676 m
.d
.comb
+= bsa
.eq(self
.a
.part(a_index
* pwidth
, pwidth
))
677 m
.d
.comb
+= bsb
.eq(self
.b
.part(b_index
* pwidth
, pwidth
))
678 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
679 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
681 #TODO: sort out width issues, get inputs a/b switched on/off.
682 #data going into Muxes is 1/2 the required width
686 bsa = Signal(self.twidth//2, reset_less=True)
687 bsb = Signal(self.twidth//2, reset_less=True)
688 asel = Signal(width, reset_less=True)
689 bsel = Signal(width, reset_less=True)
690 a_index, b_index = self.a_index, self.b_index
691 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
692 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
693 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
694 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
695 m.d.comb += self.ti.eq(bsa * bsb)
696 m.d.comb += self.term.eq(self.ti)
702 class ProductTerms(Elaboratable
):
703 """ creates a bank of product terms. also performs the actual bit-selection
704 this class is to be wrapped with a for-loop on the "a" operand.
705 it creates a second-level for-loop on the "b" operand.
707 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
708 self
.a_index
= a_index
713 self
.a
= Signal(twidth
//2, reset_less
=True)
714 self
.b
= Signal(twidth
//2, reset_less
=True)
715 self
.pb_en
= Signal(pbwid
, reset_less
=True)
716 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
717 for i
in range(blen
)]
719 def elaborate(self
, platform
):
723 for b_index
in range(self
.blen
):
724 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
725 self
.a_index
, b_index
)
726 setattr(m
.submodules
, "term_%d" % b_index
, t
)
728 m
.d
.comb
+= t
.a
.eq(self
.a
)
729 m
.d
.comb
+= t
.b
.eq(self
.b
)
730 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
732 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
737 class LSBNegTerm(Elaboratable
):
739 def __init__(self
, bit_width
):
740 self
.bit_width
= bit_width
741 self
.part
= Signal(reset_less
=True)
742 self
.signed
= Signal(reset_less
=True)
743 self
.op
= Signal(bit_width
, reset_less
=True)
744 self
.msb
= Signal(reset_less
=True)
745 self
.nt
= Signal(bit_width
*2, reset_less
=True)
746 self
.nl
= Signal(bit_width
*2, reset_less
=True)
748 def elaborate(self
, platform
):
751 bit_wid
= self
.bit_width
752 ext
= Repl(0, bit_wid
) # extend output to HI part
754 # determine sign of each incoming number *in this partition*
755 enabled
= Signal(reset_less
=True)
756 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
758 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
759 # negation operation is split into a bitwise not and a +1.
760 # likewise for 16, 32, and 64-bit values.
762 # width-extended 1s complement if a is signed, otherwise zero
763 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
765 # add 1 if signed, otherwise add zero
766 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
771 class Parts(Elaboratable
):
773 def __init__(self
, pbwid
, epps
, n_parts
):
776 self
.epps
= PartitionPoints
.like(epps
, name
="epps") # expanded points
778 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
780 def elaborate(self
, platform
):
783 epps
, parts
= self
.epps
, self
.parts
784 # collect part-bytes (double factor because the input is extended)
785 pbs
= Signal(self
.pbwid
, reset_less
=True)
787 for i
in range(self
.pbwid
):
788 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
789 m
.d
.comb
+= pb
.eq(epps
.part_byte(i
, mfactor
=2)) # double
791 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
793 # negated-temporary copy of partition bits
794 npbs
= Signal
.like(pbs
, reset_less
=True)
795 m
.d
.comb
+= npbs
.eq(~pbs
)
796 byte_count
= 8 // len(parts
)
797 for i
in range(len(parts
)):
799 pbl
.append(npbs
[i
* byte_count
- 1])
800 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
802 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
803 value
= Signal(len(pbl
), name
="value_%d" % i
, reset_less
=True)
804 m
.d
.comb
+= value
.eq(Cat(*pbl
))
805 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
810 class Part(Elaboratable
):
811 """ a key class which, depending on the partitioning, will determine
812 what action to take when parts of the output are signed or unsigned.
814 this requires 2 pieces of data *per operand, per partition*:
815 whether the MSB is HI/LO (per partition!), and whether a signed
816 or unsigned operation has been *requested*.
818 once that is determined, signed is basically carried out
819 by splitting 2's complement into 1's complement plus one.
820 1's complement is just a bit-inversion.
822 the extra terms - as separate terms - are then thrown at the
823 AddReduce alongside the multiplication part-results.
825 def __init__(self
, epps
, width
, n_parts
, n_levels
, pbwid
):
833 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
834 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
835 self
.pbs
= Signal(pbwid
, reset_less
=True)
838 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
840 self
.not_a_term
= Signal(width
)
841 self
.neg_lsb_a_term
= Signal(width
)
842 self
.not_b_term
= Signal(width
)
843 self
.neg_lsb_b_term
= Signal(width
)
845 def elaborate(self
, platform
):
848 pbs
, parts
= self
.pbs
, self
.parts
850 m
.submodules
.p
= p
= Parts(self
.pbwid
, epps
, len(parts
))
851 m
.d
.comb
+= p
.epps
.eq(epps
)
854 byte_count
= 8 // len(parts
)
856 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= (
857 self
.not_a_term
, self
.neg_lsb_a_term
,
858 self
.not_b_term
, self
.neg_lsb_b_term
)
860 byte_width
= 8 // len(parts
) # byte width
861 bit_wid
= 8 * byte_width
# bit width
862 nat
, nbt
, nla
, nlb
= [], [], [], []
863 for i
in range(len(parts
)):
864 # work out bit-inverted and +1 term for a.
865 pa
= LSBNegTerm(bit_wid
)
866 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
867 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
868 m
.d
.comb
+= pa
.op
.eq(self
.a
.part(bit_wid
* i
, bit_wid
))
869 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
870 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
874 # work out bit-inverted and +1 term for b
875 pb
= LSBNegTerm(bit_wid
)
876 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
877 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
878 m
.d
.comb
+= pb
.op
.eq(self
.b
.part(bit_wid
* i
, bit_wid
))
879 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
880 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
884 # concatenate together and return all 4 results.
885 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
886 not_b_term
.eq(Cat(*nbt
)),
887 neg_lsb_a_term
.eq(Cat(*nla
)),
888 neg_lsb_b_term
.eq(Cat(*nlb
)),
894 class IntermediateOut(Elaboratable
):
895 """ selects the HI/LO part of the multiplication, for a given bit-width
896 the output is also reconstructed in its SIMD (partition) lanes.
898 def __init__(self
, width
, out_wid
, n_parts
):
900 self
.n_parts
= n_parts
901 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
903 self
.intermed
= Signal(out_wid
, reset_less
=True)
904 self
.output
= Signal(out_wid
//2, reset_less
=True)
906 def elaborate(self
, platform
):
912 for i
in range(self
.n_parts
):
913 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
915 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
916 self
.intermed
.part(i
* w
*2, w
),
917 self
.intermed
.part(i
* w
*2 + w
, w
)))
919 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
924 class FinalOut(Elaboratable
):
925 """ selects the final output based on the partitioning.
927 each byte is selectable independently, i.e. it is possible
928 that some partitions requested 8-bit computation whilst others
929 requested 16 or 32 bit.
931 def __init__(self
, out_wid
):
933 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
934 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
935 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
937 self
.i8
= Signal(out_wid
, reset_less
=True)
938 self
.i16
= Signal(out_wid
, reset_less
=True)
939 self
.i32
= Signal(out_wid
, reset_less
=True)
940 self
.i64
= Signal(out_wid
, reset_less
=True)
943 self
.out
= Signal(out_wid
, reset_less
=True)
945 def elaborate(self
, platform
):
949 # select one of the outputs: d8 selects i8, d16 selects i16
950 # d32 selects i32, and the default is i64.
951 # d8 and d16 are ORed together in the first Mux
952 # then the 2nd selects either i8 or i16.
953 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
954 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
956 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
957 Mux(self
.d8
[i
], self
.i8
.part(i
* 8, 8),
958 self
.i16
.part(i
* 8, 8)),
959 Mux(self
.d32
[i
// 4], self
.i32
.part(i
* 8, 8),
960 self
.i64
.part(i
* 8, 8))))
962 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
966 class OrMod(Elaboratable
):
967 """ ORs four values together in a hierarchical tree
969 def __init__(self
, wid
):
971 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
973 self
.orout
= Signal(wid
, reset_less
=True)
975 def elaborate(self
, platform
):
977 or1
= Signal(self
.wid
, reset_less
=True)
978 or2
= Signal(self
.wid
, reset_less
=True)
979 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
980 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
981 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
986 class Signs(Elaboratable
):
987 """ determines whether a or b are signed numbers
988 based on the required operation type (OP_MUL_*)
992 self
.part_ops
= Signal(2, reset_less
=True)
993 self
.a_signed
= Signal(reset_less
=True)
994 self
.b_signed
= Signal(reset_less
=True)
996 def elaborate(self
, platform
):
1000 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
1001 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
1002 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
1003 m
.d
.comb
+= self
.a_signed
.eq(asig
)
1004 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
1009 class Mul8_16_32_64(Elaboratable
):
1010 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1012 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1013 partitions on naturally-aligned boundaries. Supports the operation being
1014 set for each partition independently.
1016 :attribute part_pts: the input partition points. Has a partition point at
1017 multiples of 8 in 0 < i < 64. Each partition point's associated
1018 ``Value`` is a ``Signal``. Modification not supported, except for by
1020 :attribute part_ops: the operation for each byte. The operation for a
1021 particular partition is selected by assigning the selected operation
1022 code to each byte in the partition. The allowed operation codes are:
1024 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1025 RISC-V's `mul` instruction.
1026 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1027 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1029 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1030 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1031 `mulhsu` instruction.
1032 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1033 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1037 def __init__(self
, register_levels
=()):
1038 """ register_levels: specifies the points in the cascade at which
1039 flip-flops are to be inserted.
1043 self
.register_levels
= list(register_levels
)
1046 self
.part_pts
= PartitionPoints()
1047 for i
in range(8, 64, 8):
1048 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
1049 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
1053 # intermediates (needed for unit tests)
1054 self
._intermediate
_output
= Signal(128)
1057 self
.output
= Signal(64)
1059 def elaborate(self
, platform
):
1062 # collect part-bytes
1063 pbs
= Signal(8, reset_less
=True)
1066 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
1067 m
.d
.comb
+= pb
.eq(self
.part_pts
.part_byte(i
))
1069 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
1071 # create (doubled) PartitionPoints (output is double input width)
1072 expanded_part_pts
= eps
= PartitionPoints()
1073 for i
, v
in self
.part_pts
.items():
1074 ep
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
1075 expanded_part_pts
[i
* 2] = ep
1076 m
.d
.comb
+= ep
.eq(v
)
1083 setattr(m
.submodules
, "signs%d" % i
, s
)
1084 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
1086 n_levels
= len(self
.register_levels
)+1
1087 m
.submodules
.part_8
= part_8
= Part(eps
, 128, 8, n_levels
, 8)
1088 m
.submodules
.part_16
= part_16
= Part(eps
, 128, 4, n_levels
, 8)
1089 m
.submodules
.part_32
= part_32
= Part(eps
, 128, 2, n_levels
, 8)
1090 m
.submodules
.part_64
= part_64
= Part(eps
, 128, 1, n_levels
, 8)
1091 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
1092 for mod
in [part_8
, part_16
, part_32
, part_64
]:
1093 m
.d
.comb
+= mod
.a
.eq(self
.a
)
1094 m
.d
.comb
+= mod
.b
.eq(self
.b
)
1095 for i
in range(len(signs
)):
1096 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
1097 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
1098 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
1099 nat_l
.append(mod
.not_a_term
)
1100 nbt_l
.append(mod
.not_b_term
)
1101 nla_l
.append(mod
.neg_lsb_a_term
)
1102 nlb_l
.append(mod
.neg_lsb_b_term
)
1106 for a_index
in range(8):
1107 t
= ProductTerms(8, 128, 8, a_index
, 8)
1108 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1110 m
.d
.comb
+= t
.a
.eq(self
.a
)
1111 m
.d
.comb
+= t
.b
.eq(self
.b
)
1112 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1114 for term
in t
.terms
:
1117 # it's fine to bitwise-or data together since they are never enabled
1119 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1120 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1121 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1122 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1123 for l
, mod
in [(nat_l
, nat_or
),
1127 for i
in range(len(l
)):
1128 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1129 terms
.append(mod
.orout
)
1131 add_reduce
= AddReduce(terms
,
1133 self
.register_levels
,
1137 out_part_ops
= add_reduce
.levels
[-1].out_part_ops
1138 out_part_pts
= add_reduce
.levels
[-1]._reg
_partition
_points
1140 m
.submodules
.add_reduce
= add_reduce
1141 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
1143 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1144 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
1146 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
1149 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1150 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
1152 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1155 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1156 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
1158 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1161 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1162 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
1164 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1166 m
.submodules
.p_8
= p_8
= Parts(8, eps
, len(part_8
.parts
))
1167 m
.submodules
.p_16
= p_16
= Parts(8, eps
, len(part_16
.parts
))
1168 m
.submodules
.p_32
= p_32
= Parts(8, eps
, len(part_32
.parts
))
1169 m
.submodules
.p_64
= p_64
= Parts(8, eps
, len(part_64
.parts
))
1171 m
.d
.comb
+= p_8
.epps
.eq(out_part_pts
)
1172 m
.d
.comb
+= p_16
.epps
.eq(out_part_pts
)
1173 m
.d
.comb
+= p_32
.epps
.eq(out_part_pts
)
1174 m
.d
.comb
+= p_64
.epps
.eq(out_part_pts
)
1177 m
.submodules
.finalout
= finalout
= FinalOut(64)
1178 for i
in range(len(part_8
.parts
)):
1179 m
.d
.comb
+= finalout
.d8
[i
].eq(p_8
.parts
[i
])
1180 for i
in range(len(part_16
.parts
)):
1181 m
.d
.comb
+= finalout
.d16
[i
].eq(p_16
.parts
[i
])
1182 for i
in range(len(part_32
.parts
)):
1183 m
.d
.comb
+= finalout
.d32
[i
].eq(p_32
.parts
[i
])
1184 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
1185 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
1186 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
1187 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
1188 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
1193 if __name__
== "__main__":
1197 m
._intermediate
_output
,
1200 *m
.part_pts
.values()])