working on code
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 import enum
2 from abc import ABCMeta, abstractmethod
3 from enum import Enum, unique
4 from functools import lru_cache, total_ordering
5 from io import StringIO
6 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
7 Mapping, Sequence, TypeVar, Union, overload)
8 from weakref import WeakValueDictionary as _WeakVDict
9
10 from cached_property import cached_property
11 from nmutil.plain_data import fields, plain_data
12
13 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
14 final)
15 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
16 OFSet, OSet)
17
18 GPR_SIZE_IN_BYTES = 8
19 BITS_IN_BYTE = 8
20 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
21 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
22
23
24 @final
25 class Fn:
26 def __init__(self):
27 self.ops = [] # type: list[Op]
28 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
29 self.__next_name_suffix = 2
30
31 def _add_op_with_unused_name(self, op, name=""):
32 # type: (Op, str) -> str
33 if op.fn is not self:
34 raise ValueError("can't add Op to wrong Fn")
35 if hasattr(op, "name"):
36 raise ValueError("Op already named")
37 orig_name = name
38 while True:
39 if name != "" and name not in self.__op_names:
40 self.__op_names[name] = op
41 return name
42 name = orig_name + str(self.__next_name_suffix)
43 self.__next_name_suffix += 1
44
45 def __repr__(self):
46 # type: () -> str
47 return "<Fn>"
48
49 def ops_to_str(self, as_python_literal=False, wrap_width=63,
50 python_indent=" ", indent=" "):
51 # type: (bool, int, str, str) -> str
52 l = [] # type: list[str]
53 for op in self.ops:
54 l.append(op.__repr__(wrap_width=wrap_width, indent=indent))
55 retval = "\n".join(l)
56 if as_python_literal:
57 l = [python_indent + "\""]
58 for ch in retval:
59 if ch == "\n":
60 l.append(f"\\n\"\n{python_indent}\"")
61 elif ch in "\"\\":
62 l.append("\\" + ch)
63 elif ch.isascii() and ch.isprintable():
64 l.append(ch)
65 else:
66 l.append(repr(ch).strip("\"'"))
67 l.append("\"")
68 retval = "".join(l)
69 empty_end = f"\"\n{python_indent}\"\""
70 if retval.endswith(empty_end):
71 retval = retval[:-len(empty_end)]
72 return retval
73
74 def append_op(self, op):
75 # type: (Op) -> None
76 if op.fn is not self:
77 raise ValueError("can't add Op to wrong Fn")
78 self.ops.append(op)
79
80 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
81 maxvl=1):
82 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
83 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
84 input_vals=input_vals, immediates=immediates, name=name)
85 self.append_op(retval)
86 return retval
87
88 def sim(self, state):
89 # type: (BaseSimState) -> None
90 for op in self.ops:
91 op.sim(state)
92
93 def gen_asm(self, state):
94 # type: (GenAsmState) -> None
95 for op in self.ops:
96 op.gen_asm(state)
97
98 def pre_ra_insert_copies(self):
99 # type: () -> None
100 orig_ops = list(self.ops)
101 copied_outputs = {} # type: dict[SSAVal, SSAVal]
102 setvli_outputs = {} # type: dict[SSAVal, Op]
103 self.ops.clear()
104 for op in orig_ops:
105 for i in range(len(op.input_vals)):
106 inp = copied_outputs[op.input_vals[i]]
107 if inp.ty.base_ty is BaseTy.I64:
108 maxvl = inp.ty.reg_len
109 if inp.ty.reg_len != 1:
110 setvl = self.append_new_op(
111 OpKind.SetVLI, immediates=[maxvl],
112 name=f"{op.name}.inp{i}.setvl")
113 vl = setvl.outputs[0]
114 mv = self.append_new_op(
115 OpKind.VecCopyToReg, input_vals=[inp, vl],
116 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
117 else:
118 mv = self.append_new_op(
119 OpKind.CopyToReg, input_vals=[inp],
120 name=f"{op.name}.inp{i}.copy")
121 op.input_vals[i] = mv.outputs[0]
122 elif inp.ty.base_ty is BaseTy.CA \
123 or inp.ty.base_ty is BaseTy.VL_MAXVL:
124 # all copies would be no-ops, so we don't need to copy,
125 # though we do need to rematerialize SetVLI ops right
126 # before the ops VL
127 if inp in setvli_outputs:
128 setvl = self.append_new_op(
129 OpKind.SetVLI,
130 immediates=setvli_outputs[inp].immediates,
131 name=f"{op.name}.inp{i}.setvl")
132 inp = setvl.outputs[0]
133 op.input_vals[i] = inp
134 else:
135 assert_never(inp.ty.base_ty)
136 self.ops.append(op)
137 for i, out in enumerate(op.outputs):
138 if op.kind is OpKind.SetVLI:
139 setvli_outputs[out] = op
140 if out.ty.base_ty is BaseTy.I64:
141 maxvl = out.ty.reg_len
142 if out.ty.reg_len != 1:
143 setvl = self.append_new_op(
144 OpKind.SetVLI, immediates=[maxvl],
145 name=f"{op.name}.out{i}.setvl")
146 vl = setvl.outputs[0]
147 mv = self.append_new_op(
148 OpKind.VecCopyFromReg, input_vals=[out, vl],
149 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
150 else:
151 mv = self.append_new_op(
152 OpKind.CopyFromReg, input_vals=[out],
153 name=f"{op.name}.out{i}.copy")
154 copied_outputs[out] = mv.outputs[0]
155 elif out.ty.base_ty is BaseTy.CA \
156 or out.ty.base_ty is BaseTy.VL_MAXVL:
157 # all copies would be no-ops, so we don't need to copy
158 copied_outputs[out] = out
159 else:
160 assert_never(out.ty.base_ty)
161
162
163 @final
164 @unique
165 @total_ordering
166 class OpStage(Enum):
167 value: Literal[0, 1] # type: ignore
168
169 def __new__(cls, value):
170 # type: (int) -> OpStage
171 value = int(value)
172 if value not in (0, 1):
173 raise ValueError("invalid value")
174 retval = object.__new__(cls)
175 retval._value_ = value
176 return retval
177
178 Early = 0
179 """ early stage of Op execution, where all input reads occur.
180 all output writes with `write_stage == Early` occur here too, and therefore
181 conflict with input reads, telling the compiler that it that can't share
182 that output's register with any inputs that the output isn't tied to.
183
184 All outputs, even unused outputs, can't share registers with any other
185 outputs, independent of `write_stage` settings.
186 """
187 Late = 1
188 """ late stage of Op execution, where all output writes with
189 `write_stage == Late` occur, and therefore don't conflict with input reads,
190 telling the compiler that any inputs can safely use the same register as
191 those outputs.
192
193 All outputs, even unused outputs, can't share registers with any other
194 outputs, independent of `write_stage` settings.
195 """
196
197 def __repr__(self):
198 # type: () -> str
199 return f"OpStage.{self._name_}"
200
201 def __lt__(self, other):
202 # type: (OpStage | object) -> bool
203 if isinstance(other, OpStage):
204 return self.value < other.value
205 return NotImplemented
206
207
208 assert OpStage.Early < OpStage.Late, "early must be less than late"
209
210
211 @plain_data(frozen=True, unsafe_hash=True, repr=False)
212 @final
213 @total_ordering
214 class ProgramPoint(metaclass=InternedMeta):
215 __slots__ = "op_index", "stage"
216
217 def __init__(self, op_index, stage):
218 # type: (int, OpStage) -> None
219 self.op_index = op_index
220 self.stage = stage
221
222 @property
223 def int_value(self):
224 # type: () -> int
225 """ an integer representation of `self` such that it keeps ordering and
226 successor/predecessor relations.
227 """
228 return self.op_index * 2 + self.stage.value
229
230 @staticmethod
231 def from_int_value(int_value):
232 # type: (int) -> ProgramPoint
233 op_index, stage = divmod(int_value, 2)
234 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
235
236 def next(self, steps=1):
237 # type: (int) -> ProgramPoint
238 return ProgramPoint.from_int_value(self.int_value + steps)
239
240 def prev(self, steps=1):
241 # type: (int) -> ProgramPoint
242 return self.next(steps=-steps)
243
244 def __lt__(self, other):
245 # type: (ProgramPoint | Any) -> bool
246 if not isinstance(other, ProgramPoint):
247 return NotImplemented
248 if self.op_index != other.op_index:
249 return self.op_index < other.op_index
250 return self.stage < other.stage
251
252 def __repr__(self):
253 # type: () -> str
254 return f"<ops[{self.op_index}]:{self.stage._name_}>"
255
256
257 @plain_data(frozen=True, unsafe_hash=True, repr=False)
258 @final
259 class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
260 __slots__ = "start", "stop"
261
262 def __init__(self, start, stop):
263 # type: (ProgramPoint, ProgramPoint) -> None
264 self.start = start
265 self.stop = stop
266
267 @cached_property
268 def int_value_range(self):
269 # type: () -> range
270 return range(self.start.int_value, self.stop.int_value)
271
272 @staticmethod
273 def from_int_value_range(int_value_range):
274 # type: (range) -> ProgramRange
275 if int_value_range.step != 1:
276 raise ValueError("int_value_range must have step == 1")
277 return ProgramRange(
278 start=ProgramPoint.from_int_value(int_value_range.start),
279 stop=ProgramPoint.from_int_value(int_value_range.stop))
280
281 @overload
282 def __getitem__(self, __idx):
283 # type: (int) -> ProgramPoint
284 ...
285
286 @overload
287 def __getitem__(self, __idx):
288 # type: (slice) -> ProgramRange
289 ...
290
291 def __getitem__(self, __idx):
292 # type: (int | slice) -> ProgramPoint | ProgramRange
293 v = range(self.start.int_value, self.stop.int_value)[__idx]
294 if isinstance(v, int):
295 return ProgramPoint.from_int_value(v)
296 return ProgramRange.from_int_value_range(v)
297
298 def __len__(self):
299 # type: () -> int
300 return len(self.int_value_range)
301
302 def __iter__(self):
303 # type: () -> Iterator[ProgramPoint]
304 return map(ProgramPoint.from_int_value, self.int_value_range)
305
306 def __repr__(self):
307 # type: () -> str
308 start = repr(self.start).lstrip("<").rstrip(">")
309 stop = repr(self.stop).lstrip("<").rstrip(">")
310 return f"<range:{start}..{stop}>"
311
312
313 @plain_data(frozen=True, eq=False, repr=False)
314 @final
315 class FnAnalysis:
316 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
317 "def_program_ranges", "use_program_points",
318 "all_program_points")
319
320 def __init__(self, fn):
321 # type: (Fn) -> None
322 self.fn = fn
323 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
324 self.all_program_points = ProgramRange(
325 start=ProgramPoint(op_index=0, stage=OpStage.Early),
326 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
327 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
328 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
329 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
330 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
331 for op in fn.ops:
332 for use in op.input_uses:
333 uses[use.ssa_val].add(use)
334 use_program_point = self.__get_use_program_point(use)
335 use_program_points[use] = use_program_point
336 live_range_stops[use.ssa_val] = max(
337 live_range_stops[use.ssa_val], use_program_point.next())
338 for out in op.outputs:
339 uses[out] = OSet()
340 def_program_range = self.__get_def_program_range(out)
341 def_program_ranges[out] = def_program_range
342 live_range_stops[out] = def_program_range.stop
343 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
344 self.def_program_ranges = FMap(def_program_ranges)
345 self.use_program_points = FMap(use_program_points)
346 live_ranges = {} # type: dict[SSAVal, ProgramRange]
347 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
348 for ssa_val in uses.keys():
349 live_ranges[ssa_val] = live_range = ProgramRange(
350 start=self.def_program_ranges[ssa_val].start,
351 stop=live_range_stops[ssa_val])
352 for program_point in live_range:
353 live_at[program_point].add(ssa_val)
354 self.live_ranges = FMap(live_ranges)
355 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
356
357 def __get_def_program_range(self, ssa_val):
358 # type: (SSAVal) -> ProgramRange
359 write_stage = ssa_val.defining_descriptor.write_stage
360 start = ProgramPoint(
361 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
362 # always include late stage of ssa_val.op, to ensure outputs always
363 # overlap all other outputs.
364 # stop is exclusive, so we need the next program point.
365 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
366 return ProgramRange(start=start, stop=stop)
367
368 def __get_use_program_point(self, ssa_use):
369 # type: (SSAUse) -> ProgramPoint
370 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
371 "assumed here, ensured by GenericOpProperties.__init__"
372 return ProgramPoint(
373 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
374
375 def __eq__(self, other):
376 # type: (FnAnalysis | Any) -> bool
377 if isinstance(other, FnAnalysis):
378 return self.fn == other.fn
379 return NotImplemented
380
381 def __hash__(self):
382 # type: () -> int
383 return hash(self.fn)
384
385 def __repr__(self):
386 # type: () -> str
387 return "<FnAnalysis>"
388
389
390 @unique
391 @final
392 class BaseTy(Enum):
393 I64 = enum.auto()
394 CA = enum.auto()
395 VL_MAXVL = enum.auto()
396
397 @cached_property
398 def only_scalar(self):
399 # type: () -> bool
400 if self is BaseTy.I64:
401 return False
402 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
403 return True
404 else:
405 assert_never(self)
406
407 @cached_property
408 def max_reg_len(self):
409 # type: () -> int
410 if self is BaseTy.I64:
411 return 128
412 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
413 return 1
414 else:
415 assert_never(self)
416
417 def __repr__(self):
418 return "BaseTy." + self._name_
419
420
421 @plain_data(frozen=True, unsafe_hash=True, repr=False)
422 @final
423 class Ty(metaclass=InternedMeta):
424 __slots__ = "base_ty", "reg_len"
425
426 @staticmethod
427 def validate(base_ty, reg_len):
428 # type: (BaseTy, int) -> str | None
429 """ return a string with the error if the combination is invalid,
430 otherwise return None
431 """
432 if base_ty.only_scalar and reg_len != 1:
433 return f"can't create a vector of an only-scalar type: {base_ty}"
434 if reg_len < 1 or reg_len > base_ty.max_reg_len:
435 return "reg_len out of range"
436 return None
437
438 def __init__(self, base_ty, reg_len):
439 # type: (BaseTy, int) -> None
440 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
441 if msg is not None:
442 raise ValueError(msg)
443 self.base_ty = base_ty
444 self.reg_len = reg_len
445
446 def __repr__(self):
447 # type: () -> str
448 if self.reg_len != 1:
449 reg_len = f"*{self.reg_len}"
450 else:
451 reg_len = ""
452 return f"<{self.base_ty._name_}{reg_len}>"
453
454
455 @unique
456 @final
457 class LocKind(Enum):
458 GPR = enum.auto()
459 StackI64 = enum.auto()
460 CA = enum.auto()
461 VL_MAXVL = enum.auto()
462
463 @cached_property
464 def base_ty(self):
465 # type: () -> BaseTy
466 if self is LocKind.GPR or self is LocKind.StackI64:
467 return BaseTy.I64
468 if self is LocKind.CA:
469 return BaseTy.CA
470 if self is LocKind.VL_MAXVL:
471 return BaseTy.VL_MAXVL
472 else:
473 assert_never(self)
474
475 @cached_property
476 def loc_count(self):
477 # type: () -> int
478 if self is LocKind.StackI64:
479 return 512
480 if self is LocKind.GPR or self is LocKind.CA \
481 or self is LocKind.VL_MAXVL:
482 return self.base_ty.max_reg_len
483 else:
484 assert_never(self)
485
486 def __repr__(self):
487 return "LocKind." + self._name_
488
489
490 @final
491 @unique
492 class LocSubKind(Enum):
493 BASE_GPR = enum.auto()
494 SV_EXTRA2_VGPR = enum.auto()
495 SV_EXTRA2_SGPR = enum.auto()
496 SV_EXTRA3_VGPR = enum.auto()
497 SV_EXTRA3_SGPR = enum.auto()
498 StackI64 = enum.auto()
499 CA = enum.auto()
500 VL_MAXVL = enum.auto()
501
502 @cached_property
503 def kind(self):
504 # type: () -> LocKind
505 # pyright fails typechecking when using `in` here:
506 # reported: https://github.com/microsoft/pyright/issues/4102
507 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
508 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
509 LocSubKind.SV_EXTRA3_SGPR):
510 return LocKind.GPR
511 if self is LocSubKind.StackI64:
512 return LocKind.StackI64
513 if self is LocSubKind.CA:
514 return LocKind.CA
515 if self is LocSubKind.VL_MAXVL:
516 return LocKind.VL_MAXVL
517 assert_never(self)
518
519 @property
520 def base_ty(self):
521 return self.kind.base_ty
522
523 @lru_cache()
524 def allocatable_locs(self, ty):
525 # type: (Ty) -> LocSet
526 if ty.base_ty != self.base_ty:
527 raise ValueError("type mismatch")
528 if self is LocSubKind.BASE_GPR:
529 starts = range(32)
530 elif self is LocSubKind.SV_EXTRA2_VGPR:
531 starts = range(0, 128, 2)
532 elif self is LocSubKind.SV_EXTRA2_SGPR:
533 starts = range(64)
534 elif self is LocSubKind.SV_EXTRA3_VGPR \
535 or self is LocSubKind.SV_EXTRA3_SGPR:
536 starts = range(128)
537 elif self is LocSubKind.StackI64:
538 starts = range(LocKind.StackI64.loc_count)
539 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
540 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
541 else:
542 assert_never(self)
543 retval = [] # type: list[Loc]
544 for start in starts:
545 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
546 if loc is None:
547 continue
548 conflicts = False
549 for special_loc in SPECIAL_GPRS:
550 if loc.conflicts(special_loc):
551 conflicts = True
552 break
553 if not conflicts:
554 retval.append(loc)
555 return LocSet(retval)
556
557 def __repr__(self):
558 return "LocSubKind." + self._name_
559
560
561 @plain_data(frozen=True, unsafe_hash=True)
562 @final
563 class GenericTy(metaclass=InternedMeta):
564 __slots__ = "base_ty", "is_vec"
565
566 def __init__(self, base_ty, is_vec):
567 # type: (BaseTy, bool) -> None
568 self.base_ty = base_ty
569 if base_ty.only_scalar and is_vec:
570 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
571 self.is_vec = is_vec
572
573 def instantiate(self, maxvl):
574 # type: (int) -> Ty
575 # here's where subvl and elwid would be accounted for
576 if self.is_vec:
577 return Ty(self.base_ty, maxvl)
578 return Ty(self.base_ty, 1)
579
580 def can_instantiate_to(self, ty):
581 # type: (Ty) -> bool
582 if self.base_ty != ty.base_ty:
583 return False
584 if self.is_vec:
585 return True
586 return ty.reg_len == 1
587
588
589 @plain_data(frozen=True, unsafe_hash=True)
590 @final
591 class Loc(metaclass=InternedMeta):
592 __slots__ = "kind", "start", "reg_len"
593
594 @staticmethod
595 def validate(kind, start, reg_len):
596 # type: (LocKind, int, int) -> str | None
597 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
598 if msg is not None:
599 return msg
600 if reg_len > kind.loc_count:
601 return "invalid reg_len"
602 if start < 0 or start + reg_len > kind.loc_count:
603 return "start not in valid range"
604 return None
605
606 @staticmethod
607 def try_make(kind, start, reg_len):
608 # type: (LocKind, int, int) -> Loc | None
609 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
610 if msg is not None:
611 return None
612 return Loc(kind=kind, start=start, reg_len=reg_len)
613
614 def __init__(self, kind, start, reg_len):
615 # type: (LocKind, int, int) -> None
616 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
617 if msg is not None:
618 raise ValueError(msg)
619 self.kind = kind
620 self.reg_len = reg_len
621 self.start = start
622
623 def conflicts(self, other):
624 # type: (Loc) -> bool
625 return (self.kind == other.kind
626 and self.start < other.stop and other.start < self.stop)
627
628 @staticmethod
629 def make_ty(kind, reg_len):
630 # type: (LocKind, int) -> Ty
631 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
632
633 @cached_property
634 def ty(self):
635 # type: () -> Ty
636 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
637
638 @property
639 def stop(self):
640 # type: () -> int
641 return self.start + self.reg_len
642
643 def try_concat(self, *others):
644 # type: (*Loc | None) -> Loc | None
645 reg_len = self.reg_len
646 stop = self.stop
647 for other in others:
648 if other is None or other.kind != self.kind:
649 return None
650 if stop != other.start:
651 return None
652 stop = other.stop
653 reg_len += other.reg_len
654 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
655
656 def get_subloc_at_offset(self, subloc_ty, offset):
657 # type: (Ty, int) -> Loc
658 if subloc_ty.base_ty != self.kind.base_ty:
659 raise ValueError("BaseTy mismatch")
660 if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
661 raise ValueError("invalid sub-Loc: offset and/or "
662 "subloc_ty.reg_len out of range")
663 return Loc(kind=self.kind,
664 start=self.start + offset, reg_len=subloc_ty.reg_len)
665
666
667 SPECIAL_GPRS = (
668 Loc(kind=LocKind.GPR, start=0, reg_len=1),
669 Loc(kind=LocKind.GPR, start=1, reg_len=1),
670 Loc(kind=LocKind.GPR, start=2, reg_len=1),
671 Loc(kind=LocKind.GPR, start=13, reg_len=1),
672 )
673
674
675 @final
676 class LocSet(OFSet[Loc], metaclass=InternedMeta):
677 def __init__(self, __locs=()):
678 # type: (Iterable[Loc]) -> None
679 super().__init__(__locs)
680 if isinstance(__locs, LocSet):
681 self.__starts = __locs.starts
682 self.__ty = __locs.ty
683 return
684 starts = {i: BitSet() for i in LocKind}
685 ty = None # type: None | Ty
686 for loc in self:
687 if ty is None:
688 ty = loc.ty
689 if ty != loc.ty:
690 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
691 starts[loc.kind].add(loc.start)
692 self.__starts = FMap(
693 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
694 self.__ty = ty
695
696 @property
697 def starts(self):
698 # type: () -> FMap[LocKind, FBitSet]
699 return self.__starts
700
701 @property
702 def ty(self):
703 # type: () -> Ty | None
704 return self.__ty
705
706 @cached_property
707 def stops(self):
708 # type: () -> FMap[LocKind, FBitSet]
709 if self.ty is None:
710 return FMap()
711 sh = self.ty.reg_len
712 return FMap(
713 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
714
715 @property
716 def kinds(self):
717 # type: () -> AbstractSet[LocKind]
718 return self.starts.keys()
719
720 @property
721 def reg_len(self):
722 # type: () -> int | None
723 if self.ty is None:
724 return None
725 return self.ty.reg_len
726
727 @property
728 def base_ty(self):
729 # type: () -> BaseTy | None
730 if self.ty is None:
731 return None
732 return self.ty.base_ty
733
734 def concat(self, *others):
735 # type: (*LocSet) -> LocSet
736 if self.ty is None:
737 return LocSet()
738 base_ty = self.ty.base_ty
739 reg_len = self.ty.reg_len
740 starts = {k: BitSet(v) for k, v in self.starts.items()}
741 for other in others:
742 if other.ty is None:
743 return LocSet()
744 if other.ty.base_ty != base_ty:
745 return LocSet()
746 for kind, other_starts in other.starts.items():
747 if kind not in starts:
748 continue
749 starts[kind].bits &= other_starts.bits >> reg_len
750 if starts[kind] == 0:
751 del starts[kind]
752 if len(starts) == 0:
753 return LocSet()
754 reg_len += other.ty.reg_len
755
756 def locs():
757 # type: () -> Iterable[Loc]
758 for kind, v in starts.items():
759 for start in v:
760 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
761 if loc is not None:
762 yield loc
763 return LocSet(locs())
764
765 @lru_cache(maxsize=None, typed=True)
766 def max_conflicts_with(self, other):
767 # type: (LocSet | Loc) -> int
768 """the largest number of Locs in `self` that a single Loc
769 from `other` can conflict with
770 """
771 if isinstance(other, LocSet):
772 return max(self.max_conflicts_with(i) for i in other)
773 else:
774 return sum(other.conflicts(i) for i in self)
775
776 def __repr__(self):
777 return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
778
779
780 @plain_data(frozen=True, unsafe_hash=True)
781 @final
782 class GenericOperandDesc(metaclass=InternedMeta):
783 """generic Op operand descriptor"""
784 __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
785 "write_stage")
786
787 def __init__(
788 self, ty, # type: GenericTy
789 sub_kinds, # type: Iterable[LocSubKind]
790 *,
791 fixed_loc=None, # type: Loc | None
792 tied_input_index=None, # type: int | None
793 spread=False, # type: bool
794 write_stage=OpStage.Early, # type: OpStage
795 ):
796 # type: (...) -> None
797 self.ty = ty
798 self.sub_kinds = OFSet(sub_kinds)
799 if len(self.sub_kinds) == 0:
800 raise ValueError("sub_kinds can't be empty")
801 self.fixed_loc = fixed_loc
802 if fixed_loc is not None:
803 if tied_input_index is not None:
804 raise ValueError("operand can't be both tied and fixed")
805 if not ty.can_instantiate_to(fixed_loc.ty):
806 raise ValueError(
807 f"fixed_loc has incompatible type for given generic "
808 f"type: fixed_loc={fixed_loc} generic ty={ty}")
809 if len(self.sub_kinds) != 1:
810 raise ValueError(
811 "multiple sub_kinds not allowed for fixed operand")
812 for sub_kind in self.sub_kinds:
813 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
814 raise ValueError(
815 f"fixed_loc not in given sub_kind: "
816 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
817 for sub_kind in self.sub_kinds:
818 if sub_kind.base_ty != ty.base_ty:
819 raise ValueError(f"sub_kind is incompatible with type: "
820 f"sub_kind={sub_kind} ty={ty}")
821 if tied_input_index is not None and tied_input_index < 0:
822 raise ValueError("invalid tied_input_index")
823 self.tied_input_index = tied_input_index
824 self.spread = spread
825 if spread:
826 if self.tied_input_index is not None:
827 raise ValueError("operand can't be both spread and tied")
828 if self.fixed_loc is not None:
829 raise ValueError("operand can't be both spread and fixed")
830 if self.ty.is_vec:
831 raise ValueError("operand can't be both spread and vector")
832 self.write_stage = write_stage
833
834 @cached_property
835 def ty_before_spread(self):
836 # type: () -> GenericTy
837 if self.spread:
838 return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
839 return self.ty
840
841 def tied_to_input(self, tied_input_index):
842 # type: (int) -> Self
843 return GenericOperandDesc(self.ty, self.sub_kinds,
844 tied_input_index=tied_input_index,
845 write_stage=self.write_stage)
846
847 def with_fixed_loc(self, fixed_loc):
848 # type: (Loc) -> Self
849 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
850 write_stage=self.write_stage)
851
852 def with_write_stage(self, write_stage):
853 # type: (OpStage) -> Self
854 return GenericOperandDesc(self.ty, self.sub_kinds,
855 fixed_loc=self.fixed_loc,
856 tied_input_index=self.tied_input_index,
857 spread=self.spread,
858 write_stage=write_stage)
859
860 def instantiate(self, maxvl):
861 # type: (int) -> Iterable[OperandDesc]
862 # assumes all spread operands have ty.reg_len = 1
863 rep_count = 1
864 if self.spread:
865 rep_count = maxvl
866 ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
867
868 def locs_before_spread():
869 # type: () -> Iterable[Loc]
870 if self.fixed_loc is not None:
871 if ty_before_spread != self.fixed_loc.ty:
872 raise ValueError(
873 f"instantiation failed: type mismatch with fixed_loc: "
874 f"instantiated type: {ty_before_spread} "
875 f"fixed_loc: {self.fixed_loc}")
876 yield self.fixed_loc
877 return
878 for sub_kind in self.sub_kinds:
879 yield from sub_kind.allocatable_locs(ty_before_spread)
880 loc_set_before_spread = LocSet(locs_before_spread())
881 for idx in range(rep_count):
882 if not self.spread:
883 idx = None
884 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
885 tied_input_index=self.tied_input_index,
886 spread_index=idx, write_stage=self.write_stage)
887
888
889 @plain_data(frozen=True, unsafe_hash=True)
890 @final
891 class OperandDesc(metaclass=InternedMeta):
892 """Op operand descriptor"""
893 __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
894 "write_stage")
895
896 def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
897 write_stage):
898 # type: (LocSet, int | None, int | None, OpStage) -> None
899 if len(loc_set_before_spread) == 0:
900 raise ValueError("loc_set_before_spread must not be empty")
901 self.loc_set_before_spread = loc_set_before_spread
902 self.tied_input_index = tied_input_index
903 if self.tied_input_index is not None and spread_index is not None:
904 raise ValueError("operand can't be both spread and tied")
905 self.spread_index = spread_index
906 self.write_stage = write_stage
907
908 @cached_property
909 def ty_before_spread(self):
910 # type: () -> Ty
911 ty = self.loc_set_before_spread.ty
912 assert ty is not None, (
913 "__init__ checked that the LocSet isn't empty, "
914 "non-empty LocSets should always have ty set")
915 return ty
916
917 @cached_property
918 def ty(self):
919 """ Ty after any spread is applied """
920 if self.spread_index is not None:
921 # assumes all spread operands have ty.reg_len = 1
922 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
923 return self.ty_before_spread
924
925 @property
926 def reg_offset_in_unspread(self):
927 """ the number of reg-sized slots in the unspread Loc before self's Loc
928
929 e.g. if the unspread Loc containing self is:
930 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
931 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
932 then reg_offset_into_unspread == 2 == 10 - 8
933 """
934 if self.spread_index is None:
935 return 0
936 return self.spread_index * self.ty.reg_len
937
938
939 OD_BASE_SGPR = GenericOperandDesc(
940 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
941 sub_kinds=[LocSubKind.BASE_GPR])
942 OD_EXTRA3_SGPR = GenericOperandDesc(
943 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
944 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
945 OD_EXTRA3_VGPR = GenericOperandDesc(
946 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
947 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
948 OD_EXTRA2_SGPR = GenericOperandDesc(
949 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
950 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
951 OD_EXTRA2_VGPR = GenericOperandDesc(
952 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
953 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
954 OD_CA = GenericOperandDesc(
955 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
956 sub_kinds=[LocSubKind.CA])
957 OD_VL = GenericOperandDesc(
958 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
959 sub_kinds=[LocSubKind.VL_MAXVL])
960
961
962 @plain_data(frozen=True, unsafe_hash=True)
963 @final
964 class GenericOpProperties(metaclass=InternedMeta):
965 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
966 "is_copy", "is_load_immediate", "has_side_effects")
967
968 def __init__(
969 self, demo_asm, # type: str
970 inputs, # type: Iterable[GenericOperandDesc]
971 outputs, # type: Iterable[GenericOperandDesc]
972 immediates=(), # type: Iterable[range]
973 is_copy=False, # type: bool
974 is_load_immediate=False, # type: bool
975 has_side_effects=False, # type: bool
976 ):
977 # type: (...) -> None
978 self.demo_asm = demo_asm # type: str
979 self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
980 for inp in self.inputs:
981 if inp.tied_input_index is not None:
982 raise ValueError(
983 f"tied_input_index is not allowed on inputs: {inp}")
984 if inp.write_stage is not OpStage.Early:
985 raise ValueError(
986 f"write_stage is not allowed on inputs: {inp}")
987 self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
988 fixed_locs = [] # type: list[tuple[Loc, int]]
989 for idx, out in enumerate(self.outputs):
990 if out.tied_input_index is not None:
991 if out.tied_input_index >= len(self.inputs):
992 raise ValueError(f"tied_input_index out of range: {out}")
993 tied_inp = self.inputs[out.tied_input_index]
994 expected_out = tied_inp.tied_to_input(out.tied_input_index) \
995 .with_write_stage(out.write_stage)
996 if expected_out != out:
997 raise ValueError(f"output can't be tied to non-equivalent "
998 f"input: {out} tied to {tied_inp}")
999 if out.fixed_loc is not None:
1000 for other_fixed_loc, other_idx in fixed_locs:
1001 if not other_fixed_loc.conflicts(out.fixed_loc):
1002 continue
1003 raise ValueError(
1004 f"conflicting fixed_locs: outputs[{idx}] and "
1005 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1006 f"with {other_fixed_loc}")
1007 fixed_locs.append((out.fixed_loc, idx))
1008 self.immediates = tuple(immediates) # type: tuple[range, ...]
1009 self.is_copy = is_copy # type: bool
1010 self.is_load_immediate = is_load_immediate # type: bool
1011 self.has_side_effects = has_side_effects # type: bool
1012
1013
1014 @plain_data(frozen=True, unsafe_hash=True)
1015 @final
1016 class OpProperties(metaclass=InternedMeta):
1017 __slots__ = "kind", "inputs", "outputs", "maxvl"
1018
1019 def __init__(self, kind, maxvl):
1020 # type: (OpKind, int) -> None
1021 self.kind = kind # type: OpKind
1022 inputs = [] # type: list[OperandDesc]
1023 for inp in self.generic.inputs:
1024 inputs.extend(inp.instantiate(maxvl=maxvl))
1025 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
1026 outputs = [] # type: list[OperandDesc]
1027 for out in self.generic.outputs:
1028 outputs.extend(out.instantiate(maxvl=maxvl))
1029 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
1030 self.maxvl = maxvl # type: int
1031
1032 @property
1033 def generic(self):
1034 # type: () -> GenericOpProperties
1035 return self.kind.properties
1036
1037 @property
1038 def immediates(self):
1039 # type: () -> tuple[range, ...]
1040 return self.generic.immediates
1041
1042 @property
1043 def demo_asm(self):
1044 # type: () -> str
1045 return self.generic.demo_asm
1046
1047 @property
1048 def is_copy(self):
1049 # type: () -> bool
1050 return self.generic.is_copy
1051
1052 @property
1053 def is_load_immediate(self):
1054 # type: () -> bool
1055 return self.generic.is_load_immediate
1056
1057 @property
1058 def has_side_effects(self):
1059 # type: () -> bool
1060 return self.generic.has_side_effects
1061
1062
1063 IMM_S16 = range(-1 << 15, 1 << 15)
1064
1065 _SIM_FN = Callable[["Op", "BaseSimState"], None]
1066 _SIM_FN2 = Callable[[], _SIM_FN]
1067 _SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1068 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
1069 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
1070 _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1071
1072
1073 @unique
1074 @final
1075 class OpKind(Enum):
1076 def __init__(self, properties):
1077 # type: (GenericOpProperties) -> None
1078 super().__init__()
1079 self.__properties = properties
1080
1081 @property
1082 def properties(self):
1083 # type: () -> GenericOpProperties
1084 return self.__properties
1085
1086 def instantiate(self, maxvl):
1087 # type: (int) -> OpProperties
1088 return OpProperties(self, maxvl=maxvl)
1089
1090 def __repr__(self):
1091 # type: () -> str
1092 return "OpKind." + self._name_
1093
1094 @cached_property
1095 def sim(self):
1096 # type: () -> _SIM_FN
1097 return _SIM_FNS[self.properties]()
1098
1099 @cached_property
1100 def gen_asm(self):
1101 # type: () -> _GEN_ASM_FN
1102 return _GEN_ASMS[self.properties]()
1103
1104 @staticmethod
1105 def __clearca_sim(op, state):
1106 # type: (Op, BaseSimState) -> None
1107 state[op.outputs[0]] = False,
1108
1109 @staticmethod
1110 def __clearca_gen_asm(op, state):
1111 # type: (Op, GenAsmState) -> None
1112 state.writeln("addic 0, 0, 0")
1113 ClearCA = GenericOpProperties(
1114 demo_asm="addic 0, 0, 0",
1115 inputs=[],
1116 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1117 )
1118 _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
1119 _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
1120
1121 @staticmethod
1122 def __setca_sim(op, state):
1123 # type: (Op, BaseSimState) -> None
1124 state[op.outputs[0]] = True,
1125
1126 @staticmethod
1127 def __setca_gen_asm(op, state):
1128 # type: (Op, GenAsmState) -> None
1129 state.writeln("subfc 0, 0, 0")
1130 SetCA = GenericOpProperties(
1131 demo_asm="subfc 0, 0, 0",
1132 inputs=[],
1133 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1134 )
1135 _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
1136 _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
1137
1138 @staticmethod
1139 def __svadde_sim(op, state):
1140 # type: (Op, BaseSimState) -> None
1141 RA = state[op.input_vals[0]]
1142 RB = state[op.input_vals[1]]
1143 carry, = state[op.input_vals[2]]
1144 VL, = state[op.input_vals[3]]
1145 RT = [] # type: list[int]
1146 for i in range(VL):
1147 v = RA[i] + RB[i] + carry
1148 RT.append(v & GPR_VALUE_MASK)
1149 carry = (v >> GPR_SIZE_IN_BITS) != 0
1150 state[op.outputs[0]] = tuple(RT)
1151 state[op.outputs[1]] = carry,
1152
1153 @staticmethod
1154 def __svadde_gen_asm(op, state):
1155 # type: (Op, GenAsmState) -> None
1156 RT = state.vgpr(op.outputs[0])
1157 RA = state.vgpr(op.input_vals[0])
1158 RB = state.vgpr(op.input_vals[1])
1159 state.writeln(f"sv.adde {RT}, {RA}, {RB}")
1160 SvAddE = GenericOpProperties(
1161 demo_asm="sv.adde *RT, *RA, *RB",
1162 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1163 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1164 )
1165 _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
1166 _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
1167
1168 @staticmethod
1169 def __addze_sim(op, state):
1170 # type: (Op, BaseSimState) -> None
1171 RA, = state[op.input_vals[0]]
1172 carry, = state[op.input_vals[1]]
1173 v = RA + carry
1174 RT = v & GPR_VALUE_MASK
1175 carry = (v >> GPR_SIZE_IN_BITS) != 0
1176 state[op.outputs[0]] = RT,
1177 state[op.outputs[1]] = carry,
1178
1179 @staticmethod
1180 def __addze_gen_asm(op, state):
1181 # type: (Op, GenAsmState) -> None
1182 RT = state.vgpr(op.outputs[0])
1183 RA = state.vgpr(op.input_vals[0])
1184 state.writeln(f"addze {RT}, {RA}")
1185 AddZE = GenericOpProperties(
1186 demo_asm="addze RT, RA",
1187 inputs=[OD_BASE_SGPR, OD_CA],
1188 outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
1189 )
1190 _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
1191 _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
1192
1193 @staticmethod
1194 def __svsubfe_sim(op, state):
1195 # type: (Op, BaseSimState) -> None
1196 RA = state[op.input_vals[0]]
1197 RB = state[op.input_vals[1]]
1198 carry, = state[op.input_vals[2]]
1199 VL, = state[op.input_vals[3]]
1200 RT = [] # type: list[int]
1201 for i in range(VL):
1202 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1203 RT.append(v & GPR_VALUE_MASK)
1204 carry = (v >> GPR_SIZE_IN_BITS) != 0
1205 state[op.outputs[0]] = tuple(RT)
1206 state[op.outputs[1]] = carry,
1207
1208 @staticmethod
1209 def __svsubfe_gen_asm(op, state):
1210 # type: (Op, GenAsmState) -> None
1211 RT = state.vgpr(op.outputs[0])
1212 RA = state.vgpr(op.input_vals[0])
1213 RB = state.vgpr(op.input_vals[1])
1214 state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
1215 SvSubFE = GenericOpProperties(
1216 demo_asm="sv.subfe *RT, *RA, *RB",
1217 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1218 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1219 )
1220 _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
1221 _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
1222
1223 @staticmethod
1224 def __svandvs_sim(op, state):
1225 # type: (Op, BaseSimState) -> None
1226 RA = state[op.input_vals[0]]
1227 RB, = state[op.input_vals[1]]
1228 VL, = state[op.input_vals[2]]
1229 RT = [] # type: list[int]
1230 for i in range(VL):
1231 RT.append(RA[i] & RB & GPR_VALUE_MASK)
1232 state[op.outputs[0]] = tuple(RT)
1233
1234 @staticmethod
1235 def __svandvs_gen_asm(op, state):
1236 # type: (Op, GenAsmState) -> None
1237 RT = state.vgpr(op.outputs[0])
1238 RA = state.vgpr(op.input_vals[0])
1239 RB = state.sgpr(op.input_vals[1])
1240 state.writeln(f"sv.and {RT}, {RA}, {RB}")
1241 SvAndVS = GenericOpProperties(
1242 demo_asm="sv.and *RT, *RA, RB",
1243 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1244 outputs=[OD_EXTRA3_VGPR],
1245 )
1246 _SIM_FNS[SvAndVS] = lambda: OpKind.__svandvs_sim
1247 _GEN_ASMS[SvAndVS] = lambda: OpKind.__svandvs_gen_asm
1248
1249 @staticmethod
1250 def __svmaddedu_sim(op, state):
1251 # type: (Op, BaseSimState) -> None
1252 RA = state[op.input_vals[0]]
1253 RB, = state[op.input_vals[1]]
1254 carry, = state[op.input_vals[2]]
1255 VL, = state[op.input_vals[3]]
1256 RT = [] # type: list[int]
1257 for i in range(VL):
1258 v = RA[i] * RB + carry
1259 RT.append(v & GPR_VALUE_MASK)
1260 carry = v >> GPR_SIZE_IN_BITS
1261 state[op.outputs[0]] = tuple(RT)
1262 state[op.outputs[1]] = carry,
1263
1264 @staticmethod
1265 def __svmaddedu_gen_asm(op, state):
1266 # type: (Op, GenAsmState) -> None
1267 RT = state.vgpr(op.outputs[0])
1268 RA = state.vgpr(op.input_vals[0])
1269 RB = state.sgpr(op.input_vals[1])
1270 RC = state.sgpr(op.input_vals[2])
1271 state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1272 SvMAddEDU = GenericOpProperties(
1273 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1274 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1275 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1276 )
1277 _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
1278 _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
1279
1280 @staticmethod
1281 def __sradi_sim(op, state):
1282 # type: (Op, BaseSimState) -> None
1283 rs, = state[op.input_vals[0]]
1284 imm = op.immediates[0]
1285 if rs >= 1 << (GPR_SIZE_IN_BITS - 1):
1286 rs -= 1 << GPR_SIZE_IN_BITS
1287 v = rs >> imm
1288 RA = v & GPR_VALUE_MASK
1289 CA = (RA << imm) != rs
1290 state[op.outputs[0]] = RA,
1291 state[op.outputs[1]] = CA,
1292
1293 @staticmethod
1294 def __sradi_gen_asm(op, state):
1295 # type: (Op, GenAsmState) -> None
1296 RA = state.sgpr(op.outputs[0])
1297 RS = state.sgpr(op.input_vals[0])
1298 imm = op.immediates[0]
1299 state.writeln(f"sradi {RA}, {RS}, {imm}")
1300 SRADI = GenericOpProperties(
1301 demo_asm="sradi RA, RS, imm",
1302 inputs=[OD_BASE_SGPR],
1303 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late),
1304 OD_CA.with_write_stage(OpStage.Late)],
1305 immediates=[range(GPR_SIZE_IN_BITS)],
1306 )
1307 _SIM_FNS[SRADI] = lambda: OpKind.__sradi_sim
1308 _GEN_ASMS[SRADI] = lambda: OpKind.__sradi_gen_asm
1309
1310 @staticmethod
1311 def __setvli_sim(op, state):
1312 # type: (Op, BaseSimState) -> None
1313 state[op.outputs[0]] = op.immediates[0],
1314
1315 @staticmethod
1316 def __setvli_gen_asm(op, state):
1317 # type: (Op, GenAsmState) -> None
1318 imm = op.immediates[0]
1319 state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
1320 SetVLI = GenericOpProperties(
1321 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1322 inputs=(),
1323 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1324 immediates=[range(1, 65)],
1325 is_load_immediate=True,
1326 )
1327 _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
1328 _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
1329
1330 @staticmethod
1331 def __svli_sim(op, state):
1332 # type: (Op, BaseSimState) -> None
1333 VL, = state[op.input_vals[0]]
1334 imm = op.immediates[0] & GPR_VALUE_MASK
1335 state[op.outputs[0]] = (imm,) * VL
1336
1337 @staticmethod
1338 def __svli_gen_asm(op, state):
1339 # type: (Op, GenAsmState) -> None
1340 RT = state.vgpr(op.outputs[0])
1341 imm = op.immediates[0]
1342 state.writeln(f"sv.addi {RT}, 0, {imm}")
1343 SvLI = GenericOpProperties(
1344 demo_asm="sv.addi *RT, 0, imm",
1345 inputs=[OD_VL],
1346 outputs=[OD_EXTRA3_VGPR],
1347 immediates=[IMM_S16],
1348 is_load_immediate=True,
1349 )
1350 _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
1351 _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
1352
1353 @staticmethod
1354 def __li_sim(op, state):
1355 # type: (Op, BaseSimState) -> None
1356 imm = op.immediates[0] & GPR_VALUE_MASK
1357 state[op.outputs[0]] = imm,
1358
1359 @staticmethod
1360 def __li_gen_asm(op, state):
1361 # type: (Op, GenAsmState) -> None
1362 RT = state.sgpr(op.outputs[0])
1363 imm = op.immediates[0]
1364 state.writeln(f"addi {RT}, 0, {imm}")
1365 LI = GenericOpProperties(
1366 demo_asm="addi RT, 0, imm",
1367 inputs=(),
1368 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1369 immediates=[IMM_S16],
1370 is_load_immediate=True,
1371 )
1372 _SIM_FNS[LI] = lambda: OpKind.__li_sim
1373 _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
1374
1375 @staticmethod
1376 def __veccopytoreg_sim(op, state):
1377 # type: (Op, BaseSimState) -> None
1378 state[op.outputs[0]] = state[op.input_vals[0]]
1379
1380 @staticmethod
1381 def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
1382 # type: (Loc, Loc, bool, GenAsmState) -> None
1383 sv = "sv." if is_vec else ""
1384 rev = ""
1385 if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
1386 rev = "/mrr"
1387 if src_loc == dest_loc:
1388 return # no-op
1389 if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1390 raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
1391 if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1392 raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
1393 if src_loc.kind is LocKind.StackI64:
1394 if dest_loc.kind is LocKind.StackI64:
1395 raise ValueError(
1396 f"can't copy from stack to stack: {src_loc} {dest_loc}")
1397 elif dest_loc.kind is not LocKind.GPR:
1398 assert_never(dest_loc.kind)
1399 src = state.stack(src_loc)
1400 dest = state.gpr(dest_loc, is_vec=is_vec)
1401 state.writeln(f"{sv}ld {dest}, {src}")
1402 elif dest_loc.kind is LocKind.StackI64:
1403 if src_loc.kind is not LocKind.GPR:
1404 assert_never(src_loc.kind)
1405 src = state.gpr(src_loc, is_vec=is_vec)
1406 dest = state.stack(dest_loc)
1407 state.writeln(f"{sv}std {src}, {dest}")
1408 elif src_loc.kind is LocKind.GPR:
1409 if dest_loc.kind is not LocKind.GPR:
1410 assert_never(dest_loc.kind)
1411 src = state.gpr(src_loc, is_vec=is_vec)
1412 dest = state.gpr(dest_loc, is_vec=is_vec)
1413 state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
1414 else:
1415 assert_never(src_loc.kind)
1416
1417 @staticmethod
1418 def __veccopytoreg_gen_asm(op, state):
1419 # type: (Op, GenAsmState) -> None
1420 OpKind.__copy_to_from_reg_gen_asm(
1421 src_loc=state.loc(
1422 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1423 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1424 is_vec=True, state=state)
1425
1426 VecCopyToReg = GenericOpProperties(
1427 demo_asm="sv.mv dest, src",
1428 inputs=[GenericOperandDesc(
1429 ty=GenericTy(BaseTy.I64, is_vec=True),
1430 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1431 ), OD_VL],
1432 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1433 is_copy=True,
1434 )
1435 _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
1436 _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
1437
1438 @staticmethod
1439 def __veccopyfromreg_sim(op, state):
1440 # type: (Op, BaseSimState) -> None
1441 state[op.outputs[0]] = state[op.input_vals[0]]
1442
1443 @staticmethod
1444 def __veccopyfromreg_gen_asm(op, state):
1445 # type: (Op, GenAsmState) -> None
1446 OpKind.__copy_to_from_reg_gen_asm(
1447 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1448 dest_loc=state.loc(
1449 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1450 is_vec=True, state=state)
1451 VecCopyFromReg = GenericOpProperties(
1452 demo_asm="sv.mv dest, src",
1453 inputs=[OD_EXTRA3_VGPR, OD_VL],
1454 outputs=[GenericOperandDesc(
1455 ty=GenericTy(BaseTy.I64, is_vec=True),
1456 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1457 write_stage=OpStage.Late,
1458 )],
1459 is_copy=True,
1460 )
1461 _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
1462 _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
1463
1464 @staticmethod
1465 def __copytoreg_sim(op, state):
1466 # type: (Op, BaseSimState) -> None
1467 state[op.outputs[0]] = state[op.input_vals[0]]
1468
1469 @staticmethod
1470 def __copytoreg_gen_asm(op, state):
1471 # type: (Op, GenAsmState) -> None
1472 OpKind.__copy_to_from_reg_gen_asm(
1473 src_loc=state.loc(
1474 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1475 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1476 is_vec=False, state=state)
1477 CopyToReg = GenericOpProperties(
1478 demo_asm="mv dest, src",
1479 inputs=[GenericOperandDesc(
1480 ty=GenericTy(BaseTy.I64, is_vec=False),
1481 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1482 LocSubKind.StackI64],
1483 )],
1484 outputs=[GenericOperandDesc(
1485 ty=GenericTy(BaseTy.I64, is_vec=False),
1486 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1487 write_stage=OpStage.Late,
1488 )],
1489 is_copy=True,
1490 )
1491 _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
1492 _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
1493
1494 @staticmethod
1495 def __copyfromreg_sim(op, state):
1496 # type: (Op, BaseSimState) -> None
1497 state[op.outputs[0]] = state[op.input_vals[0]]
1498
1499 @staticmethod
1500 def __copyfromreg_gen_asm(op, state):
1501 # type: (Op, GenAsmState) -> None
1502 OpKind.__copy_to_from_reg_gen_asm(
1503 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1504 dest_loc=state.loc(
1505 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1506 is_vec=False, state=state)
1507 CopyFromReg = GenericOpProperties(
1508 demo_asm="mv dest, src",
1509 inputs=[GenericOperandDesc(
1510 ty=GenericTy(BaseTy.I64, is_vec=False),
1511 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1512 )],
1513 outputs=[GenericOperandDesc(
1514 ty=GenericTy(BaseTy.I64, is_vec=False),
1515 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1516 LocSubKind.StackI64],
1517 write_stage=OpStage.Late,
1518 )],
1519 is_copy=True,
1520 )
1521 _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
1522 _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
1523
1524 @staticmethod
1525 def __concat_sim(op, state):
1526 # type: (Op, BaseSimState) -> None
1527 state[op.outputs[0]] = tuple(
1528 state[i][0] for i in op.input_vals[:-1])
1529
1530 @staticmethod
1531 def __concat_gen_asm(op, state):
1532 # type: (Op, GenAsmState) -> None
1533 OpKind.__copy_to_from_reg_gen_asm(
1534 src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
1535 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1536 is_vec=True, state=state)
1537 Concat = GenericOpProperties(
1538 demo_asm="sv.mv dest, src",
1539 inputs=[GenericOperandDesc(
1540 ty=GenericTy(BaseTy.I64, is_vec=False),
1541 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1542 spread=True,
1543 ), OD_VL],
1544 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1545 is_copy=True,
1546 )
1547 _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
1548 _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
1549
1550 @staticmethod
1551 def __spread_sim(op, state):
1552 # type: (Op, BaseSimState) -> None
1553 for idx, inp in enumerate(state[op.input_vals[0]]):
1554 state[op.outputs[idx]] = inp,
1555
1556 @staticmethod
1557 def __spread_gen_asm(op, state):
1558 # type: (Op, GenAsmState) -> None
1559 OpKind.__copy_to_from_reg_gen_asm(
1560 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1561 dest_loc=state.loc(op.outputs, LocKind.GPR),
1562 is_vec=True, state=state)
1563 Spread = GenericOpProperties(
1564 demo_asm="sv.mv dest, src",
1565 inputs=[OD_EXTRA3_VGPR, OD_VL],
1566 outputs=[GenericOperandDesc(
1567 ty=GenericTy(BaseTy.I64, is_vec=False),
1568 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1569 spread=True,
1570 write_stage=OpStage.Late,
1571 )],
1572 is_copy=True,
1573 )
1574 _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
1575 _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
1576
1577 @staticmethod
1578 def __svld_sim(op, state):
1579 # type: (Op, BaseSimState) -> None
1580 RA, = state[op.input_vals[0]]
1581 VL, = state[op.input_vals[1]]
1582 addr = RA + op.immediates[0]
1583 RT = [] # type: list[int]
1584 for i in range(VL):
1585 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1586 RT.append(v & GPR_VALUE_MASK)
1587 state[op.outputs[0]] = tuple(RT)
1588
1589 @staticmethod
1590 def __svld_gen_asm(op, state):
1591 # type: (Op, GenAsmState) -> None
1592 RA = state.sgpr(op.input_vals[0])
1593 RT = state.vgpr(op.outputs[0])
1594 imm = op.immediates[0]
1595 state.writeln(f"sv.ld {RT}, {imm}({RA})")
1596 SvLd = GenericOpProperties(
1597 demo_asm="sv.ld *RT, imm(RA)",
1598 inputs=[OD_EXTRA3_SGPR, OD_VL],
1599 outputs=[OD_EXTRA3_VGPR],
1600 immediates=[IMM_S16],
1601 )
1602 _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
1603 _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
1604
1605 @staticmethod
1606 def __ld_sim(op, state):
1607 # type: (Op, BaseSimState) -> None
1608 RA, = state[op.input_vals[0]]
1609 addr = RA + op.immediates[0]
1610 v = state.load(addr)
1611 state[op.outputs[0]] = v & GPR_VALUE_MASK,
1612
1613 @staticmethod
1614 def __ld_gen_asm(op, state):
1615 # type: (Op, GenAsmState) -> None
1616 RA = state.sgpr(op.input_vals[0])
1617 RT = state.sgpr(op.outputs[0])
1618 imm = op.immediates[0]
1619 state.writeln(f"ld {RT}, {imm}({RA})")
1620 Ld = GenericOpProperties(
1621 demo_asm="ld RT, imm(RA)",
1622 inputs=[OD_BASE_SGPR],
1623 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1624 immediates=[IMM_S16],
1625 )
1626 _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
1627 _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
1628
1629 @staticmethod
1630 def __svstd_sim(op, state):
1631 # type: (Op, BaseSimState) -> None
1632 RS = state[op.input_vals[0]]
1633 RA, = state[op.input_vals[1]]
1634 VL, = state[op.input_vals[2]]
1635 addr = RA + op.immediates[0]
1636 for i in range(VL):
1637 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1638
1639 @staticmethod
1640 def __svstd_gen_asm(op, state):
1641 # type: (Op, GenAsmState) -> None
1642 RS = state.vgpr(op.input_vals[0])
1643 RA = state.sgpr(op.input_vals[1])
1644 imm = op.immediates[0]
1645 state.writeln(f"sv.std {RS}, {imm}({RA})")
1646 SvStd = GenericOpProperties(
1647 demo_asm="sv.std *RS, imm(RA)",
1648 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1649 outputs=[],
1650 immediates=[IMM_S16],
1651 has_side_effects=True,
1652 )
1653 _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
1654 _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
1655
1656 @staticmethod
1657 def __std_sim(op, state):
1658 # type: (Op, BaseSimState) -> None
1659 RS, = state[op.input_vals[0]]
1660 RA, = state[op.input_vals[1]]
1661 addr = RA + op.immediates[0]
1662 state.store(addr, value=RS)
1663
1664 @staticmethod
1665 def __std_gen_asm(op, state):
1666 # type: (Op, GenAsmState) -> None
1667 RS = state.sgpr(op.input_vals[0])
1668 RA = state.sgpr(op.input_vals[1])
1669 imm = op.immediates[0]
1670 state.writeln(f"std {RS}, {imm}({RA})")
1671 Std = GenericOpProperties(
1672 demo_asm="std RS, imm(RA)",
1673 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1674 outputs=[],
1675 immediates=[IMM_S16],
1676 has_side_effects=True,
1677 )
1678 _SIM_FNS[Std] = lambda: OpKind.__std_sim
1679 _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
1680
1681 @staticmethod
1682 def __funcargr3_sim(op, state):
1683 # type: (Op, BaseSimState) -> None
1684 pass # return value set before simulation
1685
1686 @staticmethod
1687 def __funcargr3_gen_asm(op, state):
1688 # type: (Op, GenAsmState) -> None
1689 pass # no instructions needed
1690 FuncArgR3 = GenericOpProperties(
1691 demo_asm="",
1692 inputs=[],
1693 outputs=[OD_BASE_SGPR.with_fixed_loc(
1694 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1695 )
1696 _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
1697 _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
1698
1699
1700 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1701 class SSAValOrUse(metaclass=InternedMeta):
1702 __slots__ = "op", "operand_idx"
1703
1704 def __init__(self, op, operand_idx):
1705 # type: (Op, int) -> None
1706 super().__init__()
1707 self.op = op
1708 if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
1709 raise ValueError("invalid operand_idx")
1710 self.operand_idx = operand_idx
1711
1712 @abstractmethod
1713 def __repr__(self):
1714 # type: () -> str
1715 ...
1716
1717 @property
1718 @abstractmethod
1719 def descriptor_array(self):
1720 # type: () -> tuple[OperandDesc, ...]
1721 ...
1722
1723 @cached_property
1724 def defining_descriptor(self):
1725 # type: () -> OperandDesc
1726 return self.descriptor_array[self.operand_idx]
1727
1728 @cached_property
1729 def ty(self):
1730 # type: () -> Ty
1731 return self.defining_descriptor.ty
1732
1733 @cached_property
1734 def ty_before_spread(self):
1735 # type: () -> Ty
1736 return self.defining_descriptor.ty_before_spread
1737
1738 @property
1739 def base_ty(self):
1740 # type: () -> BaseTy
1741 return self.ty_before_spread.base_ty
1742
1743 @property
1744 def reg_offset_in_unspread(self):
1745 """ the number of reg-sized slots in the unspread Loc before self's Loc
1746
1747 e.g. if the unspread Loc containing self is:
1748 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1749 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1750 then reg_offset_into_unspread == 2 == 10 - 8
1751 """
1752 return self.defining_descriptor.reg_offset_in_unspread
1753
1754 @property
1755 def unspread_start_idx(self):
1756 # type: () -> int
1757 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1758
1759 @property
1760 def unspread_start(self):
1761 # type: () -> Self
1762 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1763
1764
1765 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1766 @final
1767 class SSAVal(SSAValOrUse):
1768 __slots__ = ()
1769
1770 def __repr__(self):
1771 # type: () -> str
1772 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1773
1774 @cached_property
1775 def def_loc_set_before_spread(self):
1776 # type: () -> LocSet
1777 return self.defining_descriptor.loc_set_before_spread
1778
1779 @cached_property
1780 def descriptor_array(self):
1781 # type: () -> tuple[OperandDesc, ...]
1782 return self.op.properties.outputs
1783
1784 @cached_property
1785 def tied_input(self):
1786 # type: () -> None | SSAUse
1787 if self.defining_descriptor.tied_input_index is None:
1788 return None
1789 return SSAUse(op=self.op,
1790 operand_idx=self.defining_descriptor.tied_input_index)
1791
1792 @property
1793 def write_stage(self):
1794 # type: () -> OpStage
1795 return self.defining_descriptor.write_stage
1796
1797
1798 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1799 @final
1800 class SSAUse(SSAValOrUse):
1801 __slots__ = ()
1802
1803 @cached_property
1804 def use_loc_set_before_spread(self):
1805 # type: () -> LocSet
1806 return self.defining_descriptor.loc_set_before_spread
1807
1808 @cached_property
1809 def descriptor_array(self):
1810 # type: () -> tuple[OperandDesc, ...]
1811 return self.op.properties.inputs
1812
1813 def __repr__(self):
1814 # type: () -> str
1815 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1816
1817 @property
1818 def ssa_val(self):
1819 # type: () -> SSAVal
1820 return self.op.input_vals[self.operand_idx]
1821
1822 @ssa_val.setter
1823 def ssa_val(self, ssa_val):
1824 # type: (SSAVal) -> None
1825 self.op.input_vals[self.operand_idx] = ssa_val
1826
1827
1828 _T = TypeVar("_T")
1829 _Desc = TypeVar("_Desc")
1830
1831
1832 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
1833 @abstractmethod
1834 def _verify_write_with_desc(self, idx, item, desc):
1835 # type: (int, _T | Any, _Desc) -> None
1836 raise NotImplementedError
1837
1838 @final
1839 def _verify_write(self, idx, item):
1840 # type: (int | Any, _T | Any) -> int
1841 if not isinstance(idx, int):
1842 if isinstance(idx, slice):
1843 raise TypeError(
1844 f"can't write to slice of {self.__class__.__name__}")
1845 raise TypeError(f"can't write with index {idx!r}")
1846 # normalize idx, raising IndexError if it is out of range
1847 idx = range(len(self.descriptors))[idx]
1848 desc = self.descriptors[idx]
1849 self._verify_write_with_desc(idx, item, desc)
1850 return idx
1851
1852 def _on_set(self, idx, new_item, old_item):
1853 # type: (int, _T, _T | None) -> None
1854 pass
1855
1856 @abstractmethod
1857 def _get_descriptors(self):
1858 # type: () -> tuple[_Desc, ...]
1859 raise NotImplementedError
1860
1861 @cached_property
1862 @final
1863 def descriptors(self):
1864 # type: () -> tuple[_Desc, ...]
1865 return self._get_descriptors()
1866
1867 @property
1868 @final
1869 def op(self):
1870 return self.__op
1871
1872 def __init__(self, items, op):
1873 # type: (Iterable[_T], Op) -> None
1874 super().__init__()
1875 self.__op = op
1876 self.__items = [] # type: list[_T]
1877 for idx, item in enumerate(items):
1878 if idx >= len(self.descriptors):
1879 raise ValueError("too many items")
1880 _ = self._verify_write(idx, item)
1881 self.__items.append(item)
1882 if len(self.__items) < len(self.descriptors):
1883 raise ValueError("not enough items")
1884
1885 @final
1886 def __iter__(self):
1887 # type: () -> Iterator[_T]
1888 yield from self.__items
1889
1890 @overload
1891 def __getitem__(self, idx):
1892 # type: (int) -> _T
1893 ...
1894
1895 @overload
1896 def __getitem__(self, idx):
1897 # type: (slice) -> list[_T]
1898 ...
1899
1900 @final
1901 def __getitem__(self, idx):
1902 # type: (int | slice) -> _T | list[_T]
1903 return self.__items[idx]
1904
1905 @final
1906 def __setitem__(self, idx, item):
1907 # type: (int, _T) -> None
1908 idx = self._verify_write(idx, item)
1909 self.__items[idx] = item
1910
1911 @final
1912 def __len__(self):
1913 # type: () -> int
1914 return len(self.__items)
1915
1916 def __repr__(self):
1917 # type: () -> str
1918 return f"{self.__class__.__name__}({self.__items}, op=...)"
1919
1920
1921 @final
1922 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
1923 def _get_descriptors(self):
1924 # type: () -> tuple[OperandDesc, ...]
1925 return self.op.properties.inputs
1926
1927 def _verify_write_with_desc(self, idx, item, desc):
1928 # type: (int, SSAVal | Any, OperandDesc) -> None
1929 if not isinstance(item, SSAVal):
1930 raise TypeError("expected value of type SSAVal")
1931 if item.ty != desc.ty:
1932 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
1933 f"corresponding input's type {desc.ty!r}")
1934
1935 def _on_set(self, idx, new_item, old_item):
1936 # type: (int, SSAVal, SSAVal | None) -> None
1937 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
1938
1939 def __init__(self, items, op):
1940 # type: (Iterable[SSAVal], Op) -> None
1941 if hasattr(op, "inputs"):
1942 raise ValueError("Op.inputs already set")
1943 super().__init__(items, op)
1944
1945
1946 @final
1947 class OpImmediates(OpInputSeq[int, range]):
1948 def _get_descriptors(self):
1949 # type: () -> tuple[range, ...]
1950 return self.op.properties.immediates
1951
1952 def _verify_write_with_desc(self, idx, item, desc):
1953 # type: (int, int | Any, range) -> None
1954 if not isinstance(item, int):
1955 raise TypeError("expected value of type int")
1956 if item not in desc:
1957 raise ValueError(f"immediate value {item!r} not in {desc!r}")
1958
1959 def __init__(self, items, op):
1960 # type: (Iterable[int], Op) -> None
1961 if hasattr(op, "immediates"):
1962 raise ValueError("Op.immediates already set")
1963 super().__init__(items, op)
1964
1965
1966 @plain_data(frozen=True, eq=False, repr=False)
1967 @final
1968 class Op:
1969 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
1970 "outputs", "name")
1971
1972 def __init__(self, fn, properties, input_vals, immediates, name=""):
1973 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1974 self.fn = fn
1975 self.properties = properties
1976 self.input_vals = OpInputVals(input_vals, op=self)
1977 inputs_len = len(self.properties.inputs)
1978 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
1979 self.immediates = OpImmediates(immediates, op=self)
1980 outputs_len = len(self.properties.outputs)
1981 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
1982 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
1983
1984 @property
1985 def kind(self):
1986 # type: () -> OpKind
1987 return self.properties.kind
1988
1989 def __eq__(self, other):
1990 # type: (Op | Any) -> bool
1991 if isinstance(other, Op):
1992 return self is other
1993 return NotImplemented
1994
1995 def __hash__(self):
1996 # type: () -> int
1997 return object.__hash__(self)
1998
1999 def __repr__(self, wrap_width=63, indent=" "):
2000 # type: (int, str) -> str
2001 WRAP_POINT = "\u200B" # zero-width space
2002 items = [f"{self.name}:\n"]
2003 for i, out in enumerate(self.outputs):
2004 item = f"<...outputs[{i}]: {out.ty}>"
2005 if i == 0:
2006 item = "(" + WRAP_POINT + item
2007 if i != len(self.outputs) - 1:
2008 item += ", " + WRAP_POINT
2009 else:
2010 item += WRAP_POINT + ") <= "
2011 items.append(item)
2012 items.append(self.kind._name_)
2013 if len(self.input_vals) + len(self.immediates) != 0:
2014 items[-1] += "("
2015 items[-1] += WRAP_POINT
2016 for i, inp in enumerate(self.input_vals):
2017 item = repr(inp)
2018 if i != len(self.input_vals) - 1 or len(self.immediates) != 0:
2019 item += ", " + WRAP_POINT
2020 else:
2021 item += ") " + WRAP_POINT
2022 items.append(item)
2023 for i, imm in enumerate(self.immediates):
2024 item = hex(imm)
2025 if i != len(self.immediates) - 1:
2026 item += ", " + WRAP_POINT
2027 else:
2028 item += ") " + WRAP_POINT
2029 items.append(item)
2030 lines = [] # type: list[str]
2031 for i, line_in in enumerate("".join(items).splitlines()):
2032 if i != 0:
2033 line_in = indent + line_in
2034 line_out = ""
2035 for part in line_in.split(WRAP_POINT):
2036 if line_out == "":
2037 line_out = part
2038 continue
2039 trial_line_out = line_out + part
2040 if len(trial_line_out.rstrip()) > wrap_width:
2041 lines.append(line_out.rstrip())
2042 line_out = indent + part
2043 else:
2044 line_out = trial_line_out
2045 lines.append(line_out.rstrip())
2046 return "\n".join(lines)
2047
2048 def sim(self, state):
2049 # type: (BaseSimState) -> None
2050 for inp in self.input_vals:
2051 try:
2052 val = state[inp]
2053 except KeyError:
2054 raise ValueError(f"SSAVal {inp} not yet assigned when "
2055 f"running {self}")
2056 if len(val) != inp.ty.reg_len:
2057 raise ValueError(
2058 f"value of SSAVal {inp} has wrong number of elements: "
2059 f"expected {inp.ty.reg_len} found "
2060 f"{len(val)}: {val!r}")
2061 if isinstance(state, PreRASimState):
2062 for out in self.outputs:
2063 if out in state.ssa_vals:
2064 if self.kind is OpKind.FuncArgR3:
2065 continue
2066 raise ValueError(f"SSAVal {out} already assigned before "
2067 f"running {self}")
2068 self.kind.sim(self, state)
2069 for out in self.outputs:
2070 try:
2071 val = state[out]
2072 except KeyError:
2073 raise ValueError(f"running {self} failed to assign to {out}")
2074 if len(val) != out.ty.reg_len:
2075 raise ValueError(
2076 f"value of SSAVal {out} has wrong number of elements: "
2077 f"expected {out.ty.reg_len} found "
2078 f"{len(val)}: {val!r}")
2079
2080 def gen_asm(self, state):
2081 # type: (GenAsmState) -> None
2082 all_loc_kinds = tuple(LocKind)
2083 for inp in self.input_vals:
2084 state.loc(inp, expected_kinds=all_loc_kinds)
2085 for out in self.outputs:
2086 state.loc(out, expected_kinds=all_loc_kinds)
2087 self.kind.gen_asm(self, state)
2088
2089
2090 @plain_data(frozen=True, repr=False)
2091 class BaseSimState(metaclass=ABCMeta):
2092 __slots__ = "memory",
2093
2094 def __init__(self, memory):
2095 # type: (dict[int, int]) -> None
2096 super().__init__()
2097 self.memory = memory # type: dict[int, int]
2098
2099 def load_byte(self, addr):
2100 # type: (int) -> int
2101 addr &= GPR_VALUE_MASK
2102 return self.memory.get(addr, 0) & 0xFF
2103
2104 def store_byte(self, addr, value):
2105 # type: (int, int) -> None
2106 addr &= GPR_VALUE_MASK
2107 value &= 0xFF
2108 self.memory[addr] = value
2109
2110 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
2111 # type: (int, int, bool) -> int
2112 if addr % size_in_bytes != 0:
2113 raise ValueError(f"address not aligned: {hex(addr)} "
2114 f"required alignment: {size_in_bytes}")
2115 retval = 0
2116 for i in range(size_in_bytes):
2117 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
2118 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
2119 retval -= 1 << size_in_bytes * BITS_IN_BYTE
2120 return retval
2121
2122 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
2123 # type: (int, int, int) -> None
2124 if addr % size_in_bytes != 0:
2125 raise ValueError(f"address not aligned: {hex(addr)} "
2126 f"required alignment: {size_in_bytes}")
2127 for i in range(size_in_bytes):
2128 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
2129
2130 def _memory__repr(self):
2131 # type: () -> str
2132 if len(self.memory) == 0:
2133 return "{}"
2134 keys = sorted(self.memory.keys(), reverse=True)
2135 CHUNK_SIZE = GPR_SIZE_IN_BYTES
2136 items = [] # type: list[str]
2137 while len(keys) != 0:
2138 addr = keys[-1]
2139 if (len(keys) >= CHUNK_SIZE
2140 and addr % CHUNK_SIZE == 0
2141 and keys[-CHUNK_SIZE:]
2142 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
2143 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
2144 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2145 keys[-CHUNK_SIZE:] = ()
2146 else:
2147 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2148 if len(items) == 1:
2149 return f"{{{items[0]}}}"
2150 items_str = ",\n".join(items)
2151 return f"{{\n{items_str}}}"
2152
2153 def __repr__(self):
2154 # type: () -> str
2155 field_vals = [] # type: list[str]
2156 for name in fields(self):
2157 try:
2158 value = getattr(self, name)
2159 except AttributeError:
2160 field_vals.append(f"{name}=<not set>")
2161 continue
2162 repr_fn = getattr(self, f"_{name}__repr", None)
2163 if callable(repr_fn):
2164 field_vals.append(f"{name}={repr_fn()}")
2165 else:
2166 field_vals.append(f"{name}={value!r}")
2167 field_vals_str = ", ".join(field_vals)
2168 return f"{self.__class__.__name__}({field_vals_str})"
2169
2170 @abstractmethod
2171 def __getitem__(self, ssa_val):
2172 # type: (SSAVal) -> tuple[int, ...]
2173 ...
2174
2175 @abstractmethod
2176 def __setitem__(self, ssa_val, value):
2177 # type: (SSAVal, tuple[int, ...]) -> None
2178 ...
2179
2180
2181 @plain_data(frozen=True, repr=False)
2182 @final
2183 class PreRASimState(BaseSimState):
2184 __slots__ = "ssa_vals",
2185
2186 def __init__(self, ssa_vals, memory):
2187 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2188 super().__init__(memory)
2189 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
2190
2191 def _ssa_vals__repr(self):
2192 # type: () -> str
2193 if len(self.ssa_vals) == 0:
2194 return "{}"
2195 items = [] # type: list[str]
2196 CHUNK_SIZE = 4
2197 for k, v in self.ssa_vals.items():
2198 element_strs = [] # type: list[str]
2199 for i, el in enumerate(v):
2200 if i % CHUNK_SIZE != 0:
2201 element_strs.append(" " + hex(el))
2202 else:
2203 element_strs.append("\n " + hex(el))
2204 if len(element_strs) <= CHUNK_SIZE:
2205 element_strs[0] = element_strs[0].lstrip()
2206 if len(element_strs) == 1:
2207 element_strs.append("")
2208 v_str = ",".join(element_strs)
2209 items.append(f"{k!r}: ({v_str})")
2210 if len(items) == 1 and "\n" not in items[0]:
2211 return f"{{{items[0]}}}"
2212 items_str = ",\n".join(items)
2213 return f"{{\n{items_str},\n}}"
2214
2215 def __getitem__(self, ssa_val):
2216 # type: (SSAVal) -> tuple[int, ...]
2217 return self.ssa_vals[ssa_val]
2218
2219 def __setitem__(self, ssa_val, value):
2220 # type: (SSAVal, tuple[int, ...]) -> None
2221 if len(value) != ssa_val.ty.reg_len:
2222 raise ValueError("value has wrong len")
2223 self.ssa_vals[ssa_val] = value
2224
2225
2226 @plain_data(frozen=True, repr=False)
2227 @final
2228 class PostRASimState(BaseSimState):
2229 __slots__ = "ssa_val_to_loc_map", "loc_values"
2230
2231 def __init__(self, ssa_val_to_loc_map, memory, loc_values):
2232 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2233 super().__init__(memory)
2234 self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
2235 for ssa_val, loc in self.ssa_val_to_loc_map.items():
2236 if ssa_val.ty != loc.ty:
2237 raise ValueError(
2238 f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2239 self.loc_values = loc_values
2240 for loc in self.loc_values.keys():
2241 if loc.reg_len != 1:
2242 raise ValueError(
2243 "loc_values must only contain Locs with reg_len=1, all "
2244 "larger Locs will be split into reg_len=1 sub-Locs")
2245
2246 def _loc_values__repr(self):
2247 # type: () -> str
2248 locs = sorted(self.loc_values.keys(), key=lambda v: (v.kind, v.start))
2249 items = [] # type: list[str]
2250 for loc in locs:
2251 items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
2252 items_str = ",\n".join(items)
2253 return f"{{\n{items_str},\n}}"
2254
2255 def __getitem__(self, ssa_val):
2256 # type: (SSAVal) -> tuple[int, ...]
2257 loc = self.ssa_val_to_loc_map[ssa_val]
2258 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2259 retval = [] # type: list[int]
2260 for i in range(loc.reg_len):
2261 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2262 retval.append(self.loc_values.get(subloc, 0))
2263 return tuple(retval)
2264
2265 def __setitem__(self, ssa_val, value):
2266 # type: (SSAVal, tuple[int, ...]) -> None
2267 if len(value) != ssa_val.ty.reg_len:
2268 raise ValueError("value has wrong len")
2269 loc = self.ssa_val_to_loc_map[ssa_val]
2270 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2271 for i in range(loc.reg_len):
2272 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2273 self.loc_values[subloc] = value[i]
2274
2275
2276 @plain_data(frozen=True)
2277 class GenAsmState:
2278 __slots__ = "allocated_locs", "output"
2279
2280 def __init__(self, allocated_locs, output=None):
2281 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2282 super().__init__()
2283 self.allocated_locs = FMap(allocated_locs)
2284 for ssa_val, loc in self.allocated_locs.items():
2285 if ssa_val.ty != loc.ty:
2286 raise ValueError(
2287 f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2288 if output is None:
2289 output = []
2290 self.output = output
2291
2292 __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
2293
2294 def loc(self, ssa_val_or_locs, expected_kinds):
2295 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2296 if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
2297 ssa_val_or_locs = [ssa_val_or_locs]
2298 locs = [] # type: list[Loc]
2299 for i in ssa_val_or_locs:
2300 if isinstance(i, SSAVal):
2301 locs.append(self.allocated_locs[i])
2302 else:
2303 locs.append(i)
2304 if len(locs) == 0:
2305 raise ValueError("invalid Loc sequence: must not be empty")
2306 retval = locs[0].try_concat(*locs[1:])
2307 if retval is None:
2308 raise ValueError("invalid Loc sequence: try_concat failed")
2309 if isinstance(expected_kinds, LocKind):
2310 expected_kinds = expected_kinds,
2311 if retval.kind not in expected_kinds:
2312 if len(expected_kinds) == 1:
2313 expected_kinds = expected_kinds[0]
2314 raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
2315 f"{retval.kind} expected {expected_kinds}")
2316 return retval
2317
2318 def gpr(self, ssa_val_or_locs, is_vec):
2319 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2320 loc = self.loc(ssa_val_or_locs, LocKind.GPR)
2321 vec_str = "*" if is_vec else ""
2322 return vec_str + str(loc.start)
2323
2324 def sgpr(self, ssa_val_or_locs):
2325 # type: (__SSA_VAL_OR_LOCS) -> str
2326 return self.gpr(ssa_val_or_locs, is_vec=False)
2327
2328 def vgpr(self, ssa_val_or_locs):
2329 # type: (__SSA_VAL_OR_LOCS) -> str
2330 return self.gpr(ssa_val_or_locs, is_vec=True)
2331
2332 def stack(self, ssa_val_or_locs):
2333 # type: (__SSA_VAL_OR_LOCS) -> str
2334 loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
2335 return f"{loc.start}(1)"
2336
2337 def writeln(self, *line_segments):
2338 # type: (*str) -> None
2339 line = " ".join(line_segments)
2340 if isinstance(self.output, list):
2341 self.output.append(line)
2342 else:
2343 self.output.write(line + "\n")