remove plain_data Generic workaround
[bigint-presentation-code.git] / src / bigint_presentation_code / toom_cook.py
1 """
2 Toom-Cook algorithm generator for SVP64
3
4 the register allocator uses an algorithm based on:
5 [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
6 """
7
8 from abc import ABCMeta, abstractmethod
9 from collections import defaultdict
10 from enum import Enum, unique
11 from functools import lru_cache
12 from typing import (Sequence, AbstractSet, Iterable, Mapping,
13 TYPE_CHECKING, Sequence, TypeVar, Generic)
14
15 from nmutil.plain_data import plain_data
16
17 if TYPE_CHECKING:
18 from typing_extensions import final, Self
19 else:
20 def final(v):
21 return v
22
23
24 @plain_data(frozen=True, unsafe_hash=True)
25 class PhysLoc(metaclass=ABCMeta):
26 __slots__ = ()
27
28
29 @plain_data(frozen=True, unsafe_hash=True)
30 class RegLoc(PhysLoc):
31 __slots__ = ()
32
33 @abstractmethod
34 def conflicts(self, other):
35 # type: (RegLoc) -> bool
36 ...
37
38
39 @plain_data(frozen=True, unsafe_hash=True)
40 class GPRRangeOrStackLoc(PhysLoc):
41 __slots__ = ()
42
43 @abstractmethod
44 def __len__(self):
45 # type: () -> int
46 ...
47
48
49 GPR_COUNT = 128
50
51
52 @plain_data(frozen=True, unsafe_hash=True)
53 @final
54 class GPRRange(RegLoc, GPRRangeOrStackLoc, Sequence["GPRRange"]):
55 __slots__ = "start", "length"
56
57 def __init__(self, start, length=None):
58 # type: (int | range, int | None) -> None
59 if isinstance(start, range):
60 if length is not None:
61 raise TypeError("can't specify length when input is a range")
62 if start.step != 1:
63 raise ValueError("range must have a step of 1")
64 length = len(start)
65 start = start.start
66 elif length is None:
67 length = 1
68 if length <= 0 or start < 0 or start + length > GPR_COUNT:
69 raise ValueError("invalid GPRRange")
70 self.start = start
71 self.length = length
72
73 @property
74 def stop(self):
75 return self.start + self.length
76
77 @property
78 def step(self):
79 return 1
80
81 @property
82 def range(self):
83 return range(self.start, self.stop, self.step)
84
85 def __len__(self):
86 return self.length
87
88 def __getitem__(self, item):
89 # type: (int | slice) -> GPRRange
90 return GPRRange(self.range[item])
91
92 def __contains__(self, value):
93 # type: (GPRRange) -> bool
94 return value.start >= self.start and value.stop <= self.stop
95
96 def index(self, sub, start=None, end=None):
97 # type: (GPRRange, int | None, int | None) -> int
98 r = self.range[start:end]
99 if sub.start < r.start or sub.stop > r.stop:
100 raise ValueError("GPR range not found")
101 return sub.start - self.start
102
103 def count(self, sub, start=None, end=None):
104 # type: (GPRRange, int | None, int | None) -> int
105 r = self.range[start:end]
106 if len(r) == 0:
107 return 0
108 return int(sub in GPRRange(r))
109
110 def conflicts(self, other):
111 # type: (RegLoc) -> bool
112 if isinstance(other, GPRRange):
113 return self.stop > other.start and other.stop > self.start
114 return False
115
116
117 SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
118
119
120 @final
121 @unique
122 class XERBit(Enum, RegLoc):
123 CY = "CY"
124
125 def conflicts(self, other):
126 # type: (RegLoc) -> bool
127 if isinstance(other, XERBit):
128 return self == other
129 return False
130
131
132 @final
133 @unique
134 class GlobalMem(Enum, RegLoc):
135 """singleton representing all non-StackSlot memory -- treated as a single
136 physical register for register allocation purposes.
137 """
138 GlobalMem = "GlobalMem"
139
140 def conflicts(self, other):
141 # type: (RegLoc) -> bool
142 if isinstance(other, GlobalMem):
143 return self == other
144 return False
145
146
147 @final
148 class RegClass(AbstractSet[RegLoc]):
149 def __init__(self, regs):
150 # type: (Iterable[RegLoc]) -> None
151 self.__regs = frozenset(regs)
152
153 def __len__(self):
154 return len(self.__regs)
155
156 def __iter__(self):
157 return iter(self.__regs)
158
159 def __contains__(self, v):
160 # type: (RegLoc) -> bool
161 return v in self.__regs
162
163 def __hash__(self):
164 return super()._hash()
165
166
167 @plain_data(frozen=True, unsafe_hash=True)
168 class RegType(metaclass=ABCMeta):
169 __slots__ = ()
170
171 @property
172 @abstractmethod
173 def reg_class(self):
174 # type: () -> RegClass
175 return ...
176
177
178 @plain_data(frozen=True, eq=False)
179 class GPRRangeType(RegType):
180 __slots__ = "length",
181
182 def __init__(self, length):
183 # type: (int) -> None
184 if length < 1 or length > GPR_COUNT:
185 raise ValueError("invalid length")
186 self.length = length
187
188 @staticmethod
189 @lru_cache()
190 def __get_reg_class(length):
191 # type: (int) -> RegClass
192 regs = []
193 for start in range(GPR_COUNT - length):
194 reg = GPRRange(start, length)
195 if any(i in reg for i in SPECIAL_GPRS):
196 continue
197 regs.append(reg)
198 return RegClass(regs)
199
200 @property
201 def reg_class(self):
202 # type: () -> RegClass
203 return GPRRangeType.__get_reg_class(self.length)
204
205 @final
206 def __eq__(self, other):
207 if isinstance(other, GPRRangeType):
208 return self.length == other.length
209 return False
210
211 @final
212 def __hash__(self):
213 return hash(self.length)
214
215
216 @plain_data(frozen=True, eq=False)
217 @final
218 class GPRType(GPRRangeType):
219 __slots__ = ()
220
221 def __init__(self, length=1):
222 if length != 1:
223 raise ValueError("length must be 1")
224 super().__init__(length=1)
225
226
227 @plain_data(frozen=True, unsafe_hash=True)
228 @final
229 class CYType(RegType):
230 __slots__ = ()
231
232 @property
233 def reg_class(self):
234 # type: () -> RegClass
235 return RegClass([XERBit.CY])
236
237
238 @plain_data(frozen=True, unsafe_hash=True)
239 @final
240 class GlobalMemType(RegType):
241 __slots__ = ()
242
243 @property
244 def reg_class(self):
245 # type: () -> RegClass
246 return RegClass([GlobalMem.GlobalMem])
247
248
249 @plain_data()
250 @final
251 class StackSlot(GPRRangeOrStackLoc):
252 """a stack slot. Use OpCopy to load from/store into this stack slot."""
253 __slots__ = "offset", "length"
254
255 def __init__(self, offset=None, length=1):
256 # type: (int | None, int) -> None
257 self.offset = offset
258 if length < 1:
259 raise ValueError("invalid length")
260 self.length = length
261
262 def __len__(self):
263 return self.length
264
265
266 _RegType = TypeVar("_RegType", bound=RegType)
267
268
269 @plain_data(frozen=True, eq=False)
270 @final
271 class SSAVal(Generic[_RegType]):
272 __slots__ = "op", "arg_name", "ty", "arg_index"
273
274 def __init__(self, op, arg_name, ty):
275 # type: (Op, str, _RegType) -> None
276 self.op = op
277 """the Op that writes this SSAVal"""
278
279 self.arg_name = arg_name
280 """the name of the argument of self.op that writes this SSAVal"""
281
282 self.ty = ty
283
284 def __eq__(self, rhs):
285 if isinstance(rhs, SSAVal):
286 return (self.op is rhs.op
287 and self.arg_name == rhs.arg_name)
288 return False
289
290 def __hash__(self):
291 return hash((id(self.op), self.arg_name))
292
293
294 @final
295 @plain_data(unsafe_hash=True, frozen=True)
296 class EqualityConstraint:
297 __slots__ = "lhs", "rhs"
298
299 def __init__(self, lhs, rhs):
300 # type: (list[SSAVal], list[SSAVal]) -> None
301 self.lhs = lhs
302 self.rhs = rhs
303 if len(lhs) == 0 or len(rhs) == 0:
304 raise ValueError("can't constrain an empty list to be equal")
305
306
307 @plain_data(unsafe_hash=True, frozen=True)
308 class Op(metaclass=ABCMeta):
309 __slots__ = ()
310
311 @abstractmethod
312 def inputs(self):
313 # type: () -> dict[str, SSAVal]
314 ...
315
316 @abstractmethod
317 def outputs(self):
318 # type: () -> dict[str, SSAVal]
319 ...
320
321 def get_equality_constraints(self):
322 # type: () -> Iterable[EqualityConstraint]
323 if False:
324 yield ...
325
326 def __init__(self):
327 pass
328
329
330 @plain_data(unsafe_hash=True, frozen=True)
331 @final
332 class OpCopy(Op, Generic[_RegType]):
333 __slots__ = "dest", "src"
334
335 def inputs(self):
336 # type: () -> dict[str, SSAVal]
337 return {"src": self.src}
338
339 def outputs(self):
340 # type: () -> dict[str, SSAVal]
341 return {"dest": self.dest}
342
343 def __init__(self, src):
344 # type: (SSAVal[_RegType]) -> None
345 self.dest = SSAVal(self, "dest", src.ty)
346 self.src = src
347
348
349 @plain_data(unsafe_hash=True, frozen=True)
350 @final
351 class OpConcat(Op):
352 __slots__ = "dest", "sources"
353
354 def inputs(self):
355 # type: () -> dict[str, SSAVal]
356 return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
357
358 def outputs(self):
359 # type: () -> dict[str, SSAVal]
360 return {"dest": self.dest}
361
362 def __init__(self, sources):
363 # type: (Iterable[SSAVal[GPRRangeType]]) -> None
364 sources = tuple(sources)
365 self.dest = SSAVal(self, "dest", GPRRangeType(
366 sum(i.ty.length for i in sources)))
367 self.sources = sources
368
369 def get_equality_constraints(self):
370 # type: () -> Iterable[EqualityConstraint]
371 yield EqualityConstraint([self.dest], [*self.sources])
372
373
374 @plain_data(unsafe_hash=True, frozen=True)
375 @final
376 class OpSplit(Op):
377 __slots__ = "results", "src"
378
379 def inputs(self):
380 # type: () -> dict[str, SSAVal]
381 return {"src": self.src}
382
383 def outputs(self):
384 # type: () -> dict[str, SSAVal]
385 return {i.arg_name: i for i in self.results}
386
387 def __init__(self, src, split_indexes):
388 # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
389 ranges = [] # type: list[GPRRangeType]
390 last = 0
391 for i in split_indexes:
392 if not (0 < i < src.ty.length):
393 raise ValueError(f"invalid split index: {i}, must be in "
394 f"0 < i < {src.ty.length}")
395 ranges.append(GPRRangeType(i - last))
396 last = i
397 ranges.append(GPRRangeType(src.ty.length - last))
398 self.src = src
399 self.results = tuple(
400 SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges))
401
402 def get_equality_constraints(self):
403 # type: () -> Iterable[EqualityConstraint]
404 yield EqualityConstraint([*self.results], [self.src])
405
406
407 @plain_data(unsafe_hash=True, frozen=True)
408 @final
409 class OpAddSubE(Op):
410 __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
411
412 def inputs(self):
413 # type: () -> dict[str, SSAVal]
414 return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
415
416 def outputs(self):
417 # type: () -> dict[str, SSAVal]
418 return {"RT": self.RT, "CY_out": self.CY_out}
419
420 def __init__(self, RA, RB, CY_in, is_sub):
421 # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
422 if RA.ty != RB.ty:
423 raise TypeError(f"source types must match: "
424 f"{RA} doesn't match {RB}")
425 self.RT = SSAVal(self, "RT", RA.ty)
426 self.RA = RA
427 self.RB = RB
428 self.CY_in = CY_in
429 self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
430 self.is_sub = is_sub
431
432
433 @plain_data(unsafe_hash=True, frozen=True)
434 @final
435 class OpBigIntMulDiv(Op):
436 __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
437
438 def inputs(self):
439 # type: () -> dict[str, SSAVal]
440 return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
441
442 def outputs(self):
443 # type: () -> dict[str, SSAVal]
444 return {"RT": self.RT, "RS": self.RS}
445
446 def __init__(self, RA, RB, RC, is_div):
447 # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
448 self.RT = SSAVal(self, "RT", RA.ty)
449 self.RA = RA
450 self.RB = RB
451 self.RC = RC
452 self.RS = SSAVal(self, "RS", RC.ty)
453 self.is_div = is_div
454
455 def get_equality_constraints(self):
456 # type: () -> Iterable[EqualityConstraint]
457 yield EqualityConstraint([self.RC], [self.RS])
458
459
460 @final
461 @unique
462 class ShiftKind(Enum):
463 Sl = "sl"
464 Sr = "sr"
465 Sra = "sra"
466
467
468 @plain_data(unsafe_hash=True, frozen=True)
469 @final
470 class OpBigIntShift(Op):
471 __slots__ = "RT", "inp", "sh", "kind"
472
473 def inputs(self):
474 # type: () -> dict[str, SSAVal]
475 return {"inp": self.inp, "sh": self.sh}
476
477 def outputs(self):
478 # type: () -> dict[str, SSAVal]
479 return {"RT": self.RT}
480
481 def __init__(self, inp, sh, kind):
482 # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
483 self.RT = SSAVal(self, "RT", inp.ty)
484 self.inp = inp
485 self.sh = sh
486 self.kind = kind
487
488
489 @plain_data(unsafe_hash=True, frozen=True)
490 @final
491 class OpLI(Op):
492 __slots__ = "out", "value"
493
494 def inputs(self):
495 # type: () -> dict[str, SSAVal]
496 return {}
497
498 def outputs(self):
499 # type: () -> dict[str, SSAVal]
500 return {"out": self.out}
501
502 def __init__(self, value, length=1):
503 # type: (int, int) -> None
504 self.out = SSAVal(self, "out", GPRRangeType(length))
505 self.value = value
506
507
508 @plain_data(unsafe_hash=True, frozen=True)
509 @final
510 class OpClearCY(Op):
511 __slots__ = "out",
512
513 def inputs(self):
514 # type: () -> dict[str, SSAVal]
515 return {}
516
517 def outputs(self):
518 # type: () -> dict[str, SSAVal]
519 return {"out": self.out}
520
521 def __init__(self):
522 # type: () -> None
523 self.out = SSAVal(self, "out", CYType())
524
525
526 @plain_data(unsafe_hash=True, frozen=True)
527 @final
528 class OpLoad(Op):
529 __slots__ = "RT", "RA", "offset", "mem"
530
531 def inputs(self):
532 # type: () -> dict[str, SSAVal]
533 return {"RA": self.RA, "mem": self.mem}
534
535 def outputs(self):
536 # type: () -> dict[str, SSAVal]
537 return {"RT": self.RT}
538
539 def __init__(self, RA, offset, mem, length=1):
540 # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
541 self.RT = SSAVal(self, "RT", GPRRangeType(length))
542 self.RA = RA
543 self.offset = offset
544 self.mem = mem
545
546
547 @plain_data(unsafe_hash=True, frozen=True)
548 @final
549 class OpStore(Op):
550 __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
551
552 def inputs(self):
553 # type: () -> dict[str, SSAVal]
554 return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
555
556 def outputs(self):
557 # type: () -> dict[str, SSAVal]
558 return {"mem_out": self.mem_out}
559
560 def __init__(self, RS, RA, offset, mem_in):
561 # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
562 self.RS = RS
563 self.RA = RA
564 self.offset = offset
565 self.mem_in = mem_in
566 self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
567
568
569 @plain_data(unsafe_hash=True, frozen=True)
570 @final
571 class OpFuncArg(Op):
572 __slots__ = "out",
573
574 def inputs(self):
575 # type: () -> dict[str, SSAVal]
576 return {}
577
578 def outputs(self):
579 # type: () -> dict[str, SSAVal]
580 return {"out": self.out}
581
582 def __init__(self, ty):
583 # type: (RegType) -> None
584 self.out = SSAVal(self, "out", ty)
585
586
587 @plain_data(unsafe_hash=True, frozen=True)
588 @final
589 class OpInputMem(Op):
590 __slots__ = "out",
591
592 def inputs(self):
593 # type: () -> dict[str, SSAVal]
594 return {}
595
596 def outputs(self):
597 # type: () -> dict[str, SSAVal]
598 return {"out": self.out}
599
600 def __init__(self):
601 # type: () -> None
602 self.out = SSAVal(self, "out", GlobalMemType())
603
604
605 def op_set_to_list(ops):
606 # type: (Iterable[Op]) -> list[Op]
607 worklists = [set()] # type: list[set[Op]]
608 input_vals_to_ops_map = defaultdict(set) # type: dict[SSAVal, set[Op]]
609 ops_to_pending_input_count_map = {} # type: dict[Op, int]
610 for op in ops:
611 input_count = 0
612 for val in op.inputs().values():
613 input_count += 1
614 input_vals_to_ops_map[val].add(op)
615 while len(worklists) <= input_count:
616 worklists.append(set())
617 ops_to_pending_input_count_map[op] = input_count
618 worklists[input_count].add(op)
619 retval = [] # type: list[Op]
620 ready_vals = set() # type: set[SSAVal]
621 while len(worklists[0]) != 0:
622 writing_op = worklists[0].pop()
623 retval.append(writing_op)
624 for val in writing_op.outputs().values():
625 if val in ready_vals:
626 raise ValueError(f"multiple instructions must not write "
627 f"to the same SSA value: {val}")
628 ready_vals.add(val)
629 for reading_op in input_vals_to_ops_map[val]:
630 pending = ops_to_pending_input_count_map[reading_op]
631 worklists[pending].remove(reading_op)
632 pending -= 1
633 worklists[pending].add(reading_op)
634 ops_to_pending_input_count_map[reading_op] = pending
635 for worklist in worklists:
636 for op in worklist:
637 raise ValueError(f"instruction is part of a dependency loop or "
638 f"its inputs are never written: {op}")
639 return retval
640
641
642 @plain_data(unsafe_hash=True, order=True, frozen=True)
643 class LiveInterval:
644 __slots__ = "first_write", "last_use"
645
646 def __init__(self, first_write, last_use=None):
647 # type: (int, int | None) -> None
648 if last_use is None:
649 last_use = first_write
650 if last_use < first_write:
651 raise ValueError("uses must be after first_write")
652 if first_write < 0 or last_use < 0:
653 raise ValueError("indexes must be nonnegative")
654 self.first_write = first_write
655 self.last_use = last_use
656
657 def overlaps(self, other):
658 # type: (LiveInterval) -> bool
659 if self.first_write == other.first_write:
660 return True
661 return self.last_use > other.first_write \
662 and other.last_use > self.first_write
663
664 def __add__(self, use):
665 # type: (int) -> LiveInterval
666 last_use = max(self.last_use, use)
667 return LiveInterval(first_write=self.first_write, last_use=last_use)
668
669
670 @final
671 class MergedRegSet(Mapping[SSAVal[_RegType], int]):
672 def __init__(self, reg_set):
673 # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
674 self.__items = {} # type: dict[SSAVal[_RegType], int]
675 if isinstance(reg_set, SSAVal):
676 reg_set = [(reg_set, 0)]
677 for ssa_val, offset in reg_set:
678 if ssa_val in self.__items:
679 other = self.__items[ssa_val]
680 if offset != other:
681 raise ValueError(
682 f"can't merge register sets: conflicting offsets: "
683 f"for {ssa_val}: {offset} != {other}")
684 else:
685 self.__items[ssa_val] = offset
686 first_item = None
687 for i in self.__items.items():
688 first_item = i
689 break
690 if first_item is None:
691 raise ValueError("can't have empty MergedRegs")
692 first_ssa_val, start = first_item
693 ty = first_ssa_val.ty
694 if isinstance(ty, GPRRangeType):
695 stop = start + ty.length
696 for ssa_val, offset in self.__items.items():
697 if not isinstance(ssa_val.ty, GPRRangeType):
698 raise ValueError(f"can't merge incompatible types: "
699 f"{ssa_val.ty} and {ty}")
700 stop = max(stop, offset + ssa_val.ty.length)
701 start = min(start, offset)
702 ty = GPRRangeType(stop - start)
703 else:
704 stop = 1
705 for ssa_val, offset in self.__items.items():
706 if offset != 0:
707 raise ValueError(f"can't have non-zero offset "
708 f"for {ssa_val.ty}")
709 if ty != ssa_val.ty:
710 raise ValueError(f"can't merge incompatible types: "
711 f"{ssa_val.ty} and {ty}")
712 self.__start = start # type: int
713 self.__stop = stop # type: int
714 self.__ty = ty # type: RegType
715
716 @staticmethod
717 def from_equality_constraint(constraint_sequence):
718 # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
719 if len(constraint_sequence) == 1:
720 # any type allowed with len = 1
721 return MergedRegSet(constraint_sequence[0])
722 offset = 0
723 retval = []
724 for val in constraint_sequence:
725 if not isinstance(val.ty, GPRRangeType):
726 raise ValueError("equality constraint sequences must only "
727 "have SSAVal type GPRRangeType")
728 retval.append((val, offset))
729 offset += val.ty.length
730 return MergedRegSet(retval)
731
732 @property
733 def ty(self):
734 return self.__ty
735
736 @property
737 def stop(self):
738 return self.__stop
739
740 @property
741 def start(self):
742 return self.__start
743
744 @property
745 def range(self):
746 return range(self.__start, self.__stop)
747
748 def offset_by(self, amount):
749 # type: (int) -> MergedRegSet[_RegType]
750 return MergedRegSet((k, v + amount) for k, v in self.items())
751
752 def normalized(self):
753 # type: () -> MergedRegSet[_RegType]
754 return self.offset_by(-self.start)
755
756 def with_offset_to_match(self, target):
757 # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
758 for ssa_val, offset in self.items():
759 if ssa_val in target:
760 return self.offset_by(target[ssa_val] - offset)
761 raise ValueError("can't change offset to match unrelated MergedRegSet")
762
763 def __getitem__(self, item):
764 # type: (SSAVal[_RegType]) -> int
765 return self.__items[item]
766
767 def __iter__(self):
768 return iter(self.__items)
769
770 def __len__(self):
771 return len(self.__items)
772
773 def __hash__(self):
774 return hash(frozenset(self.items()))
775
776 def __repr__(self):
777 return f"MergedRegSet({list(self.__items.items())})"
778
779
780 @final
781 class MergedRegSets(Mapping[SSAVal, MergedRegSet]):
782 def __init__(self, ops):
783 # type: (Iterable[Op]) -> None
784 merged_sets = {} # type: dict[SSAVal, MergedRegSet]
785 for op in ops:
786 for val in (*op.inputs().values(), *op.outputs().values()):
787 if val not in merged_sets:
788 merged_sets[val] = MergedRegSet(val)
789 for e in op.get_equality_constraints():
790 lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
791 rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
792 lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
793 rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
794 full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
795 for val in full_set.keys():
796 merged_sets[val] = full_set
797
798 self.__map = {k: v.normalized() for k, v in merged_sets.items()}
799
800 def __getitem__(self, key):
801 # type: (SSAVal) -> MergedRegSet
802 return self.__map[key]
803
804 def __iter__(self):
805 return iter(self.__map)
806
807 def __len__(self):
808 return len(self.__map)
809
810
811 @final
812 class LiveIntervals(Mapping[MergedRegSet, LiveInterval]):
813 def __init__(self, ops):
814 # type: (list[Op]) -> None
815 self.__merges_reg_sets = MergedRegSets(ops)
816 live_intervals = {} # type: dict[MergedRegSet, LiveInterval]
817 for op_idx, op in enumerate(ops):
818 for val in op.inputs().values():
819 live_intervals[self.__merges_reg_sets[val]] += op_idx
820 for val in op.outputs().values():
821 reg_set = self.__merges_reg_sets[val]
822 if reg_set not in live_intervals:
823 live_intervals[reg_set] = LiveInterval(op_idx)
824 else:
825 live_intervals[reg_set] += op_idx
826 self.__live_intervals = live_intervals
827
828 @property
829 def merges_reg_sets(self):
830 return self.__merges_reg_sets
831
832 def __getitem__(self, key):
833 # type: (MergedRegSet) -> LiveInterval
834 return self.__live_intervals[key]
835
836 def __iter__(self):
837 return iter(self.__live_intervals)
838
839
840 @final
841 class IGNode:
842 """ interference graph node """
843 __slots__ = "merged_reg_set", "edges"
844
845 def __init__(self, merged_reg_set, edges=()):
846 # type: (MergedRegSet, Iterable[IGNode]) -> None
847 self.merged_reg_set = merged_reg_set
848 self.edges = set(edges)
849
850 def add_edge(self, other):
851 # type: (IGNode) -> None
852 self.edges.add(other)
853 other.edges.add(self)
854
855 def __eq__(self, other):
856 # type: (object) -> bool
857 if isinstance(other, IGNode):
858 return self.merged_reg_set == other.merged_reg_set
859 return NotImplemented
860
861 def __hash__(self):
862 return hash(self.merged_reg_set)
863
864 def __repr__(self, nodes=None):
865 # type: (None | dict[IGNode, int]) -> str
866 if nodes is None:
867 nodes = {}
868 if self in nodes:
869 return f"<IGNode #{nodes[self]}>"
870 nodes[self] = len(nodes)
871 edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
872 return (f"IGNode(#{nodes[self]}, "
873 f"merged_reg_set={self.merged_reg_set}, "
874 f"edges={edges})")
875
876
877 @final
878 class InterferenceGraph(Mapping[MergedRegSet, IGNode]):
879 def __init__(self, merged_reg_sets):
880 # type: (Iterable[MergedRegSet]) -> None
881 self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
882
883 def __getitem__(self, key):
884 # type: (MergedRegSet) -> IGNode
885 return self.__nodes[key]
886
887 def __iter__(self):
888 return iter(self.__nodes)
889
890
891 @plain_data()
892 class AllocationFailed:
893 __slots__ = "op_idx", "arg", "live_intervals"
894
895 def __init__(self, op_idx, arg, live_intervals):
896 # type: (int, SSAVal, LiveIntervals) -> None
897 self.op_idx = op_idx
898 self.arg = arg
899 self.live_intervals = live_intervals
900
901
902 def try_allocate_registers_without_spilling(ops):
903 # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed
904
905 live_intervals = LiveIntervals(ops)
906
907 raise NotImplementedError
908
909
910 def allocate_registers(ops):
911 # type: (list[Op]) -> None
912 raise NotImplementedError