TOOM-2 256x256->512-bit [un]signed*[un]signed mul works!
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 from contextlib import contextmanager
2 import enum
3 from abc import ABCMeta, abstractmethod
4 from enum import Enum, unique
5 from functools import lru_cache, total_ordering
6 from io import StringIO
7 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
8 Mapping, Sequence, TypeVar, Union, overload)
9 from weakref import WeakValueDictionary as _WeakVDict
10
11 from cached_property import cached_property
12 from nmutil.plain_data import fields, plain_data
13
14 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
15 final)
16 from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
17 OFSet, OSet)
18
19 GPR_SIZE_IN_BYTES = 8
20 BITS_IN_BYTE = 8
21 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
22 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
23
24
25 @final
26 class Fn:
27 def __init__(self):
28 self.ops = [] # type: list[Op]
29 self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
30 self.__next_name_suffix = 2
31
32 def _add_op_with_unused_name(self, op, name=""):
33 # type: (Op, str) -> str
34 if op.fn is not self:
35 raise ValueError("can't add Op to wrong Fn")
36 if hasattr(op, "name"):
37 raise ValueError("Op already named")
38 orig_name = name
39 while True:
40 if name != "" and name not in self.__op_names:
41 self.__op_names[name] = op
42 return name
43 name = orig_name + str(self.__next_name_suffix)
44 self.__next_name_suffix += 1
45
46 def __repr__(self):
47 # type: () -> str
48 return "<Fn>"
49
50 def ops_to_str(self, as_python_literal=False, wrap_width=63,
51 python_indent=" ", indent=" "):
52 # type: (bool, int, str, str) -> str
53 l = [] # type: list[str]
54 for op in self.ops:
55 l.append(op.__repr__(wrap_width=wrap_width, indent=indent))
56 retval = "\n".join(l)
57 if as_python_literal:
58 l = [python_indent + "\""]
59 for ch in retval:
60 if ch == "\n":
61 l.append(f"\\n\"\n{python_indent}\"")
62 elif ch in "\"\\":
63 l.append("\\" + ch)
64 elif ch.isascii() and ch.isprintable():
65 l.append(ch)
66 else:
67 l.append(repr(ch).strip("\"'"))
68 l.append("\"")
69 retval = "".join(l)
70 empty_end = f"\"\n{python_indent}\"\""
71 if retval.endswith(empty_end):
72 retval = retval[:-len(empty_end)]
73 return retval
74
75 def append_op(self, op):
76 # type: (Op) -> None
77 if op.fn is not self:
78 raise ValueError("can't add Op to wrong Fn")
79 self.ops.append(op)
80
81 def append_new_op(self, kind, input_vals=(), immediates=(), name="",
82 maxvl=1):
83 # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
84 retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
85 input_vals=input_vals, immediates=immediates, name=name)
86 self.append_op(retval)
87 return retval
88
89 def sim(self, state):
90 # type: (BaseSimState) -> None
91 for op in self.ops:
92 op.sim(state)
93
94 def gen_asm(self, state):
95 # type: (GenAsmState) -> None
96 for op in self.ops:
97 op.gen_asm(state)
98
99 def pre_ra_insert_copies(self):
100 # type: () -> None
101 orig_ops = list(self.ops)
102 copied_outputs = {} # type: dict[SSAVal, SSAVal]
103 setvli_outputs = {} # type: dict[SSAVal, Op]
104 self.ops.clear()
105 for op in orig_ops:
106 for i in range(len(op.input_vals)):
107 inp = copied_outputs[op.input_vals[i]]
108 if inp.ty.base_ty is BaseTy.I64:
109 maxvl = inp.ty.reg_len
110 if inp.ty.reg_len != 1:
111 setvl = self.append_new_op(
112 OpKind.SetVLI, immediates=[maxvl],
113 name=f"{op.name}.inp{i}.setvl")
114 vl = setvl.outputs[0]
115 mv = self.append_new_op(
116 OpKind.VecCopyToReg, input_vals=[inp, vl],
117 maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
118 else:
119 mv = self.append_new_op(
120 OpKind.CopyToReg, input_vals=[inp],
121 name=f"{op.name}.inp{i}.copy")
122 op.input_vals[i] = mv.outputs[0]
123 elif inp.ty.base_ty is BaseTy.CA \
124 or inp.ty.base_ty is BaseTy.VL_MAXVL:
125 # all copies would be no-ops, so we don't need to copy,
126 # though we do need to rematerialize SetVLI ops right
127 # before the ops VL
128 if inp in setvli_outputs:
129 setvl = self.append_new_op(
130 OpKind.SetVLI,
131 immediates=setvli_outputs[inp].immediates,
132 name=f"{op.name}.inp{i}.setvl")
133 inp = setvl.outputs[0]
134 op.input_vals[i] = inp
135 else:
136 assert_never(inp.ty.base_ty)
137 self.ops.append(op)
138 for i, out in enumerate(op.outputs):
139 if op.kind is OpKind.SetVLI:
140 setvli_outputs[out] = op
141 if out.ty.base_ty is BaseTy.I64:
142 maxvl = out.ty.reg_len
143 if out.ty.reg_len != 1:
144 setvl = self.append_new_op(
145 OpKind.SetVLI, immediates=[maxvl],
146 name=f"{op.name}.out{i}.setvl")
147 vl = setvl.outputs[0]
148 mv = self.append_new_op(
149 OpKind.VecCopyFromReg, input_vals=[out, vl],
150 maxvl=maxvl, name=f"{op.name}.out{i}.copy")
151 else:
152 mv = self.append_new_op(
153 OpKind.CopyFromReg, input_vals=[out],
154 name=f"{op.name}.out{i}.copy")
155 copied_outputs[out] = mv.outputs[0]
156 elif out.ty.base_ty is BaseTy.CA \
157 or out.ty.base_ty is BaseTy.VL_MAXVL:
158 # all copies would be no-ops, so we don't need to copy
159 copied_outputs[out] = out
160 else:
161 assert_never(out.ty.base_ty)
162
163
164 @final
165 @unique
166 @total_ordering
167 class OpStage(Enum):
168 value: Literal[0, 1] # type: ignore
169
170 def __new__(cls, value):
171 # type: (int) -> OpStage
172 value = int(value)
173 if value not in (0, 1):
174 raise ValueError("invalid value")
175 retval = object.__new__(cls)
176 retval._value_ = value
177 return retval
178
179 Early = 0
180 """ early stage of Op execution, where all input reads occur.
181 all output writes with `write_stage == Early` occur here too, and therefore
182 conflict with input reads, telling the compiler that it that can't share
183 that output's register with any inputs that the output isn't tied to.
184
185 All outputs, even unused outputs, can't share registers with any other
186 outputs, independent of `write_stage` settings.
187 """
188 Late = 1
189 """ late stage of Op execution, where all output writes with
190 `write_stage == Late` occur, and therefore don't conflict with input reads,
191 telling the compiler that any inputs can safely use the same register as
192 those outputs.
193
194 All outputs, even unused outputs, can't share registers with any other
195 outputs, independent of `write_stage` settings.
196 """
197
198 def __repr__(self):
199 # type: () -> str
200 return f"OpStage.{self._name_}"
201
202 def __lt__(self, other):
203 # type: (OpStage | object) -> bool
204 if isinstance(other, OpStage):
205 return self.value < other.value
206 return NotImplemented
207
208
209 assert OpStage.Early < OpStage.Late, "early must be less than late"
210
211
212 @plain_data(frozen=True, unsafe_hash=True, repr=False)
213 @final
214 @total_ordering
215 class ProgramPoint(metaclass=InternedMeta):
216 __slots__ = "op_index", "stage"
217
218 def __init__(self, op_index, stage):
219 # type: (int, OpStage) -> None
220 self.op_index = op_index
221 self.stage = stage
222
223 @property
224 def int_value(self):
225 # type: () -> int
226 """ an integer representation of `self` such that it keeps ordering and
227 successor/predecessor relations.
228 """
229 return self.op_index * 2 + self.stage.value
230
231 @staticmethod
232 def from_int_value(int_value):
233 # type: (int) -> ProgramPoint
234 op_index, stage = divmod(int_value, 2)
235 return ProgramPoint(op_index=op_index, stage=OpStage(stage))
236
237 def next(self, steps=1):
238 # type: (int) -> ProgramPoint
239 return ProgramPoint.from_int_value(self.int_value + steps)
240
241 def prev(self, steps=1):
242 # type: (int) -> ProgramPoint
243 return self.next(steps=-steps)
244
245 def __lt__(self, other):
246 # type: (ProgramPoint | Any) -> bool
247 if not isinstance(other, ProgramPoint):
248 return NotImplemented
249 if self.op_index != other.op_index:
250 return self.op_index < other.op_index
251 return self.stage < other.stage
252
253 def __repr__(self):
254 # type: () -> str
255 return f"<ops[{self.op_index}]:{self.stage._name_}>"
256
257
258 @plain_data(frozen=True, unsafe_hash=True, repr=False)
259 @final
260 class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
261 __slots__ = "start", "stop"
262
263 def __init__(self, start, stop):
264 # type: (ProgramPoint, ProgramPoint) -> None
265 self.start = start
266 self.stop = stop
267
268 @cached_property
269 def int_value_range(self):
270 # type: () -> range
271 return range(self.start.int_value, self.stop.int_value)
272
273 @staticmethod
274 def from_int_value_range(int_value_range):
275 # type: (range) -> ProgramRange
276 if int_value_range.step != 1:
277 raise ValueError("int_value_range must have step == 1")
278 return ProgramRange(
279 start=ProgramPoint.from_int_value(int_value_range.start),
280 stop=ProgramPoint.from_int_value(int_value_range.stop))
281
282 @overload
283 def __getitem__(self, __idx):
284 # type: (int) -> ProgramPoint
285 ...
286
287 @overload
288 def __getitem__(self, __idx):
289 # type: (slice) -> ProgramRange
290 ...
291
292 def __getitem__(self, __idx):
293 # type: (int | slice) -> ProgramPoint | ProgramRange
294 v = range(self.start.int_value, self.stop.int_value)[__idx]
295 if isinstance(v, int):
296 return ProgramPoint.from_int_value(v)
297 return ProgramRange.from_int_value_range(v)
298
299 def __len__(self):
300 # type: () -> int
301 return len(self.int_value_range)
302
303 def __iter__(self):
304 # type: () -> Iterator[ProgramPoint]
305 return map(ProgramPoint.from_int_value, self.int_value_range)
306
307 def __repr__(self):
308 # type: () -> str
309 start = repr(self.start).lstrip("<").rstrip(">")
310 stop = repr(self.stop).lstrip("<").rstrip(">")
311 return f"<range:{start}..{stop}>"
312
313
314 @plain_data(frozen=True, eq=False, repr=False)
315 @final
316 class FnAnalysis:
317 __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
318 "def_program_ranges", "use_program_points",
319 "all_program_points")
320
321 def __init__(self, fn):
322 # type: (Fn) -> None
323 self.fn = fn
324 self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
325 self.all_program_points = ProgramRange(
326 start=ProgramPoint(op_index=0, stage=OpStage.Early),
327 stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
328 def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
329 use_program_points = {} # type: dict[SSAUse, ProgramPoint]
330 uses = {} # type: dict[SSAVal, OSet[SSAUse]]
331 live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
332 for op in fn.ops:
333 for use in op.input_uses:
334 uses[use.ssa_val].add(use)
335 use_program_point = self.__get_use_program_point(use)
336 use_program_points[use] = use_program_point
337 live_range_stops[use.ssa_val] = max(
338 live_range_stops[use.ssa_val], use_program_point.next())
339 for out in op.outputs:
340 uses[out] = OSet()
341 def_program_range = self.__get_def_program_range(out)
342 def_program_ranges[out] = def_program_range
343 live_range_stops[out] = def_program_range.stop
344 self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
345 self.def_program_ranges = FMap(def_program_ranges)
346 self.use_program_points = FMap(use_program_points)
347 live_ranges = {} # type: dict[SSAVal, ProgramRange]
348 live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
349 for ssa_val in uses.keys():
350 live_ranges[ssa_val] = live_range = ProgramRange(
351 start=self.def_program_ranges[ssa_val].start,
352 stop=live_range_stops[ssa_val])
353 for program_point in live_range:
354 live_at[program_point].add(ssa_val)
355 self.live_ranges = FMap(live_ranges)
356 self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
357
358 def __get_def_program_range(self, ssa_val):
359 # type: (SSAVal) -> ProgramRange
360 write_stage = ssa_val.defining_descriptor.write_stage
361 start = ProgramPoint(
362 op_index=self.op_indexes[ssa_val.op], stage=write_stage)
363 # always include late stage of ssa_val.op, to ensure outputs always
364 # overlap all other outputs.
365 # stop is exclusive, so we need the next program point.
366 stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
367 return ProgramRange(start=start, stop=stop)
368
369 def __get_use_program_point(self, ssa_use):
370 # type: (SSAUse) -> ProgramPoint
371 assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
372 "assumed here, ensured by GenericOpProperties.__init__"
373 return ProgramPoint(
374 op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
375
376 def __eq__(self, other):
377 # type: (FnAnalysis | Any) -> bool
378 if isinstance(other, FnAnalysis):
379 return self.fn == other.fn
380 return NotImplemented
381
382 def __hash__(self):
383 # type: () -> int
384 return hash(self.fn)
385
386 def __repr__(self):
387 # type: () -> str
388 return "<FnAnalysis>"
389
390
391 @unique
392 @final
393 class BaseTy(Enum):
394 I64 = enum.auto()
395 CA = enum.auto()
396 VL_MAXVL = enum.auto()
397
398 @cached_property
399 def only_scalar(self):
400 # type: () -> bool
401 if self is BaseTy.I64:
402 return False
403 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
404 return True
405 else:
406 assert_never(self)
407
408 @cached_property
409 def max_reg_len(self):
410 # type: () -> int
411 if self is BaseTy.I64:
412 return 128
413 elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
414 return 1
415 else:
416 assert_never(self)
417
418 def __repr__(self):
419 return "BaseTy." + self._name_
420
421
422 @plain_data(frozen=True, unsafe_hash=True, repr=False)
423 @final
424 class Ty(metaclass=InternedMeta):
425 __slots__ = "base_ty", "reg_len"
426
427 @staticmethod
428 def validate(base_ty, reg_len):
429 # type: (BaseTy, int) -> str | None
430 """ return a string with the error if the combination is invalid,
431 otherwise return None
432 """
433 if base_ty.only_scalar and reg_len != 1:
434 return f"can't create a vector of an only-scalar type: {base_ty}"
435 if reg_len < 1 or reg_len > base_ty.max_reg_len:
436 return "reg_len out of range"
437 return None
438
439 def __init__(self, base_ty, reg_len):
440 # type: (BaseTy, int) -> None
441 msg = self.validate(base_ty=base_ty, reg_len=reg_len)
442 if msg is not None:
443 raise ValueError(msg)
444 self.base_ty = base_ty
445 self.reg_len = reg_len
446
447 def __repr__(self):
448 # type: () -> str
449 if self.reg_len != 1:
450 reg_len = f"*{self.reg_len}"
451 else:
452 reg_len = ""
453 return f"<{self.base_ty._name_}{reg_len}>"
454
455
456 @unique
457 @final
458 class LocKind(Enum):
459 GPR = enum.auto()
460 StackI64 = enum.auto()
461 CA = enum.auto()
462 VL_MAXVL = enum.auto()
463
464 @cached_property
465 def base_ty(self):
466 # type: () -> BaseTy
467 if self is LocKind.GPR or self is LocKind.StackI64:
468 return BaseTy.I64
469 if self is LocKind.CA:
470 return BaseTy.CA
471 if self is LocKind.VL_MAXVL:
472 return BaseTy.VL_MAXVL
473 else:
474 assert_never(self)
475
476 @cached_property
477 def loc_count(self):
478 # type: () -> int
479 if self is LocKind.StackI64:
480 return 512
481 if self is LocKind.GPR or self is LocKind.CA \
482 or self is LocKind.VL_MAXVL:
483 return self.base_ty.max_reg_len
484 else:
485 assert_never(self)
486
487 def __repr__(self):
488 return "LocKind." + self._name_
489
490
491 @final
492 @unique
493 class LocSubKind(Enum):
494 BASE_GPR = enum.auto()
495 SV_EXTRA2_VGPR = enum.auto()
496 SV_EXTRA2_SGPR = enum.auto()
497 SV_EXTRA3_VGPR = enum.auto()
498 SV_EXTRA3_SGPR = enum.auto()
499 StackI64 = enum.auto()
500 CA = enum.auto()
501 VL_MAXVL = enum.auto()
502
503 @cached_property
504 def kind(self):
505 # type: () -> LocKind
506 # pyright fails typechecking when using `in` here:
507 # reported: https://github.com/microsoft/pyright/issues/4102
508 if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
509 LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
510 LocSubKind.SV_EXTRA3_SGPR):
511 return LocKind.GPR
512 if self is LocSubKind.StackI64:
513 return LocKind.StackI64
514 if self is LocSubKind.CA:
515 return LocKind.CA
516 if self is LocSubKind.VL_MAXVL:
517 return LocKind.VL_MAXVL
518 assert_never(self)
519
520 @property
521 def base_ty(self):
522 return self.kind.base_ty
523
524 @lru_cache()
525 def allocatable_locs(self, ty):
526 # type: (Ty) -> LocSet
527 if ty.base_ty != self.base_ty:
528 raise ValueError("type mismatch")
529 if self is LocSubKind.BASE_GPR:
530 starts = range(32)
531 elif self is LocSubKind.SV_EXTRA2_VGPR:
532 starts = range(0, 128, 2)
533 elif self is LocSubKind.SV_EXTRA2_SGPR:
534 starts = range(64)
535 elif self is LocSubKind.SV_EXTRA3_VGPR \
536 or self is LocSubKind.SV_EXTRA3_SGPR:
537 starts = range(128)
538 elif self is LocSubKind.StackI64:
539 starts = range(LocKind.StackI64.loc_count)
540 elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
541 return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
542 else:
543 assert_never(self)
544 retval = [] # type: list[Loc]
545 for start in starts:
546 loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
547 if loc is None:
548 continue
549 conflicts = False
550 for special_loc in SPECIAL_GPRS:
551 if loc.conflicts(special_loc):
552 conflicts = True
553 break
554 if not conflicts:
555 retval.append(loc)
556 return LocSet(retval)
557
558 def __repr__(self):
559 return "LocSubKind." + self._name_
560
561
562 @plain_data(frozen=True, unsafe_hash=True)
563 @final
564 class GenericTy(metaclass=InternedMeta):
565 __slots__ = "base_ty", "is_vec"
566
567 def __init__(self, base_ty, is_vec):
568 # type: (BaseTy, bool) -> None
569 self.base_ty = base_ty
570 if base_ty.only_scalar and is_vec:
571 raise ValueError(f"base_ty={base_ty} requires is_vec=False")
572 self.is_vec = is_vec
573
574 def instantiate(self, maxvl):
575 # type: (int) -> Ty
576 # here's where subvl and elwid would be accounted for
577 if self.is_vec:
578 return Ty(self.base_ty, maxvl)
579 return Ty(self.base_ty, 1)
580
581 def can_instantiate_to(self, ty):
582 # type: (Ty) -> bool
583 if self.base_ty != ty.base_ty:
584 return False
585 if self.is_vec:
586 return True
587 return ty.reg_len == 1
588
589
590 @plain_data(frozen=True, unsafe_hash=True)
591 @final
592 class Loc(metaclass=InternedMeta):
593 __slots__ = "kind", "start", "reg_len"
594
595 @staticmethod
596 def validate(kind, start, reg_len):
597 # type: (LocKind, int, int) -> str | None
598 msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
599 if msg is not None:
600 return msg
601 if reg_len > kind.loc_count:
602 return "invalid reg_len"
603 if start < 0 or start + reg_len > kind.loc_count:
604 return "start not in valid range"
605 return None
606
607 @staticmethod
608 def try_make(kind, start, reg_len):
609 # type: (LocKind, int, int) -> Loc | None
610 msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
611 if msg is not None:
612 return None
613 return Loc(kind=kind, start=start, reg_len=reg_len)
614
615 def __init__(self, kind, start, reg_len):
616 # type: (LocKind, int, int) -> None
617 msg = self.validate(kind=kind, start=start, reg_len=reg_len)
618 if msg is not None:
619 raise ValueError(msg)
620 self.kind = kind
621 self.reg_len = reg_len
622 self.start = start
623
624 def conflicts(self, other):
625 # type: (Loc) -> bool
626 return (self.kind == other.kind
627 and self.start < other.stop and other.start < self.stop)
628
629 @staticmethod
630 def make_ty(kind, reg_len):
631 # type: (LocKind, int) -> Ty
632 return Ty(base_ty=kind.base_ty, reg_len=reg_len)
633
634 @cached_property
635 def ty(self):
636 # type: () -> Ty
637 return self.make_ty(kind=self.kind, reg_len=self.reg_len)
638
639 @property
640 def stop(self):
641 # type: () -> int
642 return self.start + self.reg_len
643
644 def try_concat(self, *others):
645 # type: (*Loc | None) -> Loc | None
646 reg_len = self.reg_len
647 stop = self.stop
648 for other in others:
649 if other is None or other.kind != self.kind:
650 return None
651 if stop != other.start:
652 return None
653 stop = other.stop
654 reg_len += other.reg_len
655 return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
656
657 def get_subloc_at_offset(self, subloc_ty, offset):
658 # type: (Ty, int) -> Loc
659 if subloc_ty.base_ty != self.kind.base_ty:
660 raise ValueError("BaseTy mismatch")
661 if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
662 raise ValueError("invalid sub-Loc: offset and/or "
663 "subloc_ty.reg_len out of range")
664 return Loc(kind=self.kind,
665 start=self.start + offset, reg_len=subloc_ty.reg_len)
666
667
668 SPECIAL_GPRS = (
669 Loc(kind=LocKind.GPR, start=0, reg_len=1),
670 Loc(kind=LocKind.GPR, start=1, reg_len=1),
671 Loc(kind=LocKind.GPR, start=2, reg_len=1),
672 Loc(kind=LocKind.GPR, start=13, reg_len=1),
673 )
674
675
676 @final
677 class LocSet(OFSet[Loc], metaclass=InternedMeta):
678 def __init__(self, __locs=()):
679 # type: (Iterable[Loc]) -> None
680 super().__init__(__locs)
681 if isinstance(__locs, LocSet):
682 self.__starts = __locs.starts
683 self.__ty = __locs.ty
684 return
685 starts = {i: BitSet() for i in LocKind}
686 ty = None # type: None | Ty
687 for loc in self:
688 if ty is None:
689 ty = loc.ty
690 if ty != loc.ty:
691 raise ValueError(f"conflicting types: {ty} != {loc.ty}")
692 starts[loc.kind].add(loc.start)
693 self.__starts = FMap(
694 (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
695 self.__ty = ty
696
697 @property
698 def starts(self):
699 # type: () -> FMap[LocKind, FBitSet]
700 return self.__starts
701
702 @property
703 def ty(self):
704 # type: () -> Ty | None
705 return self.__ty
706
707 @cached_property
708 def stops(self):
709 # type: () -> FMap[LocKind, FBitSet]
710 if self.ty is None:
711 return FMap()
712 sh = self.ty.reg_len
713 return FMap(
714 (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
715
716 @property
717 def kinds(self):
718 # type: () -> AbstractSet[LocKind]
719 return self.starts.keys()
720
721 @property
722 def reg_len(self):
723 # type: () -> int | None
724 if self.ty is None:
725 return None
726 return self.ty.reg_len
727
728 @property
729 def base_ty(self):
730 # type: () -> BaseTy | None
731 if self.ty is None:
732 return None
733 return self.ty.base_ty
734
735 def concat(self, *others):
736 # type: (*LocSet) -> LocSet
737 if self.ty is None:
738 return LocSet()
739 base_ty = self.ty.base_ty
740 reg_len = self.ty.reg_len
741 starts = {k: BitSet(v) for k, v in self.starts.items()}
742 for other in others:
743 if other.ty is None:
744 return LocSet()
745 if other.ty.base_ty != base_ty:
746 return LocSet()
747 for kind, other_starts in other.starts.items():
748 if kind not in starts:
749 continue
750 starts[kind].bits &= other_starts.bits >> reg_len
751 if starts[kind] == 0:
752 del starts[kind]
753 if len(starts) == 0:
754 return LocSet()
755 reg_len += other.ty.reg_len
756
757 def locs():
758 # type: () -> Iterable[Loc]
759 for kind, v in starts.items():
760 for start in v:
761 loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
762 if loc is not None:
763 yield loc
764 return LocSet(locs())
765
766 @lru_cache(maxsize=None, typed=True)
767 def max_conflicts_with(self, other):
768 # type: (LocSet | Loc) -> int
769 """the largest number of Locs in `self` that a single Loc
770 from `other` can conflict with
771 """
772 if isinstance(other, LocSet):
773 return max(self.max_conflicts_with(i) for i in other)
774 else:
775 return sum(other.conflicts(i) for i in self)
776
777 def __repr__(self):
778 return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
779
780
781 @plain_data(frozen=True, unsafe_hash=True)
782 @final
783 class GenericOperandDesc(metaclass=InternedMeta):
784 """generic Op operand descriptor"""
785 __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
786 "write_stage")
787
788 def __init__(
789 self, ty, # type: GenericTy
790 sub_kinds, # type: Iterable[LocSubKind]
791 *,
792 fixed_loc=None, # type: Loc | None
793 tied_input_index=None, # type: int | None
794 spread=False, # type: bool
795 write_stage=OpStage.Early, # type: OpStage
796 ):
797 # type: (...) -> None
798 self.ty = ty
799 self.sub_kinds = OFSet(sub_kinds)
800 if len(self.sub_kinds) == 0:
801 raise ValueError("sub_kinds can't be empty")
802 self.fixed_loc = fixed_loc
803 if fixed_loc is not None:
804 if tied_input_index is not None:
805 raise ValueError("operand can't be both tied and fixed")
806 if not ty.can_instantiate_to(fixed_loc.ty):
807 raise ValueError(
808 f"fixed_loc has incompatible type for given generic "
809 f"type: fixed_loc={fixed_loc} generic ty={ty}")
810 if len(self.sub_kinds) != 1:
811 raise ValueError(
812 "multiple sub_kinds not allowed for fixed operand")
813 for sub_kind in self.sub_kinds:
814 if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
815 raise ValueError(
816 f"fixed_loc not in given sub_kind: "
817 f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
818 for sub_kind in self.sub_kinds:
819 if sub_kind.base_ty != ty.base_ty:
820 raise ValueError(f"sub_kind is incompatible with type: "
821 f"sub_kind={sub_kind} ty={ty}")
822 if tied_input_index is not None and tied_input_index < 0:
823 raise ValueError("invalid tied_input_index")
824 self.tied_input_index = tied_input_index
825 self.spread = spread
826 if spread:
827 if self.tied_input_index is not None:
828 raise ValueError("operand can't be both spread and tied")
829 if self.fixed_loc is not None:
830 raise ValueError("operand can't be both spread and fixed")
831 if self.ty.is_vec:
832 raise ValueError("operand can't be both spread and vector")
833 self.write_stage = write_stage
834
835 @cached_property
836 def ty_before_spread(self):
837 # type: () -> GenericTy
838 if self.spread:
839 return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
840 return self.ty
841
842 def tied_to_input(self, tied_input_index):
843 # type: (int) -> Self
844 return GenericOperandDesc(self.ty, self.sub_kinds,
845 tied_input_index=tied_input_index,
846 write_stage=self.write_stage)
847
848 def with_fixed_loc(self, fixed_loc):
849 # type: (Loc) -> Self
850 return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
851 write_stage=self.write_stage)
852
853 def with_write_stage(self, write_stage):
854 # type: (OpStage) -> Self
855 return GenericOperandDesc(self.ty, self.sub_kinds,
856 fixed_loc=self.fixed_loc,
857 tied_input_index=self.tied_input_index,
858 spread=self.spread,
859 write_stage=write_stage)
860
861 def instantiate(self, maxvl):
862 # type: (int) -> Iterable[OperandDesc]
863 # assumes all spread operands have ty.reg_len = 1
864 rep_count = 1
865 if self.spread:
866 rep_count = maxvl
867 ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
868
869 def locs_before_spread():
870 # type: () -> Iterable[Loc]
871 if self.fixed_loc is not None:
872 if ty_before_spread != self.fixed_loc.ty:
873 raise ValueError(
874 f"instantiation failed: type mismatch with fixed_loc: "
875 f"instantiated type: {ty_before_spread} "
876 f"fixed_loc: {self.fixed_loc}")
877 yield self.fixed_loc
878 return
879 for sub_kind in self.sub_kinds:
880 yield from sub_kind.allocatable_locs(ty_before_spread)
881 loc_set_before_spread = LocSet(locs_before_spread())
882 for idx in range(rep_count):
883 if not self.spread:
884 idx = None
885 yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
886 tied_input_index=self.tied_input_index,
887 spread_index=idx, write_stage=self.write_stage)
888
889
890 @plain_data(frozen=True, unsafe_hash=True)
891 @final
892 class OperandDesc(metaclass=InternedMeta):
893 """Op operand descriptor"""
894 __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
895 "write_stage")
896
897 def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
898 write_stage):
899 # type: (LocSet, int | None, int | None, OpStage) -> None
900 if len(loc_set_before_spread) == 0:
901 raise ValueError("loc_set_before_spread must not be empty")
902 self.loc_set_before_spread = loc_set_before_spread
903 self.tied_input_index = tied_input_index
904 if self.tied_input_index is not None and spread_index is not None:
905 raise ValueError("operand can't be both spread and tied")
906 self.spread_index = spread_index
907 self.write_stage = write_stage
908
909 @cached_property
910 def ty_before_spread(self):
911 # type: () -> Ty
912 ty = self.loc_set_before_spread.ty
913 assert ty is not None, (
914 "__init__ checked that the LocSet isn't empty, "
915 "non-empty LocSets should always have ty set")
916 return ty
917
918 @cached_property
919 def ty(self):
920 """ Ty after any spread is applied """
921 if self.spread_index is not None:
922 # assumes all spread operands have ty.reg_len = 1
923 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
924 return self.ty_before_spread
925
926 @property
927 def reg_offset_in_unspread(self):
928 """ the number of reg-sized slots in the unspread Loc before self's Loc
929
930 e.g. if the unspread Loc containing self is:
931 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
932 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
933 then reg_offset_into_unspread == 2 == 10 - 8
934 """
935 if self.spread_index is None:
936 return 0
937 return self.spread_index * self.ty.reg_len
938
939
940 OD_BASE_SGPR = GenericOperandDesc(
941 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
942 sub_kinds=[LocSubKind.BASE_GPR])
943 OD_EXTRA3_SGPR = GenericOperandDesc(
944 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
945 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
946 OD_EXTRA3_VGPR = GenericOperandDesc(
947 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
948 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
949 OD_EXTRA2_SGPR = GenericOperandDesc(
950 ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
951 sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
952 OD_EXTRA2_VGPR = GenericOperandDesc(
953 ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
954 sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
955 OD_CA = GenericOperandDesc(
956 ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
957 sub_kinds=[LocSubKind.CA])
958 OD_VL = GenericOperandDesc(
959 ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
960 sub_kinds=[LocSubKind.VL_MAXVL])
961
962
963 @plain_data(frozen=True, unsafe_hash=True)
964 @final
965 class GenericOpProperties(metaclass=InternedMeta):
966 __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
967 "is_copy", "is_load_immediate", "has_side_effects")
968
969 def __init__(
970 self, demo_asm, # type: str
971 inputs, # type: Iterable[GenericOperandDesc]
972 outputs, # type: Iterable[GenericOperandDesc]
973 immediates=(), # type: Iterable[range]
974 is_copy=False, # type: bool
975 is_load_immediate=False, # type: bool
976 has_side_effects=False, # type: bool
977 ):
978 # type: (...) -> None
979 self.demo_asm = demo_asm # type: str
980 self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
981 for inp in self.inputs:
982 if inp.tied_input_index is not None:
983 raise ValueError(
984 f"tied_input_index is not allowed on inputs: {inp}")
985 if inp.write_stage is not OpStage.Early:
986 raise ValueError(
987 f"write_stage is not allowed on inputs: {inp}")
988 self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
989 fixed_locs = [] # type: list[tuple[Loc, int]]
990 for idx, out in enumerate(self.outputs):
991 if out.tied_input_index is not None:
992 if out.tied_input_index >= len(self.inputs):
993 raise ValueError(f"tied_input_index out of range: {out}")
994 tied_inp = self.inputs[out.tied_input_index]
995 expected_out = tied_inp.tied_to_input(out.tied_input_index) \
996 .with_write_stage(out.write_stage)
997 if expected_out != out:
998 raise ValueError(f"output can't be tied to non-equivalent "
999 f"input: {out} tied to {tied_inp}")
1000 if out.fixed_loc is not None:
1001 for other_fixed_loc, other_idx in fixed_locs:
1002 if not other_fixed_loc.conflicts(out.fixed_loc):
1003 continue
1004 raise ValueError(
1005 f"conflicting fixed_locs: outputs[{idx}] and "
1006 f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
1007 f"with {other_fixed_loc}")
1008 fixed_locs.append((out.fixed_loc, idx))
1009 self.immediates = tuple(immediates) # type: tuple[range, ...]
1010 self.is_copy = is_copy # type: bool
1011 self.is_load_immediate = is_load_immediate # type: bool
1012 self.has_side_effects = has_side_effects # type: bool
1013
1014
1015 @plain_data(frozen=True, unsafe_hash=True)
1016 @final
1017 class OpProperties(metaclass=InternedMeta):
1018 __slots__ = "kind", "inputs", "outputs", "maxvl"
1019
1020 def __init__(self, kind, maxvl):
1021 # type: (OpKind, int) -> None
1022 self.kind = kind # type: OpKind
1023 inputs = [] # type: list[OperandDesc]
1024 for inp in self.generic.inputs:
1025 inputs.extend(inp.instantiate(maxvl=maxvl))
1026 self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
1027 outputs = [] # type: list[OperandDesc]
1028 for out in self.generic.outputs:
1029 outputs.extend(out.instantiate(maxvl=maxvl))
1030 self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
1031 self.maxvl = maxvl # type: int
1032
1033 @property
1034 def generic(self):
1035 # type: () -> GenericOpProperties
1036 return self.kind.properties
1037
1038 @property
1039 def immediates(self):
1040 # type: () -> tuple[range, ...]
1041 return self.generic.immediates
1042
1043 @property
1044 def demo_asm(self):
1045 # type: () -> str
1046 return self.generic.demo_asm
1047
1048 @property
1049 def is_copy(self):
1050 # type: () -> bool
1051 return self.generic.is_copy
1052
1053 @property
1054 def is_load_immediate(self):
1055 # type: () -> bool
1056 return self.generic.is_load_immediate
1057
1058 @property
1059 def has_side_effects(self):
1060 # type: () -> bool
1061 return self.generic.has_side_effects
1062
1063
1064 IMM_S16 = range(-1 << 15, 1 << 15)
1065
1066 _SIM_FN = Callable[["Op", "BaseSimState"], None]
1067 _SIM_FN2 = Callable[[], _SIM_FN]
1068 _SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
1069 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
1070 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
1071 _GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
1072
1073
1074 @unique
1075 @final
1076 class OpKind(Enum):
1077 def __init__(self, properties):
1078 # type: (GenericOpProperties) -> None
1079 super().__init__()
1080 self.__properties = properties
1081
1082 @property
1083 def properties(self):
1084 # type: () -> GenericOpProperties
1085 return self.__properties
1086
1087 def instantiate(self, maxvl):
1088 # type: (int) -> OpProperties
1089 return OpProperties(self, maxvl=maxvl)
1090
1091 def __repr__(self):
1092 # type: () -> str
1093 return "OpKind." + self._name_
1094
1095 @cached_property
1096 def sim(self):
1097 # type: () -> _SIM_FN
1098 return _SIM_FNS[self.properties]()
1099
1100 @cached_property
1101 def gen_asm(self):
1102 # type: () -> _GEN_ASM_FN
1103 return _GEN_ASMS[self.properties]()
1104
1105 @staticmethod
1106 def __clearca_sim(op, state):
1107 # type: (Op, BaseSimState) -> None
1108 state[op.outputs[0]] = False,
1109
1110 @staticmethod
1111 def __clearca_gen_asm(op, state):
1112 # type: (Op, GenAsmState) -> None
1113 state.writeln("addic 0, 0, 0")
1114 ClearCA = GenericOpProperties(
1115 demo_asm="addic 0, 0, 0",
1116 inputs=[],
1117 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1118 )
1119 _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
1120 _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
1121
1122 @staticmethod
1123 def __setca_sim(op, state):
1124 # type: (Op, BaseSimState) -> None
1125 state[op.outputs[0]] = True,
1126
1127 @staticmethod
1128 def __setca_gen_asm(op, state):
1129 # type: (Op, GenAsmState) -> None
1130 state.writeln("subfc 0, 0, 0")
1131 SetCA = GenericOpProperties(
1132 demo_asm="subfc 0, 0, 0",
1133 inputs=[],
1134 outputs=[OD_CA.with_write_stage(OpStage.Late)],
1135 )
1136 _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
1137 _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
1138
1139 @staticmethod
1140 def __svadde_sim(op, state):
1141 # type: (Op, BaseSimState) -> None
1142 RA = state[op.input_vals[0]]
1143 RB = state[op.input_vals[1]]
1144 carry, = state[op.input_vals[2]]
1145 VL, = state[op.input_vals[3]]
1146 RT = [] # type: list[int]
1147 for i in range(VL):
1148 v = RA[i] + RB[i] + carry
1149 RT.append(v & GPR_VALUE_MASK)
1150 carry = (v >> GPR_SIZE_IN_BITS) != 0
1151 state[op.outputs[0]] = tuple(RT)
1152 state[op.outputs[1]] = carry,
1153
1154 @staticmethod
1155 def __svadde_gen_asm(op, state):
1156 # type: (Op, GenAsmState) -> None
1157 RT = state.vgpr(op.outputs[0])
1158 RA = state.vgpr(op.input_vals[0])
1159 RB = state.vgpr(op.input_vals[1])
1160 state.writeln(f"sv.adde {RT}, {RA}, {RB}")
1161 SvAddE = GenericOpProperties(
1162 demo_asm="sv.adde *RT, *RA, *RB",
1163 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1164 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1165 )
1166 _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
1167 _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
1168
1169 @staticmethod
1170 def __addze_sim(op, state):
1171 # type: (Op, BaseSimState) -> None
1172 RA, = state[op.input_vals[0]]
1173 carry, = state[op.input_vals[1]]
1174 v = RA + carry
1175 RT = v & GPR_VALUE_MASK
1176 carry = (v >> GPR_SIZE_IN_BITS) != 0
1177 state[op.outputs[0]] = RT,
1178 state[op.outputs[1]] = carry,
1179
1180 @staticmethod
1181 def __addze_gen_asm(op, state):
1182 # type: (Op, GenAsmState) -> None
1183 RT = state.vgpr(op.outputs[0])
1184 RA = state.vgpr(op.input_vals[0])
1185 state.writeln(f"addze {RT}, {RA}")
1186 AddZE = GenericOpProperties(
1187 demo_asm="addze RT, RA",
1188 inputs=[OD_BASE_SGPR, OD_CA],
1189 outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
1190 )
1191 _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
1192 _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
1193
1194 @staticmethod
1195 def __svsubfe_sim(op, state):
1196 # type: (Op, BaseSimState) -> None
1197 RA = state[op.input_vals[0]]
1198 RB = state[op.input_vals[1]]
1199 carry, = state[op.input_vals[2]]
1200 VL, = state[op.input_vals[3]]
1201 RT = [] # type: list[int]
1202 for i in range(VL):
1203 v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
1204 RT.append(v & GPR_VALUE_MASK)
1205 carry = (v >> GPR_SIZE_IN_BITS) != 0
1206 state[op.outputs[0]] = tuple(RT)
1207 state[op.outputs[1]] = carry,
1208
1209 @staticmethod
1210 def __svsubfe_gen_asm(op, state):
1211 # type: (Op, GenAsmState) -> None
1212 RT = state.vgpr(op.outputs[0])
1213 RA = state.vgpr(op.input_vals[0])
1214 RB = state.vgpr(op.input_vals[1])
1215 state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
1216 SvSubFE = GenericOpProperties(
1217 demo_asm="sv.subfe *RT, *RA, *RB",
1218 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
1219 outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
1220 )
1221 _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
1222 _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
1223
1224 @staticmethod
1225 def __svandvs_sim(op, state):
1226 # type: (Op, BaseSimState) -> None
1227 RA = state[op.input_vals[0]]
1228 RB, = state[op.input_vals[1]]
1229 VL, = state[op.input_vals[2]]
1230 RT = [] # type: list[int]
1231 for i in range(VL):
1232 RT.append(RA[i] & RB & GPR_VALUE_MASK)
1233 state[op.outputs[0]] = tuple(RT)
1234
1235 @staticmethod
1236 def __svandvs_gen_asm(op, state):
1237 # type: (Op, GenAsmState) -> None
1238 RT = state.vgpr(op.outputs[0])
1239 RA = state.vgpr(op.input_vals[0])
1240 RB = state.sgpr(op.input_vals[1])
1241 state.writeln(f"sv.and {RT}, {RA}, {RB}")
1242 SvAndVS = GenericOpProperties(
1243 demo_asm="sv.and *RT, *RA, RB",
1244 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1245 outputs=[OD_EXTRA3_VGPR],
1246 )
1247 _SIM_FNS[SvAndVS] = lambda: OpKind.__svandvs_sim
1248 _GEN_ASMS[SvAndVS] = lambda: OpKind.__svandvs_gen_asm
1249
1250 @staticmethod
1251 def __svmaddedu_sim(op, state):
1252 # type: (Op, BaseSimState) -> None
1253 RA = state[op.input_vals[0]]
1254 RB, = state[op.input_vals[1]]
1255 carry, = state[op.input_vals[2]]
1256 VL, = state[op.input_vals[3]]
1257 RT = [] # type: list[int]
1258 for i in range(VL):
1259 v = RA[i] * RB + carry
1260 RT.append(v & GPR_VALUE_MASK)
1261 carry = v >> GPR_SIZE_IN_BITS
1262 state[op.outputs[0]] = tuple(RT)
1263 state[op.outputs[1]] = carry,
1264
1265 @staticmethod
1266 def __svmaddedu_gen_asm(op, state):
1267 # type: (Op, GenAsmState) -> None
1268 RT = state.vgpr(op.outputs[0])
1269 RA = state.vgpr(op.input_vals[0])
1270 RB = state.sgpr(op.input_vals[1])
1271 RC = state.sgpr(op.input_vals[2])
1272 state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
1273 SvMAddEDU = GenericOpProperties(
1274 demo_asm="sv.maddedu *RT, *RA, RB, RC",
1275 inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
1276 outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
1277 )
1278 _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
1279 _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
1280
1281 @staticmethod
1282 def __sradi_sim(op, state):
1283 # type: (Op, BaseSimState) -> None
1284 rs, = state[op.input_vals[0]]
1285 imm = op.immediates[0]
1286 if rs >= 1 << (GPR_SIZE_IN_BITS - 1):
1287 rs -= 1 << GPR_SIZE_IN_BITS
1288 v = rs >> imm
1289 RA = v & GPR_VALUE_MASK
1290 CA = (RA << imm) != rs
1291 state[op.outputs[0]] = RA,
1292 state[op.outputs[1]] = CA,
1293
1294 @staticmethod
1295 def __sradi_gen_asm(op, state):
1296 # type: (Op, GenAsmState) -> None
1297 RA = state.sgpr(op.outputs[0])
1298 RS = state.sgpr(op.input_vals[0])
1299 imm = op.immediates[0]
1300 state.writeln(f"sradi {RA}, {RS}, {imm}")
1301 SRADI = GenericOpProperties(
1302 demo_asm="sradi RA, RS, imm",
1303 inputs=[OD_BASE_SGPR],
1304 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late),
1305 OD_CA.with_write_stage(OpStage.Late)],
1306 immediates=[range(GPR_SIZE_IN_BITS)],
1307 )
1308 _SIM_FNS[SRADI] = lambda: OpKind.__sradi_sim
1309 _GEN_ASMS[SRADI] = lambda: OpKind.__sradi_gen_asm
1310
1311 @staticmethod
1312 def __setvli_sim(op, state):
1313 # type: (Op, BaseSimState) -> None
1314 state[op.outputs[0]] = op.immediates[0],
1315
1316 @staticmethod
1317 def __setvli_gen_asm(op, state):
1318 # type: (Op, GenAsmState) -> None
1319 imm = op.immediates[0]
1320 state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
1321 SetVLI = GenericOpProperties(
1322 demo_asm="setvl 0, 0, imm, 0, 1, 1",
1323 inputs=(),
1324 outputs=[OD_VL.with_write_stage(OpStage.Late)],
1325 immediates=[range(1, 65)],
1326 is_load_immediate=True,
1327 )
1328 _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
1329 _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
1330
1331 @staticmethod
1332 def __svli_sim(op, state):
1333 # type: (Op, BaseSimState) -> None
1334 VL, = state[op.input_vals[0]]
1335 imm = op.immediates[0] & GPR_VALUE_MASK
1336 state[op.outputs[0]] = (imm,) * VL
1337
1338 @staticmethod
1339 def __svli_gen_asm(op, state):
1340 # type: (Op, GenAsmState) -> None
1341 RT = state.vgpr(op.outputs[0])
1342 imm = op.immediates[0]
1343 state.writeln(f"sv.addi {RT}, 0, {imm}")
1344 SvLI = GenericOpProperties(
1345 demo_asm="sv.addi *RT, 0, imm",
1346 inputs=[OD_VL],
1347 outputs=[OD_EXTRA3_VGPR],
1348 immediates=[IMM_S16],
1349 is_load_immediate=True,
1350 )
1351 _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
1352 _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
1353
1354 @staticmethod
1355 def __li_sim(op, state):
1356 # type: (Op, BaseSimState) -> None
1357 imm = op.immediates[0] & GPR_VALUE_MASK
1358 state[op.outputs[0]] = imm,
1359
1360 @staticmethod
1361 def __li_gen_asm(op, state):
1362 # type: (Op, GenAsmState) -> None
1363 RT = state.sgpr(op.outputs[0])
1364 imm = op.immediates[0]
1365 state.writeln(f"addi {RT}, 0, {imm}")
1366 LI = GenericOpProperties(
1367 demo_asm="addi RT, 0, imm",
1368 inputs=(),
1369 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1370 immediates=[IMM_S16],
1371 is_load_immediate=True,
1372 )
1373 _SIM_FNS[LI] = lambda: OpKind.__li_sim
1374 _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
1375
1376 @staticmethod
1377 def __veccopytoreg_sim(op, state):
1378 # type: (Op, BaseSimState) -> None
1379 state[op.outputs[0]] = state[op.input_vals[0]]
1380
1381 @staticmethod
1382 def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
1383 # type: (Loc, Loc, bool, GenAsmState) -> None
1384 sv = "sv." if is_vec else ""
1385 rev = ""
1386 if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
1387 rev = "/mrr"
1388 if src_loc == dest_loc:
1389 return # no-op
1390 if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1391 raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
1392 if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
1393 raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
1394 if src_loc.kind is LocKind.StackI64:
1395 if dest_loc.kind is LocKind.StackI64:
1396 raise ValueError(
1397 f"can't copy from stack to stack: {src_loc} {dest_loc}")
1398 elif dest_loc.kind is not LocKind.GPR:
1399 assert_never(dest_loc.kind)
1400 src = state.stack(src_loc)
1401 dest = state.gpr(dest_loc, is_vec=is_vec)
1402 state.writeln(f"{sv}ld {dest}, {src}")
1403 elif dest_loc.kind is LocKind.StackI64:
1404 if src_loc.kind is not LocKind.GPR:
1405 assert_never(src_loc.kind)
1406 src = state.gpr(src_loc, is_vec=is_vec)
1407 dest = state.stack(dest_loc)
1408 state.writeln(f"{sv}std {src}, {dest}")
1409 elif src_loc.kind is LocKind.GPR:
1410 if dest_loc.kind is not LocKind.GPR:
1411 assert_never(dest_loc.kind)
1412 src = state.gpr(src_loc, is_vec=is_vec)
1413 dest = state.gpr(dest_loc, is_vec=is_vec)
1414 state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
1415 else:
1416 assert_never(src_loc.kind)
1417
1418 @staticmethod
1419 def __veccopytoreg_gen_asm(op, state):
1420 # type: (Op, GenAsmState) -> None
1421 OpKind.__copy_to_from_reg_gen_asm(
1422 src_loc=state.loc(
1423 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1424 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1425 is_vec=True, state=state)
1426
1427 VecCopyToReg = GenericOpProperties(
1428 demo_asm="sv.mv dest, src",
1429 inputs=[GenericOperandDesc(
1430 ty=GenericTy(BaseTy.I64, is_vec=True),
1431 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1432 ), OD_VL],
1433 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1434 is_copy=True,
1435 )
1436 _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
1437 _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
1438
1439 @staticmethod
1440 def __veccopyfromreg_sim(op, state):
1441 # type: (Op, BaseSimState) -> None
1442 state[op.outputs[0]] = state[op.input_vals[0]]
1443
1444 @staticmethod
1445 def __veccopyfromreg_gen_asm(op, state):
1446 # type: (Op, GenAsmState) -> None
1447 OpKind.__copy_to_from_reg_gen_asm(
1448 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1449 dest_loc=state.loc(
1450 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1451 is_vec=True, state=state)
1452 VecCopyFromReg = GenericOpProperties(
1453 demo_asm="sv.mv dest, src",
1454 inputs=[OD_EXTRA3_VGPR, OD_VL],
1455 outputs=[GenericOperandDesc(
1456 ty=GenericTy(BaseTy.I64, is_vec=True),
1457 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
1458 write_stage=OpStage.Late,
1459 )],
1460 is_copy=True,
1461 )
1462 _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
1463 _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
1464
1465 @staticmethod
1466 def __copytoreg_sim(op, state):
1467 # type: (Op, BaseSimState) -> None
1468 state[op.outputs[0]] = state[op.input_vals[0]]
1469
1470 @staticmethod
1471 def __copytoreg_gen_asm(op, state):
1472 # type: (Op, GenAsmState) -> None
1473 OpKind.__copy_to_from_reg_gen_asm(
1474 src_loc=state.loc(
1475 op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
1476 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1477 is_vec=False, state=state)
1478 CopyToReg = GenericOpProperties(
1479 demo_asm="mv dest, src",
1480 inputs=[GenericOperandDesc(
1481 ty=GenericTy(BaseTy.I64, is_vec=False),
1482 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1483 LocSubKind.StackI64],
1484 )],
1485 outputs=[GenericOperandDesc(
1486 ty=GenericTy(BaseTy.I64, is_vec=False),
1487 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1488 write_stage=OpStage.Late,
1489 )],
1490 is_copy=True,
1491 )
1492 _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
1493 _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
1494
1495 @staticmethod
1496 def __copyfromreg_sim(op, state):
1497 # type: (Op, BaseSimState) -> None
1498 state[op.outputs[0]] = state[op.input_vals[0]]
1499
1500 @staticmethod
1501 def __copyfromreg_gen_asm(op, state):
1502 # type: (Op, GenAsmState) -> None
1503 OpKind.__copy_to_from_reg_gen_asm(
1504 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1505 dest_loc=state.loc(
1506 op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
1507 is_vec=False, state=state)
1508 CopyFromReg = GenericOpProperties(
1509 demo_asm="mv dest, src",
1510 inputs=[GenericOperandDesc(
1511 ty=GenericTy(BaseTy.I64, is_vec=False),
1512 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
1513 )],
1514 outputs=[GenericOperandDesc(
1515 ty=GenericTy(BaseTy.I64, is_vec=False),
1516 sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
1517 LocSubKind.StackI64],
1518 write_stage=OpStage.Late,
1519 )],
1520 is_copy=True,
1521 )
1522 _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
1523 _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
1524
1525 @staticmethod
1526 def __concat_sim(op, state):
1527 # type: (Op, BaseSimState) -> None
1528 state[op.outputs[0]] = tuple(
1529 state[i][0] for i in op.input_vals[:-1])
1530
1531 @staticmethod
1532 def __concat_gen_asm(op, state):
1533 # type: (Op, GenAsmState) -> None
1534 OpKind.__copy_to_from_reg_gen_asm(
1535 src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
1536 dest_loc=state.loc(op.outputs[0], LocKind.GPR),
1537 is_vec=True, state=state)
1538 Concat = GenericOpProperties(
1539 demo_asm="sv.mv dest, src",
1540 inputs=[GenericOperandDesc(
1541 ty=GenericTy(BaseTy.I64, is_vec=False),
1542 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1543 spread=True,
1544 ), OD_VL],
1545 outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
1546 is_copy=True,
1547 )
1548 _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
1549 _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
1550
1551 @staticmethod
1552 def __spread_sim(op, state):
1553 # type: (Op, BaseSimState) -> None
1554 for idx, inp in enumerate(state[op.input_vals[0]]):
1555 state[op.outputs[idx]] = inp,
1556
1557 @staticmethod
1558 def __spread_gen_asm(op, state):
1559 # type: (Op, GenAsmState) -> None
1560 OpKind.__copy_to_from_reg_gen_asm(
1561 src_loc=state.loc(op.input_vals[0], LocKind.GPR),
1562 dest_loc=state.loc(op.outputs, LocKind.GPR),
1563 is_vec=True, state=state)
1564 Spread = GenericOpProperties(
1565 demo_asm="sv.mv dest, src",
1566 inputs=[OD_EXTRA3_VGPR, OD_VL],
1567 outputs=[GenericOperandDesc(
1568 ty=GenericTy(BaseTy.I64, is_vec=False),
1569 sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
1570 spread=True,
1571 write_stage=OpStage.Late,
1572 )],
1573 is_copy=True,
1574 )
1575 _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
1576 _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
1577
1578 @staticmethod
1579 def __svld_sim(op, state):
1580 # type: (Op, BaseSimState) -> None
1581 RA, = state[op.input_vals[0]]
1582 VL, = state[op.input_vals[1]]
1583 addr = RA + op.immediates[0]
1584 RT = [] # type: list[int]
1585 for i in range(VL):
1586 v = state.load(addr + GPR_SIZE_IN_BYTES * i)
1587 RT.append(v & GPR_VALUE_MASK)
1588 state[op.outputs[0]] = tuple(RT)
1589
1590 @staticmethod
1591 def __svld_gen_asm(op, state):
1592 # type: (Op, GenAsmState) -> None
1593 RA = state.sgpr(op.input_vals[0])
1594 RT = state.vgpr(op.outputs[0])
1595 imm = op.immediates[0]
1596 state.writeln(f"sv.ld {RT}, {imm}({RA})")
1597 SvLd = GenericOpProperties(
1598 demo_asm="sv.ld *RT, imm(RA)",
1599 inputs=[OD_EXTRA3_SGPR, OD_VL],
1600 outputs=[OD_EXTRA3_VGPR],
1601 immediates=[IMM_S16],
1602 )
1603 _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
1604 _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
1605
1606 @staticmethod
1607 def __ld_sim(op, state):
1608 # type: (Op, BaseSimState) -> None
1609 RA, = state[op.input_vals[0]]
1610 addr = RA + op.immediates[0]
1611 v = state.load(addr)
1612 state[op.outputs[0]] = v & GPR_VALUE_MASK,
1613
1614 @staticmethod
1615 def __ld_gen_asm(op, state):
1616 # type: (Op, GenAsmState) -> None
1617 RA = state.sgpr(op.input_vals[0])
1618 RT = state.sgpr(op.outputs[0])
1619 imm = op.immediates[0]
1620 state.writeln(f"ld {RT}, {imm}({RA})")
1621 Ld = GenericOpProperties(
1622 demo_asm="ld RT, imm(RA)",
1623 inputs=[OD_BASE_SGPR],
1624 outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
1625 immediates=[IMM_S16],
1626 )
1627 _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
1628 _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
1629
1630 @staticmethod
1631 def __svstd_sim(op, state):
1632 # type: (Op, BaseSimState) -> None
1633 RS = state[op.input_vals[0]]
1634 RA, = state[op.input_vals[1]]
1635 VL, = state[op.input_vals[2]]
1636 addr = RA + op.immediates[0]
1637 for i in range(VL):
1638 state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
1639
1640 @staticmethod
1641 def __svstd_gen_asm(op, state):
1642 # type: (Op, GenAsmState) -> None
1643 RS = state.vgpr(op.input_vals[0])
1644 RA = state.sgpr(op.input_vals[1])
1645 imm = op.immediates[0]
1646 state.writeln(f"sv.std {RS}, {imm}({RA})")
1647 SvStd = GenericOpProperties(
1648 demo_asm="sv.std *RS, imm(RA)",
1649 inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
1650 outputs=[],
1651 immediates=[IMM_S16],
1652 has_side_effects=True,
1653 )
1654 _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
1655 _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
1656
1657 @staticmethod
1658 def __std_sim(op, state):
1659 # type: (Op, BaseSimState) -> None
1660 RS, = state[op.input_vals[0]]
1661 RA, = state[op.input_vals[1]]
1662 addr = RA + op.immediates[0]
1663 state.store(addr, value=RS)
1664
1665 @staticmethod
1666 def __std_gen_asm(op, state):
1667 # type: (Op, GenAsmState) -> None
1668 RS = state.sgpr(op.input_vals[0])
1669 RA = state.sgpr(op.input_vals[1])
1670 imm = op.immediates[0]
1671 state.writeln(f"std {RS}, {imm}({RA})")
1672 Std = GenericOpProperties(
1673 demo_asm="std RS, imm(RA)",
1674 inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
1675 outputs=[],
1676 immediates=[IMM_S16],
1677 has_side_effects=True,
1678 )
1679 _SIM_FNS[Std] = lambda: OpKind.__std_sim
1680 _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
1681
1682 @staticmethod
1683 def __funcargr3_sim(op, state):
1684 # type: (Op, BaseSimState) -> None
1685 pass # return value set before simulation
1686
1687 @staticmethod
1688 def __funcargr3_gen_asm(op, state):
1689 # type: (Op, GenAsmState) -> None
1690 pass # no instructions needed
1691 FuncArgR3 = GenericOpProperties(
1692 demo_asm="",
1693 inputs=[],
1694 outputs=[OD_BASE_SGPR.with_fixed_loc(
1695 Loc(kind=LocKind.GPR, start=3, reg_len=1))],
1696 )
1697 _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
1698 _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
1699
1700
1701 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1702 class SSAValOrUse(metaclass=InternedMeta):
1703 __slots__ = "op", "operand_idx"
1704
1705 def __init__(self, op, operand_idx):
1706 # type: (Op, int) -> None
1707 super().__init__()
1708 self.op = op
1709 if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
1710 raise ValueError("invalid operand_idx")
1711 self.operand_idx = operand_idx
1712
1713 @abstractmethod
1714 def __repr__(self):
1715 # type: () -> str
1716 ...
1717
1718 @property
1719 @abstractmethod
1720 def descriptor_array(self):
1721 # type: () -> tuple[OperandDesc, ...]
1722 ...
1723
1724 @cached_property
1725 def defining_descriptor(self):
1726 # type: () -> OperandDesc
1727 return self.descriptor_array[self.operand_idx]
1728
1729 @cached_property
1730 def ty(self):
1731 # type: () -> Ty
1732 return self.defining_descriptor.ty
1733
1734 @cached_property
1735 def ty_before_spread(self):
1736 # type: () -> Ty
1737 return self.defining_descriptor.ty_before_spread
1738
1739 @property
1740 def base_ty(self):
1741 # type: () -> BaseTy
1742 return self.ty_before_spread.base_ty
1743
1744 @property
1745 def reg_offset_in_unspread(self):
1746 """ the number of reg-sized slots in the unspread Loc before self's Loc
1747
1748 e.g. if the unspread Loc containing self is:
1749 `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
1750 and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
1751 then reg_offset_into_unspread == 2 == 10 - 8
1752 """
1753 return self.defining_descriptor.reg_offset_in_unspread
1754
1755 @property
1756 def unspread_start_idx(self):
1757 # type: () -> int
1758 return self.operand_idx - (self.defining_descriptor.spread_index or 0)
1759
1760 @property
1761 def unspread_start(self):
1762 # type: () -> Self
1763 return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
1764
1765
1766 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1767 @final
1768 class SSAVal(SSAValOrUse):
1769 __slots__ = ()
1770
1771 def __repr__(self):
1772 # type: () -> str
1773 return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
1774
1775 @cached_property
1776 def def_loc_set_before_spread(self):
1777 # type: () -> LocSet
1778 return self.defining_descriptor.loc_set_before_spread
1779
1780 @cached_property
1781 def descriptor_array(self):
1782 # type: () -> tuple[OperandDesc, ...]
1783 return self.op.properties.outputs
1784
1785 @cached_property
1786 def tied_input(self):
1787 # type: () -> None | SSAUse
1788 if self.defining_descriptor.tied_input_index is None:
1789 return None
1790 return SSAUse(op=self.op,
1791 operand_idx=self.defining_descriptor.tied_input_index)
1792
1793 @property
1794 def write_stage(self):
1795 # type: () -> OpStage
1796 return self.defining_descriptor.write_stage
1797
1798 @property
1799 def current_debugging_value(self):
1800 # type: () -> tuple[int, ...]
1801 """ get the current value for debugging in pdb or similar.
1802
1803 This is intended for use with
1804 `PreRASimState.set_current_debugging_state`.
1805
1806 This is only intended for debugging, do not use in unit tests or
1807 production code.
1808 """
1809 return PreRASimState.get_current_debugging_state()[self]
1810
1811
1812 @plain_data(frozen=True, unsafe_hash=True, repr=False)
1813 @final
1814 class SSAUse(SSAValOrUse):
1815 __slots__ = ()
1816
1817 @cached_property
1818 def use_loc_set_before_spread(self):
1819 # type: () -> LocSet
1820 return self.defining_descriptor.loc_set_before_spread
1821
1822 @cached_property
1823 def descriptor_array(self):
1824 # type: () -> tuple[OperandDesc, ...]
1825 return self.op.properties.inputs
1826
1827 def __repr__(self):
1828 # type: () -> str
1829 return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
1830
1831 @property
1832 def ssa_val(self):
1833 # type: () -> SSAVal
1834 return self.op.input_vals[self.operand_idx]
1835
1836 @ssa_val.setter
1837 def ssa_val(self, ssa_val):
1838 # type: (SSAVal) -> None
1839 self.op.input_vals[self.operand_idx] = ssa_val
1840
1841
1842 _T = TypeVar("_T")
1843 _Desc = TypeVar("_Desc")
1844
1845
1846 class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
1847 @abstractmethod
1848 def _verify_write_with_desc(self, idx, item, desc):
1849 # type: (int, _T | Any, _Desc) -> None
1850 raise NotImplementedError
1851
1852 @final
1853 def _verify_write(self, idx, item):
1854 # type: (int | Any, _T | Any) -> int
1855 if not isinstance(idx, int):
1856 if isinstance(idx, slice):
1857 raise TypeError(
1858 f"can't write to slice of {self.__class__.__name__}")
1859 raise TypeError(f"can't write with index {idx!r}")
1860 # normalize idx, raising IndexError if it is out of range
1861 idx = range(len(self.descriptors))[idx]
1862 desc = self.descriptors[idx]
1863 self._verify_write_with_desc(idx, item, desc)
1864 return idx
1865
1866 def _on_set(self, idx, new_item, old_item):
1867 # type: (int, _T, _T | None) -> None
1868 pass
1869
1870 @abstractmethod
1871 def _get_descriptors(self):
1872 # type: () -> tuple[_Desc, ...]
1873 raise NotImplementedError
1874
1875 @cached_property
1876 @final
1877 def descriptors(self):
1878 # type: () -> tuple[_Desc, ...]
1879 return self._get_descriptors()
1880
1881 @property
1882 @final
1883 def op(self):
1884 return self.__op
1885
1886 def __init__(self, items, op):
1887 # type: (Iterable[_T], Op) -> None
1888 super().__init__()
1889 self.__op = op
1890 self.__items = [] # type: list[_T]
1891 for idx, item in enumerate(items):
1892 if idx >= len(self.descriptors):
1893 raise ValueError("too many items")
1894 _ = self._verify_write(idx, item)
1895 self.__items.append(item)
1896 if len(self.__items) < len(self.descriptors):
1897 raise ValueError("not enough items")
1898
1899 @final
1900 def __iter__(self):
1901 # type: () -> Iterator[_T]
1902 yield from self.__items
1903
1904 @overload
1905 def __getitem__(self, idx):
1906 # type: (int) -> _T
1907 ...
1908
1909 @overload
1910 def __getitem__(self, idx):
1911 # type: (slice) -> list[_T]
1912 ...
1913
1914 @final
1915 def __getitem__(self, idx):
1916 # type: (int | slice) -> _T | list[_T]
1917 return self.__items[idx]
1918
1919 @final
1920 def __setitem__(self, idx, item):
1921 # type: (int, _T) -> None
1922 idx = self._verify_write(idx, item)
1923 self.__items[idx] = item
1924
1925 @final
1926 def __len__(self):
1927 # type: () -> int
1928 return len(self.__items)
1929
1930 def __repr__(self):
1931 # type: () -> str
1932 return f"{self.__class__.__name__}({self.__items}, op=...)"
1933
1934
1935 @final
1936 class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
1937 def _get_descriptors(self):
1938 # type: () -> tuple[OperandDesc, ...]
1939 return self.op.properties.inputs
1940
1941 def _verify_write_with_desc(self, idx, item, desc):
1942 # type: (int, SSAVal | Any, OperandDesc) -> None
1943 if not isinstance(item, SSAVal):
1944 raise TypeError("expected value of type SSAVal")
1945 if item.ty != desc.ty:
1946 raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
1947 f"corresponding input's type {desc.ty!r}")
1948
1949 def _on_set(self, idx, new_item, old_item):
1950 # type: (int, SSAVal, SSAVal | None) -> None
1951 SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
1952
1953 def __init__(self, items, op):
1954 # type: (Iterable[SSAVal], Op) -> None
1955 if hasattr(op, "inputs"):
1956 raise ValueError("Op.inputs already set")
1957 super().__init__(items, op)
1958
1959
1960 @final
1961 class OpImmediates(OpInputSeq[int, range]):
1962 def _get_descriptors(self):
1963 # type: () -> tuple[range, ...]
1964 return self.op.properties.immediates
1965
1966 def _verify_write_with_desc(self, idx, item, desc):
1967 # type: (int, int | Any, range) -> None
1968 if not isinstance(item, int):
1969 raise TypeError("expected value of type int")
1970 if item not in desc:
1971 raise ValueError(f"immediate value {item!r} not in {desc!r}")
1972
1973 def __init__(self, items, op):
1974 # type: (Iterable[int], Op) -> None
1975 if hasattr(op, "immediates"):
1976 raise ValueError("Op.immediates already set")
1977 super().__init__(items, op)
1978
1979
1980 @plain_data(frozen=True, eq=False, repr=False)
1981 @final
1982 class Op:
1983 __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
1984 "outputs", "name")
1985
1986 def __init__(self, fn, properties, input_vals, immediates, name=""):
1987 # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
1988 self.fn = fn
1989 self.properties = properties
1990 self.input_vals = OpInputVals(input_vals, op=self)
1991 inputs_len = len(self.properties.inputs)
1992 self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
1993 self.immediates = OpImmediates(immediates, op=self)
1994 outputs_len = len(self.properties.outputs)
1995 self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
1996 self.name = fn._add_op_with_unused_name(self, name) # type: ignore
1997
1998 @property
1999 def kind(self):
2000 # type: () -> OpKind
2001 return self.properties.kind
2002
2003 def __eq__(self, other):
2004 # type: (Op | Any) -> bool
2005 if isinstance(other, Op):
2006 return self is other
2007 return NotImplemented
2008
2009 def __hash__(self):
2010 # type: () -> int
2011 return object.__hash__(self)
2012
2013 def __repr__(self, wrap_width=63, indent=" "):
2014 # type: (int, str) -> str
2015 WRAP_POINT = "\u200B" # zero-width space
2016 items = [f"{self.name}:\n"]
2017 for i, out in enumerate(self.outputs):
2018 item = f"<...outputs[{i}]: {out.ty}>"
2019 if i == 0:
2020 item = "(" + WRAP_POINT + item
2021 if i != len(self.outputs) - 1:
2022 item += ", " + WRAP_POINT
2023 else:
2024 item += WRAP_POINT + ") <= "
2025 items.append(item)
2026 items.append(self.kind._name_)
2027 if len(self.input_vals) + len(self.immediates) != 0:
2028 items[-1] += "("
2029 items[-1] += WRAP_POINT
2030 for i, inp in enumerate(self.input_vals):
2031 item = repr(inp)
2032 if i != len(self.input_vals) - 1 or len(self.immediates) != 0:
2033 item += ", " + WRAP_POINT
2034 else:
2035 item += ") " + WRAP_POINT
2036 items.append(item)
2037 for i, imm in enumerate(self.immediates):
2038 item = hex(imm)
2039 if i != len(self.immediates) - 1:
2040 item += ", " + WRAP_POINT
2041 else:
2042 item += ") " + WRAP_POINT
2043 items.append(item)
2044 lines = [] # type: list[str]
2045 for i, line_in in enumerate("".join(items).splitlines()):
2046 if i != 0:
2047 line_in = indent + line_in
2048 line_out = ""
2049 for part in line_in.split(WRAP_POINT):
2050 if line_out == "":
2051 line_out = part
2052 continue
2053 trial_line_out = line_out + part
2054 if len(trial_line_out.rstrip()) > wrap_width:
2055 lines.append(line_out.rstrip())
2056 line_out = indent + part
2057 else:
2058 line_out = trial_line_out
2059 lines.append(line_out.rstrip())
2060 return "\n".join(lines)
2061
2062 def sim(self, state):
2063 # type: (BaseSimState) -> None
2064 for inp in self.input_vals:
2065 try:
2066 val = state[inp]
2067 except KeyError:
2068 raise ValueError(f"SSAVal {inp} not yet assigned when "
2069 f"running {self}")
2070 if len(val) != inp.ty.reg_len:
2071 raise ValueError(
2072 f"value of SSAVal {inp} has wrong number of elements: "
2073 f"expected {inp.ty.reg_len} found "
2074 f"{len(val)}: {val!r}")
2075 if isinstance(state, PreRASimState):
2076 for out in self.outputs:
2077 if out in state.ssa_vals:
2078 if self.kind is OpKind.FuncArgR3:
2079 continue
2080 raise ValueError(f"SSAVal {out} already assigned before "
2081 f"running {self}")
2082 self.kind.sim(self, state)
2083 for out in self.outputs:
2084 try:
2085 val = state[out]
2086 except KeyError:
2087 raise ValueError(f"running {self} failed to assign to {out}")
2088 if len(val) != out.ty.reg_len:
2089 raise ValueError(
2090 f"value of SSAVal {out} has wrong number of elements: "
2091 f"expected {out.ty.reg_len} found "
2092 f"{len(val)}: {val!r}")
2093
2094 def gen_asm(self, state):
2095 # type: (GenAsmState) -> None
2096 all_loc_kinds = tuple(LocKind)
2097 for inp in self.input_vals:
2098 state.loc(inp, expected_kinds=all_loc_kinds)
2099 for out in self.outputs:
2100 state.loc(out, expected_kinds=all_loc_kinds)
2101 self.kind.gen_asm(self, state)
2102
2103
2104 @plain_data(frozen=True, repr=False)
2105 class BaseSimState(metaclass=ABCMeta):
2106 __slots__ = "memory",
2107
2108 def __init__(self, memory):
2109 # type: (dict[int, int]) -> None
2110 super().__init__()
2111 self.memory = memory # type: dict[int, int]
2112
2113 def load_byte(self, addr):
2114 # type: (int) -> int
2115 addr &= GPR_VALUE_MASK
2116 return self.memory.get(addr, 0) & 0xFF
2117
2118 def store_byte(self, addr, value):
2119 # type: (int, int) -> None
2120 addr &= GPR_VALUE_MASK
2121 value &= 0xFF
2122 self.memory[addr] = value
2123
2124 def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
2125 # type: (int, int, bool) -> int
2126 if addr % size_in_bytes != 0:
2127 raise ValueError(f"address not aligned: {hex(addr)} "
2128 f"required alignment: {size_in_bytes}")
2129 retval = 0
2130 for i in range(size_in_bytes):
2131 retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
2132 if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
2133 retval -= 1 << size_in_bytes * BITS_IN_BYTE
2134 return retval
2135
2136 def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
2137 # type: (int, int, int) -> None
2138 if addr % size_in_bytes != 0:
2139 raise ValueError(f"address not aligned: {hex(addr)} "
2140 f"required alignment: {size_in_bytes}")
2141 for i in range(size_in_bytes):
2142 self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
2143
2144 def _memory__repr(self):
2145 # type: () -> str
2146 if len(self.memory) == 0:
2147 return "{}"
2148 keys = sorted(self.memory.keys(), reverse=True)
2149 CHUNK_SIZE = GPR_SIZE_IN_BYTES
2150 items = [] # type: list[str]
2151 while len(keys) != 0:
2152 addr = keys[-1]
2153 if (len(keys) >= CHUNK_SIZE
2154 and addr % CHUNK_SIZE == 0
2155 and keys[-CHUNK_SIZE:]
2156 == list(reversed(range(addr, addr + CHUNK_SIZE)))):
2157 value = self.load(addr, size_in_bytes=CHUNK_SIZE)
2158 items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
2159 keys[-CHUNK_SIZE:] = ()
2160 else:
2161 items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
2162 if len(items) == 1:
2163 return f"{{{items[0]}}}"
2164 items_str = ",\n".join(items)
2165 return f"{{\n{items_str}}}"
2166
2167 def __repr__(self):
2168 # type: () -> str
2169 field_vals = [] # type: list[str]
2170 for name in fields(self):
2171 try:
2172 value = getattr(self, name)
2173 except AttributeError:
2174 field_vals.append(f"{name}=<not set>")
2175 continue
2176 repr_fn = getattr(self, f"_{name}__repr", None)
2177 if callable(repr_fn):
2178 field_vals.append(f"{name}={repr_fn()}")
2179 else:
2180 field_vals.append(f"{name}={value!r}")
2181 field_vals_str = ", ".join(field_vals)
2182 return f"{self.__class__.__name__}({field_vals_str})"
2183
2184 @abstractmethod
2185 def __getitem__(self, ssa_val):
2186 # type: (SSAVal) -> tuple[int, ...]
2187 ...
2188
2189 @abstractmethod
2190 def __setitem__(self, ssa_val, value):
2191 # type: (SSAVal, tuple[int, ...]) -> None
2192 ...
2193
2194
2195 @plain_data(frozen=True, repr=False)
2196 @final
2197 class PreRASimState(BaseSimState):
2198 __slots__ = "ssa_vals",
2199
2200 def __init__(self, ssa_vals, memory):
2201 # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
2202 super().__init__(memory)
2203 self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
2204
2205 def _ssa_vals__repr(self):
2206 # type: () -> str
2207 if len(self.ssa_vals) == 0:
2208 return "{}"
2209 items = [] # type: list[str]
2210 CHUNK_SIZE = 4
2211 for k, v in self.ssa_vals.items():
2212 element_strs = [] # type: list[str]
2213 for i, el in enumerate(v):
2214 if i % CHUNK_SIZE != 0:
2215 element_strs.append(" " + hex(el))
2216 else:
2217 element_strs.append("\n " + hex(el))
2218 if len(element_strs) <= CHUNK_SIZE:
2219 element_strs[0] = element_strs[0].lstrip()
2220 if len(element_strs) == 1:
2221 element_strs.append("")
2222 v_str = ",".join(element_strs)
2223 items.append(f"{k!r}: ({v_str})")
2224 if len(items) == 1 and "\n" not in items[0]:
2225 return f"{{{items[0]}}}"
2226 items_str = ",\n".join(items)
2227 return f"{{\n{items_str},\n}}"
2228
2229 def __getitem__(self, ssa_val):
2230 # type: (SSAVal) -> tuple[int, ...]
2231 return self.ssa_vals[ssa_val]
2232
2233 def __setitem__(self, ssa_val, value):
2234 # type: (SSAVal, tuple[int, ...]) -> None
2235 if len(value) != ssa_val.ty.reg_len:
2236 raise ValueError("value has wrong len")
2237 self.ssa_vals[ssa_val] = value
2238
2239 __CURRENT_DEBUGGING_STATE = [] # type: list[PreRASimState]
2240
2241 @contextmanager
2242 def set_as_current_debugging_state(self):
2243 """ return a context manager that sets self as the current state for
2244 debugging in pdb or similar. This is intended only for use with
2245 `get_current_debugging_state` which should not be used in unit tests
2246 or production code.
2247 """
2248 try:
2249 PreRASimState.__CURRENT_DEBUGGING_STATE.append(self)
2250 yield
2251 finally:
2252 assert self is PreRASimState.__CURRENT_DEBUGGING_STATE.pop(), \
2253 "inconsistent __CURRENT_DEBUGGING_STATE"
2254
2255 @staticmethod
2256 def get_current_debugging_state():
2257 # type: () -> PreRASimState
2258 """ get the current state for debugging in pdb or similar.
2259
2260 This is intended for use with `set_current_debugging_state`.
2261
2262 This is only intended for debugging, do not use in unit tests or
2263 production code.
2264 """
2265 if len(PreRASimState.__CURRENT_DEBUGGING_STATE) == 0:
2266 raise ValueError("no current debugging state")
2267 return PreRASimState.__CURRENT_DEBUGGING_STATE[-1]
2268
2269
2270 @plain_data(frozen=True, repr=False)
2271 @final
2272 class PostRASimState(BaseSimState):
2273 __slots__ = "ssa_val_to_loc_map", "loc_values"
2274
2275 def __init__(self, ssa_val_to_loc_map, memory, loc_values):
2276 # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
2277 super().__init__(memory)
2278 self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
2279 for ssa_val, loc in self.ssa_val_to_loc_map.items():
2280 if ssa_val.ty != loc.ty:
2281 raise ValueError(
2282 f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
2283 self.loc_values = loc_values
2284 for loc in self.loc_values.keys():
2285 if loc.reg_len != 1:
2286 raise ValueError(
2287 "loc_values must only contain Locs with reg_len=1, all "
2288 "larger Locs will be split into reg_len=1 sub-Locs")
2289
2290 def _loc_values__repr(self):
2291 # type: () -> str
2292 locs = sorted(self.loc_values.keys(),
2293 key=lambda v: (v.kind.name, v.start))
2294 items = [] # type: list[str]
2295 for loc in locs:
2296 items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
2297 items_str = ",\n".join(items)
2298 return f"{{\n{items_str},\n}}"
2299
2300 def __getitem__(self, ssa_val):
2301 # type: (SSAVal) -> tuple[int, ...]
2302 loc = self.ssa_val_to_loc_map[ssa_val]
2303 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2304 retval = [] # type: list[int]
2305 for i in range(loc.reg_len):
2306 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2307 retval.append(self.loc_values.get(subloc, 0))
2308 return tuple(retval)
2309
2310 def __setitem__(self, ssa_val, value):
2311 # type: (SSAVal, tuple[int, ...]) -> None
2312 if len(value) != ssa_val.ty.reg_len:
2313 raise ValueError("value has wrong len")
2314 loc = self.ssa_val_to_loc_map[ssa_val]
2315 subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
2316 for i in range(loc.reg_len):
2317 subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
2318 self.loc_values[subloc] = value[i]
2319
2320
2321 @plain_data(frozen=True)
2322 class GenAsmState:
2323 __slots__ = "allocated_locs", "output"
2324
2325 def __init__(self, allocated_locs, output=None):
2326 # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
2327 super().__init__()
2328 self.allocated_locs = FMap(allocated_locs)
2329 for ssa_val, loc in self.allocated_locs.items():
2330 if ssa_val.ty != loc.ty:
2331 raise ValueError(
2332 f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
2333 if output is None:
2334 output = []
2335 self.output = output
2336
2337 __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
2338
2339 def loc(self, ssa_val_or_locs, expected_kinds):
2340 # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
2341 if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
2342 ssa_val_or_locs = [ssa_val_or_locs]
2343 locs = [] # type: list[Loc]
2344 for i in ssa_val_or_locs:
2345 if isinstance(i, SSAVal):
2346 locs.append(self.allocated_locs[i])
2347 else:
2348 locs.append(i)
2349 if len(locs) == 0:
2350 raise ValueError("invalid Loc sequence: must not be empty")
2351 retval = locs[0].try_concat(*locs[1:])
2352 if retval is None:
2353 raise ValueError("invalid Loc sequence: try_concat failed")
2354 if isinstance(expected_kinds, LocKind):
2355 expected_kinds = expected_kinds,
2356 if retval.kind not in expected_kinds:
2357 if len(expected_kinds) == 1:
2358 expected_kinds = expected_kinds[0]
2359 raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
2360 f"{retval.kind} expected {expected_kinds}")
2361 return retval
2362
2363 def gpr(self, ssa_val_or_locs, is_vec):
2364 # type: (__SSA_VAL_OR_LOCS, bool) -> str
2365 loc = self.loc(ssa_val_or_locs, LocKind.GPR)
2366 vec_str = "*" if is_vec else ""
2367 return vec_str + str(loc.start)
2368
2369 def sgpr(self, ssa_val_or_locs):
2370 # type: (__SSA_VAL_OR_LOCS) -> str
2371 return self.gpr(ssa_val_or_locs, is_vec=False)
2372
2373 def vgpr(self, ssa_val_or_locs):
2374 # type: (__SSA_VAL_OR_LOCS) -> str
2375 return self.gpr(ssa_val_or_locs, is_vec=True)
2376
2377 def stack(self, ssa_val_or_locs):
2378 # type: (__SSA_VAL_OR_LOCS) -> str
2379 loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
2380 return f"{loc.start}(1)"
2381
2382 def writeln(self, *line_segments):
2383 # type: (*str) -> None
2384 line = " ".join(line_segments)
2385 if isinstance(self.output, list):
2386 self.output.append(line)
2387 else:
2388 self.output.write(line + "\n")