eacceb410d8726d49d32f9be1192b8446c46a707
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir2.py
1 import enum
2 from enum import Enum, unique
3 from typing import AbstractSet, Iterable, Iterator, NoReturn, Tuple, Union, overload
4
5 from cached_property import cached_property
6 from nmutil.plain_data import plain_data
7
8 from bigint_presentation_code.util import OFSet, OSet, Self, assert_never, final
9 from weakref import WeakValueDictionary
10
11
12 @final
13 class Fn:
14 def __init__(self):
15 self.ops = [] # type: list[Op]
16 op_names = WeakValueDictionary()
17 self.__op_names = op_names # type: WeakValueDictionary[str, Op]
18 self.__next_name_suffix = 2
19
20 def _add_op_with_unused_name(self, op, name=""):
21 # type: (Op, str) -> str
22 if op.fn is not self:
23 raise ValueError("can't add Op to wrong Fn")
24 if hasattr(op, "name"):
25 raise ValueError("Op already named")
26 orig_name = name
27 while True:
28 if name not in self.__op_names:
29 self.__op_names[name] = op
30 return name
31 name = orig_name + str(self.__next_name_suffix)
32 self.__next_name_suffix += 1
33
34 def __repr__(self):
35 return "<Fn>"
36
37
38 @unique
39 @final
40 class RegKind(Enum):
41 GPR = enum.auto()
42 CA = enum.auto()
43 VL_MAXVL = enum.auto()
44
45 @cached_property
46 def only_scalar(self):
47 if self is RegKind.GPR:
48 return False
49 elif self is RegKind.CA or self is RegKind.VL_MAXVL:
50 return True
51 else:
52 assert_never(self)
53
54 @cached_property
55 def reg_count(self):
56 if self is RegKind.GPR:
57 return 128
58 elif self is RegKind.CA or self is RegKind.VL_MAXVL:
59 return 1
60 else:
61 assert_never(self)
62
63 def __repr__(self):
64 return "RegKind." + self._name_
65
66
67 @plain_data(frozen=True, unsafe_hash=True)
68 @final
69 class OperandType:
70 __slots__ = "kind", "vec"
71
72 def __init__(self, kind, vec):
73 # type: (RegKind, bool) -> None
74 self.kind = kind
75 if kind.only_scalar and vec:
76 raise ValueError(f"kind={kind} must have vec=False")
77 self.vec = vec
78
79 def get_length(self, maxvl):
80 # type: (int) -> int
81 # here's where subvl and elwid would be accounted for
82 if self.vec:
83 return maxvl
84 return 1
85
86
87 @plain_data(frozen=True, unsafe_hash=True)
88 @final
89 class RegShape:
90 __slots__ = "kind", "length"
91
92 def __init__(self, kind, length=1):
93 # type: (RegKind, int) -> None
94 self.kind = kind
95 if length < 1 or length > kind.reg_count:
96 raise ValueError("invalid length")
97 self.length = length
98
99 def try_concat(self, *others):
100 # type: (*RegShape | Reg | RegClass | None) -> RegShape | None
101 kind = self.kind
102 length = self.length
103 for other in others:
104 if isinstance(other, (Reg, RegClass)):
105 other = other.shape
106 if other is None:
107 return None
108 if other.kind != self.kind:
109 return None
110 length += other.length
111 if length > kind.reg_count:
112 return None
113 return RegShape(kind=kind, length=length)
114
115
116 @plain_data(frozen=True, unsafe_hash=True)
117 @final
118 class Reg:
119 __slots__ = "shape", "start"
120
121 def __init__(self, shape, start):
122 # type: (RegShape, int) -> None
123 self.shape = shape
124 if start < 0 or start + shape.length > shape.kind.reg_count:
125 raise ValueError("start not in valid range")
126 self.start = start
127
128 @property
129 def kind(self):
130 return self.shape.kind
131
132 @property
133 def length(self):
134 return self.shape.length
135
136 def conflicts(self, other):
137 # type: (Reg) -> bool
138 return (self.kind == other.kind
139 and self.start < other.stop and other.start < self.stop)
140
141 @property
142 def stop(self):
143 return self.start + self.length
144
145 def try_concat(self, *others):
146 # type: (*Reg | None) -> Reg | None
147 shape = self.shape.try_concat(*others)
148 if shape is None:
149 return None
150 stop = self.stop
151 for other in others:
152 assert other is not None, "already caught by RegShape.try_concat"
153 if stop != other.start:
154 return None
155 stop = other.stop
156 return Reg(shape, self.start)
157
158
159 @final
160 class RegClass(AbstractSet[Reg]):
161 def __init__(self, regs_or_starts=(), shape=None, starts_bitset=0):
162 # type: (Iterable[Reg | int], RegShape | None, int) -> None
163 for reg_or_start in regs_or_starts:
164 if isinstance(reg_or_start, Reg):
165 if shape is None:
166 shape = reg_or_start.shape
167 elif shape != reg_or_start.shape:
168 raise ValueError(f"conflicting RegShapes: {shape} and "
169 f"{reg_or_start.shape}")
170 start = reg_or_start.start
171 else:
172 start = reg_or_start
173 if start < 0:
174 raise ValueError("a Reg's start is out of range")
175 starts_bitset |= 1 << start
176 if starts_bitset == 0:
177 shape = None
178 self.__shape = shape
179 self.__starts_bitset = starts_bitset
180 if shape is None:
181 if starts_bitset != 0:
182 raise ValueError("non-empty RegClass must have non-None shape")
183 return
184 if self.stops_bitset >= 1 << shape.kind.reg_count:
185 raise ValueError("a Reg's start is out of range")
186
187 @property
188 def shape(self):
189 # type: () -> RegShape | None
190 return self.__shape
191
192 @property
193 def starts_bitset(self):
194 # type: () -> int
195 return self.__starts_bitset
196
197 @property
198 def stops_bitset(self):
199 # type: () -> int
200 if self.__shape is None:
201 return 0
202 return self.__starts_bitset << self.__shape.length
203
204 @cached_property
205 def starts(self):
206 # type: () -> OFSet[int]
207 if self.length is None:
208 return OFSet()
209 # TODO: fixme
210 # return OFSet(for i in range(self.length))
211
212 @cached_property
213 def stops(self):
214 # type: () -> OFSet[int]
215 if self.__shape is None:
216 return OFSet()
217 return OFSet(i + self.__shape.length for i in self.__starts)
218
219 @property
220 def kind(self):
221 if self.__shape is None:
222 return None
223 return self.__shape.kind
224
225 @property
226 def length(self):
227 """length of registers in this RegClass, not to be confused with the number of `Reg`s in self"""
228 if self.__shape is None:
229 return None
230 return self.__shape.length
231
232 def concat(self, *others):
233 # type: (*RegClass) -> RegClass
234 shape = self.__shape
235 if shape is None:
236 return RegClass()
237 shape = shape.try_concat(*others)
238 if shape is None:
239 return RegClass()
240 starts = OSet(self.starts)
241 offset = shape.length
242 for other in others:
243 assert other.__shape is not None, \
244 "already caught by RegShape.try_concat"
245 starts &= OSet(i - offset for i in other.starts)
246 offset += other.__shape.length
247 return RegClass(starts, shape=shape)
248
249 def __contains__(self, reg):
250 # type: (Reg) -> bool
251 return reg.shape == self.shape and reg.start in self.starts
252
253 def __iter__(self):
254 # type: () -> Iterator[Reg]
255 if self.shape is None:
256 return
257 for start in self.starts:
258 yield Reg(shape=self.shape, start=start)
259
260 def __len__(self):
261 return len(self.starts)
262
263 def __hash__(self):
264 return super()._hash()
265
266
267 @plain_data(frozen=True, unsafe_hash=True)
268 @final
269 class Operand:
270 __slots__ = "ty", "regs"
271
272 def __init__(self, ty, regs=None):
273 # type: (OperandType, OFSet[int] | None) -> None
274 pass
275
276
277 OT_VGPR = OperandType(RegKind.GPR, vec=True)
278 OT_SGPR = OperandType(RegKind.GPR, vec=False)
279 OT_CA = OperandType(RegKind.CA, vec=False)
280 OT_VL = OperandType(RegKind.VL_MAXVL, vec=False)
281
282
283 @plain_data(frozen=True, unsafe_hash=True)
284 class TiedOutput:
285 __slots__ = "input_index", "output_index"
286
287 def __init__(self, input_index, output_index):
288 # type: (int, int) -> None
289 self.input_index = input_index
290 self.output_index = output_index
291
292
293 Constraint = Union[TiedOutput, NoReturn]
294
295
296 @plain_data(frozen=True, unsafe_hash=True)
297 @final
298 class OpProperties:
299 __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "constraints",
300 "is_copy", "is_load_immediate", "has_side_effects")
301
302 def __init__(self, demo_asm, # type: str
303 inputs, # type: Iterable[OperandType]
304 outputs, # type: Iterable[OperandType]
305 immediates, # type: Iterable[range]
306 constraints, # type: Iterable[Constraint]
307 is_copy=False, # type: bool
308 is_load_immediate=False, # type: bool
309 has_side_effects=False, # type: bool
310 ):
311 # type: (...) -> None
312 self.demo_asm = demo_asm
313 self.inputs = tuple(inputs)
314 self.outputs = tuple(outputs)
315 self.immediates = tuple(immediates)
316 self.constraints = tuple(constraints)
317 self.is_copy = is_copy
318 self.is_load_immediate = is_load_immediate
319 self.has_side_effects = has_side_effects
320
321
322 @unique
323 @final
324 class OpKind(Enum):
325 def __init__(self, properties):
326 # type: (OpProperties) -> None
327 super().__init__()
328 self.properties = properties
329
330 SvAddE = OpProperties(
331 demo_asm="sv.adde *RT, *RA, *RB",
332 inputs=(OT_VGPR, OT_VGPR, OT_CA, OT_VL),
333 outputs=(OT_VGPR, OT_CA),
334 immediates=(),
335 constraints=(),
336 )
337 SvSubFE = OpProperties(
338 demo_asm="sv.subfe *RT, *RA, *RB",
339 inputs=(OT_VGPR, OT_VGPR, OT_CA, OT_VL),
340 outputs=(OT_VGPR, OT_CA),
341 immediates=(),
342 constraints=(),
343 )
344 SvMAddEDU = OpProperties(
345 demo_asm="sv.maddedu *RT, *RA, RB, RC",
346 inputs=(OT_VGPR, OT_SGPR, OT_SGPR, OT_VL),
347 outputs=(OT_VGPR, OT_SGPR),
348 immediates=(),
349 constraints=(),
350 )
351 SetVLI = OpProperties(
352 demo_asm="setvl 0, 0, imm, 0, 1, 1",
353 inputs=(),
354 outputs=(OT_VL,),
355 immediates=(range(1, 65),),
356 constraints=(),
357 is_load_immediate=True,
358 )
359 SvLI = OpProperties(
360 demo_asm="sv.addi *RT, 0, imm",
361 inputs=(OT_VL,),
362 outputs=(OT_VGPR,),
363 immediates=(range(-2 ** 15, 2 ** 15),),
364 constraints=(),
365 is_load_immediate=True,
366 )
367 LI = OpProperties(
368 demo_asm="addi RT, 0, imm",
369 inputs=(),
370 outputs=(OT_SGPR,),
371 immediates=(range(-2 ** 15, 2 ** 15),),
372 constraints=(),
373 is_load_immediate=True,
374 )
375 SvMv = OpProperties(
376 demo_asm="sv.or *RT, *src, *src",
377 inputs=(OT_VGPR, OT_VL),
378 outputs=(OT_VGPR,),
379 immediates=(),
380 constraints=(),
381 is_copy=True,
382 )
383 Mv = OpProperties(
384 demo_asm="mv RT, src",
385 inputs=(OT_SGPR,),
386 outputs=(OT_SGPR,),
387 immediates=(),
388 constraints=(),
389 is_copy=True,
390 )
391
392
393 @plain_data(frozen=True, unsafe_hash=True, repr=False)
394 @final
395 class SSAVal:
396 __slots__ = "sliced_op_outputs",
397
398 _SlicedOpOutputIn = Union["tuple[Op, int, int | range | slice]",
399 "tuple[Op, int]", "SSAVal"]
400
401 @staticmethod
402 def __process_sliced_op_outputs(inp):
403 # type: (Iterable[_SlicedOpOutputIn]) -> Iterable[Tuple["Op", int, range]]
404 for v in inp:
405 if isinstance(v, SSAVal):
406 yield from v.sliced_op_outputs
407 continue
408 op = v[0]
409 output_index = v[1]
410 if output_index < 0 or output_index >= len(op.properties.outputs):
411 raise ValueError("invalid output_index")
412 cur_len = op.properties.outputs[output_index].get_length(op.maxvl)
413 slice_ = slice(None) if len(v) == 2 else v[2]
414 if isinstance(slice_, range):
415 slice_ = slice(slice_.start, slice_.stop, slice_.step)
416 if isinstance(slice_, int):
417 # raise exception for out-of-range values
418 idx = range(cur_len)[slice_]
419 range_ = range(idx, idx + 1)
420 else:
421 # raise exception for out-of-range values
422 range_ = range(cur_len)[slice_]
423 if range_.step != 1:
424 raise ValueError("slice step must be 1")
425 if len(range_) == 0:
426 continue
427 yield op, output_index, range_
428
429 def __init__(self, sliced_op_outputs):
430 # type: (Iterable[_SlicedOpOutputIn] | SSAVal) -> None
431 # we have length arg so plain_data.replace works
432 if isinstance(sliced_op_outputs, SSAVal):
433 inp = sliced_op_outputs.sliced_op_outputs
434 else:
435 inp = SSAVal.__process_sliced_op_outputs(sliced_op_outputs)
436 processed = [] # type: list[tuple[Op, int, range]]
437 length = 0
438 for op, output_index, range_ in inp:
439 length += len(range_)
440 if len(processed) == 0:
441 processed.append((op, output_index, range_))
442 continue
443 last_op, last_output_index, last_range_ = processed[-1]
444 if last_op == op and last_output_index == output_index \
445 and last_range_.stop == range_.start:
446 # merge slices
447 range_ = range(last_range_.start, range_.stop)
448 processed[-1] = op, output_index, range_
449 else:
450 processed.append((op, output_index, range_))
451 self.sliced_op_outputs = tuple(processed)
452
453 def __add__(self, other):
454 # type: (SSAVal) -> SSAVal
455 if not isinstance(other, SSAVal):
456 return NotImplemented
457 return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs)
458
459 def __radd__(self, other):
460 # type: (SSAVal) -> SSAVal
461 if isinstance(other, SSAVal):
462 return other.__add__(self)
463 return NotImplemented
464
465 @cached_property
466 def expanded_sliced_op_outputs(self):
467 # type: () -> tuple[tuple[Op, int, int], ...]
468 retval = []
469 for op, output_index, range_ in self.sliced_op_outputs:
470 for i in range_:
471 retval.append((op, output_index, i))
472 # must be tuple to not be modifiable since it's cached
473 return tuple(retval)
474
475 def __getitem__(self, idx):
476 # type: (int | slice) -> SSAVal
477 if isinstance(idx, int):
478 return SSAVal([self.expanded_sliced_op_outputs[idx]])
479 return SSAVal(self.expanded_sliced_op_outputs[idx])
480
481 def __len__(self):
482 return len(self.expanded_sliced_op_outputs)
483
484 def __iter__(self):
485 # type: () -> Iterator[SSAVal]
486 for v in self.expanded_sliced_op_outputs:
487 yield SSAVal([v])
488
489 def __repr__(self):
490 # type: () -> str
491 if len(self.sliced_op_outputs) == 0:
492 return "SSAVal([])"
493 parts = []
494 for op, output_index, range_ in self.sliced_op_outputs:
495 out_len = op.properties.outputs[output_index].get_length(op.maxvl)
496 parts.append(f"<{op.name}#{output_index}>")
497 if range_ != range(out_len):
498 parts[-1] += f"[{range_.start}:{range_.stop}]"
499 return " + ".join(parts)
500
501
502 @plain_data(frozen=True, eq=False)
503 @final
504 class Op:
505 __slots__ = "fn", "kind", "inputs", "immediates", "outputs", "maxvl", "name"
506
507 def __init__(self, fn, kind, inputs, immediates, maxvl, name=""):
508 # type: (Fn, OpKind, Iterable[SSAVal], Iterable[int], int, str) -> None
509 self.fn = fn
510 self.kind = kind
511 self.inputs = list(inputs)
512 self.immediates = list(immediates)
513 self.maxvl = maxvl
514 outputs_len = len(self.properties.outputs)
515 self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len))
516 self.name = fn._add_op_with_unused_name(self, name)
517
518 @property
519 def properties(self):
520 return self.kind.properties
521
522 def __eq__(self, other):
523 if isinstance(other, Op):
524 return self is other
525 return NotImplemented
526
527 def __hash__(self):
528 return object.__hash__(self)