Add rudimentary test for partitioned add with carry
[ieee754fpu.git] / src / ieee754 / part_mul_add / adder.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Partitioned Integer Addition.
4
5 See:
6 * https://libre-riscv.org/3d_gpu/architecture/dynamic_simd/add/
7 """
8
9 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
10 from nmigen.hdl.ast import Assign
11 from abc import ABCMeta, abstractmethod
12 from nmigen.cli import main
13 from functools import reduce
14 from operator import or_
15 from ieee754.pipeline import PipelineSpec
16 from nmutil.pipemodbase import PipeModBase
17
18 from ieee754.part_mul_add.partpoints import PartitionPoints
19
20
21 class FullAdder(Elaboratable):
22 """Full Adder.
23
24 :attribute in0: the first input
25 :attribute in1: the second input
26 :attribute in2: the third input
27 :attribute sum: the sum output
28 :attribute carry: the carry output
29
30 Rather than do individual full adders (and have an array of them,
31 which would be very slow to simulate), this module can specify the
32 bit width of the inputs and outputs: in effect it performs multiple
33 Full 3-2 Add operations "in parallel".
34 """
35
36 def __init__(self, width):
37 """Create a ``FullAdder``.
38
39 :param width: the bit width of the input and output
40 """
41 self.in0 = Signal(width, reset_less=True)
42 self.in1 = Signal(width, reset_less=True)
43 self.in2 = Signal(width, reset_less=True)
44 self.sum = Signal(width, reset_less=True)
45 self.carry = Signal(width, reset_less=True)
46
47 def elaborate(self, platform):
48 """Elaborate this module."""
49 m = Module()
50 comb = m.d.comb
51 comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
52 comb += self.carry.eq((self.in0 & self.in1)
53 | (self.in1 & self.in2)
54 | (self.in2 & self.in0))
55 return m
56
57
58 class MaskedFullAdder(Elaboratable):
59 """Masked Full Adder.
60
61 :attribute mask: the carry partition mask
62 :attribute in0: the first input
63 :attribute in1: the second input
64 :attribute in2: the third input
65 :attribute sum: the sum output
66 :attribute mcarry: the masked carry output
67
68 FullAdders are always used with a "mask" on the output. To keep
69 the graphviz "clean", this class performs the masking here rather
70 than inside a large for-loop.
71
72 See the following discussion as to why this is no longer derived
73 from FullAdder. Each carry is shifted here *before* being ANDed
74 with the mask, so that an AOI cell may be used (which is more
75 gate-efficient)
76 https://en.wikipedia.org/wiki/AND-OR-Invert
77 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
78 """
79
80 def __init__(self, width):
81 """Create a ``MaskedFullAdder``.
82
83 :param width: the bit width of the input and output
84 """
85 self.width = width
86 self.mask = Signal(width, reset_less=True)
87 self.mcarry = Signal(width, reset_less=True)
88 self.in0 = Signal(width, reset_less=True)
89 self.in1 = Signal(width, reset_less=True)
90 self.in2 = Signal(width, reset_less=True)
91 self.sum = Signal(width, reset_less=True)
92
93 def elaborate(self, platform):
94 """Elaborate this module."""
95 m = Module()
96 comb = m.d.comb
97 s1 = Signal(self.width, reset_less=True)
98 s2 = Signal(self.width, reset_less=True)
99 s3 = Signal(self.width, reset_less=True)
100 c1 = Signal(self.width, reset_less=True)
101 c2 = Signal(self.width, reset_less=True)
102 c3 = Signal(self.width, reset_less=True)
103 comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
104 comb += s1.eq(Cat(0, self.in0))
105 comb += s2.eq(Cat(0, self.in1))
106 comb += s3.eq(Cat(0, self.in2))
107 comb += c1.eq(s1 & s2 & self.mask)
108 comb += c2.eq(s2 & s3 & self.mask)
109 comb += c3.eq(s3 & s1 & self.mask)
110 comb += self.mcarry.eq(c1 | c2 | c3)
111 return m
112
113
114 class PartitionedAdder(Elaboratable):
115 """Partitioned Adder.
116
117 Performs the final add. The partition points are included in the
118 actual add (in one of the operands only), which causes a carry over
119 to the next bit. Then the final output *removes* the extra bits from
120 the result.
121
122 partition: .... P... P... P... P... (32 bits)
123 a : .... .... .... .... .... (32 bits)
124 b : .... .... .... .... .... (32 bits)
125 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
126 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
127 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
128 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
129
130 :attribute width: the bit width of the input and output. Read-only.
131 :attribute a: the first input to the adder
132 :attribute b: the second input to the adder
133 :attribute output: the sum output
134 :attribute partition_points: the input partition points. Modification not
135 supported, except for by ``Signal.eq``.
136 """
137
138 def __init__(self, width, partition_points, partition_step=1):
139 """Create a ``PartitionedAdder``.
140
141 :param width: the bit width of the input and output
142 :param partition_points: the input partition points
143 :param partition_step: a multiplier (typically double) step
144 which in-place "expands" the partition points
145 """
146 self.width = width
147 self.pmul = partition_step
148 self.partition_points = PartitionPoints(partition_points)
149 self.a = Signal(width, reset_less=True)
150 self.b = Signal(width, reset_less=True)
151 self.carry_in = Signal(self.partition_points.get_max_partition_count(width))
152 self.carry_out = Signal(self.partition_points.get_max_partition_count(width))
153 self.output = Signal(width, reset_less=True)
154 if not self.partition_points.fits_in_width(width):
155 raise ValueError("partition_points doesn't fit in width")
156 expanded_width = 2
157 for i in range(self.width):
158 if i in self.partition_points:
159 expanded_width += 1
160 expanded_width += 1
161 self._expanded_width = expanded_width
162
163 def elaborate(self, platform):
164 """Elaborate this module."""
165 m = Module()
166 comb = m.d.comb
167 expanded_a = Signal(self._expanded_width, reset_less=True)
168 expanded_b = Signal(self._expanded_width, reset_less=True)
169 expanded_o = Signal(self._expanded_width, reset_less=True)
170
171 expanded_index = 0
172 # store bits in a list, use Cat later. graphviz is much cleaner
173 al, bl, ol, cl, ea, eb, eo, co = [],[],[],[],[],[],[],[]
174
175 # partition points are "breaks" (extra zeros or 1s) in what would
176 # otherwise be a massive long add. when the "break" points are 0,
177 # whatever is in it (in the output) is discarded. however when
178 # there is a "1", it causes a roll-over carry to the *next* bit.
179 # we still ignore the "break" bit in the [intermediate] output,
180 # however by that time we've got the effect that we wanted: the
181 # carry has been carried *over* the break point.
182
183 carry_bit = 0
184 al.append(self.carry_in[carry_bit])
185 bl.append(self.carry_in[carry_bit])
186 ea.append(expanded_a[expanded_index])
187 eb.append(expanded_b[expanded_index])
188 carry_bit += 1
189 expanded_index += 1
190
191 for i in range(self.width):
192 pi = i/self.pmul # double the range of the partition point test
193 if pi.is_integer() and pi in self.partition_points:
194 # add extra bit set to 0 + 0 for enabled partition points
195 a_bit = Signal()
196 m.d.comb += a_bit.eq(~self.partition_points[pi] |
197 (self.partition_points[pi] & self.carry_in[carry_bit]))
198 # and 1 + 0 for disabled partition points
199 ea.append(expanded_a[expanded_index])
200 al.append(a_bit) # add extra bit in a
201 eb.append(expanded_b[expanded_index])
202 bl.append(self.carry_in[carry_bit] &
203 self.partition_points[pi]) # yes, add a zero
204 co.append(expanded_o[expanded_index])
205 cl.append(self.carry_out[carry_bit-1])
206 expanded_index += 1 # skip the extra point. NOT in the output
207 carry_bit += 1
208 ea.append(expanded_a[expanded_index])
209 eb.append(expanded_b[expanded_index])
210 eo.append(expanded_o[expanded_index])
211 al.append(self.a[i])
212 bl.append(self.b[i])
213 ol.append(self.output[i])
214 expanded_index += 1
215 al.append(0)
216 bl.append(0)
217 co.append(expanded_o[expanded_index])
218 cl.append(self.carry_out[carry_bit-1])
219
220 # combine above using Cat
221 comb += Cat(*ea).eq(Cat(*al))
222 comb += Cat(*eb).eq(Cat(*bl))
223 comb += Cat(*ol).eq(Cat(*eo))
224 comb += Cat(*cl).eq(Cat(*co))
225
226 # use only one addition to take advantage of look-ahead carry and
227 # special hardware on FPGAs
228 comb += expanded_o.eq(expanded_a + expanded_b)
229
230 return m
231
232