switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / part_mul_add / adder.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11 from ieee754.pipeline import PipelineSpec
12 from nmutil.pipemodbase import PipeModBase
13
14 from ieee754.part_mul_add.partpoints import PartitionPoints
15
16
17 class FullAdder(Elaboratable):
18 """Full Adder.
19
20 :attribute in0: the first input
21 :attribute in1: the second input
22 :attribute in2: the third input
23 :attribute sum: the sum output
24 :attribute carry: the carry output
25
26 Rather than do individual full adders (and have an array of them,
27 which would be very slow to simulate), this module can specify the
28 bit width of the inputs and outputs: in effect it performs multiple
29 Full 3-2 Add operations "in parallel".
30 """
31
32 def __init__(self, width):
33 """Create a ``FullAdder``.
34
35 :param width: the bit width of the input and output
36 """
37 self.in0 = Signal(width, reset_less=True)
38 self.in1 = Signal(width, reset_less=True)
39 self.in2 = Signal(width, reset_less=True)
40 self.sum = Signal(width, reset_less=True)
41 self.carry = Signal(width, reset_less=True)
42
43 def elaborate(self, platform):
44 """Elaborate this module."""
45 m = Module()
46 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
47 m.d.comb += self.carry.eq((self.in0 & self.in1)
48 | (self.in1 & self.in2)
49 | (self.in2 & self.in0))
50 return m
51
52
53 class MaskedFullAdder(Elaboratable):
54 """Masked Full Adder.
55
56 :attribute mask: the carry partition mask
57 :attribute in0: the first input
58 :attribute in1: the second input
59 :attribute in2: the third input
60 :attribute sum: the sum output
61 :attribute mcarry: the masked carry output
62
63 FullAdders are always used with a "mask" on the output. To keep
64 the graphviz "clean", this class performs the masking here rather
65 than inside a large for-loop.
66
67 See the following discussion as to why this is no longer derived
68 from FullAdder. Each carry is shifted here *before* being ANDed
69 with the mask, so that an AOI cell may be used (which is more
70 gate-efficient)
71 https://en.wikipedia.org/wiki/AND-OR-Invert
72 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
73 """
74
75 def __init__(self, width):
76 """Create a ``MaskedFullAdder``.
77
78 :param width: the bit width of the input and output
79 """
80 self.width = width
81 self.mask = Signal(width, reset_less=True)
82 self.mcarry = Signal(width, reset_less=True)
83 self.in0 = Signal(width, reset_less=True)
84 self.in1 = Signal(width, reset_less=True)
85 self.in2 = Signal(width, reset_less=True)
86 self.sum = Signal(width, reset_less=True)
87
88 def elaborate(self, platform):
89 """Elaborate this module."""
90 m = Module()
91 s1 = Signal(self.width, reset_less=True)
92 s2 = Signal(self.width, reset_less=True)
93 s3 = Signal(self.width, reset_less=True)
94 c1 = Signal(self.width, reset_less=True)
95 c2 = Signal(self.width, reset_less=True)
96 c3 = Signal(self.width, reset_less=True)
97 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
98 m.d.comb += s1.eq(Cat(0, self.in0))
99 m.d.comb += s2.eq(Cat(0, self.in1))
100 m.d.comb += s3.eq(Cat(0, self.in2))
101 m.d.comb += c1.eq(s1 & s2 & self.mask)
102 m.d.comb += c2.eq(s2 & s3 & self.mask)
103 m.d.comb += c3.eq(s3 & s1 & self.mask)
104 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
105 return m
106
107
108 class PartitionedAdder(Elaboratable):
109 """Partitioned Adder.
110
111 Performs the final add. The partition points are included in the
112 actual add (in one of the operands only), which causes a carry over
113 to the next bit. Then the final output *removes* the extra bits from
114 the result.
115
116 partition: .... P... P... P... P... (32 bits)
117 a : .... .... .... .... .... (32 bits)
118 b : .... .... .... .... .... (32 bits)
119 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
120 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
121 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
122 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
123
124 :attribute width: the bit width of the input and output. Read-only.
125 :attribute a: the first input to the adder
126 :attribute b: the second input to the adder
127 :attribute output: the sum output
128 :attribute partition_points: the input partition points. Modification not
129 supported, except for by ``Signal.eq``.
130 """
131
132 def __init__(self, width, partition_points, partition_step=1):
133 """Create a ``PartitionedAdder``.
134
135 :param width: the bit width of the input and output
136 :param partition_points: the input partition points
137 :param partition_step: a multiplier (typically double) step
138 which in-place "expands" the partition points
139 """
140 self.width = width
141 self.pmul = partition_step
142 self.a = Signal(width, reset_less=True)
143 self.b = Signal(width, reset_less=True)
144 self.output = Signal(width, reset_less=True)
145 self.partition_points = PartitionPoints(partition_points)
146 if not self.partition_points.fits_in_width(width):
147 raise ValueError("partition_points doesn't fit in width")
148 expanded_width = 0
149 for i in range(self.width):
150 if i in self.partition_points:
151 expanded_width += 1
152 expanded_width += 1
153 self._expanded_width = expanded_width
154
155 def elaborate(self, platform):
156 """Elaborate this module."""
157 m = Module()
158 expanded_a = Signal(self._expanded_width, reset_less=True)
159 expanded_b = Signal(self._expanded_width, reset_less=True)
160 expanded_o = Signal(self._expanded_width, reset_less=True)
161
162 expanded_index = 0
163 # store bits in a list, use Cat later. graphviz is much cleaner
164 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
165
166 # partition points are "breaks" (extra zeros or 1s) in what would
167 # otherwise be a massive long add. when the "break" points are 0,
168 # whatever is in it (in the output) is discarded. however when
169 # there is a "1", it causes a roll-over carry to the *next* bit.
170 # we still ignore the "break" bit in the [intermediate] output,
171 # however by that time we've got the effect that we wanted: the
172 # carry has been carried *over* the break point.
173
174 for i in range(self.width):
175 pi = i/self.pmul # double the range of the partition point test
176 if pi.is_integer() and pi in self.partition_points:
177 # add extra bit set to 0 + 0 for enabled partition points
178 # and 1 + 0 for disabled partition points
179 ea.append(expanded_a[expanded_index])
180 al.append(~self.partition_points[pi]) # add extra bit in a
181 eb.append(expanded_b[expanded_index])
182 bl.append(C(0)) # yes, add a zero
183 expanded_index += 1 # skip the extra point. NOT in the output
184 ea.append(expanded_a[expanded_index])
185 eb.append(expanded_b[expanded_index])
186 eo.append(expanded_o[expanded_index])
187 al.append(self.a[i])
188 bl.append(self.b[i])
189 ol.append(self.output[i])
190 expanded_index += 1
191
192 # combine above using Cat
193 m.d.comb += Cat(*ea).eq(Cat(*al))
194 m.d.comb += Cat(*eb).eq(Cat(*bl))
195 m.d.comb += Cat(*ol).eq(Cat(*eo))
196
197 # use only one addition to take advantage of look-ahead carry and
198 # special hardware on FPGAs
199 m.d.comb += expanded_o.eq(expanded_a + expanded_b)
200 return m
201
202