517e5423e267ed031da12ef08c01af7c7090b39e
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 """
2 Compiler IR for Toom-Cook algorithm generator for SVP64
3
4 This assumes VL != 0 throughout.
5 """
6
7 from abc import ABCMeta, abstractmethod
8 from collections import defaultdict
9 from enum import Enum, EnumMeta, unique
10 from functools import lru_cache
11 from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast
12
13 from nmutil.plain_data import fields, plain_data
14
15 from bigint_presentation_code.util import OFSet, OSet, final
16
17
18 class ABCEnumMeta(EnumMeta, ABCMeta):
19 pass
20
21
22 class RegLoc(metaclass=ABCMeta):
23 __slots__ = ()
24
25 @abstractmethod
26 def conflicts(self, other):
27 # type: (RegLoc) -> bool
28 ...
29
30 def get_subreg_at_offset(self, subreg_type, offset):
31 # type: (RegType, int) -> RegLoc
32 if self not in subreg_type.reg_class:
33 raise ValueError(f"register not a member of subreg_type: "
34 f"reg={self} subreg_type={subreg_type}")
35 if offset != 0:
36 raise ValueError(f"non-zero sub-register offset not supported "
37 f"for register: {self}")
38 return self
39
40
41 GPR_COUNT = 128
42
43
44 @plain_data(frozen=True, unsafe_hash=True)
45 @final
46 class GPRRange(RegLoc, Sequence["GPRRange"]):
47 __slots__ = "start", "length"
48
49 def __init__(self, start, length=None):
50 # type: (int | range, int | None) -> None
51 if isinstance(start, range):
52 if length is not None:
53 raise TypeError("can't specify length when input is a range")
54 if start.step != 1:
55 raise ValueError("range must have a step of 1")
56 length = len(start)
57 start = start.start
58 elif length is None:
59 length = 1
60 if length <= 0 or start < 0 or start + length > GPR_COUNT:
61 raise ValueError("invalid GPRRange")
62 self.start = start
63 self.length = length
64
65 @property
66 def stop(self):
67 return self.start + self.length
68
69 @property
70 def step(self):
71 return 1
72
73 @property
74 def range(self):
75 return range(self.start, self.stop, self.step)
76
77 def __len__(self):
78 return self.length
79
80 def __getitem__(self, item):
81 # type: (int | slice) -> GPRRange
82 return GPRRange(self.range[item])
83
84 def __contains__(self, value):
85 # type: (GPRRange) -> bool
86 return value.start >= self.start and value.stop <= self.stop
87
88 def index(self, sub, start=None, end=None):
89 # type: (GPRRange, int | None, int | None) -> int
90 r = self.range[start:end]
91 if sub.start < r.start or sub.stop > r.stop:
92 raise ValueError("GPR range not found")
93 return sub.start - self.start
94
95 def count(self, sub, start=None, end=None):
96 # type: (GPRRange, int | None, int | None) -> int
97 r = self.range[start:end]
98 if len(r) == 0:
99 return 0
100 return int(sub in GPRRange(r))
101
102 def conflicts(self, other):
103 # type: (RegLoc) -> bool
104 if isinstance(other, GPRRange):
105 return self.stop > other.start and other.stop > self.start
106 return False
107
108 def get_subreg_at_offset(self, subreg_type, offset):
109 # type: (RegType, int) -> GPRRange
110 if not isinstance(subreg_type, (GPRRangeType, FixedGPRRangeType)):
111 raise ValueError(f"subreg_type is not a FixedGPRRangeType or "
112 f"GPRRangeType: {subreg_type}")
113 if offset < 0 or offset + subreg_type.length > self.stop:
114 raise ValueError(f"sub-register offset is out of range: {offset}")
115 return GPRRange(self.start + offset, subreg_type.length)
116
117
118 SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
119
120
121 @final
122 @unique
123 class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
124 CA = "CA"
125
126 def conflicts(self, other):
127 # type: (RegLoc) -> bool
128 if isinstance(other, XERBit):
129 return self == other
130 return False
131
132
133 @final
134 @unique
135 class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
136 """singleton representing all non-StackSlot memory -- treated as a single
137 physical register for register allocation purposes.
138 """
139 GlobalMem = "GlobalMem"
140
141 def conflicts(self, other):
142 # type: (RegLoc) -> bool
143 if isinstance(other, GlobalMem):
144 return self == other
145 return False
146
147
148 @final
149 @unique
150 class VL(RegLoc, Enum, metaclass=ABCEnumMeta):
151 VL_MAXVL = "VL_MAXVL"
152 """VL and MAXVL"""
153
154 def conflicts(self, other):
155 # type: (RegLoc) -> bool
156 if isinstance(other, VL):
157 return self == other
158 return False
159
160
161 @final
162 class RegClass(OFSet[RegLoc]):
163 """ an ordered set of registers.
164 earlier registers are preferred by the register allocator.
165 """
166
167 @lru_cache(maxsize=None, typed=True)
168 def max_conflicts_with(self, other):
169 # type: (RegClass | RegLoc) -> int
170 """the largest number of registers in `self` that a single register
171 from `other` can conflict with
172 """
173 if isinstance(other, RegClass):
174 return max(self.max_conflicts_with(i) for i in other)
175 else:
176 return sum(other.conflicts(i) for i in self)
177
178
179 @plain_data(frozen=True, unsafe_hash=True)
180 class RegType(metaclass=ABCMeta):
181 __slots__ = ()
182
183 @property
184 @abstractmethod
185 def reg_class(self):
186 # type: () -> RegClass
187 return ...
188
189
190 _RegType = TypeVar("_RegType", bound=RegType)
191 _RegLoc = TypeVar("_RegLoc", bound=RegLoc)
192
193
194 @plain_data(frozen=True, eq=False)
195 @final
196 class GPRRangeType(RegType):
197 __slots__ = "length",
198
199 def __init__(self, length=1):
200 # type: (int) -> None
201 if length < 1 or length > GPR_COUNT:
202 raise ValueError("invalid length")
203 self.length = length
204
205 @staticmethod
206 @lru_cache(maxsize=None)
207 def __get_reg_class(length):
208 # type: (int) -> RegClass
209 regs = []
210 for start in range(GPR_COUNT - length):
211 reg = GPRRange(start, length)
212 if any(i in reg for i in SPECIAL_GPRS):
213 continue
214 regs.append(reg)
215 return RegClass(regs)
216
217 @property
218 @final
219 def reg_class(self):
220 # type: () -> RegClass
221 return GPRRangeType.__get_reg_class(self.length)
222
223 @final
224 def __eq__(self, other):
225 if isinstance(other, GPRRangeType):
226 return self.length == other.length
227 return False
228
229 @final
230 def __hash__(self):
231 return hash(self.length)
232
233
234 GPRType = GPRRangeType
235 """a length=1 GPRRangeType"""
236
237
238 @plain_data(frozen=True, unsafe_hash=True)
239 @final
240 class FixedGPRRangeType(RegType):
241 __slots__ = "reg",
242
243 def __init__(self, reg):
244 # type: (GPRRange) -> None
245 self.reg = reg
246
247 @property
248 def reg_class(self):
249 # type: () -> RegClass
250 return RegClass([self.reg])
251
252 @property
253 def length(self):
254 # type: () -> int
255 return self.reg.length
256
257
258 @plain_data(frozen=True, unsafe_hash=True)
259 @final
260 class CAType(RegType):
261 __slots__ = ()
262
263 @property
264 def reg_class(self):
265 # type: () -> RegClass
266 return RegClass([XERBit.CA])
267
268
269 @plain_data(frozen=True, unsafe_hash=True)
270 @final
271 class GlobalMemType(RegType):
272 __slots__ = ()
273
274 @property
275 def reg_class(self):
276 # type: () -> RegClass
277 return RegClass([GlobalMem.GlobalMem])
278
279
280 @plain_data(frozen=True, unsafe_hash=True)
281 @final
282 class KnownVLType(RegType):
283 __slots__ = "length",
284
285 def __init__(self, length):
286 # type: (int) -> None
287 if not (0 < length <= 64):
288 raise ValueError("invalid VL value")
289 self.length = length
290
291 @property
292 def reg_class(self):
293 # type: () -> RegClass
294 return RegClass([VL.VL_MAXVL])
295
296
297 def assert_vl_is(vl, expected_vl):
298 # type: (SSAKnownVL | KnownVLType | int | None, int) -> None
299 if vl is None:
300 vl = 1
301 elif isinstance(vl, SSAVal):
302 vl = vl.ty.length
303 elif isinstance(vl, KnownVLType):
304 vl = vl.length
305 if vl != expected_vl:
306 raise ValueError(
307 f"wrong VL: expected {expected_vl} got {vl}")
308
309
310 STACK_SLOT_SIZE = 8
311
312
313 @plain_data(frozen=True, unsafe_hash=True)
314 @final
315 class StackSlot(RegLoc):
316 __slots__ = "start_slot", "length_in_slots",
317
318 def __init__(self, start_slot, length_in_slots):
319 # type: (int, int) -> None
320 self.start_slot = start_slot
321 if length_in_slots < 1:
322 raise ValueError("invalid length_in_slots")
323 self.length_in_slots = length_in_slots
324
325 @property
326 def stop_slot(self):
327 return self.start_slot + self.length_in_slots
328
329 @property
330 def start_byte(self):
331 return self.start_slot * STACK_SLOT_SIZE
332
333 def conflicts(self, other):
334 # type: (RegLoc) -> bool
335 if isinstance(other, StackSlot):
336 return (self.stop_slot > other.start_slot
337 and other.stop_slot > self.start_slot)
338 return False
339
340 def get_subreg_at_offset(self, subreg_type, offset):
341 # type: (RegType, int) -> StackSlot
342 if not isinstance(subreg_type, StackSlotType):
343 raise ValueError(f"subreg_type is not a "
344 f"StackSlotType: {subreg_type}")
345 if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
346 raise ValueError(f"sub-register offset is out of range: {offset}")
347 return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
348
349
350 STACK_SLOT_COUNT = 128
351
352
353 @plain_data(frozen=True, eq=False)
354 @final
355 class StackSlotType(RegType):
356 __slots__ = "length_in_slots",
357
358 def __init__(self, length_in_slots=1):
359 # type: (int) -> None
360 if length_in_slots < 1:
361 raise ValueError("invalid length_in_slots")
362 self.length_in_slots = length_in_slots
363
364 @staticmethod
365 @lru_cache(maxsize=None)
366 def __get_reg_class(length_in_slots):
367 # type: (int) -> RegClass
368 regs = []
369 for start in range(STACK_SLOT_COUNT - length_in_slots):
370 reg = StackSlot(start, length_in_slots)
371 regs.append(reg)
372 return RegClass(regs)
373
374 @property
375 def reg_class(self):
376 # type: () -> RegClass
377 return StackSlotType.__get_reg_class(self.length_in_slots)
378
379 @final
380 def __eq__(self, other):
381 if isinstance(other, StackSlotType):
382 return self.length_in_slots == other.length_in_slots
383 return False
384
385 @final
386 def __hash__(self):
387 return hash(self.length_in_slots)
388
389
390 @plain_data(frozen=True, eq=False, repr=False)
391 @final
392 class SSAVal(Generic[_RegType]):
393 __slots__ = "op", "arg_name", "ty",
394
395 def __init__(self, op, arg_name, ty):
396 # type: (Op, str, _RegType) -> None
397 self.op = op
398 """the Op that writes this SSAVal"""
399
400 self.arg_name = arg_name
401 """the name of the argument of self.op that writes this SSAVal"""
402
403 self.ty = ty
404
405 def __eq__(self, rhs):
406 if isinstance(rhs, SSAVal):
407 return (self.op is rhs.op
408 and self.arg_name == rhs.arg_name)
409 return False
410
411 def __hash__(self):
412 return hash((id(self.op), self.arg_name))
413
414 def __repr__(self, long=False):
415 if not long:
416 return f"<#{self.op.id}.{self.arg_name}>"
417 fields_list = []
418 for name in fields(self):
419 v = getattr(self, name, None)
420 if v is not None:
421 if name == "op":
422 v = v.__repr__(just_id=True)
423 else:
424 v = repr(v)
425 fields_list.append(f"{name}={v}")
426 fields_str = ", ".join(fields_list)
427 return f"SSAVal({fields_str})"
428
429
430 SSAGPRRange = SSAVal[GPRRangeType]
431 SSAGPR = SSAVal[GPRType]
432 SSAKnownVL = SSAVal[KnownVLType]
433
434
435 @final
436 @plain_data(unsafe_hash=True, frozen=True)
437 class EqualityConstraint:
438 __slots__ = "lhs", "rhs"
439
440 def __init__(self, lhs, rhs):
441 # type: (list[SSAVal], list[SSAVal]) -> None
442 self.lhs = lhs
443 self.rhs = rhs
444 if len(lhs) == 0 or len(rhs) == 0:
445 raise ValueError("can't constrain an empty list to be equal")
446
447
448 @final
449 class Fn:
450 __slots__ = "ops",
451
452 def __init__(self):
453 # type: () -> None
454 self.ops = [] # type: list[Op]
455
456 def __repr__(self, short=False):
457 if short:
458 return "<Fn>"
459 ops = ", ".join(op.__repr__(just_id=True) for op in self.ops)
460 return f"<Fn([{ops}])>"
461
462
463 class _NotSet:
464 """ helper for __repr__ for when fields aren't set """
465
466 def __repr__(self):
467 return "<not set>"
468
469
470 _NOT_SET = _NotSet()
471
472
473 @plain_data(frozen=True, unsafe_hash=True)
474 class AsmTemplateSegment(Generic[_RegType], metaclass=ABCMeta):
475 __slots__ = "ssa_val",
476
477 def __init__(self, ssa_val):
478 # type: (SSAVal[_RegType]) -> None
479 self.ssa_val = ssa_val
480
481 def render(self, regs):
482 # type: (dict[SSAVal, RegLoc]) -> str
483 return self._render(regs[self.ssa_val])
484
485 @abstractmethod
486 def _render(self, reg):
487 # type: (RegLoc) -> str
488 ...
489
490
491 @plain_data(frozen=True, unsafe_hash=True)
492 @final
493 class ATSGPR(AsmTemplateSegment[GPRRangeType]):
494 __slots__ = "offset",
495
496 def __init__(self, ssa_val, offset=0):
497 # type: (SSAGPRRange, int) -> None
498 super().__init__(ssa_val)
499 self.offset = offset
500
501 def _render(self, reg):
502 # type: (RegLoc) -> str
503 if not isinstance(reg, GPRRange):
504 raise TypeError()
505 return str(reg.start + self.offset)
506
507
508 @plain_data(frozen=True, unsafe_hash=True)
509 @final
510 class ATSStackSlot(AsmTemplateSegment[StackSlotType]):
511 __slots__ = ()
512
513 def _render(self, reg):
514 # type: (RegLoc) -> str
515 if not isinstance(reg, StackSlot):
516 raise TypeError()
517 return f"{reg.start_slot}(1)"
518
519
520 @plain_data(frozen=True, unsafe_hash=True)
521 @final
522 class ATSCopyGPRRange(AsmTemplateSegment["GPRRangeType | FixedGPRRangeType"]):
523 __slots__ = "src_ssa_val",
524
525 def __init__(self, ssa_val, src_ssa_val):
526 # type: (SSAVal[GPRRangeType | FixedGPRRangeType], SSAVal[GPRRangeType | FixedGPRRangeType]) -> None
527 self.ssa_val = ssa_val
528 self.src_ssa_val = src_ssa_val
529
530 def render(self, regs):
531 # type: (dict[SSAVal, RegLoc]) -> str
532 src = regs[self.src_ssa_val]
533 dest = regs[self.ssa_val]
534 if not isinstance(dest, GPRRange):
535 raise TypeError()
536 if not isinstance(src, GPRRange):
537 raise TypeError()
538 if src.length != dest.length:
539 raise ValueError()
540 if src == dest:
541 return ""
542 mrr = ""
543 sv_ = "sv."
544 if src.length == 1:
545 sv_ = ""
546 elif src.conflicts(dest) and src.start > dest.start:
547 mrr = "/mrr"
548 return f"{sv_}or{mrr} *{dest.start}, *{src.start}, *{src.start}\n"
549
550 def _render(self, reg):
551 # type: (RegLoc) -> str
552 raise TypeError("must call self.render")
553
554
555 @final
556 class AsmTemplate(Sequence["str | AsmTemplateSegment"]):
557 @staticmethod
558 def __process_segments(segments):
559 # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> Iterable[str | AsmTemplateSegment]
560 for i in segments:
561 if isinstance(i, AsmTemplate):
562 yield from i
563 else:
564 yield i
565
566 def __init__(self, segments=()):
567 # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> None
568 self.__segments = tuple(self.__process_segments(segments))
569
570 def __getitem__(self, index):
571 # type: (int) -> str | AsmTemplateSegment
572 return self.__segments[index]
573
574 def __len__(self):
575 return len(self.__segments)
576
577 def __iter__(self):
578 return iter(self.__segments)
579
580 def __hash__(self):
581 return hash(self.__segments)
582
583 def render(self, regs):
584 # type: (dict[SSAVal, RegLoc]) -> str
585 retval = [] # type: list[str]
586 for segment in self:
587 if isinstance(segment, AsmTemplateSegment):
588 retval.append(segment.render(regs))
589 else:
590 retval.append(segment)
591 return "".join(retval)
592
593
594 @final
595 class AsmContext:
596 def __init__(self, assigned_registers):
597 # type: (dict[SSAVal, RegLoc]) -> None
598 self.__assigned_registers = assigned_registers
599
600 def reg(self, ssa_val, expected_ty):
601 # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc
602 try:
603 reg = self.__assigned_registers[ssa_val]
604 except KeyError as e:
605 raise ValueError(f"SSAVal not assigned a register: {ssa_val}")
606 wrong_len = (isinstance(reg, GPRRange)
607 and reg.length != ssa_val.ty.length)
608 if not isinstance(reg, expected_ty) or wrong_len:
609 raise TypeError(
610 f"SSAVal is assigned a register of the wrong type: "
611 f"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}")
612 return reg
613
614 def gpr_range(self, ssa_val):
615 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange
616 return self.reg(ssa_val, GPRRange)
617
618 def stack_slot(self, ssa_val):
619 # type: (SSAVal[StackSlotType]) -> StackSlot
620 return self.reg(ssa_val, StackSlot)
621
622 def gpr(self, ssa_val, vec, offset=0):
623 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str
624 reg = self.gpr_range(ssa_val).start + offset
625 return "*" * vec + str(reg)
626
627 def vgpr(self, ssa_val, offset=0):
628 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str
629 return self.gpr(ssa_val=ssa_val, vec=True, offset=offset)
630
631 def sgpr(self, ssa_val, offset=0):
632 # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str
633 return self.gpr(ssa_val=ssa_val, vec=False, offset=offset)
634
635 def needs_sv(self, *regs):
636 # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool
637 for reg in regs:
638 reg = self.gpr_range(reg)
639 if reg.length != 1 or reg.start >= 32:
640 return True
641 return False
642
643
644 @plain_data(unsafe_hash=True, frozen=True, repr=False)
645 class Op(metaclass=ABCMeta):
646 __slots__ = "id", "fn"
647
648 @abstractmethod
649 def inputs(self):
650 # type: () -> dict[str, SSAVal]
651 ...
652
653 @abstractmethod
654 def outputs(self):
655 # type: () -> dict[str, SSAVal]
656 ...
657
658 def get_equality_constraints(self):
659 # type: () -> Iterable[EqualityConstraint]
660 if False:
661 yield ...
662
663 def get_extra_interferences(self):
664 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
665 if False:
666 yield ...
667
668 def __init__(self, fn):
669 # type: (Fn) -> None
670 self.id = len(fn.ops)
671 fn.ops.append(self)
672 self.fn = fn
673
674 @final
675 def __repr__(self, just_id=False):
676 fields_list = [f"#{self.id}"]
677 outputs = None
678 try:
679 outputs = self.outputs()
680 except AttributeError:
681 pass
682 if not just_id:
683 for name in fields(self):
684 v = getattr(self, name, _NOT_SET)
685 if ((outputs is None or name in outputs)
686 and isinstance(v, SSAVal)):
687 v = v.__repr__(long=True)
688 elif isinstance(v, Fn):
689 v = v.__repr__(short=True)
690 else:
691 v = repr(v)
692 fields_list.append(f"{name}={v}")
693 fields_str = ', '.join(fields_list)
694 return f"{self.__class__.__name__}({fields_str})"
695
696 @abstractmethod
697 def get_asm_lines(self, ctx):
698 # type: (AsmContext) -> list[str]
699 """get the lines of assembly for this Op"""
700 ...
701
702
703 @plain_data(unsafe_hash=True, frozen=True, repr=False)
704 @final
705 class OpLoadFromStackSlot(Op):
706 __slots__ = "dest", "src", "vl"
707
708 def inputs(self):
709 # type: () -> dict[str, SSAVal]
710 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
711 if self.vl is not None:
712 retval["vl"] = self.vl
713 return retval
714
715 def outputs(self):
716 # type: () -> dict[str, SSAVal]
717 return {"dest": self.dest}
718
719 def __init__(self, fn, src, vl=None):
720 # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None
721 super().__init__(fn)
722 self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
723 self.src = src
724 self.vl = vl
725 assert_vl_is(vl, self.dest.ty.length)
726
727 def get_asm_lines(self, ctx):
728 # type: (AsmContext) -> list[str]
729 dest = ctx.gpr(self.dest, vec=self.dest.ty.length != 1)
730 src = ctx.stack_slot(self.src)
731 if ctx.needs_sv(self.dest):
732 return [f"sv.ld {dest}, {src.start_byte}(1)"]
733 return [f"ld {dest}, {src.start_byte}(1)"]
734
735
736 @plain_data(unsafe_hash=True, frozen=True, repr=False)
737 @final
738 class OpStoreToStackSlot(Op):
739 __slots__ = "dest", "src", "vl"
740
741 def inputs(self):
742 # type: () -> dict[str, SSAVal]
743 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
744 if self.vl is not None:
745 retval["vl"] = self.vl
746 return retval
747
748 def outputs(self):
749 # type: () -> dict[str, SSAVal]
750 return {"dest": self.dest}
751
752 def __init__(self, fn, src, vl=None):
753 # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None
754 super().__init__(fn)
755 self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
756 self.src = src
757 self.vl = vl
758 assert_vl_is(vl, src.ty.length)
759
760 def get_asm_lines(self, ctx):
761 # type: (AsmContext) -> list[str]
762 src = ctx.gpr(self.src, vec=self.src.ty.length != 1)
763 dest = ctx.stack_slot(self.dest)
764 if ctx.needs_sv(self.src):
765 return [f"sv.std {src}, {dest.start_byte}(1)"]
766 return [f"std {src}, {dest.start_byte}(1)"]
767
768
769 _RegSrcType = TypeVar("_RegSrcType", bound=RegType)
770
771
772 @plain_data(unsafe_hash=True, frozen=True, repr=False)
773 @final
774 class OpCopy(Op, Generic[_RegSrcType, _RegType]):
775 __slots__ = "dest", "src", "vl"
776
777 def inputs(self):
778 # type: () -> dict[str, SSAVal]
779 retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
780 if self.vl is not None:
781 retval["vl"] = self.vl
782 return retval
783
784 def outputs(self):
785 # type: () -> dict[str, SSAVal]
786 return {"dest": self.dest}
787
788 def __init__(self, fn, src, dest_ty=None, vl=None):
789 # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None
790 super().__init__(fn)
791 if dest_ty is None:
792 dest_ty = cast(_RegType, src.ty)
793 if isinstance(src.ty, GPRRangeType) \
794 and isinstance(dest_ty, FixedGPRRangeType):
795 if src.ty.length != dest_ty.reg.length:
796 raise ValueError(f"incompatible source and destination "
797 f"types: {src.ty} and {dest_ty}")
798 length = src.ty.length
799 elif isinstance(src.ty, FixedGPRRangeType) \
800 and isinstance(dest_ty, GPRRangeType):
801 if src.ty.reg.length != dest_ty.length:
802 raise ValueError(f"incompatible source and destination "
803 f"types: {src.ty} and {dest_ty}")
804 length = src.ty.length
805 elif src.ty != dest_ty:
806 raise ValueError(f"incompatible source and destination "
807 f"types: {src.ty} and {dest_ty}")
808 elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)):
809 length = src.ty.length
810 else:
811 length = 1
812
813 self.dest = SSAVal(self, "dest", dest_ty) # type: SSAVal[_RegType]
814 self.src = src
815 self.vl = vl
816 assert_vl_is(vl, length)
817
818 def get_asm_lines(self, ctx):
819 # type: (AsmContext) -> list[str]
820 if ctx.reg(self.src, RegLoc) == ctx.reg(self.dest, RegLoc):
821 return []
822 if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
823 isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
824 vec = self.dest.ty.length != 1
825 dest = ctx.gpr_range(self.dest) # type: ignore
826 src = ctx.gpr_range(self.src) # type: ignore
827 dest_s = ctx.gpr(self.dest, vec=vec) # type: ignore
828 src_s = ctx.gpr(self.src, vec=vec) # type: ignore
829 mrr = ""
830 if src.conflicts(dest) and src.start > dest.start:
831 mrr = "/mrr"
832 if ctx.needs_sv(self.src, self.dest): # type: ignore
833 return [f"sv.or{mrr} {dest_s}, {src_s}, {src_s}"]
834 return [f"or {dest_s}, {src_s}, {src_s}"]
835 raise NotImplementedError
836
837
838 @plain_data(unsafe_hash=True, frozen=True, repr=False)
839 @final
840 class OpConcat(Op):
841 __slots__ = "dest", "sources"
842
843 def inputs(self):
844 # type: () -> dict[str, SSAVal]
845 return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
846
847 def outputs(self):
848 # type: () -> dict[str, SSAVal]
849 return {"dest": self.dest}
850
851 def __init__(self, fn, sources):
852 # type: (Fn, Iterable[SSAGPRRange]) -> None
853 super().__init__(fn)
854 sources = tuple(sources)
855 self.dest = SSAVal(self, "dest", GPRRangeType(
856 sum(i.ty.length for i in sources)))
857 self.sources = sources
858
859 def get_equality_constraints(self):
860 # type: () -> Iterable[EqualityConstraint]
861 yield EqualityConstraint([self.dest], [*self.sources])
862
863 def get_asm_lines(self, ctx):
864 # type: (AsmContext) -> list[str]
865 return []
866
867
868 @plain_data(unsafe_hash=True, frozen=True, repr=False)
869 @final
870 class OpSplit(Op):
871 __slots__ = "results", "src"
872
873 def inputs(self):
874 # type: () -> dict[str, SSAVal]
875 return {"src": self.src}
876
877 def outputs(self):
878 # type: () -> dict[str, SSAVal]
879 return {i.arg_name: i for i in self.results}
880
881 def __init__(self, fn, src, split_indexes):
882 # type: (Fn, SSAGPRRange, Iterable[int]) -> None
883 super().__init__(fn)
884 ranges = [] # type: list[GPRRangeType]
885 last = 0
886 for i in split_indexes:
887 if not (0 < i < src.ty.length):
888 raise ValueError(f"invalid split index: {i}, must be in "
889 f"0 < i < {src.ty.length}")
890 ranges.append(GPRRangeType(i - last))
891 last = i
892 ranges.append(GPRRangeType(src.ty.length - last))
893 self.src = src
894 self.results = tuple(
895 SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges))
896
897 def get_equality_constraints(self):
898 # type: () -> Iterable[EqualityConstraint]
899 yield EqualityConstraint([*self.results], [self.src])
900
901 def get_asm_lines(self, ctx):
902 # type: (AsmContext) -> list[str]
903 return []
904
905
906 @plain_data(unsafe_hash=True, frozen=True, repr=False)
907 @final
908 class OpBigIntAddSub(Op):
909 __slots__ = "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl"
910
911 def inputs(self):
912 # type: () -> dict[str, SSAVal]
913 retval = {} # type: dict[str, SSAVal[Any]]
914 retval["lhs"] = self.lhs
915 retval["rhs"] = self.rhs
916 retval["CA_in"] = self.CA_in
917 if self.vl is not None:
918 retval["vl"] = self.vl
919 return retval
920
921 def outputs(self):
922 # type: () -> dict[str, SSAVal]
923 return {"out": self.out, "CA_out": self.CA_out}
924
925 def __init__(self, fn, lhs, rhs, CA_in, is_sub, vl=None):
926 # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None
927 super().__init__(fn)
928 if lhs.ty != rhs.ty:
929 raise TypeError(f"source types must match: "
930 f"{lhs} doesn't match {rhs}")
931 self.out = SSAVal(self, "out", lhs.ty)
932 self.lhs = lhs
933 self.rhs = rhs
934 self.CA_in = CA_in
935 self.CA_out = SSAVal(self, "CA_out", CA_in.ty)
936 self.is_sub = is_sub
937 self.vl = vl
938 assert_vl_is(vl, lhs.ty.length)
939
940 def get_extra_interferences(self):
941 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
942 yield self.out, self.lhs
943 yield self.out, self.rhs
944
945 def get_asm_lines(self, ctx):
946 # type: (AsmContext) -> list[str]
947 vec = self.out.ty.length != 1
948 out = ctx.gpr(self.out, vec=vec)
949 RA = ctx.gpr(self.lhs, vec=vec)
950 RB = ctx.gpr(self.rhs, vec=vec)
951 mnemonic = "adde"
952 if self.is_sub:
953 mnemonic = "subfe"
954 RA, RB = RB, RA # reorder to match subfe
955 if ctx.needs_sv(self.out, self.lhs, self.rhs):
956 return [f"sv.{mnemonic} {out}, {RA}, {RB}"]
957 return [f"{mnemonic} {out}, {RA}, {RB}"]
958
959
960 @plain_data(unsafe_hash=True, frozen=True, repr=False)
961 @final
962 class OpBigIntMulDiv(Op):
963 __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div", "vl"
964
965 def inputs(self):
966 # type: () -> dict[str, SSAVal]
967 retval = {} # type: dict[str, SSAVal[Any]]
968 retval["RA"] = self.RA
969 retval["RB"] = self.RB
970 retval["RC"] = self.RC
971 if self.vl is not None:
972 retval["vl"] = self.vl
973 return retval
974
975 def outputs(self):
976 # type: () -> dict[str, SSAVal]
977 return {"RT": self.RT, "RS": self.RS}
978
979 def __init__(self, fn, RA, RB, RC, is_div, vl):
980 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None
981 super().__init__(fn)
982 self.RT = SSAVal(self, "RT", RA.ty)
983 self.RA = RA
984 self.RB = RB
985 self.RC = RC
986 self.RS = SSAVal(self, "RS", RC.ty)
987 self.is_div = is_div
988 self.vl = vl
989 assert_vl_is(vl, RA.ty.length)
990
991 def get_equality_constraints(self):
992 # type: () -> Iterable[EqualityConstraint]
993 yield EqualityConstraint([self.RC], [self.RS])
994
995 def get_extra_interferences(self):
996 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
997 yield self.RT, self.RA
998 yield self.RT, self.RB
999 yield self.RT, self.RC
1000 yield self.RT, self.RS
1001 yield self.RS, self.RA
1002 yield self.RS, self.RB
1003
1004 def get_asm_lines(self, ctx):
1005 # type: (AsmContext) -> list[str]
1006 vec = self.RT.ty.length != 1
1007 RT = ctx.gpr(self.RT, vec=vec)
1008 RA = ctx.gpr(self.RA, vec=vec)
1009 RB = ctx.sgpr(self.RB)
1010 RC = ctx.sgpr(self.RC)
1011 mnemonic = "maddedu"
1012 if self.is_div:
1013 mnemonic = "divmod2du/mrr"
1014 return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
1015
1016
1017 @final
1018 @unique
1019 class ShiftKind(Enum):
1020 Sl = "sl"
1021 Sr = "sr"
1022 Sra = "sra"
1023
1024 def make_big_int_carry_in(self, fn, inp):
1025 # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]]
1026 if self is ShiftKind.Sl or self is ShiftKind.Sr:
1027 li = OpLI(fn, 0)
1028 return li.out, [li]
1029 else:
1030 assert self is ShiftKind.Sra
1031 split = OpSplit(fn, inp, [inp.ty.length - 1])
1032 shr = OpShiftImm(fn, split.results[1], sh=63, kind=ShiftKind.Sra)
1033 return shr.out, [split, shr]
1034
1035 def make_big_int_shift(self, fn, inp, sh, vl):
1036 # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]]
1037 carry_in, ops = self.make_big_int_carry_in(fn, inp)
1038 big_int_shift = OpBigIntShift(fn, inp, sh, carry_in, kind=self, vl=vl)
1039 ops.append(big_int_shift)
1040 return big_int_shift.out, ops
1041
1042
1043 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1044 @final
1045 class OpBigIntShift(Op):
1046 __slots__ = "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl"
1047
1048 def inputs(self):
1049 # type: () -> dict[str, SSAVal]
1050 retval = {} # type: dict[str, SSAVal[Any]]
1051 retval["inp"] = self.inp
1052 retval["sh"] = self.sh
1053 retval["carry_in"] = self.carry_in
1054 if self.vl is not None:
1055 retval["vl"] = self.vl
1056 return retval
1057
1058 def outputs(self):
1059 # type: () -> dict[str, SSAVal]
1060 return {"out": self.out, "_out_padding": self._out_padding}
1061
1062 def __init__(self, fn, inp, sh, carry_in, kind, vl=None):
1063 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None
1064 super().__init__(fn)
1065 self.out = SSAVal(self, "out", inp.ty)
1066 self._out_padding = SSAVal(self, "_out_padding", GPRRangeType())
1067 self.carry_in = carry_in
1068 self.inp = inp
1069 self.sh = sh
1070 self.kind = kind
1071 self.vl = vl
1072 assert_vl_is(vl, inp.ty.length)
1073
1074 def get_extra_interferences(self):
1075 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1076 yield self.out, self.sh
1077
1078 def get_equality_constraints(self):
1079 # type: () -> Iterable[EqualityConstraint]
1080 if self.kind is ShiftKind.Sl:
1081 yield EqualityConstraint([self.carry_in, self.inp],
1082 [self.out, self._out_padding])
1083 else:
1084 assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
1085 yield EqualityConstraint([self.inp, self.carry_in],
1086 [self._out_padding, self.out])
1087
1088 def get_asm_lines(self, ctx):
1089 # type: (AsmContext) -> list[str]
1090 vec = self.out.ty.length != 1
1091 if self.kind is ShiftKind.Sl:
1092 RT = ctx.gpr(self.out, vec=vec)
1093 RA = ctx.gpr(self.out, vec=vec, offset=-1)
1094 RB = ctx.sgpr(self.sh)
1095 mrr = "/mrr" if vec else ""
1096 return [f"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"]
1097 else:
1098 assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
1099 RT = ctx.gpr(self.out, vec=vec)
1100 RA = ctx.gpr(self.out, vec=vec, offset=1)
1101 RB = ctx.sgpr(self.sh)
1102 return [f"sv.dsrd {RT}, {RA}, {RB}, 1"]
1103
1104
1105 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1106 @final
1107 class OpShiftImm(Op):
1108 __slots__ = "out", "inp", "sh", "kind", "ca_out"
1109
1110 def inputs(self):
1111 # type: () -> dict[str, SSAVal]
1112 return {"inp": self.inp}
1113
1114 def outputs(self):
1115 # type: () -> dict[str, SSAVal]
1116 if self.ca_out is not None:
1117 return {"out": self.out, "ca_out": self.ca_out}
1118 return {"out": self.out}
1119
1120 def __init__(self, fn, inp, sh, kind):
1121 # type: (Fn, SSAGPR, int, ShiftKind) -> None
1122 super().__init__(fn)
1123 self.out = SSAVal(self, "out", inp.ty)
1124 self.inp = inp
1125 if not (0 <= sh < 64):
1126 raise ValueError("shift amount out of range")
1127 self.sh = sh
1128 self.kind = kind
1129 if self.kind is ShiftKind.Sra:
1130 self.ca_out = SSAVal(self, "ca_out", CAType())
1131 else:
1132 self.ca_out = None
1133
1134 def get_asm_lines(self, ctx):
1135 # type: (AsmContext) -> list[str]
1136 out = ctx.sgpr(self.out)
1137 inp = ctx.sgpr(self.inp)
1138 if self.kind is ShiftKind.Sl:
1139 mnemonic = "rldicr"
1140 args = f"{self.sh}, {63 - self.sh}"
1141 elif self.kind is ShiftKind.Sr:
1142 mnemonic = "rldicl"
1143 v = (64 - self.sh) % 64
1144 args = f"{v}, {self.sh}"
1145 else:
1146 assert self.kind is ShiftKind.Sra
1147 mnemonic = "sradi"
1148 args = f"{self.sh}"
1149 if ctx.needs_sv(self.out, self.inp):
1150 return [f"sv.{mnemonic} {out}, {inp}, {args}"]
1151 return [f"{mnemonic} {out}, {inp}, {args}"]
1152
1153
1154 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1155 @final
1156 class OpLI(Op):
1157 __slots__ = "out", "value", "vl"
1158
1159 def inputs(self):
1160 # type: () -> dict[str, SSAVal]
1161 retval = {} # type: dict[str, SSAVal[Any]]
1162 if self.vl is not None:
1163 retval["vl"] = self.vl
1164 return retval
1165
1166 def outputs(self):
1167 # type: () -> dict[str, SSAVal]
1168 return {"out": self.out}
1169
1170 def __init__(self, fn, value, vl=None):
1171 # type: (Fn, int, SSAKnownVL | None) -> None
1172 super().__init__(fn)
1173 if vl is None:
1174 length = 1
1175 else:
1176 length = vl.ty.length
1177 self.out = SSAVal(self, "out", GPRRangeType(length))
1178 if not (-1 << 15 <= value <= (1 << 15) - 1):
1179 raise ValueError(f"value out of range: {value}")
1180 self.value = value
1181 self.vl = vl
1182 assert_vl_is(vl, length)
1183
1184 def get_asm_lines(self, ctx):
1185 # type: (AsmContext) -> list[str]
1186 vec = self.out.ty.length != 1
1187 out = ctx.gpr(self.out, vec=vec)
1188 if ctx.needs_sv(self.out):
1189 return [f"sv.addi {out}, 0, {self.value}"]
1190 return [f"addi {out}, 0, {self.value}"]
1191
1192
1193 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1194 @final
1195 class OpSetCA(Op):
1196 __slots__ = "out", "value"
1197
1198 def inputs(self):
1199 # type: () -> dict[str, SSAVal]
1200 return {}
1201
1202 def outputs(self):
1203 # type: () -> dict[str, SSAVal]
1204 return {"out": self.out}
1205
1206 def __init__(self, fn, value):
1207 # type: (Fn, bool) -> None
1208 super().__init__(fn)
1209 self.out = SSAVal(self, "out", CAType())
1210 self.value = value
1211
1212 def get_asm_lines(self, ctx):
1213 # type: (AsmContext) -> list[str]
1214 if self.value:
1215 return ["subfic 0, 0, -1"]
1216 return ["addic 0, 0, 0"]
1217
1218
1219 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1220 @final
1221 class OpLoad(Op):
1222 __slots__ = "RT", "RA", "offset", "mem", "vl"
1223
1224 def inputs(self):
1225 # type: () -> dict[str, SSAVal]
1226 retval = {} # type: dict[str, SSAVal[Any]]
1227 retval["RA"] = self.RA
1228 retval["mem"] = self.mem
1229 if self.vl is not None:
1230 retval["vl"] = self.vl
1231 return retval
1232
1233 def outputs(self):
1234 # type: () -> dict[str, SSAVal]
1235 return {"RT": self.RT}
1236
1237 def __init__(self, fn, RA, offset, mem, vl=None):
1238 # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1239 super().__init__(fn)
1240 if vl is None:
1241 length = 1
1242 else:
1243 length = vl.ty.length
1244 self.RT = SSAVal(self, "RT", GPRRangeType(length))
1245 self.RA = RA
1246 if not (-1 << 15 <= offset <= (1 << 15) - 1):
1247 raise ValueError(f"offset out of range: {offset}")
1248 if offset % 4 != 0:
1249 raise ValueError(f"offset not aligned: {offset}")
1250 self.offset = offset
1251 self.mem = mem
1252 self.vl = vl
1253 assert_vl_is(vl, length)
1254
1255 def get_extra_interferences(self):
1256 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1257 if self.RT.ty.length > 1:
1258 yield self.RT, self.RA
1259
1260 def get_asm_lines(self, ctx):
1261 # type: (AsmContext) -> list[str]
1262 RT = ctx.gpr(self.RT, vec=self.RT.ty.length != 1)
1263 RA = ctx.sgpr(self.RA)
1264 if ctx.needs_sv(self.RT, self.RA):
1265 return [f"sv.ld {RT}, {self.offset}({RA})"]
1266 return [f"ld {RT}, {self.offset}({RA})"]
1267
1268
1269 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1270 @final
1271 class OpStore(Op):
1272 __slots__ = "RS", "RA", "offset", "mem_in", "mem_out", "vl"
1273
1274 def inputs(self):
1275 # type: () -> dict[str, SSAVal]
1276 retval = {} # type: dict[str, SSAVal[Any]]
1277 retval["RS"] = self.RS
1278 retval["RA"] = self.RA
1279 retval["mem_in"] = self.mem_in
1280 if self.vl is not None:
1281 retval["vl"] = self.vl
1282 return retval
1283
1284 def outputs(self):
1285 # type: () -> dict[str, SSAVal]
1286 return {"mem_out": self.mem_out}
1287
1288 def __init__(self, fn, RS, RA, offset, mem_in, vl=None):
1289 # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1290 super().__init__(fn)
1291 self.RS = RS
1292 self.RA = RA
1293 if not (-1 << 15 <= offset <= (1 << 15) - 1):
1294 raise ValueError(f"offset out of range: {offset}")
1295 if offset % 4 != 0:
1296 raise ValueError(f"offset not aligned: {offset}")
1297 self.offset = offset
1298 self.mem_in = mem_in
1299 self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
1300 self.vl = vl
1301 assert_vl_is(vl, RS.ty.length)
1302
1303 def get_asm_lines(self, ctx):
1304 # type: (AsmContext) -> list[str]
1305 RS = ctx.gpr(self.RS, vec=self.RS.ty.length != 1)
1306 RA = ctx.sgpr(self.RA)
1307 if ctx.needs_sv(self.RS, self.RA):
1308 return [f"sv.std {RS}, {self.offset}({RA})"]
1309 return [f"std {RS}, {self.offset}({RA})"]
1310
1311
1312 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1313 @final
1314 class OpFuncArg(Op):
1315 __slots__ = "out",
1316
1317 def inputs(self):
1318 # type: () -> dict[str, SSAVal]
1319 return {}
1320
1321 def outputs(self):
1322 # type: () -> dict[str, SSAVal]
1323 return {"out": self.out}
1324
1325 def __init__(self, fn, ty):
1326 # type: (Fn, FixedGPRRangeType) -> None
1327 super().__init__(fn)
1328 self.out = SSAVal(self, "out", ty)
1329
1330 def get_asm_lines(self, ctx):
1331 # type: (AsmContext) -> list[str]
1332 return []
1333
1334
1335 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1336 @final
1337 class OpInputMem(Op):
1338 __slots__ = "out",
1339
1340 def inputs(self):
1341 # type: () -> dict[str, SSAVal]
1342 return {}
1343
1344 def outputs(self):
1345 # type: () -> dict[str, SSAVal]
1346 return {"out": self.out}
1347
1348 def __init__(self, fn):
1349 # type: (Fn) -> None
1350 super().__init__(fn)
1351 self.out = SSAVal(self, "out", GlobalMemType())
1352
1353 def get_asm_lines(self, ctx):
1354 # type: (AsmContext) -> list[str]
1355 return []
1356
1357
1358 @plain_data(unsafe_hash=True, frozen=True, repr=False)
1359 @final
1360 class OpSetVLImm(Op):
1361 __slots__ = "out",
1362
1363 def inputs(self):
1364 # type: () -> dict[str, SSAVal]
1365 return {}
1366
1367 def outputs(self):
1368 # type: () -> dict[str, SSAVal]
1369 return {"out": self.out}
1370
1371 def __init__(self, fn, length):
1372 # type: (Fn, int) -> None
1373 super().__init__(fn)
1374 self.out = SSAVal(self, "out", KnownVLType(length))
1375
1376 def get_asm_lines(self, ctx):
1377 # type: (AsmContext) -> list[str]
1378 return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
1379
1380
1381 def op_set_to_list(ops):
1382 # type: (Iterable[Op]) -> list[Op]
1383 worklists = [{}] # type: list[dict[Op, None]]
1384 inps_to_ops_map = defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
1385 ops_to_pending_input_count_map = {} # type: dict[Op, int]
1386 for op in ops:
1387 input_count = 0
1388 for val in op.inputs().values():
1389 input_count += 1
1390 inps_to_ops_map[val][op] = None
1391 while len(worklists) <= input_count:
1392 worklists.append({})
1393 ops_to_pending_input_count_map[op] = input_count
1394 worklists[input_count][op] = None
1395 retval = [] # type: list[Op]
1396 ready_vals = OSet() # type: OSet[SSAVal]
1397 while len(worklists[0]) != 0:
1398 writing_op = next(iter(worklists[0]))
1399 del worklists[0][writing_op]
1400 retval.append(writing_op)
1401 for val in writing_op.outputs().values():
1402 if val in ready_vals:
1403 raise ValueError(f"multiple instructions must not write "
1404 f"to the same SSA value: {val}")
1405 ready_vals.add(val)
1406 for reading_op in inps_to_ops_map[val]:
1407 pending = ops_to_pending_input_count_map[reading_op]
1408 del worklists[pending][reading_op]
1409 pending -= 1
1410 worklists[pending][reading_op] = None
1411 ops_to_pending_input_count_map[reading_op] = pending
1412 for worklist in worklists:
1413 for op in worklist:
1414 raise ValueError(f"instruction is part of a dependency loop or "
1415 f"its inputs are never written: {op}")
1416 return retval
1417
1418
1419 def generate_assembly(ops, assigned_registers=None):
1420 # type: (list[Op], dict[SSAVal, RegLoc] | None) -> list[str]
1421 if assigned_registers is None:
1422 from bigint_presentation_code.register_allocator import \
1423 allocate_registers
1424 assigned_registers = allocate_registers(ops)
1425 ctx = AsmContext(assigned_registers)
1426 retval = [] # list[str]
1427 for op in ops:
1428 retval.extend(op.get_asm_lines(ctx))
1429 retval.append("bclr 20, 0, 0")
1430 return retval