working on rewriting compiler ir to fix reg alloc issues
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 # type: ignore
2 """
3 Compiler IR for Toom-Cook algorithm generator for SVP64
4
5 This assumes VL != 0 throughout.
6 """
7
8 from abc import ABCMeta, abstractmethod
9 from collections import defaultdict
10 from enum import Enum, EnumMeta, unique
11 from functools import lru_cache
12 from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast
13
14 from nmutil.plain_data import fields, plain_data
15
16 from bigint_presentation_code.type_util import final
17 from bigint_presentation_code.util import FMap, OFSet, OSet
18
19
20 class ABCEnumMeta(EnumMeta, ABCMeta):
21 pass
22
23
24 class RegLoc(metaclass=ABCMeta):
25 __slots__ = ()
26
27 @abstractmethod
28 def conflicts(self, other):
29 # type: (RegLoc) -> bool
30 ...
31
32 def get_subreg_at_offset(self, subreg_type, offset):
33 # type: (RegType, int) -> RegLoc
34 if self not in subreg_type.reg_class:
35 raise ValueError(f"register not a member of subreg_type: "
36 f"reg={self} subreg_type={subreg_type}")
37 if offset != 0:
38 raise ValueError(f"non-zero sub-register offset not supported "
39 f"for register: {self}")
40 return self
41
42
43 GPR_COUNT = 128
44
45
46 @plain_data(frozen=True, unsafe_hash=True, repr=False)
47 @final
48 class GPRRange(RegLoc, Sequence["GPRRange"]):
49 __slots__ = "start", "length"
50
51 def __init__(self, start, length=None):
52 # type: (int | range, int | None) -> None
53 if isinstance(start, range):
54 if length is not None:
55 raise TypeError("can't specify length when input is a range")
56 if start.step != 1:
57 raise ValueError("range must have a step of 1")
58 length = len(start)
59 start = start.start
60 elif length is None:
61 length = 1
62 if length <= 0 or start < 0 or start + length > GPR_COUNT:
63 raise ValueError("invalid GPRRange")
64 self.start = start
65 self.length = length
66
67 @property
68 def stop(self):
69 return self.start + self.length
70
71 @property
72 def step(self):
73 return 1
74
75 @property
76 def range(self):
77 return range(self.start, self.stop, self.step)
78
79 def __len__(self):
80 return self.length
81
82 def __getitem__(self, item):
83 # type: (int | slice) -> GPRRange
84 return GPRRange(self.range[item])
85
86 def __contains__(self, value):
87 # type: (GPRRange) -> bool
88 return value.start >= self.start and value.stop <= self.stop
89
90 def index(self, sub, start=None, end=None):
91 # type: (GPRRange, int | None, int | None) -> int
92 r = self.range[start:end]
93 if sub.start < r.start or sub.stop > r.stop:
94 raise ValueError("GPR range not found")
95 return sub.start - self.start
96
97 def count(self, sub, start=None, end=None):
98 # type: (GPRRange, int | None, int | None) -> int
99 r = self.range[start:end]
100 if len(r) == 0:
101 return 0
102 return int(sub in GPRRange(r))
103
104 def conflicts(self, other):
105 # type: (RegLoc) -> bool
106 if isinstance(other, GPRRange):
107 return self.stop > other.start and other.stop > self.start
108 return False
109
110 def get_subreg_at_offset(self, subreg_type, offset):
111 # type: (RegType, int) -> GPRRange
112 if not isinstance(subreg_type, (GPRRangeType, FixedGPRRangeType)):
113 raise ValueError(f"subreg_type is not a FixedGPRRangeType or "
114 f"GPRRangeType: {subreg_type}")
115 if offset < 0 or offset + subreg_type.length > self.stop:
116 raise ValueError(f"sub-register offset is out of range: {offset}")
117 return GPRRange(self.start + offset, subreg_type.length)
118
119 def __repr__(self):
120 if self.length == 1:
121 return f"<r{self.start}>"
122 return f"<r{self.start}..len={self.length}>"
123
124
125 SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
126
127
128 @final
129 @unique
130 class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
131 CA = "CA"
132
133 def conflicts(self, other):
134 # type: (RegLoc) -> bool
135 if isinstance(other, XERBit):
136 return self == other
137 return False
138
139
140 @final
141 @unique
142 class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
143 """singleton representing all non-StackSlot memory -- treated as a single
144 physical register for register allocation purposes.
145 """
146 GlobalMem = "GlobalMem"
147
148 def conflicts(self, other):
149 # type: (RegLoc) -> bool
150 if isinstance(other, GlobalMem):
151 return self == other
152 return False
153
154
155 @final
156 @unique
157 class VL(RegLoc, Enum, metaclass=ABCEnumMeta):
158 VL_MAXVL = "VL_MAXVL"
159 """VL and MAXVL"""
160
161 def conflicts(self, other):
162 # type: (RegLoc) -> bool
163 if isinstance(other, VL):
164 return self == other
165 return False
166
167
168 @final
169 class RegClass(OFSet[RegLoc]):
170 """ an ordered set of registers.
171 earlier registers are preferred by the register allocator.
172 """
173
174 @lru_cache(maxsize=None, typed=True)
175 def max_conflicts_with(self, other):
176 # type: (RegClass | RegLoc) -> int
177 """the largest number of registers in `self` that a single register
178 from `other` can conflict with
179 """
180 if isinstance(other, RegClass):
181 return max(self.max_conflicts_with(i) for i in other)
182 else:
183 return sum(other.conflicts(i) for i in self)
184
185
186 @plain_data(frozen=True, unsafe_hash=True)
187 class RegType(metaclass=ABCMeta):
188 __slots__ = ()
189
190 @property
191 @abstractmethod
192 def reg_class(self):
193 # type: () -> RegClass
194 return ...
195
196
197 _RegType = TypeVar("_RegType", bound=RegType)
198 _RegLoc = TypeVar("_RegLoc", bound=RegLoc)
199
200
201 @plain_data(frozen=True, eq=False, repr=False)
202 @final
203 class GPRRangeType(RegType):
204 __slots__ = "length",
205
206 def __init__(self, length=1):
207 # type: (int) -> None
208 if length < 1 or length > GPR_COUNT:
209 raise ValueError("invalid length")
210 self.length = length
211
212 @staticmethod
213 @lru_cache(maxsize=None)
214 def __get_reg_class(length):
215 # type: (int) -> RegClass
216 regs = []
217 for start in range(GPR_COUNT - length):
218 reg = GPRRange(start, length)
219 if any(i in reg for i in SPECIAL_GPRS):
220 continue
221 regs.append(reg)
222 return RegClass(regs)
223
224 @property
225 @final
226 def reg_class(self):
227 # type: () -> RegClass
228 return GPRRangeType.__get_reg_class(self.length)
229
230 @final
231 def __eq__(self, other):
232 if isinstance(other, GPRRangeType):
233 return self.length == other.length
234 return False
235
236 @final
237 def __hash__(self):
238 return hash(self.length)
239
240 def __repr__(self):
241 return f"<gpr_ty[{self.length}]>"
242
243
244 GPRType = GPRRangeType
245 """a length=1 GPRRangeType"""
246
247
248 @plain_data(frozen=True, unsafe_hash=True, repr=False)
249 @final
250 class FixedGPRRangeType(RegType):
251 __slots__ = "reg",
252
253 def __init__(self, reg):
254 # type: (GPRRange) -> None
255 self.reg = reg
256
257 @property
258 def reg_class(self):
259 # type: () -> RegClass
260 return RegClass([self.reg])
261
262 @property
263 def length(self):
264 # type: () -> int
265 return self.reg.length
266
267 def __repr__(self):
268 return f"<fixed({self.reg})>"
269
270
271 @plain_data(frozen=True, unsafe_hash=True)
272 @final
273 class CAType(RegType):
274 __slots__ = ()
275
276 @property
277 def reg_class(self):
278 # type: () -> RegClass
279 return RegClass([XERBit.CA])
280
281
282 @plain_data(frozen=True, unsafe_hash=True)
283 @final
284 class GlobalMemType(RegType):
285 __slots__ = ()
286
287 @property
288 def reg_class(self):
289 # type: () -> RegClass
290 return RegClass([GlobalMem.GlobalMem])
291
292
293 @plain_data(frozen=True, unsafe_hash=True)
294 @final
295 class KnownVLType(RegType):
296 __slots__ = "length",
297
298 def __init__(self, length):
299 # type: (int) -> None
300 if not (0 < length <= 64):
301 raise ValueError("invalid VL value")
302 self.length = length
303
304 @property
305 def reg_class(self):
306 # type: () -> RegClass
307 return RegClass([VL.VL_MAXVL])
308
309
310 def assert_vl_is(vl, expected_vl):
311 # type: (SSAKnownVL | KnownVLType | int | None, int) -> None
312 if vl is None:
313 vl = 1
314 elif isinstance(vl, SSAVal):
315 vl = vl.ty.length
316 elif isinstance(vl, KnownVLType):
317 vl = vl.length
318 if vl != expected_vl:
319 raise ValueError(
320 f"wrong VL: expected {expected_vl} got {vl}")
321
322
323 STACK_SLOT_SIZE = 8
324
325
326 @plain_data(frozen=True, unsafe_hash=True)
327 @final
328 class StackSlot(RegLoc):
329 __slots__ = "start_slot", "length_in_slots",
330
331 def __init__(self, start_slot, length_in_slots):
332 # type: (int, int) -> None
333 self.start_slot = start_slot
334 if length_in_slots < 1:
335 raise ValueError("invalid length_in_slots")
336 self.length_in_slots = length_in_slots
337
338 @property
339 def stop_slot(self):
340 return self.start_slot + self.length_in_slots
341
342 @property
343 def start_byte(self):
344 return self.start_slot * STACK_SLOT_SIZE
345
346 def conflicts(self, other):
347 # type: (RegLoc) -> bool
348 if isinstance(other, StackSlot):
349 return (self.stop_slot > other.start_slot
350 and other.stop_slot > self.start_slot)
351 return False
352
353 def get_subreg_at_offset(self, subreg_type, offset):
354 # type: (RegType, int) -> StackSlot
355 if not isinstance(subreg_type, StackSlotType):
356 raise ValueError(f"subreg_type is not a "
357 f"StackSlotType: {subreg_type}")
358 if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
359 raise ValueError(f"sub-register offset is out of range: {offset}")
360 return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
361
362
363 STACK_SLOT_COUNT = 128
364
365
366 @plain_data(frozen=True, eq=False)
367 @final
368 class StackSlotType(RegType):
369 __slots__ = "length_in_slots",
370
371 def __init__(self, length_in_slots=1):
372 # type: (int) -> None
373 if length_in_slots < 1:
374 raise ValueError("invalid length_in_slots")
375 self.length_in_slots = length_in_slots
376
377 @staticmethod
378 @lru_cache(maxsize=None)
379 def __get_reg_class(length_in_slots):
380 # type: (int) -> RegClass
381 regs = []
382 for start in range(STACK_SLOT_COUNT - length_in_slots):
383 reg = StackSlot(start, length_in_slots)
384 regs.append(reg)
385 return RegClass(regs)
386
387 @property
388 def reg_class(self):
389 # type: () -> RegClass
390 return StackSlotType.__get_reg_class(self.length_in_slots)
391
392 @final
393 def __eq__(self, other):
394 if isinstance(other, StackSlotType):
395 return self.length_in_slots == other.length_in_slots
396 return False
397
398 @final
399 def __hash__(self):
400 return hash(self.length_in_slots)
401
402
403 @plain_data(frozen=True, eq=False, repr=False)
404 @final
405 class SSAVal(Generic[_RegType]):
406 __slots__ = "op", "arg_name", "ty",
407
408 def __init__(self, op, arg_name, ty):
409 # type: (Op, str, _RegType) -> None
410 self.op = op
411 """the Op that writes this SSAVal"""
412
413 self.arg_name = arg_name
414 """the name of the argument of self.op that writes this SSAVal"""
415
416 self.ty = ty
417
418 def __eq__(self, rhs):
419 if isinstance(rhs, SSAVal):
420 return (self.op is rhs.op
421 and self.arg_name == rhs.arg_name)
422 return False
423
424 def __hash__(self):
425 return hash((id(self.op), self.arg_name))
426
427 def __repr__(self):
428 return f"<#{self.op.id}.{self.arg_name}: {self.ty}>"
429
430
431 SSAGPRRange = SSAVal[GPRRangeType]
432 SSAGPR = SSAVal[GPRType]
433 SSAKnownVL = SSAVal[KnownVLType]
434
435
436 @final
437 @plain_data(unsafe_hash=True, frozen=True)
438 class EqualityConstraint:
439 __slots__ = "lhs", "rhs"
440
441 def __init__(self, lhs, rhs):
442 # type: (list[SSAVal], list[SSAVal]) -> None
443 self.lhs = lhs
444 self.rhs = rhs
445 if len(lhs) == 0 or len(rhs) == 0:
446 raise ValueError("can't constrain an empty list to be equal")
447
448
449 @final
450 class Fn:
451 __slots__ = "ops",
452
453 def __init__(self):
454 # type: () -> None
455 self.ops = [] # type: list[Op]
456
457 def __repr__(self, short=False):
458 if short:
459 return "<Fn>"
460 ops = ", ".join(op.__repr__(just_id=True) for op in self.ops)
461 return f"<Fn([{ops}])>"
462
463 def pre_ra_sim(self, state):
464 # type: (PreRASimState) -> None
465 for op in self.ops:
466 op.pre_ra_sim(state)
467
468
469 class _NotSet:
470 """ helper for __repr__ for when fields aren't set """
471
472 def __repr__(self):
473 return "<not set>"
474
475
476 _NOT_SET = _NotSet()
477
478
479 @final
480 class AsmContext:
481 def __init__(self, assigned_registers):
482 # type: (dict[SSAVal, RegLoc]) -> None
483 self.__assigned_registers = assigned_registers
484
485 def reg(self, ssa_val, expected_ty):
486 # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc
487 try:
488 reg = self.__assigned_registers[ssa_val]
489 except KeyError as e:
490 raise ValueError(f"SSAVal not assigned a register: {ssa_val}")
491 wrong_len = (isinstance(reg, GPRRange)
492 and reg.length != ssa_val.ty.length)
493 if not isinstance(reg, expected_ty) or wrong_len:
494 raise TypeError(
495 f"SSAVal is assigned a register of the wrong type: "
496 f"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}")
497 return reg
498
499 def gpr_range(self, ssa_val):
500 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange
501 return self.reg(ssa_val, GPRRange)
502
503 def stack_slot(self, ssa_val):
504 # type: (SSAVal[StackSlotType]) -> StackSlot
505 return self.reg(ssa_val, StackSlot)
506
507 def gpr(self, ssa_val, vec, offset=0):
508 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str
509 reg = self.gpr_range(ssa_val).start + offset
510 return "*" * vec + str(reg)
511
512 def vgpr(self, ssa_val, offset=0):
513 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str
514 return self.gpr(ssa_val=ssa_val, vec=True, offset=offset)
515
516 def sgpr(self, ssa_val, offset=0):
517 # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str
518 return self.gpr(ssa_val=ssa_val, vec=False, offset=offset)
519
520 def needs_sv(self, *regs):
521 # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool
522 for reg in regs:
523 reg = self.gpr_range(reg)
524 if reg.length != 1 or reg.start >= 32:
525 return True
526 return False
527
528
529 GPR_SIZE_IN_BYTES = 8
530 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * 8
531 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
532
533
534 @plain_data(frozen=True)
535 @final
536 class PreRASimState:
537 __slots__ = ("gprs", "VLs", "CAs",
538 "global_mems", "stack_slots",
539 "fixed_gprs")
540
541 def __init__(
542 self,
543 gprs, # type: dict[SSAGPRRange, tuple[int, ...]]
544 VLs, # type: dict[SSAKnownVL, int]
545 CAs, # type: dict[SSAVal[CAType], bool]
546 global_mems, # type: dict[SSAVal[GlobalMemType], FMap[int, int]]
547 stack_slots, # type: dict[SSAVal[StackSlotType], tuple[int, ...]]
548 fixed_gprs, # type: dict[SSAVal[FixedGPRRangeType], tuple[int, ...]]
549 ):
550 # type: (...) -> None
551 self.gprs = gprs
552 self.VLs = VLs
553 self.CAs = CAs
554 self.global_mems = global_mems
555 self.stack_slots = stack_slots
556 self.fixed_gprs = fixed_gprs
557
558
559 @plain_data(unsafe_hash=True, frozen=True, repr=False)
560 class Op(metaclass=ABCMeta):
561 __slots__ = "id", "fn"
562
563 @abstractmethod
564 def inputs(self):
565 # type: () -> dict[str, SSAVal]
566 ...
567
568 @abstractmethod
569 def outputs(self):
570 # type: () -> dict[str, SSAVal]
571 ...
572
573 def get_equality_constraints(self):
574 # type: () -> Iterable[EqualityConstraint]
575 if False:
576 yield ...
577
578 def get_extra_interferences(self):
579 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
580 if False:
581 yield ...
582
583 def __init__(self, fn):
584 # type: (Fn) -> None
585 self.id = len(fn.ops)
586 fn.ops.append(self)
587 self.fn = fn
588
589 @final
590 def __repr__(self, just_id=False):
591 fields_list = [f"#{self.id}"]
592 outputs = None
593 try:
594 outputs = self.outputs()
595 except AttributeError:
596 pass
597 if not just_id:
598 for name in fields(self):
599 if name in ("id", "fn"):
600 continue
601 v = getattr(self, name, _NOT_SET)
602 if (outputs is not None and name in outputs
603 and outputs[name] is v):
604 fields_list.append(repr(v))
605 else:
606 fields_list.append(f"{name}={v!r}")
607 fields_str = ', '.join(fields_list)
608 return f"{self.__class__.__name__}({fields_str})"
609
610 @abstractmethod
611 def get_asm_lines(self, ctx):
612 # type: (AsmContext) -> list[str]
613 """get the lines of assembly for this Op"""
614 ...
615
616 @abstractmethod
617 def pre_ra_sim(self, state):
618 # type: (PreRASimState) -> None
619 """simulate op before register allocation"""
620 ...
621
622
623 @plain_data(unsafe_hash=True, frozen=True, repr=False)
624 @final
625 class OpLoadFromStackSlot(Op):
626 __slots__ = "dest", "src", "vl"
627
628 def inputs(self):
629 # type: () -> dict[str, SSAVal]
630 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
631 if self.vl is not None:
632 retval["vl"] = self.vl
633 return retval
634
635 def outputs(self):
636 # type: () -> dict[str, SSAVal]
637 return {"dest": self.dest}
638
639 def __init__(self, fn, src, vl=None):
640 # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None
641 super().__init__(fn)
642 self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
643 self.src = src
644 self.vl = vl
645 assert_vl_is(vl, self.dest.ty.length)
646
647 def get_asm_lines(self, ctx):
648 # type: (AsmContext) -> list[str]
649 dest = ctx.gpr(self.dest, vec=self.dest.ty.length != 1)
650 src = ctx.stack_slot(self.src)
651 if ctx.needs_sv(self.dest):
652 return [f"sv.ld {dest}, {src.start_byte}(1)"]
653 return [f"ld {dest}, {src.start_byte}(1)"]
654
655 def pre_ra_sim(self, state):
656 # type: (PreRASimState) -> None
657 """simulate op before register allocation"""
658 state.gprs[self.dest] = state.stack_slots[self.src]
659
660
661 @plain_data(unsafe_hash=True, frozen=True, repr=False)
662 @final
663 class OpStoreToStackSlot(Op):
664 __slots__ = "dest", "src", "vl"
665
666 def inputs(self):
667 # type: () -> dict[str, SSAVal]
668 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
669 if self.vl is not None:
670 retval["vl"] = self.vl
671 return retval
672
673 def outputs(self):
674 # type: () -> dict[str, SSAVal]
675 return {"dest": self.dest}
676
677 def __init__(self, fn, src, vl=None):
678 # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None
679 super().__init__(fn)
680 self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
681 self.src = src
682 self.vl = vl
683 assert_vl_is(vl, src.ty.length)
684
685 def get_asm_lines(self, ctx):
686 # type: (AsmContext) -> list[str]
687 src = ctx.gpr(self.src, vec=self.src.ty.length != 1)
688 dest = ctx.stack_slot(self.dest)
689 if ctx.needs_sv(self.src):
690 return [f"sv.std {src}, {dest.start_byte}(1)"]
691 return [f"std {src}, {dest.start_byte}(1)"]
692
693 def pre_ra_sim(self, state):
694 # type: (PreRASimState) -> None
695 """simulate op before register allocation"""
696 state.stack_slots[self.dest] = state.gprs[self.src]
697
698
699 _RegSrcType = TypeVar("_RegSrcType", bound=RegType)
700
701
702 @plain_data(unsafe_hash=True, frozen=True, repr=False)
703 @final
704 class OpCopy(Op, Generic[_RegSrcType, _RegType]):
705 __slots__ = "dest", "src", "vl"
706
707 def inputs(self):
708 # type: () -> dict[str, SSAVal]
709 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
710 if self.vl is not None:
711 retval["vl"] = self.vl
712 return retval
713
714 def outputs(self):
715 # type: () -> dict[str, SSAVal]
716 return {"dest": self.dest}
717
718 def __init__(self, fn, src, dest_ty=None, vl=None):
719 # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None
720 super().__init__(fn)
721 if dest_ty is None:
722 dest_ty = cast(_RegType, src.ty)
723 if isinstance(src.ty, GPRRangeType) \
724 and isinstance(dest_ty, FixedGPRRangeType):
725 if src.ty.length != dest_ty.reg.length:
726 raise ValueError(f"incompatible source and destination "
727 f"types: {src.ty} and {dest_ty}")
728 length = src.ty.length
729 elif isinstance(src.ty, FixedGPRRangeType) \
730 and isinstance(dest_ty, GPRRangeType):
731 if src.ty.reg.length != dest_ty.length:
732 raise ValueError(f"incompatible source and destination "
733 f"types: {src.ty} and {dest_ty}")
734 length = src.ty.length
735 elif src.ty != dest_ty:
736 raise ValueError(f"incompatible source and destination "
737 f"types: {src.ty} and {dest_ty}")
738 elif isinstance(src.ty, StackSlotType):
739 raise ValueError("can't use OpCopy on stack slots")
740 elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)):
741 length = src.ty.length
742 else:
743 length = 1
744
745 self.dest = SSAVal(self, "dest", dest_ty) # type: SSAVal[_RegType]
746 self.src = src
747 self.vl = vl
748 assert_vl_is(vl, length)
749
750 def get_asm_lines(self, ctx):
751 # type: (AsmContext) -> list[str]
752 if ctx.reg(self.src, RegLoc) == ctx.reg(self.dest, RegLoc):
753 return []
754 if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
755 isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
756 vec = self.dest.ty.length != 1
757 dest = ctx.gpr_range(self.dest) # type: ignore
758 src = ctx.gpr_range(self.src) # type: ignore
759 dest_s = ctx.gpr(self.dest, vec=vec) # type: ignore
760 src_s = ctx.gpr(self.src, vec=vec) # type: ignore
761 mrr = ""
762 if src.conflicts(dest) and src.start > dest.start:
763 mrr = "/mrr"
764 if ctx.needs_sv(self.src, self.dest): # type: ignore
765 return [f"sv.or{mrr} {dest_s}, {src_s}, {src_s}"]
766 return [f"or {dest_s}, {src_s}, {src_s}"]
767 raise NotImplementedError
768
769 def pre_ra_sim(self, state):
770 # type: (PreRASimState) -> None
771 if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
772 isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
773 if isinstance(self.src.ty, GPRRangeType):
774 v = state.gprs[self.src] # type: ignore
775 else:
776 v = state.fixed_gprs[self.src] # type: ignore
777 if isinstance(self.dest.ty, GPRRangeType):
778 state.gprs[self.dest] = v # type: ignore
779 else:
780 state.fixed_gprs[self.dest] = v # type: ignore
781 elif (isinstance(self.src.ty, FixedGPRRangeType) and
782 isinstance(self.dest.ty, GPRRangeType)):
783 state.gprs[self.dest] = state.fixed_gprs[self.src] # type: ignore
784 elif (isinstance(self.src.ty, GPRRangeType) and
785 isinstance(self.dest.ty, FixedGPRRangeType)):
786 state.fixed_gprs[self.dest] = state.gprs[self.src] # type: ignore
787 elif (isinstance(self.src.ty, CAType) and
788 self.src.ty == self.dest.ty):
789 state.CAs[self.dest] = state.CAs[self.src] # type: ignore
790 elif (isinstance(self.src.ty, KnownVLType) and
791 self.src.ty == self.dest.ty):
792 state.VLs[self.dest] = state.VLs[self.src] # type: ignore
793 elif (isinstance(self.src.ty, GlobalMemType) and
794 self.src.ty == self.dest.ty):
795 v = state.global_mems[self.src] # type: ignore
796 state.global_mems[self.dest] = v # type: ignore
797 else:
798 raise NotImplementedError
799
800
801 @plain_data(unsafe_hash=True, frozen=True, repr=False)
802 @final
803 class OpConcat(Op):
804 __slots__ = "dest", "sources"
805
806 def inputs(self):
807 # type: () -> dict[str, SSAVal]
808 return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
809
810 def outputs(self):
811 # type: () -> dict[str, SSAVal]
812 return {"dest": self.dest}
813
814 def __init__(self, fn, sources):
815 # type: (Fn, Iterable[SSAGPRRange]) -> None
816 super().__init__(fn)
817 sources = tuple(sources)
818 self.dest = SSAVal(self, "dest", GPRRangeType(
819 sum(i.ty.length for i in sources)))
820 self.sources = sources
821
822 def get_equality_constraints(self):
823 # type: () -> Iterable[EqualityConstraint]
824 yield EqualityConstraint([self.dest], [*self.sources])
825
826 def get_asm_lines(self, ctx):
827 # type: (AsmContext) -> list[str]
828 return []
829
830 def pre_ra_sim(self, state):
831 # type: (PreRASimState) -> None
832 v = []
833 for src in self.sources:
834 v.extend(state.gprs[src])
835 state.gprs[self.dest] = tuple(v)
836
837
838 @plain_data(unsafe_hash=True, frozen=True, repr=False)
839 @final
840 class OpSplit(Op):
841 __slots__ = "results", "src"
842
843 def inputs(self):
844 # type: () -> dict[str, SSAVal]
845 return {"src": self.src}
846
847 def outputs(self):
848 # type: () -> dict[str, SSAVal]
849 return {i.arg_name: i for i in self.results}
850
851 def __init__(self, fn, src, split_indexes):
852 # type: (Fn, SSAGPRRange, Iterable[int]) -> None
853 super().__init__(fn)
854 ranges = [] # type: list[GPRRangeType]
855 last = 0
856 for i in split_indexes:
857 if not (0 < i < src.ty.length):
858 raise ValueError(f"invalid split index: {i}, must be in "
859 f"0 < i < {src.ty.length}")
860 ranges.append(GPRRangeType(i - last))
861 last = i
862 ranges.append(GPRRangeType(src.ty.length - last))
863 self.src = src
864 self.results = tuple(
865 SSAVal(self, f"results[{i}]", r) for i, r in enumerate(ranges))
866
867 def get_equality_constraints(self):
868 # type: () -> Iterable[EqualityConstraint]
869 yield EqualityConstraint([*self.results], [self.src])
870
871 def get_asm_lines(self, ctx):
872 # type: (AsmContext) -> list[str]
873 return []
874
875 def pre_ra_sim(self, state):
876 # type: (PreRASimState) -> None
877 rest = state.gprs[self.src]
878 for dest in reversed(self.results):
879 state.gprs[dest] = rest[-dest.ty.length:]
880 rest = rest[:-dest.ty.length]
881
882
883 @plain_data(unsafe_hash=True, frozen=True, repr=False)
884 @final
885 class OpBigIntAddSub(Op):
886 __slots__ = "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl"
887
888 def inputs(self):
889 # type: () -> dict[str, SSAVal]
890 retval = {} # type: dict[str, SSAVal[Any]]
891 retval["lhs"] = self.lhs
892 retval["rhs"] = self.rhs
893 retval["CA_in"] = self.CA_in
894 if self.vl is not None:
895 retval["vl"] = self.vl
896 return retval
897
898 def outputs(self):
899 # type: () -> dict[str, SSAVal]
900 return {"out": self.out, "CA_out": self.CA_out}
901
902 def __init__(self, fn, lhs, rhs, CA_in, is_sub, vl=None):
903 # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None
904 super().__init__(fn)
905 if lhs.ty != rhs.ty:
906 raise TypeError(f"source types must match: "
907 f"{lhs} doesn't match {rhs}")
908 self.out = SSAVal(self, "out", lhs.ty)
909 self.lhs = lhs
910 self.rhs = rhs
911 self.CA_in = CA_in
912 self.CA_out = SSAVal(self, "CA_out", CA_in.ty)
913 self.is_sub = is_sub
914 self.vl = vl
915 assert_vl_is(vl, lhs.ty.length)
916
917 def get_extra_interferences(self):
918 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
919 yield self.out, self.lhs
920 yield self.out, self.rhs
921
922 def get_asm_lines(self, ctx):
923 # type: (AsmContext) -> list[str]
924 vec = self.out.ty.length != 1
925 out = ctx.gpr(self.out, vec=vec)
926 RA = ctx.gpr(self.lhs, vec=vec)
927 RB = ctx.gpr(self.rhs, vec=vec)
928 mnemonic = "adde"
929 if self.is_sub:
930 mnemonic = "subfe"
931 RA, RB = RB, RA # reorder to match subfe
932 if ctx.needs_sv(self.out, self.lhs, self.rhs):
933 return [f"sv.{mnemonic} {out}, {RA}, {RB}"]
934 return [f"{mnemonic} {out}, {RA}, {RB}"]
935
936 def pre_ra_sim(self, state):
937 # type: (PreRASimState) -> None
938 carry = state.CAs[self.CA_in]
939 out = [] # type: list[int]
940 for l, r in zip(state.gprs[self.lhs], state.gprs[self.rhs]):
941 if self.is_sub:
942 r = r ^ GPR_VALUE_MASK
943 s = l + r + carry
944 carry = s != (s & GPR_VALUE_MASK)
945 out.append(s & GPR_VALUE_MASK)
946 state.CAs[self.CA_out] = carry
947 state.gprs[self.out] = tuple(out)
948
949
950 @plain_data(unsafe_hash=True, frozen=True, repr=False)
951 @final
952 class OpBigIntMulDiv(Op):
953 __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div", "vl"
954
955 def inputs(self):
956 # type: () -> dict[str, SSAVal]
957 retval = {} # type: dict[str, SSAVal[Any]]
958 retval["RA"] = self.RA
959 retval["RB"] = self.RB
960 retval["RC"] = self.RC
961 if self.vl is not None:
962 retval["vl"] = self.vl
963 return retval
964
965 def outputs(self):
966 # type: () -> dict[str, SSAVal]
967 return {"RT": self.RT, "RS": self.RS}
968
969 def __init__(self, fn, RA, RB, RC, is_div, vl):
970 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None
971 super().__init__(fn)
972 self.RT = SSAVal(self, "RT", RA.ty)
973 self.RA = RA
974 self.RB = RB
975 self.RC = RC
976 self.RS = SSAVal(self, "RS", RC.ty)
977 self.is_div = is_div
978 self.vl = vl
979 assert_vl_is(vl, RA.ty.length)
980
981 def get_equality_constraints(self):
982 # type: () -> Iterable[EqualityConstraint]
983 yield EqualityConstraint([self.RC], [self.RS])
984
985 def get_extra_interferences(self):
986 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
987 yield self.RT, self.RA
988 yield self.RT, self.RB
989 yield self.RT, self.RC
990 yield self.RT, self.RS
991 yield self.RS, self.RA
992 yield self.RS, self.RB
993
994 def get_asm_lines(self, ctx):
995 # type: (AsmContext) -> list[str]
996 vec = self.RT.ty.length != 1
997 RT = ctx.gpr(self.RT, vec=vec)
998 RA = ctx.gpr(self.RA, vec=vec)
999 RB = ctx.sgpr(self.RB)
1000 RC = ctx.sgpr(self.RC)
1001 mnemonic = "maddedu"
1002 if self.is_div:
1003 mnemonic = "divmod2du/mrr"
1004 return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
1005
1006 def pre_ra_sim(self, state):
1007 # type: (PreRASimState) -> None
1008 carry = state.gprs[self.RC][0]
1009 RA = state.gprs[self.RA]
1010 RB = state.gprs[self.RB][0]
1011 RT = [0] * self.RT.ty.length
1012 if self.is_div:
1013 for i in reversed(range(self.RT.ty.length)):
1014 if carry < RB and RB != 0:
1015 div, mod = divmod((carry << 64) | RA[i], RB)
1016 RT[i] = div & GPR_VALUE_MASK
1017 carry = mod & GPR_VALUE_MASK
1018 else:
1019 RT[i] = GPR_VALUE_MASK
1020 carry = 0
1021 else:
1022 for i in range(self.RT.ty.length):
1023 v = RA[i] * RB + carry
1024 carry = v >> 64
1025 RT[i] = v & GPR_VALUE_MASK
1026 state.gprs[self.RS] = carry,
1027 state.gprs[self.RT] = tuple(RT)
1028
1029
1030 @final
1031 @unique
1032 class ShiftKind(Enum):
1033 Sl = "sl"
1034 Sr = "sr"
1035 Sra = "sra"
1036
1037 def make_big_int_carry_in(self, fn, inp):
1038 # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]]
1039 if self is ShiftKind.Sl or self is ShiftKind.Sr:
1040 li = OpLI(fn, 0)
1041 return li.out, [li]
1042 else:
1043 assert self is ShiftKind.Sra
1044 split = OpSplit(fn, inp, [inp.ty.length - 1])
1045 shr = OpShiftImm(fn, split.results[1], sh=63, kind=ShiftKind.Sra)
1046 return shr.out, [split, shr]
1047
1048 def make_big_int_shift(self, fn, inp, sh, vl):
1049 # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]]
1050 carry_in, ops = self.make_big_int_carry_in(fn, inp)
1051 big_int_shift = OpBigIntShift(fn, inp, sh, carry_in, kind=self, vl=vl)
1052 ops.append(big_int_shift)
1053 return big_int_shift.out, ops
1054
1055
1056 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1057 @final
1058 class OpBigIntShift(Op):
1059 __slots__ = "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl"
1060
1061 def inputs(self):
1062 # type: () -> dict[str, SSAVal]
1063 retval = {} # type: dict[str, SSAVal[Any]]
1064 retval["inp"] = self.inp
1065 retval["sh"] = self.sh
1066 retval["carry_in"] = self.carry_in
1067 if self.vl is not None:
1068 retval["vl"] = self.vl
1069 return retval
1070
1071 def outputs(self):
1072 # type: () -> dict[str, SSAVal]
1073 return {"out": self.out, "_out_padding": self._out_padding}
1074
1075 def __init__(self, fn, inp, sh, carry_in, kind, vl=None):
1076 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None
1077 super().__init__(fn)
1078 self.out = SSAVal(self, "out", inp.ty)
1079 self._out_padding = SSAVal(self, "_out_padding", GPRRangeType())
1080 self.carry_in = carry_in
1081 self.inp = inp
1082 self.sh = sh
1083 self.kind = kind
1084 self.vl = vl
1085 assert_vl_is(vl, inp.ty.length)
1086
1087 def get_extra_interferences(self):
1088 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1089 yield self.out, self.sh
1090
1091 def get_equality_constraints(self):
1092 # type: () -> Iterable[EqualityConstraint]
1093 if self.kind is ShiftKind.Sl:
1094 yield EqualityConstraint([self.carry_in, self.inp],
1095 [self.out, self._out_padding])
1096 else:
1097 assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
1098 yield EqualityConstraint([self.inp, self.carry_in],
1099 [self._out_padding, self.out])
1100
1101 def get_asm_lines(self, ctx):
1102 # type: (AsmContext) -> list[str]
1103 vec = self.out.ty.length != 1
1104 if self.kind is ShiftKind.Sl:
1105 RT = ctx.gpr(self.out, vec=vec)
1106 RA = ctx.gpr(self.out, vec=vec, offset=-1)
1107 RB = ctx.sgpr(self.sh)
1108 mrr = "/mrr" if vec else ""
1109 return [f"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"]
1110 else:
1111 assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
1112 RT = ctx.gpr(self.out, vec=vec)
1113 RA = ctx.gpr(self.out, vec=vec, offset=1)
1114 RB = ctx.sgpr(self.sh)
1115 return [f"sv.dsrd {RT}, {RA}, {RB}, 1"]
1116
1117 def pre_ra_sim(self, state):
1118 # type: (PreRASimState) -> None
1119 out = [0] * self.out.ty.length
1120 carry = state.gprs[self.carry_in][0]
1121 sh = state.gprs[self.sh][0] % 64
1122 if self.kind is ShiftKind.Sl:
1123 inp = carry, *state.gprs[self.inp]
1124 for i in reversed(range(self.out.ty.length)):
1125 v = inp[i] | (inp[i + 1] << 64)
1126 v <<= sh
1127 out[i] = (v >> 64) & GPR_VALUE_MASK
1128 else:
1129 assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
1130 inp = *state.gprs[self.inp], carry
1131 for i in range(self.out.ty.length):
1132 v = inp[i] | (inp[i + 1] << 64)
1133 v >>= sh
1134 out[i] = v & GPR_VALUE_MASK
1135 # state.gprs[self._out_padding] is intentionally not written
1136 state.gprs[self.out] = tuple(out)
1137
1138
1139 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1140 @final
1141 class OpShiftImm(Op):
1142 __slots__ = "out", "inp", "sh", "kind", "ca_out"
1143
1144 def inputs(self):
1145 # type: () -> dict[str, SSAVal]
1146 return {"inp": self.inp}
1147
1148 def outputs(self):
1149 # type: () -> dict[str, SSAVal]
1150 if self.ca_out is not None:
1151 return {"out": self.out, "ca_out": self.ca_out}
1152 return {"out": self.out}
1153
1154 def __init__(self, fn, inp, sh, kind):
1155 # type: (Fn, SSAGPR, int, ShiftKind) -> None
1156 super().__init__(fn)
1157 self.out = SSAVal(self, "out", inp.ty)
1158 self.inp = inp
1159 if not (0 <= sh < 64):
1160 raise ValueError("shift amount out of range")
1161 self.sh = sh
1162 self.kind = kind
1163 if self.kind is ShiftKind.Sra:
1164 self.ca_out = SSAVal(self, "ca_out", CAType())
1165 else:
1166 self.ca_out = None
1167
1168 def get_asm_lines(self, ctx):
1169 # type: (AsmContext) -> list[str]
1170 out = ctx.sgpr(self.out)
1171 inp = ctx.sgpr(self.inp)
1172 if self.kind is ShiftKind.Sl:
1173 mnemonic = "rldicr"
1174 args = f"{self.sh}, {63 - self.sh}"
1175 elif self.kind is ShiftKind.Sr:
1176 mnemonic = "rldicl"
1177 v = (64 - self.sh) % 64
1178 args = f"{v}, {self.sh}"
1179 else:
1180 assert self.kind is ShiftKind.Sra
1181 mnemonic = "sradi"
1182 args = f"{self.sh}"
1183 if ctx.needs_sv(self.out, self.inp):
1184 return [f"sv.{mnemonic} {out}, {inp}, {args}"]
1185 return [f"{mnemonic} {out}, {inp}, {args}"]
1186
1187 def pre_ra_sim(self, state):
1188 # type: (PreRASimState) -> None
1189 inp = state.gprs[self.inp][0]
1190 if self.kind is ShiftKind.Sl:
1191 assert self.ca_out is None
1192 out = inp << self.sh
1193 elif self.kind is ShiftKind.Sr:
1194 assert self.ca_out is None
1195 out = inp >> self.sh
1196 else:
1197 assert self.kind is ShiftKind.Sra
1198 assert self.ca_out is not None
1199 if inp & (1 << 63): # sign extend
1200 inp -= 1 << 64
1201 out = inp >> self.sh
1202 ca = inp < 0 and (out << self.sh) != inp
1203 state.CAs[self.ca_out] = ca
1204 state.gprs[self.out] = out,
1205
1206
1207 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1208 @final
1209 class OpLI(Op):
1210 __slots__ = "out", "value", "vl"
1211
1212 def inputs(self):
1213 # type: () -> dict[str, SSAVal]
1214 retval = {} # type: dict[str, SSAVal[Any]]
1215 if self.vl is not None:
1216 retval["vl"] = self.vl
1217 return retval
1218
1219 def outputs(self):
1220 # type: () -> dict[str, SSAVal]
1221 return {"out": self.out}
1222
1223 def __init__(self, fn, value, vl=None):
1224 # type: (Fn, int, SSAKnownVL | None) -> None
1225 super().__init__(fn)
1226 if vl is None:
1227 length = 1
1228 else:
1229 length = vl.ty.length
1230 self.out = SSAVal(self, "out", GPRRangeType(length))
1231 if not (-1 << 15 <= value <= (1 << 15) - 1):
1232 raise ValueError(f"value out of range: {value}")
1233 self.value = value
1234 self.vl = vl
1235 assert_vl_is(vl, length)
1236
1237 def get_asm_lines(self, ctx):
1238 # type: (AsmContext) -> list[str]
1239 vec = self.out.ty.length != 1
1240 out = ctx.gpr(self.out, vec=vec)
1241 if ctx.needs_sv(self.out):
1242 return [f"sv.addi {out}, 0, {self.value}"]
1243 return [f"addi {out}, 0, {self.value}"]
1244
1245 def pre_ra_sim(self, state):
1246 # type: (PreRASimState) -> None
1247 value = self.value & GPR_VALUE_MASK
1248 state.gprs[self.out] = (value,) * self.out.ty.length
1249
1250
1251 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1252 @final
1253 class OpSetCA(Op):
1254 __slots__ = "out", "value"
1255
1256 def inputs(self):
1257 # type: () -> dict[str, SSAVal]
1258 return {}
1259
1260 def outputs(self):
1261 # type: () -> dict[str, SSAVal]
1262 return {"out": self.out}
1263
1264 def __init__(self, fn, value):
1265 # type: (Fn, bool) -> None
1266 super().__init__(fn)
1267 self.out = SSAVal(self, "out", CAType())
1268 self.value = value
1269
1270 def get_asm_lines(self, ctx):
1271 # type: (AsmContext) -> list[str]
1272 if self.value:
1273 return ["subfic 0, 0, -1"]
1274 return ["addic 0, 0, 0"]
1275
1276 def pre_ra_sim(self, state):
1277 # type: (PreRASimState) -> None
1278 state.CAs[self.out] = self.value
1279
1280
1281 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1282 @final
1283 class OpLoad(Op):
1284 __slots__ = "RT", "RA", "offset", "mem", "vl"
1285
1286 def inputs(self):
1287 # type: () -> dict[str, SSAVal]
1288 retval = {} # type: dict[str, SSAVal[Any]]
1289 retval["RA"] = self.RA
1290 retval["mem"] = self.mem
1291 if self.vl is not None:
1292 retval["vl"] = self.vl
1293 return retval
1294
1295 def outputs(self):
1296 # type: () -> dict[str, SSAVal]
1297 return {"RT": self.RT}
1298
1299 def __init__(self, fn, RA, offset, mem, vl=None):
1300 # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1301 super().__init__(fn)
1302 if vl is None:
1303 length = 1
1304 else:
1305 length = vl.ty.length
1306 self.RT = SSAVal(self, "RT", GPRRangeType(length))
1307 self.RA = RA
1308 if not (-1 << 15 <= offset <= (1 << 15) - 1):
1309 raise ValueError(f"offset out of range: {offset}")
1310 if offset % 4 != 0:
1311 raise ValueError(f"offset not aligned: {offset}")
1312 self.offset = offset
1313 self.mem = mem
1314 self.vl = vl
1315 assert_vl_is(vl, length)
1316
1317 def get_extra_interferences(self):
1318 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1319 if self.RT.ty.length > 1:
1320 yield self.RT, self.RA
1321
1322 def get_asm_lines(self, ctx):
1323 # type: (AsmContext) -> list[str]
1324 RT = ctx.gpr(self.RT, vec=self.RT.ty.length != 1)
1325 RA = ctx.sgpr(self.RA)
1326 if ctx.needs_sv(self.RT, self.RA):
1327 return [f"sv.ld {RT}, {self.offset}({RA})"]
1328 return [f"ld {RT}, {self.offset}({RA})"]
1329
1330 def pre_ra_sim(self, state):
1331 # type: (PreRASimState) -> None
1332 addr = state.gprs[self.RA][0]
1333 addr += self.offset
1334 RT = [0] * self.RT.ty.length
1335 mem = state.global_mems[self.mem]
1336 for i in range(self.RT.ty.length):
1337 cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
1338 if cur_addr % GPR_SIZE_IN_BYTES != 0:
1339 raise ValueError(f"can't load from unaligned address: "
1340 f"{cur_addr:#x}")
1341 for j in range(GPR_SIZE_IN_BYTES):
1342 byte_val = mem.get(cur_addr + j, 0) & 0xFF
1343 RT[i] |= byte_val << (j * 8)
1344 state.gprs[self.RT] = tuple(RT)
1345
1346
1347 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1348 @final
1349 class OpStore(Op):
1350 __slots__ = "RS", "RA", "offset", "mem_in", "mem_out", "vl"
1351
1352 def inputs(self):
1353 # type: () -> dict[str, SSAVal]
1354 retval = {} # type: dict[str, SSAVal[Any]]
1355 retval["RS"] = self.RS
1356 retval["RA"] = self.RA
1357 retval["mem_in"] = self.mem_in
1358 if self.vl is not None:
1359 retval["vl"] = self.vl
1360 return retval
1361
1362 def outputs(self):
1363 # type: () -> dict[str, SSAVal]
1364 return {"mem_out": self.mem_out}
1365
1366 def __init__(self, fn, RS, RA, offset, mem_in, vl=None):
1367 # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1368 super().__init__(fn)
1369 self.RS = RS
1370 self.RA = RA
1371 if not (-1 << 15 <= offset <= (1 << 15) - 1):
1372 raise ValueError(f"offset out of range: {offset}")
1373 if offset % 4 != 0:
1374 raise ValueError(f"offset not aligned: {offset}")
1375 self.offset = offset
1376 self.mem_in = mem_in
1377 self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
1378 self.vl = vl
1379 assert_vl_is(vl, RS.ty.length)
1380
1381 def get_asm_lines(self, ctx):
1382 # type: (AsmContext) -> list[str]
1383 RS = ctx.gpr(self.RS, vec=self.RS.ty.length != 1)
1384 RA = ctx.sgpr(self.RA)
1385 if ctx.needs_sv(self.RS, self.RA):
1386 return [f"sv.std {RS}, {self.offset}({RA})"]
1387 return [f"std {RS}, {self.offset}({RA})"]
1388
1389 def pre_ra_sim(self, state):
1390 # type: (PreRASimState) -> None
1391 mem = dict(state.global_mems[self.mem_in])
1392 addr = state.gprs[self.RA][0]
1393 addr += self.offset
1394 RS = state.gprs[self.RS]
1395 for i in range(self.RS.ty.length):
1396 cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
1397 if cur_addr % GPR_SIZE_IN_BYTES != 0:
1398 raise ValueError(f"can't store to unaligned address: "
1399 f"{cur_addr:#x}")
1400 for j in range(GPR_SIZE_IN_BYTES):
1401 mem[cur_addr + j] = (RS[i] >> (j * 8)) & 0xFF
1402 state.global_mems[self.mem_out] = FMap(mem)
1403
1404
1405 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1406 @final
1407 class OpFuncArg(Op):
1408 __slots__ = "out",
1409
1410 def inputs(self):
1411 # type: () -> dict[str, SSAVal]
1412 return {}
1413
1414 def outputs(self):
1415 # type: () -> dict[str, SSAVal]
1416 return {"out": self.out}
1417
1418 def __init__(self, fn, ty):
1419 # type: (Fn, FixedGPRRangeType) -> None
1420 super().__init__(fn)
1421 self.out = SSAVal(self, "out", ty)
1422
1423 def get_asm_lines(self, ctx):
1424 # type: (AsmContext) -> list[str]
1425 return []
1426
1427 def pre_ra_sim(self, state):
1428 # type: (PreRASimState) -> None
1429 if self.out not in state.fixed_gprs:
1430 state.fixed_gprs[self.out] = (0,) * self.out.ty.length
1431
1432
1433 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1434 @final
1435 class OpInputMem(Op):
1436 __slots__ = "out",
1437
1438 def inputs(self):
1439 # type: () -> dict[str, SSAVal]
1440 return {}
1441
1442 def outputs(self):
1443 # type: () -> dict[str, SSAVal]
1444 return {"out": self.out}
1445
1446 def __init__(self, fn):
1447 # type: (Fn) -> None
1448 super().__init__(fn)
1449 self.out = SSAVal(self, "out", GlobalMemType())
1450
1451 def get_asm_lines(self, ctx):
1452 # type: (AsmContext) -> list[str]
1453 return []
1454
1455 def pre_ra_sim(self, state):
1456 # type: (PreRASimState) -> None
1457 if self.out not in state.global_mems:
1458 state.global_mems[self.out] = FMap()
1459
1460
1461 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1462 @final
1463 class OpSetVLImm(Op):
1464 __slots__ = "out",
1465
1466 def inputs(self):
1467 # type: () -> dict[str, SSAVal]
1468 return {}
1469
1470 def outputs(self):
1471 # type: () -> dict[str, SSAVal]
1472 return {"out": self.out}
1473
1474 def __init__(self, fn, length):
1475 # type: (Fn, int) -> None
1476 super().__init__(fn)
1477 self.out = SSAVal(self, "out", KnownVLType(length))
1478
1479 def get_asm_lines(self, ctx):
1480 # type: (AsmContext) -> list[str]
1481 return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
1482
1483 def pre_ra_sim(self, state):
1484 # type: (PreRASimState) -> None
1485 state.VLs[self.out] = self.out.ty.length
1486
1487
1488 def op_set_to_list(ops):
1489 # type: (Iterable[Op]) -> list[Op]
1490 worklists = [{}] # type: list[dict[Op, None]]
1491 inps_to_ops_map = defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
1492 ops_to_pending_input_count_map = {} # type: dict[Op, int]
1493 for op in ops:
1494 input_count = 0
1495 for val in op.inputs().values():
1496 input_count += 1
1497 inps_to_ops_map[val][op] = None
1498 while len(worklists) <= input_count:
1499 worklists.append({})
1500 ops_to_pending_input_count_map[op] = input_count
1501 worklists[input_count][op] = None
1502 retval = [] # type: list[Op]
1503 ready_vals = OSet() # type: OSet[SSAVal]
1504 while len(worklists[0]) != 0:
1505 writing_op = next(iter(worklists[0]))
1506 del worklists[0][writing_op]
1507 retval.append(writing_op)
1508 for val in writing_op.outputs().values():
1509 if val in ready_vals:
1510 raise ValueError(f"multiple instructions must not write "
1511 f"to the same SSA value: {val}")
1512 ready_vals.add(val)
1513 for reading_op in inps_to_ops_map[val]:
1514 pending = ops_to_pending_input_count_map[reading_op]
1515 del worklists[pending][reading_op]
1516 pending -= 1
1517 worklists[pending][reading_op] = None
1518 ops_to_pending_input_count_map[reading_op] = pending
1519 for worklist in worklists:
1520 for op in worklist:
1521 raise ValueError(f"instruction is part of a dependency loop or "
1522 f"its inputs are never written: {op}")
1523 return retval
1524
1525
1526 def generate_assembly(ops, assigned_registers=None):
1527 # type: (list[Op], dict[SSAVal, RegLoc] | None) -> list[str]
1528 if assigned_registers is None:
1529 from bigint_presentation_code.register_allocator import \
1530 allocate_registers
1531 assigned_registers = allocate_registers(ops)
1532 ctx = AsmContext(assigned_registers)
1533 retval = [] # list[str]
1534 for op in ops:
1535 retval.extend(op.get_asm_lines(ctx))
1536 retval.append("bclr 20, 0, 0")
1537 return retval