split out LSB and neg term to separate module
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 """
58 if name is None:
59 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
60 retval = PartitionPoints()
61 for point, enabled in self.items():
62 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
63 return retval
64
65 def eq(self, rhs):
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self.keys()) != set(rhs.keys()):
68 raise ValueError("incompatible point set")
69 for point, enabled in self.items():
70 yield enabled.eq(rhs[point])
71
72 def as_mask(self, width):
73 """Create a bit-mask from `self`.
74
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
77
78 :param width: the bit width of the resulting mask
79 """
80 bits = []
81 for i in range(width):
82 if i in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106
107 class FullAdder(Elaboratable):
108 """Full Adder.
109
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
115
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
120 """
121
122 def __init__(self, width):
123 """Create a ``FullAdder``.
124
125 :param width: the bit width of the input and output
126 """
127 self.in0 = Signal(width)
128 self.in1 = Signal(width)
129 self.in2 = Signal(width)
130 self.sum = Signal(width)
131 self.carry = Signal(width)
132
133 def elaborate(self, platform):
134 """Elaborate this module."""
135 m = Module()
136 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
137 m.d.comb += self.carry.eq((self.in0 & self.in1)
138 | (self.in1 & self.in2)
139 | (self.in2 & self.in0))
140 return m
141
142
143 class PartitionedAdder(Elaboratable):
144 """Partitioned Adder.
145
146 :attribute width: the bit width of the input and output. Read-only.
147 :attribute a: the first input to the adder
148 :attribute b: the second input to the adder
149 :attribute output: the sum output
150 :attribute partition_points: the input partition points. Modification not
151 supported, except for by ``Signal.eq``.
152 """
153
154 def __init__(self, width, partition_points):
155 """Create a ``PartitionedAdder``.
156
157 :param width: the bit width of the input and output
158 :param partition_points: the input partition points
159 """
160 self.width = width
161 self.a = Signal(width)
162 self.b = Signal(width)
163 self.output = Signal(width)
164 self.partition_points = PartitionPoints(partition_points)
165 if not self.partition_points.fits_in_width(width):
166 raise ValueError("partition_points doesn't fit in width")
167 expanded_width = 0
168 for i in range(self.width):
169 if i in self.partition_points:
170 expanded_width += 1
171 expanded_width += 1
172 self._expanded_width = expanded_width
173 # XXX these have to remain here due to some horrible nmigen
174 # simulation bugs involving sync. it is *not* necessary to
175 # have them here, they should (under normal circumstances)
176 # be moved into elaborate, as they are entirely local
177 self._expanded_a = Signal(expanded_width)
178 self._expanded_b = Signal(expanded_width)
179 self._expanded_output = Signal(expanded_width)
180
181 def elaborate(self, platform):
182 """Elaborate this module."""
183 m = Module()
184 expanded_index = 0
185 # store bits in a list, use Cat later. graphviz is much cleaner
186 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
187
188 # partition points are "breaks" (extra zeros) in what would otherwise
189 # be a massive long add.
190 for i in range(self.width):
191 if i in self.partition_points:
192 # add extra bit set to 0 + 0 for enabled partition points
193 # and 1 + 0 for disabled partition points
194 ea.append(self._expanded_a[expanded_index])
195 al.append(~self.partition_points[i])
196 eb.append(self._expanded_b[expanded_index])
197 bl.append(C(0))
198 expanded_index += 1
199 ea.append(self._expanded_a[expanded_index])
200 al.append(self.a[i])
201 eb.append(self._expanded_b[expanded_index])
202 bl.append(self.b[i])
203 eo.append(self._expanded_output[expanded_index])
204 ol.append(self.output[i])
205 expanded_index += 1
206 # combine above using Cat
207 m.d.comb += Cat(*ea).eq(Cat(*al))
208 m.d.comb += Cat(*eb).eq(Cat(*bl))
209 m.d.comb += Cat(*ol).eq(Cat(*eo))
210 # use only one addition to take advantage of look-ahead carry and
211 # special hardware on FPGAs
212 m.d.comb += self._expanded_output.eq(
213 self._expanded_a + self._expanded_b)
214 return m
215
216
217 FULL_ADDER_INPUT_COUNT = 3
218
219
220 class AddReduce(Elaboratable):
221 """Add list of numbers together.
222
223 :attribute inputs: input ``Signal``s to be summed. Modification not
224 supported, except for by ``Signal.eq``.
225 :attribute register_levels: List of nesting levels that should have
226 pipeline registers.
227 :attribute output: output sum.
228 :attribute partition_points: the input partition points. Modification not
229 supported, except for by ``Signal.eq``.
230 """
231
232 def __init__(self, inputs, output_width, register_levels, partition_points):
233 """Create an ``AddReduce``.
234
235 :param inputs: input ``Signal``s to be summed.
236 :param output_width: bit-width of ``output``.
237 :param register_levels: List of nesting levels that should have
238 pipeline registers.
239 :param partition_points: the input partition points.
240 """
241 self.inputs = list(inputs)
242 self._resized_inputs = [
243 Signal(output_width, name=f"resized_inputs[{i}]")
244 for i in range(len(self.inputs))]
245 self.register_levels = list(register_levels)
246 self.output = Signal(output_width)
247 self.partition_points = PartitionPoints(partition_points)
248 if not self.partition_points.fits_in_width(output_width):
249 raise ValueError("partition_points doesn't fit in output_width")
250 self._reg_partition_points = self.partition_points.like()
251 max_level = AddReduce.get_max_level(len(self.inputs))
252 for level in self.register_levels:
253 if level > max_level:
254 raise ValueError(
255 "not enough adder levels for specified register levels")
256
257 @staticmethod
258 def get_max_level(input_count):
259 """Get the maximum level.
260
261 All ``register_levels`` must be less than or equal to the maximum
262 level.
263 """
264 retval = 0
265 while True:
266 groups = AddReduce.full_adder_groups(input_count)
267 if len(groups) == 0:
268 return retval
269 input_count %= FULL_ADDER_INPUT_COUNT
270 input_count += 2 * len(groups)
271 retval += 1
272
273 def next_register_levels(self):
274 """``Iterable`` of ``register_levels`` for next recursive level."""
275 for level in self.register_levels:
276 if level > 0:
277 yield level - 1
278
279 @staticmethod
280 def full_adder_groups(input_count):
281 """Get ``inputs`` indices for which a full adder should be built."""
282 return range(0,
283 input_count - FULL_ADDER_INPUT_COUNT + 1,
284 FULL_ADDER_INPUT_COUNT)
285
286 def elaborate(self, platform):
287 """Elaborate this module."""
288 m = Module()
289
290 # resize inputs to correct bit-width and optionally add in
291 # pipeline registers
292 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
293 for i in range(len(self.inputs))]
294 if 0 in self.register_levels:
295 m.d.sync += resized_input_assignments
296 m.d.sync += self._reg_partition_points.eq(self.partition_points)
297 else:
298 m.d.comb += resized_input_assignments
299 m.d.comb += self._reg_partition_points.eq(self.partition_points)
300
301 groups = AddReduce.full_adder_groups(len(self.inputs))
302 # if there are no full adders to create, then we handle the base cases
303 # and return, otherwise we go on to the recursive case
304 if len(groups) == 0:
305 if len(self.inputs) == 0:
306 # use 0 as the default output value
307 m.d.comb += self.output.eq(0)
308 elif len(self.inputs) == 1:
309 # handle single input
310 m.d.comb += self.output.eq(self._resized_inputs[0])
311 else:
312 # base case for adding 2 or more inputs, which get recursively
313 # reduced to 2 inputs
314 assert len(self.inputs) == 2
315 adder = PartitionedAdder(len(self.output),
316 self._reg_partition_points)
317 m.submodules.final_adder = adder
318 m.d.comb += adder.a.eq(self._resized_inputs[0])
319 m.d.comb += adder.b.eq(self._resized_inputs[1])
320 m.d.comb += self.output.eq(adder.output)
321 return m
322 # go on to handle recursive case
323 intermediate_terms = []
324
325 def add_intermediate_term(value):
326 intermediate_term = Signal(
327 len(self.output),
328 name=f"intermediate_terms[{len(intermediate_terms)}]")
329 intermediate_terms.append(intermediate_term)
330 m.d.comb += intermediate_term.eq(value)
331
332 # store mask in intermediary (simplifies graph)
333 part_mask = Signal(len(self.output), reset_less=True)
334 mask = self._reg_partition_points.as_mask(len(self.output))
335 m.d.comb += part_mask.eq(mask)
336
337 # create full adders for this recursive level.
338 # this shrinks N terms to 2 * (N // 3) plus the remainder
339 for i in groups:
340 adder_i = FullAdder(len(self.output))
341 setattr(m.submodules, f"adder_{i}", adder_i)
342 m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
343 m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
344 m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
345 add_intermediate_term(adder_i.sum)
346 shifted_carry = adder_i.carry << 1
347 # mask out carry bits to prevent carries between partitions
348 add_intermediate_term((adder_i.carry << 1) & part_mask)
349 # handle the remaining inputs.
350 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
351 add_intermediate_term(self._resized_inputs[-1])
352 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
353 # Just pass the terms to the next layer, since we wouldn't gain
354 # anything by using a half adder since there would still be 2 terms
355 # and just passing the terms to the next layer saves gates.
356 add_intermediate_term(self._resized_inputs[-2])
357 add_intermediate_term(self._resized_inputs[-1])
358 else:
359 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
360 # recursive invocation of ``AddReduce``
361 next_level = AddReduce(intermediate_terms,
362 len(self.output),
363 self.next_register_levels(),
364 self._reg_partition_points)
365 m.submodules.next_level = next_level
366 m.d.comb += self.output.eq(next_level.output)
367 return m
368
369
370 OP_MUL_LOW = 0
371 OP_MUL_SIGNED_HIGH = 1
372 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
373 OP_MUL_UNSIGNED_HIGH = 3
374
375
376 def get_term(value, shift=0, enabled=None):
377 if enabled is not None:
378 value = Mux(enabled, value, 0)
379 if shift > 0:
380 value = Cat(Repl(C(0, 1), shift), value)
381 else:
382 assert shift == 0
383 return value
384
385
386 class ProductTerm(Elaboratable):
387 """ this class creates a single product term (a[..]*b[..]).
388 it has a design flaw in that is the *output* that is selected,
389 where the multiplication(s) are combinatorially generated
390 all the time.
391 """
392
393 def __init__(self, width, twidth, pbwid, a_index, b_index):
394 self.a_index = a_index
395 self.b_index = b_index
396 shift = 8 * (self.a_index + self.b_index)
397 self.pwidth = width
398 self.twidth = twidth
399 self.width = width*2
400 self.shift = shift
401
402 self.ti = Signal(self.width, reset_less=True)
403 self.term = Signal(twidth, reset_less=True)
404 self.a = Signal(twidth//2, reset_less=True)
405 self.b = Signal(twidth//2, reset_less=True)
406 self.pb_en = Signal(pbwid, reset_less=True)
407
408 self.tl = tl = []
409 min_index = min(self.a_index, self.b_index)
410 max_index = max(self.a_index, self.b_index)
411 for i in range(min_index, max_index):
412 tl.append(self.pb_en[i])
413 name = "te_%d_%d" % (self.a_index, self.b_index)
414 if len(tl) > 0:
415 term_enabled = Signal(name=name, reset_less=True)
416 else:
417 term_enabled = None
418 self.enabled = term_enabled
419 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
420
421 def elaborate(self, platform):
422
423 m = Module()
424 if self.enabled is not None:
425 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
426
427 bsa = Signal(self.width, reset_less=True)
428 bsb = Signal(self.width, reset_less=True)
429 a_index, b_index = self.a_index, self.b_index
430 pwidth = self.pwidth
431 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
432 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
433 m.d.comb += self.ti.eq(bsa * bsb)
434 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
435 """
436 #TODO: sort out width issues, get inputs a/b switched on/off.
437 #data going into Muxes is 1/2 the required width
438
439 pwidth = self.pwidth
440 width = self.width
441 bsa = Signal(self.twidth//2, reset_less=True)
442 bsb = Signal(self.twidth//2, reset_less=True)
443 asel = Signal(width, reset_less=True)
444 bsel = Signal(width, reset_less=True)
445 a_index, b_index = self.a_index, self.b_index
446 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
447 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
448 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
449 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
450 m.d.comb += self.ti.eq(bsa * bsb)
451 m.d.comb += self.term.eq(self.ti)
452 """
453
454 return m
455
456
457 class ProductTerms(Elaboratable):
458 """ creates a bank of product terms. also performs the actual bit-selection
459 this class is to be wrapped with a for-loop on the "a" operand.
460 it creates a second-level for-loop on the "b" operand.
461 """
462 def __init__(self, width, twidth, pbwid, a_index, blen):
463 self.a_index = a_index
464 self.blen = blen
465 self.pwidth = width
466 self.twidth = twidth
467 self.pbwid = pbwid
468 self.a = Signal(twidth//2, reset_less=True)
469 self.b = Signal(twidth//2, reset_less=True)
470 self.pb_en = Signal(pbwid, reset_less=True)
471 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
472 for i in range(blen)]
473
474 def elaborate(self, platform):
475
476 m = Module()
477
478 for b_index in range(self.blen):
479 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
480 self.a_index, b_index)
481 setattr(m.submodules, "term_%d" % b_index, t)
482
483 m.d.comb += t.a.eq(self.a)
484 m.d.comb += t.b.eq(self.b)
485 m.d.comb += t.pb_en.eq(self.pb_en)
486
487 m.d.comb += self.terms[b_index].eq(t.term)
488
489 return m
490
491 class LSBNegTerm(Elaboratable):
492
493 def __init__(self, bit_width):
494 self.bit_width = bit_width
495 self.part = Signal(reset_less=True)
496 self.signed = Signal(reset_less=True)
497 self.op = Signal(bit_width, reset_less=True)
498 self.msb = Signal(reset_less=True)
499 self.nt = Signal(bit_wid*2, reset_less=True)
500 self.nl = Signal(bit_wid*2, reset_less=True)
501
502 def elaborate(self, platform):
503 m = Module()
504 comb = m.d.comb
505 bit_wid = self.bit_width
506 ext = Repl(0, bit_wid) # extend output to HI part
507
508 # determine sign of each incoming number *in this partition*
509 enabled = Signal(name="en_%d" % i, reset_less=True)
510 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
511
512 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
513 # negation operation is split into a bitwise not and a +1.
514 # likewise for 16, 32, and 64-bit values.
515
516 # width-extended 1s complement if a is signed, otherwise zero
517 comb += nt.eq(Mux(a_enabled, Cat(ext, ~self.op), 0))
518
519 # add 1 if signed, otherwise add zero
520 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
521
522 return m
523
524
525 class Part(Elaboratable):
526 """ a key class which, depending on the partitioning, will determine
527 what action to take when parts of the output are signed or unsigned.
528
529 this requires 2 pieces of data *per operand, per partition*:
530 whether the MSB is HI/LO (per partition!), and whether a signed
531 or unsigned operation has been *requested*.
532
533 once that is determined, signed is basically carried out
534 by splitting 2's complement into 1's complement plus one.
535 1's complement is just a bit-inversion.
536
537 the extra terms - as separate terms - are then thrown at the
538 AddReduce alongside the multiplication part-results.
539 """
540 def __init__(self, width, n_parts, n_levels, pbwid):
541
542 # inputs
543 self.a = Signal(64)
544 self.b = Signal(64)
545 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
546 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
547 self.pbs = Signal(pbwid, reset_less=True)
548
549 # outputs
550 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
551 self.delayed_parts = [
552 [Signal(name=f"delayed_part_{delay}_{i}")
553 for i in range(n_parts)]
554 for delay in range(n_levels)]
555 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
556 self.dplast = [Signal(name=f"dplast_{i}")
557 for i in range(n_parts)]
558
559 self.not_a_term = Signal(width)
560 self.neg_lsb_a_term = Signal(width)
561 self.not_b_term = Signal(width)
562 self.neg_lsb_b_term = Signal(width)
563
564 def elaborate(self, platform):
565 m = Module()
566
567 pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
568 byte_count = 8 // len(parts)
569 for i in range(len(parts)):
570 pbl = []
571 pbl.append(~pbs[i * byte_count - 1])
572 for j in range(i * byte_count, (i + 1) * byte_count - 1):
573 pbl.append(pbs[j])
574 pbl.append(~pbs[(i + 1) * byte_count - 1])
575 value = Signal(len(pbl), reset_less=True)
576 m.d.comb += value.eq(Cat(*pbl))
577 m.d.comb += parts[i].eq(~(value).bool())
578 m.d.comb += delayed_parts[0][i].eq(parts[i])
579 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
580 for j in range(len(delayed_parts)-1)]
581 m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
582
583 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
584 self.not_a_term, self.neg_lsb_a_term, \
585 self.not_b_term, self.neg_lsb_b_term
586
587 byte_width = 8 // len(parts) # byte width
588 bit_wid = 8 * byte_width # bit width
589 ext = Repl(0, bit_wid) # extend output to HI part
590 nat, nbt, nla, nlb = [], [], [], []
591 for i in range(len(parts)):
592 # determine sign of each incoming number *in this partition*
593 be = (parts[i] & self.a[(i + 1) * bit_wid - 1] # MSB
594 & self.a_signed[i * byte_width]) # a op is signed?
595 ae = (parts[i] & self.b[(i + 1) * bit_wid - 1] # MSB
596 & self.b_signed[i * byte_width]) # b op is signed?
597 a_enabled = Signal(name="a_en_%d" % i, reset_less=True)
598 b_enabled = Signal(name="b_en_%d" % i, reset_less=True)
599 m.d.comb += a_enabled.eq(ae)
600 m.d.comb += b_enabled.eq(be)
601
602 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
603 # negation operation is split into a bitwise not and a +1.
604 # likewise for 16, 32, and 64-bit values.
605
606 # a: width-extended 1s complement if a is signed, otherwise zero
607 nat.append(Mux(a_enabled,
608 Cat(ext, ~self.a.bit_select(bit_wid * i, bit_wid)),
609 0))
610
611 # a: add 1 if a signed, otherwise add zero
612 nla.append(Cat(ext, a_enabled, Repl(0, bit_wid-1)))
613
614 # b: width-extended 1s complement if a is signed, otherwise zero
615 nbt.append(Mux(b_enabled,
616 Cat(ext, ~self.b.bit_select(bit_wid * i, bit_wid)),
617 0))
618
619 # b: add 1 if b signed, otherwise add zero
620 nlb.append(Cat(ext, b_enabled, Repl(0, bit_wid-1)))
621
622 # concatenate together and return all 4 results.
623 m.d.comb += [not_a_term.eq(Cat(*nat)),
624 not_b_term.eq(Cat(*nbt)),
625 neg_lsb_a_term.eq(Cat(*nla)),
626 neg_lsb_b_term.eq(Cat(*nlb)),
627 ]
628
629 return m
630
631
632 class IntermediateOut(Elaboratable):
633 """ selects the HI/LO part of the multiplication, for a given bit-width
634 the output is also reconstructed in its SIMD (partition) lanes.
635 """
636 def __init__(self, width, out_wid, n_parts):
637 self.width = width
638 self.n_parts = n_parts
639 self.delayed_part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
640 for i in range(8)]
641 self.intermed = Signal(out_wid, reset_less=True)
642 self.output = Signal(out_wid//2, reset_less=True)
643
644 def elaborate(self, platform):
645 m = Module()
646
647 ol = []
648 w = self.width
649 sel = w // 8
650 for i in range(self.n_parts):
651 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
652 m.d.comb += op.eq(
653 Mux(self.delayed_part_ops[sel * i] == OP_MUL_LOW,
654 self.intermed.bit_select(i * w*2, w),
655 self.intermed.bit_select(i * w*2 + w, w)))
656 ol.append(op)
657 m.d.comb += self.output.eq(Cat(*ol))
658
659 return m
660
661
662 class FinalOut(Elaboratable):
663 """ selects the final output based on the partitioning.
664
665 each byte is selectable independently, i.e. it is possible
666 that some partitions requested 8-bit computation whilst others
667 requested 16 or 32 bit.
668 """
669 def __init__(self, out_wid):
670 # inputs
671 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
672 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
673 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
674
675 self.i8 = Signal(out_wid, reset_less=True)
676 self.i16 = Signal(out_wid, reset_less=True)
677 self.i32 = Signal(out_wid, reset_less=True)
678 self.i64 = Signal(out_wid, reset_less=True)
679
680 # output
681 self.out = Signal(out_wid, reset_less=True)
682
683 def elaborate(self, platform):
684 m = Module()
685 ol = []
686 for i in range(8):
687 # select one of the outputs: d8 selects i8, d16 selects i16
688 # d32 selects i32, and the default is i64.
689 # d8 and d16 are ORed together in the first Mux
690 # then the 2nd selects either i8 or i16.
691 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
692 op = Signal(8, reset_less=True, name="op_%d" % i)
693 m.d.comb += op.eq(
694 Mux(self.d8[i] | self.d16[i // 2],
695 Mux(self.d8[i], self.i8.bit_select(i * 8, 8),
696 self.i16.bit_select(i * 8, 8)),
697 Mux(self.d32[i // 4], self.i32.bit_select(i * 8, 8),
698 self.i64.bit_select(i * 8, 8))))
699 ol.append(op)
700 m.d.comb += self.out.eq(Cat(*ol))
701 return m
702
703
704 class OrMod(Elaboratable):
705 """ ORs four values together in a hierarchical tree
706 """
707 def __init__(self, wid):
708 self.wid = wid
709 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
710 for i in range(4)]
711 self.orout = Signal(wid, reset_less=True)
712
713 def elaborate(self, platform):
714 m = Module()
715 or1 = Signal(self.wid, reset_less=True)
716 or2 = Signal(self.wid, reset_less=True)
717 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
718 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
719 m.d.comb += self.orout.eq(or1 | or2)
720
721 return m
722
723
724 class Signs(Elaboratable):
725 """ determines whether a or b are signed numbers
726 based on the required operation type (OP_MUL_*)
727 """
728
729 def __init__(self):
730 self.part_ops = Signal(2, reset_less=True)
731 self.a_signed = Signal(reset_less=True)
732 self.b_signed = Signal(reset_less=True)
733
734 def elaborate(self, platform):
735
736 m = Module()
737
738 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
739 bsig = (self.part_ops == OP_MUL_LOW) \
740 | (self.part_ops == OP_MUL_SIGNED_HIGH)
741 m.d.comb += self.a_signed.eq(asig)
742 m.d.comb += self.b_signed.eq(bsig)
743
744 return m
745
746
747 class Mul8_16_32_64(Elaboratable):
748 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
749
750 Supports partitioning into any combination of 8, 16, 32, and 64-bit
751 partitions on naturally-aligned boundaries. Supports the operation being
752 set for each partition independently.
753
754 :attribute part_pts: the input partition points. Has a partition point at
755 multiples of 8 in 0 < i < 64. Each partition point's associated
756 ``Value`` is a ``Signal``. Modification not supported, except for by
757 ``Signal.eq``.
758 :attribute part_ops: the operation for each byte. The operation for a
759 particular partition is selected by assigning the selected operation
760 code to each byte in the partition. The allowed operation codes are:
761
762 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
763 RISC-V's `mul` instruction.
764 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
765 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
766 instruction.
767 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
768 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
769 `mulhsu` instruction.
770 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
771 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
772 instruction.
773 """
774
775 def __init__(self, register_levels=()):
776 """ register_levels: specifies the points in the cascade at which
777 flip-flops are to be inserted.
778 """
779
780 # parameter(s)
781 self.register_levels = list(register_levels)
782
783 # inputs
784 self.part_pts = PartitionPoints()
785 for i in range(8, 64, 8):
786 self.part_pts[i] = Signal(name=f"part_pts_{i}")
787 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
788 self.a = Signal(64)
789 self.b = Signal(64)
790
791 # intermediates (needed for unit tests)
792 self._intermediate_output = Signal(128)
793
794 # output
795 self.output = Signal(64)
796
797 def _part_byte(self, index):
798 if index == -1 or index == 7:
799 return C(True, 1)
800 assert index >= 0 and index < 8
801 return self.part_pts[index * 8 + 8]
802
803 def elaborate(self, platform):
804 m = Module()
805
806 # collect part-bytes
807 pbs = Signal(8, reset_less=True)
808 tl = []
809 for i in range(8):
810 pb = Signal(name="pb%d" % i, reset_less=True)
811 m.d.comb += pb.eq(self._part_byte(i))
812 tl.append(pb)
813 m.d.comb += pbs.eq(Cat(*tl))
814
815 # local variables
816 signs = []
817 for i in range(8):
818 s = Signs()
819 signs.append(s)
820 setattr(m.submodules, "signs%d" % i, s)
821 m.d.comb += s.part_ops.eq(self.part_ops[i])
822
823 delayed_part_ops = [
824 [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
825 for i in range(8)]
826 for delay in range(1 + len(self.register_levels))]
827 for i in range(len(self.part_ops)):
828 m.d.comb += delayed_part_ops[0][i].eq(self.part_ops[i])
829 m.d.sync += [delayed_part_ops[j + 1][i].eq(delayed_part_ops[j][i])
830 for j in range(len(self.register_levels))]
831
832 n_levels = len(self.register_levels)+1
833 m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
834 m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
835 m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
836 m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
837 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
838 for mod in [part_8, part_16, part_32, part_64]:
839 m.d.comb += mod.a.eq(self.a)
840 m.d.comb += mod.b.eq(self.b)
841 for i in range(len(signs)):
842 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
843 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
844 m.d.comb += mod.pbs.eq(pbs)
845 nat_l.append(mod.not_a_term)
846 nbt_l.append(mod.not_b_term)
847 nla_l.append(mod.neg_lsb_a_term)
848 nlb_l.append(mod.neg_lsb_b_term)
849
850 terms = []
851
852 for a_index in range(8):
853 t = ProductTerms(8, 128, 8, a_index, 8)
854 setattr(m.submodules, "terms_%d" % a_index, t)
855
856 m.d.comb += t.a.eq(self.a)
857 m.d.comb += t.b.eq(self.b)
858 m.d.comb += t.pb_en.eq(pbs)
859
860 for term in t.terms:
861 terms.append(term)
862
863 # it's fine to bitwise-or data together since they are never enabled
864 # at the same time
865 m.submodules.nat_or = nat_or = OrMod(128)
866 m.submodules.nbt_or = nbt_or = OrMod(128)
867 m.submodules.nla_or = nla_or = OrMod(128)
868 m.submodules.nlb_or = nlb_or = OrMod(128)
869 for l, mod in [(nat_l, nat_or),
870 (nbt_l, nbt_or),
871 (nla_l, nla_or),
872 (nlb_l, nlb_or)]:
873 for i in range(len(l)):
874 m.d.comb += mod.orin[i].eq(l[i])
875 terms.append(mod.orout)
876
877 expanded_part_pts = PartitionPoints()
878 for i, v in self.part_pts.items():
879 signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
880 expanded_part_pts[i * 2] = signal
881 m.d.comb += signal.eq(v)
882
883 add_reduce = AddReduce(terms,
884 128,
885 self.register_levels,
886 expanded_part_pts)
887 m.submodules.add_reduce = add_reduce
888 m.d.comb += self._intermediate_output.eq(add_reduce.output)
889 # create _output_64
890 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
891 m.d.comb += io64.intermed.eq(self._intermediate_output)
892 for i in range(8):
893 m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
894
895 # create _output_32
896 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
897 m.d.comb += io32.intermed.eq(self._intermediate_output)
898 for i in range(8):
899 m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
900
901 # create _output_16
902 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
903 m.d.comb += io16.intermed.eq(self._intermediate_output)
904 for i in range(8):
905 m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
906
907 # create _output_8
908 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
909 m.d.comb += io8.intermed.eq(self._intermediate_output)
910 for i in range(8):
911 m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
912
913 # final output
914 m.submodules.finalout = finalout = FinalOut(64)
915 for i in range(len(part_8.delayed_parts[-1])):
916 m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
917 for i in range(len(part_16.delayed_parts[-1])):
918 m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
919 for i in range(len(part_32.delayed_parts[-1])):
920 m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
921 m.d.comb += finalout.i8.eq(io8.output)
922 m.d.comb += finalout.i16.eq(io16.output)
923 m.d.comb += finalout.i32.eq(io32.output)
924 m.d.comb += finalout.i64.eq(io64.output)
925 m.d.comb += self.output.eq(finalout.out)
926
927 return m
928
929
930 if __name__ == "__main__":
931 m = Mul8_16_32_64()
932 main(m, ports=[m.a,
933 m.b,
934 m._intermediate_output,
935 m.output,
936 *m.part_ops,
937 *m.part_pts.values()])