remove unused code I forgot
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 """
2 Compiler IR for Toom-Cook algorithm generator for SVP64
3
4 This assumes VL != 0 throughout.
5 """
6
7 from abc import ABCMeta, abstractmethod
8 from collections import defaultdict
9 from enum import Enum, EnumMeta, unique
10 from functools import lru_cache
11 from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast
12
13 from nmutil.plain_data import fields, plain_data
14
15 from bigint_presentation_code.util import FMap, OFSet, OSet, final
16
17
18 class ABCEnumMeta(EnumMeta, ABCMeta):
19 pass
20
21
22 class RegLoc(metaclass=ABCMeta):
23 __slots__ = ()
24
25 @abstractmethod
26 def conflicts(self, other):
27 # type: (RegLoc) -> bool
28 ...
29
30 def get_subreg_at_offset(self, subreg_type, offset):
31 # type: (RegType, int) -> RegLoc
32 if self not in subreg_type.reg_class:
33 raise ValueError(f"register not a member of subreg_type: "
34 f"reg={self} subreg_type={subreg_type}")
35 if offset != 0:
36 raise ValueError(f"non-zero sub-register offset not supported "
37 f"for register: {self}")
38 return self
39
40
41 GPR_COUNT = 128
42
43
44 @plain_data(frozen=True, unsafe_hash=True, repr=False)
45 @final
46 class GPRRange(RegLoc, Sequence["GPRRange"]):
47 __slots__ = "start", "length"
48
49 def __init__(self, start, length=None):
50 # type: (int | range, int | None) -> None
51 if isinstance(start, range):
52 if length is not None:
53 raise TypeError("can't specify length when input is a range")
54 if start.step != 1:
55 raise ValueError("range must have a step of 1")
56 length = len(start)
57 start = start.start
58 elif length is None:
59 length = 1
60 if length <= 0 or start < 0 or start + length > GPR_COUNT:
61 raise ValueError("invalid GPRRange")
62 self.start = start
63 self.length = length
64
65 @property
66 def stop(self):
67 return self.start + self.length
68
69 @property
70 def step(self):
71 return 1
72
73 @property
74 def range(self):
75 return range(self.start, self.stop, self.step)
76
77 def __len__(self):
78 return self.length
79
80 def __getitem__(self, item):
81 # type: (int | slice) -> GPRRange
82 return GPRRange(self.range[item])
83
84 def __contains__(self, value):
85 # type: (GPRRange) -> bool
86 return value.start >= self.start and value.stop <= self.stop
87
88 def index(self, sub, start=None, end=None):
89 # type: (GPRRange, int | None, int | None) -> int
90 r = self.range[start:end]
91 if sub.start < r.start or sub.stop > r.stop:
92 raise ValueError("GPR range not found")
93 return sub.start - self.start
94
95 def count(self, sub, start=None, end=None):
96 # type: (GPRRange, int | None, int | None) -> int
97 r = self.range[start:end]
98 if len(r) == 0:
99 return 0
100 return int(sub in GPRRange(r))
101
102 def conflicts(self, other):
103 # type: (RegLoc) -> bool
104 if isinstance(other, GPRRange):
105 return self.stop > other.start and other.stop > self.start
106 return False
107
108 def get_subreg_at_offset(self, subreg_type, offset):
109 # type: (RegType, int) -> GPRRange
110 if not isinstance(subreg_type, (GPRRangeType, FixedGPRRangeType)):
111 raise ValueError(f"subreg_type is not a FixedGPRRangeType or "
112 f"GPRRangeType: {subreg_type}")
113 if offset < 0 or offset + subreg_type.length > self.stop:
114 raise ValueError(f"sub-register offset is out of range: {offset}")
115 return GPRRange(self.start + offset, subreg_type.length)
116
117 def __repr__(self):
118 if self.length == 1:
119 return f"<r{self.start}>"
120 return f"<r{self.start}..len={self.length}>"
121
122
123 SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
124
125
126 @final
127 @unique
128 class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
129 CA = "CA"
130
131 def conflicts(self, other):
132 # type: (RegLoc) -> bool
133 if isinstance(other, XERBit):
134 return self == other
135 return False
136
137
138 @final
139 @unique
140 class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
141 """singleton representing all non-StackSlot memory -- treated as a single
142 physical register for register allocation purposes.
143 """
144 GlobalMem = "GlobalMem"
145
146 def conflicts(self, other):
147 # type: (RegLoc) -> bool
148 if isinstance(other, GlobalMem):
149 return self == other
150 return False
151
152
153 @final
154 @unique
155 class VL(RegLoc, Enum, metaclass=ABCEnumMeta):
156 VL_MAXVL = "VL_MAXVL"
157 """VL and MAXVL"""
158
159 def conflicts(self, other):
160 # type: (RegLoc) -> bool
161 if isinstance(other, VL):
162 return self == other
163 return False
164
165
166 @final
167 class RegClass(OFSet[RegLoc]):
168 """ an ordered set of registers.
169 earlier registers are preferred by the register allocator.
170 """
171
172 @lru_cache(maxsize=None, typed=True)
173 def max_conflicts_with(self, other):
174 # type: (RegClass | RegLoc) -> int
175 """the largest number of registers in `self` that a single register
176 from `other` can conflict with
177 """
178 if isinstance(other, RegClass):
179 return max(self.max_conflicts_with(i) for i in other)
180 else:
181 return sum(other.conflicts(i) for i in self)
182
183
184 @plain_data(frozen=True, unsafe_hash=True)
185 class RegType(metaclass=ABCMeta):
186 __slots__ = ()
187
188 @property
189 @abstractmethod
190 def reg_class(self):
191 # type: () -> RegClass
192 return ...
193
194
195 _RegType = TypeVar("_RegType", bound=RegType)
196 _RegLoc = TypeVar("_RegLoc", bound=RegLoc)
197
198
199 @plain_data(frozen=True, eq=False, repr=False)
200 @final
201 class GPRRangeType(RegType):
202 __slots__ = "length",
203
204 def __init__(self, length=1):
205 # type: (int) -> None
206 if length < 1 or length > GPR_COUNT:
207 raise ValueError("invalid length")
208 self.length = length
209
210 @staticmethod
211 @lru_cache(maxsize=None)
212 def __get_reg_class(length):
213 # type: (int) -> RegClass
214 regs = []
215 for start in range(GPR_COUNT - length):
216 reg = GPRRange(start, length)
217 if any(i in reg for i in SPECIAL_GPRS):
218 continue
219 regs.append(reg)
220 return RegClass(regs)
221
222 @property
223 @final
224 def reg_class(self):
225 # type: () -> RegClass
226 return GPRRangeType.__get_reg_class(self.length)
227
228 @final
229 def __eq__(self, other):
230 if isinstance(other, GPRRangeType):
231 return self.length == other.length
232 return False
233
234 @final
235 def __hash__(self):
236 return hash(self.length)
237
238 def __repr__(self):
239 return f"<gpr_ty[{self.length}]>"
240
241
242 GPRType = GPRRangeType
243 """a length=1 GPRRangeType"""
244
245
246 @plain_data(frozen=True, unsafe_hash=True, repr=False)
247 @final
248 class FixedGPRRangeType(RegType):
249 __slots__ = "reg",
250
251 def __init__(self, reg):
252 # type: (GPRRange) -> None
253 self.reg = reg
254
255 @property
256 def reg_class(self):
257 # type: () -> RegClass
258 return RegClass([self.reg])
259
260 @property
261 def length(self):
262 # type: () -> int
263 return self.reg.length
264
265 def __repr__(self):
266 return f"<fixed({self.reg})>"
267
268
269 @plain_data(frozen=True, unsafe_hash=True)
270 @final
271 class CAType(RegType):
272 __slots__ = ()
273
274 @property
275 def reg_class(self):
276 # type: () -> RegClass
277 return RegClass([XERBit.CA])
278
279
280 @plain_data(frozen=True, unsafe_hash=True)
281 @final
282 class GlobalMemType(RegType):
283 __slots__ = ()
284
285 @property
286 def reg_class(self):
287 # type: () -> RegClass
288 return RegClass([GlobalMem.GlobalMem])
289
290
291 @plain_data(frozen=True, unsafe_hash=True)
292 @final
293 class KnownVLType(RegType):
294 __slots__ = "length",
295
296 def __init__(self, length):
297 # type: (int) -> None
298 if not (0 < length <= 64):
299 raise ValueError("invalid VL value")
300 self.length = length
301
302 @property
303 def reg_class(self):
304 # type: () -> RegClass
305 return RegClass([VL.VL_MAXVL])
306
307
308 def assert_vl_is(vl, expected_vl):
309 # type: (SSAKnownVL | KnownVLType | int | None, int) -> None
310 if vl is None:
311 vl = 1
312 elif isinstance(vl, SSAVal):
313 vl = vl.ty.length
314 elif isinstance(vl, KnownVLType):
315 vl = vl.length
316 if vl != expected_vl:
317 raise ValueError(
318 f"wrong VL: expected {expected_vl} got {vl}")
319
320
321 STACK_SLOT_SIZE = 8
322
323
324 @plain_data(frozen=True, unsafe_hash=True)
325 @final
326 class StackSlot(RegLoc):
327 __slots__ = "start_slot", "length_in_slots",
328
329 def __init__(self, start_slot, length_in_slots):
330 # type: (int, int) -> None
331 self.start_slot = start_slot
332 if length_in_slots < 1:
333 raise ValueError("invalid length_in_slots")
334 self.length_in_slots = length_in_slots
335
336 @property
337 def stop_slot(self):
338 return self.start_slot + self.length_in_slots
339
340 @property
341 def start_byte(self):
342 return self.start_slot * STACK_SLOT_SIZE
343
344 def conflicts(self, other):
345 # type: (RegLoc) -> bool
346 if isinstance(other, StackSlot):
347 return (self.stop_slot > other.start_slot
348 and other.stop_slot > self.start_slot)
349 return False
350
351 def get_subreg_at_offset(self, subreg_type, offset):
352 # type: (RegType, int) -> StackSlot
353 if not isinstance(subreg_type, StackSlotType):
354 raise ValueError(f"subreg_type is not a "
355 f"StackSlotType: {subreg_type}")
356 if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
357 raise ValueError(f"sub-register offset is out of range: {offset}")
358 return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
359
360
361 STACK_SLOT_COUNT = 128
362
363
364 @plain_data(frozen=True, eq=False)
365 @final
366 class StackSlotType(RegType):
367 __slots__ = "length_in_slots",
368
369 def __init__(self, length_in_slots=1):
370 # type: (int) -> None
371 if length_in_slots < 1:
372 raise ValueError("invalid length_in_slots")
373 self.length_in_slots = length_in_slots
374
375 @staticmethod
376 @lru_cache(maxsize=None)
377 def __get_reg_class(length_in_slots):
378 # type: (int) -> RegClass
379 regs = []
380 for start in range(STACK_SLOT_COUNT - length_in_slots):
381 reg = StackSlot(start, length_in_slots)
382 regs.append(reg)
383 return RegClass(regs)
384
385 @property
386 def reg_class(self):
387 # type: () -> RegClass
388 return StackSlotType.__get_reg_class(self.length_in_slots)
389
390 @final
391 def __eq__(self, other):
392 if isinstance(other, StackSlotType):
393 return self.length_in_slots == other.length_in_slots
394 return False
395
396 @final
397 def __hash__(self):
398 return hash(self.length_in_slots)
399
400
401 @plain_data(frozen=True, eq=False, repr=False)
402 @final
403 class SSAVal(Generic[_RegType]):
404 __slots__ = "op", "arg_name", "ty",
405
406 def __init__(self, op, arg_name, ty):
407 # type: (Op, str, _RegType) -> None
408 self.op = op
409 """the Op that writes this SSAVal"""
410
411 self.arg_name = arg_name
412 """the name of the argument of self.op that writes this SSAVal"""
413
414 self.ty = ty
415
416 def __eq__(self, rhs):
417 if isinstance(rhs, SSAVal):
418 return (self.op is rhs.op
419 and self.arg_name == rhs.arg_name)
420 return False
421
422 def __hash__(self):
423 return hash((id(self.op), self.arg_name))
424
425 def __repr__(self):
426 return f"<#{self.op.id}.{self.arg_name}: {self.ty}>"
427
428
429 SSAGPRRange = SSAVal[GPRRangeType]
430 SSAGPR = SSAVal[GPRType]
431 SSAKnownVL = SSAVal[KnownVLType]
432
433
434 @final
435 @plain_data(unsafe_hash=True, frozen=True)
436 class EqualityConstraint:
437 __slots__ = "lhs", "rhs"
438
439 def __init__(self, lhs, rhs):
440 # type: (list[SSAVal], list[SSAVal]) -> None
441 self.lhs = lhs
442 self.rhs = rhs
443 if len(lhs) == 0 or len(rhs) == 0:
444 raise ValueError("can't constrain an empty list to be equal")
445
446
447 @final
448 class Fn:
449 __slots__ = "ops",
450
451 def __init__(self):
452 # type: () -> None
453 self.ops = [] # type: list[Op]
454
455 def __repr__(self, short=False):
456 if short:
457 return "<Fn>"
458 ops = ", ".join(op.__repr__(just_id=True) for op in self.ops)
459 return f"<Fn([{ops}])>"
460
461 def pre_ra_sim(self, state):
462 # type: (PreRASimState) -> None
463 for op in self.ops:
464 op.pre_ra_sim(state)
465
466
467 class _NotSet:
468 """ helper for __repr__ for when fields aren't set """
469
470 def __repr__(self):
471 return "<not set>"
472
473
474 _NOT_SET = _NotSet()
475
476
477 @final
478 class AsmContext:
479 def __init__(self, assigned_registers):
480 # type: (dict[SSAVal, RegLoc]) -> None
481 self.__assigned_registers = assigned_registers
482
483 def reg(self, ssa_val, expected_ty):
484 # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc
485 try:
486 reg = self.__assigned_registers[ssa_val]
487 except KeyError as e:
488 raise ValueError(f"SSAVal not assigned a register: {ssa_val}")
489 wrong_len = (isinstance(reg, GPRRange)
490 and reg.length != ssa_val.ty.length)
491 if not isinstance(reg, expected_ty) or wrong_len:
492 raise TypeError(
493 f"SSAVal is assigned a register of the wrong type: "
494 f"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}")
495 return reg
496
497 def gpr_range(self, ssa_val):
498 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange
499 return self.reg(ssa_val, GPRRange)
500
501 def stack_slot(self, ssa_val):
502 # type: (SSAVal[StackSlotType]) -> StackSlot
503 return self.reg(ssa_val, StackSlot)
504
505 def gpr(self, ssa_val, vec, offset=0):
506 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str
507 reg = self.gpr_range(ssa_val).start + offset
508 return "*" * vec + str(reg)
509
510 def vgpr(self, ssa_val, offset=0):
511 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str
512 return self.gpr(ssa_val=ssa_val, vec=True, offset=offset)
513
514 def sgpr(self, ssa_val, offset=0):
515 # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str
516 return self.gpr(ssa_val=ssa_val, vec=False, offset=offset)
517
518 def needs_sv(self, *regs):
519 # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool
520 for reg in regs:
521 reg = self.gpr_range(reg)
522 if reg.length != 1 or reg.start >= 32:
523 return True
524 return False
525
526
527 GPR_SIZE_IN_BYTES = 8
528 GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * 8
529 GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
530
531
532 @plain_data(frozen=True)
533 @final
534 class PreRASimState:
535 __slots__ = ("gprs", "VLs", "CAs",
536 "global_mems", "stack_slots",
537 "fixed_gprs")
538
539 def __init__(
540 self,
541 gprs, # type: dict[SSAGPRRange, tuple[int, ...]]
542 VLs, # type: dict[SSAKnownVL, int]
543 CAs, # type: dict[SSAVal[CAType], bool]
544 global_mems, # type: dict[SSAVal[GlobalMemType], FMap[int, int]]
545 stack_slots, # type: dict[SSAVal[StackSlotType], tuple[int, ...]]
546 fixed_gprs, # type: dict[SSAVal[FixedGPRRangeType], tuple[int, ...]]
547 ):
548 # type: (...) -> None
549 self.gprs = gprs
550 self.VLs = VLs
551 self.CAs = CAs
552 self.global_mems = global_mems
553 self.stack_slots = stack_slots
554 self.fixed_gprs = fixed_gprs
555
556
557 @plain_data(unsafe_hash=True, frozen=True, repr=False)
558 class Op(metaclass=ABCMeta):
559 __slots__ = "id", "fn"
560
561 @abstractmethod
562 def inputs(self):
563 # type: () -> dict[str, SSAVal]
564 ...
565
566 @abstractmethod
567 def outputs(self):
568 # type: () -> dict[str, SSAVal]
569 ...
570
571 def get_equality_constraints(self):
572 # type: () -> Iterable[EqualityConstraint]
573 if False:
574 yield ...
575
576 def get_extra_interferences(self):
577 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
578 if False:
579 yield ...
580
581 def __init__(self, fn):
582 # type: (Fn) -> None
583 self.id = len(fn.ops)
584 fn.ops.append(self)
585 self.fn = fn
586
587 @final
588 def __repr__(self, just_id=False):
589 fields_list = [f"#{self.id}"]
590 outputs = None
591 try:
592 outputs = self.outputs()
593 except AttributeError:
594 pass
595 if not just_id:
596 for name in fields(self):
597 if name in ("id", "fn"):
598 continue
599 v = getattr(self, name, _NOT_SET)
600 if (outputs is not None and name in outputs
601 and outputs[name] is v):
602 fields_list.append(repr(v))
603 else:
604 fields_list.append(f"{name}={v!r}")
605 fields_str = ', '.join(fields_list)
606 return f"{self.__class__.__name__}({fields_str})"
607
608 @abstractmethod
609 def get_asm_lines(self, ctx):
610 # type: (AsmContext) -> list[str]
611 """get the lines of assembly for this Op"""
612 ...
613
614 @abstractmethod
615 def pre_ra_sim(self, state):
616 # type: (PreRASimState) -> None
617 """simulate op before register allocation"""
618 ...
619
620
621 @plain_data(unsafe_hash=True, frozen=True, repr=False)
622 @final
623 class OpLoadFromStackSlot(Op):
624 __slots__ = "dest", "src", "vl"
625
626 def inputs(self):
627 # type: () -> dict[str, SSAVal]
628 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
629 if self.vl is not None:
630 retval["vl"] = self.vl
631 return retval
632
633 def outputs(self):
634 # type: () -> dict[str, SSAVal]
635 return {"dest": self.dest}
636
637 def __init__(self, fn, src, vl=None):
638 # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None
639 super().__init__(fn)
640 self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
641 self.src = src
642 self.vl = vl
643 assert_vl_is(vl, self.dest.ty.length)
644
645 def get_asm_lines(self, ctx):
646 # type: (AsmContext) -> list[str]
647 dest = ctx.gpr(self.dest, vec=self.dest.ty.length != 1)
648 src = ctx.stack_slot(self.src)
649 if ctx.needs_sv(self.dest):
650 return [f"sv.ld {dest}, {src.start_byte}(1)"]
651 return [f"ld {dest}, {src.start_byte}(1)"]
652
653 def pre_ra_sim(self, state):
654 # type: (PreRASimState) -> None
655 """simulate op before register allocation"""
656 state.gprs[self.dest] = state.stack_slots[self.src]
657
658
659 @plain_data(unsafe_hash=True, frozen=True, repr=False)
660 @final
661 class OpStoreToStackSlot(Op):
662 __slots__ = "dest", "src", "vl"
663
664 def inputs(self):
665 # type: () -> dict[str, SSAVal]
666 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
667 if self.vl is not None:
668 retval["vl"] = self.vl
669 return retval
670
671 def outputs(self):
672 # type: () -> dict[str, SSAVal]
673 return {"dest": self.dest}
674
675 def __init__(self, fn, src, vl=None):
676 # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None
677 super().__init__(fn)
678 self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
679 self.src = src
680 self.vl = vl
681 assert_vl_is(vl, src.ty.length)
682
683 def get_asm_lines(self, ctx):
684 # type: (AsmContext) -> list[str]
685 src = ctx.gpr(self.src, vec=self.src.ty.length != 1)
686 dest = ctx.stack_slot(self.dest)
687 if ctx.needs_sv(self.src):
688 return [f"sv.std {src}, {dest.start_byte}(1)"]
689 return [f"std {src}, {dest.start_byte}(1)"]
690
691 def pre_ra_sim(self, state):
692 # type: (PreRASimState) -> None
693 """simulate op before register allocation"""
694 state.stack_slots[self.dest] = state.gprs[self.src]
695
696
697 _RegSrcType = TypeVar("_RegSrcType", bound=RegType)
698
699
700 @plain_data(unsafe_hash=True, frozen=True, repr=False)
701 @final
702 class OpCopy(Op, Generic[_RegSrcType, _RegType]):
703 __slots__ = "dest", "src", "vl"
704
705 def inputs(self):
706 # type: () -> dict[str, SSAVal]
707 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
708 if self.vl is not None:
709 retval["vl"] = self.vl
710 return retval
711
712 def outputs(self):
713 # type: () -> dict[str, SSAVal]
714 return {"dest": self.dest}
715
716 def __init__(self, fn, src, dest_ty=None, vl=None):
717 # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None
718 super().__init__(fn)
719 if dest_ty is None:
720 dest_ty = cast(_RegType, src.ty)
721 if isinstance(src.ty, GPRRangeType) \
722 and isinstance(dest_ty, FixedGPRRangeType):
723 if src.ty.length != dest_ty.reg.length:
724 raise ValueError(f"incompatible source and destination "
725 f"types: {src.ty} and {dest_ty}")
726 length = src.ty.length
727 elif isinstance(src.ty, FixedGPRRangeType) \
728 and isinstance(dest_ty, GPRRangeType):
729 if src.ty.reg.length != dest_ty.length:
730 raise ValueError(f"incompatible source and destination "
731 f"types: {src.ty} and {dest_ty}")
732 length = src.ty.length
733 elif src.ty != dest_ty:
734 raise ValueError(f"incompatible source and destination "
735 f"types: {src.ty} and {dest_ty}")
736 elif isinstance(src.ty, StackSlotType):
737 raise ValueError("can't use OpCopy on stack slots")
738 elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)):
739 length = src.ty.length
740 else:
741 length = 1
742
743 self.dest = SSAVal(self, "dest", dest_ty) # type: SSAVal[_RegType]
744 self.src = src
745 self.vl = vl
746 assert_vl_is(vl, length)
747
748 def get_asm_lines(self, ctx):
749 # type: (AsmContext) -> list[str]
750 if ctx.reg(self.src, RegLoc) == ctx.reg(self.dest, RegLoc):
751 return []
752 if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
753 isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
754 vec = self.dest.ty.length != 1
755 dest = ctx.gpr_range(self.dest) # type: ignore
756 src = ctx.gpr_range(self.src) # type: ignore
757 dest_s = ctx.gpr(self.dest, vec=vec) # type: ignore
758 src_s = ctx.gpr(self.src, vec=vec) # type: ignore
759 mrr = ""
760 if src.conflicts(dest) and src.start > dest.start:
761 mrr = "/mrr"
762 if ctx.needs_sv(self.src, self.dest): # type: ignore
763 return [f"sv.or{mrr} {dest_s}, {src_s}, {src_s}"]
764 return [f"or {dest_s}, {src_s}, {src_s}"]
765 raise NotImplementedError
766
767 def pre_ra_sim(self, state):
768 # type: (PreRASimState) -> None
769 if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
770 isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
771 if isinstance(self.src.ty, GPRRangeType):
772 v = state.gprs[self.src] # type: ignore
773 else:
774 v = state.fixed_gprs[self.src] # type: ignore
775 if isinstance(self.dest.ty, GPRRangeType):
776 state.gprs[self.dest] = v # type: ignore
777 else:
778 state.fixed_gprs[self.dest] = v # type: ignore
779 elif (isinstance(self.src.ty, FixedGPRRangeType) and
780 isinstance(self.dest.ty, GPRRangeType)):
781 state.gprs[self.dest] = state.fixed_gprs[self.src] # type: ignore
782 elif (isinstance(self.src.ty, GPRRangeType) and
783 isinstance(self.dest.ty, FixedGPRRangeType)):
784 state.fixed_gprs[self.dest] = state.gprs[self.src] # type: ignore
785 elif (isinstance(self.src.ty, CAType) and
786 self.src.ty == self.dest.ty):
787 state.CAs[self.dest] = state.CAs[self.src] # type: ignore
788 elif (isinstance(self.src.ty, KnownVLType) and
789 self.src.ty == self.dest.ty):
790 state.VLs[self.dest] = state.VLs[self.src] # type: ignore
791 elif (isinstance(self.src.ty, GlobalMemType) and
792 self.src.ty == self.dest.ty):
793 v = state.global_mems[self.src] # type: ignore
794 state.global_mems[self.dest] = v # type: ignore
795 else:
796 raise NotImplementedError
797
798
799 @plain_data(unsafe_hash=True, frozen=True, repr=False)
800 @final
801 class OpConcat(Op):
802 __slots__ = "dest", "sources"
803
804 def inputs(self):
805 # type: () -> dict[str, SSAVal]
806 return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
807
808 def outputs(self):
809 # type: () -> dict[str, SSAVal]
810 return {"dest": self.dest}
811
812 def __init__(self, fn, sources):
813 # type: (Fn, Iterable[SSAGPRRange]) -> None
814 super().__init__(fn)
815 sources = tuple(sources)
816 self.dest = SSAVal(self, "dest", GPRRangeType(
817 sum(i.ty.length for i in sources)))
818 self.sources = sources
819
820 def get_equality_constraints(self):
821 # type: () -> Iterable[EqualityConstraint]
822 yield EqualityConstraint([self.dest], [*self.sources])
823
824 def get_asm_lines(self, ctx):
825 # type: (AsmContext) -> list[str]
826 return []
827
828 def pre_ra_sim(self, state):
829 # type: (PreRASimState) -> None
830 v = []
831 for src in self.sources:
832 v.extend(state.gprs[src])
833 state.gprs[self.dest] = tuple(v)
834
835
836 @plain_data(unsafe_hash=True, frozen=True, repr=False)
837 @final
838 class OpSplit(Op):
839 __slots__ = "results", "src"
840
841 def inputs(self):
842 # type: () -> dict[str, SSAVal]
843 return {"src": self.src}
844
845 def outputs(self):
846 # type: () -> dict[str, SSAVal]
847 return {i.arg_name: i for i in self.results}
848
849 def __init__(self, fn, src, split_indexes):
850 # type: (Fn, SSAGPRRange, Iterable[int]) -> None
851 super().__init__(fn)
852 ranges = [] # type: list[GPRRangeType]
853 last = 0
854 for i in split_indexes:
855 if not (0 < i < src.ty.length):
856 raise ValueError(f"invalid split index: {i}, must be in "
857 f"0 < i < {src.ty.length}")
858 ranges.append(GPRRangeType(i - last))
859 last = i
860 ranges.append(GPRRangeType(src.ty.length - last))
861 self.src = src
862 self.results = tuple(
863 SSAVal(self, f"results[{i}]", r) for i, r in enumerate(ranges))
864
865 def get_equality_constraints(self):
866 # type: () -> Iterable[EqualityConstraint]
867 yield EqualityConstraint([*self.results], [self.src])
868
869 def get_asm_lines(self, ctx):
870 # type: (AsmContext) -> list[str]
871 return []
872
873 def pre_ra_sim(self, state):
874 # type: (PreRASimState) -> None
875 rest = state.gprs[self.src]
876 for dest in reversed(self.results):
877 state.gprs[dest] = rest[-dest.ty.length:]
878 rest = rest[:-dest.ty.length]
879
880
881 @plain_data(unsafe_hash=True, frozen=True, repr=False)
882 @final
883 class OpBigIntAddSub(Op):
884 __slots__ = "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl"
885
886 def inputs(self):
887 # type: () -> dict[str, SSAVal]
888 retval = {} # type: dict[str, SSAVal[Any]]
889 retval["lhs"] = self.lhs
890 retval["rhs"] = self.rhs
891 retval["CA_in"] = self.CA_in
892 if self.vl is not None:
893 retval["vl"] = self.vl
894 return retval
895
896 def outputs(self):
897 # type: () -> dict[str, SSAVal]
898 return {"out": self.out, "CA_out": self.CA_out}
899
900 def __init__(self, fn, lhs, rhs, CA_in, is_sub, vl=None):
901 # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None
902 super().__init__(fn)
903 if lhs.ty != rhs.ty:
904 raise TypeError(f"source types must match: "
905 f"{lhs} doesn't match {rhs}")
906 self.out = SSAVal(self, "out", lhs.ty)
907 self.lhs = lhs
908 self.rhs = rhs
909 self.CA_in = CA_in
910 self.CA_out = SSAVal(self, "CA_out", CA_in.ty)
911 self.is_sub = is_sub
912 self.vl = vl
913 assert_vl_is(vl, lhs.ty.length)
914
915 def get_extra_interferences(self):
916 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
917 yield self.out, self.lhs
918 yield self.out, self.rhs
919
920 def get_asm_lines(self, ctx):
921 # type: (AsmContext) -> list[str]
922 vec = self.out.ty.length != 1
923 out = ctx.gpr(self.out, vec=vec)
924 RA = ctx.gpr(self.lhs, vec=vec)
925 RB = ctx.gpr(self.rhs, vec=vec)
926 mnemonic = "adde"
927 if self.is_sub:
928 mnemonic = "subfe"
929 RA, RB = RB, RA # reorder to match subfe
930 if ctx.needs_sv(self.out, self.lhs, self.rhs):
931 return [f"sv.{mnemonic} {out}, {RA}, {RB}"]
932 return [f"{mnemonic} {out}, {RA}, {RB}"]
933
934 def pre_ra_sim(self, state):
935 # type: (PreRASimState) -> None
936 carry = state.CAs[self.CA_in]
937 out = [] # type: list[int]
938 for l, r in zip(state.gprs[self.lhs], state.gprs[self.rhs]):
939 if self.is_sub:
940 r = r ^ GPR_VALUE_MASK
941 s = l + r + carry
942 carry = s != (s & GPR_VALUE_MASK)
943 out.append(s & GPR_VALUE_MASK)
944 state.CAs[self.CA_out] = carry
945 state.gprs[self.out] = tuple(out)
946
947
948 @plain_data(unsafe_hash=True, frozen=True, repr=False)
949 @final
950 class OpBigIntMulDiv(Op):
951 __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div", "vl"
952
953 def inputs(self):
954 # type: () -> dict[str, SSAVal]
955 retval = {} # type: dict[str, SSAVal[Any]]
956 retval["RA"] = self.RA
957 retval["RB"] = self.RB
958 retval["RC"] = self.RC
959 if self.vl is not None:
960 retval["vl"] = self.vl
961 return retval
962
963 def outputs(self):
964 # type: () -> dict[str, SSAVal]
965 return {"RT": self.RT, "RS": self.RS}
966
967 def __init__(self, fn, RA, RB, RC, is_div, vl):
968 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None
969 super().__init__(fn)
970 self.RT = SSAVal(self, "RT", RA.ty)
971 self.RA = RA
972 self.RB = RB
973 self.RC = RC
974 self.RS = SSAVal(self, "RS", RC.ty)
975 self.is_div = is_div
976 self.vl = vl
977 assert_vl_is(vl, RA.ty.length)
978
979 def get_equality_constraints(self):
980 # type: () -> Iterable[EqualityConstraint]
981 yield EqualityConstraint([self.RC], [self.RS])
982
983 def get_extra_interferences(self):
984 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
985 yield self.RT, self.RA
986 yield self.RT, self.RB
987 yield self.RT, self.RC
988 yield self.RT, self.RS
989 yield self.RS, self.RA
990 yield self.RS, self.RB
991
992 def get_asm_lines(self, ctx):
993 # type: (AsmContext) -> list[str]
994 vec = self.RT.ty.length != 1
995 RT = ctx.gpr(self.RT, vec=vec)
996 RA = ctx.gpr(self.RA, vec=vec)
997 RB = ctx.sgpr(self.RB)
998 RC = ctx.sgpr(self.RC)
999 mnemonic = "maddedu"
1000 if self.is_div:
1001 mnemonic = "divmod2du/mrr"
1002 return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
1003
1004 def pre_ra_sim(self, state):
1005 # type: (PreRASimState) -> None
1006 carry = state.gprs[self.RC][0]
1007 RA = state.gprs[self.RA]
1008 RB = state.gprs[self.RB][0]
1009 RT = [0] * self.RT.ty.length
1010 if self.is_div:
1011 for i in reversed(range(self.RT.ty.length)):
1012 if carry < RB and RB != 0:
1013 div, mod = divmod((carry << 64) | RA[i], RB)
1014 RT[i] = div & GPR_VALUE_MASK
1015 carry = mod & GPR_VALUE_MASK
1016 else:
1017 RT[i] = GPR_VALUE_MASK
1018 carry = 0
1019 else:
1020 for i in range(self.RT.ty.length):
1021 v = RA[i] * RB + carry
1022 carry = v >> 64
1023 RT[i] = v & GPR_VALUE_MASK
1024 state.gprs[self.RS] = carry,
1025 state.gprs[self.RT] = tuple(RT)
1026
1027
1028 @final
1029 @unique
1030 class ShiftKind(Enum):
1031 Sl = "sl"
1032 Sr = "sr"
1033 Sra = "sra"
1034
1035 def make_big_int_carry_in(self, fn, inp):
1036 # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]]
1037 if self is ShiftKind.Sl or self is ShiftKind.Sr:
1038 li = OpLI(fn, 0)
1039 return li.out, [li]
1040 else:
1041 assert self is ShiftKind.Sra
1042 split = OpSplit(fn, inp, [inp.ty.length - 1])
1043 shr = OpShiftImm(fn, split.results[1], sh=63, kind=ShiftKind.Sra)
1044 return shr.out, [split, shr]
1045
1046 def make_big_int_shift(self, fn, inp, sh, vl):
1047 # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]]
1048 carry_in, ops = self.make_big_int_carry_in(fn, inp)
1049 big_int_shift = OpBigIntShift(fn, inp, sh, carry_in, kind=self, vl=vl)
1050 ops.append(big_int_shift)
1051 return big_int_shift.out, ops
1052
1053
1054 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1055 @final
1056 class OpBigIntShift(Op):
1057 __slots__ = "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl"
1058
1059 def inputs(self):
1060 # type: () -> dict[str, SSAVal]
1061 retval = {} # type: dict[str, SSAVal[Any]]
1062 retval["inp"] = self.inp
1063 retval["sh"] = self.sh
1064 retval["carry_in"] = self.carry_in
1065 if self.vl is not None:
1066 retval["vl"] = self.vl
1067 return retval
1068
1069 def outputs(self):
1070 # type: () -> dict[str, SSAVal]
1071 return {"out": self.out, "_out_padding": self._out_padding}
1072
1073 def __init__(self, fn, inp, sh, carry_in, kind, vl=None):
1074 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None
1075 super().__init__(fn)
1076 self.out = SSAVal(self, "out", inp.ty)
1077 self._out_padding = SSAVal(self, "_out_padding", GPRRangeType())
1078 self.carry_in = carry_in
1079 self.inp = inp
1080 self.sh = sh
1081 self.kind = kind
1082 self.vl = vl
1083 assert_vl_is(vl, inp.ty.length)
1084
1085 def get_extra_interferences(self):
1086 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1087 yield self.out, self.sh
1088
1089 def get_equality_constraints(self):
1090 # type: () -> Iterable[EqualityConstraint]
1091 if self.kind is ShiftKind.Sl:
1092 yield EqualityConstraint([self.carry_in, self.inp],
1093 [self.out, self._out_padding])
1094 else:
1095 assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
1096 yield EqualityConstraint([self.inp, self.carry_in],
1097 [self._out_padding, self.out])
1098
1099 def get_asm_lines(self, ctx):
1100 # type: (AsmContext) -> list[str]
1101 vec = self.out.ty.length != 1
1102 if self.kind is ShiftKind.Sl:
1103 RT = ctx.gpr(self.out, vec=vec)
1104 RA = ctx.gpr(self.out, vec=vec, offset=-1)
1105 RB = ctx.sgpr(self.sh)
1106 mrr = "/mrr" if vec else ""
1107 return [f"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"]
1108 else:
1109 assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
1110 RT = ctx.gpr(self.out, vec=vec)
1111 RA = ctx.gpr(self.out, vec=vec, offset=1)
1112 RB = ctx.sgpr(self.sh)
1113 return [f"sv.dsrd {RT}, {RA}, {RB}, 1"]
1114
1115 def pre_ra_sim(self, state):
1116 # type: (PreRASimState) -> None
1117 out = [0] * self.out.ty.length
1118 carry = state.gprs[self.carry_in][0]
1119 sh = state.gprs[self.sh][0] % 64
1120 if self.kind is ShiftKind.Sl:
1121 inp = carry, *state.gprs[self.inp]
1122 for i in reversed(range(self.out.ty.length)):
1123 v = inp[i] | (inp[i + 1] << 64)
1124 v <<= sh
1125 out[i] = (v >> 64) & GPR_VALUE_MASK
1126 else:
1127 assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
1128 inp = *state.gprs[self.inp], carry
1129 for i in range(self.out.ty.length):
1130 v = inp[i] | (inp[i + 1] << 64)
1131 v >>= sh
1132 out[i] = v & GPR_VALUE_MASK
1133 # state.gprs[self._out_padding] is intentionally not written
1134 state.gprs[self.out] = tuple(out)
1135
1136
1137 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1138 @final
1139 class OpShiftImm(Op):
1140 __slots__ = "out", "inp", "sh", "kind", "ca_out"
1141
1142 def inputs(self):
1143 # type: () -> dict[str, SSAVal]
1144 return {"inp": self.inp}
1145
1146 def outputs(self):
1147 # type: () -> dict[str, SSAVal]
1148 if self.ca_out is not None:
1149 return {"out": self.out, "ca_out": self.ca_out}
1150 return {"out": self.out}
1151
1152 def __init__(self, fn, inp, sh, kind):
1153 # type: (Fn, SSAGPR, int, ShiftKind) -> None
1154 super().__init__(fn)
1155 self.out = SSAVal(self, "out", inp.ty)
1156 self.inp = inp
1157 if not (0 <= sh < 64):
1158 raise ValueError("shift amount out of range")
1159 self.sh = sh
1160 self.kind = kind
1161 if self.kind is ShiftKind.Sra:
1162 self.ca_out = SSAVal(self, "ca_out", CAType())
1163 else:
1164 self.ca_out = None
1165
1166 def get_asm_lines(self, ctx):
1167 # type: (AsmContext) -> list[str]
1168 out = ctx.sgpr(self.out)
1169 inp = ctx.sgpr(self.inp)
1170 if self.kind is ShiftKind.Sl:
1171 mnemonic = "rldicr"
1172 args = f"{self.sh}, {63 - self.sh}"
1173 elif self.kind is ShiftKind.Sr:
1174 mnemonic = "rldicl"
1175 v = (64 - self.sh) % 64
1176 args = f"{v}, {self.sh}"
1177 else:
1178 assert self.kind is ShiftKind.Sra
1179 mnemonic = "sradi"
1180 args = f"{self.sh}"
1181 if ctx.needs_sv(self.out, self.inp):
1182 return [f"sv.{mnemonic} {out}, {inp}, {args}"]
1183 return [f"{mnemonic} {out}, {inp}, {args}"]
1184
1185 def pre_ra_sim(self, state):
1186 # type: (PreRASimState) -> None
1187 inp = state.gprs[self.inp][0]
1188 if self.kind is ShiftKind.Sl:
1189 assert self.ca_out is None
1190 out = inp << self.sh
1191 elif self.kind is ShiftKind.Sr:
1192 assert self.ca_out is None
1193 out = inp >> self.sh
1194 else:
1195 assert self.kind is ShiftKind.Sra
1196 assert self.ca_out is not None
1197 if inp & (1 << 63): # sign extend
1198 inp -= 1 << 64
1199 out = inp >> self.sh
1200 ca = inp < 0 and (out << self.sh) != inp
1201 state.CAs[self.ca_out] = ca
1202 state.gprs[self.out] = out,
1203
1204
1205 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1206 @final
1207 class OpLI(Op):
1208 __slots__ = "out", "value", "vl"
1209
1210 def inputs(self):
1211 # type: () -> dict[str, SSAVal]
1212 retval = {} # type: dict[str, SSAVal[Any]]
1213 if self.vl is not None:
1214 retval["vl"] = self.vl
1215 return retval
1216
1217 def outputs(self):
1218 # type: () -> dict[str, SSAVal]
1219 return {"out": self.out}
1220
1221 def __init__(self, fn, value, vl=None):
1222 # type: (Fn, int, SSAKnownVL | None) -> None
1223 super().__init__(fn)
1224 if vl is None:
1225 length = 1
1226 else:
1227 length = vl.ty.length
1228 self.out = SSAVal(self, "out", GPRRangeType(length))
1229 if not (-1 << 15 <= value <= (1 << 15) - 1):
1230 raise ValueError(f"value out of range: {value}")
1231 self.value = value
1232 self.vl = vl
1233 assert_vl_is(vl, length)
1234
1235 def get_asm_lines(self, ctx):
1236 # type: (AsmContext) -> list[str]
1237 vec = self.out.ty.length != 1
1238 out = ctx.gpr(self.out, vec=vec)
1239 if ctx.needs_sv(self.out):
1240 return [f"sv.addi {out}, 0, {self.value}"]
1241 return [f"addi {out}, 0, {self.value}"]
1242
1243 def pre_ra_sim(self, state):
1244 # type: (PreRASimState) -> None
1245 value = self.value & GPR_VALUE_MASK
1246 state.gprs[self.out] = (value,) * self.out.ty.length
1247
1248
1249 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1250 @final
1251 class OpSetCA(Op):
1252 __slots__ = "out", "value"
1253
1254 def inputs(self):
1255 # type: () -> dict[str, SSAVal]
1256 return {}
1257
1258 def outputs(self):
1259 # type: () -> dict[str, SSAVal]
1260 return {"out": self.out}
1261
1262 def __init__(self, fn, value):
1263 # type: (Fn, bool) -> None
1264 super().__init__(fn)
1265 self.out = SSAVal(self, "out", CAType())
1266 self.value = value
1267
1268 def get_asm_lines(self, ctx):
1269 # type: (AsmContext) -> list[str]
1270 if self.value:
1271 return ["subfic 0, 0, -1"]
1272 return ["addic 0, 0, 0"]
1273
1274 def pre_ra_sim(self, state):
1275 # type: (PreRASimState) -> None
1276 state.CAs[self.out] = self.value
1277
1278
1279 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1280 @final
1281 class OpLoad(Op):
1282 __slots__ = "RT", "RA", "offset", "mem", "vl"
1283
1284 def inputs(self):
1285 # type: () -> dict[str, SSAVal]
1286 retval = {} # type: dict[str, SSAVal[Any]]
1287 retval["RA"] = self.RA
1288 retval["mem"] = self.mem
1289 if self.vl is not None:
1290 retval["vl"] = self.vl
1291 return retval
1292
1293 def outputs(self):
1294 # type: () -> dict[str, SSAVal]
1295 return {"RT": self.RT}
1296
1297 def __init__(self, fn, RA, offset, mem, vl=None):
1298 # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1299 super().__init__(fn)
1300 if vl is None:
1301 length = 1
1302 else:
1303 length = vl.ty.length
1304 self.RT = SSAVal(self, "RT", GPRRangeType(length))
1305 self.RA = RA
1306 if not (-1 << 15 <= offset <= (1 << 15) - 1):
1307 raise ValueError(f"offset out of range: {offset}")
1308 if offset % 4 != 0:
1309 raise ValueError(f"offset not aligned: {offset}")
1310 self.offset = offset
1311 self.mem = mem
1312 self.vl = vl
1313 assert_vl_is(vl, length)
1314
1315 def get_extra_interferences(self):
1316 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1317 if self.RT.ty.length > 1:
1318 yield self.RT, self.RA
1319
1320 def get_asm_lines(self, ctx):
1321 # type: (AsmContext) -> list[str]
1322 RT = ctx.gpr(self.RT, vec=self.RT.ty.length != 1)
1323 RA = ctx.sgpr(self.RA)
1324 if ctx.needs_sv(self.RT, self.RA):
1325 return [f"sv.ld {RT}, {self.offset}({RA})"]
1326 return [f"ld {RT}, {self.offset}({RA})"]
1327
1328 def pre_ra_sim(self, state):
1329 # type: (PreRASimState) -> None
1330 addr = state.gprs[self.RA][0]
1331 addr += self.offset
1332 RT = [0] * self.RT.ty.length
1333 mem = state.global_mems[self.mem]
1334 for i in range(self.RT.ty.length):
1335 cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
1336 if cur_addr % GPR_SIZE_IN_BYTES != 0:
1337 raise ValueError(f"can't load from unaligned address: "
1338 f"{cur_addr:#x}")
1339 for j in range(GPR_SIZE_IN_BYTES):
1340 byte_val = mem.get(cur_addr + j, 0) & 0xFF
1341 RT[i] |= byte_val << (j * 8)
1342 state.gprs[self.RT] = tuple(RT)
1343
1344
1345 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1346 @final
1347 class OpStore(Op):
1348 __slots__ = "RS", "RA", "offset", "mem_in", "mem_out", "vl"
1349
1350 def inputs(self):
1351 # type: () -> dict[str, SSAVal]
1352 retval = {} # type: dict[str, SSAVal[Any]]
1353 retval["RS"] = self.RS
1354 retval["RA"] = self.RA
1355 retval["mem_in"] = self.mem_in
1356 if self.vl is not None:
1357 retval["vl"] = self.vl
1358 return retval
1359
1360 def outputs(self):
1361 # type: () -> dict[str, SSAVal]
1362 return {"mem_out": self.mem_out}
1363
1364 def __init__(self, fn, RS, RA, offset, mem_in, vl=None):
1365 # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1366 super().__init__(fn)
1367 self.RS = RS
1368 self.RA = RA
1369 if not (-1 << 15 <= offset <= (1 << 15) - 1):
1370 raise ValueError(f"offset out of range: {offset}")
1371 if offset % 4 != 0:
1372 raise ValueError(f"offset not aligned: {offset}")
1373 self.offset = offset
1374 self.mem_in = mem_in
1375 self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
1376 self.vl = vl
1377 assert_vl_is(vl, RS.ty.length)
1378
1379 def get_asm_lines(self, ctx):
1380 # type: (AsmContext) -> list[str]
1381 RS = ctx.gpr(self.RS, vec=self.RS.ty.length != 1)
1382 RA = ctx.sgpr(self.RA)
1383 if ctx.needs_sv(self.RS, self.RA):
1384 return [f"sv.std {RS}, {self.offset}({RA})"]
1385 return [f"std {RS}, {self.offset}({RA})"]
1386
1387 def pre_ra_sim(self, state):
1388 # type: (PreRASimState) -> None
1389 mem = dict(state.global_mems[self.mem_in])
1390 addr = state.gprs[self.RA][0]
1391 addr += self.offset
1392 RS = state.gprs[self.RS]
1393 for i in range(self.RS.ty.length):
1394 cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
1395 if cur_addr % GPR_SIZE_IN_BYTES != 0:
1396 raise ValueError(f"can't store to unaligned address: "
1397 f"{cur_addr:#x}")
1398 for j in range(GPR_SIZE_IN_BYTES):
1399 mem[cur_addr + j] = (RS[i] >> (j * 8)) & 0xFF
1400 state.global_mems[self.mem_out] = FMap(mem)
1401
1402
1403 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1404 @final
1405 class OpFuncArg(Op):
1406 __slots__ = "out",
1407
1408 def inputs(self):
1409 # type: () -> dict[str, SSAVal]
1410 return {}
1411
1412 def outputs(self):
1413 # type: () -> dict[str, SSAVal]
1414 return {"out": self.out}
1415
1416 def __init__(self, fn, ty):
1417 # type: (Fn, FixedGPRRangeType) -> None
1418 super().__init__(fn)
1419 self.out = SSAVal(self, "out", ty)
1420
1421 def get_asm_lines(self, ctx):
1422 # type: (AsmContext) -> list[str]
1423 return []
1424
1425 def pre_ra_sim(self, state):
1426 # type: (PreRASimState) -> None
1427 if self.out not in state.fixed_gprs:
1428 state.fixed_gprs[self.out] = (0,) * self.out.ty.length
1429
1430
1431 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1432 @final
1433 class OpInputMem(Op):
1434 __slots__ = "out",
1435
1436 def inputs(self):
1437 # type: () -> dict[str, SSAVal]
1438 return {}
1439
1440 def outputs(self):
1441 # type: () -> dict[str, SSAVal]
1442 return {"out": self.out}
1443
1444 def __init__(self, fn):
1445 # type: (Fn) -> None
1446 super().__init__(fn)
1447 self.out = SSAVal(self, "out", GlobalMemType())
1448
1449 def get_asm_lines(self, ctx):
1450 # type: (AsmContext) -> list[str]
1451 return []
1452
1453 def pre_ra_sim(self, state):
1454 # type: (PreRASimState) -> None
1455 if self.out not in state.global_mems:
1456 state.global_mems[self.out] = FMap()
1457
1458
1459 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1460 @final
1461 class OpSetVLImm(Op):
1462 __slots__ = "out",
1463
1464 def inputs(self):
1465 # type: () -> dict[str, SSAVal]
1466 return {}
1467
1468 def outputs(self):
1469 # type: () -> dict[str, SSAVal]
1470 return {"out": self.out}
1471
1472 def __init__(self, fn, length):
1473 # type: (Fn, int) -> None
1474 super().__init__(fn)
1475 self.out = SSAVal(self, "out", KnownVLType(length))
1476
1477 def get_asm_lines(self, ctx):
1478 # type: (AsmContext) -> list[str]
1479 return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
1480
1481 def pre_ra_sim(self, state):
1482 # type: (PreRASimState) -> None
1483 state.VLs[self.out] = self.out.ty.length
1484
1485
1486 def op_set_to_list(ops):
1487 # type: (Iterable[Op]) -> list[Op]
1488 worklists = [{}] # type: list[dict[Op, None]]
1489 inps_to_ops_map = defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
1490 ops_to_pending_input_count_map = {} # type: dict[Op, int]
1491 for op in ops:
1492 input_count = 0
1493 for val in op.inputs().values():
1494 input_count += 1
1495 inps_to_ops_map[val][op] = None
1496 while len(worklists) <= input_count:
1497 worklists.append({})
1498 ops_to_pending_input_count_map[op] = input_count
1499 worklists[input_count][op] = None
1500 retval = [] # type: list[Op]
1501 ready_vals = OSet() # type: OSet[SSAVal]
1502 while len(worklists[0]) != 0:
1503 writing_op = next(iter(worklists[0]))
1504 del worklists[0][writing_op]
1505 retval.append(writing_op)
1506 for val in writing_op.outputs().values():
1507 if val in ready_vals:
1508 raise ValueError(f"multiple instructions must not write "
1509 f"to the same SSA value: {val}")
1510 ready_vals.add(val)
1511 for reading_op in inps_to_ops_map[val]:
1512 pending = ops_to_pending_input_count_map[reading_op]
1513 del worklists[pending][reading_op]
1514 pending -= 1
1515 worklists[pending][reading_op] = None
1516 ops_to_pending_input_count_map[reading_op] = pending
1517 for worklist in worklists:
1518 for op in worklist:
1519 raise ValueError(f"instruction is part of a dependency loop or "
1520 f"its inputs are never written: {op}")
1521 return retval
1522
1523
1524 def generate_assembly(ops, assigned_registers=None):
1525 # type: (list[Op], dict[SSAVal, RegLoc] | None) -> list[str]
1526 if assigned_registers is None:
1527 from bigint_presentation_code.register_allocator import \
1528 allocate_registers
1529 assigned_registers = allocate_registers(ops)
1530 ctx = AsmContext(assigned_registers)
1531 retval = [] # list[str]
1532 for op in ops:
1533 retval.extend(op.get_asm_lines(ctx))
1534 retval.append("bclr 20, 0, 0")
1535 return retval