1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
11 from ieee754
.pipeline
import PipelineSpec
12 from nmutil
.pipemodbase
import PipeModBase
14 from ieee754
.part_mul_add
.partpoints
import PartitionPoints
15 from ieee754
.part_mul_add
.adder
import PartitionedAdder
, MaskedFullAdder
18 FULL_ADDER_INPUT_COUNT
= 3
22 def __init__(self
, part_pts
, n_inputs
, output_width
, n_parts
):
23 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
24 for i
in range(n_parts
)]
25 self
.terms
= [Signal(output_width
, name
=f
"terms_{i}",
27 for i
in range(n_inputs
)]
28 self
.part_pts
= part_pts
.like()
30 def eq_from(self
, part_pts
, inputs
, part_ops
):
31 return [self
.part_pts
.eq(part_pts
)] + \
32 [self
.terms
[i
].eq(inputs
[i
])
33 for i
in range(len(self
.terms
))] + \
34 [self
.part_ops
[i
].eq(part_ops
[i
])
35 for i
in range(len(self
.part_ops
))]
38 return self
.eq_from(rhs
.part_pts
, rhs
.terms
, rhs
.part_ops
)
41 class FinalReduceData
:
43 def __init__(self
, part_pts
, output_width
, n_parts
):
44 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
45 for i
in range(n_parts
)]
46 self
.output
= Signal(output_width
, reset_less
=True)
47 self
.part_pts
= part_pts
.like()
49 def eq_from(self
, part_pts
, output
, part_ops
):
50 return [self
.part_pts
.eq(part_pts
)] + \
51 [self
.output
.eq(output
)] + \
52 [self
.part_ops
[i
].eq(part_ops
[i
])
53 for i
in range(len(self
.part_ops
))]
56 return self
.eq_from(rhs
.part_pts
, rhs
.output
, rhs
.part_ops
)
59 class FinalAdd(PipeModBase
):
60 """ Final stage of add reduce
63 def __init__(self
, pspec
, lidx
, n_inputs
, partition_points
,
66 self
.partition_step
= partition_step
67 self
.output_width
= pspec
.width
* 2
68 self
.n_inputs
= n_inputs
69 self
.n_parts
= pspec
.n_parts
70 self
.partition_points
= PartitionPoints(partition_points
)
71 if not self
.partition_points
.fits_in_width(self
.output_width
):
72 raise ValueError("partition_points doesn't fit in output_width")
74 super().__init
__(pspec
, "finaladd")
77 return AddReduceData(self
.partition_points
, self
.n_inputs
,
78 self
.output_width
, self
.n_parts
)
81 return FinalReduceData(self
.partition_points
,
82 self
.output_width
, self
.n_parts
)
84 def elaborate(self
, platform
):
85 """Elaborate this module."""
88 output_width
= self
.output_width
89 output
= Signal(output_width
, reset_less
=True)
90 if self
.n_inputs
== 0:
91 # use 0 as the default output value
92 m
.d
.comb
+= output
.eq(0)
93 elif self
.n_inputs
== 1:
95 m
.d
.comb
+= output
.eq(self
.i
.terms
[0])
97 # base case for adding 2 inputs
98 assert self
.n_inputs
== 2
99 adder
= PartitionedAdder(output_width
,
100 self
.i
.part_pts
, self
.partition_step
)
101 m
.submodules
.final_adder
= adder
102 m
.d
.comb
+= adder
.a
.eq(self
.i
.terms
[0])
103 m
.d
.comb
+= adder
.b
.eq(self
.i
.terms
[1])
104 m
.d
.comb
+= output
.eq(adder
.output
)
107 m
.d
.comb
+= self
.o
.eq_from(self
.i
.part_pts
, output
,
113 class AddReduceSingle(PipeModBase
):
114 """Add list of numbers together.
116 :attribute inputs: input ``Signal``s to be summed. Modification not
117 supported, except for by ``Signal.eq``.
118 :attribute register_levels: List of nesting levels that should have
120 :attribute output: output sum.
121 :attribute partition_points: the input partition points. Modification not
122 supported, except for by ``Signal.eq``.
125 def __init__(self
, pspec
, lidx
, n_inputs
, partition_points
,
127 """Create an ``AddReduce``.
129 :param inputs: input ``Signal``s to be summed.
130 :param output_width: bit-width of ``output``.
131 :param partition_points: the input partition points.
134 self
.partition_step
= partition_step
135 self
.n_inputs
= n_inputs
136 self
.n_parts
= pspec
.n_parts
137 self
.output_width
= pspec
.width
* 2
138 self
.partition_points
= PartitionPoints(partition_points
)
139 if not self
.partition_points
.fits_in_width(self
.output_width
):
140 raise ValueError("partition_points doesn't fit in output_width")
142 self
.groups
= AddReduceSingle
.full_adder_groups(n_inputs
)
143 self
.n_terms
= AddReduceSingle
.calc_n_inputs(n_inputs
, self
.groups
)
145 super().__init
__(pspec
, "addreduce_%d" % lidx
)
148 return AddReduceData(self
.partition_points
, self
.n_inputs
,
149 self
.output_width
, self
.n_parts
)
152 return AddReduceData(self
.partition_points
, self
.n_terms
,
153 self
.output_width
, self
.n_parts
)
156 def calc_n_inputs(n_inputs
, groups
):
157 retval
= len(groups
)*2
158 if n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
160 elif n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
163 assert n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
167 def get_max_level(input_count
):
168 """Get the maximum level.
170 All ``register_levels`` must be less than or equal to the maximum
175 groups
= AddReduceSingle
.full_adder_groups(input_count
)
178 input_count
%= FULL_ADDER_INPUT_COUNT
179 input_count
+= 2 * len(groups
)
183 def full_adder_groups(input_count
):
184 """Get ``inputs`` indices for which a full adder should be built."""
186 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
187 FULL_ADDER_INPUT_COUNT
)
189 def create_next_terms(self
):
190 """ create next intermediate terms, for linking up in elaborate, below
195 # create full adders for this recursive level.
196 # this shrinks N terms to 2 * (N // 3) plus the remainder
197 for i
in self
.groups
:
198 adder_i
= MaskedFullAdder(self
.output_width
)
199 adders
.append((i
, adder_i
))
200 # add both the sum and the masked-carry to the next level.
201 # 3 inputs have now been reduced to 2...
202 terms
.append(adder_i
.sum)
203 terms
.append(adder_i
.mcarry
)
204 # handle the remaining inputs.
205 if self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
206 terms
.append(self
.i
.terms
[-1])
207 elif self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
208 # Just pass the terms to the next layer, since we wouldn't gain
209 # anything by using a half adder since there would still be 2 terms
210 # and just passing the terms to the next layer saves gates.
211 terms
.append(self
.i
.terms
[-2])
212 terms
.append(self
.i
.terms
[-1])
214 assert self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
218 def elaborate(self
, platform
):
219 """Elaborate this module."""
222 terms
, adders
= self
.create_next_terms()
224 # copy the intermediate terms to the output
225 for i
, value
in enumerate(terms
):
226 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
228 # copy reg part points and part ops to output
229 m
.d
.comb
+= self
.o
.part_pts
.eq(self
.i
.part_pts
)
230 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
231 for i
in range(len(self
.i
.part_ops
))]
233 # set up the partition mask (for the adders)
234 part_mask
= Signal(self
.output_width
, reset_less
=True)
236 # get partition points as a mask
237 mask
= self
.i
.part_pts
.as_mask(self
.output_width
,
238 mul
=self
.partition_step
)
239 m
.d
.comb
+= part_mask
.eq(mask
)
241 # add and link the intermediate term modules
242 for i
, (iidx
, adder_i
) in enumerate(adders
):
243 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
245 m
.d
.comb
+= adder_i
.in0
.eq(self
.i
.terms
[iidx
])
246 m
.d
.comb
+= adder_i
.in1
.eq(self
.i
.terms
[iidx
+ 1])
247 m
.d
.comb
+= adder_i
.in2
.eq(self
.i
.terms
[iidx
+ 2])
248 m
.d
.comb
+= adder_i
.mask
.eq(part_mask
)
253 class AddReduceInternal
:
254 """Iteratively Add list of numbers together.
256 :attribute inputs: input ``Signal``s to be summed. Modification not
257 supported, except for by ``Signal.eq``.
258 :attribute register_levels: List of nesting levels that should have
260 :attribute output: output sum.
261 :attribute partition_points: the input partition points. Modification not
262 supported, except for by ``Signal.eq``.
265 def __init__(self
, pspec
, n_inputs
, part_pts
, partition_step
=1):
266 """Create an ``AddReduce``.
268 :param inputs: input ``Signal``s to be summed.
269 :param output_width: bit-width of ``output``.
270 :param partition_points: the input partition points.
273 self
.n_inputs
= n_inputs
274 self
.output_width
= pspec
.width
* 2
275 self
.partition_points
= part_pts
276 self
.partition_step
= partition_step
280 def create_levels(self
):
281 """creates reduction levels"""
284 partition_points
= self
.partition_points
287 groups
= AddReduceSingle
.full_adder_groups(ilen
)
291 next_level
= AddReduceSingle(self
.pspec
, lidx
, ilen
,
294 mods
.append(next_level
)
295 partition_points
= next_level
.i
.part_pts
296 ilen
= len(next_level
.o
.terms
)
299 next_level
= FinalAdd(self
.pspec
, lidx
, ilen
,
300 partition_points
, self
.partition_step
)
301 mods
.append(next_level
)
306 class AddReduce(AddReduceInternal
, Elaboratable
):
307 """Recursively Add list of numbers together.
309 :attribute inputs: input ``Signal``s to be summed. Modification not
310 supported, except for by ``Signal.eq``.
311 :attribute register_levels: List of nesting levels that should have
313 :attribute output: output sum.
314 :attribute partition_points: the input partition points. Modification not
315 supported, except for by ``Signal.eq``.
318 def __init__(self
, inputs
, output_width
, register_levels
, part_pts
,
319 part_ops
, partition_step
=1):
320 """Create an ``AddReduce``.
322 :param inputs: input ``Signal``s to be summed.
323 :param output_width: bit-width of ``output``.
324 :param register_levels: List of nesting levels that should have
326 :param partition_points: the input partition points.
328 self
._inputs
= inputs
329 self
._part
_pts
= part_pts
330 self
._part
_ops
= part_ops
331 n_parts
= len(part_ops
)
332 self
.i
= AddReduceData(part_pts
, len(inputs
),
333 output_width
, n_parts
)
334 AddReduceInternal
.__init
__(self
, pspec
, n_inputs
, part_pts
,
336 self
.o
= FinalReduceData(part_pts
, output_width
, n_parts
)
337 self
.register_levels
= register_levels
340 def get_max_level(input_count
):
341 return AddReduceSingle
.get_max_level(input_count
)
344 def next_register_levels(register_levels
):
345 """``Iterable`` of ``register_levels`` for next recursive level."""
346 for level
in register_levels
:
350 def elaborate(self
, platform
):
351 """Elaborate this module."""
354 m
.d
.comb
+= self
.i
.eq_from(self
._part
_pts
, self
._inputs
, self
._part
_ops
)
356 for i
, next_level
in enumerate(self
.levels
):
357 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
360 for idx
in range(len(self
.levels
)):
361 mcur
= self
.levels
[idx
]
362 if idx
in self
.register_levels
:
363 m
.d
.sync
+= mcur
.i
.eq(i
)
365 m
.d
.comb
+= mcur
.i
.eq(i
)
366 i
= mcur
.o
# for next loop
368 # output comes from last module
369 m
.d
.comb
+= self
.o
.eq(i
)
375 OP_MUL_SIGNED_HIGH
= 1
376 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
377 OP_MUL_UNSIGNED_HIGH
= 3
380 def get_term(value
, shift
=0, enabled
=None):
381 if enabled
is not None:
382 value
= Mux(enabled
, value
, 0)
384 value
= Cat(Repl(C(0, 1), shift
), value
)
390 class ProductTerm(Elaboratable
):
391 """ this class creates a single product term (a[..]*b[..]).
392 it has a design flaw in that is the *output* that is selected,
393 where the multiplication(s) are combinatorially generated
397 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
398 self
.a_index
= a_index
399 self
.b_index
= b_index
400 shift
= 8 * (self
.a_index
+ self
.b_index
)
406 self
.ti
= Signal(self
.width
, reset_less
=True)
407 self
.term
= Signal(twidth
, reset_less
=True)
408 self
.a
= Signal(twidth
//2, reset_less
=True)
409 self
.b
= Signal(twidth
//2, reset_less
=True)
410 self
.pb_en
= Signal(pbwid
, reset_less
=True)
413 min_index
= min(self
.a_index
, self
.b_index
)
414 max_index
= max(self
.a_index
, self
.b_index
)
415 for i
in range(min_index
, max_index
):
416 tl
.append(self
.pb_en
[i
])
417 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
419 term_enabled
= Signal(name
=name
, reset_less
=True)
422 self
.enabled
= term_enabled
423 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
425 def elaborate(self
, platform
):
428 if self
.enabled
is not None:
429 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
431 bsa
= Signal(self
.width
, reset_less
=True)
432 bsb
= Signal(self
.width
, reset_less
=True)
433 a_index
, b_index
= self
.a_index
, self
.b_index
435 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
436 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
437 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
438 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
440 #TODO: sort out width issues, get inputs a/b switched on/off.
441 #data going into Muxes is 1/2 the required width
445 bsa = Signal(self.twidth//2, reset_less=True)
446 bsb = Signal(self.twidth//2, reset_less=True)
447 asel = Signal(width, reset_less=True)
448 bsel = Signal(width, reset_less=True)
449 a_index, b_index = self.a_index, self.b_index
450 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
451 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
452 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
453 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
454 m.d.comb += self.ti.eq(bsa * bsb)
455 m.d.comb += self.term.eq(self.ti)
461 class ProductTerms(Elaboratable
):
462 """ creates a bank of product terms. also performs the actual bit-selection
463 this class is to be wrapped with a for-loop on the "a" operand.
464 it creates a second-level for-loop on the "b" operand.
466 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
467 self
.a_index
= a_index
472 self
.a
= Signal(twidth
//2, reset_less
=True)
473 self
.b
= Signal(twidth
//2, reset_less
=True)
474 self
.pb_en
= Signal(pbwid
, reset_less
=True)
475 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
476 for i
in range(blen
)]
478 def elaborate(self
, platform
):
482 for b_index
in range(self
.blen
):
483 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
484 self
.a_index
, b_index
)
485 setattr(m
.submodules
, "term_%d" % b_index
, t
)
487 m
.d
.comb
+= t
.a
.eq(self
.a
)
488 m
.d
.comb
+= t
.b
.eq(self
.b
)
489 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
491 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
496 class LSBNegTerm(Elaboratable
):
498 def __init__(self
, bit_width
):
499 self
.bit_width
= bit_width
500 self
.part
= Signal(reset_less
=True)
501 self
.signed
= Signal(reset_less
=True)
502 self
.op
= Signal(bit_width
, reset_less
=True)
503 self
.msb
= Signal(reset_less
=True)
504 self
.nt
= Signal(bit_width
*2, reset_less
=True)
505 self
.nl
= Signal(bit_width
*2, reset_less
=True)
507 def elaborate(self
, platform
):
510 bit_wid
= self
.bit_width
511 ext
= Repl(0, bit_wid
) # extend output to HI part
513 # determine sign of each incoming number *in this partition*
514 enabled
= Signal(reset_less
=True)
515 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
517 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
518 # negation operation is split into a bitwise not and a +1.
519 # likewise for 16, 32, and 64-bit values.
521 # width-extended 1s complement if a is signed, otherwise zero
522 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
524 # add 1 if signed, otherwise add zero
525 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
530 class Parts(Elaboratable
):
532 def __init__(self
, pbwid
, part_pts
, n_parts
):
535 self
.part_pts
= PartitionPoints
.like(part_pts
)
537 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
538 for i
in range(n_parts
)]
540 def elaborate(self
, platform
):
543 part_pts
, parts
= self
.part_pts
, self
.parts
544 # collect part-bytes (double factor because the input is extended)
545 pbs
= Signal(self
.pbwid
, reset_less
=True)
547 for i
in range(self
.pbwid
):
548 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
549 m
.d
.comb
+= pb
.eq(part_pts
.part_byte(i
))
551 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
553 # negated-temporary copy of partition bits
554 npbs
= Signal
.like(pbs
, reset_less
=True)
555 m
.d
.comb
+= npbs
.eq(~pbs
)
556 byte_count
= 8 // len(parts
)
557 for i
in range(len(parts
)):
559 pbl
.append(npbs
[i
* byte_count
- 1])
560 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
562 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
563 value
= Signal(len(pbl
), name
="value_%d" % i
, reset_less
=True)
564 m
.d
.comb
+= value
.eq(Cat(*pbl
))
565 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
570 class Part(Elaboratable
):
571 """ a key class which, depending on the partitioning, will determine
572 what action to take when parts of the output are signed or unsigned.
574 this requires 2 pieces of data *per operand, per partition*:
575 whether the MSB is HI/LO (per partition!), and whether a signed
576 or unsigned operation has been *requested*.
578 once that is determined, signed is basically carried out
579 by splitting 2's complement into 1's complement plus one.
580 1's complement is just a bit-inversion.
582 the extra terms - as separate terms - are then thrown at the
583 AddReduce alongside the multiplication part-results.
585 def __init__(self
, part_pts
, width
, n_parts
, pbwid
):
588 self
.part_pts
= part_pts
591 self
.a
= Signal(64, reset_less
=True)
592 self
.b
= Signal(64, reset_less
=True)
593 self
.a_signed
= [Signal(name
=f
"a_signed_{i}", reset_less
=True)
595 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}", reset_less
=True)
597 self
.pbs
= Signal(pbwid
, reset_less
=True)
600 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
601 for i
in range(n_parts
)]
603 self
.not_a_term
= Signal(width
, reset_less
=True)
604 self
.neg_lsb_a_term
= Signal(width
, reset_less
=True)
605 self
.not_b_term
= Signal(width
, reset_less
=True)
606 self
.neg_lsb_b_term
= Signal(width
, reset_less
=True)
608 def elaborate(self
, platform
):
611 pbs
, parts
= self
.pbs
, self
.parts
612 part_pts
= self
.part_pts
613 m
.submodules
.p
= p
= Parts(self
.pbwid
, part_pts
, len(parts
))
614 m
.d
.comb
+= p
.part_pts
.eq(part_pts
)
617 byte_count
= 8 // len(parts
)
619 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= (
620 self
.not_a_term
, self
.neg_lsb_a_term
,
621 self
.not_b_term
, self
.neg_lsb_b_term
)
623 byte_width
= 8 // len(parts
) # byte width
624 bit_wid
= 8 * byte_width
# bit width
625 nat
, nbt
, nla
, nlb
= [], [], [], []
626 for i
in range(len(parts
)):
627 # work out bit-inverted and +1 term for a.
628 pa
= LSBNegTerm(bit_wid
)
629 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
630 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
631 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
632 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
633 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
637 # work out bit-inverted and +1 term for b
638 pb
= LSBNegTerm(bit_wid
)
639 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
640 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
641 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
642 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
643 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
647 # concatenate together and return all 4 results.
648 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
649 not_b_term
.eq(Cat(*nbt
)),
650 neg_lsb_a_term
.eq(Cat(*nla
)),
651 neg_lsb_b_term
.eq(Cat(*nlb
)),
657 class IntermediateOut(Elaboratable
):
658 """ selects the HI/LO part of the multiplication, for a given bit-width
659 the output is also reconstructed in its SIMD (partition) lanes.
661 def __init__(self
, width
, out_wid
, n_parts
):
663 self
.n_parts
= n_parts
664 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
666 self
.intermed
= Signal(out_wid
, reset_less
=True)
667 self
.output
= Signal(out_wid
//2, reset_less
=True)
669 def elaborate(self
, platform
):
675 for i
in range(self
.n_parts
):
676 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
678 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
679 self
.intermed
.bit_select(i
* w
*2, w
),
680 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
682 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
687 class FinalOut(PipeModBase
):
688 """ selects the final output based on the partitioning.
690 each byte is selectable independently, i.e. it is possible
691 that some partitions requested 8-bit computation whilst others
692 requested 16 or 32 bit.
694 def __init__(self
, pspec
, part_pts
):
696 self
.part_pts
= part_pts
697 self
.output_width
= pspec
.width
* 2
698 self
.n_parts
= pspec
.n_parts
699 self
.out_wid
= pspec
.width
701 super().__init
__(pspec
, "finalout")
704 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
709 def elaborate(self
, platform
):
712 part_pts
= self
.part_pts
713 m
.submodules
.p_8
= p_8
= Parts(8, part_pts
, 8)
714 m
.submodules
.p_16
= p_16
= Parts(8, part_pts
, 4)
715 m
.submodules
.p_32
= p_32
= Parts(8, part_pts
, 2)
716 m
.submodules
.p_64
= p_64
= Parts(8, part_pts
, 1)
718 out_part_pts
= self
.i
.part_pts
721 d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
722 d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
723 d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
725 i8
= Signal(self
.out_wid
, reset_less
=True)
726 i16
= Signal(self
.out_wid
, reset_less
=True)
727 i32
= Signal(self
.out_wid
, reset_less
=True)
728 i64
= Signal(self
.out_wid
, reset_less
=True)
730 m
.d
.comb
+= p_8
.part_pts
.eq(out_part_pts
)
731 m
.d
.comb
+= p_16
.part_pts
.eq(out_part_pts
)
732 m
.d
.comb
+= p_32
.part_pts
.eq(out_part_pts
)
733 m
.d
.comb
+= p_64
.part_pts
.eq(out_part_pts
)
735 for i
in range(len(p_8
.parts
)):
736 m
.d
.comb
+= d8
[i
].eq(p_8
.parts
[i
])
737 for i
in range(len(p_16
.parts
)):
738 m
.d
.comb
+= d16
[i
].eq(p_16
.parts
[i
])
739 for i
in range(len(p_32
.parts
)):
740 m
.d
.comb
+= d32
[i
].eq(p_32
.parts
[i
])
741 m
.d
.comb
+= i8
.eq(self
.i
.outputs
[0])
742 m
.d
.comb
+= i16
.eq(self
.i
.outputs
[1])
743 m
.d
.comb
+= i32
.eq(self
.i
.outputs
[2])
744 m
.d
.comb
+= i64
.eq(self
.i
.outputs
[3])
748 # select one of the outputs: d8 selects i8, d16 selects i16
749 # d32 selects i32, and the default is i64.
750 # d8 and d16 are ORed together in the first Mux
751 # then the 2nd selects either i8 or i16.
752 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
753 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
755 Mux(d8
[i
] | d16
[i
// 2],
756 Mux(d8
[i
], i8
.bit_select(i
* 8, 8),
757 i16
.bit_select(i
* 8, 8)),
758 Mux(d32
[i
// 4], i32
.bit_select(i
* 8, 8),
759 i64
.bit_select(i
* 8, 8))))
763 m
.d
.comb
+= self
.o
.output
.eq(Cat(*ol
))
764 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.intermediate_output
)
769 class OrMod(Elaboratable
):
770 """ ORs four values together in a hierarchical tree
772 def __init__(self
, wid
):
774 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
776 self
.orout
= Signal(wid
, reset_less
=True)
778 def elaborate(self
, platform
):
780 or1
= Signal(self
.wid
, reset_less
=True)
781 or2
= Signal(self
.wid
, reset_less
=True)
782 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
783 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
784 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
789 class Signs(Elaboratable
):
790 """ determines whether a or b are signed numbers
791 based on the required operation type (OP_MUL_*)
795 self
.part_ops
= Signal(2, reset_less
=True)
796 self
.a_signed
= Signal(reset_less
=True)
797 self
.b_signed
= Signal(reset_less
=True)
799 def elaborate(self
, platform
):
803 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
804 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
805 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
806 m
.d
.comb
+= self
.a_signed
.eq(asig
)
807 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
812 class IntermediateData
:
814 def __init__(self
, part_pts
, output_width
, n_parts
):
815 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
816 for i
in range(n_parts
)]
817 self
.part_pts
= part_pts
.like()
818 self
.outputs
= [Signal(output_width
, name
="io%d" % i
, reset_less
=True)
820 # intermediates (needed for unit tests)
821 self
.intermediate_output
= Signal(output_width
)
823 def eq_from(self
, part_pts
, outputs
, intermediate_output
,
825 return [self
.part_pts
.eq(part_pts
)] + \
826 [self
.intermediate_output
.eq(intermediate_output
)] + \
827 [self
.outputs
[i
].eq(outputs
[i
])
828 for i
in range(4)] + \
829 [self
.part_ops
[i
].eq(part_ops
[i
])
830 for i
in range(len(self
.part_ops
))]
833 return self
.eq_from(rhs
.part_pts
, rhs
.outputs
,
834 rhs
.intermediate_output
, rhs
.part_ops
)
842 self
.part_pts
= PartitionPoints()
843 for i
in range(8, 64, 8):
844 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
845 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
847 def eq_from(self
, part_pts
, a
, b
, part_ops
):
848 return [self
.part_pts
.eq(part_pts
)] + \
849 [self
.a
.eq(a
), self
.b
.eq(b
)] + \
850 [self
.part_ops
[i
].eq(part_ops
[i
])
851 for i
in range(len(self
.part_ops
))]
854 return self
.eq_from(rhs
.part_pts
, rhs
.a
, rhs
.b
, rhs
.part_ops
)
860 self
.intermediate_output
= Signal(128) # needed for unit tests
861 self
.output
= Signal(64)
864 return [self
.intermediate_output
.eq(rhs
.intermediate_output
),
865 self
.output
.eq(rhs
.output
)]
868 class AllTerms(PipeModBase
):
869 """Set of terms to be added together
872 def __init__(self
, pspec
, n_inputs
):
873 """Create an ``AllTerms``.
875 self
.n_inputs
= n_inputs
876 self
.n_parts
= pspec
.n_parts
877 self
.output_width
= pspec
.width
* 2
878 super().__init
__(pspec
, "allterms")
884 return AddReduceData(self
.i
.part_pts
, self
.n_inputs
,
885 self
.output_width
, self
.n_parts
)
887 def elaborate(self
, platform
):
890 eps
= self
.i
.part_pts
893 pbs
= Signal(8, reset_less
=True)
896 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
897 m
.d
.comb
+= pb
.eq(eps
.part_byte(i
))
899 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
906 setattr(m
.submodules
, "signs%d" % i
, s
)
907 m
.d
.comb
+= s
.part_ops
.eq(self
.i
.part_ops
[i
])
909 m
.submodules
.part_8
= part_8
= Part(eps
, 128, 8, 8)
910 m
.submodules
.part_16
= part_16
= Part(eps
, 128, 4, 8)
911 m
.submodules
.part_32
= part_32
= Part(eps
, 128, 2, 8)
912 m
.submodules
.part_64
= part_64
= Part(eps
, 128, 1, 8)
913 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
914 for mod
in [part_8
, part_16
, part_32
, part_64
]:
915 m
.d
.comb
+= mod
.a
.eq(self
.i
.a
)
916 m
.d
.comb
+= mod
.b
.eq(self
.i
.b
)
917 for i
in range(len(signs
)):
918 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
919 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
920 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
921 nat_l
.append(mod
.not_a_term
)
922 nbt_l
.append(mod
.not_b_term
)
923 nla_l
.append(mod
.neg_lsb_a_term
)
924 nlb_l
.append(mod
.neg_lsb_b_term
)
928 for a_index
in range(8):
929 t
= ProductTerms(8, 128, 8, a_index
, 8)
930 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
932 m
.d
.comb
+= t
.a
.eq(self
.i
.a
)
933 m
.d
.comb
+= t
.b
.eq(self
.i
.b
)
934 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
939 # it's fine to bitwise-or data together since they are never enabled
941 m
.submodules
.nat_or
= nat_or
= OrMod(128)
942 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
943 m
.submodules
.nla_or
= nla_or
= OrMod(128)
944 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
945 for l
, mod
in [(nat_l
, nat_or
),
949 for i
in range(len(l
)):
950 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
951 terms
.append(mod
.orout
)
953 # copy the intermediate terms to the output
954 for i
, value
in enumerate(terms
):
955 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
957 # copy reg part points and part ops to output
958 m
.d
.comb
+= self
.o
.part_pts
.eq(eps
)
959 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
960 for i
in range(len(self
.i
.part_ops
))]
965 class Intermediates(PipeModBase
):
966 """ Intermediate output modules
969 def __init__(self
, pspec
, part_pts
):
970 self
.part_pts
= part_pts
971 self
.output_width
= pspec
.width
* 2
972 self
.n_parts
= pspec
.n_parts
974 super().__init
__(pspec
, "intermediates")
977 return FinalReduceData(self
.part_pts
, self
.output_width
, self
.n_parts
)
980 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
982 def elaborate(self
, platform
):
985 out_part_ops
= self
.i
.part_ops
986 out_part_pts
= self
.i
.part_pts
989 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
990 m
.d
.comb
+= io64
.intermed
.eq(self
.i
.output
)
992 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
993 m
.d
.comb
+= self
.o
.outputs
[3].eq(io64
.output
)
996 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
997 m
.d
.comb
+= io32
.intermed
.eq(self
.i
.output
)
999 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1000 m
.d
.comb
+= self
.o
.outputs
[2].eq(io32
.output
)
1003 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1004 m
.d
.comb
+= io16
.intermed
.eq(self
.i
.output
)
1006 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1007 m
.d
.comb
+= self
.o
.outputs
[1].eq(io16
.output
)
1010 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1011 m
.d
.comb
+= io8
.intermed
.eq(self
.i
.output
)
1013 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1014 m
.d
.comb
+= self
.o
.outputs
[0].eq(io8
.output
)
1017 m
.d
.comb
+= self
.o
.part_ops
[i
].eq(out_part_ops
[i
])
1018 m
.d
.comb
+= self
.o
.part_pts
.eq(out_part_pts
)
1019 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.output
)
1024 class Mul8_16_32_64(Elaboratable
):
1025 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1027 XXX NOTE: this class is intended for unit test purposes ONLY.
1029 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1030 partitions on naturally-aligned boundaries. Supports the operation being
1031 set for each partition independently.
1033 :attribute part_pts: the input partition points. Has a partition point at
1034 multiples of 8 in 0 < i < 64. Each partition point's associated
1035 ``Value`` is a ``Signal``. Modification not supported, except for by
1037 :attribute part_ops: the operation for each byte. The operation for a
1038 particular partition is selected by assigning the selected operation
1039 code to each byte in the partition. The allowed operation codes are:
1041 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1042 RISC-V's `mul` instruction.
1043 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1044 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1046 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1047 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1048 `mulhsu` instruction.
1049 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1050 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1054 def __init__(self
, register_levels
=()):
1055 """ register_levels: specifies the points in the cascade at which
1056 flip-flops are to be inserted.
1059 self
.id_wid
= 0 # num_bits(num_rows)
1061 self
.pspec
= PipelineSpec(64, self
.id_wid
, self
.op_wid
, n_ops
=3)
1062 self
.pspec
.n_parts
= 8
1065 self
.register_levels
= list(register_levels
)
1067 self
.i
= self
.ispec()
1068 self
.o
= self
.ospec()
1071 self
.part_pts
= self
.i
.part_pts
1072 self
.part_ops
= self
.i
.part_ops
1077 self
.intermediate_output
= self
.o
.intermediate_output
1078 self
.output
= self
.o
.output
1086 def elaborate(self
, platform
):
1089 part_pts
= self
.part_pts
1092 t
= AllTerms(self
.pspec
, n_inputs
)
1097 at
= AddReduceInternal(self
.pspec
, n_inputs
, part_pts
, partition_step
=2)
1100 for idx
in range(len(at
.levels
)):
1101 mcur
= at
.levels
[idx
]
1104 if idx
in self
.register_levels
:
1105 m
.d
.sync
+= o
.eq(mcur
.process(i
))
1107 m
.d
.comb
+= o
.eq(mcur
.process(i
))
1108 i
= o
# for next loop
1110 interm
= Intermediates(self
.pspec
, part_pts
)
1112 o
= interm
.process(interm
.i
)
1115 finalout
= FinalOut(self
.pspec
, part_pts
)
1116 finalout
.setup(m
, o
)
1117 m
.d
.comb
+= self
.o
.eq(finalout
.process(o
))
1122 if __name__
== "__main__":
1126 m
.intermediate_output
,
1129 *m
.part_pts
.values()])