192x192->384-bit O(n^2) mul works in SSA form, reg-alloc gives incorrect results...
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 """
2 Compiler IR for Toom-Cook algorithm generator for SVP64
3
4 This assumes VL != 0 throughout.
5 """
6
7 from abc import ABCMeta, abstractmethod
8 from collections import defaultdict
9 from enum import Enum, EnumMeta, unique
10 from functools import lru_cache
11 from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast
12
13 from nmutil.plain_data import fields, plain_data
14
15 from bigint_presentation_code.util import FMap, OFSet, OSet, final
16
17
18 class ABCEnumMeta(EnumMeta, ABCMeta):
19 pass
20
21
22 class RegLoc(metaclass=ABCMeta):
23 __slots__ = ()
24
25 @abstractmethod
26 def conflicts(self, other):
27 # type: (RegLoc) -> bool
28 ...
29
30 def get_subreg_at_offset(self, subreg_type, offset):
31 # type: (RegType, int) -> RegLoc
32 if self not in subreg_type.reg_class:
33 raise ValueError(f"register not a member of subreg_type: "
34 f"reg={self} subreg_type={subreg_type}")
35 if offset != 0:
36 raise ValueError(f"non-zero sub-register offset not supported "
37 f"for register: {self}")
38 return self
39
40
41 GPR_COUNT = 128
42
43
44 @plain_data(frozen=True, unsafe_hash=True, repr=False)
45 @final
46 class GPRRange(RegLoc, Sequence["GPRRange"]):
47 __slots__ = "start", "length"
48
49 def __init__(self, start, length=None):
50 # type: (int | range, int | None) -> None
51 if isinstance(start, range):
52 if length is not None:
53 raise TypeError("can't specify length when input is a range")
54 if start.step != 1:
55 raise ValueError("range must have a step of 1")
56 length = len(start)
57 start = start.start
58 elif length is None:
59 length = 1
60 if length <= 0 or start < 0 or start + length > GPR_COUNT:
61 raise ValueError("invalid GPRRange")
62 self.start = start
63 self.length = length
64
65 @property
66 def stop(self):
67 return self.start + self.length
68
69 @property
70 def step(self):
71 return 1
72
73 @property
74 def range(self):
75 return range(self.start, self.stop, self.step)
76
77 def __len__(self):
78 return self.length
79
80 def __getitem__(self, item):
81 # type: (int | slice) -> GPRRange
82 return GPRRange(self.range[item])
83
84 def __contains__(self, value):
85 # type: (GPRRange) -> bool
86 return value.start >= self.start and value.stop <= self.stop
87
88 def index(self, sub, start=None, end=None):
89 # type: (GPRRange, int | None, int | None) -> int
90 r = self.range[start:end]
91 if sub.start < r.start or sub.stop > r.stop:
92 raise ValueError("GPR range not found")
93 return sub.start - self.start
94
95 def count(self, sub, start=None, end=None):
96 # type: (GPRRange, int | None, int | None) -> int
97 r = self.range[start:end]
98 if len(r) == 0:
99 return 0
100 return int(sub in GPRRange(r))
101
102 def conflicts(self, other):
103 # type: (RegLoc) -> bool
104 if isinstance(other, GPRRange):
105 return self.stop > other.start and other.stop > self.start
106 return False
107
108 def get_subreg_at_offset(self, subreg_type, offset):
109 # type: (RegType, int) -> GPRRange
110 if not isinstance(subreg_type, (GPRRangeType, FixedGPRRangeType)):
111 raise ValueError(f"subreg_type is not a FixedGPRRangeType or "
112 f"GPRRangeType: {subreg_type}")
113 if offset < 0 or offset + subreg_type.length > self.stop:
114 raise ValueError(f"sub-register offset is out of range: {offset}")
115 return GPRRange(self.start + offset, subreg_type.length)
116
117 def __repr__(self):
118 if self.length == 1:
119 return f"<r{self.start}>"
120 return f"<r{self.start}..len={self.length}>"
121
122
123 SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
124
125
126 @final
127 @unique
128 class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
129 CA = "CA"
130
131 def conflicts(self, other):
132 # type: (RegLoc) -> bool
133 if isinstance(other, XERBit):
134 return self == other
135 return False
136
137
138 @final
139 @unique
140 class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
141 """singleton representing all non-StackSlot memory -- treated as a single
142 physical register for register allocation purposes.
143 """
144 GlobalMem = "GlobalMem"
145
146 def conflicts(self, other):
147 # type: (RegLoc) -> bool
148 if isinstance(other, GlobalMem):
149 return self == other
150 return False
151
152
153 @final
154 @unique
155 class VL(RegLoc, Enum, metaclass=ABCEnumMeta):
156 VL_MAXVL = "VL_MAXVL"
157 """VL and MAXVL"""
158
159 def conflicts(self, other):
160 # type: (RegLoc) -> bool
161 if isinstance(other, VL):
162 return self == other
163 return False
164
165
166 @final
167 class RegClass(OFSet[RegLoc]):
168 """ an ordered set of registers.
169 earlier registers are preferred by the register allocator.
170 """
171
172 @lru_cache(maxsize=None, typed=True)
173 def max_conflicts_with(self, other):
174 # type: (RegClass | RegLoc) -> int
175 """the largest number of registers in `self` that a single register
176 from `other` can conflict with
177 """
178 if isinstance(other, RegClass):
179 return max(self.max_conflicts_with(i) for i in other)
180 else:
181 return sum(other.conflicts(i) for i in self)
182
183
184 @plain_data(frozen=True, unsafe_hash=True)
185 class RegType(metaclass=ABCMeta):
186 __slots__ = ()
187
188 @property
189 @abstractmethod
190 def reg_class(self):
191 # type: () -> RegClass
192 return ...
193
194
195 _RegType = TypeVar("_RegType", bound=RegType)
196 _RegLoc = TypeVar("_RegLoc", bound=RegLoc)
197
198
199 @plain_data(frozen=True, eq=False, repr=False)
200 @final
201 class GPRRangeType(RegType):
202 __slots__ = "length",
203
204 def __init__(self, length=1):
205 # type: (int) -> None
206 if length < 1 or length > GPR_COUNT:
207 raise ValueError("invalid length")
208 self.length = length
209
210 @staticmethod
211 @lru_cache(maxsize=None)
212 def __get_reg_class(length):
213 # type: (int) -> RegClass
214 regs = []
215 for start in range(GPR_COUNT - length):
216 reg = GPRRange(start, length)
217 if any(i in reg for i in SPECIAL_GPRS):
218 continue
219 regs.append(reg)
220 return RegClass(regs)
221
222 @property
223 @final
224 def reg_class(self):
225 # type: () -> RegClass
226 return GPRRangeType.__get_reg_class(self.length)
227
228 @final
229 def __eq__(self, other):
230 if isinstance(other, GPRRangeType):
231 return self.length == other.length
232 return False
233
234 @final
235 def __hash__(self):
236 return hash(self.length)
237
238 def __repr__(self):
239 return f"<gpr_ty[{self.length}]>"
240
241
242 GPRType = GPRRangeType
243 """a length=1 GPRRangeType"""
244
245
246 @plain_data(frozen=True, unsafe_hash=True, repr=False)
247 @final
248 class FixedGPRRangeType(RegType):
249 __slots__ = "reg",
250
251 def __init__(self, reg):
252 # type: (GPRRange) -> None
253 self.reg = reg
254
255 @property
256 def reg_class(self):
257 # type: () -> RegClass
258 return RegClass([self.reg])
259
260 @property
261 def length(self):
262 # type: () -> int
263 return self.reg.length
264
265 def __repr__(self):
266 return f"<fixed({self.reg})>"
267
268
269 @plain_data(frozen=True, unsafe_hash=True)
270 @final
271 class CAType(RegType):
272 __slots__ = ()
273
274 @property
275 def reg_class(self):
276 # type: () -> RegClass
277 return RegClass([XERBit.CA])
278
279
280 @plain_data(frozen=True, unsafe_hash=True)
281 @final
282 class GlobalMemType(RegType):
283 __slots__ = ()
284
285 @property
286 def reg_class(self):
287 # type: () -> RegClass
288 return RegClass([GlobalMem.GlobalMem])
289
290
291 @plain_data(frozen=True, unsafe_hash=True)
292 @final
293 class KnownVLType(RegType):
294 __slots__ = "length",
295
296 def __init__(self, length):
297 # type: (int) -> None
298 if not (0 < length <= 64):
299 raise ValueError("invalid VL value")
300 self.length = length
301
302 @property
303 def reg_class(self):
304 # type: () -> RegClass
305 return RegClass([VL.VL_MAXVL])
306
307
308 def assert_vl_is(vl, expected_vl):
309 # type: (SSAKnownVL | KnownVLType | int | None, int) -> None
310 if vl is None:
311 vl = 1
312 elif isinstance(vl, SSAVal):
313 vl = vl.ty.length
314 elif isinstance(vl, KnownVLType):
315 vl = vl.length
316 if vl != expected_vl:
317 raise ValueError(
318 f"wrong VL: expected {expected_vl} got {vl}")
319
320
321 STACK_SLOT_SIZE = 8
322
323
324 @plain_data(frozen=True, unsafe_hash=True)
325 @final
326 class StackSlot(RegLoc):
327 __slots__ = "start_slot", "length_in_slots",
328
329 def __init__(self, start_slot, length_in_slots):
330 # type: (int, int) -> None
331 self.start_slot = start_slot
332 if length_in_slots < 1:
333 raise ValueError("invalid length_in_slots")
334 self.length_in_slots = length_in_slots
335
336 @property
337 def stop_slot(self):
338 return self.start_slot + self.length_in_slots
339
340 @property
341 def start_byte(self):
342 return self.start_slot * STACK_SLOT_SIZE
343
344 def conflicts(self, other):
345 # type: (RegLoc) -> bool
346 if isinstance(other, StackSlot):
347 return (self.stop_slot > other.start_slot
348 and other.stop_slot > self.start_slot)
349 return False
350
351 def get_subreg_at_offset(self, subreg_type, offset):
352 # type: (RegType, int) -> StackSlot
353 if not isinstance(subreg_type, StackSlotType):
354 raise ValueError(f"subreg_type is not a "
355 f"StackSlotType: {subreg_type}")
356 if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
357 raise ValueError(f"sub-register offset is out of range: {offset}")
358 return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
359
360
361 STACK_SLOT_COUNT = 128
362
363
364 @plain_data(frozen=True, eq=False)
365 @final
366 class StackSlotType(RegType):
367 __slots__ = "length_in_slots",
368
369 def __init__(self, length_in_slots=1):
370 # type: (int) -> None
371 if length_in_slots < 1:
372 raise ValueError("invalid length_in_slots")
373 self.length_in_slots = length_in_slots
374
375 @staticmethod
376 @lru_cache(maxsize=None)
377 def __get_reg_class(length_in_slots):
378 # type: (int) -> RegClass
379 regs = []
380 for start in range(STACK_SLOT_COUNT - length_in_slots):
381 reg = StackSlot(start, length_in_slots)
382 regs.append(reg)
383 return RegClass(regs)
384
385 @property
386 def reg_class(self):
387 # type: () -> RegClass
388 return StackSlotType.__get_reg_class(self.length_in_slots)
389
390 @final
391 def __eq__(self, other):
392 if isinstance(other, StackSlotType):
393 return self.length_in_slots == other.length_in_slots
394 return False
395
396 @final
397 def __hash__(self):
398 return hash(self.length_in_slots)
399
400
401 @plain_data(frozen=True, eq=False, repr=False)
402 @final
403 class SSAVal(Generic[_RegType]):
404 __slots__ = "op", "arg_name", "ty",
405
406 def __init__(self, op, arg_name, ty):
407 # type: (Op, str, _RegType) -> None
408 self.op = op
409 """the Op that writes this SSAVal"""
410
411 self.arg_name = arg_name
412 """the name of the argument of self.op that writes this SSAVal"""
413
414 self.ty = ty
415
416 def __eq__(self, rhs):
417 if isinstance(rhs, SSAVal):
418 return (self.op is rhs.op
419 and self.arg_name == rhs.arg_name)
420 return False
421
422 def __hash__(self):
423 return hash((id(self.op), self.arg_name))
424
425 def __repr__(self):
426 return f"<#{self.op.id}.{self.arg_name}: {self.ty}>"
427
428
429 SSAGPRRange = SSAVal[GPRRangeType]
430 SSAGPR = SSAVal[GPRType]
431 SSAKnownVL = SSAVal[KnownVLType]
432
433
434 @final
435 @plain_data(unsafe_hash=True, frozen=True)
436 class EqualityConstraint:
437 __slots__ = "lhs", "rhs"
438
439 def __init__(self, lhs, rhs):
440 # type: (list[SSAVal], list[SSAVal]) -> None
441 self.lhs = lhs
442 self.rhs = rhs
443 if len(lhs) == 0 or len(rhs) == 0:
444 raise ValueError("can't constrain an empty list to be equal")
445
446
447 @final
448 class Fn:
449 __slots__ = "ops",
450
451 def __init__(self):
452 # type: () -> None
453 self.ops = [] # type: list[Op]
454
455 def __repr__(self, short=False):
456 if short:
457 return "<Fn>"
458 ops = ", ".join(op.__repr__(just_id=True) for op in self.ops)
459 return f"<Fn([{ops}])>"
460
461 def pre_ra_sim(self, state):
462 # type: (PreRASimState) -> None
463 for op in self.ops:
464 op.pre_ra_sim(state)
465
466
467 class _NotSet:
468 """ helper for __repr__ for when fields aren't set """
469
470 def __repr__(self):
471 return "<not set>"
472
473
474 _NOT_SET = _NotSet()
475
476
477 @plain_data(frozen=True, unsafe_hash=True)
478 class AsmTemplateSegment(Generic[_RegType], metaclass=ABCMeta):
479 __slots__ = "ssa_val",
480
481 def __init__(self, ssa_val):
482 # type: (SSAVal[_RegType]) -> None
483 self.ssa_val = ssa_val
484
485 def render(self, regs):
486 # type: (dict[SSAVal, RegLoc]) -> str
487 return self._render(regs[self.ssa_val])
488
489 @abstractmethod
490 def _render(self, reg):
491 # type: (RegLoc) -> str
492 ...
493
494
495 @plain_data(frozen=True, unsafe_hash=True)
496 @final
497 class ATSGPR(AsmTemplateSegment[GPRRangeType]):
498 __slots__ = "offset",
499
500 def __init__(self, ssa_val, offset=0):
501 # type: (SSAGPRRange, int) -> None
502 super().__init__(ssa_val)
503 self.offset = offset
504
505 def _render(self, reg):
506 # type: (RegLoc) -> str
507 if not isinstance(reg, GPRRange):
508 raise TypeError()
509 return str(reg.start + self.offset)
510
511
512 @plain_data(frozen=True, unsafe_hash=True)
513 @final
514 class ATSStackSlot(AsmTemplateSegment[StackSlotType]):
515 __slots__ = ()
516
517 def _render(self, reg):
518 # type: (RegLoc) -> str
519 if not isinstance(reg, StackSlot):
520 raise TypeError()
521 return f"{reg.start_slot}(1)"
522
523
524 @plain_data(frozen=True, unsafe_hash=True)
525 @final
526 class ATSCopyGPRRange(AsmTemplateSegment["GPRRangeType | FixedGPRRangeType"]):
527 __slots__ = "src_ssa_val",
528
529 def __init__(self, ssa_val, src_ssa_val):
530 # type: (SSAVal[GPRRangeType | FixedGPRRangeType], SSAVal[GPRRangeType | FixedGPRRangeType]) -> None
531 self.ssa_val = ssa_val
532 self.src_ssa_val = src_ssa_val
533
534 def render(self, regs):
535 # type: (dict[SSAVal, RegLoc]) -> str
536 src = regs[self.src_ssa_val]
537 dest = regs[self.ssa_val]
538 if not isinstance(dest, GPRRange):
539 raise TypeError()
540 if not isinstance(src, GPRRange):
541 raise TypeError()
542 if src.length != dest.length:
543 raise ValueError()
544 if src == dest:
545 return ""
546 mrr = ""
547 sv_ = "sv."
548 if src.length == 1:
549 sv_ = ""
550 elif src.conflicts(dest) and src.start > dest.start:
551 mrr = "/mrr"
552 return f"{sv_}or{mrr} *{dest.start}, *{src.start}, *{src.start}\n"
553
554 def _render(self, reg):
555 # type: (RegLoc) -> str
556 raise TypeError("must call self.render")
557
558
559 @final
560 class AsmTemplate(Sequence["str | AsmTemplateSegment"]):
561 @staticmethod
562 def __process_segments(segments):
563 # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> Iterable[str | AsmTemplateSegment]
564 for i in segments:
565 if isinstance(i, AsmTemplate):
566 yield from i
567 else:
568 yield i
569
570 def __init__(self, segments=()):
571 # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> None
572 self.__segments = tuple(self.__process_segments(segments))
573
574 def __getitem__(self, index):
575 # type: (int) -> str | AsmTemplateSegment
576 return self.__segments[index]
577
578 def __len__(self):
579 return len(self.__segments)
580
581 def __iter__(self):
582 return iter(self.__segments)
583
584 def __hash__(self):
585 return hash(self.__segments)
586
587 def render(self, regs):
588 # type: (dict[SSAVal, RegLoc]) -> str
589 retval = [] # type: list[str]
590 for segment in self:
591 if isinstance(segment, AsmTemplateSegment):
592 retval.append(segment.render(regs))
593 else:
594 retval.append(segment)
595 return "".join(retval)
596
597
598 @final
599 class AsmContext:
600 def __init__(self, assigned_registers):
601 # type: (dict[SSAVal, RegLoc]) -> None
602 self.__assigned_registers = assigned_registers
603
604 def reg(self, ssa_val, expected_ty):
605 # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc
606 try:
607 reg = self.__assigned_registers[ssa_val]
608 except KeyError as e:
609 raise ValueError(f"SSAVal not assigned a register: {ssa_val}")
610 wrong_len = (isinstance(reg, GPRRange)
611 and reg.length != ssa_val.ty.length)
612 if not isinstance(reg, expected_ty) or wrong_len:
613 raise TypeError(
614 f"SSAVal is assigned a register of the wrong type: "
615 f"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}")
616 return reg
617
618 def gpr_range(self, ssa_val):
619 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange
620 return self.reg(ssa_val, GPRRange)
621
622 def stack_slot(self, ssa_val):
623 # type: (SSAVal[StackSlotType]) -> StackSlot
624 return self.reg(ssa_val, StackSlot)
625
626 def gpr(self, ssa_val, vec, offset=0):
627 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str
628 reg = self.gpr_range(ssa_val).start + offset
629 return "*" * vec + str(reg)
630
631 def vgpr(self, ssa_val, offset=0):
632 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str
633 return self.gpr(ssa_val=ssa_val, vec=True, offset=offset)
634
635 def sgpr(self, ssa_val, offset=0):
636 # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str
637 return self.gpr(ssa_val=ssa_val, vec=False, offset=offset)
638
639 def needs_sv(self, *regs):
640 # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool
641 for reg in regs:
642 reg = self.gpr_range(reg)
643 if reg.length != 1 or reg.start >= 32:
644 return True
645 return False
646
647
648 GPR_SIZE_IN_BYTES = 8
649 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * 8
650 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
651
652
653 @plain_data(frozen=True)
654 @final
655 class PreRASimState:
656 __slots__ = ("gprs", "VLs", "CAs",
657 "global_mems", "stack_slots",
658 "fixed_gprs")
659
660 def __init__(
661 self,
662 gprs, # type: dict[SSAGPRRange, tuple[int, ...]]
663 VLs, # type: dict[SSAKnownVL, int]
664 CAs, # type: dict[SSAVal[CAType], bool]
665 global_mems, # type: dict[SSAVal[GlobalMemType], FMap[int, int]]
666 stack_slots, # type: dict[SSAVal[StackSlotType], tuple[int, ...]]
667 fixed_gprs, # type: dict[SSAVal[FixedGPRRangeType], tuple[int, ...]]
668 ):
669 # type: (...) -> None
670 self.gprs = gprs
671 self.VLs = VLs
672 self.CAs = CAs
673 self.global_mems = global_mems
674 self.stack_slots = stack_slots
675 self.fixed_gprs = fixed_gprs
676
677
678 @plain_data(unsafe_hash=True, frozen=True, repr=False)
679 class Op(metaclass=ABCMeta):
680 __slots__ = "id", "fn"
681
682 @abstractmethod
683 def inputs(self):
684 # type: () -> dict[str, SSAVal]
685 ...
686
687 @abstractmethod
688 def outputs(self):
689 # type: () -> dict[str, SSAVal]
690 ...
691
692 def get_equality_constraints(self):
693 # type: () -> Iterable[EqualityConstraint]
694 if False:
695 yield ...
696
697 def get_extra_interferences(self):
698 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
699 if False:
700 yield ...
701
702 def __init__(self, fn):
703 # type: (Fn) -> None
704 self.id = len(fn.ops)
705 fn.ops.append(self)
706 self.fn = fn
707
708 @final
709 def __repr__(self, just_id=False):
710 fields_list = [f"#{self.id}"]
711 outputs = None
712 try:
713 outputs = self.outputs()
714 except AttributeError:
715 pass
716 if not just_id:
717 for name in fields(self):
718 if name in ("id", "fn"):
719 continue
720 v = getattr(self, name, _NOT_SET)
721 if (outputs is not None and name in outputs
722 and outputs[name] is v):
723 fields_list.append(repr(v))
724 else:
725 fields_list.append(f"{name}={v!r}")
726 fields_str = ', '.join(fields_list)
727 return f"{self.__class__.__name__}({fields_str})"
728
729 @abstractmethod
730 def get_asm_lines(self, ctx):
731 # type: (AsmContext) -> list[str]
732 """get the lines of assembly for this Op"""
733 ...
734
735 @abstractmethod
736 def pre_ra_sim(self, state):
737 # type: (PreRASimState) -> None
738 """simulate op before register allocation"""
739 ...
740
741
742 @plain_data(unsafe_hash=True, frozen=True, repr=False)
743 @final
744 class OpLoadFromStackSlot(Op):
745 __slots__ = "dest", "src", "vl"
746
747 def inputs(self):
748 # type: () -> dict[str, SSAVal]
749 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
750 if self.vl is not None:
751 retval["vl"] = self.vl
752 return retval
753
754 def outputs(self):
755 # type: () -> dict[str, SSAVal]
756 return {"dest": self.dest}
757
758 def __init__(self, fn, src, vl=None):
759 # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None
760 super().__init__(fn)
761 self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
762 self.src = src
763 self.vl = vl
764 assert_vl_is(vl, self.dest.ty.length)
765
766 def get_asm_lines(self, ctx):
767 # type: (AsmContext) -> list[str]
768 dest = ctx.gpr(self.dest, vec=self.dest.ty.length != 1)
769 src = ctx.stack_slot(self.src)
770 if ctx.needs_sv(self.dest):
771 return [f"sv.ld {dest}, {src.start_byte}(1)"]
772 return [f"ld {dest}, {src.start_byte}(1)"]
773
774 def pre_ra_sim(self, state):
775 # type: (PreRASimState) -> None
776 """simulate op before register allocation"""
777 state.gprs[self.dest] = state.stack_slots[self.src]
778
779
780 @plain_data(unsafe_hash=True, frozen=True, repr=False)
781 @final
782 class OpStoreToStackSlot(Op):
783 __slots__ = "dest", "src", "vl"
784
785 def inputs(self):
786 # type: () -> dict[str, SSAVal]
787 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
788 if self.vl is not None:
789 retval["vl"] = self.vl
790 return retval
791
792 def outputs(self):
793 # type: () -> dict[str, SSAVal]
794 return {"dest": self.dest}
795
796 def __init__(self, fn, src, vl=None):
797 # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None
798 super().__init__(fn)
799 self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
800 self.src = src
801 self.vl = vl
802 assert_vl_is(vl, src.ty.length)
803
804 def get_asm_lines(self, ctx):
805 # type: (AsmContext) -> list[str]
806 src = ctx.gpr(self.src, vec=self.src.ty.length != 1)
807 dest = ctx.stack_slot(self.dest)
808 if ctx.needs_sv(self.src):
809 return [f"sv.std {src}, {dest.start_byte}(1)"]
810 return [f"std {src}, {dest.start_byte}(1)"]
811
812 def pre_ra_sim(self, state):
813 # type: (PreRASimState) -> None
814 """simulate op before register allocation"""
815 state.stack_slots[self.dest] = state.gprs[self.src]
816
817
818 _RegSrcType = TypeVar("_RegSrcType", bound=RegType)
819
820
821 @plain_data(unsafe_hash=True, frozen=True, repr=False)
822 @final
823 class OpCopy(Op, Generic[_RegSrcType, _RegType]):
824 __slots__ = "dest", "src", "vl"
825
826 def inputs(self):
827 # type: () -> dict[str, SSAVal]
828 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
829 if self.vl is not None:
830 retval["vl"] = self.vl
831 return retval
832
833 def outputs(self):
834 # type: () -> dict[str, SSAVal]
835 return {"dest": self.dest}
836
837 def __init__(self, fn, src, dest_ty=None, vl=None):
838 # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None
839 super().__init__(fn)
840 if dest_ty is None:
841 dest_ty = cast(_RegType, src.ty)
842 if isinstance(src.ty, GPRRangeType) \
843 and isinstance(dest_ty, FixedGPRRangeType):
844 if src.ty.length != dest_ty.reg.length:
845 raise ValueError(f"incompatible source and destination "
846 f"types: {src.ty} and {dest_ty}")
847 length = src.ty.length
848 elif isinstance(src.ty, FixedGPRRangeType) \
849 and isinstance(dest_ty, GPRRangeType):
850 if src.ty.reg.length != dest_ty.length:
851 raise ValueError(f"incompatible source and destination "
852 f"types: {src.ty} and {dest_ty}")
853 length = src.ty.length
854 elif src.ty != dest_ty:
855 raise ValueError(f"incompatible source and destination "
856 f"types: {src.ty} and {dest_ty}")
857 elif isinstance(src.ty, StackSlotType):
858 raise ValueError("can't use OpCopy on stack slots")
859 elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)):
860 length = src.ty.length
861 else:
862 length = 1
863
864 self.dest = SSAVal(self, "dest", dest_ty) # type: SSAVal[_RegType]
865 self.src = src
866 self.vl = vl
867 assert_vl_is(vl, length)
868
869 def get_asm_lines(self, ctx):
870 # type: (AsmContext) -> list[str]
871 if ctx.reg(self.src, RegLoc) == ctx.reg(self.dest, RegLoc):
872 return []
873 if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
874 isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
875 vec = self.dest.ty.length != 1
876 dest = ctx.gpr_range(self.dest) # type: ignore
877 src = ctx.gpr_range(self.src) # type: ignore
878 dest_s = ctx.gpr(self.dest, vec=vec) # type: ignore
879 src_s = ctx.gpr(self.src, vec=vec) # type: ignore
880 mrr = ""
881 if src.conflicts(dest) and src.start > dest.start:
882 mrr = "/mrr"
883 if ctx.needs_sv(self.src, self.dest): # type: ignore
884 return [f"sv.or{mrr} {dest_s}, {src_s}, {src_s}"]
885 return [f"or {dest_s}, {src_s}, {src_s}"]
886 raise NotImplementedError
887
888 def pre_ra_sim(self, state):
889 # type: (PreRASimState) -> None
890 if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
891 isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
892 if isinstance(self.src.ty, GPRRangeType):
893 v = state.gprs[self.src] # type: ignore
894 else:
895 v = state.fixed_gprs[self.src] # type: ignore
896 if isinstance(self.dest.ty, GPRRangeType):
897 state.gprs[self.dest] = v # type: ignore
898 else:
899 state.fixed_gprs[self.dest] = v # type: ignore
900 elif (isinstance(self.src.ty, FixedGPRRangeType) and
901 isinstance(self.dest.ty, GPRRangeType)):
902 state.gprs[self.dest] = state.fixed_gprs[self.src] # type: ignore
903 elif (isinstance(self.src.ty, GPRRangeType) and
904 isinstance(self.dest.ty, FixedGPRRangeType)):
905 state.fixed_gprs[self.dest] = state.gprs[self.src] # type: ignore
906 elif (isinstance(self.src.ty, CAType) and
907 self.src.ty == self.dest.ty):
908 state.CAs[self.dest] = state.CAs[self.src] # type: ignore
909 elif (isinstance(self.src.ty, KnownVLType) and
910 self.src.ty == self.dest.ty):
911 state.VLs[self.dest] = state.VLs[self.src] # type: ignore
912 elif (isinstance(self.src.ty, GlobalMemType) and
913 self.src.ty == self.dest.ty):
914 v = state.global_mems[self.src] # type: ignore
915 state.global_mems[self.dest] = v # type: ignore
916 else:
917 raise NotImplementedError
918
919
920 @plain_data(unsafe_hash=True, frozen=True, repr=False)
921 @final
922 class OpConcat(Op):
923 __slots__ = "dest", "sources"
924
925 def inputs(self):
926 # type: () -> dict[str, SSAVal]
927 return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
928
929 def outputs(self):
930 # type: () -> dict[str, SSAVal]
931 return {"dest": self.dest}
932
933 def __init__(self, fn, sources):
934 # type: (Fn, Iterable[SSAGPRRange]) -> None
935 super().__init__(fn)
936 sources = tuple(sources)
937 self.dest = SSAVal(self, "dest", GPRRangeType(
938 sum(i.ty.length for i in sources)))
939 self.sources = sources
940
941 def get_equality_constraints(self):
942 # type: () -> Iterable[EqualityConstraint]
943 yield EqualityConstraint([self.dest], [*self.sources])
944
945 def get_asm_lines(self, ctx):
946 # type: (AsmContext) -> list[str]
947 return []
948
949 def pre_ra_sim(self, state):
950 # type: (PreRASimState) -> None
951 v = []
952 for src in self.sources:
953 v.extend(state.gprs[src])
954 state.gprs[self.dest] = tuple(v)
955
956
957 @plain_data(unsafe_hash=True, frozen=True, repr=False)
958 @final
959 class OpSplit(Op):
960 __slots__ = "results", "src"
961
962 def inputs(self):
963 # type: () -> dict[str, SSAVal]
964 return {"src": self.src}
965
966 def outputs(self):
967 # type: () -> dict[str, SSAVal]
968 return {i.arg_name: i for i in self.results}
969
970 def __init__(self, fn, src, split_indexes):
971 # type: (Fn, SSAGPRRange, Iterable[int]) -> None
972 super().__init__(fn)
973 ranges = [] # type: list[GPRRangeType]
974 last = 0
975 for i in split_indexes:
976 if not (0 < i < src.ty.length):
977 raise ValueError(f"invalid split index: {i}, must be in "
978 f"0 < i < {src.ty.length}")
979 ranges.append(GPRRangeType(i - last))
980 last = i
981 ranges.append(GPRRangeType(src.ty.length - last))
982 self.src = src
983 self.results = tuple(
984 SSAVal(self, f"results[{i}]", r) for i, r in enumerate(ranges))
985
986 def get_equality_constraints(self):
987 # type: () -> Iterable[EqualityConstraint]
988 yield EqualityConstraint([*self.results], [self.src])
989
990 def get_asm_lines(self, ctx):
991 # type: (AsmContext) -> list[str]
992 return []
993
994 def pre_ra_sim(self, state):
995 # type: (PreRASimState) -> None
996 rest = state.gprs[self.src]
997 for dest in reversed(self.results):
998 state.gprs[dest] = rest[-dest.ty.length:]
999 rest = rest[:-dest.ty.length]
1000
1001
1002 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1003 @final
1004 class OpBigIntAddSub(Op):
1005 __slots__ = "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl"
1006
1007 def inputs(self):
1008 # type: () -> dict[str, SSAVal]
1009 retval = {} # type: dict[str, SSAVal[Any]]
1010 retval["lhs"] = self.lhs
1011 retval["rhs"] = self.rhs
1012 retval["CA_in"] = self.CA_in
1013 if self.vl is not None:
1014 retval["vl"] = self.vl
1015 return retval
1016
1017 def outputs(self):
1018 # type: () -> dict[str, SSAVal]
1019 return {"out": self.out, "CA_out": self.CA_out}
1020
1021 def __init__(self, fn, lhs, rhs, CA_in, is_sub, vl=None):
1022 # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None
1023 super().__init__(fn)
1024 if lhs.ty != rhs.ty:
1025 raise TypeError(f"source types must match: "
1026 f"{lhs} doesn't match {rhs}")
1027 self.out = SSAVal(self, "out", lhs.ty)
1028 self.lhs = lhs
1029 self.rhs = rhs
1030 self.CA_in = CA_in
1031 self.CA_out = SSAVal(self, "CA_out", CA_in.ty)
1032 self.is_sub = is_sub
1033 self.vl = vl
1034 assert_vl_is(vl, lhs.ty.length)
1035
1036 def get_extra_interferences(self):
1037 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1038 yield self.out, self.lhs
1039 yield self.out, self.rhs
1040
1041 def get_asm_lines(self, ctx):
1042 # type: (AsmContext) -> list[str]
1043 vec = self.out.ty.length != 1
1044 out = ctx.gpr(self.out, vec=vec)
1045 RA = ctx.gpr(self.lhs, vec=vec)
1046 RB = ctx.gpr(self.rhs, vec=vec)
1047 mnemonic = "adde"
1048 if self.is_sub:
1049 mnemonic = "subfe"
1050 RA, RB = RB, RA # reorder to match subfe
1051 if ctx.needs_sv(self.out, self.lhs, self.rhs):
1052 return [f"sv.{mnemonic} {out}, {RA}, {RB}"]
1053 return [f"{mnemonic} {out}, {RA}, {RB}"]
1054
1055 def pre_ra_sim(self, state):
1056 # type: (PreRASimState) -> None
1057 carry = state.CAs[self.CA_in]
1058 out = [] # type: list[int]
1059 for l, r in zip(state.gprs[self.lhs], state.gprs[self.rhs]):
1060 if self.is_sub:
1061 r = r ^ GPR_VALUE_MASK
1062 s = l + r + carry
1063 carry = s != (s & GPR_VALUE_MASK)
1064 out.append(s & GPR_VALUE_MASK)
1065 state.CAs[self.CA_out] = carry
1066 state.gprs[self.out] = tuple(out)
1067
1068
1069 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1070 @final
1071 class OpBigIntMulDiv(Op):
1072 __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div", "vl"
1073
1074 def inputs(self):
1075 # type: () -> dict[str, SSAVal]
1076 retval = {} # type: dict[str, SSAVal[Any]]
1077 retval["RA"] = self.RA
1078 retval["RB"] = self.RB
1079 retval["RC"] = self.RC
1080 if self.vl is not None:
1081 retval["vl"] = self.vl
1082 return retval
1083
1084 def outputs(self):
1085 # type: () -> dict[str, SSAVal]
1086 return {"RT": self.RT, "RS": self.RS}
1087
1088 def __init__(self, fn, RA, RB, RC, is_div, vl):
1089 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None
1090 super().__init__(fn)
1091 self.RT = SSAVal(self, "RT", RA.ty)
1092 self.RA = RA
1093 self.RB = RB
1094 self.RC = RC
1095 self.RS = SSAVal(self, "RS", RC.ty)
1096 self.is_div = is_div
1097 self.vl = vl
1098 assert_vl_is(vl, RA.ty.length)
1099
1100 def get_equality_constraints(self):
1101 # type: () -> Iterable[EqualityConstraint]
1102 yield EqualityConstraint([self.RC], [self.RS])
1103
1104 def get_extra_interferences(self):
1105 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1106 yield self.RT, self.RA
1107 yield self.RT, self.RB
1108 yield self.RT, self.RC
1109 yield self.RT, self.RS
1110 yield self.RS, self.RA
1111 yield self.RS, self.RB
1112
1113 def get_asm_lines(self, ctx):
1114 # type: (AsmContext) -> list[str]
1115 vec = self.RT.ty.length != 1
1116 RT = ctx.gpr(self.RT, vec=vec)
1117 RA = ctx.gpr(self.RA, vec=vec)
1118 RB = ctx.sgpr(self.RB)
1119 RC = ctx.sgpr(self.RC)
1120 mnemonic = "maddedu"
1121 if self.is_div:
1122 mnemonic = "divmod2du/mrr"
1123 return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
1124
1125 def pre_ra_sim(self, state):
1126 # type: (PreRASimState) -> None
1127 carry = state.gprs[self.RC][0]
1128 RA = state.gprs[self.RA]
1129 RB = state.gprs[self.RB][0]
1130 RT = [0] * self.RT.ty.length
1131 if self.is_div:
1132 for i in reversed(range(self.RT.ty.length)):
1133 if carry < RB and RB != 0:
1134 div, mod = divmod((carry << 64) | RA[i], RB)
1135 RT[i] = div & GPR_VALUE_MASK
1136 carry = mod & GPR_VALUE_MASK
1137 else:
1138 RT[i] = GPR_VALUE_MASK
1139 carry = 0
1140 else:
1141 for i in range(self.RT.ty.length):
1142 v = RA[i] * RB + carry
1143 carry = v >> 64
1144 RT[i] = v & GPR_VALUE_MASK
1145 state.gprs[self.RS] = carry,
1146 state.gprs[self.RT] = tuple(RT)
1147
1148
1149 @final
1150 @unique
1151 class ShiftKind(Enum):
1152 Sl = "sl"
1153 Sr = "sr"
1154 Sra = "sra"
1155
1156 def make_big_int_carry_in(self, fn, inp):
1157 # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]]
1158 if self is ShiftKind.Sl or self is ShiftKind.Sr:
1159 li = OpLI(fn, 0)
1160 return li.out, [li]
1161 else:
1162 assert self is ShiftKind.Sra
1163 split = OpSplit(fn, inp, [inp.ty.length - 1])
1164 shr = OpShiftImm(fn, split.results[1], sh=63, kind=ShiftKind.Sra)
1165 return shr.out, [split, shr]
1166
1167 def make_big_int_shift(self, fn, inp, sh, vl):
1168 # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]]
1169 carry_in, ops = self.make_big_int_carry_in(fn, inp)
1170 big_int_shift = OpBigIntShift(fn, inp, sh, carry_in, kind=self, vl=vl)
1171 ops.append(big_int_shift)
1172 return big_int_shift.out, ops
1173
1174
1175 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1176 @final
1177 class OpBigIntShift(Op):
1178 __slots__ = "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl"
1179
1180 def inputs(self):
1181 # type: () -> dict[str, SSAVal]
1182 retval = {} # type: dict[str, SSAVal[Any]]
1183 retval["inp"] = self.inp
1184 retval["sh"] = self.sh
1185 retval["carry_in"] = self.carry_in
1186 if self.vl is not None:
1187 retval["vl"] = self.vl
1188 return retval
1189
1190 def outputs(self):
1191 # type: () -> dict[str, SSAVal]
1192 return {"out": self.out, "_out_padding": self._out_padding}
1193
1194 def __init__(self, fn, inp, sh, carry_in, kind, vl=None):
1195 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None
1196 super().__init__(fn)
1197 self.out = SSAVal(self, "out", inp.ty)
1198 self._out_padding = SSAVal(self, "_out_padding", GPRRangeType())
1199 self.carry_in = carry_in
1200 self.inp = inp
1201 self.sh = sh
1202 self.kind = kind
1203 self.vl = vl
1204 assert_vl_is(vl, inp.ty.length)
1205
1206 def get_extra_interferences(self):
1207 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1208 yield self.out, self.sh
1209
1210 def get_equality_constraints(self):
1211 # type: () -> Iterable[EqualityConstraint]
1212 if self.kind is ShiftKind.Sl:
1213 yield EqualityConstraint([self.carry_in, self.inp],
1214 [self.out, self._out_padding])
1215 else:
1216 assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
1217 yield EqualityConstraint([self.inp, self.carry_in],
1218 [self._out_padding, self.out])
1219
1220 def get_asm_lines(self, ctx):
1221 # type: (AsmContext) -> list[str]
1222 vec = self.out.ty.length != 1
1223 if self.kind is ShiftKind.Sl:
1224 RT = ctx.gpr(self.out, vec=vec)
1225 RA = ctx.gpr(self.out, vec=vec, offset=-1)
1226 RB = ctx.sgpr(self.sh)
1227 mrr = "/mrr" if vec else ""
1228 return [f"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"]
1229 else:
1230 assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
1231 RT = ctx.gpr(self.out, vec=vec)
1232 RA = ctx.gpr(self.out, vec=vec, offset=1)
1233 RB = ctx.sgpr(self.sh)
1234 return [f"sv.dsrd {RT}, {RA}, {RB}, 1"]
1235
1236 def pre_ra_sim(self, state):
1237 # type: (PreRASimState) -> None
1238 out = [0] * self.out.ty.length
1239 carry = state.gprs[self.carry_in][0]
1240 sh = state.gprs[self.sh][0] % 64
1241 if self.kind is ShiftKind.Sl:
1242 inp = carry, *state.gprs[self.inp]
1243 for i in reversed(range(self.out.ty.length)):
1244 v = inp[i] | (inp[i + 1] << 64)
1245 v <<= sh
1246 out[i] = (v >> 64) & GPR_VALUE_MASK
1247 else:
1248 assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
1249 inp = *state.gprs[self.inp], carry
1250 for i in range(self.out.ty.length):
1251 v = inp[i] | (inp[i + 1] << 64)
1252 v >>= sh
1253 out[i] = v & GPR_VALUE_MASK
1254 # state.gprs[self._out_padding] is intentionally not written
1255 state.gprs[self.out] = tuple(out)
1256
1257
1258 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1259 @final
1260 class OpShiftImm(Op):
1261 __slots__ = "out", "inp", "sh", "kind", "ca_out"
1262
1263 def inputs(self):
1264 # type: () -> dict[str, SSAVal]
1265 return {"inp": self.inp}
1266
1267 def outputs(self):
1268 # type: () -> dict[str, SSAVal]
1269 if self.ca_out is not None:
1270 return {"out": self.out, "ca_out": self.ca_out}
1271 return {"out": self.out}
1272
1273 def __init__(self, fn, inp, sh, kind):
1274 # type: (Fn, SSAGPR, int, ShiftKind) -> None
1275 super().__init__(fn)
1276 self.out = SSAVal(self, "out", inp.ty)
1277 self.inp = inp
1278 if not (0 <= sh < 64):
1279 raise ValueError("shift amount out of range")
1280 self.sh = sh
1281 self.kind = kind
1282 if self.kind is ShiftKind.Sra:
1283 self.ca_out = SSAVal(self, "ca_out", CAType())
1284 else:
1285 self.ca_out = None
1286
1287 def get_asm_lines(self, ctx):
1288 # type: (AsmContext) -> list[str]
1289 out = ctx.sgpr(self.out)
1290 inp = ctx.sgpr(self.inp)
1291 if self.kind is ShiftKind.Sl:
1292 mnemonic = "rldicr"
1293 args = f"{self.sh}, {63 - self.sh}"
1294 elif self.kind is ShiftKind.Sr:
1295 mnemonic = "rldicl"
1296 v = (64 - self.sh) % 64
1297 args = f"{v}, {self.sh}"
1298 else:
1299 assert self.kind is ShiftKind.Sra
1300 mnemonic = "sradi"
1301 args = f"{self.sh}"
1302 if ctx.needs_sv(self.out, self.inp):
1303 return [f"sv.{mnemonic} {out}, {inp}, {args}"]
1304 return [f"{mnemonic} {out}, {inp}, {args}"]
1305
1306 def pre_ra_sim(self, state):
1307 # type: (PreRASimState) -> None
1308 inp = state.gprs[self.inp][0]
1309 if self.kind is ShiftKind.Sl:
1310 assert self.ca_out is None
1311 out = inp << self.sh
1312 elif self.kind is ShiftKind.Sr:
1313 assert self.ca_out is None
1314 out = inp >> self.sh
1315 else:
1316 assert self.kind is ShiftKind.Sra
1317 assert self.ca_out is not None
1318 if inp & (1 << 63): # sign extend
1319 inp -= 1 << 64
1320 out = inp >> self.sh
1321 ca = inp < 0 and (out << self.sh) != inp
1322 state.CAs[self.ca_out] = ca
1323 state.gprs[self.out] = out,
1324
1325
1326 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1327 @final
1328 class OpLI(Op):
1329 __slots__ = "out", "value", "vl"
1330
1331 def inputs(self):
1332 # type: () -> dict[str, SSAVal]
1333 retval = {} # type: dict[str, SSAVal[Any]]
1334 if self.vl is not None:
1335 retval["vl"] = self.vl
1336 return retval
1337
1338 def outputs(self):
1339 # type: () -> dict[str, SSAVal]
1340 return {"out": self.out}
1341
1342 def __init__(self, fn, value, vl=None):
1343 # type: (Fn, int, SSAKnownVL | None) -> None
1344 super().__init__(fn)
1345 if vl is None:
1346 length = 1
1347 else:
1348 length = vl.ty.length
1349 self.out = SSAVal(self, "out", GPRRangeType(length))
1350 if not (-1 << 15 <= value <= (1 << 15) - 1):
1351 raise ValueError(f"value out of range: {value}")
1352 self.value = value
1353 self.vl = vl
1354 assert_vl_is(vl, length)
1355
1356 def get_asm_lines(self, ctx):
1357 # type: (AsmContext) -> list[str]
1358 vec = self.out.ty.length != 1
1359 out = ctx.gpr(self.out, vec=vec)
1360 if ctx.needs_sv(self.out):
1361 return [f"sv.addi {out}, 0, {self.value}"]
1362 return [f"addi {out}, 0, {self.value}"]
1363
1364 def pre_ra_sim(self, state):
1365 # type: (PreRASimState) -> None
1366 value = self.value & GPR_VALUE_MASK
1367 state.gprs[self.out] = (value,) * self.out.ty.length
1368
1369
1370 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1371 @final
1372 class OpSetCA(Op):
1373 __slots__ = "out", "value"
1374
1375 def inputs(self):
1376 # type: () -> dict[str, SSAVal]
1377 return {}
1378
1379 def outputs(self):
1380 # type: () -> dict[str, SSAVal]
1381 return {"out": self.out}
1382
1383 def __init__(self, fn, value):
1384 # type: (Fn, bool) -> None
1385 super().__init__(fn)
1386 self.out = SSAVal(self, "out", CAType())
1387 self.value = value
1388
1389 def get_asm_lines(self, ctx):
1390 # type: (AsmContext) -> list[str]
1391 if self.value:
1392 return ["subfic 0, 0, -1"]
1393 return ["addic 0, 0, 0"]
1394
1395 def pre_ra_sim(self, state):
1396 # type: (PreRASimState) -> None
1397 state.CAs[self.out] = self.value
1398
1399
1400 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1401 @final
1402 class OpLoad(Op):
1403 __slots__ = "RT", "RA", "offset", "mem", "vl"
1404
1405 def inputs(self):
1406 # type: () -> dict[str, SSAVal]
1407 retval = {} # type: dict[str, SSAVal[Any]]
1408 retval["RA"] = self.RA
1409 retval["mem"] = self.mem
1410 if self.vl is not None:
1411 retval["vl"] = self.vl
1412 return retval
1413
1414 def outputs(self):
1415 # type: () -> dict[str, SSAVal]
1416 return {"RT": self.RT}
1417
1418 def __init__(self, fn, RA, offset, mem, vl=None):
1419 # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1420 super().__init__(fn)
1421 if vl is None:
1422 length = 1
1423 else:
1424 length = vl.ty.length
1425 self.RT = SSAVal(self, "RT", GPRRangeType(length))
1426 self.RA = RA
1427 if not (-1 << 15 <= offset <= (1 << 15) - 1):
1428 raise ValueError(f"offset out of range: {offset}")
1429 if offset % 4 != 0:
1430 raise ValueError(f"offset not aligned: {offset}")
1431 self.offset = offset
1432 self.mem = mem
1433 self.vl = vl
1434 assert_vl_is(vl, length)
1435
1436 def get_extra_interferences(self):
1437 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1438 if self.RT.ty.length > 1:
1439 yield self.RT, self.RA
1440
1441 def get_asm_lines(self, ctx):
1442 # type: (AsmContext) -> list[str]
1443 RT = ctx.gpr(self.RT, vec=self.RT.ty.length != 1)
1444 RA = ctx.sgpr(self.RA)
1445 if ctx.needs_sv(self.RT, self.RA):
1446 return [f"sv.ld {RT}, {self.offset}({RA})"]
1447 return [f"ld {RT}, {self.offset}({RA})"]
1448
1449 def pre_ra_sim(self, state):
1450 # type: (PreRASimState) -> None
1451 addr = state.gprs[self.RA][0]
1452 addr += self.offset
1453 RT = [0] * self.RT.ty.length
1454 mem = state.global_mems[self.mem]
1455 for i in range(self.RT.ty.length):
1456 cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
1457 if cur_addr % GPR_SIZE_IN_BYTES != 0:
1458 raise ValueError(f"can't load from unaligned address: "
1459 f"{cur_addr:#x}")
1460 for j in range(GPR_SIZE_IN_BYTES):
1461 byte_val = mem.get(cur_addr + j, 0) & 0xFF
1462 RT[i] |= byte_val << (j * 8)
1463 state.gprs[self.RT] = tuple(RT)
1464
1465
1466 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1467 @final
1468 class OpStore(Op):
1469 __slots__ = "RS", "RA", "offset", "mem_in", "mem_out", "vl"
1470
1471 def inputs(self):
1472 # type: () -> dict[str, SSAVal]
1473 retval = {} # type: dict[str, SSAVal[Any]]
1474 retval["RS"] = self.RS
1475 retval["RA"] = self.RA
1476 retval["mem_in"] = self.mem_in
1477 if self.vl is not None:
1478 retval["vl"] = self.vl
1479 return retval
1480
1481 def outputs(self):
1482 # type: () -> dict[str, SSAVal]
1483 return {"mem_out": self.mem_out}
1484
1485 def __init__(self, fn, RS, RA, offset, mem_in, vl=None):
1486 # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1487 super().__init__(fn)
1488 self.RS = RS
1489 self.RA = RA
1490 if not (-1 << 15 <= offset <= (1 << 15) - 1):
1491 raise ValueError(f"offset out of range: {offset}")
1492 if offset % 4 != 0:
1493 raise ValueError(f"offset not aligned: {offset}")
1494 self.offset = offset
1495 self.mem_in = mem_in
1496 self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
1497 self.vl = vl
1498 assert_vl_is(vl, RS.ty.length)
1499
1500 def get_asm_lines(self, ctx):
1501 # type: (AsmContext) -> list[str]
1502 RS = ctx.gpr(self.RS, vec=self.RS.ty.length != 1)
1503 RA = ctx.sgpr(self.RA)
1504 if ctx.needs_sv(self.RS, self.RA):
1505 return [f"sv.std {RS}, {self.offset}({RA})"]
1506 return [f"std {RS}, {self.offset}({RA})"]
1507
1508 def pre_ra_sim(self, state):
1509 # type: (PreRASimState) -> None
1510 mem = dict(state.global_mems[self.mem_in])
1511 addr = state.gprs[self.RA][0]
1512 addr += self.offset
1513 RS = state.gprs[self.RS]
1514 for i in range(self.RS.ty.length):
1515 cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
1516 if cur_addr % GPR_SIZE_IN_BYTES != 0:
1517 raise ValueError(f"can't store to unaligned address: "
1518 f"{cur_addr:#x}")
1519 for j in range(GPR_SIZE_IN_BYTES):
1520 mem[cur_addr + j] = (RS[i] >> (j * 8)) & 0xFF
1521 state.global_mems[self.mem_out] = FMap(mem)
1522
1523
1524 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1525 @final
1526 class OpFuncArg(Op):
1527 __slots__ = "out",
1528
1529 def inputs(self):
1530 # type: () -> dict[str, SSAVal]
1531 return {}
1532
1533 def outputs(self):
1534 # type: () -> dict[str, SSAVal]
1535 return {"out": self.out}
1536
1537 def __init__(self, fn, ty):
1538 # type: (Fn, FixedGPRRangeType) -> None
1539 super().__init__(fn)
1540 self.out = SSAVal(self, "out", ty)
1541
1542 def get_asm_lines(self, ctx):
1543 # type: (AsmContext) -> list[str]
1544 return []
1545
1546 def pre_ra_sim(self, state):
1547 # type: (PreRASimState) -> None
1548 if self.out not in state.fixed_gprs:
1549 state.fixed_gprs[self.out] = (0,) * self.out.ty.length
1550
1551
1552 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1553 @final
1554 class OpInputMem(Op):
1555 __slots__ = "out",
1556
1557 def inputs(self):
1558 # type: () -> dict[str, SSAVal]
1559 return {}
1560
1561 def outputs(self):
1562 # type: () -> dict[str, SSAVal]
1563 return {"out": self.out}
1564
1565 def __init__(self, fn):
1566 # type: (Fn) -> None
1567 super().__init__(fn)
1568 self.out = SSAVal(self, "out", GlobalMemType())
1569
1570 def get_asm_lines(self, ctx):
1571 # type: (AsmContext) -> list[str]
1572 return []
1573
1574 def pre_ra_sim(self, state):
1575 # type: (PreRASimState) -> None
1576 if self.out not in state.global_mems:
1577 state.global_mems[self.out] = FMap()
1578
1579
1580 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1581 @final
1582 class OpSetVLImm(Op):
1583 __slots__ = "out",
1584
1585 def inputs(self):
1586 # type: () -> dict[str, SSAVal]
1587 return {}
1588
1589 def outputs(self):
1590 # type: () -> dict[str, SSAVal]
1591 return {"out": self.out}
1592
1593 def __init__(self, fn, length):
1594 # type: (Fn, int) -> None
1595 super().__init__(fn)
1596 self.out = SSAVal(self, "out", KnownVLType(length))
1597
1598 def get_asm_lines(self, ctx):
1599 # type: (AsmContext) -> list[str]
1600 return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
1601
1602 def pre_ra_sim(self, state):
1603 # type: (PreRASimState) -> None
1604 state.VLs[self.out] = self.out.ty.length
1605
1606
1607 def op_set_to_list(ops):
1608 # type: (Iterable[Op]) -> list[Op]
1609 worklists = [{}] # type: list[dict[Op, None]]
1610 inps_to_ops_map = defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
1611 ops_to_pending_input_count_map = {} # type: dict[Op, int]
1612 for op in ops:
1613 input_count = 0
1614 for val in op.inputs().values():
1615 input_count += 1
1616 inps_to_ops_map[val][op] = None
1617 while len(worklists) <= input_count:
1618 worklists.append({})
1619 ops_to_pending_input_count_map[op] = input_count
1620 worklists[input_count][op] = None
1621 retval = [] # type: list[Op]
1622 ready_vals = OSet() # type: OSet[SSAVal]
1623 while len(worklists[0]) != 0:
1624 writing_op = next(iter(worklists[0]))
1625 del worklists[0][writing_op]
1626 retval.append(writing_op)
1627 for val in writing_op.outputs().values():
1628 if val in ready_vals:
1629 raise ValueError(f"multiple instructions must not write "
1630 f"to the same SSA value: {val}")
1631 ready_vals.add(val)
1632 for reading_op in inps_to_ops_map[val]:
1633 pending = ops_to_pending_input_count_map[reading_op]
1634 del worklists[pending][reading_op]
1635 pending -= 1
1636 worklists[pending][reading_op] = None
1637 ops_to_pending_input_count_map[reading_op] = pending
1638 for worklist in worklists:
1639 for op in worklist:
1640 raise ValueError(f"instruction is part of a dependency loop or "
1641 f"its inputs are never written: {op}")
1642 return retval
1643
1644
1645 def generate_assembly(ops, assigned_registers=None):
1646 # type: (list[Op], dict[SSAVal, RegLoc] | None) -> list[str]
1647 if assigned_registers is None:
1648 from bigint_presentation_code.register_allocator import \
1649 allocate_registers
1650 assigned_registers = allocate_registers(ops)
1651 ctx = AsmContext(assigned_registers)
1652 retval = [] # list[str]
1653 for op in ops:
1654 retval.extend(op.get_asm_lines(ctx))
1655 retval.append("bclr 20, 0, 0")
1656 return retval