get LocSet hash working correctly
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 import enum
2 from abc import ABCMeta, abstractmethod
3 from enum import Enum, unique
4 from functools import lru_cache, total_ordering
5 from io import StringIO
6 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
7 Mapping, Sequence, TypeVar, Union, overload)
8 from weakref import WeakValueDictionary as _WeakVDict
9
10 from cached_property import cached_property
11 from nmutil.plain_data import fields, plain_data
12
13 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
14 final)
15 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
16 OFSet, OSet)
17
18
19 @final
20 class Fn:
21 def __init__(self):
22 self.ops = [] # type: list[Op]
23 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
24 self.__next_name_suffix = 2
25
26 def _add_op_with_unused_name(self, op, name=""):
27 # type: (Op, str) -> str
28 if op.fn is not self:
29 raise ValueError("can't add Op to wrong Fn")
30 if hasattr(op, "name"):
31 raise ValueError("Op already named")
32 orig_name = name
33 while True:
34 if name != "" and name not in self.__op_names:
35 self.__op_names[name] = op
36 return name
37 name = orig_name + str(self.__next_name_suffix)
38 self.__next_name_suffix += 1
39
40 def __repr__(self):
41 # type: () -> str
42 return "<Fn>"
43
44 def append_op(self, op):
45 # type: (Op) -> None
46 if op.fn is not self:
47 raise ValueError("can't add Op to wrong Fn")
48 self.ops.append(op)
49
50 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
51 maxvl=1):
52 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
53 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
54 input_vals=input_vals, immediates=immediates, name=name)
55 self.append_op(retval)
56 return retval
57
58 def sim(self, state):
59 # type: (BaseSimState) -> None
60 for op in self.ops:
61 op.sim(state)
62
63 def gen_asm(self, state):
64 # type: (GenAsmState) -> None
65 for op in self.ops:
66 op.gen_asm(state)
67
68 def pre_ra_insert_copies(self):
69 # type: () -> None
70 orig_ops = list(self.ops)
71 copied_outputs = {} # type: dict[SSAVal, SSAVal]
72 setvli_outputs = {} # type: dict[SSAVal, Op]
73 self.ops.clear()
74 for op in orig_ops:
75 for i in range(len(op.input_vals)):
76 inp = copied_outputs[op.input_vals[i]]
77 if inp.ty.base_ty is BaseTy.I64:
78 maxvl = inp.ty.reg_len
79 if inp.ty.reg_len != 1:
80 setvl = self.append_new_op(
81 OpKind.SetVLI, immediates=[maxvl],
82 name=f"{op.name}.inp{i}.setvl")
83 vl = setvl.outputs[0]
84 mv = self.append_new_op(
85 OpKind.VecCopyToReg, input_vals=[inp, vl],
86 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
87 else:
88 mv = self.append_new_op(
89 OpKind.CopyToReg, input_vals=[inp],
90 name=f"{op.name}.inp{i}.copy")
91 op.input_vals[i] = mv.outputs[0]
92 elif inp.ty.base_ty is BaseTy.CA \
93 or inp.ty.base_ty is BaseTy.VL_MAXVL:
94 # all copies would be no-ops, so we don't need to copy,
95 # though we do need to rematerialize SetVLI ops right
96 # before the ops VL
97 if inp in setvli_outputs:
98 setvl = self.append_new_op(
99 OpKind.SetVLI,
100 immediates=setvli_outputs[inp].immediates,
101 name=f"{op.name}.inp{i}.setvl")
102 inp = setvl.outputs[0]
103 op.input_vals[i] = inp
104 else:
105 assert_never(inp.ty.base_ty)
106 self.ops.append(op)
107 for i, out in enumerate(op.outputs):
108 if op.kind is OpKind.SetVLI:
109 setvli_outputs[out] = op
110 if out.ty.base_ty is BaseTy.I64:
111 maxvl = out.ty.reg_len
112 if out.ty.reg_len != 1:
113 setvl = self.append_new_op(
114 OpKind.SetVLI, immediates=[maxvl],
115 name=f"{op.name}.out{i}.setvl")
116 vl = setvl.outputs[0]
117 mv = self.append_new_op(
118 OpKind.VecCopyFromReg, input_vals=[out, vl],
119 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
120 else:
121 mv = self.append_new_op(
122 OpKind.CopyFromReg, input_vals=[out],
123 name=f"{op.name}.out{i}.copy")
124 copied_outputs[out] = mv.outputs[0]
125 elif out.ty.base_ty is BaseTy.CA \
126 or out.ty.base_ty is BaseTy.VL_MAXVL:
127 # all copies would be no-ops, so we don't need to copy
128 copied_outputs[out] = out
129 else:
130 assert_never(out.ty.base_ty)
131
132
133 @final
134 @unique
135 @total_ordering
136 class OpStage(Enum):
137 value: Literal[0, 1] # type: ignore
138
139 def __new__(cls, value):
140 # type: (int) -> OpStage
141 value = int(value)
142 if value not in (0, 1):
143 raise ValueError("invalid value")
144 retval = object.__new__(cls)
145 retval._value_ = value
146 return retval
147
148 Early = 0
149 """ early stage of Op execution, where all input reads occur.
150 all output writes with `write_stage == Early` occur here too, and therefore
151 conflict with input reads, telling the compiler that it that can't share
152 that output's register with any inputs that the output isn't tied to.
153
154 All outputs, even unused outputs, can't share registers with any other
155 outputs, independent of `write_stage` settings.
156 """
157 Late = 1
158 """ late stage of Op execution, where all output writes with
159 `write_stage == Late` occur, and therefore don't conflict with input reads,
160 telling the compiler that any inputs can safely use the same register as
161 those outputs.
162
163 All outputs, even unused outputs, can't share registers with any other
164 outputs, independent of `write_stage` settings.
165 """
166
167 def __repr__(self):
168 # type: () -> str
169 return f"OpStage.{self._name_}"
170
171 def __lt__(self, other):
172 # type: (OpStage | object) -> bool
173 if isinstance(other, OpStage):
174 return self.value < other.value
175 return NotImplemented
176
177
178 assert OpStage.Early < OpStage.Late, "early must be less than late"
179
180
181 @plain_data(frozen=True, unsafe_hash=True, repr=False)
182 @final
183 @total_ordering
184 class ProgramPoint(metaclass=InternedMeta):
185 __slots__ = "op_index", "stage"
186
187 def __init__(self, op_index, stage):
188 # type: (int, OpStage) -> None
189 self.op_index = op_index
190 self.stage = stage
191
192 @property
193 def int_value(self):
194 # type: () -> int
195 """ an integer representation of `self` such that it keeps ordering and
196 successor/predecessor relations.
197 """
198 return self.op_index * 2 + self.stage.value
199
200 @staticmethod
201 def from_int_value(int_value):
202 # type: (int) -> ProgramPoint
203 op_index, stage = divmod(int_value, 2)
204 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
205
206 def next(self, steps=1):
207 # type: (int) -> ProgramPoint
208 return ProgramPoint.from_int_value(self.int_value + steps)
209
210 def prev(self, steps=1):
211 # type: (int) -> ProgramPoint
212 return self.next(steps=-steps)
213
214 def __lt__(self, other):
215 # type: (ProgramPoint | Any) -> bool
216 if not isinstance(other, ProgramPoint):
217 return NotImplemented
218 if self.op_index != other.op_index:
219 return self.op_index < other.op_index
220 return self.stage < other.stage
221
222 def __repr__(self):
223 # type: () -> str
224 return f"<ops[{self.op_index}]:{self.stage._name_}>"
225
226
227 @plain_data(frozen=True, unsafe_hash=True, repr=False)
228 @final
229 class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
230 __slots__ = "start", "stop"
231
232 def __init__(self, start, stop):
233 # type: (ProgramPoint, ProgramPoint) -> None
234 self.start = start
235 self.stop = stop
236
237 @cached_property
238 def int_value_range(self):
239 # type: () -> range
240 return range(self.start.int_value, self.stop.int_value)
241
242 @staticmethod
243 def from_int_value_range(int_value_range):
244 # type: (range) -> ProgramRange
245 if int_value_range.step != 1:
246 raise ValueError("int_value_range must have step == 1")
247 return ProgramRange(
248 start=ProgramPoint.from_int_value(int_value_range.start),
249 stop=ProgramPoint.from_int_value(int_value_range.stop))
250
251 @overload
252 def __getitem__(self, __idx):
253 # type: (int) -> ProgramPoint
254 ...
255
256 @overload
257 def __getitem__(self, __idx):
258 # type: (slice) -> ProgramRange
259 ...
260
261 def __getitem__(self, __idx):
262 # type: (int | slice) -> ProgramPoint | ProgramRange
263 v = range(self.start.int_value, self.stop.int_value)[__idx]
264 if isinstance(v, int):
265 return ProgramPoint.from_int_value(v)
266 return ProgramRange.from_int_value_range(v)
267
268 def __len__(self):
269 # type: () -> int
270 return len(self.int_value_range)
271
272 def __iter__(self):
273 # type: () -> Iterator[ProgramPoint]
274 return map(ProgramPoint.from_int_value, self.int_value_range)
275
276 def __repr__(self):
277 # type: () -> str
278 start = repr(self.start).lstrip("<").rstrip(">")
279 stop = repr(self.stop).lstrip("<").rstrip(">")
280 return f"<range:{start}..{stop}>"
281
282
283 @plain_data(frozen=True, eq=False, repr=False)
284 @final
285 class FnAnalysis:
286 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
287 "def_program_ranges", "use_program_points",
288 "all_program_points")
289
290 def __init__(self, fn):
291 # type: (Fn) -> None
292 self.fn = fn
293 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
294 self.all_program_points = ProgramRange(
295 start=ProgramPoint(op_index=0, stage=OpStage.Early),
296 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
297 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
298 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
299 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
300 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
301 for op in fn.ops:
302 for use in op.input_uses:
303 uses[use.ssa_val].add(use)
304 use_program_point = self.__get_use_program_point(use)
305 use_program_points[use] = use_program_point
306 live_range_stops[use.ssa_val] = max(
307 live_range_stops[use.ssa_val], use_program_point.next())
308 for out in op.outputs:
309 uses[out] = OSet()
310 def_program_range = self.__get_def_program_range(out)
311 def_program_ranges[out] = def_program_range
312 live_range_stops[out] = def_program_range.stop
313 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
314 self.def_program_ranges = FMap(def_program_ranges)
315 self.use_program_points = FMap(use_program_points)
316 live_ranges = {} # type: dict[SSAVal, ProgramRange]
317 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
318 for ssa_val in uses.keys():
319 live_ranges[ssa_val] = live_range = ProgramRange(
320 start=self.def_program_ranges[ssa_val].start,
321 stop=live_range_stops[ssa_val])
322 for program_point in live_range:
323 live_at[program_point].add(ssa_val)
324 self.live_ranges = FMap(live_ranges)
325 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
326
327 def __get_def_program_range(self, ssa_val):
328 # type: (SSAVal) -> ProgramRange
329 write_stage = ssa_val.defining_descriptor.write_stage
330 start = ProgramPoint(
331 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
332 # always include late stage of ssa_val.op, to ensure outputs always
333 # overlap all other outputs.
334 # stop is exclusive, so we need the next program point.
335 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
336 return ProgramRange(start=start, stop=stop)
337
338 def __get_use_program_point(self, ssa_use):
339 # type: (SSAUse) -> ProgramPoint
340 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
341 "assumed here, ensured by GenericOpProperties.__init__"
342 return ProgramPoint(
343 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
344
345 def __eq__(self, other):
346 # type: (FnAnalysis | Any) -> bool
347 if isinstance(other, FnAnalysis):
348 return self.fn == other.fn
349 return NotImplemented
350
351 def __hash__(self):
352 # type: () -> int
353 return hash(self.fn)
354
355 def __repr__(self):
356 # type: () -> str
357 return "<FnAnalysis>"
358
359
360 @unique
361 @final
362 class BaseTy(Enum):
363 I64 = enum.auto()
364 CA = enum.auto()
365 VL_MAXVL = enum.auto()
366
367 @cached_property
368 def only_scalar(self):
369 # type: () -> bool
370 if self is BaseTy.I64:
371 return False
372 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
373 return True
374 else:
375 assert_never(self)
376
377 @cached_property
378 def max_reg_len(self):
379 # type: () -> int
380 if self is BaseTy.I64:
381 return 128
382 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
383 return 1
384 else:
385 assert_never(self)
386
387 def __repr__(self):
388 return "BaseTy." + self._name_
389
390
391 @plain_data(frozen=True, unsafe_hash=True, repr=False)
392 @final
393 class Ty(metaclass=InternedMeta):
394 __slots__ = "base_ty", "reg_len"
395
396 @staticmethod
397 def validate(base_ty, reg_len):
398 # type: (BaseTy, int) -> str | None
399 """ return a string with the error if the combination is invalid,
400 otherwise return None
401 """
402 if base_ty.only_scalar and reg_len != 1:
403 return f"can't create a vector of an only-scalar type: {base_ty}"
404 if reg_len < 1 or reg_len > base_ty.max_reg_len:
405 return "reg_len out of range"
406 return None
407
408 def __init__(self, base_ty, reg_len):
409 # type: (BaseTy, int) -> None
410 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
411 if msg is not None:
412 raise ValueError(msg)
413 self.base_ty = base_ty
414 self.reg_len = reg_len
415
416 def __repr__(self):
417 # type: () -> str
418 if self.reg_len != 1:
419 reg_len = f"*{self.reg_len}"
420 else:
421 reg_len = ""
422 return f"<{self.base_ty._name_}{reg_len}>"
423
424
425 @unique
426 @final
427 class LocKind(Enum):
428 GPR = enum.auto()
429 StackI64 = enum.auto()
430 CA = enum.auto()
431 VL_MAXVL = enum.auto()
432
433 @cached_property
434 def base_ty(self):
435 # type: () -> BaseTy
436 if self is LocKind.GPR or self is LocKind.StackI64:
437 return BaseTy.I64
438 if self is LocKind.CA:
439 return BaseTy.CA
440 if self is LocKind.VL_MAXVL:
441 return BaseTy.VL_MAXVL
442 else:
443 assert_never(self)
444
445 @cached_property
446 def loc_count(self):
447 # type: () -> int
448 if self is LocKind.StackI64:
449 return 512
450 if self is LocKind.GPR or self is LocKind.CA \
451 or self is LocKind.VL_MAXVL:
452 return self.base_ty.max_reg_len
453 else:
454 assert_never(self)
455
456 def __repr__(self):
457 return "LocKind." + self._name_
458
459
460 @final
461 @unique
462 class LocSubKind(Enum):
463 BASE_GPR = enum.auto()
464 SV_EXTRA2_VGPR = enum.auto()
465 SV_EXTRA2_SGPR = enum.auto()
466 SV_EXTRA3_VGPR = enum.auto()
467 SV_EXTRA3_SGPR = enum.auto()
468 StackI64 = enum.auto()
469 CA = enum.auto()
470 VL_MAXVL = enum.auto()
471
472 @cached_property
473 def kind(self):
474 # type: () -> LocKind
475 # pyright fails typechecking when using `in` here:
476 # reported: https://github.com/microsoft/pyright/issues/4102
477 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
478 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
479 LocSubKind.SV_EXTRA3_SGPR):
480 return LocKind.GPR
481 if self is LocSubKind.StackI64:
482 return LocKind.StackI64
483 if self is LocSubKind.CA:
484 return LocKind.CA
485 if self is LocSubKind.VL_MAXVL:
486 return LocKind.VL_MAXVL
487 assert_never(self)
488
489 @property
490 def base_ty(self):
491 return self.kind.base_ty
492
493 @lru_cache()
494 def allocatable_locs(self, ty):
495 # type: (Ty) -> LocSet
496 if ty.base_ty != self.base_ty:
497 raise ValueError("type mismatch")
498 if self is LocSubKind.BASE_GPR:
499 starts = range(32)
500 elif self is LocSubKind.SV_EXTRA2_VGPR:
501 starts = range(0, 128, 2)
502 elif self is LocSubKind.SV_EXTRA2_SGPR:
503 starts = range(64)
504 elif self is LocSubKind.SV_EXTRA3_VGPR \
505 or self is LocSubKind.SV_EXTRA3_SGPR:
506 starts = range(128)
507 elif self is LocSubKind.StackI64:
508 starts = range(LocKind.StackI64.loc_count)
509 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
510 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
511 else:
512 assert_never(self)
513 retval = [] # type: list[Loc]
514 for start in starts:
515 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
516 if loc is None:
517 continue
518 conflicts = False
519 for special_loc in SPECIAL_GPRS:
520 if loc.conflicts(special_loc):
521 conflicts = True
522 break
523 if not conflicts:
524 retval.append(loc)
525 return LocSet(retval)
526
527 def __repr__(self):
528 return "LocSubKind." + self._name_
529
530
531 @plain_data(frozen=True, unsafe_hash=True)
532 @final
533 class GenericTy(metaclass=InternedMeta):
534 __slots__ = "base_ty", "is_vec"
535
536 def __init__(self, base_ty, is_vec):
537 # type: (BaseTy, bool) -> None
538 self.base_ty = base_ty
539 if base_ty.only_scalar and is_vec:
540 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
541 self.is_vec = is_vec
542
543 def instantiate(self, maxvl):
544 # type: (int) -> Ty
545 # here's where subvl and elwid would be accounted for
546 if self.is_vec:
547 return Ty(self.base_ty, maxvl)
548 return Ty(self.base_ty, 1)
549
550 def can_instantiate_to(self, ty):
551 # type: (Ty) -> bool
552 if self.base_ty != ty.base_ty:
553 return False
554 if self.is_vec:
555 return True
556 return ty.reg_len == 1
557
558
559 @plain_data(frozen=True, unsafe_hash=True)
560 @final
561 class Loc(metaclass=InternedMeta):
562 __slots__ = "kind", "start", "reg_len"
563
564 @staticmethod
565 def validate(kind, start, reg_len):
566 # type: (LocKind, int, int) -> str | None
567 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
568 if msg is not None:
569 return msg
570 if reg_len > kind.loc_count:
571 return "invalid reg_len"
572 if start < 0 or start + reg_len > kind.loc_count:
573 return "start not in valid range"
574 return None
575
576 @staticmethod
577 def try_make(kind, start, reg_len):
578 # type: (LocKind, int, int) -> Loc | None
579 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
580 if msg is not None:
581 return None
582 return Loc(kind=kind, start=start, reg_len=reg_len)
583
584 def __init__(self, kind, start, reg_len):
585 # type: (LocKind, int, int) -> None
586 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
587 if msg is not None:
588 raise ValueError(msg)
589 self.kind = kind
590 self.reg_len = reg_len
591 self.start = start
592
593 def conflicts(self, other):
594 # type: (Loc) -> bool
595 return (self.kind == other.kind
596 and self.start < other.stop and other.start < self.stop)
597
598 @staticmethod
599 def make_ty(kind, reg_len):
600 # type: (LocKind, int) -> Ty
601 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
602
603 @cached_property
604 def ty(self):
605 # type: () -> Ty
606 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
607
608 @property
609 def stop(self):
610 # type: () -> int
611 return self.start + self.reg_len
612
613 def try_concat(self, *others):
614 # type: (*Loc | None) -> Loc | None
615 reg_len = self.reg_len
616 stop = self.stop
617 for other in others:
618 if other is None or other.kind != self.kind:
619 return None
620 if stop != other.start:
621 return None
622 stop = other.stop
623 reg_len += other.reg_len
624 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
625
626 def get_subloc_at_offset(self, subloc_ty, offset):
627 # type: (Ty, int) -> Loc
628 if subloc_ty.base_ty != self.kind.base_ty:
629 raise ValueError("BaseTy mismatch")
630 if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
631 raise ValueError("invalid sub-Loc: offset and/or "
632 "subloc_ty.reg_len out of range")
633 return Loc(kind=self.kind,
634 start=self.start + offset, reg_len=subloc_ty.reg_len)
635
636
637 SPECIAL_GPRS = (
638 Loc(kind=LocKind.GPR, start=0, reg_len=1),
639 Loc(kind=LocKind.GPR, start=1, reg_len=1),
640 Loc(kind=LocKind.GPR, start=2, reg_len=1),
641 Loc(kind=LocKind.GPR, start=13, reg_len=1),
642 )
643
644
645 @final
646 class LocSet(OFSet[Loc], metaclass=InternedMeta):
647 def __init__(self, __locs=()):
648 # type: (Iterable[Loc]) -> None
649 super().__init__(__locs)
650 if isinstance(__locs, LocSet):
651 self.__starts = __locs.starts
652 self.__ty = __locs.ty
653 return
654 starts = {i: BitSet() for i in LocKind}
655 ty = None # type: None | Ty
656 for loc in self:
657 if ty is None:
658 ty = loc.ty
659 if ty != loc.ty:
660 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
661 starts[loc.kind].add(loc.start)
662 self.__starts = FMap(
663 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
664 self.__ty = ty
665
666 @property
667 def starts(self):
668 # type: () -> FMap[LocKind, FBitSet]
669 return self.__starts
670
671 @property
672 def ty(self):
673 # type: () -> Ty | None
674 return self.__ty
675
676 @cached_property
677 def stops(self):
678 # type: () -> FMap[LocKind, FBitSet]
679 if self.ty is None:
680 return FMap()
681 sh = self.ty.reg_len
682 return FMap(
683 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
684
685 @property
686 def kinds(self):
687 # type: () -> AbstractSet[LocKind]
688 return self.starts.keys()
689
690 @property
691 def reg_len(self):
692 # type: () -> int | None
693 if self.ty is None:
694 return None
695 return self.ty.reg_len
696
697 @property
698 def base_ty(self):
699 # type: () -> BaseTy | None
700 if self.ty is None:
701 return None
702 return self.ty.base_ty
703
704 def concat(self, *others):
705 # type: (*LocSet) -> LocSet
706 if self.ty is None:
707 return LocSet()
708 base_ty = self.ty.base_ty
709 reg_len = self.ty.reg_len
710 starts = {k: BitSet(v) for k, v in self.starts.items()}
711 for other in others:
712 if other.ty is None:
713 return LocSet()
714 if other.ty.base_ty != base_ty:
715 return LocSet()
716 for kind, other_starts in other.starts.items():
717 if kind not in starts:
718 continue
719 starts[kind].bits &= other_starts.bits >> reg_len
720 if starts[kind] == 0:
721 del starts[kind]
722 if len(starts) == 0:
723 return LocSet()
724 reg_len += other.ty.reg_len
725
726 def locs():
727 # type: () -> Iterable[Loc]
728 for kind, v in starts.items():
729 for start in v:
730 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
731 if loc is not None:
732 yield loc
733 return LocSet(locs())
734
735 @lru_cache(maxsize=None, typed=True)
736 def max_conflicts_with(self, other):
737 # type: (LocSet | Loc) -> int
738 """the largest number of Locs in `self` that a single Loc
739 from `other` can conflict with
740 """
741 if isinstance(other, LocSet):
742 return max(self.max_conflicts_with(i) for i in other)
743 else:
744 return sum(other.conflicts(i) for i in self)
745
746 def __repr__(self):
747 return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
748
749
750 @plain_data(frozen=True, unsafe_hash=True)
751 @final
752 class GenericOperandDesc(metaclass=InternedMeta):
753 """generic Op operand descriptor"""
754 __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
755 "write_stage")
756
757 def __init__(
758 self, ty, # type: GenericTy
759 sub_kinds, # type: Iterable[LocSubKind]
760 *,
761 fixed_loc=None, # type: Loc | None
762 tied_input_index=None, # type: int | None
763 spread=False, # type: bool
764 write_stage=OpStage.Early, # type: OpStage
765 ):
766 # type: (...) -> None
767 self.ty = ty
768 self.sub_kinds = OFSet(sub_kinds)
769 if len(self.sub_kinds) == 0:
770 raise ValueError("sub_kinds can't be empty")
771 self.fixed_loc = fixed_loc
772 if fixed_loc is not None:
773 if tied_input_index is not None:
774 raise ValueError("operand can't be both tied and fixed")
775 if not ty.can_instantiate_to(fixed_loc.ty):
776 raise ValueError(
777 f"fixed_loc has incompatible type for given generic "
778 f"type: fixed_loc={fixed_loc} generic ty={ty}")
779 if len(self.sub_kinds) != 1:
780 raise ValueError(
781 "multiple sub_kinds not allowed for fixed operand")
782 for sub_kind in self.sub_kinds:
783 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
784 raise ValueError(
785 f"fixed_loc not in given sub_kind: "
786 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
787 for sub_kind in self.sub_kinds:
788 if sub_kind.base_ty != ty.base_ty:
789 raise ValueError(f"sub_kind is incompatible with type: "
790 f"sub_kind={sub_kind} ty={ty}")
791 if tied_input_index is not None and tied_input_index < 0:
792 raise ValueError("invalid tied_input_index")
793 self.tied_input_index = tied_input_index
794 self.spread = spread
795 if spread:
796 if self.tied_input_index is not None:
797 raise ValueError("operand can't be both spread and tied")
798 if self.fixed_loc is not None:
799 raise ValueError("operand can't be both spread and fixed")
800 if self.ty.is_vec:
801 raise ValueError("operand can't be both spread and vector")
802 self.write_stage = write_stage
803
804 @cached_property
805 def ty_before_spread(self):
806 # type: () -> GenericTy
807 if self.spread:
808 return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
809 return self.ty
810
811 def tied_to_input(self, tied_input_index):
812 # type: (int) -> Self
813 return GenericOperandDesc(self.ty, self.sub_kinds,
814 tied_input_index=tied_input_index,
815 write_stage=self.write_stage)
816
817 def with_fixed_loc(self, fixed_loc):
818 # type: (Loc) -> Self
819 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
820 write_stage=self.write_stage)
821
822 def with_write_stage(self, write_stage):
823 # type: (OpStage) -> Self
824 return GenericOperandDesc(self.ty, self.sub_kinds,
825 fixed_loc=self.fixed_loc,
826 tied_input_index=self.tied_input_index,
827 spread=self.spread,
828 write_stage=write_stage)
829
830 def instantiate(self, maxvl):
831 # type: (int) -> Iterable[OperandDesc]
832 # assumes all spread operands have ty.reg_len = 1
833 rep_count = 1
834 if self.spread:
835 rep_count = maxvl
836 ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
837
838 def locs_before_spread():
839 # type: () -> Iterable[Loc]
840 if self.fixed_loc is not None:
841 if ty_before_spread != self.fixed_loc.ty:
842 raise ValueError(
843 f"instantiation failed: type mismatch with fixed_loc: "
844 f"instantiated type: {ty_before_spread} "
845 f"fixed_loc: {self.fixed_loc}")
846 yield self.fixed_loc
847 return
848 for sub_kind in self.sub_kinds:
849 yield from sub_kind.allocatable_locs(ty_before_spread)
850 loc_set_before_spread = LocSet(locs_before_spread())
851 for idx in range(rep_count):
852 if not self.spread:
853 idx = None
854 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
855 tied_input_index=self.tied_input_index,
856 spread_index=idx, write_stage=self.write_stage)
857
858
859 @plain_data(frozen=True, unsafe_hash=True)
860 @final
861 class OperandDesc(metaclass=InternedMeta):
862 """Op operand descriptor"""
863 __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
864 "write_stage")
865
866 def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
867 write_stage):
868 # type: (LocSet, int | None, int | None, OpStage) -> None
869 if len(loc_set_before_spread) == 0:
870 raise ValueError("loc_set_before_spread must not be empty")
871 self.loc_set_before_spread = loc_set_before_spread
872 self.tied_input_index = tied_input_index
873 if self.tied_input_index is not None and spread_index is not None:
874 raise ValueError("operand can't be both spread and tied")
875 self.spread_index = spread_index
876 self.write_stage = write_stage
877
878 @cached_property
879 def ty_before_spread(self):
880 # type: () -> Ty
881 ty = self.loc_set_before_spread.ty
882 assert ty is not None, (
883 "__init__ checked that the LocSet isn't empty, "
884 "non-empty LocSets should always have ty set")
885 return ty
886
887 @cached_property
888 def ty(self):
889 """ Ty after any spread is applied """
890 if self.spread_index is not None:
891 # assumes all spread operands have ty.reg_len = 1
892 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
893 return self.ty_before_spread
894
895 @property
896 def reg_offset_in_unspread(self):
897 """ the number of reg-sized slots in the unspread Loc before self's Loc
898
899 e.g. if the unspread Loc containing self is:
900 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
901 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
902 then reg_offset_into_unspread == 2 == 10 - 8
903 """
904 if self.spread_index is None:
905 return 0
906 return self.spread_index * self.ty.reg_len
907
908
909 OD_BASE_SGPR = GenericOperandDesc(
910 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
911 sub_kinds=[LocSubKind.BASE_GPR])
912 OD_EXTRA3_SGPR = GenericOperandDesc(
913 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
914 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
915 OD_EXTRA3_VGPR = GenericOperandDesc(
916 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
917 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
918 OD_EXTRA2_SGPR = GenericOperandDesc(
919 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
920 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
921 OD_EXTRA2_VGPR = GenericOperandDesc(
922 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
923 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
924 OD_CA = GenericOperandDesc(
925 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
926 sub_kinds=[LocSubKind.CA])
927 OD_VL = GenericOperandDesc(
928 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
929 sub_kinds=[LocSubKind.VL_MAXVL])
930
931
932 @plain_data(frozen=True, unsafe_hash=True)
933 @final
934 class GenericOpProperties(metaclass=InternedMeta):
935 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
936 "is_copy", "is_load_immediate", "has_side_effects")
937
938 def __init__(
939 self, demo_asm, # type: str
940 inputs, # type: Iterable[GenericOperandDesc]
941 outputs, # type: Iterable[GenericOperandDesc]
942 immediates=(), # type: Iterable[range]
943 is_copy=False, # type: bool
944 is_load_immediate=False, # type: bool
945 has_side_effects=False, # type: bool
946 ):
947 # type: (...) -> None
948 self.demo_asm = demo_asm # type: str
949 self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
950 for inp in self.inputs:
951 if inp.tied_input_index is not None:
952 raise ValueError(
953 f"tied_input_index is not allowed on inputs: {inp}")
954 if inp.write_stage is not OpStage.Early:
955 raise ValueError(
956 f"write_stage is not allowed on inputs: {inp}")
957 self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
958 fixed_locs = [] # type: list[tuple[Loc, int]]
959 for idx, out in enumerate(self.outputs):
960 if out.tied_input_index is not None:
961 if out.tied_input_index >= len(self.inputs):
962 raise ValueError(f"tied_input_index out of range: {out}")
963 tied_inp = self.inputs[out.tied_input_index]
964 expected_out = tied_inp.tied_to_input(out.tied_input_index) \
965 .with_write_stage(out.write_stage)
966 if expected_out != out:
967 raise ValueError(f"output can't be tied to non-equivalent "
968 f"input: {out} tied to {tied_inp}")
969 if out.fixed_loc is not None:
970 for other_fixed_loc, other_idx in fixed_locs:
971 if not other_fixed_loc.conflicts(out.fixed_loc):
972 continue
973 raise ValueError(
974 f"conflicting fixed_locs: outputs[{idx}] and "
975 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
976 f"with {other_fixed_loc}")
977 fixed_locs.append((out.fixed_loc, idx))
978 self.immediates = tuple(immediates) # type: tuple[range, ...]
979 self.is_copy = is_copy # type: bool
980 self.is_load_immediate = is_load_immediate # type: bool
981 self.has_side_effects = has_side_effects # type: bool
982
983
984 @plain_data(frozen=True, unsafe_hash=True)
985 @final
986 class OpProperties(metaclass=InternedMeta):
987 __slots__ = "kind", "inputs", "outputs", "maxvl"
988
989 def __init__(self, kind, maxvl):
990 # type: (OpKind, int) -> None
991 self.kind = kind # type: OpKind
992 inputs = [] # type: list[OperandDesc]
993 for inp in self.generic.inputs:
994 inputs.extend(inp.instantiate(maxvl=maxvl))
995 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
996 outputs = [] # type: list[OperandDesc]
997 for out in self.generic.outputs:
998 outputs.extend(out.instantiate(maxvl=maxvl))
999 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
1000 self.maxvl = maxvl # type: int
1001
1002 @property
1003 def generic(self):
1004 # type: () -> GenericOpProperties
1005 return self.kind.properties
1006
1007 @property
1008 def immediates(self):
1009 # type: () -> tuple[range, ...]
1010 return self.generic.immediates
1011
1012 @property
1013 def demo_asm(self):
1014 # type: () -> str
1015 return self.generic.demo_asm
1016
1017 @property
1018 def is_copy(self):
1019 # type: () -> bool
1020 return self.generic.is_copy
1021
1022 @property
1023 def is_load_immediate(self):
1024 # type: () -> bool
1025 return self.generic.is_load_immediate
1026
1027 @property
1028 def has_side_effects(self):
1029 # type: () -> bool
1030 return self.generic.has_side_effects
1031
1032
1033 IMM_S16 = range(-1 << 15, 1 << 15)
1034
1035 _SIM_FN = Callable[["Op", "BaseSimState"], None]
1036 _SIM_FN2 = Callable[[], _SIM_FN]
1037 _SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1038 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
1039 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
1040 _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1041
1042
1043 @unique
1044 @final
1045 class OpKind(Enum):
1046 def __init__(self, properties):
1047 # type: (GenericOpProperties) -> None
1048 super().__init__()
1049 self.__properties = properties
1050
1051 @property
1052 def properties(self):
1053 # type: () -> GenericOpProperties
1054 return self.__properties
1055
1056 def instantiate(self, maxvl):
1057 # type: (int) -> OpProperties
1058 return OpProperties(self, maxvl=maxvl)
1059
1060 def __repr__(self):
1061 # type: () -> str
1062 return "OpKind." + self._name_
1063
1064 @cached_property
1065 def sim(self):
1066 # type: () -> _SIM_FN
1067 return _SIM_FNS[self.properties]()
1068
1069 @cached_property
1070 def gen_asm(self):
1071 # type: () -> _GEN_ASM_FN
1072 return _GEN_ASMS[self.properties]()
1073
1074 @staticmethod
1075 def __clearca_sim(op, state):
1076 # type: (Op, BaseSimState) -> None
1077 state[op.outputs[0]] = False,
1078
1079 @staticmethod
1080 def __clearca_gen_asm(op, state):
1081 # type: (Op, GenAsmState) -> None
1082 state.writeln("addic 0, 0, 0")
1083 ClearCA = GenericOpProperties(
1084 demo_asm="addic 0, 0, 0",
1085 inputs=[],
1086 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1087 )
1088 _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
1089 _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
1090
1091 @staticmethod
1092 def __setca_sim(op, state):
1093 # type: (Op, BaseSimState) -> None
1094 state[op.outputs[0]] = True,
1095
1096 @staticmethod
1097 def __setca_gen_asm(op, state):
1098 # type: (Op, GenAsmState) -> None
1099 state.writeln("subfc 0, 0, 0")
1100 SetCA = GenericOpProperties(
1101 demo_asm="subfc 0, 0, 0",
1102 inputs=[],
1103 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1104 )
1105 _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
1106 _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
1107
1108 @staticmethod
1109 def __svadde_sim(op, state):
1110 # type: (Op, BaseSimState) -> None
1111 RA = state[op.input_vals[0]]
1112 RB = state[op.input_vals[1]]
1113 carry, = state[op.input_vals[2]]
1114 VL, = state[op.input_vals[3]]
1115 RT = [] # type: list[int]
1116 for i in range(VL):
1117 v = RA[i] + RB[i] + carry
1118 RT.append(v & GPR_VALUE_MASK)
1119 carry = (v >> GPR_SIZE_IN_BITS) != 0
1120 state[op.outputs[0]] = tuple(RT)
1121 state[op.outputs[1]] = carry,
1122
1123 @staticmethod
1124 def __svadde_gen_asm(op, state):
1125 # type: (Op, GenAsmState) -> None
1126 RT = state.vgpr(op.outputs[0])
1127 RA = state.vgpr(op.input_vals[0])
1128 RB = state.vgpr(op.input_vals[1])
1129 state.writeln(f"sv.adde {RT}, {RA}, {RB}")
1130 SvAddE = GenericOpProperties(
1131 demo_asm="sv.adde *RT, *RA, *RB",
1132 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1133 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1134 )
1135 _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
1136 _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
1137
1138 @staticmethod
1139 def __addze_sim(op, state):
1140 # type: (Op, BaseSimState) -> None
1141 RA, = state[op.input_vals[0]]
1142 carry, = state[op.input_vals[1]]
1143 v = RA + carry
1144 RT = v & GPR_VALUE_MASK
1145 carry = (v >> GPR_SIZE_IN_BITS) != 0
1146 state[op.outputs[0]] = RT,
1147 state[op.outputs[1]] = carry,
1148
1149 @staticmethod
1150 def __addze_gen_asm(op, state):
1151 # type: (Op, GenAsmState) -> None
1152 RT = state.vgpr(op.outputs[0])
1153 RA = state.vgpr(op.input_vals[0])
1154 state.writeln(f"addze {RT}, {RA}")
1155 AddZE = GenericOpProperties(
1156 demo_asm="addze RT, RA",
1157 inputs=[OD_BASE_SGPR, OD_CA],
1158 outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
1159 )
1160 _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
1161 _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
1162
1163 @staticmethod
1164 def __svsubfe_sim(op, state):
1165 # type: (Op, BaseSimState) -> None
1166 RA = state[op.input_vals[0]]
1167 RB = state[op.input_vals[1]]
1168 carry, = state[op.input_vals[2]]
1169 VL, = state[op.input_vals[3]]
1170 RT = [] # type: list[int]
1171 for i in range(VL):
1172 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1173 RT.append(v & GPR_VALUE_MASK)
1174 carry = (v >> GPR_SIZE_IN_BITS) != 0
1175 state[op.outputs[0]] = tuple(RT)
1176 state[op.outputs[1]] = carry,
1177
1178 @staticmethod
1179 def __svsubfe_gen_asm(op, state):
1180 # type: (Op, GenAsmState) -> None
1181 RT = state.vgpr(op.outputs[0])
1182 RA = state.vgpr(op.input_vals[0])
1183 RB = state.vgpr(op.input_vals[1])
1184 state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
1185 SvSubFE = GenericOpProperties(
1186 demo_asm="sv.subfe *RT, *RA, *RB",
1187 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1188 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1189 )
1190 _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
1191 _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
1192
1193 @staticmethod
1194 def __svmaddedu_sim(op, state):
1195 # type: (Op, BaseSimState) -> None
1196 RA = state[op.input_vals[0]]
1197 RB, = state[op.input_vals[1]]
1198 carry, = state[op.input_vals[2]]
1199 VL, = state[op.input_vals[3]]
1200 RT = [] # type: list[int]
1201 for i in range(VL):
1202 v = RA[i] * RB + carry
1203 RT.append(v & GPR_VALUE_MASK)
1204 carry = v >> GPR_SIZE_IN_BITS
1205 state[op.outputs[0]] = tuple(RT)
1206 state[op.outputs[1]] = carry,
1207
1208 @staticmethod
1209 def __svmaddedu_gen_asm(op, state):
1210 # type: (Op, GenAsmState) -> None
1211 RT = state.vgpr(op.outputs[0])
1212 RA = state.vgpr(op.input_vals[0])
1213 RB = state.sgpr(op.input_vals[1])
1214 RC = state.sgpr(op.input_vals[2])
1215 state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1216 SvMAddEDU = GenericOpProperties(
1217 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1218 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1219 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1220 )
1221 _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
1222 _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
1223
1224 @staticmethod
1225 def __setvli_sim(op, state):
1226 # type: (Op, BaseSimState) -> None
1227 state[op.outputs[0]] = op.immediates[0],
1228
1229 @staticmethod
1230 def __setvli_gen_asm(op, state):
1231 # type: (Op, GenAsmState) -> None
1232 imm = op.immediates[0]
1233 state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
1234 SetVLI = GenericOpProperties(
1235 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1236 inputs=(),
1237 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1238 immediates=[range(1, 65)],
1239 is_load_immediate=True,
1240 )
1241 _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
1242 _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
1243
1244 @staticmethod
1245 def __svli_sim(op, state):
1246 # type: (Op, BaseSimState) -> None
1247 VL, = state[op.input_vals[0]]
1248 imm = op.immediates[0] & GPR_VALUE_MASK
1249 state[op.outputs[0]] = (imm,) * VL
1250
1251 @staticmethod
1252 def __svli_gen_asm(op, state):
1253 # type: (Op, GenAsmState) -> None
1254 RT = state.vgpr(op.outputs[0])
1255 imm = op.immediates[0]
1256 state.writeln(f"sv.addi {RT}, 0, {imm}")
1257 SvLI = GenericOpProperties(
1258 demo_asm="sv.addi *RT, 0, imm",
1259 inputs=[OD_VL],
1260 outputs=[OD_EXTRA3_VGPR],
1261 immediates=[IMM_S16],
1262 is_load_immediate=True,
1263 )
1264 _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
1265 _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
1266
1267 @staticmethod
1268 def __li_sim(op, state):
1269 # type: (Op, BaseSimState) -> None
1270 imm = op.immediates[0] & GPR_VALUE_MASK
1271 state[op.outputs[0]] = imm,
1272
1273 @staticmethod
1274 def __li_gen_asm(op, state):
1275 # type: (Op, GenAsmState) -> None
1276 RT = state.sgpr(op.outputs[0])
1277 imm = op.immediates[0]
1278 state.writeln(f"addi {RT}, 0, {imm}")
1279 LI = GenericOpProperties(
1280 demo_asm="addi RT, 0, imm",
1281 inputs=(),
1282 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1283 immediates=[IMM_S16],
1284 is_load_immediate=True,
1285 )
1286 _SIM_FNS[LI] = lambda: OpKind.__li_sim
1287 _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
1288
1289 @staticmethod
1290 def __veccopytoreg_sim(op, state):
1291 # type: (Op, BaseSimState) -> None
1292 state[op.outputs[0]] = state[op.input_vals[0]]
1293
1294 @staticmethod
1295 def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
1296 # type: (Loc, Loc, bool, GenAsmState) -> None
1297 sv = "sv." if is_vec else ""
1298 rev = ""
1299 if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
1300 rev = "/mrr"
1301 if src_loc == dest_loc:
1302 return # no-op
1303 if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1304 raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
1305 if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1306 raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
1307 if src_loc.kind is LocKind.StackI64:
1308 if dest_loc.kind is LocKind.StackI64:
1309 raise ValueError(
1310 f"can't copy from stack to stack: {src_loc} {dest_loc}")
1311 elif dest_loc.kind is not LocKind.GPR:
1312 assert_never(dest_loc.kind)
1313 src = state.stack(src_loc)
1314 dest = state.gpr(dest_loc, is_vec=is_vec)
1315 state.writeln(f"{sv}ld {dest}, {src}")
1316 elif dest_loc.kind is LocKind.StackI64:
1317 if src_loc.kind is not LocKind.GPR:
1318 assert_never(src_loc.kind)
1319 src = state.gpr(src_loc, is_vec=is_vec)
1320 dest = state.stack(dest_loc)
1321 state.writeln(f"{sv}std {src}, {dest}")
1322 elif src_loc.kind is LocKind.GPR:
1323 if dest_loc.kind is not LocKind.GPR:
1324 assert_never(dest_loc.kind)
1325 src = state.gpr(src_loc, is_vec=is_vec)
1326 dest = state.gpr(dest_loc, is_vec=is_vec)
1327 state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
1328 else:
1329 assert_never(src_loc.kind)
1330
1331 @staticmethod
1332 def __veccopytoreg_gen_asm(op, state):
1333 # type: (Op, GenAsmState) -> None
1334 OpKind.__copy_to_from_reg_gen_asm(
1335 src_loc=state.loc(
1336 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1337 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1338 is_vec=True, state=state)
1339
1340 VecCopyToReg = GenericOpProperties(
1341 demo_asm="sv.mv dest, src",
1342 inputs=[GenericOperandDesc(
1343 ty=GenericTy(BaseTy.I64, is_vec=True),
1344 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1345 ), OD_VL],
1346 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1347 is_copy=True,
1348 )
1349 _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
1350 _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
1351
1352 @staticmethod
1353 def __veccopyfromreg_sim(op, state):
1354 # type: (Op, BaseSimState) -> None
1355 state[op.outputs[0]] = state[op.input_vals[0]]
1356
1357 @staticmethod
1358 def __veccopyfromreg_gen_asm(op, state):
1359 # type: (Op, GenAsmState) -> None
1360 OpKind.__copy_to_from_reg_gen_asm(
1361 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1362 dest_loc=state.loc(
1363 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1364 is_vec=True, state=state)
1365 VecCopyFromReg = GenericOpProperties(
1366 demo_asm="sv.mv dest, src",
1367 inputs=[OD_EXTRA3_VGPR, OD_VL],
1368 outputs=[GenericOperandDesc(
1369 ty=GenericTy(BaseTy.I64, is_vec=True),
1370 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1371 write_stage=OpStage.Late,
1372 )],
1373 is_copy=True,
1374 )
1375 _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
1376 _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
1377
1378 @staticmethod
1379 def __copytoreg_sim(op, state):
1380 # type: (Op, BaseSimState) -> None
1381 state[op.outputs[0]] = state[op.input_vals[0]]
1382
1383 @staticmethod
1384 def __copytoreg_gen_asm(op, state):
1385 # type: (Op, GenAsmState) -> None
1386 OpKind.__copy_to_from_reg_gen_asm(
1387 src_loc=state.loc(
1388 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1389 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1390 is_vec=False, state=state)
1391 CopyToReg = GenericOpProperties(
1392 demo_asm="mv dest, src",
1393 inputs=[GenericOperandDesc(
1394 ty=GenericTy(BaseTy.I64, is_vec=False),
1395 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1396 LocSubKind.StackI64],
1397 )],
1398 outputs=[GenericOperandDesc(
1399 ty=GenericTy(BaseTy.I64, is_vec=False),
1400 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1401 write_stage=OpStage.Late,
1402 )],
1403 is_copy=True,
1404 )
1405 _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
1406 _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
1407
1408 @staticmethod
1409 def __copyfromreg_sim(op, state):
1410 # type: (Op, BaseSimState) -> None
1411 state[op.outputs[0]] = state[op.input_vals[0]]
1412
1413 @staticmethod
1414 def __copyfromreg_gen_asm(op, state):
1415 # type: (Op, GenAsmState) -> None
1416 OpKind.__copy_to_from_reg_gen_asm(
1417 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1418 dest_loc=state.loc(
1419 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1420 is_vec=False, state=state)
1421 CopyFromReg = GenericOpProperties(
1422 demo_asm="mv dest, src",
1423 inputs=[GenericOperandDesc(
1424 ty=GenericTy(BaseTy.I64, is_vec=False),
1425 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1426 )],
1427 outputs=[GenericOperandDesc(
1428 ty=GenericTy(BaseTy.I64, is_vec=False),
1429 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1430 LocSubKind.StackI64],
1431 write_stage=OpStage.Late,
1432 )],
1433 is_copy=True,
1434 )
1435 _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
1436 _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
1437
1438 @staticmethod
1439 def __concat_sim(op, state):
1440 # type: (Op, BaseSimState) -> None
1441 state[op.outputs[0]] = tuple(
1442 state[i][0] for i in op.input_vals[:-1])
1443
1444 @staticmethod
1445 def __concat_gen_asm(op, state):
1446 # type: (Op, GenAsmState) -> None
1447 OpKind.__copy_to_from_reg_gen_asm(
1448 src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
1449 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1450 is_vec=True, state=state)
1451 Concat = GenericOpProperties(
1452 demo_asm="sv.mv dest, src",
1453 inputs=[GenericOperandDesc(
1454 ty=GenericTy(BaseTy.I64, is_vec=False),
1455 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1456 spread=True,
1457 ), OD_VL],
1458 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1459 is_copy=True,
1460 )
1461 _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
1462 _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
1463
1464 @staticmethod
1465 def __spread_sim(op, state):
1466 # type: (Op, BaseSimState) -> None
1467 for idx, inp in enumerate(state[op.input_vals[0]]):
1468 state[op.outputs[idx]] = inp,
1469
1470 @staticmethod
1471 def __spread_gen_asm(op, state):
1472 # type: (Op, GenAsmState) -> None
1473 OpKind.__copy_to_from_reg_gen_asm(
1474 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1475 dest_loc=state.loc(op.outputs, LocKind.GPR),
1476 is_vec=True, state=state)
1477 Spread = GenericOpProperties(
1478 demo_asm="sv.mv dest, src",
1479 inputs=[OD_EXTRA3_VGPR, OD_VL],
1480 outputs=[GenericOperandDesc(
1481 ty=GenericTy(BaseTy.I64, is_vec=False),
1482 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1483 spread=True,
1484 write_stage=OpStage.Late,
1485 )],
1486 is_copy=True,
1487 )
1488 _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
1489 _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
1490
1491 @staticmethod
1492 def __svld_sim(op, state):
1493 # type: (Op, BaseSimState) -> None
1494 RA, = state[op.input_vals[0]]
1495 VL, = state[op.input_vals[1]]
1496 addr = RA + op.immediates[0]
1497 RT = [] # type: list[int]
1498 for i in range(VL):
1499 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1500 RT.append(v & GPR_VALUE_MASK)
1501 state[op.outputs[0]] = tuple(RT)
1502
1503 @staticmethod
1504 def __svld_gen_asm(op, state):
1505 # type: (Op, GenAsmState) -> None
1506 RA = state.sgpr(op.input_vals[0])
1507 RT = state.vgpr(op.outputs[0])
1508 imm = op.immediates[0]
1509 state.writeln(f"sv.ld {RT}, {imm}({RA})")
1510 SvLd = GenericOpProperties(
1511 demo_asm="sv.ld *RT, imm(RA)",
1512 inputs=[OD_EXTRA3_SGPR, OD_VL],
1513 outputs=[OD_EXTRA3_VGPR],
1514 immediates=[IMM_S16],
1515 )
1516 _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
1517 _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
1518
1519 @staticmethod
1520 def __ld_sim(op, state):
1521 # type: (Op, BaseSimState) -> None
1522 RA, = state[op.input_vals[0]]
1523 addr = RA + op.immediates[0]
1524 v = state.load(addr)
1525 state[op.outputs[0]] = v & GPR_VALUE_MASK,
1526
1527 @staticmethod
1528 def __ld_gen_asm(op, state):
1529 # type: (Op, GenAsmState) -> None
1530 RA = state.sgpr(op.input_vals[0])
1531 RT = state.sgpr(op.outputs[0])
1532 imm = op.immediates[0]
1533 state.writeln(f"ld {RT}, {imm}({RA})")
1534 Ld = GenericOpProperties(
1535 demo_asm="ld RT, imm(RA)",
1536 inputs=[OD_BASE_SGPR],
1537 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1538 immediates=[IMM_S16],
1539 )
1540 _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
1541 _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
1542
1543 @staticmethod
1544 def __svstd_sim(op, state):
1545 # type: (Op, BaseSimState) -> None
1546 RS = state[op.input_vals[0]]
1547 RA, = state[op.input_vals[1]]
1548 VL, = state[op.input_vals[2]]
1549 addr = RA + op.immediates[0]
1550 for i in range(VL):
1551 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1552
1553 @staticmethod
1554 def __svstd_gen_asm(op, state):
1555 # type: (Op, GenAsmState) -> None
1556 RS = state.vgpr(op.input_vals[0])
1557 RA = state.sgpr(op.input_vals[1])
1558 imm = op.immediates[0]
1559 state.writeln(f"sv.std {RS}, {imm}({RA})")
1560 SvStd = GenericOpProperties(
1561 demo_asm="sv.std *RS, imm(RA)",
1562 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1563 outputs=[],
1564 immediates=[IMM_S16],
1565 has_side_effects=True,
1566 )
1567 _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
1568 _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
1569
1570 @staticmethod
1571 def __std_sim(op, state):
1572 # type: (Op, BaseSimState) -> None
1573 RS, = state[op.input_vals[0]]
1574 RA, = state[op.input_vals[1]]
1575 addr = RA + op.immediates[0]
1576 state.store(addr, value=RS)
1577
1578 @staticmethod
1579 def __std_gen_asm(op, state):
1580 # type: (Op, GenAsmState) -> None
1581 RS = state.sgpr(op.input_vals[0])
1582 RA = state.sgpr(op.input_vals[1])
1583 imm = op.immediates[0]
1584 state.writeln(f"std {RS}, {imm}({RA})")
1585 Std = GenericOpProperties(
1586 demo_asm="std RS, imm(RA)",
1587 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1588 outputs=[],
1589 immediates=[IMM_S16],
1590 has_side_effects=True,
1591 )
1592 _SIM_FNS[Std] = lambda: OpKind.__std_sim
1593 _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
1594
1595 @staticmethod
1596 def __funcargr3_sim(op, state):
1597 # type: (Op, BaseSimState) -> None
1598 pass # return value set before simulation
1599
1600 @staticmethod
1601 def __funcargr3_gen_asm(op, state):
1602 # type: (Op, GenAsmState) -> None
1603 pass # no instructions needed
1604 FuncArgR3 = GenericOpProperties(
1605 demo_asm="",
1606 inputs=[],
1607 outputs=[OD_BASE_SGPR.with_fixed_loc(
1608 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1609 )
1610 _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
1611 _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
1612
1613
1614 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1615 class SSAValOrUse(metaclass=InternedMeta):
1616 __slots__ = "op", "operand_idx"
1617
1618 def __init__(self, op, operand_idx):
1619 # type: (Op, int) -> None
1620 super().__init__()
1621 self.op = op
1622 if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
1623 raise ValueError("invalid operand_idx")
1624 self.operand_idx = operand_idx
1625
1626 @abstractmethod
1627 def __repr__(self):
1628 # type: () -> str
1629 ...
1630
1631 @property
1632 @abstractmethod
1633 def descriptor_array(self):
1634 # type: () -> tuple[OperandDesc, ...]
1635 ...
1636
1637 @cached_property
1638 def defining_descriptor(self):
1639 # type: () -> OperandDesc
1640 return self.descriptor_array[self.operand_idx]
1641
1642 @cached_property
1643 def ty(self):
1644 # type: () -> Ty
1645 return self.defining_descriptor.ty
1646
1647 @cached_property
1648 def ty_before_spread(self):
1649 # type: () -> Ty
1650 return self.defining_descriptor.ty_before_spread
1651
1652 @property
1653 def base_ty(self):
1654 # type: () -> BaseTy
1655 return self.ty_before_spread.base_ty
1656
1657 @property
1658 def reg_offset_in_unspread(self):
1659 """ the number of reg-sized slots in the unspread Loc before self's Loc
1660
1661 e.g. if the unspread Loc containing self is:
1662 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1663 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1664 then reg_offset_into_unspread == 2 == 10 - 8
1665 """
1666 return self.defining_descriptor.reg_offset_in_unspread
1667
1668 @property
1669 def unspread_start_idx(self):
1670 # type: () -> int
1671 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1672
1673 @property
1674 def unspread_start(self):
1675 # type: () -> Self
1676 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1677
1678
1679 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1680 @final
1681 class SSAVal(SSAValOrUse):
1682 __slots__ = ()
1683
1684 def __repr__(self):
1685 # type: () -> str
1686 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1687
1688 @cached_property
1689 def def_loc_set_before_spread(self):
1690 # type: () -> LocSet
1691 return self.defining_descriptor.loc_set_before_spread
1692
1693 @cached_property
1694 def descriptor_array(self):
1695 # type: () -> tuple[OperandDesc, ...]
1696 return self.op.properties.outputs
1697
1698 @cached_property
1699 def tied_input(self):
1700 # type: () -> None | SSAUse
1701 if self.defining_descriptor.tied_input_index is None:
1702 return None
1703 return SSAUse(op=self.op,
1704 operand_idx=self.defining_descriptor.tied_input_index)
1705
1706 @property
1707 def write_stage(self):
1708 # type: () -> OpStage
1709 return self.defining_descriptor.write_stage
1710
1711
1712 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1713 @final
1714 class SSAUse(SSAValOrUse):
1715 __slots__ = ()
1716
1717 @cached_property
1718 def use_loc_set_before_spread(self):
1719 # type: () -> LocSet
1720 return self.defining_descriptor.loc_set_before_spread
1721
1722 @cached_property
1723 def descriptor_array(self):
1724 # type: () -> tuple[OperandDesc, ...]
1725 return self.op.properties.inputs
1726
1727 def __repr__(self):
1728 # type: () -> str
1729 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1730
1731 @property
1732 def ssa_val(self):
1733 # type: () -> SSAVal
1734 return self.op.input_vals[self.operand_idx]
1735
1736 @ssa_val.setter
1737 def ssa_val(self, ssa_val):
1738 # type: (SSAVal) -> None
1739 self.op.input_vals[self.operand_idx] = ssa_val
1740
1741
1742 _T = TypeVar("_T")
1743 _Desc = TypeVar("_Desc")
1744
1745
1746 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
1747 @abstractmethod
1748 def _verify_write_with_desc(self, idx, item, desc):
1749 # type: (int, _T | Any, _Desc) -> None
1750 raise NotImplementedError
1751
1752 @final
1753 def _verify_write(self, idx, item):
1754 # type: (int | Any, _T | Any) -> int
1755 if not isinstance(idx, int):
1756 if isinstance(idx, slice):
1757 raise TypeError(
1758 f"can't write to slice of {self.__class__.__name__}")
1759 raise TypeError(f"can't write with index {idx!r}")
1760 # normalize idx, raising IndexError if it is out of range
1761 idx = range(len(self.descriptors))[idx]
1762 desc = self.descriptors[idx]
1763 self._verify_write_with_desc(idx, item, desc)
1764 return idx
1765
1766 def _on_set(self, idx, new_item, old_item):
1767 # type: (int, _T, _T | None) -> None
1768 pass
1769
1770 @abstractmethod
1771 def _get_descriptors(self):
1772 # type: () -> tuple[_Desc, ...]
1773 raise NotImplementedError
1774
1775 @cached_property
1776 @final
1777 def descriptors(self):
1778 # type: () -> tuple[_Desc, ...]
1779 return self._get_descriptors()
1780
1781 @property
1782 @final
1783 def op(self):
1784 return self.__op
1785
1786 def __init__(self, items, op):
1787 # type: (Iterable[_T], Op) -> None
1788 super().__init__()
1789 self.__op = op
1790 self.__items = [] # type: list[_T]
1791 for idx, item in enumerate(items):
1792 if idx >= len(self.descriptors):
1793 raise ValueError("too many items")
1794 _ = self._verify_write(idx, item)
1795 self.__items.append(item)
1796 if len(self.__items) < len(self.descriptors):
1797 raise ValueError("not enough items")
1798
1799 @final
1800 def __iter__(self):
1801 # type: () -> Iterator[_T]
1802 yield from self.__items
1803
1804 @overload
1805 def __getitem__(self, idx):
1806 # type: (int) -> _T
1807 ...
1808
1809 @overload
1810 def __getitem__(self, idx):
1811 # type: (slice) -> list[_T]
1812 ...
1813
1814 @final
1815 def __getitem__(self, idx):
1816 # type: (int | slice) -> _T | list[_T]
1817 return self.__items[idx]
1818
1819 @final
1820 def __setitem__(self, idx, item):
1821 # type: (int, _T) -> None
1822 idx = self._verify_write(idx, item)
1823 self.__items[idx] = item
1824
1825 @final
1826 def __len__(self):
1827 # type: () -> int
1828 return len(self.__items)
1829
1830 def __repr__(self):
1831 # type: () -> str
1832 return f"{self.__class__.__name__}({self.__items}, op=...)"
1833
1834
1835 @final
1836 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
1837 def _get_descriptors(self):
1838 # type: () -> tuple[OperandDesc, ...]
1839 return self.op.properties.inputs
1840
1841 def _verify_write_with_desc(self, idx, item, desc):
1842 # type: (int, SSAVal | Any, OperandDesc) -> None
1843 if not isinstance(item, SSAVal):
1844 raise TypeError("expected value of type SSAVal")
1845 if item.ty != desc.ty:
1846 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
1847 f"corresponding input's type {desc.ty!r}")
1848
1849 def _on_set(self, idx, new_item, old_item):
1850 # type: (int, SSAVal, SSAVal | None) -> None
1851 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
1852
1853 def __init__(self, items, op):
1854 # type: (Iterable[SSAVal], Op) -> None
1855 if hasattr(op, "inputs"):
1856 raise ValueError("Op.inputs already set")
1857 super().__init__(items, op)
1858
1859
1860 @final
1861 class OpImmediates(OpInputSeq[int, range]):
1862 def _get_descriptors(self):
1863 # type: () -> tuple[range, ...]
1864 return self.op.properties.immediates
1865
1866 def _verify_write_with_desc(self, idx, item, desc):
1867 # type: (int, int | Any, range) -> None
1868 if not isinstance(item, int):
1869 raise TypeError("expected value of type int")
1870 if item not in desc:
1871 raise ValueError(f"immediate value {item!r} not in {desc!r}")
1872
1873 def __init__(self, items, op):
1874 # type: (Iterable[int], Op) -> None
1875 if hasattr(op, "immediates"):
1876 raise ValueError("Op.immediates already set")
1877 super().__init__(items, op)
1878
1879
1880 @plain_data(frozen=True, eq=False, repr=False)
1881 @final
1882 class Op:
1883 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
1884 "outputs", "name")
1885
1886 def __init__(self, fn, properties, input_vals, immediates, name=""):
1887 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1888 self.fn = fn
1889 self.properties = properties
1890 self.input_vals = OpInputVals(input_vals, op=self)
1891 inputs_len = len(self.properties.inputs)
1892 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
1893 self.immediates = OpImmediates(immediates, op=self)
1894 outputs_len = len(self.properties.outputs)
1895 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
1896 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
1897
1898 @property
1899 def kind(self):
1900 # type: () -> OpKind
1901 return self.properties.kind
1902
1903 def __eq__(self, other):
1904 # type: (Op | Any) -> bool
1905 if isinstance(other, Op):
1906 return self is other
1907 return NotImplemented
1908
1909 def __hash__(self):
1910 # type: () -> int
1911 return object.__hash__(self)
1912
1913 def __repr__(self):
1914 # type: () -> str
1915 field_vals = [] # type: list[str]
1916 for name in fields(self):
1917 if name == "properties":
1918 name = "kind"
1919 elif name == "fn":
1920 continue
1921 try:
1922 value = getattr(self, name)
1923 except AttributeError:
1924 field_vals.append(f"{name}=<not set>")
1925 continue
1926 if isinstance(value, OpInputSeq):
1927 value = list(value) # type: ignore
1928 field_vals.append(f"{name}={value!r}")
1929 field_vals_str = ", ".join(field_vals)
1930 return f"Op({field_vals_str})"
1931
1932 def sim(self, state):
1933 # type: (BaseSimState) -> None
1934 for inp in self.input_vals:
1935 try:
1936 val = state[inp]
1937 except KeyError:
1938 raise ValueError(f"SSAVal {inp} not yet assigned when "
1939 f"running {self}")
1940 if len(val) != inp.ty.reg_len:
1941 raise ValueError(
1942 f"value of SSAVal {inp} has wrong number of elements: "
1943 f"expected {inp.ty.reg_len} found "
1944 f"{len(val)}: {val!r}")
1945 if isinstance(state, PreRASimState):
1946 for out in self.outputs:
1947 if out in state.ssa_vals:
1948 if self.kind is OpKind.FuncArgR3:
1949 continue
1950 raise ValueError(f"SSAVal {out} already assigned before "
1951 f"running {self}")
1952 self.kind.sim(self, state)
1953 for out in self.outputs:
1954 try:
1955 val = state[out]
1956 except KeyError:
1957 raise ValueError(f"running {self} failed to assign to {out}")
1958 if len(val) != out.ty.reg_len:
1959 raise ValueError(
1960 f"value of SSAVal {out} has wrong number of elements: "
1961 f"expected {out.ty.reg_len} found "
1962 f"{len(val)}: {val!r}")
1963
1964 def gen_asm(self, state):
1965 # type: (GenAsmState) -> None
1966 all_loc_kinds = tuple(LocKind)
1967 for inp in self.input_vals:
1968 state.loc(inp, expected_kinds=all_loc_kinds)
1969 for out in self.outputs:
1970 state.loc(out, expected_kinds=all_loc_kinds)
1971 self.kind.gen_asm(self, state)
1972
1973
1974 GPR_SIZE_IN_BYTES = 8
1975 BITS_IN_BYTE = 8
1976 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
1977 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
1978
1979
1980 @plain_data(frozen=True, repr=False)
1981 class BaseSimState(metaclass=ABCMeta):
1982 __slots__ = "memory",
1983
1984 def __init__(self, memory):
1985 # type: (dict[int, int]) -> None
1986 super().__init__()
1987 self.memory = memory # type: dict[int, int]
1988
1989 def load_byte(self, addr):
1990 # type: (int) -> int
1991 addr &= GPR_VALUE_MASK
1992 return self.memory.get(addr, 0) & 0xFF
1993
1994 def store_byte(self, addr, value):
1995 # type: (int, int) -> None
1996 addr &= GPR_VALUE_MASK
1997 value &= 0xFF
1998 self.memory[addr] = value
1999
2000 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
2001 # type: (int, int, bool) -> int
2002 if addr % size_in_bytes != 0:
2003 raise ValueError(f"address not aligned: {hex(addr)} "
2004 f"required alignment: {size_in_bytes}")
2005 retval = 0
2006 for i in range(size_in_bytes):
2007 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
2008 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
2009 retval -= 1 << size_in_bytes * BITS_IN_BYTE
2010 return retval
2011
2012 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
2013 # type: (int, int, int) -> None
2014 if addr % size_in_bytes != 0:
2015 raise ValueError(f"address not aligned: {hex(addr)} "
2016 f"required alignment: {size_in_bytes}")
2017 for i in range(size_in_bytes):
2018 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
2019
2020 def _memory__repr(self):
2021 # type: () -> str
2022 if len(self.memory) == 0:
2023 return "{}"
2024 keys = sorted(self.memory.keys(), reverse=True)
2025 CHUNK_SIZE = GPR_SIZE_IN_BYTES
2026 items = [] # type: list[str]
2027 while len(keys) != 0:
2028 addr = keys[-1]
2029 if (len(keys) >= CHUNK_SIZE
2030 and addr % CHUNK_SIZE == 0
2031 and keys[-CHUNK_SIZE:]
2032 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
2033 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
2034 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2035 keys[-CHUNK_SIZE:] = ()
2036 else:
2037 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2038 if len(items) == 1:
2039 return f"{{{items[0]}}}"
2040 items_str = ",\n".join(items)
2041 return f"{{\n{items_str}}}"
2042
2043 def __repr__(self):
2044 # type: () -> str
2045 field_vals = [] # type: list[str]
2046 for name in fields(self):
2047 try:
2048 value = getattr(self, name)
2049 except AttributeError:
2050 field_vals.append(f"{name}=<not set>")
2051 continue
2052 repr_fn = getattr(self, f"_{name}__repr", None)
2053 if callable(repr_fn):
2054 field_vals.append(f"{name}={repr_fn()}")
2055 else:
2056 field_vals.append(f"{name}={value!r}")
2057 field_vals_str = ", ".join(field_vals)
2058 return f"{self.__class__.__name__}({field_vals_str})"
2059
2060 @abstractmethod
2061 def __getitem__(self, ssa_val):
2062 # type: (SSAVal) -> tuple[int, ...]
2063 ...
2064
2065 @abstractmethod
2066 def __setitem__(self, ssa_val, value):
2067 # type: (SSAVal, tuple[int, ...]) -> None
2068 ...
2069
2070
2071 @plain_data(frozen=True, repr=False)
2072 @final
2073 class PreRASimState(BaseSimState):
2074 __slots__ = "ssa_vals",
2075
2076 def __init__(self, ssa_vals, memory):
2077 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2078 super().__init__(memory)
2079 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
2080
2081 def _ssa_vals__repr(self):
2082 # type: () -> str
2083 if len(self.ssa_vals) == 0:
2084 return "{}"
2085 items = [] # type: list[str]
2086 CHUNK_SIZE = 4
2087 for k, v in self.ssa_vals.items():
2088 element_strs = [] # type: list[str]
2089 for i, el in enumerate(v):
2090 if i % CHUNK_SIZE != 0:
2091 element_strs.append(" " + hex(el))
2092 else:
2093 element_strs.append("\n " + hex(el))
2094 if len(element_strs) <= CHUNK_SIZE:
2095 element_strs[0] = element_strs[0].lstrip()
2096 if len(element_strs) == 1:
2097 element_strs.append("")
2098 v_str = ",".join(element_strs)
2099 items.append(f"{k!r}: ({v_str})")
2100 if len(items) == 1 and "\n" not in items[0]:
2101 return f"{{{items[0]}}}"
2102 items_str = ",\n".join(items)
2103 return f"{{\n{items_str},\n}}"
2104
2105 def __getitem__(self, ssa_val):
2106 # type: (SSAVal) -> tuple[int, ...]
2107 return self.ssa_vals[ssa_val]
2108
2109 def __setitem__(self, ssa_val, value):
2110 # type: (SSAVal, tuple[int, ...]) -> None
2111 if len(value) != ssa_val.ty.reg_len:
2112 raise ValueError("value has wrong len")
2113 self.ssa_vals[ssa_val] = value
2114
2115
2116 @plain_data(frozen=True, repr=False)
2117 @final
2118 class PostRASimState(BaseSimState):
2119 __slots__ = "ssa_val_to_loc_map", "loc_values"
2120
2121 def __init__(self, ssa_val_to_loc_map, memory, loc_values):
2122 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2123 super().__init__(memory)
2124 self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
2125 for ssa_val, loc in self.ssa_val_to_loc_map.items():
2126 if ssa_val.ty != loc.ty:
2127 raise ValueError(
2128 f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2129 self.loc_values = loc_values
2130 for loc in self.loc_values.keys():
2131 if loc.reg_len != 1:
2132 raise ValueError(
2133 "loc_values must only contain Locs with reg_len=1, all "
2134 "larger Locs will be split into reg_len=1 sub-Locs")
2135
2136 def _loc_values__repr(self):
2137 # type: () -> str
2138 locs = sorted(self.loc_values.keys(), key=lambda v: (v.kind, v.start))
2139 items = [] # type: list[str]
2140 for loc in locs:
2141 items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
2142 items_str = ",\n".join(items)
2143 return f"{{\n{items_str},\n}}"
2144
2145 def __getitem__(self, ssa_val):
2146 # type: (SSAVal) -> tuple[int, ...]
2147 loc = self.ssa_val_to_loc_map[ssa_val]
2148 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2149 retval = [] # type: list[int]
2150 for i in range(loc.reg_len):
2151 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2152 retval.append(self.loc_values.get(subloc, 0))
2153 return tuple(retval)
2154
2155 def __setitem__(self, ssa_val, value):
2156 # type: (SSAVal, tuple[int, ...]) -> None
2157 if len(value) != ssa_val.ty.reg_len:
2158 raise ValueError("value has wrong len")
2159 loc = self.ssa_val_to_loc_map[ssa_val]
2160 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2161 for i in range(loc.reg_len):
2162 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2163 self.loc_values[subloc] = value[i]
2164
2165
2166 @plain_data(frozen=True)
2167 class GenAsmState:
2168 __slots__ = "allocated_locs", "output"
2169
2170 def __init__(self, allocated_locs, output=None):
2171 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2172 super().__init__()
2173 self.allocated_locs = FMap(allocated_locs)
2174 for ssa_val, loc in self.allocated_locs.items():
2175 if ssa_val.ty != loc.ty:
2176 raise ValueError(
2177 f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2178 if output is None:
2179 output = []
2180 self.output = output
2181
2182 __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
2183
2184 def loc(self, ssa_val_or_locs, expected_kinds):
2185 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2186 if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
2187 ssa_val_or_locs = [ssa_val_or_locs]
2188 locs = [] # type: list[Loc]
2189 for i in ssa_val_or_locs:
2190 if isinstance(i, SSAVal):
2191 locs.append(self.allocated_locs[i])
2192 else:
2193 locs.append(i)
2194 if len(locs) == 0:
2195 raise ValueError("invalid Loc sequence: must not be empty")
2196 retval = locs[0].try_concat(*locs[1:])
2197 if retval is None:
2198 raise ValueError("invalid Loc sequence: try_concat failed")
2199 if isinstance(expected_kinds, LocKind):
2200 expected_kinds = expected_kinds,
2201 if retval.kind not in expected_kinds:
2202 if len(expected_kinds) == 1:
2203 expected_kinds = expected_kinds[0]
2204 raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
2205 f"{retval.kind} expected {expected_kinds}")
2206 return retval
2207
2208 def gpr(self, ssa_val_or_locs, is_vec):
2209 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2210 loc = self.loc(ssa_val_or_locs, LocKind.GPR)
2211 vec_str = "*" if is_vec else ""
2212 return vec_str + str(loc.start)
2213
2214 def sgpr(self, ssa_val_or_locs):
2215 # type: (__SSA_VAL_OR_LOCS) -> str
2216 return self.gpr(ssa_val_or_locs, is_vec=False)
2217
2218 def vgpr(self, ssa_val_or_locs):
2219 # type: (__SSA_VAL_OR_LOCS) -> str
2220 return self.gpr(ssa_val_or_locs, is_vec=True)
2221
2222 def stack(self, ssa_val_or_locs):
2223 # type: (__SSA_VAL_OR_LOCS) -> str
2224 loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
2225 return f"{loc.start}(1)"
2226
2227 def writeln(self, *line_segments):
2228 # type: (*str) -> None
2229 line = " ".join(line_segments)
2230 if isinstance(self.output, list):
2231 self.output.append(line)
2232 else:
2233 self.output.write(line + "\n")