24b86494e37c71f10c9491e697efb5f0823e37b8
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir.py
1 """
2 Compiler IR for Toom-Cook algorithm generator for SVP64
3 """
4
5 from abc import ABCMeta, abstractmethod
6 from collections import defaultdict
7 from enum import Enum, EnumMeta, unique
8 from functools import lru_cache
9 from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence,
10 TypeVar, cast)
11
12 from cached_property import cached_property
13 from nmutil.plain_data import fields, plain_data
14
15 if TYPE_CHECKING:
16 from typing_extensions import final
17 else:
18 def final(v):
19 return v
20
21
22 class ABCEnumMeta(EnumMeta, ABCMeta):
23 pass
24
25
26 class RegLoc(metaclass=ABCMeta):
27 __slots__ = ()
28
29 @abstractmethod
30 def conflicts(self, other):
31 # type: (RegLoc) -> bool
32 ...
33
34 def get_subreg_at_offset(self, subreg_type, offset):
35 # type: (RegType, int) -> RegLoc
36 if self not in subreg_type.reg_class:
37 raise ValueError(f"register not a member of subreg_type: "
38 f"reg={self} subreg_type={subreg_type}")
39 if offset != 0:
40 raise ValueError(f"non-zero sub-register offset not supported "
41 f"for register: {self}")
42 return self
43
44
45 GPR_COUNT = 128
46
47
48 @plain_data(frozen=True, unsafe_hash=True)
49 @final
50 class GPRRange(RegLoc, Sequence["GPRRange"]):
51 __slots__ = "start", "length"
52
53 def __init__(self, start, length=None):
54 # type: (int | range, int | None) -> None
55 if isinstance(start, range):
56 if length is not None:
57 raise TypeError("can't specify length when input is a range")
58 if start.step != 1:
59 raise ValueError("range must have a step of 1")
60 length = len(start)
61 start = start.start
62 elif length is None:
63 length = 1
64 if length <= 0 or start < 0 or start + length > GPR_COUNT:
65 raise ValueError("invalid GPRRange")
66 self.start = start
67 self.length = length
68
69 @property
70 def stop(self):
71 return self.start + self.length
72
73 @property
74 def step(self):
75 return 1
76
77 @property
78 def range(self):
79 return range(self.start, self.stop, self.step)
80
81 def __len__(self):
82 return self.length
83
84 def __getitem__(self, item):
85 # type: (int | slice) -> GPRRange
86 return GPRRange(self.range[item])
87
88 def __contains__(self, value):
89 # type: (GPRRange) -> bool
90 return value.start >= self.start and value.stop <= self.stop
91
92 def index(self, sub, start=None, end=None):
93 # type: (GPRRange, int | None, int | None) -> int
94 r = self.range[start:end]
95 if sub.start < r.start or sub.stop > r.stop:
96 raise ValueError("GPR range not found")
97 return sub.start - self.start
98
99 def count(self, sub, start=None, end=None):
100 # type: (GPRRange, int | None, int | None) -> int
101 r = self.range[start:end]
102 if len(r) == 0:
103 return 0
104 return int(sub in GPRRange(r))
105
106 def conflicts(self, other):
107 # type: (RegLoc) -> bool
108 if isinstance(other, GPRRange):
109 return self.stop > other.start and other.stop > self.start
110 return False
111
112 def get_subreg_at_offset(self, subreg_type, offset):
113 # type: (RegType, int) -> GPRRange
114 if not isinstance(subreg_type, GPRRangeType):
115 raise ValueError(f"subreg_type is not a "
116 f"GPRRangeType: {subreg_type}")
117 if offset < 0 or offset + subreg_type.length > self.stop:
118 raise ValueError(f"sub-register offset is out of range: {offset}")
119 return GPRRange(self.start + offset, subreg_type.length)
120
121
122 SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
123
124
125 @final
126 @unique
127 class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
128 CY = "CY"
129
130 def conflicts(self, other):
131 # type: (RegLoc) -> bool
132 if isinstance(other, XERBit):
133 return self == other
134 return False
135
136
137 @final
138 @unique
139 class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
140 """singleton representing all non-StackSlot memory -- treated as a single
141 physical register for register allocation purposes.
142 """
143 GlobalMem = "GlobalMem"
144
145 def conflicts(self, other):
146 # type: (RegLoc) -> bool
147 if isinstance(other, GlobalMem):
148 return self == other
149 return False
150
151
152 @final
153 class RegClass(AbstractSet[RegLoc]):
154 """ an ordered set of registers.
155 earlier registers are preferred by the register allocator.
156 """
157
158 def __init__(self, regs):
159 # type: (Iterable[RegLoc]) -> None
160
161 # use dict to maintain order
162 self.__regs = dict.fromkeys(regs) # type: dict[RegLoc, None]
163
164 def __len__(self):
165 return len(self.__regs)
166
167 def __iter__(self):
168 return iter(self.__regs)
169
170 def __contains__(self, v):
171 # type: (RegLoc) -> bool
172 return v in self.__regs
173
174 def __hash__(self):
175 return super()._hash()
176
177 @lru_cache(maxsize=None, typed=True)
178 def max_conflicts_with(self, other):
179 # type: (RegClass | RegLoc) -> int
180 """the largest number of registers in `self` that a single register
181 from `other` can conflict with
182 """
183 if isinstance(other, RegClass):
184 return max(self.max_conflicts_with(i) for i in other)
185 else:
186 return sum(other.conflicts(i) for i in self)
187
188
189 @plain_data(frozen=True, unsafe_hash=True)
190 class RegType(metaclass=ABCMeta):
191 __slots__ = ()
192
193 @property
194 @abstractmethod
195 def reg_class(self):
196 # type: () -> RegClass
197 return ...
198
199
200 _RegType = TypeVar("_RegType", bound=RegType)
201
202
203 @plain_data(frozen=True, eq=False)
204 class GPRRangeType(RegType):
205 __slots__ = "length",
206
207 def __init__(self, length):
208 # type: (int) -> None
209 if length < 1 or length > GPR_COUNT:
210 raise ValueError("invalid length")
211 self.length = length
212
213 @staticmethod
214 @lru_cache(maxsize=None)
215 def __get_reg_class(length):
216 # type: (int) -> RegClass
217 regs = []
218 for start in range(GPR_COUNT - length):
219 reg = GPRRange(start, length)
220 if any(i in reg for i in SPECIAL_GPRS):
221 continue
222 regs.append(reg)
223 return RegClass(regs)
224
225 @property
226 def reg_class(self):
227 # type: () -> RegClass
228 return GPRRangeType.__get_reg_class(self.length)
229
230 @final
231 def __eq__(self, other):
232 if isinstance(other, GPRRangeType):
233 return self.length == other.length
234 return False
235
236 @final
237 def __hash__(self):
238 return hash(self.length)
239
240
241 @plain_data(frozen=True, eq=False)
242 @final
243 class GPRType(GPRRangeType):
244 __slots__ = ()
245
246 def __init__(self, length=1):
247 if length != 1:
248 raise ValueError("length must be 1")
249 super().__init__(length=1)
250
251
252 @plain_data(frozen=True, unsafe_hash=True)
253 @final
254 class FixedGPRRangeType(GPRRangeType):
255 __slots__ = "reg",
256
257 def __init__(self, reg):
258 # type: (GPRRange) -> None
259 super().__init__(length=reg.length)
260 self.reg = reg
261
262 @property
263 def reg_class(self):
264 # type: () -> RegClass
265 return RegClass([self.reg])
266
267
268 @plain_data(frozen=True, unsafe_hash=True)
269 @final
270 class CYType(RegType):
271 __slots__ = ()
272
273 @property
274 def reg_class(self):
275 # type: () -> RegClass
276 return RegClass([XERBit.CY])
277
278
279 @plain_data(frozen=True, unsafe_hash=True)
280 @final
281 class GlobalMemType(RegType):
282 __slots__ = ()
283
284 @property
285 def reg_class(self):
286 # type: () -> RegClass
287 return RegClass([GlobalMem.GlobalMem])
288
289
290 @plain_data(frozen=True, unsafe_hash=True)
291 @final
292 class StackSlot(RegLoc):
293 __slots__ = "start_slot", "length_in_slots",
294
295 def __init__(self, start_slot, length_in_slots):
296 # type: (int, int) -> None
297 self.start_slot = start_slot
298 if length_in_slots < 1:
299 raise ValueError("invalid length_in_slots")
300 self.length_in_slots = length_in_slots
301
302 @property
303 def stop_slot(self):
304 return self.start_slot + self.length_in_slots
305
306 def conflicts(self, other):
307 # type: (RegLoc) -> bool
308 if isinstance(other, StackSlot):
309 return (self.stop_slot > other.start_slot
310 and other.stop_slot > self.start_slot)
311 return False
312
313 def get_subreg_at_offset(self, subreg_type, offset):
314 # type: (RegType, int) -> StackSlot
315 if not isinstance(subreg_type, StackSlotType):
316 raise ValueError(f"subreg_type is not a "
317 f"StackSlotType: {subreg_type}")
318 if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
319 raise ValueError(f"sub-register offset is out of range: {offset}")
320 return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
321
322
323 STACK_SLOT_COUNT = 128
324
325
326 @plain_data(frozen=True, eq=False)
327 @final
328 class StackSlotType(RegType):
329 __slots__ = "length_in_slots",
330
331 def __init__(self, length_in_slots=1):
332 # type: (int) -> None
333 if length_in_slots < 1:
334 raise ValueError("invalid length_in_slots")
335 self.length_in_slots = length_in_slots
336
337 @staticmethod
338 @lru_cache(maxsize=None)
339 def __get_reg_class(length_in_slots):
340 # type: (int) -> RegClass
341 regs = []
342 for start in range(STACK_SLOT_COUNT - length_in_slots):
343 reg = StackSlot(start, length_in_slots)
344 regs.append(reg)
345 return RegClass(regs)
346
347 @property
348 def reg_class(self):
349 # type: () -> RegClass
350 return StackSlotType.__get_reg_class(self.length_in_slots)
351
352 @final
353 def __eq__(self, other):
354 if isinstance(other, StackSlotType):
355 return self.length_in_slots == other.length_in_slots
356 return False
357
358 @final
359 def __hash__(self):
360 return hash(self.length_in_slots)
361
362
363 @plain_data(frozen=True, eq=False, repr=False)
364 @final
365 class SSAVal(Generic[_RegType]):
366 __slots__ = "op", "arg_name", "ty",
367
368 def __init__(self, op, arg_name, ty):
369 # type: (Op, str, _RegType) -> None
370 self.op = op
371 """the Op that writes this SSAVal"""
372
373 self.arg_name = arg_name
374 """the name of the argument of self.op that writes this SSAVal"""
375
376 self.ty = ty
377
378 def __eq__(self, rhs):
379 if isinstance(rhs, SSAVal):
380 return (self.op is rhs.op
381 and self.arg_name == rhs.arg_name)
382 return False
383
384 def __hash__(self):
385 return hash((id(self.op), self.arg_name))
386
387 def __repr__(self):
388 fields_list = []
389 for name in fields(self):
390 v = getattr(self, name, None)
391 if v is not None:
392 if name == "op":
393 v = v.__repr__(just_id=True)
394 else:
395 v = repr(v)
396 fields_list.append(f"{name}={v}")
397 fields_str = ", ".join(fields_list)
398 return f"SSAVal({fields_str})"
399
400
401 @final
402 @plain_data(unsafe_hash=True, frozen=True)
403 class EqualityConstraint:
404 __slots__ = "lhs", "rhs"
405
406 def __init__(self, lhs, rhs):
407 # type: (list[SSAVal], list[SSAVal]) -> None
408 self.lhs = lhs
409 self.rhs = rhs
410 if len(lhs) == 0 or len(rhs) == 0:
411 raise ValueError("can't constrain an empty list to be equal")
412
413
414 class _NotSet:
415 """ helper for __repr__ for when fields aren't set """
416
417 def __repr__(self):
418 return "<not set>"
419
420
421 _NOT_SET = _NotSet()
422
423
424 @plain_data(unsafe_hash=True, frozen=True, repr=False)
425 class Op(metaclass=ABCMeta):
426 __slots__ = ()
427
428 @abstractmethod
429 def inputs(self):
430 # type: () -> dict[str, SSAVal]
431 ...
432
433 @abstractmethod
434 def outputs(self):
435 # type: () -> dict[str, SSAVal]
436 ...
437
438 def get_equality_constraints(self):
439 # type: () -> Iterable[EqualityConstraint]
440 if False:
441 yield ...
442
443 def get_extra_interferences(self):
444 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
445 if False:
446 yield ...
447
448 __NEXT_ID = 0
449
450 @cached_property
451 def id(self):
452 retval = Op.__NEXT_ID
453 Op.__NEXT_ID += 1
454 return retval
455
456 @final
457 def __repr__(self, just_id=False):
458 fields_list = [f"#{self.id}"]
459 if not just_id:
460 for name in fields(self):
461 v = getattr(self, name, _NOT_SET)
462 fields_list.append(f"{name}={v!r}")
463 fields_str = ', '.join(fields_list)
464 return f"{self.__class__.__name__}({fields_str})"
465
466
467 @plain_data(unsafe_hash=True, frozen=True, repr=False)
468 @final
469 class OpLoadFromStackSlot(Op):
470 __slots__ = "dest", "src"
471
472 def inputs(self):
473 # type: () -> dict[str, SSAVal]
474 return {"src": self.src}
475
476 def outputs(self):
477 # type: () -> dict[str, SSAVal]
478 return {"dest": self.dest}
479
480 def __init__(self, src):
481 # type: (SSAVal[GPRRangeType]) -> None
482 self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
483 self.src = src
484
485
486 @plain_data(unsafe_hash=True, frozen=True, repr=False)
487 @final
488 class OpStoreToStackSlot(Op):
489 __slots__ = "dest", "src"
490
491 def inputs(self):
492 # type: () -> dict[str, SSAVal]
493 return {"src": self.src}
494
495 def outputs(self):
496 # type: () -> dict[str, SSAVal]
497 return {"dest": self.dest}
498
499 def __init__(self, src):
500 # type: (SSAVal[StackSlotType]) -> None
501 self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
502 self.src = src
503
504
505 _RegSrcType = TypeVar("_RegSrcType", bound=RegType)
506
507
508 @plain_data(unsafe_hash=True, frozen=True, repr=False)
509 @final
510 class OpCopy(Op, Generic[_RegSrcType, _RegType]):
511 __slots__ = "dest", "src"
512
513 def inputs(self):
514 # type: () -> dict[str, SSAVal]
515 return {"src": self.src}
516
517 def outputs(self):
518 # type: () -> dict[str, SSAVal]
519 return {"dest": self.dest}
520
521 def __init__(self, src, dest_ty=None):
522 # type: (SSAVal[_RegSrcType], _RegType | None) -> None
523 if dest_ty is None:
524 dest_ty = cast(_RegType, src.ty)
525 if isinstance(src.ty, GPRRangeType) \
526 and isinstance(dest_ty, GPRRangeType):
527 if src.ty.length != dest_ty.length:
528 raise ValueError(f"incompatible source and destination "
529 f"types: {src.ty} and {dest_ty}")
530 elif src.ty != dest_ty:
531 raise ValueError(f"incompatible source and destination "
532 f"types: {src.ty} and {dest_ty}")
533
534 self.dest = SSAVal(self, "dest", dest_ty) # type: SSAVal[_RegType]
535 self.src = src
536
537
538 @plain_data(unsafe_hash=True, frozen=True, repr=False)
539 @final
540 class OpConcat(Op):
541 __slots__ = "dest", "sources"
542
543 def inputs(self):
544 # type: () -> dict[str, SSAVal]
545 return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
546
547 def outputs(self):
548 # type: () -> dict[str, SSAVal]
549 return {"dest": self.dest}
550
551 def __init__(self, sources):
552 # type: (Iterable[SSAVal[GPRRangeType]]) -> None
553 sources = tuple(sources)
554 self.dest = SSAVal(self, "dest", GPRRangeType(
555 sum(i.ty.length for i in sources)))
556 self.sources = sources
557
558 def get_equality_constraints(self):
559 # type: () -> Iterable[EqualityConstraint]
560 yield EqualityConstraint([self.dest], [*self.sources])
561
562
563 @plain_data(unsafe_hash=True, frozen=True, repr=False)
564 @final
565 class OpSplit(Op):
566 __slots__ = "results", "src"
567
568 def inputs(self):
569 # type: () -> dict[str, SSAVal]
570 return {"src": self.src}
571
572 def outputs(self):
573 # type: () -> dict[str, SSAVal]
574 return {i.arg_name: i for i in self.results}
575
576 def __init__(self, src, split_indexes):
577 # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
578 ranges = [] # type: list[GPRRangeType]
579 last = 0
580 for i in split_indexes:
581 if not (0 < i < src.ty.length):
582 raise ValueError(f"invalid split index: {i}, must be in "
583 f"0 < i < {src.ty.length}")
584 ranges.append(GPRRangeType(i - last))
585 last = i
586 ranges.append(GPRRangeType(src.ty.length - last))
587 self.src = src
588 self.results = tuple(
589 SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges))
590
591 def get_equality_constraints(self):
592 # type: () -> Iterable[EqualityConstraint]
593 yield EqualityConstraint([*self.results], [self.src])
594
595
596 @plain_data(unsafe_hash=True, frozen=True, repr=False)
597 @final
598 class OpAddSubE(Op):
599 __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
600
601 def inputs(self):
602 # type: () -> dict[str, SSAVal]
603 return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
604
605 def outputs(self):
606 # type: () -> dict[str, SSAVal]
607 return {"RT": self.RT, "CY_out": self.CY_out}
608
609 def __init__(self, RA, RB, CY_in, is_sub):
610 # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
611 if RA.ty != RB.ty:
612 raise TypeError(f"source types must match: "
613 f"{RA} doesn't match {RB}")
614 self.RT = SSAVal(self, "RT", RA.ty)
615 self.RA = RA
616 self.RB = RB
617 self.CY_in = CY_in
618 self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
619 self.is_sub = is_sub
620
621 def get_extra_interferences(self):
622 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
623 yield self.RT, self.RA
624 yield self.RT, self.RB
625
626
627 @plain_data(unsafe_hash=True, frozen=True, repr=False)
628 @final
629 class OpBigIntMulDiv(Op):
630 __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
631
632 def inputs(self):
633 # type: () -> dict[str, SSAVal]
634 return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
635
636 def outputs(self):
637 # type: () -> dict[str, SSAVal]
638 return {"RT": self.RT, "RS": self.RS}
639
640 def __init__(self, RA, RB, RC, is_div):
641 # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
642 self.RT = SSAVal(self, "RT", RA.ty)
643 self.RA = RA
644 self.RB = RB
645 self.RC = RC
646 self.RS = SSAVal(self, "RS", RC.ty)
647 self.is_div = is_div
648
649 def get_equality_constraints(self):
650 # type: () -> Iterable[EqualityConstraint]
651 yield EqualityConstraint([self.RC], [self.RS])
652
653 def get_extra_interferences(self):
654 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
655 yield self.RT, self.RA
656 yield self.RT, self.RB
657 yield self.RT, self.RC
658 yield self.RT, self.RS
659 yield self.RS, self.RA
660 yield self.RS, self.RB
661
662
663 @final
664 @unique
665 class ShiftKind(Enum):
666 Sl = "sl"
667 Sr = "sr"
668 Sra = "sra"
669
670
671 @plain_data(unsafe_hash=True, frozen=True, repr=False)
672 @final
673 class OpBigIntShift(Op):
674 __slots__ = "RT", "inp", "sh", "kind"
675
676 def inputs(self):
677 # type: () -> dict[str, SSAVal]
678 return {"inp": self.inp, "sh": self.sh}
679
680 def outputs(self):
681 # type: () -> dict[str, SSAVal]
682 return {"RT": self.RT}
683
684 def __init__(self, inp, sh, kind):
685 # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
686 self.RT = SSAVal(self, "RT", inp.ty)
687 self.inp = inp
688 self.sh = sh
689 self.kind = kind
690
691 def get_extra_interferences(self):
692 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
693 yield self.RT, self.inp
694 yield self.RT, self.sh
695
696
697 @plain_data(unsafe_hash=True, frozen=True, repr=False)
698 @final
699 class OpLI(Op):
700 __slots__ = "out", "value"
701
702 def inputs(self):
703 # type: () -> dict[str, SSAVal]
704 return {}
705
706 def outputs(self):
707 # type: () -> dict[str, SSAVal]
708 return {"out": self.out}
709
710 def __init__(self, value, length=1):
711 # type: (int, int) -> None
712 self.out = SSAVal(self, "out", GPRRangeType(length))
713 self.value = value
714
715
716 @plain_data(unsafe_hash=True, frozen=True, repr=False)
717 @final
718 class OpClearCY(Op):
719 __slots__ = "out",
720
721 def inputs(self):
722 # type: () -> dict[str, SSAVal]
723 return {}
724
725 def outputs(self):
726 # type: () -> dict[str, SSAVal]
727 return {"out": self.out}
728
729 def __init__(self):
730 # type: () -> None
731 self.out = SSAVal(self, "out", CYType())
732
733
734 @plain_data(unsafe_hash=True, frozen=True, repr=False)
735 @final
736 class OpLoad(Op):
737 __slots__ = "RT", "RA", "offset", "mem"
738
739 def inputs(self):
740 # type: () -> dict[str, SSAVal]
741 return {"RA": self.RA, "mem": self.mem}
742
743 def outputs(self):
744 # type: () -> dict[str, SSAVal]
745 return {"RT": self.RT}
746
747 def __init__(self, RA, offset, mem, length=1):
748 # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
749 self.RT = SSAVal(self, "RT", GPRRangeType(length))
750 self.RA = RA
751 self.offset = offset
752 self.mem = mem
753
754 def get_extra_interferences(self):
755 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
756 if self.RT.ty.length > 1:
757 yield self.RT, self.RA
758
759
760 @plain_data(unsafe_hash=True, frozen=True, repr=False)
761 @final
762 class OpStore(Op):
763 __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
764
765 def inputs(self):
766 # type: () -> dict[str, SSAVal]
767 return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
768
769 def outputs(self):
770 # type: () -> dict[str, SSAVal]
771 return {"mem_out": self.mem_out}
772
773 def __init__(self, RS, RA, offset, mem_in):
774 # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
775 self.RS = RS
776 self.RA = RA
777 self.offset = offset
778 self.mem_in = mem_in
779 self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
780
781
782 @plain_data(unsafe_hash=True, frozen=True, repr=False)
783 @final
784 class OpFuncArg(Op):
785 __slots__ = "out",
786
787 def inputs(self):
788 # type: () -> dict[str, SSAVal]
789 return {}
790
791 def outputs(self):
792 # type: () -> dict[str, SSAVal]
793 return {"out": self.out}
794
795 def __init__(self, ty):
796 # type: (FixedGPRRangeType) -> None
797 self.out = SSAVal(self, "out", ty)
798
799
800 @plain_data(unsafe_hash=True, frozen=True, repr=False)
801 @final
802 class OpInputMem(Op):
803 __slots__ = "out",
804
805 def inputs(self):
806 # type: () -> dict[str, SSAVal]
807 return {}
808
809 def outputs(self):
810 # type: () -> dict[str, SSAVal]
811 return {"out": self.out}
812
813 def __init__(self):
814 # type: () -> None
815 self.out = SSAVal(self, "out", GlobalMemType())
816
817
818 def op_set_to_list(ops):
819 # type: (Iterable[Op]) -> list[Op]
820 worklists = [{}] # type: list[dict[Op, None]]
821 inps_to_ops_map = defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
822 ops_to_pending_input_count_map = {} # type: dict[Op, int]
823 for op in ops:
824 input_count = 0
825 for val in op.inputs().values():
826 input_count += 1
827 inps_to_ops_map[val][op] = None
828 while len(worklists) <= input_count:
829 worklists.append({})
830 ops_to_pending_input_count_map[op] = input_count
831 worklists[input_count][op] = None
832 retval = [] # type: list[Op]
833 ready_vals = set() # type: set[SSAVal]
834 while len(worklists[0]) != 0:
835 writing_op = next(iter(worklists[0]))
836 del worklists[0][writing_op]
837 retval.append(writing_op)
838 for val in writing_op.outputs().values():
839 if val in ready_vals:
840 raise ValueError(f"multiple instructions must not write "
841 f"to the same SSA value: {val}")
842 ready_vals.add(val)
843 for reading_op in inps_to_ops_map[val]:
844 pending = ops_to_pending_input_count_map[reading_op]
845 del worklists[pending][reading_op]
846 pending -= 1
847 worklists[pending][reading_op] = None
848 ops_to_pending_input_count_map[reading_op] = pending
849 for worklist in worklists:
850 for op in worklist:
851 raise ValueError(f"instruction is part of a dependency loop or "
852 f"its inputs are never written: {op}")
853 return retval