working on adding signed multiplication -- needed for toom-cook
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 import enum
2 from abc import ABCMeta, abstractmethod
3 from enum import Enum, unique
4 from functools import lru_cache, total_ordering
5 from io import StringIO
6 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
7 Mapping, Sequence, TypeVar, Union, overload)
8 from weakref import WeakValueDictionary as _WeakVDict
9
10 from cached_property import cached_property
11 from nmutil.plain_data import fields, plain_data
12
13 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
14 final)
15 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
16 OFSet, OSet)
17
18
19 GPR_SIZE_IN_BYTES = 8
20 BITS_IN_BYTE = 8
21 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
22 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
23
24
25 @final
26 class Fn:
27 def __init__(self):
28 self.ops = [] # type: list[Op]
29 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
30 self.__next_name_suffix = 2
31
32 def _add_op_with_unused_name(self, op, name=""):
33 # type: (Op, str) -> str
34 if op.fn is not self:
35 raise ValueError("can't add Op to wrong Fn")
36 if hasattr(op, "name"):
37 raise ValueError("Op already named")
38 orig_name = name
39 while True:
40 if name != "" and name not in self.__op_names:
41 self.__op_names[name] = op
42 return name
43 name = orig_name + str(self.__next_name_suffix)
44 self.__next_name_suffix += 1
45
46 def __repr__(self):
47 # type: () -> str
48 return "<Fn>"
49
50 def append_op(self, op):
51 # type: (Op) -> None
52 if op.fn is not self:
53 raise ValueError("can't add Op to wrong Fn")
54 self.ops.append(op)
55
56 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
57 maxvl=1):
58 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
59 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
60 input_vals=input_vals, immediates=immediates, name=name)
61 self.append_op(retval)
62 return retval
63
64 def sim(self, state):
65 # type: (BaseSimState) -> None
66 for op in self.ops:
67 op.sim(state)
68
69 def gen_asm(self, state):
70 # type: (GenAsmState) -> None
71 for op in self.ops:
72 op.gen_asm(state)
73
74 def pre_ra_insert_copies(self):
75 # type: () -> None
76 orig_ops = list(self.ops)
77 copied_outputs = {} # type: dict[SSAVal, SSAVal]
78 setvli_outputs = {} # type: dict[SSAVal, Op]
79 self.ops.clear()
80 for op in orig_ops:
81 for i in range(len(op.input_vals)):
82 inp = copied_outputs[op.input_vals[i]]
83 if inp.ty.base_ty is BaseTy.I64:
84 maxvl = inp.ty.reg_len
85 if inp.ty.reg_len != 1:
86 setvl = self.append_new_op(
87 OpKind.SetVLI, immediates=[maxvl],
88 name=f"{op.name}.inp{i}.setvl")
89 vl = setvl.outputs[0]
90 mv = self.append_new_op(
91 OpKind.VecCopyToReg, input_vals=[inp, vl],
92 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
93 else:
94 mv = self.append_new_op(
95 OpKind.CopyToReg, input_vals=[inp],
96 name=f"{op.name}.inp{i}.copy")
97 op.input_vals[i] = mv.outputs[0]
98 elif inp.ty.base_ty is BaseTy.CA \
99 or inp.ty.base_ty is BaseTy.VL_MAXVL:
100 # all copies would be no-ops, so we don't need to copy,
101 # though we do need to rematerialize SetVLI ops right
102 # before the ops VL
103 if inp in setvli_outputs:
104 setvl = self.append_new_op(
105 OpKind.SetVLI,
106 immediates=setvli_outputs[inp].immediates,
107 name=f"{op.name}.inp{i}.setvl")
108 inp = setvl.outputs[0]
109 op.input_vals[i] = inp
110 else:
111 assert_never(inp.ty.base_ty)
112 self.ops.append(op)
113 for i, out in enumerate(op.outputs):
114 if op.kind is OpKind.SetVLI:
115 setvli_outputs[out] = op
116 if out.ty.base_ty is BaseTy.I64:
117 maxvl = out.ty.reg_len
118 if out.ty.reg_len != 1:
119 setvl = self.append_new_op(
120 OpKind.SetVLI, immediates=[maxvl],
121 name=f"{op.name}.out{i}.setvl")
122 vl = setvl.outputs[0]
123 mv = self.append_new_op(
124 OpKind.VecCopyFromReg, input_vals=[out, vl],
125 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
126 else:
127 mv = self.append_new_op(
128 OpKind.CopyFromReg, input_vals=[out],
129 name=f"{op.name}.out{i}.copy")
130 copied_outputs[out] = mv.outputs[0]
131 elif out.ty.base_ty is BaseTy.CA \
132 or out.ty.base_ty is BaseTy.VL_MAXVL:
133 # all copies would be no-ops, so we don't need to copy
134 copied_outputs[out] = out
135 else:
136 assert_never(out.ty.base_ty)
137
138
139 @final
140 @unique
141 @total_ordering
142 class OpStage(Enum):
143 value: Literal[0, 1] # type: ignore
144
145 def __new__(cls, value):
146 # type: (int) -> OpStage
147 value = int(value)
148 if value not in (0, 1):
149 raise ValueError("invalid value")
150 retval = object.__new__(cls)
151 retval._value_ = value
152 return retval
153
154 Early = 0
155 """ early stage of Op execution, where all input reads occur.
156 all output writes with `write_stage == Early` occur here too, and therefore
157 conflict with input reads, telling the compiler that it that can't share
158 that output's register with any inputs that the output isn't tied to.
159
160 All outputs, even unused outputs, can't share registers with any other
161 outputs, independent of `write_stage` settings.
162 """
163 Late = 1
164 """ late stage of Op execution, where all output writes with
165 `write_stage == Late` occur, and therefore don't conflict with input reads,
166 telling the compiler that any inputs can safely use the same register as
167 those outputs.
168
169 All outputs, even unused outputs, can't share registers with any other
170 outputs, independent of `write_stage` settings.
171 """
172
173 def __repr__(self):
174 # type: () -> str
175 return f"OpStage.{self._name_}"
176
177 def __lt__(self, other):
178 # type: (OpStage | object) -> bool
179 if isinstance(other, OpStage):
180 return self.value < other.value
181 return NotImplemented
182
183
184 assert OpStage.Early < OpStage.Late, "early must be less than late"
185
186
187 @plain_data(frozen=True, unsafe_hash=True, repr=False)
188 @final
189 @total_ordering
190 class ProgramPoint(metaclass=InternedMeta):
191 __slots__ = "op_index", "stage"
192
193 def __init__(self, op_index, stage):
194 # type: (int, OpStage) -> None
195 self.op_index = op_index
196 self.stage = stage
197
198 @property
199 def int_value(self):
200 # type: () -> int
201 """ an integer representation of `self` such that it keeps ordering and
202 successor/predecessor relations.
203 """
204 return self.op_index * 2 + self.stage.value
205
206 @staticmethod
207 def from_int_value(int_value):
208 # type: (int) -> ProgramPoint
209 op_index, stage = divmod(int_value, 2)
210 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
211
212 def next(self, steps=1):
213 # type: (int) -> ProgramPoint
214 return ProgramPoint.from_int_value(self.int_value + steps)
215
216 def prev(self, steps=1):
217 # type: (int) -> ProgramPoint
218 return self.next(steps=-steps)
219
220 def __lt__(self, other):
221 # type: (ProgramPoint | Any) -> bool
222 if not isinstance(other, ProgramPoint):
223 return NotImplemented
224 if self.op_index != other.op_index:
225 return self.op_index < other.op_index
226 return self.stage < other.stage
227
228 def __repr__(self):
229 # type: () -> str
230 return f"<ops[{self.op_index}]:{self.stage._name_}>"
231
232
233 @plain_data(frozen=True, unsafe_hash=True, repr=False)
234 @final
235 class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
236 __slots__ = "start", "stop"
237
238 def __init__(self, start, stop):
239 # type: (ProgramPoint, ProgramPoint) -> None
240 self.start = start
241 self.stop = stop
242
243 @cached_property
244 def int_value_range(self):
245 # type: () -> range
246 return range(self.start.int_value, self.stop.int_value)
247
248 @staticmethod
249 def from_int_value_range(int_value_range):
250 # type: (range) -> ProgramRange
251 if int_value_range.step != 1:
252 raise ValueError("int_value_range must have step == 1")
253 return ProgramRange(
254 start=ProgramPoint.from_int_value(int_value_range.start),
255 stop=ProgramPoint.from_int_value(int_value_range.stop))
256
257 @overload
258 def __getitem__(self, __idx):
259 # type: (int) -> ProgramPoint
260 ...
261
262 @overload
263 def __getitem__(self, __idx):
264 # type: (slice) -> ProgramRange
265 ...
266
267 def __getitem__(self, __idx):
268 # type: (int | slice) -> ProgramPoint | ProgramRange
269 v = range(self.start.int_value, self.stop.int_value)[__idx]
270 if isinstance(v, int):
271 return ProgramPoint.from_int_value(v)
272 return ProgramRange.from_int_value_range(v)
273
274 def __len__(self):
275 # type: () -> int
276 return len(self.int_value_range)
277
278 def __iter__(self):
279 # type: () -> Iterator[ProgramPoint]
280 return map(ProgramPoint.from_int_value, self.int_value_range)
281
282 def __repr__(self):
283 # type: () -> str
284 start = repr(self.start).lstrip("<").rstrip(">")
285 stop = repr(self.stop).lstrip("<").rstrip(">")
286 return f"<range:{start}..{stop}>"
287
288
289 @plain_data(frozen=True, eq=False, repr=False)
290 @final
291 class FnAnalysis:
292 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
293 "def_program_ranges", "use_program_points",
294 "all_program_points")
295
296 def __init__(self, fn):
297 # type: (Fn) -> None
298 self.fn = fn
299 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
300 self.all_program_points = ProgramRange(
301 start=ProgramPoint(op_index=0, stage=OpStage.Early),
302 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
303 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
304 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
305 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
306 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
307 for op in fn.ops:
308 for use in op.input_uses:
309 uses[use.ssa_val].add(use)
310 use_program_point = self.__get_use_program_point(use)
311 use_program_points[use] = use_program_point
312 live_range_stops[use.ssa_val] = max(
313 live_range_stops[use.ssa_val], use_program_point.next())
314 for out in op.outputs:
315 uses[out] = OSet()
316 def_program_range = self.__get_def_program_range(out)
317 def_program_ranges[out] = def_program_range
318 live_range_stops[out] = def_program_range.stop
319 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
320 self.def_program_ranges = FMap(def_program_ranges)
321 self.use_program_points = FMap(use_program_points)
322 live_ranges = {} # type: dict[SSAVal, ProgramRange]
323 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
324 for ssa_val in uses.keys():
325 live_ranges[ssa_val] = live_range = ProgramRange(
326 start=self.def_program_ranges[ssa_val].start,
327 stop=live_range_stops[ssa_val])
328 for program_point in live_range:
329 live_at[program_point].add(ssa_val)
330 self.live_ranges = FMap(live_ranges)
331 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
332
333 def __get_def_program_range(self, ssa_val):
334 # type: (SSAVal) -> ProgramRange
335 write_stage = ssa_val.defining_descriptor.write_stage
336 start = ProgramPoint(
337 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
338 # always include late stage of ssa_val.op, to ensure outputs always
339 # overlap all other outputs.
340 # stop is exclusive, so we need the next program point.
341 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
342 return ProgramRange(start=start, stop=stop)
343
344 def __get_use_program_point(self, ssa_use):
345 # type: (SSAUse) -> ProgramPoint
346 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
347 "assumed here, ensured by GenericOpProperties.__init__"
348 return ProgramPoint(
349 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
350
351 def __eq__(self, other):
352 # type: (FnAnalysis | Any) -> bool
353 if isinstance(other, FnAnalysis):
354 return self.fn == other.fn
355 return NotImplemented
356
357 def __hash__(self):
358 # type: () -> int
359 return hash(self.fn)
360
361 def __repr__(self):
362 # type: () -> str
363 return "<FnAnalysis>"
364
365
366 @unique
367 @final
368 class BaseTy(Enum):
369 I64 = enum.auto()
370 CA = enum.auto()
371 VL_MAXVL = enum.auto()
372
373 @cached_property
374 def only_scalar(self):
375 # type: () -> bool
376 if self is BaseTy.I64:
377 return False
378 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
379 return True
380 else:
381 assert_never(self)
382
383 @cached_property
384 def max_reg_len(self):
385 # type: () -> int
386 if self is BaseTy.I64:
387 return 128
388 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
389 return 1
390 else:
391 assert_never(self)
392
393 def __repr__(self):
394 return "BaseTy." + self._name_
395
396
397 @plain_data(frozen=True, unsafe_hash=True, repr=False)
398 @final
399 class Ty(metaclass=InternedMeta):
400 __slots__ = "base_ty", "reg_len"
401
402 @staticmethod
403 def validate(base_ty, reg_len):
404 # type: (BaseTy, int) -> str | None
405 """ return a string with the error if the combination is invalid,
406 otherwise return None
407 """
408 if base_ty.only_scalar and reg_len != 1:
409 return f"can't create a vector of an only-scalar type: {base_ty}"
410 if reg_len < 1 or reg_len > base_ty.max_reg_len:
411 return "reg_len out of range"
412 return None
413
414 def __init__(self, base_ty, reg_len):
415 # type: (BaseTy, int) -> None
416 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
417 if msg is not None:
418 raise ValueError(msg)
419 self.base_ty = base_ty
420 self.reg_len = reg_len
421
422 def __repr__(self):
423 # type: () -> str
424 if self.reg_len != 1:
425 reg_len = f"*{self.reg_len}"
426 else:
427 reg_len = ""
428 return f"<{self.base_ty._name_}{reg_len}>"
429
430
431 @unique
432 @final
433 class LocKind(Enum):
434 GPR = enum.auto()
435 StackI64 = enum.auto()
436 CA = enum.auto()
437 VL_MAXVL = enum.auto()
438
439 @cached_property
440 def base_ty(self):
441 # type: () -> BaseTy
442 if self is LocKind.GPR or self is LocKind.StackI64:
443 return BaseTy.I64
444 if self is LocKind.CA:
445 return BaseTy.CA
446 if self is LocKind.VL_MAXVL:
447 return BaseTy.VL_MAXVL
448 else:
449 assert_never(self)
450
451 @cached_property
452 def loc_count(self):
453 # type: () -> int
454 if self is LocKind.StackI64:
455 return 512
456 if self is LocKind.GPR or self is LocKind.CA \
457 or self is LocKind.VL_MAXVL:
458 return self.base_ty.max_reg_len
459 else:
460 assert_never(self)
461
462 def __repr__(self):
463 return "LocKind." + self._name_
464
465
466 @final
467 @unique
468 class LocSubKind(Enum):
469 BASE_GPR = enum.auto()
470 SV_EXTRA2_VGPR = enum.auto()
471 SV_EXTRA2_SGPR = enum.auto()
472 SV_EXTRA3_VGPR = enum.auto()
473 SV_EXTRA3_SGPR = enum.auto()
474 StackI64 = enum.auto()
475 CA = enum.auto()
476 VL_MAXVL = enum.auto()
477
478 @cached_property
479 def kind(self):
480 # type: () -> LocKind
481 # pyright fails typechecking when using `in` here:
482 # reported: https://github.com/microsoft/pyright/issues/4102
483 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
484 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
485 LocSubKind.SV_EXTRA3_SGPR):
486 return LocKind.GPR
487 if self is LocSubKind.StackI64:
488 return LocKind.StackI64
489 if self is LocSubKind.CA:
490 return LocKind.CA
491 if self is LocSubKind.VL_MAXVL:
492 return LocKind.VL_MAXVL
493 assert_never(self)
494
495 @property
496 def base_ty(self):
497 return self.kind.base_ty
498
499 @lru_cache()
500 def allocatable_locs(self, ty):
501 # type: (Ty) -> LocSet
502 if ty.base_ty != self.base_ty:
503 raise ValueError("type mismatch")
504 if self is LocSubKind.BASE_GPR:
505 starts = range(32)
506 elif self is LocSubKind.SV_EXTRA2_VGPR:
507 starts = range(0, 128, 2)
508 elif self is LocSubKind.SV_EXTRA2_SGPR:
509 starts = range(64)
510 elif self is LocSubKind.SV_EXTRA3_VGPR \
511 or self is LocSubKind.SV_EXTRA3_SGPR:
512 starts = range(128)
513 elif self is LocSubKind.StackI64:
514 starts = range(LocKind.StackI64.loc_count)
515 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
516 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
517 else:
518 assert_never(self)
519 retval = [] # type: list[Loc]
520 for start in starts:
521 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
522 if loc is None:
523 continue
524 conflicts = False
525 for special_loc in SPECIAL_GPRS:
526 if loc.conflicts(special_loc):
527 conflicts = True
528 break
529 if not conflicts:
530 retval.append(loc)
531 return LocSet(retval)
532
533 def __repr__(self):
534 return "LocSubKind." + self._name_
535
536
537 @plain_data(frozen=True, unsafe_hash=True)
538 @final
539 class GenericTy(metaclass=InternedMeta):
540 __slots__ = "base_ty", "is_vec"
541
542 def __init__(self, base_ty, is_vec):
543 # type: (BaseTy, bool) -> None
544 self.base_ty = base_ty
545 if base_ty.only_scalar and is_vec:
546 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
547 self.is_vec = is_vec
548
549 def instantiate(self, maxvl):
550 # type: (int) -> Ty
551 # here's where subvl and elwid would be accounted for
552 if self.is_vec:
553 return Ty(self.base_ty, maxvl)
554 return Ty(self.base_ty, 1)
555
556 def can_instantiate_to(self, ty):
557 # type: (Ty) -> bool
558 if self.base_ty != ty.base_ty:
559 return False
560 if self.is_vec:
561 return True
562 return ty.reg_len == 1
563
564
565 @plain_data(frozen=True, unsafe_hash=True)
566 @final
567 class Loc(metaclass=InternedMeta):
568 __slots__ = "kind", "start", "reg_len"
569
570 @staticmethod
571 def validate(kind, start, reg_len):
572 # type: (LocKind, int, int) -> str | None
573 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
574 if msg is not None:
575 return msg
576 if reg_len > kind.loc_count:
577 return "invalid reg_len"
578 if start < 0 or start + reg_len > kind.loc_count:
579 return "start not in valid range"
580 return None
581
582 @staticmethod
583 def try_make(kind, start, reg_len):
584 # type: (LocKind, int, int) -> Loc | None
585 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
586 if msg is not None:
587 return None
588 return Loc(kind=kind, start=start, reg_len=reg_len)
589
590 def __init__(self, kind, start, reg_len):
591 # type: (LocKind, int, int) -> None
592 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
593 if msg is not None:
594 raise ValueError(msg)
595 self.kind = kind
596 self.reg_len = reg_len
597 self.start = start
598
599 def conflicts(self, other):
600 # type: (Loc) -> bool
601 return (self.kind == other.kind
602 and self.start < other.stop and other.start < self.stop)
603
604 @staticmethod
605 def make_ty(kind, reg_len):
606 # type: (LocKind, int) -> Ty
607 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
608
609 @cached_property
610 def ty(self):
611 # type: () -> Ty
612 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
613
614 @property
615 def stop(self):
616 # type: () -> int
617 return self.start + self.reg_len
618
619 def try_concat(self, *others):
620 # type: (*Loc | None) -> Loc | None
621 reg_len = self.reg_len
622 stop = self.stop
623 for other in others:
624 if other is None or other.kind != self.kind:
625 return None
626 if stop != other.start:
627 return None
628 stop = other.stop
629 reg_len += other.reg_len
630 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
631
632 def get_subloc_at_offset(self, subloc_ty, offset):
633 # type: (Ty, int) -> Loc
634 if subloc_ty.base_ty != self.kind.base_ty:
635 raise ValueError("BaseTy mismatch")
636 if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
637 raise ValueError("invalid sub-Loc: offset and/or "
638 "subloc_ty.reg_len out of range")
639 return Loc(kind=self.kind,
640 start=self.start + offset, reg_len=subloc_ty.reg_len)
641
642
643 SPECIAL_GPRS = (
644 Loc(kind=LocKind.GPR, start=0, reg_len=1),
645 Loc(kind=LocKind.GPR, start=1, reg_len=1),
646 Loc(kind=LocKind.GPR, start=2, reg_len=1),
647 Loc(kind=LocKind.GPR, start=13, reg_len=1),
648 )
649
650
651 @final
652 class LocSet(OFSet[Loc], metaclass=InternedMeta):
653 def __init__(self, __locs=()):
654 # type: (Iterable[Loc]) -> None
655 super().__init__(__locs)
656 if isinstance(__locs, LocSet):
657 self.__starts = __locs.starts
658 self.__ty = __locs.ty
659 return
660 starts = {i: BitSet() for i in LocKind}
661 ty = None # type: None | Ty
662 for loc in self:
663 if ty is None:
664 ty = loc.ty
665 if ty != loc.ty:
666 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
667 starts[loc.kind].add(loc.start)
668 self.__starts = FMap(
669 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
670 self.__ty = ty
671
672 @property
673 def starts(self):
674 # type: () -> FMap[LocKind, FBitSet]
675 return self.__starts
676
677 @property
678 def ty(self):
679 # type: () -> Ty | None
680 return self.__ty
681
682 @cached_property
683 def stops(self):
684 # type: () -> FMap[LocKind, FBitSet]
685 if self.ty is None:
686 return FMap()
687 sh = self.ty.reg_len
688 return FMap(
689 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
690
691 @property
692 def kinds(self):
693 # type: () -> AbstractSet[LocKind]
694 return self.starts.keys()
695
696 @property
697 def reg_len(self):
698 # type: () -> int | None
699 if self.ty is None:
700 return None
701 return self.ty.reg_len
702
703 @property
704 def base_ty(self):
705 # type: () -> BaseTy | None
706 if self.ty is None:
707 return None
708 return self.ty.base_ty
709
710 def concat(self, *others):
711 # type: (*LocSet) -> LocSet
712 if self.ty is None:
713 return LocSet()
714 base_ty = self.ty.base_ty
715 reg_len = self.ty.reg_len
716 starts = {k: BitSet(v) for k, v in self.starts.items()}
717 for other in others:
718 if other.ty is None:
719 return LocSet()
720 if other.ty.base_ty != base_ty:
721 return LocSet()
722 for kind, other_starts in other.starts.items():
723 if kind not in starts:
724 continue
725 starts[kind].bits &= other_starts.bits >> reg_len
726 if starts[kind] == 0:
727 del starts[kind]
728 if len(starts) == 0:
729 return LocSet()
730 reg_len += other.ty.reg_len
731
732 def locs():
733 # type: () -> Iterable[Loc]
734 for kind, v in starts.items():
735 for start in v:
736 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
737 if loc is not None:
738 yield loc
739 return LocSet(locs())
740
741 @lru_cache(maxsize=None, typed=True)
742 def max_conflicts_with(self, other):
743 # type: (LocSet | Loc) -> int
744 """the largest number of Locs in `self` that a single Loc
745 from `other` can conflict with
746 """
747 if isinstance(other, LocSet):
748 return max(self.max_conflicts_with(i) for i in other)
749 else:
750 return sum(other.conflicts(i) for i in self)
751
752 def __repr__(self):
753 return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
754
755
756 @plain_data(frozen=True, unsafe_hash=True)
757 @final
758 class GenericOperandDesc(metaclass=InternedMeta):
759 """generic Op operand descriptor"""
760 __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
761 "write_stage")
762
763 def __init__(
764 self, ty, # type: GenericTy
765 sub_kinds, # type: Iterable[LocSubKind]
766 *,
767 fixed_loc=None, # type: Loc | None
768 tied_input_index=None, # type: int | None
769 spread=False, # type: bool
770 write_stage=OpStage.Early, # type: OpStage
771 ):
772 # type: (...) -> None
773 self.ty = ty
774 self.sub_kinds = OFSet(sub_kinds)
775 if len(self.sub_kinds) == 0:
776 raise ValueError("sub_kinds can't be empty")
777 self.fixed_loc = fixed_loc
778 if fixed_loc is not None:
779 if tied_input_index is not None:
780 raise ValueError("operand can't be both tied and fixed")
781 if not ty.can_instantiate_to(fixed_loc.ty):
782 raise ValueError(
783 f"fixed_loc has incompatible type for given generic "
784 f"type: fixed_loc={fixed_loc} generic ty={ty}")
785 if len(self.sub_kinds) != 1:
786 raise ValueError(
787 "multiple sub_kinds not allowed for fixed operand")
788 for sub_kind in self.sub_kinds:
789 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
790 raise ValueError(
791 f"fixed_loc not in given sub_kind: "
792 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
793 for sub_kind in self.sub_kinds:
794 if sub_kind.base_ty != ty.base_ty:
795 raise ValueError(f"sub_kind is incompatible with type: "
796 f"sub_kind={sub_kind} ty={ty}")
797 if tied_input_index is not None and tied_input_index < 0:
798 raise ValueError("invalid tied_input_index")
799 self.tied_input_index = tied_input_index
800 self.spread = spread
801 if spread:
802 if self.tied_input_index is not None:
803 raise ValueError("operand can't be both spread and tied")
804 if self.fixed_loc is not None:
805 raise ValueError("operand can't be both spread and fixed")
806 if self.ty.is_vec:
807 raise ValueError("operand can't be both spread and vector")
808 self.write_stage = write_stage
809
810 @cached_property
811 def ty_before_spread(self):
812 # type: () -> GenericTy
813 if self.spread:
814 return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
815 return self.ty
816
817 def tied_to_input(self, tied_input_index):
818 # type: (int) -> Self
819 return GenericOperandDesc(self.ty, self.sub_kinds,
820 tied_input_index=tied_input_index,
821 write_stage=self.write_stage)
822
823 def with_fixed_loc(self, fixed_loc):
824 # type: (Loc) -> Self
825 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
826 write_stage=self.write_stage)
827
828 def with_write_stage(self, write_stage):
829 # type: (OpStage) -> Self
830 return GenericOperandDesc(self.ty, self.sub_kinds,
831 fixed_loc=self.fixed_loc,
832 tied_input_index=self.tied_input_index,
833 spread=self.spread,
834 write_stage=write_stage)
835
836 def instantiate(self, maxvl):
837 # type: (int) -> Iterable[OperandDesc]
838 # assumes all spread operands have ty.reg_len = 1
839 rep_count = 1
840 if self.spread:
841 rep_count = maxvl
842 ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
843
844 def locs_before_spread():
845 # type: () -> Iterable[Loc]
846 if self.fixed_loc is not None:
847 if ty_before_spread != self.fixed_loc.ty:
848 raise ValueError(
849 f"instantiation failed: type mismatch with fixed_loc: "
850 f"instantiated type: {ty_before_spread} "
851 f"fixed_loc: {self.fixed_loc}")
852 yield self.fixed_loc
853 return
854 for sub_kind in self.sub_kinds:
855 yield from sub_kind.allocatable_locs(ty_before_spread)
856 loc_set_before_spread = LocSet(locs_before_spread())
857 for idx in range(rep_count):
858 if not self.spread:
859 idx = None
860 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
861 tied_input_index=self.tied_input_index,
862 spread_index=idx, write_stage=self.write_stage)
863
864
865 @plain_data(frozen=True, unsafe_hash=True)
866 @final
867 class OperandDesc(metaclass=InternedMeta):
868 """Op operand descriptor"""
869 __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
870 "write_stage")
871
872 def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
873 write_stage):
874 # type: (LocSet, int | None, int | None, OpStage) -> None
875 if len(loc_set_before_spread) == 0:
876 raise ValueError("loc_set_before_spread must not be empty")
877 self.loc_set_before_spread = loc_set_before_spread
878 self.tied_input_index = tied_input_index
879 if self.tied_input_index is not None and spread_index is not None:
880 raise ValueError("operand can't be both spread and tied")
881 self.spread_index = spread_index
882 self.write_stage = write_stage
883
884 @cached_property
885 def ty_before_spread(self):
886 # type: () -> Ty
887 ty = self.loc_set_before_spread.ty
888 assert ty is not None, (
889 "__init__ checked that the LocSet isn't empty, "
890 "non-empty LocSets should always have ty set")
891 return ty
892
893 @cached_property
894 def ty(self):
895 """ Ty after any spread is applied """
896 if self.spread_index is not None:
897 # assumes all spread operands have ty.reg_len = 1
898 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
899 return self.ty_before_spread
900
901 @property
902 def reg_offset_in_unspread(self):
903 """ the number of reg-sized slots in the unspread Loc before self's Loc
904
905 e.g. if the unspread Loc containing self is:
906 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
907 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
908 then reg_offset_into_unspread == 2 == 10 - 8
909 """
910 if self.spread_index is None:
911 return 0
912 return self.spread_index * self.ty.reg_len
913
914
915 OD_BASE_SGPR = GenericOperandDesc(
916 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
917 sub_kinds=[LocSubKind.BASE_GPR])
918 OD_EXTRA3_SGPR = GenericOperandDesc(
919 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
920 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
921 OD_EXTRA3_VGPR = GenericOperandDesc(
922 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
923 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
924 OD_EXTRA2_SGPR = GenericOperandDesc(
925 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
926 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
927 OD_EXTRA2_VGPR = GenericOperandDesc(
928 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
929 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
930 OD_CA = GenericOperandDesc(
931 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
932 sub_kinds=[LocSubKind.CA])
933 OD_VL = GenericOperandDesc(
934 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
935 sub_kinds=[LocSubKind.VL_MAXVL])
936
937
938 @plain_data(frozen=True, unsafe_hash=True)
939 @final
940 class GenericOpProperties(metaclass=InternedMeta):
941 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
942 "is_copy", "is_load_immediate", "has_side_effects")
943
944 def __init__(
945 self, demo_asm, # type: str
946 inputs, # type: Iterable[GenericOperandDesc]
947 outputs, # type: Iterable[GenericOperandDesc]
948 immediates=(), # type: Iterable[range]
949 is_copy=False, # type: bool
950 is_load_immediate=False, # type: bool
951 has_side_effects=False, # type: bool
952 ):
953 # type: (...) -> None
954 self.demo_asm = demo_asm # type: str
955 self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
956 for inp in self.inputs:
957 if inp.tied_input_index is not None:
958 raise ValueError(
959 f"tied_input_index is not allowed on inputs: {inp}")
960 if inp.write_stage is not OpStage.Early:
961 raise ValueError(
962 f"write_stage is not allowed on inputs: {inp}")
963 self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
964 fixed_locs = [] # type: list[tuple[Loc, int]]
965 for idx, out in enumerate(self.outputs):
966 if out.tied_input_index is not None:
967 if out.tied_input_index >= len(self.inputs):
968 raise ValueError(f"tied_input_index out of range: {out}")
969 tied_inp = self.inputs[out.tied_input_index]
970 expected_out = tied_inp.tied_to_input(out.tied_input_index) \
971 .with_write_stage(out.write_stage)
972 if expected_out != out:
973 raise ValueError(f"output can't be tied to non-equivalent "
974 f"input: {out} tied to {tied_inp}")
975 if out.fixed_loc is not None:
976 for other_fixed_loc, other_idx in fixed_locs:
977 if not other_fixed_loc.conflicts(out.fixed_loc):
978 continue
979 raise ValueError(
980 f"conflicting fixed_locs: outputs[{idx}] and "
981 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
982 f"with {other_fixed_loc}")
983 fixed_locs.append((out.fixed_loc, idx))
984 self.immediates = tuple(immediates) # type: tuple[range, ...]
985 self.is_copy = is_copy # type: bool
986 self.is_load_immediate = is_load_immediate # type: bool
987 self.has_side_effects = has_side_effects # type: bool
988
989
990 @plain_data(frozen=True, unsafe_hash=True)
991 @final
992 class OpProperties(metaclass=InternedMeta):
993 __slots__ = "kind", "inputs", "outputs", "maxvl"
994
995 def __init__(self, kind, maxvl):
996 # type: (OpKind, int) -> None
997 self.kind = kind # type: OpKind
998 inputs = [] # type: list[OperandDesc]
999 for inp in self.generic.inputs:
1000 inputs.extend(inp.instantiate(maxvl=maxvl))
1001 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
1002 outputs = [] # type: list[OperandDesc]
1003 for out in self.generic.outputs:
1004 outputs.extend(out.instantiate(maxvl=maxvl))
1005 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
1006 self.maxvl = maxvl # type: int
1007
1008 @property
1009 def generic(self):
1010 # type: () -> GenericOpProperties
1011 return self.kind.properties
1012
1013 @property
1014 def immediates(self):
1015 # type: () -> tuple[range, ...]
1016 return self.generic.immediates
1017
1018 @property
1019 def demo_asm(self):
1020 # type: () -> str
1021 return self.generic.demo_asm
1022
1023 @property
1024 def is_copy(self):
1025 # type: () -> bool
1026 return self.generic.is_copy
1027
1028 @property
1029 def is_load_immediate(self):
1030 # type: () -> bool
1031 return self.generic.is_load_immediate
1032
1033 @property
1034 def has_side_effects(self):
1035 # type: () -> bool
1036 return self.generic.has_side_effects
1037
1038
1039 IMM_S16 = range(-1 << 15, 1 << 15)
1040
1041 _SIM_FN = Callable[["Op", "BaseSimState"], None]
1042 _SIM_FN2 = Callable[[], _SIM_FN]
1043 _SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1044 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
1045 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
1046 _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1047
1048
1049 @unique
1050 @final
1051 class OpKind(Enum):
1052 def __init__(self, properties):
1053 # type: (GenericOpProperties) -> None
1054 super().__init__()
1055 self.__properties = properties
1056
1057 @property
1058 def properties(self):
1059 # type: () -> GenericOpProperties
1060 return self.__properties
1061
1062 def instantiate(self, maxvl):
1063 # type: (int) -> OpProperties
1064 return OpProperties(self, maxvl=maxvl)
1065
1066 def __repr__(self):
1067 # type: () -> str
1068 return "OpKind." + self._name_
1069
1070 @cached_property
1071 def sim(self):
1072 # type: () -> _SIM_FN
1073 return _SIM_FNS[self.properties]()
1074
1075 @cached_property
1076 def gen_asm(self):
1077 # type: () -> _GEN_ASM_FN
1078 return _GEN_ASMS[self.properties]()
1079
1080 @staticmethod
1081 def __clearca_sim(op, state):
1082 # type: (Op, BaseSimState) -> None
1083 state[op.outputs[0]] = False,
1084
1085 @staticmethod
1086 def __clearca_gen_asm(op, state):
1087 # type: (Op, GenAsmState) -> None
1088 state.writeln("addic 0, 0, 0")
1089 ClearCA = GenericOpProperties(
1090 demo_asm="addic 0, 0, 0",
1091 inputs=[],
1092 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1093 )
1094 _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
1095 _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
1096
1097 @staticmethod
1098 def __setca_sim(op, state):
1099 # type: (Op, BaseSimState) -> None
1100 state[op.outputs[0]] = True,
1101
1102 @staticmethod
1103 def __setca_gen_asm(op, state):
1104 # type: (Op, GenAsmState) -> None
1105 state.writeln("subfc 0, 0, 0")
1106 SetCA = GenericOpProperties(
1107 demo_asm="subfc 0, 0, 0",
1108 inputs=[],
1109 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1110 )
1111 _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
1112 _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
1113
1114 @staticmethod
1115 def __svadde_sim(op, state):
1116 # type: (Op, BaseSimState) -> None
1117 RA = state[op.input_vals[0]]
1118 RB = state[op.input_vals[1]]
1119 carry, = state[op.input_vals[2]]
1120 VL, = state[op.input_vals[3]]
1121 RT = [] # type: list[int]
1122 for i in range(VL):
1123 v = RA[i] + RB[i] + carry
1124 RT.append(v & GPR_VALUE_MASK)
1125 carry = (v >> GPR_SIZE_IN_BITS) != 0
1126 state[op.outputs[0]] = tuple(RT)
1127 state[op.outputs[1]] = carry,
1128
1129 @staticmethod
1130 def __svadde_gen_asm(op, state):
1131 # type: (Op, GenAsmState) -> None
1132 RT = state.vgpr(op.outputs[0])
1133 RA = state.vgpr(op.input_vals[0])
1134 RB = state.vgpr(op.input_vals[1])
1135 state.writeln(f"sv.adde {RT}, {RA}, {RB}")
1136 SvAddE = GenericOpProperties(
1137 demo_asm="sv.adde *RT, *RA, *RB",
1138 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1139 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1140 )
1141 _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
1142 _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
1143
1144 @staticmethod
1145 def __addze_sim(op, state):
1146 # type: (Op, BaseSimState) -> None
1147 RA, = state[op.input_vals[0]]
1148 carry, = state[op.input_vals[1]]
1149 v = RA + carry
1150 RT = v & GPR_VALUE_MASK
1151 carry = (v >> GPR_SIZE_IN_BITS) != 0
1152 state[op.outputs[0]] = RT,
1153 state[op.outputs[1]] = carry,
1154
1155 @staticmethod
1156 def __addze_gen_asm(op, state):
1157 # type: (Op, GenAsmState) -> None
1158 RT = state.vgpr(op.outputs[0])
1159 RA = state.vgpr(op.input_vals[0])
1160 state.writeln(f"addze {RT}, {RA}")
1161 AddZE = GenericOpProperties(
1162 demo_asm="addze RT, RA",
1163 inputs=[OD_BASE_SGPR, OD_CA],
1164 outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
1165 )
1166 _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
1167 _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
1168
1169 @staticmethod
1170 def __svsubfe_sim(op, state):
1171 # type: (Op, BaseSimState) -> None
1172 RA = state[op.input_vals[0]]
1173 RB = state[op.input_vals[1]]
1174 carry, = state[op.input_vals[2]]
1175 VL, = state[op.input_vals[3]]
1176 RT = [] # type: list[int]
1177 for i in range(VL):
1178 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1179 RT.append(v & GPR_VALUE_MASK)
1180 carry = (v >> GPR_SIZE_IN_BITS) != 0
1181 state[op.outputs[0]] = tuple(RT)
1182 state[op.outputs[1]] = carry,
1183
1184 @staticmethod
1185 def __svsubfe_gen_asm(op, state):
1186 # type: (Op, GenAsmState) -> None
1187 RT = state.vgpr(op.outputs[0])
1188 RA = state.vgpr(op.input_vals[0])
1189 RB = state.vgpr(op.input_vals[1])
1190 state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
1191 SvSubFE = GenericOpProperties(
1192 demo_asm="sv.subfe *RT, *RA, *RB",
1193 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1194 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1195 )
1196 _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
1197 _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
1198
1199 @staticmethod
1200 def __svmaddedu_sim(op, state):
1201 # type: (Op, BaseSimState) -> None
1202 RA = state[op.input_vals[0]]
1203 RB, = state[op.input_vals[1]]
1204 carry, = state[op.input_vals[2]]
1205 VL, = state[op.input_vals[3]]
1206 RT = [] # type: list[int]
1207 for i in range(VL):
1208 v = RA[i] * RB + carry
1209 RT.append(v & GPR_VALUE_MASK)
1210 carry = v >> GPR_SIZE_IN_BITS
1211 state[op.outputs[0]] = tuple(RT)
1212 state[op.outputs[1]] = carry,
1213
1214 @staticmethod
1215 def __svmaddedu_gen_asm(op, state):
1216 # type: (Op, GenAsmState) -> None
1217 RT = state.vgpr(op.outputs[0])
1218 RA = state.vgpr(op.input_vals[0])
1219 RB = state.sgpr(op.input_vals[1])
1220 RC = state.sgpr(op.input_vals[2])
1221 state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1222 SvMAddEDU = GenericOpProperties(
1223 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1224 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1225 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1226 )
1227 _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
1228 _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
1229
1230 @staticmethod
1231 def __sradi_sim(op, state):
1232 # type: (Op, BaseSimState) -> None
1233 rs, = state[op.input_vals[0]]
1234 imm = op.immediates[0]
1235 if rs >= 1 << (GPR_SIZE_IN_BITS - 1):
1236 rs -= 1 << GPR_SIZE_IN_BITS
1237 v = rs >> imm
1238 RA = v & GPR_VALUE_MASK
1239 CA = (RA << imm) != rs
1240 state[op.outputs[0]] = RA,
1241 state[op.outputs[1]] = CA,
1242
1243 @staticmethod
1244 def __sradi_gen_asm(op, state):
1245 # type: (Op, GenAsmState) -> None
1246 RA = state.sgpr(op.outputs[0])
1247 RS = state.sgpr(op.input_vals[1])
1248 imm = op.immediates[0]
1249 state.writeln(f"sradi {RA}, {RS}, {imm}")
1250 SRADI = GenericOpProperties(
1251 demo_asm="sradi RA, RS, imm",
1252 inputs=[OD_BASE_SGPR],
1253 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late),
1254 OD_CA.with_write_stage(OpStage.Late)],
1255 immediates=[range(GPR_SIZE_IN_BITS)],
1256 )
1257 _SIM_FNS[SRADI] = lambda: OpKind.__sradi_sim
1258 _GEN_ASMS[SRADI] = lambda: OpKind.__sradi_gen_asm
1259
1260 @staticmethod
1261 def __setvli_sim(op, state):
1262 # type: (Op, BaseSimState) -> None
1263 state[op.outputs[0]] = op.immediates[0],
1264
1265 @staticmethod
1266 def __setvli_gen_asm(op, state):
1267 # type: (Op, GenAsmState) -> None
1268 imm = op.immediates[0]
1269 state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
1270 SetVLI = GenericOpProperties(
1271 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1272 inputs=(),
1273 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1274 immediates=[range(1, 65)],
1275 is_load_immediate=True,
1276 )
1277 _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
1278 _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
1279
1280 @staticmethod
1281 def __svli_sim(op, state):
1282 # type: (Op, BaseSimState) -> None
1283 VL, = state[op.input_vals[0]]
1284 imm = op.immediates[0] & GPR_VALUE_MASK
1285 state[op.outputs[0]] = (imm,) * VL
1286
1287 @staticmethod
1288 def __svli_gen_asm(op, state):
1289 # type: (Op, GenAsmState) -> None
1290 RT = state.vgpr(op.outputs[0])
1291 imm = op.immediates[0]
1292 state.writeln(f"sv.addi {RT}, 0, {imm}")
1293 SvLI = GenericOpProperties(
1294 demo_asm="sv.addi *RT, 0, imm",
1295 inputs=[OD_VL],
1296 outputs=[OD_EXTRA3_VGPR],
1297 immediates=[IMM_S16],
1298 is_load_immediate=True,
1299 )
1300 _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
1301 _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
1302
1303 @staticmethod
1304 def __li_sim(op, state):
1305 # type: (Op, BaseSimState) -> None
1306 imm = op.immediates[0] & GPR_VALUE_MASK
1307 state[op.outputs[0]] = imm,
1308
1309 @staticmethod
1310 def __li_gen_asm(op, state):
1311 # type: (Op, GenAsmState) -> None
1312 RT = state.sgpr(op.outputs[0])
1313 imm = op.immediates[0]
1314 state.writeln(f"addi {RT}, 0, {imm}")
1315 LI = GenericOpProperties(
1316 demo_asm="addi RT, 0, imm",
1317 inputs=(),
1318 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1319 immediates=[IMM_S16],
1320 is_load_immediate=True,
1321 )
1322 _SIM_FNS[LI] = lambda: OpKind.__li_sim
1323 _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
1324
1325 @staticmethod
1326 def __veccopytoreg_sim(op, state):
1327 # type: (Op, BaseSimState) -> None
1328 state[op.outputs[0]] = state[op.input_vals[0]]
1329
1330 @staticmethod
1331 def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
1332 # type: (Loc, Loc, bool, GenAsmState) -> None
1333 sv = "sv." if is_vec else ""
1334 rev = ""
1335 if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
1336 rev = "/mrr"
1337 if src_loc == dest_loc:
1338 return # no-op
1339 if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1340 raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
1341 if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1342 raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
1343 if src_loc.kind is LocKind.StackI64:
1344 if dest_loc.kind is LocKind.StackI64:
1345 raise ValueError(
1346 f"can't copy from stack to stack: {src_loc} {dest_loc}")
1347 elif dest_loc.kind is not LocKind.GPR:
1348 assert_never(dest_loc.kind)
1349 src = state.stack(src_loc)
1350 dest = state.gpr(dest_loc, is_vec=is_vec)
1351 state.writeln(f"{sv}ld {dest}, {src}")
1352 elif dest_loc.kind is LocKind.StackI64:
1353 if src_loc.kind is not LocKind.GPR:
1354 assert_never(src_loc.kind)
1355 src = state.gpr(src_loc, is_vec=is_vec)
1356 dest = state.stack(dest_loc)
1357 state.writeln(f"{sv}std {src}, {dest}")
1358 elif src_loc.kind is LocKind.GPR:
1359 if dest_loc.kind is not LocKind.GPR:
1360 assert_never(dest_loc.kind)
1361 src = state.gpr(src_loc, is_vec=is_vec)
1362 dest = state.gpr(dest_loc, is_vec=is_vec)
1363 state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
1364 else:
1365 assert_never(src_loc.kind)
1366
1367 @staticmethod
1368 def __veccopytoreg_gen_asm(op, state):
1369 # type: (Op, GenAsmState) -> None
1370 OpKind.__copy_to_from_reg_gen_asm(
1371 src_loc=state.loc(
1372 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1373 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1374 is_vec=True, state=state)
1375
1376 VecCopyToReg = GenericOpProperties(
1377 demo_asm="sv.mv dest, src",
1378 inputs=[GenericOperandDesc(
1379 ty=GenericTy(BaseTy.I64, is_vec=True),
1380 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1381 ), OD_VL],
1382 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1383 is_copy=True,
1384 )
1385 _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
1386 _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
1387
1388 @staticmethod
1389 def __veccopyfromreg_sim(op, state):
1390 # type: (Op, BaseSimState) -> None
1391 state[op.outputs[0]] = state[op.input_vals[0]]
1392
1393 @staticmethod
1394 def __veccopyfromreg_gen_asm(op, state):
1395 # type: (Op, GenAsmState) -> None
1396 OpKind.__copy_to_from_reg_gen_asm(
1397 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1398 dest_loc=state.loc(
1399 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1400 is_vec=True, state=state)
1401 VecCopyFromReg = GenericOpProperties(
1402 demo_asm="sv.mv dest, src",
1403 inputs=[OD_EXTRA3_VGPR, OD_VL],
1404 outputs=[GenericOperandDesc(
1405 ty=GenericTy(BaseTy.I64, is_vec=True),
1406 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1407 write_stage=OpStage.Late,
1408 )],
1409 is_copy=True,
1410 )
1411 _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
1412 _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
1413
1414 @staticmethod
1415 def __copytoreg_sim(op, state):
1416 # type: (Op, BaseSimState) -> None
1417 state[op.outputs[0]] = state[op.input_vals[0]]
1418
1419 @staticmethod
1420 def __copytoreg_gen_asm(op, state):
1421 # type: (Op, GenAsmState) -> None
1422 OpKind.__copy_to_from_reg_gen_asm(
1423 src_loc=state.loc(
1424 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1425 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1426 is_vec=False, state=state)
1427 CopyToReg = GenericOpProperties(
1428 demo_asm="mv dest, src",
1429 inputs=[GenericOperandDesc(
1430 ty=GenericTy(BaseTy.I64, is_vec=False),
1431 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1432 LocSubKind.StackI64],
1433 )],
1434 outputs=[GenericOperandDesc(
1435 ty=GenericTy(BaseTy.I64, is_vec=False),
1436 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1437 write_stage=OpStage.Late,
1438 )],
1439 is_copy=True,
1440 )
1441 _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
1442 _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
1443
1444 @staticmethod
1445 def __copyfromreg_sim(op, state):
1446 # type: (Op, BaseSimState) -> None
1447 state[op.outputs[0]] = state[op.input_vals[0]]
1448
1449 @staticmethod
1450 def __copyfromreg_gen_asm(op, state):
1451 # type: (Op, GenAsmState) -> None
1452 OpKind.__copy_to_from_reg_gen_asm(
1453 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1454 dest_loc=state.loc(
1455 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1456 is_vec=False, state=state)
1457 CopyFromReg = GenericOpProperties(
1458 demo_asm="mv dest, src",
1459 inputs=[GenericOperandDesc(
1460 ty=GenericTy(BaseTy.I64, is_vec=False),
1461 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1462 )],
1463 outputs=[GenericOperandDesc(
1464 ty=GenericTy(BaseTy.I64, is_vec=False),
1465 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1466 LocSubKind.StackI64],
1467 write_stage=OpStage.Late,
1468 )],
1469 is_copy=True,
1470 )
1471 _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
1472 _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
1473
1474 @staticmethod
1475 def __concat_sim(op, state):
1476 # type: (Op, BaseSimState) -> None
1477 state[op.outputs[0]] = tuple(
1478 state[i][0] for i in op.input_vals[:-1])
1479
1480 @staticmethod
1481 def __concat_gen_asm(op, state):
1482 # type: (Op, GenAsmState) -> None
1483 OpKind.__copy_to_from_reg_gen_asm(
1484 src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
1485 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1486 is_vec=True, state=state)
1487 Concat = GenericOpProperties(
1488 demo_asm="sv.mv dest, src",
1489 inputs=[GenericOperandDesc(
1490 ty=GenericTy(BaseTy.I64, is_vec=False),
1491 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1492 spread=True,
1493 ), OD_VL],
1494 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1495 is_copy=True,
1496 )
1497 _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
1498 _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
1499
1500 @staticmethod
1501 def __spread_sim(op, state):
1502 # type: (Op, BaseSimState) -> None
1503 for idx, inp in enumerate(state[op.input_vals[0]]):
1504 state[op.outputs[idx]] = inp,
1505
1506 @staticmethod
1507 def __spread_gen_asm(op, state):
1508 # type: (Op, GenAsmState) -> None
1509 OpKind.__copy_to_from_reg_gen_asm(
1510 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1511 dest_loc=state.loc(op.outputs, LocKind.GPR),
1512 is_vec=True, state=state)
1513 Spread = GenericOpProperties(
1514 demo_asm="sv.mv dest, src",
1515 inputs=[OD_EXTRA3_VGPR, OD_VL],
1516 outputs=[GenericOperandDesc(
1517 ty=GenericTy(BaseTy.I64, is_vec=False),
1518 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1519 spread=True,
1520 write_stage=OpStage.Late,
1521 )],
1522 is_copy=True,
1523 )
1524 _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
1525 _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
1526
1527 @staticmethod
1528 def __svld_sim(op, state):
1529 # type: (Op, BaseSimState) -> None
1530 RA, = state[op.input_vals[0]]
1531 VL, = state[op.input_vals[1]]
1532 addr = RA + op.immediates[0]
1533 RT = [] # type: list[int]
1534 for i in range(VL):
1535 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1536 RT.append(v & GPR_VALUE_MASK)
1537 state[op.outputs[0]] = tuple(RT)
1538
1539 @staticmethod
1540 def __svld_gen_asm(op, state):
1541 # type: (Op, GenAsmState) -> None
1542 RA = state.sgpr(op.input_vals[0])
1543 RT = state.vgpr(op.outputs[0])
1544 imm = op.immediates[0]
1545 state.writeln(f"sv.ld {RT}, {imm}({RA})")
1546 SvLd = GenericOpProperties(
1547 demo_asm="sv.ld *RT, imm(RA)",
1548 inputs=[OD_EXTRA3_SGPR, OD_VL],
1549 outputs=[OD_EXTRA3_VGPR],
1550 immediates=[IMM_S16],
1551 )
1552 _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
1553 _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
1554
1555 @staticmethod
1556 def __ld_sim(op, state):
1557 # type: (Op, BaseSimState) -> None
1558 RA, = state[op.input_vals[0]]
1559 addr = RA + op.immediates[0]
1560 v = state.load(addr)
1561 state[op.outputs[0]] = v & GPR_VALUE_MASK,
1562
1563 @staticmethod
1564 def __ld_gen_asm(op, state):
1565 # type: (Op, GenAsmState) -> None
1566 RA = state.sgpr(op.input_vals[0])
1567 RT = state.sgpr(op.outputs[0])
1568 imm = op.immediates[0]
1569 state.writeln(f"ld {RT}, {imm}({RA})")
1570 Ld = GenericOpProperties(
1571 demo_asm="ld RT, imm(RA)",
1572 inputs=[OD_BASE_SGPR],
1573 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1574 immediates=[IMM_S16],
1575 )
1576 _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
1577 _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
1578
1579 @staticmethod
1580 def __svstd_sim(op, state):
1581 # type: (Op, BaseSimState) -> None
1582 RS = state[op.input_vals[0]]
1583 RA, = state[op.input_vals[1]]
1584 VL, = state[op.input_vals[2]]
1585 addr = RA + op.immediates[0]
1586 for i in range(VL):
1587 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1588
1589 @staticmethod
1590 def __svstd_gen_asm(op, state):
1591 # type: (Op, GenAsmState) -> None
1592 RS = state.vgpr(op.input_vals[0])
1593 RA = state.sgpr(op.input_vals[1])
1594 imm = op.immediates[0]
1595 state.writeln(f"sv.std {RS}, {imm}({RA})")
1596 SvStd = GenericOpProperties(
1597 demo_asm="sv.std *RS, imm(RA)",
1598 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1599 outputs=[],
1600 immediates=[IMM_S16],
1601 has_side_effects=True,
1602 )
1603 _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
1604 _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
1605
1606 @staticmethod
1607 def __std_sim(op, state):
1608 # type: (Op, BaseSimState) -> None
1609 RS, = state[op.input_vals[0]]
1610 RA, = state[op.input_vals[1]]
1611 addr = RA + op.immediates[0]
1612 state.store(addr, value=RS)
1613
1614 @staticmethod
1615 def __std_gen_asm(op, state):
1616 # type: (Op, GenAsmState) -> None
1617 RS = state.sgpr(op.input_vals[0])
1618 RA = state.sgpr(op.input_vals[1])
1619 imm = op.immediates[0]
1620 state.writeln(f"std {RS}, {imm}({RA})")
1621 Std = GenericOpProperties(
1622 demo_asm="std RS, imm(RA)",
1623 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1624 outputs=[],
1625 immediates=[IMM_S16],
1626 has_side_effects=True,
1627 )
1628 _SIM_FNS[Std] = lambda: OpKind.__std_sim
1629 _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
1630
1631 @staticmethod
1632 def __funcargr3_sim(op, state):
1633 # type: (Op, BaseSimState) -> None
1634 pass # return value set before simulation
1635
1636 @staticmethod
1637 def __funcargr3_gen_asm(op, state):
1638 # type: (Op, GenAsmState) -> None
1639 pass # no instructions needed
1640 FuncArgR3 = GenericOpProperties(
1641 demo_asm="",
1642 inputs=[],
1643 outputs=[OD_BASE_SGPR.with_fixed_loc(
1644 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1645 )
1646 _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
1647 _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
1648
1649
1650 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1651 class SSAValOrUse(metaclass=InternedMeta):
1652 __slots__ = "op", "operand_idx"
1653
1654 def __init__(self, op, operand_idx):
1655 # type: (Op, int) -> None
1656 super().__init__()
1657 self.op = op
1658 if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
1659 raise ValueError("invalid operand_idx")
1660 self.operand_idx = operand_idx
1661
1662 @abstractmethod
1663 def __repr__(self):
1664 # type: () -> str
1665 ...
1666
1667 @property
1668 @abstractmethod
1669 def descriptor_array(self):
1670 # type: () -> tuple[OperandDesc, ...]
1671 ...
1672
1673 @cached_property
1674 def defining_descriptor(self):
1675 # type: () -> OperandDesc
1676 return self.descriptor_array[self.operand_idx]
1677
1678 @cached_property
1679 def ty(self):
1680 # type: () -> Ty
1681 return self.defining_descriptor.ty
1682
1683 @cached_property
1684 def ty_before_spread(self):
1685 # type: () -> Ty
1686 return self.defining_descriptor.ty_before_spread
1687
1688 @property
1689 def base_ty(self):
1690 # type: () -> BaseTy
1691 return self.ty_before_spread.base_ty
1692
1693 @property
1694 def reg_offset_in_unspread(self):
1695 """ the number of reg-sized slots in the unspread Loc before self's Loc
1696
1697 e.g. if the unspread Loc containing self is:
1698 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1699 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1700 then reg_offset_into_unspread == 2 == 10 - 8
1701 """
1702 return self.defining_descriptor.reg_offset_in_unspread
1703
1704 @property
1705 def unspread_start_idx(self):
1706 # type: () -> int
1707 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1708
1709 @property
1710 def unspread_start(self):
1711 # type: () -> Self
1712 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1713
1714
1715 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1716 @final
1717 class SSAVal(SSAValOrUse):
1718 __slots__ = ()
1719
1720 def __repr__(self):
1721 # type: () -> str
1722 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1723
1724 @cached_property
1725 def def_loc_set_before_spread(self):
1726 # type: () -> LocSet
1727 return self.defining_descriptor.loc_set_before_spread
1728
1729 @cached_property
1730 def descriptor_array(self):
1731 # type: () -> tuple[OperandDesc, ...]
1732 return self.op.properties.outputs
1733
1734 @cached_property
1735 def tied_input(self):
1736 # type: () -> None | SSAUse
1737 if self.defining_descriptor.tied_input_index is None:
1738 return None
1739 return SSAUse(op=self.op,
1740 operand_idx=self.defining_descriptor.tied_input_index)
1741
1742 @property
1743 def write_stage(self):
1744 # type: () -> OpStage
1745 return self.defining_descriptor.write_stage
1746
1747
1748 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1749 @final
1750 class SSAUse(SSAValOrUse):
1751 __slots__ = ()
1752
1753 @cached_property
1754 def use_loc_set_before_spread(self):
1755 # type: () -> LocSet
1756 return self.defining_descriptor.loc_set_before_spread
1757
1758 @cached_property
1759 def descriptor_array(self):
1760 # type: () -> tuple[OperandDesc, ...]
1761 return self.op.properties.inputs
1762
1763 def __repr__(self):
1764 # type: () -> str
1765 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1766
1767 @property
1768 def ssa_val(self):
1769 # type: () -> SSAVal
1770 return self.op.input_vals[self.operand_idx]
1771
1772 @ssa_val.setter
1773 def ssa_val(self, ssa_val):
1774 # type: (SSAVal) -> None
1775 self.op.input_vals[self.operand_idx] = ssa_val
1776
1777
1778 _T = TypeVar("_T")
1779 _Desc = TypeVar("_Desc")
1780
1781
1782 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
1783 @abstractmethod
1784 def _verify_write_with_desc(self, idx, item, desc):
1785 # type: (int, _T | Any, _Desc) -> None
1786 raise NotImplementedError
1787
1788 @final
1789 def _verify_write(self, idx, item):
1790 # type: (int | Any, _T | Any) -> int
1791 if not isinstance(idx, int):
1792 if isinstance(idx, slice):
1793 raise TypeError(
1794 f"can't write to slice of {self.__class__.__name__}")
1795 raise TypeError(f"can't write with index {idx!r}")
1796 # normalize idx, raising IndexError if it is out of range
1797 idx = range(len(self.descriptors))[idx]
1798 desc = self.descriptors[idx]
1799 self._verify_write_with_desc(idx, item, desc)
1800 return idx
1801
1802 def _on_set(self, idx, new_item, old_item):
1803 # type: (int, _T, _T | None) -> None
1804 pass
1805
1806 @abstractmethod
1807 def _get_descriptors(self):
1808 # type: () -> tuple[_Desc, ...]
1809 raise NotImplementedError
1810
1811 @cached_property
1812 @final
1813 def descriptors(self):
1814 # type: () -> tuple[_Desc, ...]
1815 return self._get_descriptors()
1816
1817 @property
1818 @final
1819 def op(self):
1820 return self.__op
1821
1822 def __init__(self, items, op):
1823 # type: (Iterable[_T], Op) -> None
1824 super().__init__()
1825 self.__op = op
1826 self.__items = [] # type: list[_T]
1827 for idx, item in enumerate(items):
1828 if idx >= len(self.descriptors):
1829 raise ValueError("too many items")
1830 _ = self._verify_write(idx, item)
1831 self.__items.append(item)
1832 if len(self.__items) < len(self.descriptors):
1833 raise ValueError("not enough items")
1834
1835 @final
1836 def __iter__(self):
1837 # type: () -> Iterator[_T]
1838 yield from self.__items
1839
1840 @overload
1841 def __getitem__(self, idx):
1842 # type: (int) -> _T
1843 ...
1844
1845 @overload
1846 def __getitem__(self, idx):
1847 # type: (slice) -> list[_T]
1848 ...
1849
1850 @final
1851 def __getitem__(self, idx):
1852 # type: (int | slice) -> _T | list[_T]
1853 return self.__items[idx]
1854
1855 @final
1856 def __setitem__(self, idx, item):
1857 # type: (int, _T) -> None
1858 idx = self._verify_write(idx, item)
1859 self.__items[idx] = item
1860
1861 @final
1862 def __len__(self):
1863 # type: () -> int
1864 return len(self.__items)
1865
1866 def __repr__(self):
1867 # type: () -> str
1868 return f"{self.__class__.__name__}({self.__items}, op=...)"
1869
1870
1871 @final
1872 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
1873 def _get_descriptors(self):
1874 # type: () -> tuple[OperandDesc, ...]
1875 return self.op.properties.inputs
1876
1877 def _verify_write_with_desc(self, idx, item, desc):
1878 # type: (int, SSAVal | Any, OperandDesc) -> None
1879 if not isinstance(item, SSAVal):
1880 raise TypeError("expected value of type SSAVal")
1881 if item.ty != desc.ty:
1882 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
1883 f"corresponding input's type {desc.ty!r}")
1884
1885 def _on_set(self, idx, new_item, old_item):
1886 # type: (int, SSAVal, SSAVal | None) -> None
1887 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
1888
1889 def __init__(self, items, op):
1890 # type: (Iterable[SSAVal], Op) -> None
1891 if hasattr(op, "inputs"):
1892 raise ValueError("Op.inputs already set")
1893 super().__init__(items, op)
1894
1895
1896 @final
1897 class OpImmediates(OpInputSeq[int, range]):
1898 def _get_descriptors(self):
1899 # type: () -> tuple[range, ...]
1900 return self.op.properties.immediates
1901
1902 def _verify_write_with_desc(self, idx, item, desc):
1903 # type: (int, int | Any, range) -> None
1904 if not isinstance(item, int):
1905 raise TypeError("expected value of type int")
1906 if item not in desc:
1907 raise ValueError(f"immediate value {item!r} not in {desc!r}")
1908
1909 def __init__(self, items, op):
1910 # type: (Iterable[int], Op) -> None
1911 if hasattr(op, "immediates"):
1912 raise ValueError("Op.immediates already set")
1913 super().__init__(items, op)
1914
1915
1916 @plain_data(frozen=True, eq=False, repr=False)
1917 @final
1918 class Op:
1919 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
1920 "outputs", "name")
1921
1922 def __init__(self, fn, properties, input_vals, immediates, name=""):
1923 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1924 self.fn = fn
1925 self.properties = properties
1926 self.input_vals = OpInputVals(input_vals, op=self)
1927 inputs_len = len(self.properties.inputs)
1928 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
1929 self.immediates = OpImmediates(immediates, op=self)
1930 outputs_len = len(self.properties.outputs)
1931 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
1932 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
1933
1934 @property
1935 def kind(self):
1936 # type: () -> OpKind
1937 return self.properties.kind
1938
1939 def __eq__(self, other):
1940 # type: (Op | Any) -> bool
1941 if isinstance(other, Op):
1942 return self is other
1943 return NotImplemented
1944
1945 def __hash__(self):
1946 # type: () -> int
1947 return object.__hash__(self)
1948
1949 def __repr__(self):
1950 # type: () -> str
1951 field_vals = [] # type: list[str]
1952 for name in fields(self):
1953 if name == "properties":
1954 name = "kind"
1955 elif name == "fn":
1956 continue
1957 try:
1958 value = getattr(self, name)
1959 except AttributeError:
1960 field_vals.append(f"{name}=<not set>")
1961 continue
1962 if isinstance(value, OpInputSeq):
1963 value = list(value) # type: ignore
1964 field_vals.append(f"{name}={value!r}")
1965 field_vals_str = ", ".join(field_vals)
1966 return f"Op({field_vals_str})"
1967
1968 def sim(self, state):
1969 # type: (BaseSimState) -> None
1970 for inp in self.input_vals:
1971 try:
1972 val = state[inp]
1973 except KeyError:
1974 raise ValueError(f"SSAVal {inp} not yet assigned when "
1975 f"running {self}")
1976 if len(val) != inp.ty.reg_len:
1977 raise ValueError(
1978 f"value of SSAVal {inp} has wrong number of elements: "
1979 f"expected {inp.ty.reg_len} found "
1980 f"{len(val)}: {val!r}")
1981 if isinstance(state, PreRASimState):
1982 for out in self.outputs:
1983 if out in state.ssa_vals:
1984 if self.kind is OpKind.FuncArgR3:
1985 continue
1986 raise ValueError(f"SSAVal {out} already assigned before "
1987 f"running {self}")
1988 self.kind.sim(self, state)
1989 for out in self.outputs:
1990 try:
1991 val = state[out]
1992 except KeyError:
1993 raise ValueError(f"running {self} failed to assign to {out}")
1994 if len(val) != out.ty.reg_len:
1995 raise ValueError(
1996 f"value of SSAVal {out} has wrong number of elements: "
1997 f"expected {out.ty.reg_len} found "
1998 f"{len(val)}: {val!r}")
1999
2000 def gen_asm(self, state):
2001 # type: (GenAsmState) -> None
2002 all_loc_kinds = tuple(LocKind)
2003 for inp in self.input_vals:
2004 state.loc(inp, expected_kinds=all_loc_kinds)
2005 for out in self.outputs:
2006 state.loc(out, expected_kinds=all_loc_kinds)
2007 self.kind.gen_asm(self, state)
2008
2009
2010 @plain_data(frozen=True, repr=False)
2011 class BaseSimState(metaclass=ABCMeta):
2012 __slots__ = "memory",
2013
2014 def __init__(self, memory):
2015 # type: (dict[int, int]) -> None
2016 super().__init__()
2017 self.memory = memory # type: dict[int, int]
2018
2019 def load_byte(self, addr):
2020 # type: (int) -> int
2021 addr &= GPR_VALUE_MASK
2022 return self.memory.get(addr, 0) & 0xFF
2023
2024 def store_byte(self, addr, value):
2025 # type: (int, int) -> None
2026 addr &= GPR_VALUE_MASK
2027 value &= 0xFF
2028 self.memory[addr] = value
2029
2030 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
2031 # type: (int, int, bool) -> int
2032 if addr % size_in_bytes != 0:
2033 raise ValueError(f"address not aligned: {hex(addr)} "
2034 f"required alignment: {size_in_bytes}")
2035 retval = 0
2036 for i in range(size_in_bytes):
2037 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
2038 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
2039 retval -= 1 << size_in_bytes * BITS_IN_BYTE
2040 return retval
2041
2042 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
2043 # type: (int, int, int) -> None
2044 if addr % size_in_bytes != 0:
2045 raise ValueError(f"address not aligned: {hex(addr)} "
2046 f"required alignment: {size_in_bytes}")
2047 for i in range(size_in_bytes):
2048 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
2049
2050 def _memory__repr(self):
2051 # type: () -> str
2052 if len(self.memory) == 0:
2053 return "{}"
2054 keys = sorted(self.memory.keys(), reverse=True)
2055 CHUNK_SIZE = GPR_SIZE_IN_BYTES
2056 items = [] # type: list[str]
2057 while len(keys) != 0:
2058 addr = keys[-1]
2059 if (len(keys) >= CHUNK_SIZE
2060 and addr % CHUNK_SIZE == 0
2061 and keys[-CHUNK_SIZE:]
2062 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
2063 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
2064 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2065 keys[-CHUNK_SIZE:] = ()
2066 else:
2067 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2068 if len(items) == 1:
2069 return f"{{{items[0]}}}"
2070 items_str = ",\n".join(items)
2071 return f"{{\n{items_str}}}"
2072
2073 def __repr__(self):
2074 # type: () -> str
2075 field_vals = [] # type: list[str]
2076 for name in fields(self):
2077 try:
2078 value = getattr(self, name)
2079 except AttributeError:
2080 field_vals.append(f"{name}=<not set>")
2081 continue
2082 repr_fn = getattr(self, f"_{name}__repr", None)
2083 if callable(repr_fn):
2084 field_vals.append(f"{name}={repr_fn()}")
2085 else:
2086 field_vals.append(f"{name}={value!r}")
2087 field_vals_str = ", ".join(field_vals)
2088 return f"{self.__class__.__name__}({field_vals_str})"
2089
2090 @abstractmethod
2091 def __getitem__(self, ssa_val):
2092 # type: (SSAVal) -> tuple[int, ...]
2093 ...
2094
2095 @abstractmethod
2096 def __setitem__(self, ssa_val, value):
2097 # type: (SSAVal, tuple[int, ...]) -> None
2098 ...
2099
2100
2101 @plain_data(frozen=True, repr=False)
2102 @final
2103 class PreRASimState(BaseSimState):
2104 __slots__ = "ssa_vals",
2105
2106 def __init__(self, ssa_vals, memory):
2107 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2108 super().__init__(memory)
2109 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
2110
2111 def _ssa_vals__repr(self):
2112 # type: () -> str
2113 if len(self.ssa_vals) == 0:
2114 return "{}"
2115 items = [] # type: list[str]
2116 CHUNK_SIZE = 4
2117 for k, v in self.ssa_vals.items():
2118 element_strs = [] # type: list[str]
2119 for i, el in enumerate(v):
2120 if i % CHUNK_SIZE != 0:
2121 element_strs.append(" " + hex(el))
2122 else:
2123 element_strs.append("\n " + hex(el))
2124 if len(element_strs) <= CHUNK_SIZE:
2125 element_strs[0] = element_strs[0].lstrip()
2126 if len(element_strs) == 1:
2127 element_strs.append("")
2128 v_str = ",".join(element_strs)
2129 items.append(f"{k!r}: ({v_str})")
2130 if len(items) == 1 and "\n" not in items[0]:
2131 return f"{{{items[0]}}}"
2132 items_str = ",\n".join(items)
2133 return f"{{\n{items_str},\n}}"
2134
2135 def __getitem__(self, ssa_val):
2136 # type: (SSAVal) -> tuple[int, ...]
2137 return self.ssa_vals[ssa_val]
2138
2139 def __setitem__(self, ssa_val, value):
2140 # type: (SSAVal, tuple[int, ...]) -> None
2141 if len(value) != ssa_val.ty.reg_len:
2142 raise ValueError("value has wrong len")
2143 self.ssa_vals[ssa_val] = value
2144
2145
2146 @plain_data(frozen=True, repr=False)
2147 @final
2148 class PostRASimState(BaseSimState):
2149 __slots__ = "ssa_val_to_loc_map", "loc_values"
2150
2151 def __init__(self, ssa_val_to_loc_map, memory, loc_values):
2152 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2153 super().__init__(memory)
2154 self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
2155 for ssa_val, loc in self.ssa_val_to_loc_map.items():
2156 if ssa_val.ty != loc.ty:
2157 raise ValueError(
2158 f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2159 self.loc_values = loc_values
2160 for loc in self.loc_values.keys():
2161 if loc.reg_len != 1:
2162 raise ValueError(
2163 "loc_values must only contain Locs with reg_len=1, all "
2164 "larger Locs will be split into reg_len=1 sub-Locs")
2165
2166 def _loc_values__repr(self):
2167 # type: () -> str
2168 locs = sorted(self.loc_values.keys(), key=lambda v: (v.kind, v.start))
2169 items = [] # type: list[str]
2170 for loc in locs:
2171 items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
2172 items_str = ",\n".join(items)
2173 return f"{{\n{items_str},\n}}"
2174
2175 def __getitem__(self, ssa_val):
2176 # type: (SSAVal) -> tuple[int, ...]
2177 loc = self.ssa_val_to_loc_map[ssa_val]
2178 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2179 retval = [] # type: list[int]
2180 for i in range(loc.reg_len):
2181 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2182 retval.append(self.loc_values.get(subloc, 0))
2183 return tuple(retval)
2184
2185 def __setitem__(self, ssa_val, value):
2186 # type: (SSAVal, tuple[int, ...]) -> None
2187 if len(value) != ssa_val.ty.reg_len:
2188 raise ValueError("value has wrong len")
2189 loc = self.ssa_val_to_loc_map[ssa_val]
2190 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2191 for i in range(loc.reg_len):
2192 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2193 self.loc_values[subloc] = value[i]
2194
2195
2196 @plain_data(frozen=True)
2197 class GenAsmState:
2198 __slots__ = "allocated_locs", "output"
2199
2200 def __init__(self, allocated_locs, output=None):
2201 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2202 super().__init__()
2203 self.allocated_locs = FMap(allocated_locs)
2204 for ssa_val, loc in self.allocated_locs.items():
2205 if ssa_val.ty != loc.ty:
2206 raise ValueError(
2207 f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2208 if output is None:
2209 output = []
2210 self.output = output
2211
2212 __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
2213
2214 def loc(self, ssa_val_or_locs, expected_kinds):
2215 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2216 if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
2217 ssa_val_or_locs = [ssa_val_or_locs]
2218 locs = [] # type: list[Loc]
2219 for i in ssa_val_or_locs:
2220 if isinstance(i, SSAVal):
2221 locs.append(self.allocated_locs[i])
2222 else:
2223 locs.append(i)
2224 if len(locs) == 0:
2225 raise ValueError("invalid Loc sequence: must not be empty")
2226 retval = locs[0].try_concat(*locs[1:])
2227 if retval is None:
2228 raise ValueError("invalid Loc sequence: try_concat failed")
2229 if isinstance(expected_kinds, LocKind):
2230 expected_kinds = expected_kinds,
2231 if retval.kind not in expected_kinds:
2232 if len(expected_kinds) == 1:
2233 expected_kinds = expected_kinds[0]
2234 raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
2235 f"{retval.kind} expected {expected_kinds}")
2236 return retval
2237
2238 def gpr(self, ssa_val_or_locs, is_vec):
2239 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2240 loc = self.loc(ssa_val_or_locs, LocKind.GPR)
2241 vec_str = "*" if is_vec else ""
2242 return vec_str + str(loc.start)
2243
2244 def sgpr(self, ssa_val_or_locs):
2245 # type: (__SSA_VAL_OR_LOCS) -> str
2246 return self.gpr(ssa_val_or_locs, is_vec=False)
2247
2248 def vgpr(self, ssa_val_or_locs):
2249 # type: (__SSA_VAL_OR_LOCS) -> str
2250 return self.gpr(ssa_val_or_locs, is_vec=True)
2251
2252 def stack(self, ssa_val_or_locs):
2253 # type: (__SSA_VAL_OR_LOCS) -> str
2254 loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
2255 return f"{loc.start}(1)"
2256
2257 def writeln(self, *line_segments):
2258 # type: (*str) -> None
2259 line = " ".join(line_segments)
2260 if isinstance(self.output, list):
2261 self.output.append(line)
2262 else:
2263 self.output.write(line + "\n")