Fix carry output of adder/subtracter
[ieee754fpu.git] / src / ieee754 / part_mul_add / adder.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Partitioned Integer Addition.
4
5 See:
6 * https://libre-riscv.org/3d_gpu/architecture/dynamic_simd/add/
7 """
8
9 from nmigen import Signal, Module, Elaboratable, Cat
10
11 from ieee754.part_mul_add.partpoints import PartitionPoints
12 from ieee754.part_cmp.ripple import MoveMSBDown
13
14
15 class FullAdder(Elaboratable):
16 """Full Adder.
17
18 :attribute in0: the first input
19 :attribute in1: the second input
20 :attribute in2: the third input
21 :attribute sum: the sum output
22 :attribute carry: the carry output
23
24 Rather than do individual full adders (and have an array of them,
25 which would be very slow to simulate), this module can specify the
26 bit width of the inputs and outputs: in effect it performs multiple
27 Full 3-2 Add operations "in parallel".
28 """
29
30 def __init__(self, width):
31 """Create a ``FullAdder``.
32
33 :param width: the bit width of the input and output
34 """
35 self.in0 = Signal(width, reset_less=True)
36 self.in1 = Signal(width, reset_less=True)
37 self.in2 = Signal(width, reset_less=True)
38 self.sum = Signal(width, reset_less=True)
39 self.carry = Signal(width, reset_less=True)
40
41 def elaborate(self, platform):
42 """Elaborate this module."""
43 m = Module()
44 comb = m.d.comb
45 comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
46 comb += self.carry.eq((self.in0 & self.in1)
47 | (self.in1 & self.in2)
48 | (self.in2 & self.in0))
49 return m
50
51
52 class MaskedFullAdder(Elaboratable):
53 """Masked Full Adder.
54
55 :attribute mask: the carry partition mask
56 :attribute in0: the first input
57 :attribute in1: the second input
58 :attribute in2: the third input
59 :attribute sum: the sum output
60 :attribute mcarry: the masked carry output
61
62 FullAdders are always used with a "mask" on the output. To keep
63 the graphviz "clean", this class performs the masking here rather
64 than inside a large for-loop.
65
66 See the following discussion as to why this is no longer derived
67 from FullAdder. Each carry is shifted here *before* being ANDed
68 with the mask, so that an AOI cell may be used (which is more
69 gate-efficient)
70 https://en.wikipedia.org/wiki/AND-OR-Invert
71 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
72 """
73
74 def __init__(self, width):
75 """Create a ``MaskedFullAdder``.
76
77 :param width: the bit width of the input and output
78 """
79 self.width = width
80 self.mask = Signal(width, reset_less=True)
81 self.mcarry = Signal(width, reset_less=True)
82 self.in0 = Signal(width, reset_less=True)
83 self.in1 = Signal(width, reset_less=True)
84 self.in2 = Signal(width, reset_less=True)
85 self.sum = Signal(width, reset_less=True)
86
87 def elaborate(self, platform):
88 """Elaborate this module."""
89 m = Module()
90 comb = m.d.comb
91 s1 = Signal(self.width, reset_less=True)
92 s2 = Signal(self.width, reset_less=True)
93 s3 = Signal(self.width, reset_less=True)
94 c1 = Signal(self.width, reset_less=True)
95 c2 = Signal(self.width, reset_less=True)
96 c3 = Signal(self.width, reset_less=True)
97 comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
98 comb += s1.eq(Cat(0, self.in0))
99 comb += s2.eq(Cat(0, self.in1))
100 comb += s3.eq(Cat(0, self.in2))
101 comb += c1.eq(s1 & s2 & self.mask)
102 comb += c2.eq(s2 & s3 & self.mask)
103 comb += c3.eq(s3 & s1 & self.mask)
104 comb += self.mcarry.eq(c1 | c2 | c3)
105 return m
106
107
108 class PartitionedAdder(Elaboratable):
109 """Partitioned Adder.
110
111 Performs the final add. The partition points are included in the
112 actual add (in one of the operands only), which causes a carry over
113 to the next bit. Then the final output *removes* the extra bits from
114 the result.
115
116 partition: .... P... P... P... P... (32 bits)
117 a : .... .... .... .... .... (32 bits)
118 b : .... .... .... .... .... (32 bits)
119 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
120 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
121 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
122 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
123
124 partition: p p p p (4 bits)
125 carry-in : c c c c (4 bits)
126 C = c & P: C C C c (4 bits)
127 I = P=>c : I I I I (4 bits)
128 a : AAAA AAAA AAAA AAAA AAAA (32 bits)
129 b : BBBB BBBB BBBB BBBB BBBB (32 bits)
130 exp-a : 0AAAApAAAACAAAACAAAACAAAAc (32+4 bits, P=1 if no partition)
131 exp-b : 0BBBB0BBBBIBBBBIBBBBIBBBBI (32 bits plus 4 zeros)
132 exp-o : o....oN...oN...oN...oN...x (32+4 bits - x to be discarded)
133 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
134 carry-out: o o o o (4 bits)
135
136 :attribute width: the bit width of the input and output. Read-only.
137 :attribute a: the first input to the adder
138 :attribute b: the second input to the adder
139 :attribute output: the sum output
140 :attribute part_pts: the input partition points. Modification not
141 supported, except for by ``Signal.eq``.
142 """
143
144 def __init__(self, width, part_pts, partition_step=1):
145 """Create a ``PartitionedAdder``.
146
147 :param width: the bit width of the input and output
148 :param part_pts: the input partition points
149 :param partition_step: a multiplier (typically double) step
150 which in-place "expands" the partition points
151 """
152 self.width = width
153 self.pmul = partition_step
154 self.part_pts = PartitionPoints(part_pts)
155 self.a = Signal(width, reset_less=True)
156 self.b = Signal(width, reset_less=True)
157 self.carry_in = Signal(self.part_pts.get_max_partition_count(width))
158 self.carry_out = Signal(self.part_pts.get_max_partition_count(width))
159 self.output = Signal(width, reset_less=True)
160 if not self.part_pts.fits_in_width(width):
161 raise ValueError("partition_points doesn't fit in width")
162 expanded_width = 2
163 for i in range(self.width):
164 if i in self.part_pts:
165 expanded_width += 1
166 expanded_width += 1
167 self._expanded_width = expanded_width
168
169 def elaborate(self, platform):
170 """Elaborate this module."""
171 m = Module()
172 comb = m.d.comb
173
174 carry_tmp = Signal(self.carry_out.width)
175 m.submodules.ripple = ripple = MoveMSBDown(self.carry_out.width)
176
177 expanded_a = Signal(self._expanded_width, reset_less=True)
178 expanded_b = Signal(self._expanded_width, reset_less=True)
179 expanded_o = Signal(self._expanded_width, reset_less=True)
180
181 expanded_index = 0
182 # store bits in a list, use Cat later. graphviz is much cleaner
183 al, bl, ol, cl, ea, eb, eo, co = [], [], [], [], [], [], [], []
184
185 # partition points are "breaks" (extra zeros or 1s) in what would
186 # otherwise be a massive long add. when the "break" points are 0,
187 # whatever is in it (in the output) is discarded. however when
188 # there is a "1", it causes a roll-over carry to the *next* bit.
189 # we still ignore the "break" bit in the [intermediate] output,
190 # however by that time we've got the effect that we wanted: the
191 # carry has been carried *over* the break point.
192
193 carry_bit = 0
194 al.append(self.carry_in[carry_bit])
195 bl.append(self.carry_in[carry_bit])
196 ea.append(expanded_a[expanded_index])
197 eb.append(expanded_b[expanded_index])
198 carry_bit += 1
199 expanded_index += 1
200
201 for i in range(self.width):
202 pi = i/self.pmul # double the range of the partition point test
203 if pi.is_integer() and pi in self.part_pts:
204 # add extra bit set to carry + carry for enabled
205 # partition points
206 a_bit = Signal(name="a_bit_%d" % i, reset_less=True)
207 carry_in = self.carry_in[carry_bit] # convenience
208 m.d.comb += a_bit.eq(self.part_pts[pi].implies(carry_in))
209
210 # and 1 + 0 for disabled partition points
211 ea.append(expanded_a[expanded_index])
212 al.append(a_bit) # add extra bit in a
213 eb.append(expanded_b[expanded_index])
214 bl.append(carry_in & self.part_pts[pi]) # carry bit
215 co.append(expanded_o[expanded_index])
216 cl.append(carry_tmp[carry_bit-1])
217 expanded_index += 1 # skip the extra point. NOT in the output
218 carry_bit += 1
219 ea.append(expanded_a[expanded_index])
220 eb.append(expanded_b[expanded_index])
221 eo.append(expanded_o[expanded_index])
222 al.append(self.a[i])
223 bl.append(self.b[i])
224 ol.append(self.output[i])
225 expanded_index += 1
226 al.append(0)
227 bl.append(0)
228 co.append(expanded_o[-1])
229 cl.append(carry_tmp[carry_bit-1])
230
231 # combine above using Cat
232 comb += Cat(*ea).eq(Cat(*al))
233 comb += Cat(*eb).eq(Cat(*bl))
234 comb += Cat(*ol).eq(Cat(*eo))
235 comb += Cat(*cl).eq(Cat(*co))
236
237 # use only one addition to take advantage of look-ahead carry and
238 # special hardware on FPGAs
239 comb += expanded_o.eq(expanded_a + expanded_b)
240
241 comb += ripple.results_in.eq(carry_tmp)
242 comb += ripple.gates.eq(self.part_pts.as_sig())
243 comb += self.carry_out.eq(ripple.output)
244
245 return m