44c1922f4893b03a1aa58f2225426191deadf037
[ieee754fpu.git] / src / ieee754 / part_swizzle / swizzle.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 from dataclasses import dataclass
5 from functools import reduce
6 from typing import Dict, FrozenSet, List, Set, Tuple
7 from nmigen.hdl.ast import Cat, Const, Shape, Signal, SignalKey, Value, ValueKey
8 from nmigen.hdl.dsl import Module
9 from nmigen.hdl.ir import Elaboratable
10 from ieee754.part.partsig import SimdSignal
11
12
13 @dataclass(frozen=True, unsafe_hash=True)
14 class Bit:
15 def get_value(self):
16 """get the value of this bit as a nmigen `Value`"""
17 raise NotImplementedError("called abstract method")
18
19
20 @dataclass(frozen=True, unsafe_hash=True)
21 class ValueBit(Bit):
22 src: ValueKey
23 bit_index: int
24
25 def __init__(self, src, bit_index):
26 if not isinstance(src, ValueKey):
27 src = ValueKey(src)
28 assert isinstance(bit_index, int)
29 assert bit_index in range(len(src.value))
30 object.__setattr__(self, "src", src)
31 object.__setattr__(self, "bit_index", bit_index)
32
33 def get_value(self):
34 """get the value of this bit as a nmigen `Value`"""
35 return self.src.value[self.bit_index]
36
37 def get_assign_target_sig(self):
38 """get the Signal that assigning to this bit would assign to"""
39 if isinstance(self.src.value, Signal):
40 return self.src.value
41 raise TypeError("not a valid assignment target")
42
43 def assign(self, value, signals_map):
44 sig = self.get_assign_target_sig()
45 return signals_map[SignalKey(sig)][self.bit_index].eq(value)
46
47
48 @dataclass(frozen=True, unsafe_hash=True)
49 class ConstBit(Bit):
50 bit: bool
51
52 def get_value(self):
53 return Const(self.bit, 1)
54
55
56 @dataclass(frozen=True)
57 class Swizzle:
58 bits: List[Bit]
59
60 def __init__(self, bits=()):
61 bits = list(bits)
62 for bit in bits:
63 assert isinstance(bit, Bit)
64 object.__setattr__(self, "bits", bits)
65
66 @staticmethod
67 def from_const(value, width):
68 return Swizzle(ConstBit((value & (1 << i)) != 0) for i in range(width))
69
70 @staticmethod
71 def from_value(value):
72 value = Value.cast(value)
73 if isinstance(value, Const):
74 return Swizzle.from_const(value.value, len(value))
75 return Swizzle(ValueBit(value, i) for i in range(len(value)))
76
77 def get_value(self):
78 return Cat(*(bit.get_value() for bit in self.bits))
79
80 def get_sign(self):
81 return self.bits[-1] if len(self.bits) != 0 else ConstBit(False)
82
83 def convert_u_to(self, shape):
84 shape = Shape.cast(shape)
85 additional = shape.width - len(self.bits)
86 self.bits[shape.width:] = [ConstBit(False)] * additional
87
88 def convert_s_to(self, shape):
89 shape = Shape.cast(shape)
90 additional = shape.width - len(self.bits)
91 self.bits[shape.width:] = [self.get_sign()] * additional
92
93 def __getitem__(self, key):
94 if isinstance(key, int):
95 return Swizzle([self.bits[key]])
96 assert isinstance(key, slice)
97 return Swizzle(self.bits[key])
98
99 def __add__(self, other):
100 if isinstance(other, Swizzle):
101 return Swizzle(self.bits + other.bits)
102 return NotImplemented
103
104 def __radd__(self, other):
105 if isinstance(other, Swizzle):
106 return Swizzle(other.bits + self.bits)
107 return NotImplemented
108
109 def __iadd__(self, other):
110 assert isinstance(other, Swizzle)
111 self.bits += other.bits
112 return self
113
114 def get_assign_target_sigs(self):
115 for b in self.bits:
116 assert isinstance(b, ValueBit)
117 yield b.get_assign_target_sig()
118
119
120 @dataclass(frozen=True)
121 class SwizzleKey:
122 """should be elwid or something similar.
123 importantly, all SimdSignals that are used together must have equal
124 SwizzleKeys."""
125 value: ValueKey
126 possible_values: FrozenSet[int]
127
128 @staticmethod
129 def from_simd_signal(simd_signal):
130 if isinstance(simd_signal, SwizzledSimdValue):
131 return simd_signal.swizzle_key
132
133 # can't just be PartitionPoints, since those vary between
134 # SimdSignals with different padding
135 raise NotImplementedError("TODO: implement extracting a SwizzleKey "
136 "from a SimdSignal")
137
138 def __init__(self, value, possible_values):
139 object.__setattr__(self, "value", ValueKey(value))
140 pvalues = []
141 shape = self.value.value.shape()
142 for value in possible_values:
143 if isinstance(value, int):
144 assert value == Const.normalize(value, shape)
145 else:
146 value = Value.cast(value)
147 assert isinstance(value, Const)
148 value = value.value
149 pvalues.append(value)
150 assert len(pvalues) != 0, "SwizzleKey can't have zero possible values"
151 object.__setattr__(self, "possible_values", frozenset(pvalues))
152
153
154 class ResolveSwizzle(Elaboratable):
155 def __init__(self, swizzled_simd_value):
156 assert isinstance(swizzled_simd_value, SwizzledSimdValue)
157 self.swizzled_simd_value = swizzled_simd_value
158
159 def elaborate(self, platform):
160 m = Module()
161 swizzle_key = self.swizzled_simd_value.swizzle_key
162 swizzles = self.swizzled_simd_value.swizzles
163 output = self.swizzled_simd_value.sig
164 with m.Switch(swizzle_key.value):
165 for k in sorted(swizzle_key.possible_values):
166 swizzle = swizzles[k]
167 with m.Case(k):
168 m.d.comb += output.eq(swizzle.get_value())
169 return m
170
171
172 class AssignSwizzle(Elaboratable):
173 def __init__(self, swizzled_simd_value, src_sig):
174 assert isinstance(swizzled_simd_value, SwizzledSimdValue)
175 self.swizzled_simd_value = swizzled_simd_value
176 assert isinstance(src_sig, Signal)
177 self.src_sig = src_sig
178 self.converted_src_sig = Signal.like(swizzled_simd_value._sig_internal)
179 targets = swizzled_simd_value._get_assign_target_sigs()
180 targets = sorted({SignalKey(s) for s in targets})
181
182 def make_sig(i, s):
183 return Signal.like(s.signal, name=f"outputs_{i}")
184 self.outputs = {s: make_sig(i, s) for i, s in enumerate(targets)}
185
186 def elaborate(self, platform):
187 m = Module()
188 swizzle_key = self.swizzled_simd_value.swizzle_key
189 swizzles = self.swizzled_simd_value.swizzles
190 for k, v in self.outputs.items():
191 m.d.comb += v.eq(k.signal)
192 m.d.comb += self.converted_src_sig.eq(self.src_sig)
193 with m.Switch(swizzle_key.value):
194 for k in sorted(swizzle_key.possible_values):
195 swizzle = swizzles[k]
196 with m.Case(k):
197 for index, bit in enumerate(swizzle.bits):
198 rhs = self.converted_src_sig[index]
199 assert isinstance(bit, ValueBit)
200 m.d.comb += bit.assign(rhs, self.outputs)
201 return m
202
203
204 class SwizzledSimdValue(SimdSignal):
205 """the result of any number of Cat and Slice operations on
206 Signals/SimdSignals. This is specifically intended to support assignment
207 to Cat and Slice, but is also useful for reducing the number of muxes
208 chained together down to a single layer of muxes."""
209 __next_id = 0
210
211 @staticmethod
212 def from_simd_signal(simd_signal):
213 if isinstance(simd_signal, SwizzledSimdValue):
214 return simd_signal
215 assert isinstance(simd_signal, SimdSignal)
216 swizzle_key = SwizzleKey.from_simd_signal(simd_signal)
217 swizzle = Swizzle.from_value(simd_signal.sig)
218 retval = SwizzledSimdValue(swizzle_key, swizzle)
219 retval.set_module(simd_signal.m)
220 return retval
221
222 @staticmethod
223 def __do_splat(swizzle_key, value):
224 """splat a non-simd value, returning a SimdSignal"""
225 raise NotImplementedError("TODO: need splat implementation")
226
227 def __do_convert_rhs_to_simd_signal_like_self(self, rhs):
228 """convert a value to be a SimdSignal of the same layout/shape as self,
229 returning a SimdSignal."""
230 raise NotImplementedError("TODO: need conversion implementation")
231
232 @staticmethod
233 def from_value(swizzle_key, value):
234 if not isinstance(value, SimdSignal):
235 value = SwizzledSimdValue.__do_splat(swizzle_key, value)
236 retval = SwizzledSimdValue.from_simd_signal(value)
237 assert swizzle_key == retval.swizzle_key
238 return retval
239
240 @classmethod
241 def __make_name(cls):
242 id_ = cls.__next_id
243 cls.__next_id = id_ + 1
244 return f"swizzle_{id_}"
245
246 def __init__(self, swizzle_key, swizzles):
247 assert isinstance(swizzle_key, SwizzleKey)
248 self.swizzle_key = swizzle_key
249 possible_keys = swizzle_key.possible_values
250 if isinstance(swizzles, Swizzle):
251 self.swizzles = {k: swizzles for k in possible_keys}
252 else:
253 self.swizzles = {}
254 for k in possible_keys:
255 swizzle = swizzles[k]
256 assert isinstance(swizzle, Swizzle)
257 self.swizzles[k] = swizzle
258 width = None
259 for swizzle in self.swizzles.values():
260 if width is None:
261 width = len(swizzle.bits)
262 assert width == len(swizzle.bits), \
263 "inconsistent swizzle widths"
264 assert width is not None
265 self.__sig_need_setup = False # ignore accesses during __init__
266 super().__init__(swizzle_key.value, width, name="output")
267 self.__sig_need_setup = True
268
269 @property
270 def sig(self):
271 # override sig to handle lazily adding the ResolveSwizzle submodule
272 if self.__sig_need_setup:
273 self.__sig_need_setup = False
274 submodule = ResolveSwizzle(self)
275 setattr(self.m.submodules, self.__make_name(), submodule)
276 return self._sig_internal
277
278 @sig.setter
279 def sig(self, value):
280 assert isinstance(value, Signal)
281 self._sig_internal = value
282
283 def _get_assign_target_sigs(self):
284 for swizzle in self.swizzles.values():
285 yield from swizzle.get_assign_target_sigs()
286
287 def __Assign__(self, val, *, src_loc_at=0):
288 rhs = self.__do_convert_rhs_to_simd_signal_like_self(val)
289 assert isinstance(rhs, SimdSignal)
290 submodule = AssignSwizzle(self, rhs.sig)
291 setattr(self.m.submodules, self.__make_name(), submodule)
292 return [k.signal.eq(v) for k, v in submodule.outputs.items()]
293
294 def __Cat__(self, *args, src_loc_at=0):
295 raise NotImplementedError("TODO: implement")
296
297 def __Slice__(self, start, stop, *, src_loc_at=0):
298 raise NotImplementedError("TODO: implement")