working on code
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir2.py
1 from collections import defaultdict
2 import enum
3 from abc import ABCMeta, abstractmethod
4 from enum import Enum, unique
5 from functools import lru_cache, total_ordering
6 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
7 Sequence, TypeVar, overload)
8 from weakref import WeakValueDictionary as _WeakVDict
9
10 from cached_property import cached_property
11 from nmutil.plain_data import fields, plain_data
12
13 from bigint_presentation_code.type_util import Self, assert_never, final, Literal
14 from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet
15
16
17 @final
18 class Fn:
19 def __init__(self):
20 self.ops = [] # type: list[Op]
21 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
22 self.__next_name_suffix = 2
23
24 def _add_op_with_unused_name(self, op, name=""):
25 # type: (Op, str) -> str
26 if op.fn is not self:
27 raise ValueError("can't add Op to wrong Fn")
28 if hasattr(op, "name"):
29 raise ValueError("Op already named")
30 orig_name = name
31 while True:
32 if name != "" and name not in self.__op_names:
33 self.__op_names[name] = op
34 return name
35 name = orig_name + str(self.__next_name_suffix)
36 self.__next_name_suffix += 1
37
38 def __repr__(self):
39 # type: () -> str
40 return "<Fn>"
41
42 def append_op(self, op):
43 # type: (Op) -> None
44 if op.fn is not self:
45 raise ValueError("can't add Op to wrong Fn")
46 self.ops.append(op)
47
48 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
49 maxvl=1):
50 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
51 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
52 input_vals=input_vals, immediates=immediates, name=name)
53 self.append_op(retval)
54 return retval
55
56 def pre_ra_sim(self, state):
57 # type: (PreRASimState) -> None
58 for op in self.ops:
59 op.pre_ra_sim(state)
60
61 def pre_ra_insert_copies(self):
62 # type: () -> None
63 orig_ops = list(self.ops)
64 copied_outputs = {} # type: dict[SSAVal, SSAVal]
65 self.ops.clear()
66 for op in orig_ops:
67 for i in range(len(op.input_vals)):
68 inp = copied_outputs[op.input_vals[i]]
69 if inp.ty.base_ty is BaseTy.I64:
70 maxvl = inp.ty.reg_len
71 if inp.ty.reg_len != 1:
72 setvl = self.append_new_op(
73 OpKind.SetVLI, immediates=[maxvl],
74 name=f"{op.name}.inp{i}.setvl")
75 vl = setvl.outputs[0]
76 mv = self.append_new_op(
77 OpKind.VecCopyToReg, input_vals=[inp, vl],
78 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
79 else:
80 mv = self.append_new_op(
81 OpKind.CopyToReg, input_vals=[inp],
82 name=f"{op.name}.inp{i}.copy")
83 op.input_vals[i] = mv.outputs[0]
84 elif inp.ty.base_ty is BaseTy.CA \
85 or inp.ty.base_ty is BaseTy.VL_MAXVL:
86 # all copies would be no-ops, so we don't need to copy
87 op.input_vals[i] = inp
88 else:
89 assert_never(inp.ty.base_ty)
90 self.ops.append(op)
91 for i, out in enumerate(op.outputs):
92 if out.ty.base_ty is BaseTy.I64:
93 maxvl = out.ty.reg_len
94 if out.ty.reg_len != 1:
95 setvl = self.append_new_op(
96 OpKind.SetVLI, immediates=[maxvl],
97 name=f"{op.name}.out{i}.setvl")
98 vl = setvl.outputs[0]
99 mv = self.append_new_op(
100 OpKind.VecCopyFromReg, input_vals=[out, vl],
101 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
102 else:
103 mv = self.append_new_op(
104 OpKind.CopyFromReg, input_vals=[out],
105 name=f"{op.name}.out{i}.copy")
106 copied_outputs[out] = mv.outputs[0]
107 elif out.ty.base_ty is BaseTy.CA \
108 or out.ty.base_ty is BaseTy.VL_MAXVL:
109 # all copies would be no-ops, so we don't need to copy
110 copied_outputs[out] = out
111 else:
112 assert_never(out.ty.base_ty)
113
114
115 @final
116 @unique
117 @total_ordering
118 class OpStage(Enum):
119 value: Literal[0, 1] # type: ignore
120
121 def __new__(cls, value):
122 # type: (int) -> OpStage
123 value = int(value)
124 if value not in (0, 1):
125 raise ValueError("invalid value")
126 retval = object.__new__(cls)
127 retval._value_ = value
128 return retval
129
130 Early = 0
131 """ early stage of Op execution, where all input reads occur.
132 all output writes with `write_stage == Early` occur here too, and therefore
133 conflict with input reads, telling the compiler that it that can't share
134 that output's register with any inputs that the output isn't tied to.
135
136 All outputs, even unused outputs, can't share registers with any other
137 outputs, independent of `write_stage` settings.
138 """
139 Late = 1
140 """ late stage of Op execution, where all output writes with
141 `write_stage == Late` occur, and therefore don't conflict with input reads,
142 telling the compiler that any inputs can safely use the same register as
143 those outputs.
144
145 All outputs, even unused outputs, can't share registers with any other
146 outputs, independent of `write_stage` settings.
147 """
148
149 def __repr__(self):
150 # type: () -> str
151 return f"OpStage.{self._name_}"
152
153 def __lt__(self, other):
154 # type: (OpStage | object) -> bool
155 if isinstance(other, OpStage):
156 return self.value < other.value
157 return NotImplemented
158
159
160 assert OpStage.Early < OpStage.Late, "early must be less than late"
161
162
163 @plain_data(frozen=True, unsafe_hash=True, repr=False)
164 @final
165 @total_ordering
166 class ProgramPoint:
167 __slots__ = "op_index", "stage"
168
169 def __init__(self, op_index, stage):
170 # type: (int, OpStage) -> None
171 self.op_index = op_index
172 self.stage = stage
173
174 @property
175 def int_value(self):
176 # type: () -> int
177 """ an integer representation of `self` such that it keeps ordering and
178 successor/predecessor relations.
179 """
180 return self.op_index * 2 + self.stage.value
181
182 @staticmethod
183 def from_int_value(int_value):
184 # type: (int) -> ProgramPoint
185 op_index, stage = divmod(int_value, 2)
186 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
187
188 def next(self, steps=1):
189 # type: (int) -> ProgramPoint
190 return ProgramPoint.from_int_value(self.int_value + steps)
191
192 def prev(self, steps=1):
193 # type: (int) -> ProgramPoint
194 return self.next(steps=-steps)
195
196 def __lt__(self, other):
197 # type: (ProgramPoint | Any) -> bool
198 if not isinstance(other, ProgramPoint):
199 return NotImplemented
200 if self.op_index != other.op_index:
201 return self.op_index < other.op_index
202 return self.stage < other.stage
203
204 def __repr__(self):
205 # type: () -> str
206 return f"<ops[{self.op_index}]:{self.stage._name_}>"
207
208
209 @plain_data(frozen=True, unsafe_hash=True, repr=False)
210 @final
211 class ProgramRange(Sequence[ProgramPoint]):
212 __slots__ = "start", "stop"
213
214 def __init__(self, start, stop):
215 # type: (ProgramPoint, ProgramPoint) -> None
216 self.start = start
217 self.stop = stop
218
219 @cached_property
220 def int_value_range(self):
221 # type: () -> range
222 return range(self.start.int_value, self.stop.int_value)
223
224 @staticmethod
225 def from_int_value_range(int_value_range):
226 # type: (range) -> ProgramRange
227 if int_value_range.step != 1:
228 raise ValueError("int_value_range must have step == 1")
229 return ProgramRange(
230 start=ProgramPoint.from_int_value(int_value_range.start),
231 stop=ProgramPoint.from_int_value(int_value_range.stop))
232
233 @overload
234 def __getitem__(self, __idx):
235 # type: (int) -> ProgramPoint
236 ...
237
238 @overload
239 def __getitem__(self, __idx):
240 # type: (slice) -> ProgramRange
241 ...
242
243 def __getitem__(self, __idx):
244 # type: (int | slice) -> ProgramPoint | ProgramRange
245 v = range(self.start.int_value, self.stop.int_value)[__idx]
246 if isinstance(v, int):
247 return ProgramPoint.from_int_value(v)
248 return ProgramRange.from_int_value_range(v)
249
250 def __len__(self):
251 # type: () -> int
252 return len(self.int_value_range)
253
254 def __iter__(self):
255 # type: () -> Iterator[ProgramPoint]
256 return map(ProgramPoint.from_int_value, self.int_value_range)
257
258 def __repr__(self):
259 # type: () -> str
260 start = repr(self.start).lstrip("<").rstrip(">")
261 stop = repr(self.stop).lstrip("<").rstrip(">")
262 return f"<range:{start}..{stop}>"
263
264
265 @plain_data(frozen=True, eq=False)
266 @final
267 class FnAnalysis:
268 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
269 "def_program_ranges", "use_program_points",
270 "all_program_points")
271
272 def __init__(self, fn):
273 # type: (Fn) -> None
274 self.fn = fn
275 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
276 self.all_program_points = ProgramRange(
277 start=ProgramPoint(op_index=0, stage=OpStage.Early),
278 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
279 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
280 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
281 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
282 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
283 for op in fn.ops:
284 for use in op.input_uses:
285 uses[use.ssa_val].add(use)
286 use_program_point = self.__get_use_program_point(use)
287 use_program_points[use] = use_program_point
288 live_range_stops[use.ssa_val] = max(
289 live_range_stops[use.ssa_val], use_program_point.next())
290 for out in op.outputs:
291 uses[out] = OSet()
292 def_program_range = self.__get_def_program_range(out)
293 def_program_ranges[out] = def_program_range
294 live_range_stops[out] = def_program_range.stop
295 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
296 self.def_program_ranges = FMap(def_program_ranges)
297 self.use_program_points = FMap(use_program_points)
298 live_ranges = {} # type: dict[SSAVal, ProgramRange]
299 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
300 for ssa_val in uses.keys():
301 live_ranges[ssa_val] = live_range = ProgramRange(
302 start=self.def_program_ranges[ssa_val].start,
303 stop=live_range_stops[ssa_val])
304 for program_point in live_range:
305 live_at[program_point].add(ssa_val)
306 self.live_ranges = FMap(live_ranges)
307 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
308
309 def __get_def_program_range(self, ssa_val):
310 # type: (SSAVal) -> ProgramRange
311 write_stage = ssa_val.defining_descriptor.write_stage
312 start = ProgramPoint(
313 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
314 # always include late stage of ssa_val.op, to ensure outputs always
315 # overlap all other outputs.
316 # stop is exclusive, so we need the next program point.
317 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
318 return ProgramRange(start=start, stop=stop)
319
320 def __get_use_program_point(self, ssa_use):
321 # type: (SSAUse) -> ProgramPoint
322 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
323 "assumed here, ensured by GenericOpProperties.__init__"
324 return ProgramPoint(
325 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
326
327 def __eq__(self, other):
328 # type: (FnAnalysis | Any) -> bool
329 if isinstance(other, FnAnalysis):
330 return self.fn == other.fn
331 return NotImplemented
332
333 def __hash__(self):
334 # type: () -> int
335 return hash(self.fn)
336
337
338 @unique
339 @final
340 class BaseTy(Enum):
341 I64 = enum.auto()
342 CA = enum.auto()
343 VL_MAXVL = enum.auto()
344
345 @cached_property
346 def only_scalar(self):
347 # type: () -> bool
348 if self is BaseTy.I64:
349 return False
350 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
351 return True
352 else:
353 assert_never(self)
354
355 @cached_property
356 def max_reg_len(self):
357 # type: () -> int
358 if self is BaseTy.I64:
359 return 128
360 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
361 return 1
362 else:
363 assert_never(self)
364
365 def __repr__(self):
366 return "BaseTy." + self._name_
367
368
369 @plain_data(frozen=True, unsafe_hash=True, repr=False)
370 @final
371 class Ty:
372 __slots__ = "base_ty", "reg_len"
373
374 @staticmethod
375 def validate(base_ty, reg_len):
376 # type: (BaseTy, int) -> str | None
377 """ return a string with the error if the combination is invalid,
378 otherwise return None
379 """
380 if base_ty.only_scalar and reg_len != 1:
381 return f"can't create a vector of an only-scalar type: {base_ty}"
382 if reg_len < 1 or reg_len > base_ty.max_reg_len:
383 return "reg_len out of range"
384 return None
385
386 def __init__(self, base_ty, reg_len):
387 # type: (BaseTy, int) -> None
388 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
389 if msg is not None:
390 raise ValueError(msg)
391 self.base_ty = base_ty
392 self.reg_len = reg_len
393
394 def __repr__(self):
395 # type: () -> str
396 if self.reg_len != 1:
397 reg_len = f"*{self.reg_len}"
398 else:
399 reg_len = ""
400 return f"<{self.base_ty._name_}{reg_len}>"
401
402
403 @unique
404 @final
405 class LocKind(Enum):
406 GPR = enum.auto()
407 StackI64 = enum.auto()
408 CA = enum.auto()
409 VL_MAXVL = enum.auto()
410
411 @cached_property
412 def base_ty(self):
413 # type: () -> BaseTy
414 if self is LocKind.GPR or self is LocKind.StackI64:
415 return BaseTy.I64
416 if self is LocKind.CA:
417 return BaseTy.CA
418 if self is LocKind.VL_MAXVL:
419 return BaseTy.VL_MAXVL
420 else:
421 assert_never(self)
422
423 @cached_property
424 def loc_count(self):
425 # type: () -> int
426 if self is LocKind.StackI64:
427 return 1024
428 if self is LocKind.GPR or self is LocKind.CA \
429 or self is LocKind.VL_MAXVL:
430 return self.base_ty.max_reg_len
431 else:
432 assert_never(self)
433
434 def __repr__(self):
435 return "LocKind." + self._name_
436
437
438 @final
439 @unique
440 class LocSubKind(Enum):
441 BASE_GPR = enum.auto()
442 SV_EXTRA2_VGPR = enum.auto()
443 SV_EXTRA2_SGPR = enum.auto()
444 SV_EXTRA3_VGPR = enum.auto()
445 SV_EXTRA3_SGPR = enum.auto()
446 StackI64 = enum.auto()
447 CA = enum.auto()
448 VL_MAXVL = enum.auto()
449
450 @cached_property
451 def kind(self):
452 # type: () -> LocKind
453 # pyright fails typechecking when using `in` here:
454 # reported: https://github.com/microsoft/pyright/issues/4102
455 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
456 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
457 LocSubKind.SV_EXTRA3_SGPR):
458 return LocKind.GPR
459 if self is LocSubKind.StackI64:
460 return LocKind.StackI64
461 if self is LocSubKind.CA:
462 return LocKind.CA
463 if self is LocSubKind.VL_MAXVL:
464 return LocKind.VL_MAXVL
465 assert_never(self)
466
467 @property
468 def base_ty(self):
469 return self.kind.base_ty
470
471 @lru_cache()
472 def allocatable_locs(self, ty):
473 # type: (Ty) -> LocSet
474 if ty.base_ty != self.base_ty:
475 raise ValueError("type mismatch")
476 if self is LocSubKind.BASE_GPR:
477 starts = range(32)
478 elif self is LocSubKind.SV_EXTRA2_VGPR:
479 starts = range(0, 128, 2)
480 elif self is LocSubKind.SV_EXTRA2_SGPR:
481 starts = range(64)
482 elif self is LocSubKind.SV_EXTRA3_VGPR \
483 or self is LocSubKind.SV_EXTRA3_SGPR:
484 starts = range(128)
485 elif self is LocSubKind.StackI64:
486 starts = range(LocKind.StackI64.loc_count)
487 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
488 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
489 else:
490 assert_never(self)
491 retval = [] # type: list[Loc]
492 for start in starts:
493 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
494 if loc is None:
495 continue
496 conflicts = False
497 for special_loc in SPECIAL_GPRS:
498 if loc.conflicts(special_loc):
499 conflicts = True
500 break
501 if not conflicts:
502 retval.append(loc)
503 return LocSet(retval)
504
505 def __repr__(self):
506 return "LocSubKind." + self._name_
507
508
509 @plain_data(frozen=True, unsafe_hash=True)
510 @final
511 class GenericTy:
512 __slots__ = "base_ty", "is_vec"
513
514 def __init__(self, base_ty, is_vec):
515 # type: (BaseTy, bool) -> None
516 self.base_ty = base_ty
517 if base_ty.only_scalar and is_vec:
518 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
519 self.is_vec = is_vec
520
521 def instantiate(self, maxvl):
522 # type: (int) -> Ty
523 # here's where subvl and elwid would be accounted for
524 if self.is_vec:
525 return Ty(self.base_ty, maxvl)
526 return Ty(self.base_ty, 1)
527
528 def can_instantiate_to(self, ty):
529 # type: (Ty) -> bool
530 if self.base_ty != ty.base_ty:
531 return False
532 if self.is_vec:
533 return True
534 return ty.reg_len == 1
535
536
537 @plain_data(frozen=True, unsafe_hash=True)
538 @final
539 class Loc:
540 __slots__ = "kind", "start", "reg_len"
541
542 @staticmethod
543 def validate(kind, start, reg_len):
544 # type: (LocKind, int, int) -> str | None
545 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
546 if msg is not None:
547 return msg
548 if reg_len > kind.loc_count:
549 return "invalid reg_len"
550 if start < 0 or start + reg_len > kind.loc_count:
551 return "start not in valid range"
552 return None
553
554 @staticmethod
555 def try_make(kind, start, reg_len):
556 # type: (LocKind, int, int) -> Loc | None
557 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
558 if msg is not None:
559 return None
560 return Loc(kind=kind, start=start, reg_len=reg_len)
561
562 def __init__(self, kind, start, reg_len):
563 # type: (LocKind, int, int) -> None
564 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
565 if msg is not None:
566 raise ValueError(msg)
567 self.kind = kind
568 self.reg_len = reg_len
569 self.start = start
570
571 def conflicts(self, other):
572 # type: (Loc) -> bool
573 return (self.kind == other.kind
574 and self.start < other.stop and other.start < self.stop)
575
576 @staticmethod
577 def make_ty(kind, reg_len):
578 # type: (LocKind, int) -> Ty
579 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
580
581 @cached_property
582 def ty(self):
583 # type: () -> Ty
584 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
585
586 @property
587 def stop(self):
588 # type: () -> int
589 return self.start + self.reg_len
590
591 def try_concat(self, *others):
592 # type: (*Loc | None) -> Loc | None
593 reg_len = self.reg_len
594 stop = self.stop
595 for other in others:
596 if other is None or other.kind != self.kind:
597 return None
598 if stop != other.start:
599 return None
600 stop = other.stop
601 reg_len += other.reg_len
602 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
603
604
605 SPECIAL_GPRS = (
606 Loc(kind=LocKind.GPR, start=0, reg_len=1),
607 Loc(kind=LocKind.GPR, start=1, reg_len=1),
608 Loc(kind=LocKind.GPR, start=2, reg_len=1),
609 Loc(kind=LocKind.GPR, start=13, reg_len=1),
610 )
611
612
613 @plain_data(frozen=True, eq=False)
614 @final
615 class LocSet(AbstractSet[Loc]):
616 __slots__ = "starts", "ty"
617
618 def __init__(self, __locs=()):
619 # type: (Iterable[Loc]) -> None
620 if isinstance(__locs, LocSet):
621 self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
622 self.ty = __locs.ty # type: Ty | None
623 return
624 starts = {i: BitSet() for i in LocKind}
625 ty = None
626 for loc in __locs:
627 if ty is None:
628 ty = loc.ty
629 if ty != loc.ty:
630 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
631 starts[loc.kind].add(loc.start)
632 self.starts = FMap(
633 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
634 self.ty = ty
635
636 @cached_property
637 def stops(self):
638 # type: () -> FMap[LocKind, FBitSet]
639 if self.ty is None:
640 return FMap()
641 sh = self.ty.reg_len
642 return FMap(
643 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
644
645 @property
646 def kinds(self):
647 # type: () -> AbstractSet[LocKind]
648 return self.starts.keys()
649
650 @property
651 def reg_len(self):
652 # type: () -> int | None
653 if self.ty is None:
654 return None
655 return self.ty.reg_len
656
657 @property
658 def base_ty(self):
659 # type: () -> BaseTy | None
660 if self.ty is None:
661 return None
662 return self.ty.base_ty
663
664 def concat(self, *others):
665 # type: (*LocSet) -> LocSet
666 if self.ty is None:
667 return LocSet()
668 base_ty = self.ty.base_ty
669 reg_len = self.ty.reg_len
670 starts = {k: BitSet(v) for k, v in self.starts.items()}
671 for other in others:
672 if other.ty is None:
673 return LocSet()
674 if other.ty.base_ty != base_ty:
675 return LocSet()
676 for kind, other_starts in other.starts.items():
677 if kind not in starts:
678 continue
679 starts[kind].bits &= other_starts.bits >> reg_len
680 if starts[kind] == 0:
681 del starts[kind]
682 if len(starts) == 0:
683 return LocSet()
684 reg_len += other.ty.reg_len
685
686 def locs():
687 # type: () -> Iterable[Loc]
688 for kind, v in starts.items():
689 for start in v:
690 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
691 if loc is not None:
692 yield loc
693 return LocSet(locs())
694
695 def __contains__(self, loc):
696 # type: (Loc | Any) -> bool
697 if not isinstance(loc, Loc) or loc.ty != self.ty:
698 return False
699 if loc.kind not in self.starts:
700 return False
701 return loc.start in self.starts[loc.kind]
702
703 def __iter__(self):
704 # type: () -> Iterator[Loc]
705 if self.ty is None:
706 return
707 for kind, starts in self.starts.items():
708 for start in starts:
709 yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
710
711 @cached_property
712 def __len(self):
713 return sum((len(v) for v in self.starts.values()), 0)
714
715 def __len__(self):
716 return self.__len
717
718 @cached_property
719 def __hash(self):
720 return super()._hash()
721
722 def __hash__(self):
723 return self.__hash
724
725 @lru_cache(maxsize=None, typed=True)
726 def max_conflicts_with(self, other):
727 # type: (LocSet | Loc) -> int
728 """the largest number of Locs in `self` that a single Loc
729 from `other` can conflict with
730 """
731 if isinstance(other, LocSet):
732 return max(self.max_conflicts_with(i) for i in other)
733 else:
734 return sum(other.conflicts(i) for i in self)
735
736
737 @plain_data(frozen=True, unsafe_hash=True)
738 @final
739 class GenericOperandDesc:
740 """generic Op operand descriptor"""
741 __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
742 "write_stage")
743
744 def __init__(
745 self, ty, # type: GenericTy
746 sub_kinds, # type: Iterable[LocSubKind]
747 *,
748 fixed_loc=None, # type: Loc | None
749 tied_input_index=None, # type: int | None
750 spread=False, # type: bool
751 write_stage=OpStage.Early, # type: OpStage
752 ):
753 # type: (...) -> None
754 self.ty = ty
755 self.sub_kinds = OFSet(sub_kinds)
756 if len(self.sub_kinds) == 0:
757 raise ValueError("sub_kinds can't be empty")
758 self.fixed_loc = fixed_loc
759 if fixed_loc is not None:
760 if tied_input_index is not None:
761 raise ValueError("operand can't be both tied and fixed")
762 if not ty.can_instantiate_to(fixed_loc.ty):
763 raise ValueError(
764 f"fixed_loc has incompatible type for given generic "
765 f"type: fixed_loc={fixed_loc} generic ty={ty}")
766 if len(self.sub_kinds) != 1:
767 raise ValueError(
768 "multiple sub_kinds not allowed for fixed operand")
769 for sub_kind in self.sub_kinds:
770 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
771 raise ValueError(
772 f"fixed_loc not in given sub_kind: "
773 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
774 for sub_kind in self.sub_kinds:
775 if sub_kind.base_ty != ty.base_ty:
776 raise ValueError(f"sub_kind is incompatible with type: "
777 f"sub_kind={sub_kind} ty={ty}")
778 if tied_input_index is not None and tied_input_index < 0:
779 raise ValueError("invalid tied_input_index")
780 self.tied_input_index = tied_input_index
781 self.spread = spread
782 if spread:
783 if self.tied_input_index is not None:
784 raise ValueError("operand can't be both spread and tied")
785 if self.fixed_loc is not None:
786 raise ValueError("operand can't be both spread and fixed")
787 if self.ty.is_vec:
788 raise ValueError("operand can't be both spread and vector")
789 self.write_stage = write_stage
790
791 def tied_to_input(self, tied_input_index):
792 # type: (int) -> Self
793 return GenericOperandDesc(self.ty, self.sub_kinds,
794 tied_input_index=tied_input_index,
795 write_stage=self.write_stage)
796
797 def with_fixed_loc(self, fixed_loc):
798 # type: (Loc) -> Self
799 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
800 write_stage=self.write_stage)
801
802 def with_write_stage(self, write_stage):
803 # type: (OpStage) -> Self
804 return GenericOperandDesc(self.ty, self.sub_kinds,
805 fixed_loc=self.fixed_loc,
806 tied_input_index=self.tied_input_index,
807 spread=self.spread,
808 write_stage=write_stage)
809
810 def instantiate(self, maxvl):
811 # type: (int) -> Iterable[OperandDesc]
812 # assumes all spread operands have ty.reg_len = 1
813 rep_count = 1
814 if self.spread:
815 rep_count = maxvl
816 maxvl = 1
817 ty = self.ty.instantiate(maxvl=maxvl)
818
819 def locs():
820 # type: () -> Iterable[Loc]
821 if self.fixed_loc is not None:
822 if ty != self.fixed_loc.ty:
823 raise ValueError(
824 f"instantiation failed: type mismatch with fixed_loc: "
825 f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
826 yield self.fixed_loc
827 return
828 for sub_kind in self.sub_kinds:
829 yield from sub_kind.allocatable_locs(ty)
830 loc_set_before_spread = LocSet(locs())
831 for idx in range(rep_count):
832 if not self.spread:
833 idx = None
834 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
835 tied_input_index=self.tied_input_index,
836 spread_index=idx, write_stage=self.write_stage)
837
838
839 @plain_data(frozen=True, unsafe_hash=True)
840 @final
841 class OperandDesc:
842 """Op operand descriptor"""
843 __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
844 "write_stage")
845
846 def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
847 write_stage):
848 # type: (LocSet, int | None, int | None, OpStage) -> None
849 if len(loc_set_before_spread) == 0:
850 raise ValueError("loc_set_before_spread must not be empty")
851 self.loc_set_before_spread = loc_set_before_spread
852 self.tied_input_index = tied_input_index
853 if self.tied_input_index is not None and self.spread_index is not None:
854 raise ValueError("operand can't be both spread and tied")
855 self.spread_index = spread_index
856 self.write_stage = write_stage
857
858 @cached_property
859 def ty_before_spread(self):
860 # type: () -> Ty
861 ty = self.loc_set_before_spread.ty
862 assert ty is not None, (
863 "__init__ checked that the LocSet isn't empty, "
864 "non-empty LocSets should always have ty set")
865 return ty
866
867 @cached_property
868 def ty(self):
869 """ Ty after any spread is applied """
870 if self.spread_index is not None:
871 # assumes all spread operands have ty.reg_len = 1
872 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
873 return self.ty_before_spread
874
875 @property
876 def reg_offset_in_unspread(self):
877 """ the number of reg-sized slots in the unspread Loc before self's Loc
878
879 e.g. if the unspread Loc containing self is:
880 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
881 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
882 then reg_offset_into_unspread == 2 == 10 - 8
883 """
884 if self.spread_index is None:
885 return 0
886 return self.spread_index * self.ty.reg_len
887
888
889 OD_BASE_SGPR = GenericOperandDesc(
890 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
891 sub_kinds=[LocSubKind.BASE_GPR])
892 OD_EXTRA3_SGPR = GenericOperandDesc(
893 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
894 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
895 OD_EXTRA3_VGPR = GenericOperandDesc(
896 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
897 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
898 OD_EXTRA2_SGPR = GenericOperandDesc(
899 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
900 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
901 OD_EXTRA2_VGPR = GenericOperandDesc(
902 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
903 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
904 OD_CA = GenericOperandDesc(
905 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
906 sub_kinds=[LocSubKind.CA])
907 OD_VL = GenericOperandDesc(
908 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
909 sub_kinds=[LocSubKind.VL_MAXVL])
910
911
912 @plain_data(frozen=True, unsafe_hash=True)
913 @final
914 class GenericOpProperties:
915 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
916 "is_copy", "is_load_immediate", "has_side_effects")
917
918 def __init__(
919 self, demo_asm, # type: str
920 inputs, # type: Iterable[GenericOperandDesc]
921 outputs, # type: Iterable[GenericOperandDesc]
922 immediates=(), # type: Iterable[range]
923 is_copy=False, # type: bool
924 is_load_immediate=False, # type: bool
925 has_side_effects=False, # type: bool
926 ):
927 # type: (...) -> None
928 self.demo_asm = demo_asm # type: str
929 self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
930 for inp in self.inputs:
931 if inp.tied_input_index is not None:
932 raise ValueError(
933 f"tied_input_index is not allowed on inputs: {inp}")
934 if inp.write_stage is not OpStage.Early:
935 raise ValueError(
936 f"write_stage is not allowed on inputs: {inp}")
937 self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
938 fixed_locs = [] # type: list[tuple[Loc, int]]
939 for idx, out in enumerate(self.outputs):
940 if out.tied_input_index is not None:
941 if out.tied_input_index >= len(self.inputs):
942 raise ValueError(f"tied_input_index out of range: {out}")
943 tied_inp = self.inputs[out.tied_input_index]
944 if tied_inp.tied_to_input(out.tied_input_index) != out:
945 raise ValueError(f"output can't be tied to non-equivalent "
946 f"input: {out} tied to {tied_inp}")
947 if out.fixed_loc is not None:
948 for other_fixed_loc, other_idx in fixed_locs:
949 if not other_fixed_loc.conflicts(out.fixed_loc):
950 continue
951 raise ValueError(
952 f"conflicting fixed_locs: outputs[{idx}] and "
953 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
954 f"with {other_fixed_loc}")
955 fixed_locs.append((out.fixed_loc, idx))
956 self.immediates = tuple(immediates) # type: tuple[range, ...]
957 self.is_copy = is_copy # type: bool
958 self.is_load_immediate = is_load_immediate # type: bool
959 self.has_side_effects = has_side_effects # type: bool
960
961
962 @plain_data(frozen=True, unsafe_hash=True)
963 @final
964 class OpProperties:
965 __slots__ = "kind", "inputs", "outputs", "maxvl"
966
967 def __init__(self, kind, maxvl):
968 # type: (OpKind, int) -> None
969 self.kind = kind # type: OpKind
970 inputs = [] # type: list[OperandDesc]
971 for inp in self.generic.inputs:
972 inputs.extend(inp.instantiate(maxvl=maxvl))
973 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
974 outputs = [] # type: list[OperandDesc]
975 for out in self.generic.outputs:
976 outputs.extend(out.instantiate(maxvl=maxvl))
977 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
978 self.maxvl = maxvl # type: int
979
980 @property
981 def generic(self):
982 # type: () -> GenericOpProperties
983 return self.kind.properties
984
985 @property
986 def immediates(self):
987 # type: () -> tuple[range, ...]
988 return self.generic.immediates
989
990 @property
991 def demo_asm(self):
992 # type: () -> str
993 return self.generic.demo_asm
994
995 @property
996 def is_copy(self):
997 # type: () -> bool
998 return self.generic.is_copy
999
1000 @property
1001 def is_load_immediate(self):
1002 # type: () -> bool
1003 return self.generic.is_load_immediate
1004
1005 @property
1006 def has_side_effects(self):
1007 # type: () -> bool
1008 return self.generic.has_side_effects
1009
1010
1011 IMM_S16 = range(-1 << 15, 1 << 15)
1012
1013 _PRE_RA_SIM_FN = Callable[["Op", "PreRASimState"], None]
1014 _PRE_RA_SIM_FN2 = Callable[[], _PRE_RA_SIM_FN]
1015 _PRE_RA_SIMS = {} # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2]
1016
1017
1018 @unique
1019 @final
1020 class OpKind(Enum):
1021 def __init__(self, properties):
1022 # type: (GenericOpProperties) -> None
1023 super().__init__()
1024 self.__properties = properties
1025
1026 @property
1027 def properties(self):
1028 # type: () -> GenericOpProperties
1029 return self.__properties
1030
1031 def instantiate(self, maxvl):
1032 # type: (int) -> OpProperties
1033 return OpProperties(self, maxvl=maxvl)
1034
1035 def __repr__(self):
1036 # type: () -> str
1037 return "OpKind." + self._name_
1038
1039 @cached_property
1040 def pre_ra_sim(self):
1041 # type: () -> _PRE_RA_SIM_FN
1042 return _PRE_RA_SIMS[self.properties]()
1043
1044 @staticmethod
1045 def __clearca_pre_ra_sim(op, state):
1046 # type: (Op, PreRASimState) -> None
1047 state.ssa_vals[op.outputs[0]] = False,
1048 ClearCA = GenericOpProperties(
1049 demo_asm="addic 0, 0, 0",
1050 inputs=[],
1051 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1052 )
1053 _PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim
1054
1055 @staticmethod
1056 def __setca_pre_ra_sim(op, state):
1057 # type: (Op, PreRASimState) -> None
1058 state.ssa_vals[op.outputs[0]] = True,
1059 SetCA = GenericOpProperties(
1060 demo_asm="subfc 0, 0, 0",
1061 inputs=[],
1062 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1063 )
1064 _PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim
1065
1066 @staticmethod
1067 def __svadde_pre_ra_sim(op, state):
1068 # type: (Op, PreRASimState) -> None
1069 RA = state.ssa_vals[op.input_vals[0]]
1070 RB = state.ssa_vals[op.input_vals[1]]
1071 carry, = state.ssa_vals[op.input_vals[2]]
1072 VL, = state.ssa_vals[op.input_vals[3]]
1073 RT = [] # type: list[int]
1074 for i in range(VL):
1075 v = RA[i] + RB[i] + carry
1076 RT.append(v & GPR_VALUE_MASK)
1077 carry = (v >> GPR_SIZE_IN_BITS) != 0
1078 state.ssa_vals[op.outputs[0]] = tuple(RT)
1079 state.ssa_vals[op.outputs[1]] = carry,
1080 SvAddE = GenericOpProperties(
1081 demo_asm="sv.adde *RT, *RA, *RB",
1082 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1083 outputs=[OD_EXTRA3_VGPR, OD_CA],
1084 )
1085 _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
1086
1087 @staticmethod
1088 def __svsubfe_pre_ra_sim(op, state):
1089 # type: (Op, PreRASimState) -> None
1090 RA = state.ssa_vals[op.input_vals[0]]
1091 RB = state.ssa_vals[op.input_vals[1]]
1092 carry, = state.ssa_vals[op.input_vals[2]]
1093 VL, = state.ssa_vals[op.input_vals[3]]
1094 RT = [] # type: list[int]
1095 for i in range(VL):
1096 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1097 RT.append(v & GPR_VALUE_MASK)
1098 carry = (v >> GPR_SIZE_IN_BITS) != 0
1099 state.ssa_vals[op.outputs[0]] = tuple(RT)
1100 state.ssa_vals[op.outputs[1]] = carry,
1101 SvSubFE = GenericOpProperties(
1102 demo_asm="sv.subfe *RT, *RA, *RB",
1103 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1104 outputs=[OD_EXTRA3_VGPR, OD_CA],
1105 )
1106 _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim
1107
1108 @staticmethod
1109 def __svmaddedu_pre_ra_sim(op, state):
1110 # type: (Op, PreRASimState) -> None
1111 RA = state.ssa_vals[op.input_vals[0]]
1112 RB, = state.ssa_vals[op.input_vals[1]]
1113 carry, = state.ssa_vals[op.input_vals[2]]
1114 VL, = state.ssa_vals[op.input_vals[3]]
1115 RT = [] # type: list[int]
1116 for i in range(VL):
1117 v = RA[i] * RB + carry
1118 RT.append(v & GPR_VALUE_MASK)
1119 carry = v >> GPR_SIZE_IN_BITS
1120 state.ssa_vals[op.outputs[0]] = tuple(RT)
1121 state.ssa_vals[op.outputs[1]] = carry,
1122 SvMAddEDU = GenericOpProperties(
1123 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1124 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1125 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1126 )
1127 _PRE_RA_SIMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_pre_ra_sim
1128
1129 @staticmethod
1130 def __setvli_pre_ra_sim(op, state):
1131 # type: (Op, PreRASimState) -> None
1132 state.ssa_vals[op.outputs[0]] = op.immediates[0],
1133 SetVLI = GenericOpProperties(
1134 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1135 inputs=(),
1136 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1137 immediates=[range(1, 65)],
1138 is_load_immediate=True,
1139 )
1140 _PRE_RA_SIMS[SetVLI] = lambda: OpKind.__setvli_pre_ra_sim
1141
1142 @staticmethod
1143 def __svli_pre_ra_sim(op, state):
1144 # type: (Op, PreRASimState) -> None
1145 VL, = state.ssa_vals[op.input_vals[0]]
1146 imm = op.immediates[0] & GPR_VALUE_MASK
1147 state.ssa_vals[op.outputs[0]] = (imm,) * VL
1148 SvLI = GenericOpProperties(
1149 demo_asm="sv.addi *RT, 0, imm",
1150 inputs=[OD_VL],
1151 outputs=[OD_EXTRA3_VGPR],
1152 immediates=[IMM_S16],
1153 is_load_immediate=True,
1154 )
1155 _PRE_RA_SIMS[SvLI] = lambda: OpKind.__svli_pre_ra_sim
1156
1157 @staticmethod
1158 def __li_pre_ra_sim(op, state):
1159 # type: (Op, PreRASimState) -> None
1160 imm = op.immediates[0] & GPR_VALUE_MASK
1161 state.ssa_vals[op.outputs[0]] = imm,
1162 LI = GenericOpProperties(
1163 demo_asm="addi RT, 0, imm",
1164 inputs=(),
1165 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1166 immediates=[IMM_S16],
1167 is_load_immediate=True,
1168 )
1169 _PRE_RA_SIMS[LI] = lambda: OpKind.__li_pre_ra_sim
1170
1171 @staticmethod
1172 def __veccopytoreg_pre_ra_sim(op, state):
1173 # type: (Op, PreRASimState) -> None
1174 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
1175 VecCopyToReg = GenericOpProperties(
1176 demo_asm="sv.mv dest, src",
1177 inputs=[GenericOperandDesc(
1178 ty=GenericTy(BaseTy.I64, is_vec=True),
1179 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1180 ), OD_VL],
1181 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1182 is_copy=True,
1183 )
1184 _PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim
1185
1186 @staticmethod
1187 def __veccopyfromreg_pre_ra_sim(op, state):
1188 # type: (Op, PreRASimState) -> None
1189 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
1190 VecCopyFromReg = GenericOpProperties(
1191 demo_asm="sv.mv dest, src",
1192 inputs=[OD_EXTRA3_VGPR, OD_VL],
1193 outputs=[GenericOperandDesc(
1194 ty=GenericTy(BaseTy.I64, is_vec=True),
1195 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1196 write_stage=OpStage.Late,
1197 )],
1198 is_copy=True,
1199 )
1200 _PRE_RA_SIMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_pre_ra_sim
1201
1202 @staticmethod
1203 def __copytoreg_pre_ra_sim(op, state):
1204 # type: (Op, PreRASimState) -> None
1205 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
1206 CopyToReg = GenericOpProperties(
1207 demo_asm="mv dest, src",
1208 inputs=[GenericOperandDesc(
1209 ty=GenericTy(BaseTy.I64, is_vec=False),
1210 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1211 LocSubKind.StackI64],
1212 )],
1213 outputs=[GenericOperandDesc(
1214 ty=GenericTy(BaseTy.I64, is_vec=False),
1215 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1216 write_stage=OpStage.Late,
1217 )],
1218 is_copy=True,
1219 )
1220 _PRE_RA_SIMS[CopyToReg] = lambda: OpKind.__copytoreg_pre_ra_sim
1221
1222 @staticmethod
1223 def __copyfromreg_pre_ra_sim(op, state):
1224 # type: (Op, PreRASimState) -> None
1225 state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
1226 CopyFromReg = GenericOpProperties(
1227 demo_asm="mv dest, src",
1228 inputs=[GenericOperandDesc(
1229 ty=GenericTy(BaseTy.I64, is_vec=False),
1230 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1231 )],
1232 outputs=[GenericOperandDesc(
1233 ty=GenericTy(BaseTy.I64, is_vec=False),
1234 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1235 LocSubKind.StackI64],
1236 write_stage=OpStage.Late,
1237 )],
1238 is_copy=True,
1239 )
1240 _PRE_RA_SIMS[CopyFromReg] = lambda: OpKind.__copyfromreg_pre_ra_sim
1241
1242 @staticmethod
1243 def __concat_pre_ra_sim(op, state):
1244 # type: (Op, PreRASimState) -> None
1245 state.ssa_vals[op.outputs[0]] = tuple(
1246 state.ssa_vals[i][0] for i in op.input_vals[:-1])
1247 Concat = GenericOpProperties(
1248 demo_asm="sv.mv dest, src",
1249 inputs=[GenericOperandDesc(
1250 ty=GenericTy(BaseTy.I64, is_vec=False),
1251 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1252 spread=True,
1253 ), OD_VL],
1254 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1255 is_copy=True,
1256 )
1257 _PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim
1258
1259 @staticmethod
1260 def __spread_pre_ra_sim(op, state):
1261 # type: (Op, PreRASimState) -> None
1262 for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]):
1263 state.ssa_vals[op.outputs[idx]] = inp,
1264 Spread = GenericOpProperties(
1265 demo_asm="sv.mv dest, src",
1266 inputs=[OD_EXTRA3_VGPR, OD_VL],
1267 outputs=[GenericOperandDesc(
1268 ty=GenericTy(BaseTy.I64, is_vec=False),
1269 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1270 spread=True,
1271 write_stage=OpStage.Late,
1272 )],
1273 is_copy=True,
1274 )
1275 _PRE_RA_SIMS[Spread] = lambda: OpKind.__spread_pre_ra_sim
1276
1277 @staticmethod
1278 def __svld_pre_ra_sim(op, state):
1279 # type: (Op, PreRASimState) -> None
1280 RA, = state.ssa_vals[op.input_vals[0]]
1281 VL, = state.ssa_vals[op.input_vals[1]]
1282 addr = RA + op.immediates[0]
1283 RT = [] # type: list[int]
1284 for i in range(VL):
1285 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1286 RT.append(v & GPR_VALUE_MASK)
1287 state.ssa_vals[op.outputs[0]] = tuple(RT)
1288 SvLd = GenericOpProperties(
1289 demo_asm="sv.ld *RT, imm(RA)",
1290 inputs=[OD_EXTRA3_SGPR, OD_VL],
1291 outputs=[OD_EXTRA3_VGPR],
1292 immediates=[IMM_S16],
1293 )
1294 _PRE_RA_SIMS[SvLd] = lambda: OpKind.__svld_pre_ra_sim
1295
1296 @staticmethod
1297 def __ld_pre_ra_sim(op, state):
1298 # type: (Op, PreRASimState) -> None
1299 RA, = state.ssa_vals[op.input_vals[0]]
1300 addr = RA + op.immediates[0]
1301 v = state.load(addr)
1302 state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
1303 Ld = GenericOpProperties(
1304 demo_asm="ld RT, imm(RA)",
1305 inputs=[OD_BASE_SGPR],
1306 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1307 immediates=[IMM_S16],
1308 )
1309 _PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim
1310
1311 @staticmethod
1312 def __svstd_pre_ra_sim(op, state):
1313 # type: (Op, PreRASimState) -> None
1314 RS = state.ssa_vals[op.input_vals[0]]
1315 RA, = state.ssa_vals[op.input_vals[1]]
1316 VL, = state.ssa_vals[op.input_vals[2]]
1317 addr = RA + op.immediates[0]
1318 for i in range(VL):
1319 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1320 SvStd = GenericOpProperties(
1321 demo_asm="sv.std *RS, imm(RA)",
1322 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1323 outputs=[],
1324 immediates=[IMM_S16],
1325 has_side_effects=True,
1326 )
1327 _PRE_RA_SIMS[SvStd] = lambda: OpKind.__svstd_pre_ra_sim
1328
1329 @staticmethod
1330 def __std_pre_ra_sim(op, state):
1331 # type: (Op, PreRASimState) -> None
1332 RS, = state.ssa_vals[op.input_vals[0]]
1333 RA, = state.ssa_vals[op.input_vals[1]]
1334 addr = RA + op.immediates[0]
1335 state.store(addr, value=RS)
1336 Std = GenericOpProperties(
1337 demo_asm="std RT, imm(RA)",
1338 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1339 outputs=[],
1340 immediates=[IMM_S16],
1341 has_side_effects=True,
1342 )
1343 _PRE_RA_SIMS[Std] = lambda: OpKind.__std_pre_ra_sim
1344
1345 @staticmethod
1346 def __funcargr3_pre_ra_sim(op, state):
1347 # type: (Op, PreRASimState) -> None
1348 pass # return value set before simulation
1349 FuncArgR3 = GenericOpProperties(
1350 demo_asm="",
1351 inputs=[],
1352 outputs=[OD_BASE_SGPR.with_fixed_loc(
1353 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1354 )
1355 _PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim
1356
1357
1358 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1359 class SSAValOrUse(metaclass=ABCMeta):
1360 __slots__ = "op", "operand_idx"
1361
1362 def __init__(self, op, operand_idx):
1363 # type: (Op, int) -> None
1364 super().__init__()
1365 self.op = op
1366 if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
1367 raise ValueError("invalid operand_idx")
1368 self.operand_idx = operand_idx
1369
1370 @abstractmethod
1371 def __repr__(self):
1372 # type: () -> str
1373 ...
1374
1375 @property
1376 @abstractmethod
1377 def descriptor_array(self):
1378 # type: () -> tuple[OperandDesc, ...]
1379 ...
1380
1381 @cached_property
1382 def defining_descriptor(self):
1383 # type: () -> OperandDesc
1384 return self.descriptor_array[self.operand_idx]
1385
1386 @cached_property
1387 def ty(self):
1388 # type: () -> Ty
1389 return self.defining_descriptor.ty
1390
1391 @cached_property
1392 def ty_before_spread(self):
1393 # type: () -> Ty
1394 return self.defining_descriptor.ty_before_spread
1395
1396 @property
1397 def base_ty(self):
1398 # type: () -> BaseTy
1399 return self.ty_before_spread.base_ty
1400
1401 @property
1402 def reg_offset_in_unspread(self):
1403 """ the number of reg-sized slots in the unspread Loc before self's Loc
1404
1405 e.g. if the unspread Loc containing self is:
1406 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1407 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1408 then reg_offset_into_unspread == 2 == 10 - 8
1409 """
1410 return self.defining_descriptor.reg_offset_in_unspread
1411
1412 @property
1413 def unspread_start_idx(self):
1414 # type: () -> int
1415 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1416
1417 @property
1418 def unspread_start(self):
1419 # type: () -> Self
1420 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1421
1422
1423 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1424 @final
1425 class SSAVal(SSAValOrUse):
1426 __slots__ = ()
1427
1428 def __repr__(self):
1429 # type: () -> str
1430 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1431
1432 @cached_property
1433 def def_loc_set_before_spread(self):
1434 # type: () -> LocSet
1435 return self.defining_descriptor.loc_set_before_spread
1436
1437 @cached_property
1438 def descriptor_array(self):
1439 # type: () -> tuple[OperandDesc, ...]
1440 return self.op.properties.outputs
1441
1442 @cached_property
1443 def tied_input(self):
1444 # type: () -> None | SSAUse
1445 if self.defining_descriptor.tied_input_index is None:
1446 return None
1447 return SSAUse(op=self.op,
1448 operand_idx=self.defining_descriptor.tied_input_index)
1449
1450 @property
1451 def write_stage(self):
1452 # type: () -> OpStage
1453 return self.defining_descriptor.write_stage
1454
1455
1456 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1457 @final
1458 class SSAUse(SSAValOrUse):
1459 __slots__ = ()
1460
1461 @cached_property
1462 def use_loc_set_before_spread(self):
1463 # type: () -> LocSet
1464 return self.defining_descriptor.loc_set_before_spread
1465
1466 @cached_property
1467 def descriptor_array(self):
1468 # type: () -> tuple[OperandDesc, ...]
1469 return self.op.properties.inputs
1470
1471 def __repr__(self):
1472 # type: () -> str
1473 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1474
1475 @property
1476 def ssa_val(self):
1477 # type: () -> SSAVal
1478 return self.op.input_vals[self.operand_idx]
1479
1480 @ssa_val.setter
1481 def ssa_val(self, ssa_val):
1482 # type: (SSAVal) -> None
1483 self.op.input_vals[self.operand_idx] = ssa_val
1484
1485
1486 _T = TypeVar("_T")
1487 _Desc = TypeVar("_Desc")
1488
1489
1490 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
1491 @abstractmethod
1492 def _verify_write_with_desc(self, idx, item, desc):
1493 # type: (int, _T | Any, _Desc) -> None
1494 raise NotImplementedError
1495
1496 @final
1497 def _verify_write(self, idx, item):
1498 # type: (int | Any, _T | Any) -> int
1499 if not isinstance(idx, int):
1500 if isinstance(idx, slice):
1501 raise TypeError(
1502 f"can't write to slice of {self.__class__.__name__}")
1503 raise TypeError(f"can't write with index {idx!r}")
1504 # normalize idx, raising IndexError if it is out of range
1505 idx = range(len(self.descriptors))[idx]
1506 desc = self.descriptors[idx]
1507 self._verify_write_with_desc(idx, item, desc)
1508 return idx
1509
1510 def _on_set(self, idx, new_item, old_item):
1511 # type: (int, _T, _T | None) -> None
1512 pass
1513
1514 @abstractmethod
1515 def _get_descriptors(self):
1516 # type: () -> tuple[_Desc, ...]
1517 raise NotImplementedError
1518
1519 @cached_property
1520 @final
1521 def descriptors(self):
1522 # type: () -> tuple[_Desc, ...]
1523 return self._get_descriptors()
1524
1525 @property
1526 @final
1527 def op(self):
1528 return self.__op
1529
1530 def __init__(self, items, op):
1531 # type: (Iterable[_T], Op) -> None
1532 super().__init__()
1533 self.__op = op
1534 self.__items = [] # type: list[_T]
1535 for idx, item in enumerate(items):
1536 if idx >= len(self.descriptors):
1537 raise ValueError("too many items")
1538 _ = self._verify_write(idx, item)
1539 self.__items.append(item)
1540 if len(self.__items) < len(self.descriptors):
1541 raise ValueError("not enough items")
1542
1543 @final
1544 def __iter__(self):
1545 # type: () -> Iterator[_T]
1546 yield from self.__items
1547
1548 @overload
1549 def __getitem__(self, idx):
1550 # type: (int) -> _T
1551 ...
1552
1553 @overload
1554 def __getitem__(self, idx):
1555 # type: (slice) -> list[_T]
1556 ...
1557
1558 @final
1559 def __getitem__(self, idx):
1560 # type: (int | slice) -> _T | list[_T]
1561 return self.__items[idx]
1562
1563 @final
1564 def __setitem__(self, idx, item):
1565 # type: (int, _T) -> None
1566 idx = self._verify_write(idx, item)
1567 self.__items[idx] = item
1568
1569 @final
1570 def __len__(self):
1571 # type: () -> int
1572 return len(self.__items)
1573
1574 def __repr__(self):
1575 # type: () -> str
1576 return f"{self.__class__.__name__}({self.__items}, op=...)"
1577
1578
1579 @final
1580 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
1581 def _get_descriptors(self):
1582 # type: () -> tuple[OperandDesc, ...]
1583 return self.op.properties.inputs
1584
1585 def _verify_write_with_desc(self, idx, item, desc):
1586 # type: (int, SSAVal | Any, OperandDesc) -> None
1587 if not isinstance(item, SSAVal):
1588 raise TypeError("expected value of type SSAVal")
1589 if item.ty != desc.ty:
1590 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
1591 f"corresponding input's type {desc.ty!r}")
1592
1593 def _on_set(self, idx, new_item, old_item):
1594 # type: (int, SSAVal, SSAVal | None) -> None
1595 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
1596
1597 def __init__(self, items, op):
1598 # type: (Iterable[SSAVal], Op) -> None
1599 if hasattr(op, "inputs"):
1600 raise ValueError("Op.inputs already set")
1601 super().__init__(items, op)
1602
1603
1604 @final
1605 class OpImmediates(OpInputSeq[int, range]):
1606 def _get_descriptors(self):
1607 # type: () -> tuple[range, ...]
1608 return self.op.properties.immediates
1609
1610 def _verify_write_with_desc(self, idx, item, desc):
1611 # type: (int, int | Any, range) -> None
1612 if not isinstance(item, int):
1613 raise TypeError("expected value of type int")
1614 if item not in desc:
1615 raise ValueError(f"immediate value {item!r} not in {desc!r}")
1616
1617 def __init__(self, items, op):
1618 # type: (Iterable[int], Op) -> None
1619 if hasattr(op, "immediates"):
1620 raise ValueError("Op.immediates already set")
1621 super().__init__(items, op)
1622
1623
1624 @plain_data(frozen=True, eq=False, repr=False)
1625 @final
1626 class Op:
1627 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
1628 "outputs", "name")
1629
1630 def __init__(self, fn, properties, input_vals, immediates, name=""):
1631 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1632 self.fn = fn
1633 self.properties = properties
1634 self.input_vals = OpInputVals(input_vals, op=self)
1635 inputs_len = len(self.properties.inputs)
1636 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
1637 self.immediates = OpImmediates(immediates, op=self)
1638 outputs_len = len(self.properties.outputs)
1639 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
1640 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
1641
1642 @property
1643 def kind(self):
1644 # type: () -> OpKind
1645 return self.properties.kind
1646
1647 def __eq__(self, other):
1648 # type: (Op | Any) -> bool
1649 if isinstance(other, Op):
1650 return self is other
1651 return NotImplemented
1652
1653 def __hash__(self):
1654 # type: () -> int
1655 return object.__hash__(self)
1656
1657 def __repr__(self):
1658 # type: () -> str
1659 field_vals = [] # type: list[str]
1660 for name in fields(self):
1661 if name == "properties":
1662 name = "kind"
1663 elif name == "fn":
1664 continue
1665 try:
1666 value = getattr(self, name)
1667 except AttributeError:
1668 field_vals.append(f"{name}=<not set>")
1669 continue
1670 if isinstance(value, OpInputSeq):
1671 value = list(value) # type: ignore
1672 field_vals.append(f"{name}={value!r}")
1673 field_vals_str = ", ".join(field_vals)
1674 return f"Op({field_vals_str})"
1675
1676 def pre_ra_sim(self, state):
1677 # type: (PreRASimState) -> None
1678 for inp in self.input_vals:
1679 if inp not in state.ssa_vals:
1680 raise ValueError(f"SSAVal {inp} not yet assigned when "
1681 f"running {self}")
1682 if len(state.ssa_vals[inp]) != inp.ty.reg_len:
1683 raise ValueError(
1684 f"value of SSAVal {inp} has wrong number of elements: "
1685 f"expected {inp.ty.reg_len} found "
1686 f"{len(state.ssa_vals[inp])}: {state.ssa_vals[inp]!r}")
1687 for out in self.outputs:
1688 if out in state.ssa_vals:
1689 if self.kind is OpKind.FuncArgR3:
1690 continue
1691 raise ValueError(f"SSAVal {out} already assigned before "
1692 f"running {self}")
1693 self.kind.pre_ra_sim(self, state)
1694 for out in self.outputs:
1695 if out not in state.ssa_vals:
1696 raise ValueError(f"running {self} failed to assign to {out}")
1697 if len(state.ssa_vals[out]) != out.ty.reg_len:
1698 raise ValueError(
1699 f"value of SSAVal {out} has wrong number of elements: "
1700 f"expected {out.ty.reg_len} found "
1701 f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
1702
1703
1704 GPR_SIZE_IN_BYTES = 8
1705 BITS_IN_BYTE = 8
1706 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
1707 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
1708
1709
1710 @plain_data(frozen=True, repr=False)
1711 @final
1712 class PreRASimState:
1713 __slots__ = "ssa_vals", "memory"
1714
1715 def __init__(self, ssa_vals, memory):
1716 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
1717 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
1718 self.memory = memory # type: dict[int, int]
1719
1720 def load_byte(self, addr):
1721 # type: (int) -> int
1722 addr &= GPR_VALUE_MASK
1723 return self.memory.get(addr, 0) & 0xFF
1724
1725 def store_byte(self, addr, value):
1726 # type: (int, int) -> None
1727 addr &= GPR_VALUE_MASK
1728 value &= 0xFF
1729 self.memory[addr] = value
1730
1731 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
1732 # type: (int, int, bool) -> int
1733 if addr % size_in_bytes != 0:
1734 raise ValueError(f"address not aligned: {hex(addr)} "
1735 f"required alignment: {size_in_bytes}")
1736 retval = 0
1737 for i in range(size_in_bytes):
1738 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
1739 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
1740 retval -= 1 << size_in_bytes * BITS_IN_BYTE
1741 return retval
1742
1743 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
1744 # type: (int, int, int) -> None
1745 if addr % size_in_bytes != 0:
1746 raise ValueError(f"address not aligned: {hex(addr)} "
1747 f"required alignment: {size_in_bytes}")
1748 for i in range(size_in_bytes):
1749 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
1750
1751 def _memory__repr(self):
1752 # type: () -> str
1753 if len(self.memory) == 0:
1754 return "{}"
1755 keys = sorted(self.memory.keys(), reverse=True)
1756 CHUNK_SIZE = GPR_SIZE_IN_BYTES
1757 items = [] # type: list[str]
1758 while len(keys) != 0:
1759 addr = keys[-1]
1760 if (len(keys) >= CHUNK_SIZE
1761 and addr % CHUNK_SIZE == 0
1762 and keys[-CHUNK_SIZE:]
1763 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
1764 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
1765 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
1766 keys[-CHUNK_SIZE:] = ()
1767 else:
1768 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
1769 if len(items) == 1:
1770 return f"{{{items[0]}}}"
1771 items_str = ",\n".join(items)
1772 return f"{{\n{items_str}}}"
1773
1774 def _ssa_vals__repr(self):
1775 # type: () -> str
1776 if len(self.ssa_vals) == 0:
1777 return "{}"
1778 items = [] # type: list[str]
1779 CHUNK_SIZE = 4
1780 for k, v in self.ssa_vals.items():
1781 element_strs = [] # type: list[str]
1782 for i, el in enumerate(v):
1783 if i % CHUNK_SIZE != 0:
1784 element_strs.append(" " + hex(el))
1785 else:
1786 element_strs.append("\n " + hex(el))
1787 if len(element_strs) <= CHUNK_SIZE:
1788 element_strs[0] = element_strs[0].lstrip()
1789 if len(element_strs) == 1:
1790 element_strs.append("")
1791 v_str = ",".join(element_strs)
1792 items.append(f"{k!r}: ({v_str})")
1793 if len(items) == 1 and "\n" not in items[0]:
1794 return f"{{{items[0]}}}"
1795 items_str = ",\n".join(items)
1796 return f"{{\n{items_str},\n}}"
1797
1798 def __repr__(self):
1799 # type: () -> str
1800 field_vals = [] # type: list[str]
1801 for name in fields(self):
1802 try:
1803 value = getattr(self, name)
1804 except AttributeError:
1805 field_vals.append(f"{name}=<not set>")
1806 continue
1807 repr_fn = getattr(self, f"_{name}__repr", None)
1808 if callable(repr_fn):
1809 field_vals.append(f"{name}={repr_fn()}")
1810 else:
1811 field_vals.append(f"{name}={value!r}")
1812 field_vals_str = ", ".join(field_vals)
1813 return f"PreRASimState({field_vals_str})"