48105fa7e0d005854afb6da4d7a215d3431ec57d
[nmigen-gf.git] / src / nmigen_gf / hdl / cldivrem.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 """ Carry-less Division and Remainder.
8
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
10 """
11
12 from dataclasses import dataclass, field, fields
13 from nmigen.hdl.ast import Signal, Value
14 from nmigen.hdl.dsl import Module
15 from nmutil.singlepipe import ControlBase
16 from nmutil.clz import CLZ, clz
17
18
19 def cldivrem_shifting(n, d, width):
20 """ Carry-less Division and Remainder based on shifting at start and end
21 allowing us to get away with checking a single bit each iteration
22 rather than checking for equal degrees every iteration.
23 `n` and `d` are integers, `width` is the number of bits needed to hold
24 each input/output.
25 Returns a tuple `q, r` of the quotient and remainder.
26 """
27 assert isinstance(width, int) and width >= 1
28 assert isinstance(n, int) and 0 <= n < 1 << width
29 assert isinstance(d, int) and 0 <= d < 1 << width
30 assert d != 0, "TODO: decide what happens on division by zero"
31
32 shape = CLDivRemShape(width)
33
34 # `clz(d, width)`, but maxes out at `width - 1` instead of `width` in
35 # order to both fit in `shape.shift_width` bits and to not shift by more
36 # than needed.
37 shift = clz(d >> 1, width - 1)
38 assert 0 <= shift < 1 << shape.shift_width, "shift overflow"
39 d <<= shift
40 assert 0 <= d < 1 << shape.d_width, "d overflow"
41 r = n << shift
42 assert 0 <= r < 1 << shape.r_width, "r overflow"
43 q = 0
44 for step in range(width):
45 q <<= 1
46 r <<= 1
47 if r >> (width * 2 - 1) != 0:
48 r ^= d << width
49 q |= 1
50 assert 0 <= q < 1 << shape.q_width, "q overflow"
51 assert 0 <= r < 1 << shape.r_width, "r overflow"
52 r >>= width
53 r >>= shift
54 return q, r
55
56
57 @dataclass(frozen=True, unsafe_hash=True)
58 class CLDivRemShape:
59 width: int
60
61 def __post_init__(self):
62 assert isinstance(self.width, int) and self.width >= 1, "invalid width"
63
64 @property
65 def done_step(self):
66 """the step number when iteration is finished
67 -- the largest `CLDivRemState.step` will get
68 """
69 return self.width
70
71 @property
72 def step_range(self):
73 """the range that `CLDivRemState.step` will fall in.
74
75 returns: range
76 """
77 return range(self.done_step + 1)
78
79 @property
80 def d_width(self):
81 """bit-width of the internal signal `CLDivRemState.d`"""
82 return self.width
83
84 @property
85 def r_width(self):
86 """bit-width of the internal signal `CLDivRemState.r`"""
87 return self.width * 2
88
89 @property
90 def q_width(self):
91 """bit-width of the internal signal `CLDivRemState.q`"""
92 return self.width
93
94 @property
95 def shift_width(self):
96 """bit-width of the internal signal `CLDivRemState.shift`"""
97 return (self.width - 1).bit_length()
98
99
100 @dataclass(frozen=True, eq=False)
101 class CLDivRemState:
102 shape: CLDivRemShape
103 name: str
104 step: Signal = field(init=False)
105 d: Signal = field(init=False)
106 r: Signal = field(init=False)
107 q: Signal = field(init=False)
108 shift: Signal = field(init=False)
109
110 def __init__(self, shape, *, name=None, src_loc_at=0):
111 assert isinstance(shape, CLDivRemShape)
112 if name is None:
113 name = Signal(src_loc_at=1 + src_loc_at).name
114 assert isinstance(name, str)
115 step = Signal(shape.step_range, name=f"{name}_step")
116 d = Signal(shape.d_width, name=f"{name}_d")
117 r = Signal(shape.r_width, name=f"{name}_r")
118 q = Signal(shape.q_width, name=f"{name}_q")
119 shift = Signal(shape.shift_width, name=f"{name}_shift")
120 object.__setattr__(self, "shape", shape)
121 object.__setattr__(self, "name", name)
122 object.__setattr__(self, "step", step)
123 object.__setattr__(self, "d", d)
124 object.__setattr__(self, "r", r)
125 object.__setattr__(self, "q", q)
126 object.__setattr__(self, "shift", shift)
127
128 def eq(self, rhs):
129 assert isinstance(rhs, CLDivRemState)
130 for f in fields(CLDivRemState):
131 if f.name in ("shape", "name"):
132 continue
133 l = getattr(self, f.name)
134 r = getattr(rhs, f.name)
135 yield l.eq(r)
136
137 @staticmethod
138 def like(other, *, name=None, src_loc_at=0):
139 assert isinstance(other, CLDivRemState)
140 return CLDivRemState(other.shape, name=name, src_loc_at=1 + src_loc_at)
141
142 @property
143 def done(self):
144 return self.will_be_done_after(steps=0)
145
146 def will_be_done_after(self, steps):
147 """ Returns True if this state will be done after
148 another `steps` passes through `set_to_next`."""
149 assert isinstance(steps, int) and steps >= 0
150 return self.step >= max(0, self.shape.done_step - steps)
151
152 def get_output(self):
153 return self.q, (self.r >> self.shape.width) >> self.shift
154
155 def set_to_initial(self, m, n, d):
156 assert isinstance(m, Module)
157 n = Value.cast(n) # convert to Value
158 d = Value.cast(d) # convert to Value
159 clz_mod = CLZ(self.shape.width - 1)
160 # can't name submodule since it would conflict if this function is
161 # called multiple times in a Module
162 m.submodules += clz_mod
163 assert clz_mod.lz.width == self.shape.shift_width, \
164 "internal inconsistency -- mismatched shift signal width"
165 m.d.comb += [
166 clz_mod.sig_in.eq(d >> 1),
167 self.shift.eq(clz_mod.lz),
168 self.d.eq(d << self.shift),
169 self.r.eq(n << self.shift),
170 self.q.eq(0),
171 self.step.eq(0),
172 ]
173
174 def set_to_next(self, m, state_in):
175 assert isinstance(m, Module)
176 assert isinstance(state_in, CLDivRemState)
177 assert state_in.shape == self.shape
178 assert self is not state_in, "a.set_to_next(m, a) is not allowed"
179 width = self.shape.width
180
181 with m.If(state_in.done):
182 m.d.comb += self.eq(state_in)
183 with m.Else():
184 m.d.comb += [
185 self.step.eq(state_in.step + 1),
186 self.d.eq(state_in.d),
187 self.shift.eq(state_in.shift),
188 ]
189 q = state_in.q << 1
190 r = state_in.r << 1
191 with m.If(r[width * 2 - 1]):
192 m.d.comb += [
193 self.q.eq(q | 1),
194 self.r.eq(r ^ (state_in.d << width)),
195 ]
196 with m.Else():
197 m.d.comb += [
198 self.q.eq(q),
199 self.r.eq(r),
200 ]
201
202
203 class CLDivRemInputData:
204 def __init__(self, shape):
205 assert isinstance(shape, CLDivRemShape)
206 self.shape = shape
207 self.n = Signal(shape.width)
208 self.d = Signal(shape.width)
209
210 def __iter__(self):
211 """ Get member signals. """
212 yield self.n
213 yield self.d
214
215 def eq(self, rhs):
216 """ Assign member signals. """
217 return [
218 self.n.eq(rhs.n),
219 self.d.eq(rhs.d),
220 ]
221
222
223 class CLDivRemOutputData:
224 def __init__(self, shape):
225 assert isinstance(shape, CLDivRemShape)
226 self.shape = shape
227 self.q = Signal(shape.width)
228 self.r = Signal(shape.width)
229
230 def __iter__(self):
231 """ Get member signals. """
232 yield self.q
233 yield self.r
234
235 def eq(self, rhs):
236 """ Assign member signals. """
237 return [
238 self.q.eq(rhs.q),
239 self.r.eq(rhs.r),
240 ]
241
242
243 class CLDivRemFSMStage(ControlBase):
244 """carry-less div/rem
245
246 Attributes:
247 shape: CLDivRemShape
248 the shape
249 steps_per_clock: int
250 number of steps that should be taken per clock cycle
251 pspec:
252 pipe-spec
253 empty: Signal()
254 true if nothing is stored in `self.saved_state`
255 saved_state: CLDivRemState()
256 the saved state that is currently being worked on.
257 """
258
259 def __init__(self, pspec, shape, *, steps_per_clock=8):
260 assert isinstance(shape, CLDivRemShape)
261 assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
262 self.shape = shape
263 self.steps_per_clock = steps_per_clock
264 self.pspec = pspec # store now: used in ispec and ospec
265 super().__init__(stage=self)
266 self.empty = Signal(reset=1)
267 self.saved_state = CLDivRemState(shape)
268
269 def ispec(self):
270 return CLDivRemInputData(self.shape)
271
272 def ospec(self):
273 return CLDivRemOutputData(self.shape)
274
275 def setup(self, m, i):
276 pass
277
278 def elaborate(self, platform):
279 m = super().elaborate(platform)
280 i_data: CLDivRemInputData = self.p.i_data
281 o_data: CLDivRemOutputData = self.n.o_data
282
283 # TODO: handle cancellation
284
285 m.d.comb += self.n.o_valid.eq(~self.empty & self.saved_state.done)
286 m.d.comb += self.p.o_ready.eq(self.empty)
287
288 def make_nc(i):
289 return CLDivRemState(self.shape, name=f"next_chain_{i}")
290 next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)]
291 for i in range(self.steps_per_clock):
292 next_chain[i + 1].set_to_next(m, next_chain[i])
293 m.d.comb += next_chain[0].eq(self.saved_state)
294 out_q, out_r = self.saved_state.get_output()
295 m.d.comb += o_data.q.eq(out_q)
296 m.d.comb += o_data.r.eq(out_r)
297 initial_state = CLDivRemState(self.shape)
298 initial_state.set_to_initial(m, n=i_data.n, d=i_data.d)
299
300 with m.If(self.empty):
301 m.d.sync += self.saved_state.eq(initial_state)
302 with m.If(self.p.i_valid):
303 m.d.sync += self.empty.eq(0)
304 with m.Else():
305 m.d.sync += self.saved_state.eq(next_chain[-1])
306 with m.If(self.n.i_ready & self.n.o_valid):
307 m.d.sync += self.empty.eq(1)
308 return m
309
310 def __iter__(self):
311 yield from self.p
312 yield from self.n
313
314 def ports(self):
315 return list(self)