change all uses of dataclass to plain_data
[ieee754fpu.git] / src / ieee754 / part_swizzle / swizzle.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 from nmutil.plain_data import plain_data
5 from nmigen.hdl.ast import Cat, Const, Shape, Signal, SignalKey, Value, ValueKey
6 from nmigen.hdl.dsl import Module
7 from nmigen.hdl.ir import Elaboratable
8 from ieee754.part.partsig import SimdSignal
9
10
11 @plain_data(frozen=True, unsafe_hash=True)
12 class Bit:
13 __slots__ = ()
14
15 def get_value(self):
16 """get the value of this bit as a nmigen `Value`"""
17 raise NotImplementedError("called abstract method")
18
19
20 @plain_data(frozen=True, unsafe_hash=True)
21 class ValueBit(Bit):
22 __slots__ = "src", "bit_index"
23
24 def __init__(self, src, bit_index):
25 if not isinstance(src, ValueKey):
26 src = ValueKey(src)
27 assert isinstance(bit_index, int)
28 assert bit_index in range(len(src.value))
29 self.src = src
30 self.bit_index = bit_index
31
32 def get_value(self):
33 """get the value of this bit as a nmigen `Value`"""
34 return self.src.value[self.bit_index]
35
36 def get_assign_target_sig(self):
37 """get the Signal that assigning to this bit would assign to"""
38 if isinstance(self.src.value, Signal):
39 return self.src.value
40 raise TypeError("not a valid assignment target")
41
42 def assign(self, value, signals_map):
43 sig = self.get_assign_target_sig()
44 return signals_map[SignalKey(sig)][self.bit_index].eq(value)
45
46
47 @plain_data(frozen=True, unsafe_hash=True)
48 class ConstBit(Bit):
49 __slots__ = "bit",
50
51 def __init__(self, bit):
52 self.bit = bool(bit)
53
54 def get_value(self):
55 return Const(self.bit, 1)
56
57
58 @plain_data(frozen=True)
59 class Swizzle:
60 """
61 Attributes:
62 bits: list[Bit]
63 """
64 __slots__ = "bits",
65
66 def __init__(self, bits=()):
67 bits = list(bits)
68 for bit in bits:
69 assert isinstance(bit, Bit)
70 self.bits = bits
71
72 @staticmethod
73 def from_const(value, width):
74 return Swizzle(ConstBit((value & (1 << i)) != 0) for i in range(width))
75
76 @staticmethod
77 def from_value(value):
78 value = Value.cast(value)
79 if isinstance(value, Const):
80 return Swizzle.from_const(value.value, len(value))
81 return Swizzle(ValueBit(value, i) for i in range(len(value)))
82
83 def get_value(self):
84 return Cat(*(bit.get_value() for bit in self.bits))
85
86 def get_sign(self):
87 return self.bits[-1] if len(self.bits) != 0 else ConstBit(False)
88
89 def convert_u_to(self, shape):
90 shape = Shape.cast(shape)
91 additional = shape.width - len(self.bits)
92 self.bits[shape.width:] = [ConstBit(False)] * additional
93
94 def convert_s_to(self, shape):
95 shape = Shape.cast(shape)
96 additional = shape.width - len(self.bits)
97 self.bits[shape.width:] = [self.get_sign()] * additional
98
99 def __getitem__(self, key):
100 if isinstance(key, int):
101 return Swizzle([self.bits[key]])
102 assert isinstance(key, slice)
103 return Swizzle(self.bits[key])
104
105 def __add__(self, other):
106 if isinstance(other, Swizzle):
107 return Swizzle(self.bits + other.bits)
108 return NotImplemented
109
110 def __radd__(self, other):
111 if isinstance(other, Swizzle):
112 return Swizzle(other.bits + self.bits)
113 return NotImplemented
114
115 def __iadd__(self, other):
116 assert isinstance(other, Swizzle)
117 self.bits += other.bits
118 return self
119
120 def get_assign_target_sigs(self):
121 for b in self.bits:
122 assert isinstance(b, ValueBit)
123 yield b.get_assign_target_sig()
124
125
126 @plain_data(frozen=True)
127 class SwizzleKey:
128 """should be elwid or something similar.
129 importantly, all SimdSignals that are used together must have equal
130 SwizzleKeys.
131
132 Attributes:
133 value: ValueKey
134 possible_values: FrozenSet[int]
135 """
136 __slots__ = "value", "possible_values"
137
138 @staticmethod
139 def from_simd_signal(simd_signal):
140 if isinstance(simd_signal, SwizzledSimdValue):
141 return simd_signal.swizzle_key
142
143 # can't just be PartitionPoints, since those vary between
144 # SimdSignals with different padding
145 raise NotImplementedError("TODO: implement extracting a SwizzleKey "
146 "from a SimdSignal")
147
148 def __init__(self, value, possible_values):
149 self.value = ValueKey(value)
150 pvalues = []
151 shape = self.value.value.shape()
152 for value in possible_values:
153 if isinstance(value, int):
154 assert value == Const.normalize(value, shape)
155 else:
156 value = Value.cast(value)
157 assert isinstance(value, Const)
158 value = value.value
159 pvalues.append(value)
160 assert len(pvalues) != 0, "SwizzleKey can't have zero possible values"
161 self.possible_values = frozenset(pvalues)
162
163
164 class ResolveSwizzle(Elaboratable):
165 def __init__(self, swizzled_simd_value):
166 assert isinstance(swizzled_simd_value, SwizzledSimdValue)
167 self.swizzled_simd_value = swizzled_simd_value
168
169 def elaborate(self, platform):
170 m = Module()
171 swizzle_key = self.swizzled_simd_value.swizzle_key
172 swizzles = self.swizzled_simd_value.swizzles
173 output = self.swizzled_simd_value.sig
174 with m.Switch(swizzle_key.value):
175 for k in sorted(swizzle_key.possible_values):
176 swizzle = swizzles[k]
177 with m.Case(k):
178 m.d.comb += output.eq(swizzle.get_value())
179 return m
180
181
182 class AssignSwizzle(Elaboratable):
183 def __init__(self, swizzled_simd_value, src_sig):
184 assert isinstance(swizzled_simd_value, SwizzledSimdValue)
185 self.swizzled_simd_value = swizzled_simd_value
186 assert isinstance(src_sig, Signal)
187 self.src_sig = src_sig
188 self.converted_src_sig = Signal.like(swizzled_simd_value._sig_internal)
189 targets = swizzled_simd_value._get_assign_target_sigs()
190 targets = sorted({SignalKey(s) for s in targets})
191
192 def make_sig(i, s):
193 return Signal.like(s.signal, name=f"outputs_{i}")
194 self.outputs = {s: make_sig(i, s) for i, s in enumerate(targets)}
195
196 def elaborate(self, platform):
197 m = Module()
198 swizzle_key = self.swizzled_simd_value.swizzle_key
199 swizzles = self.swizzled_simd_value.swizzles
200 for k, v in self.outputs.items():
201 m.d.comb += v.eq(k.signal)
202 m.d.comb += self.converted_src_sig.eq(self.src_sig)
203 with m.Switch(swizzle_key.value):
204 for k in sorted(swizzle_key.possible_values):
205 swizzle = swizzles[k]
206 with m.Case(k):
207 for index, bit in enumerate(swizzle.bits):
208 rhs = self.converted_src_sig[index]
209 assert isinstance(bit, ValueBit)
210 m.d.comb += bit.assign(rhs, self.outputs)
211 return m
212
213
214 class SwizzledSimdValue(SimdSignal):
215 """the result of any number of Cat and Slice operations on
216 Signals/SimdSignals. This is specifically intended to support assignment
217 to Cat and Slice, but is also useful for reducing the number of muxes
218 chained together down to a single layer of muxes."""
219 __next_id = 0
220
221 @staticmethod
222 def from_simd_signal(simd_signal):
223 if isinstance(simd_signal, SwizzledSimdValue):
224 return simd_signal
225 assert isinstance(simd_signal, SimdSignal)
226 swizzle_key = SwizzleKey.from_simd_signal(simd_signal)
227 swizzle = Swizzle.from_value(simd_signal.sig)
228 retval = SwizzledSimdValue(swizzle_key, swizzle)
229 retval.set_module(simd_signal.m)
230 return retval
231
232 @staticmethod
233 def __do_splat(swizzle_key, value):
234 """splat a non-simd value, returning a SimdSignal"""
235 raise NotImplementedError("TODO: need splat implementation")
236
237 def __do_convert_rhs_to_simd_signal_like_self(self, rhs):
238 """convert a value to be a SimdSignal of the same layout/shape as self,
239 returning a SimdSignal."""
240 raise NotImplementedError("TODO: need conversion implementation")
241
242 @staticmethod
243 def from_value(swizzle_key, value):
244 if not isinstance(value, SimdSignal):
245 value = SwizzledSimdValue.__do_splat(swizzle_key, value)
246 retval = SwizzledSimdValue.from_simd_signal(value)
247 assert swizzle_key == retval.swizzle_key
248 return retval
249
250 @classmethod
251 def __make_name(cls):
252 id_ = cls.__next_id
253 cls.__next_id = id_ + 1
254 return f"swizzle_{id_}"
255
256 def __init__(self, swizzle_key, swizzles):
257 assert isinstance(swizzle_key, SwizzleKey)
258 self.swizzle_key = swizzle_key
259 possible_keys = swizzle_key.possible_values
260 if isinstance(swizzles, Swizzle):
261 self.swizzles = {k: swizzles for k in possible_keys}
262 else:
263 self.swizzles = {}
264 for k in possible_keys:
265 swizzle = swizzles[k]
266 assert isinstance(swizzle, Swizzle)
267 self.swizzles[k] = swizzle
268 width = None
269 for swizzle in self.swizzles.values():
270 if width is None:
271 width = len(swizzle.bits)
272 assert width == len(swizzle.bits), \
273 "inconsistent swizzle widths"
274 assert width is not None
275 self.__sig_need_setup = False # ignore accesses during __init__
276 super().__init__(swizzle_key.value, width, name="output")
277 self.__sig_need_setup = True
278
279 @property
280 def sig(self):
281 # override sig to handle lazily adding the ResolveSwizzle submodule
282 if self.__sig_need_setup:
283 self.__sig_need_setup = False
284 submodule = ResolveSwizzle(self)
285 setattr(self.m.submodules, self.__make_name(), submodule)
286 return self._sig_internal
287
288 @sig.setter
289 def sig(self, value):
290 assert isinstance(value, Signal)
291 self._sig_internal = value
292
293 def _get_assign_target_sigs(self):
294 for swizzle in self.swizzles.values():
295 yield from swizzle.get_assign_target_sigs()
296
297 def __Assign__(self, val, *, src_loc_at=0):
298 rhs = self.__do_convert_rhs_to_simd_signal_like_self(val)
299 assert isinstance(rhs, SimdSignal)
300 submodule = AssignSwizzle(self, rhs.sig)
301 setattr(self.m.submodules, self.__make_name(), submodule)
302 return [k.signal.eq(v) for k, v in submodule.outputs.items()]
303
304 def __Cat__(self, *args, src_loc_at=0):
305 raise NotImplementedError("TODO: implement")
306
307 def __Slice__(self, start, stop, *, src_loc_at=0):
308 raise NotImplementedError("TODO: implement")