dd9061b321108cfa455b3642b54d2e2e58983d07
[bigint-presentation-code.git] / src / bigint_presentation_code / toom_cook.py
1 from abc import ABCMeta, abstractmethod
2 from collections import defaultdict
3 from enum import Enum, unique
4 from typing import AbstractSet, Iterable, Mapping, TYPE_CHECKING
5
6 from nmutil.plain_data import plain_data
7
8 if TYPE_CHECKING:
9 from typing_extensions import final, Self
10 else:
11 def final(v):
12 return v
13
14
15 @plain_data(frozen=True, unsafe_hash=True)
16 class PhysLoc:
17 pass
18
19
20 @plain_data(frozen=True, unsafe_hash=True)
21 class GPROrStackLoc(PhysLoc):
22 pass
23
24
25 @final
26 class GPR(GPROrStackLoc, Enum):
27 def __init__(self, reg_num):
28 # type: (int) -> None
29 self.reg_num = reg_num
30 # fmt: off
31 R0 = 0; R1 = 1; R2 = 2; R3 = 3; R4 = 4; R5 = 5
32 R6 = 6; R7 = 7; R8 = 8; R9 = 9; R10 = 10; R11 = 11
33 R12 = 12; R13 = 13; R14 = 14; R15 = 15; R16 = 16; R17 = 17
34 R18 = 18; R19 = 19; R20 = 20; R21 = 21; R22 = 22; R23 = 23
35 R24 = 24; R25 = 25; R26 = 26; R27 = 27; R28 = 28; R29 = 29
36 R30 = 30; R31 = 31; R32 = 32; R33 = 33; R34 = 34; R35 = 35
37 R36 = 36; R37 = 37; R38 = 38; R39 = 39; R40 = 40; R41 = 41
38 R42 = 42; R43 = 43; R44 = 44; R45 = 45; R46 = 46; R47 = 47
39 R48 = 48; R49 = 49; R50 = 50; R51 = 51; R52 = 52; R53 = 53
40 R54 = 54; R55 = 55; R56 = 56; R57 = 57; R58 = 58; R59 = 59
41 R60 = 60; R61 = 61; R62 = 62; R63 = 63; R64 = 64; R65 = 65
42 R66 = 66; R67 = 67; R68 = 68; R69 = 69; R70 = 70; R71 = 71
43 R72 = 72; R73 = 73; R74 = 74; R75 = 75; R76 = 76; R77 = 77
44 R78 = 78; R79 = 79; R80 = 80; R81 = 81; R82 = 82; R83 = 83
45 R84 = 84; R85 = 85; R86 = 86; R87 = 87; R88 = 88; R89 = 89
46 R90 = 90; R91 = 91; R92 = 92; R93 = 93; R94 = 94; R95 = 95
47 R96 = 96; R97 = 97; R98 = 98; R99 = 99; R100 = 100; R101 = 101
48 R102 = 102; R103 = 103; R104 = 104; R105 = 105; R106 = 106; R107 = 107
49 R108 = 108; R109 = 109; R110 = 110; R111 = 111; R112 = 112; R113 = 113
50 R114 = 114; R115 = 115; R116 = 116; R117 = 117; R118 = 118; R119 = 119
51 R120 = 120; R121 = 121; R122 = 122; R123 = 123; R124 = 124; R125 = 125
52 R126 = 126; R127 = 127
53 # fmt: on
54 SP = 1
55 TOC = 2
56
57
58 SPECIAL_GPRS = GPR.R0, GPR.SP, GPR.TOC, GPR.R13
59
60
61 @final
62 @unique
63 class XERBit(Enum, PhysLoc):
64 CY = "CY"
65
66
67 @final
68 @unique
69 class GlobalMem(Enum, PhysLoc):
70 """singleton representing all non-StackSlot memory"""
71 GlobalMem = "GlobalMem"
72
73
74 ALLOCATABLE_REGS = frozenset((set(GPR) - set(SPECIAL_GPRS))
75 | set(XERBit) | set(GlobalMem))
76
77
78 @plain_data()
79 @final
80 class StackSlot(GPROrStackLoc):
81 """a stack slot. Use OpCopy to load from/store into this stack slot."""
82 __slots__ = "offset",
83
84 def __init__(self, offset=None):
85 # type: (int | None) -> None
86 self.offset = offset
87
88
89 class SSAVal(metaclass=ABCMeta):
90 __slots__ = "op", "arg_name", "element_index"
91
92 def __init__(self, op, arg_name, element_index):
93 # type: (Op, str, int) -> None
94 self.op = op
95 """the Op that writes this SSAVal"""
96
97 self.arg_name = arg_name
98 self.element_index = element_index
99
100 @final
101 def __eq__(self, rhs):
102 if isinstance(rhs, SSAVal):
103 return (self.op is rhs.op
104 and self.arg_name == rhs.arg_name
105 and self.element_index == rhs.element_index)
106 return False
107
108 @final
109 def __hash__(self):
110 return hash((id(self.op), self.arg_name, self.element_index))
111
112 def _get_phys_loc(self, phys_loc_in, value_assignments=None):
113 # type: (PhysLoc | None, dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
114 if phys_loc_in is not None:
115 return phys_loc_in
116 if value_assignments is not None:
117 return value_assignments.get(self)
118 return None
119
120 @abstractmethod
121 def get_phys_loc(self, value_assignments=None):
122 # type: (dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
123 ...
124
125 @final
126 def __repr__(self):
127 name = self.__class__.__name__
128 op = object.__repr__(self.op)
129 phys_loc = self.get_phys_loc()
130 return (f"{name}(op={op}, arg_name={self.arg_name}, "
131 f"element_index={self.element_index}, phys_loc={phys_loc})")
132
133 @final
134 def like(self, op, arg_name):
135 # type: (Op, str) -> Self
136 """create a new SSAVal based off of self's type.
137 has same signature as VecArg.like.
138 """
139 return self.__class__(op=op, arg_name=arg_name,
140 element_index=0)
141
142
143 @final
144 class SSAGPRVal(SSAVal):
145 __slots__ = "phys_loc",
146
147 def __init__(self, op, arg_name, element_index, phys_loc=None):
148 # type: (Op, str, int, GPROrStackLoc | None) -> None
149 super().__init__(op, arg_name, element_index)
150 self.phys_loc = phys_loc
151
152 def __len__(self):
153 return 1
154
155 def get_phys_loc(self, value_assignments=None):
156 # type: (dict[SSAVal, PhysLoc] | None) -> GPROrStackLoc | None
157 loc = self._get_phys_loc(self.phys_loc, value_assignments)
158 if isinstance(loc, GPROrStackLoc):
159 return loc
160 return None
161
162 def get_reg_num(self, value_assignments=None):
163 # type: (dict[SSAVal, PhysLoc] | None) -> int | None
164 reg = self.get_reg(value_assignments)
165 if reg is not None:
166 return reg.reg_num
167 return None
168
169 def get_reg(self, value_assignments=None):
170 # type: (dict[SSAVal, PhysLoc] | None) -> GPR | None
171 loc = self.get_phys_loc(value_assignments)
172 if isinstance(loc, GPR):
173 return loc
174 return None
175
176 def get_stack_slot(self, value_assignments=None):
177 # type: (dict[SSAVal, PhysLoc] | None) -> StackSlot | None
178 loc = self.get_phys_loc(value_assignments)
179 if isinstance(loc, StackSlot):
180 return loc
181 return None
182
183 def possible_reg_assignments(self, value_assignments,
184 conflicting_regs=set()):
185 # type: (dict[SSAVal, PhysLoc] | None, set[GPR]) -> Iterable[GPR]
186 if self.get_phys_loc(value_assignments) is not None:
187 raise ValueError("can't assign a already-assigned SSA value")
188 for reg in GPR:
189 if reg not in conflicting_regs:
190 yield reg
191
192
193 @final
194 class SSAXERBitVal(SSAVal):
195 __slots__ = "phys_loc",
196
197 def __init__(self, op, arg_name, element_index, phys_loc=None):
198 # type: (Op, str, int, XERBit | None) -> None
199 super().__init__(op, arg_name, element_index)
200 self.phys_loc = phys_loc
201
202 def get_phys_loc(self, value_assignments=None):
203 # type: (dict[SSAVal, PhysLoc] | None) -> XERBit | None
204 loc = self._get_phys_loc(self.phys_loc, value_assignments)
205 if isinstance(loc, XERBit):
206 return loc
207 return None
208
209
210 @final
211 class SSAMemory(SSAVal):
212 __slots__ = "phys_loc",
213
214 def __init__(self, op, arg_name, element_index,
215 phys_loc=GlobalMem.GlobalMem):
216 # type: (Op, str, int, GlobalMem) -> None
217 super().__init__(op, arg_name, element_index)
218 self.phys_loc = phys_loc
219
220 def get_phys_loc(self, value_assignments=None):
221 # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem
222 loc = self._get_phys_loc(self.phys_loc, value_assignments)
223 if isinstance(loc, GlobalMem):
224 return loc
225 return self.phys_loc
226
227
228 @plain_data(unsafe_hash=True, frozen=True)
229 @final
230 class VecArg:
231 __slots__ = "regs",
232
233 def __init__(self, regs):
234 # type: (Iterable[SSAGPRVal]) -> None
235 self.regs = tuple(regs)
236
237 def __len__(self):
238 return len(self.regs)
239
240 def is_unassigned(self, value_assignments=None):
241 # type: (dict[SSAVal, PhysLoc] | None) -> bool
242 for val in self.regs:
243 if val.get_phys_loc(value_assignments) is not None:
244 return False
245 return True
246
247 def try_get_range(self, value_assignments=None, allow_unassigned=False,
248 raise_if_invalid=False):
249 # type: (dict[SSAVal, PhysLoc] | None, bool, bool) -> range | None
250 if len(self.regs) == 0:
251 return range(0)
252
253 retval = None # type: range | None
254 for i, val in enumerate(self.regs):
255 if val.get_phys_loc(value_assignments) is None:
256 if not allow_unassigned:
257 if raise_if_invalid:
258 raise ValueError("not a valid register range: "
259 "unassigned SSA value encountered")
260 return None
261 continue
262 reg = val.get_reg_num(value_assignments)
263 if reg is None:
264 if raise_if_invalid:
265 raise ValueError("not a valid register range: "
266 "non-register encountered")
267 return None
268 expected_range = range(reg - i, reg - i + len(self.regs))
269 if retval is None:
270 retval = expected_range
271 elif retval != expected_range:
272 if raise_if_invalid:
273 raise ValueError("not a valid register range: "
274 "register out of sequence")
275 return None
276 return retval
277
278 def possible_reg_assignments(
279 self,
280 val, # type: SSAVal
281 value_assignments, # type: dict[SSAVal, PhysLoc] | None
282 conflicting_regs=set(), # type: set[GPR]
283 ): # type: (...) -> Iterable[GPR]
284 index = self.regs.index(val)
285 alignment = 1
286 while alignment < len(self.regs):
287 alignment *= 2
288 r = self.try_get_range(value_assignments)
289 if r is not None and r.start % alignment != 0:
290 raise ValueError("must be a ascending aligned range of GPRs")
291 if r is None:
292 for i in range(0, len(GPR), alignment):
293 r = range(i, i + len(self.regs))
294 if any(GPR(reg) in conflicting_regs for reg in r):
295 continue
296 yield GPR(r[index])
297 else:
298 yield GPR(r[index])
299
300 def like(self, op, arg_name):
301 # type: (Op, str) -> VecArg
302 """create a new VecArg based off of self's type.
303 has same signature as SSAVal.like.
304 """
305 return VecArg(
306 SSAGPRVal(op, arg_name, i) for i in range(len(self.regs)))
307
308
309 def vec_or_scalar_arg(element_count, op, arg_name):
310 # type: (int | None, Op, str) -> VecArg | SSAGPRVal
311 if element_count is None:
312 return SSAGPRVal(op, arg_name, 0)
313 else:
314 return VecArg(SSAGPRVal(op, arg_name, i) for i in range(element_count))
315
316
317 @final
318 @plain_data(unsafe_hash=True, frozen=True)
319 class EqualityConstraint:
320 __slots__ = "lhs", "rhs"
321
322 def __init__(self, lhs, rhs):
323 # type: (SSAVal, SSAVal) -> None
324 self.lhs = lhs
325 self.rhs = rhs
326
327
328 @plain_data(unsafe_hash=True, frozen=True)
329 class Op(metaclass=ABCMeta):
330 __slots__ = ()
331
332 def input_ssa_vals(self):
333 # type: () -> Iterable[SSAVal]
334 for arg in self.inputs().values():
335 if isinstance(arg, VecArg):
336 yield from arg.regs
337 else:
338 yield arg
339
340 def output_ssa_vals(self):
341 # type: () -> Iterable[SSAVal]
342 for arg in self.outputs().values():
343 if isinstance(arg, VecArg):
344 yield from arg.regs
345 else:
346 yield arg
347
348 @abstractmethod
349 def inputs(self):
350 # type: () -> dict[str, VecArg | SSAVal]
351 ...
352
353 @abstractmethod
354 def outputs(self):
355 # type: () -> dict[str, VecArg | SSAVal]
356 ...
357
358 @abstractmethod
359 def possible_reg_assignments(self, val, value_assignments):
360 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
361 ...
362
363 def get_equality_constraints(self):
364 # type: () -> Iterable[EqualityConstraint]
365 if False:
366 yield ...
367
368 def __init__(self):
369 pass
370
371
372 @plain_data(unsafe_hash=True, frozen=True)
373 @final
374 class OpCopy(Op):
375 __slots__ = "dest", "src"
376
377 def inputs(self):
378 # type: () -> dict[str, VecArg | SSAVal]
379 return {"src": self.src}
380
381 def outputs(self):
382 # type: () -> dict[str, VecArg | SSAVal]
383 return {"dest": self.dest}
384
385 def __init__(self, src):
386 # type: (SSAGPRVal) -> None
387 self.dest = src.like(op=self, arg_name="dest")
388 self.src = src
389
390 def possible_reg_assignments(self, val, value_assignments):
391 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
392 if val not in self.input_ssa_vals() \
393 and val not in self.output_ssa_vals():
394 raise ValueError(f"{val} must be an operand of {self}")
395 if val.get_phys_loc(value_assignments) is not None:
396 raise ValueError(f"{val} already assigned a physical location")
397 if not isinstance(val, SSAGPRVal):
398 raise ValueError("invalid operand type")
399 return val.possible_reg_assignments(value_assignments)
400
401
402 def range_overlaps(range1, range2):
403 # type: (range, range) -> bool
404 if len(range1) == 0 or len(range2) == 0:
405 return False
406 range1_last = range1[-1]
407 range2_last = range2[-1]
408 return (range1.start in range2 or range1_last in range2 or
409 range2.start in range1 or range2_last in range1)
410
411
412 @plain_data(unsafe_hash=True, frozen=True)
413 @final
414 class OpAddSubE(Op):
415 __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
416
417 def inputs(self):
418 # type: () -> dict[str, VecArg | SSAVal]
419 return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
420
421 def outputs(self):
422 # type: () -> dict[str, VecArg | SSAVal]
423 return {"RT": self.RT, "CY_out": self.CY_out}
424
425 def __init__(self, RA, RB, CY_in, is_sub):
426 # type: (VecArg, VecArg, SSAXERBitVal, bool) -> None
427 if len(RA.regs) != len(RB.regs):
428 raise TypeError(f"source lengths must match: "
429 f"{RA} doesn't match {RB}")
430 self.RT = RA.like(op=self, arg_name="RT")
431 self.RA = RA
432 self.RB = RB
433 self.CY_in = CY_in
434 self.CY_out = CY_in.like(op=self, arg_name="CY_out")
435 self.is_sub = is_sub
436
437 def possible_reg_assignments(self, val, value_assignments):
438 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
439 if val not in self.input_ssa_vals() \
440 and val not in self.output_ssa_vals():
441 raise ValueError(f"{val} must be an operand of {self}")
442 if val.get_phys_loc(value_assignments) is not None:
443 raise ValueError(f"{val} already assigned a physical location")
444 if self.CY_in == val or self.CY_out == val:
445 yield XERBit.CY
446 elif val in self.RT.regs:
447 # since possible_reg_assignments only returns aligned
448 # vectors, all possible assignments either are the same as an
449 # input or don't overlap with an input and we avoid the incorrect
450 # results caused by partial overlaps overwriting input elements
451 # before they're read
452 yield from self.RT.possible_reg_assignments(val, value_assignments)
453 elif val in self.RA.regs:
454 yield from self.RA.possible_reg_assignments(val, value_assignments)
455 else:
456 yield from self.RB.possible_reg_assignments(val, value_assignments)
457
458 def get_equality_constraints(self):
459 # type: () -> Iterable[EqualityConstraint]
460 yield EqualityConstraint(self.CY_in, self.CY_out)
461
462
463 def to_reg_set(v):
464 # type: (None | GPR | range) -> set[GPR]
465 if v is None:
466 return set()
467 if isinstance(v, range):
468 return set(map(GPR, v))
469 return {v}
470
471
472 @plain_data(unsafe_hash=True, frozen=True)
473 @final
474 class OpBigIntMulDiv(Op):
475 __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
476
477 def inputs(self):
478 # type: () -> dict[str, VecArg | SSAVal]
479 return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
480
481 def outputs(self):
482 # type: () -> dict[str, VecArg | SSAVal]
483 return {"RT": self.RT, "RS": self.RS}
484
485 def __init__(self, RA, RB, RC, is_div):
486 # type: (VecArg, SSAGPRVal, SSAGPRVal, bool) -> None
487 self.RT = RA.like(op=self, arg_name="RT")
488 self.RA = RA
489 self.RB = RB
490 self.RC = RC
491 self.RS = RC.like(op=self, arg_name="RS")
492 self.is_div = is_div
493
494 def possible_reg_assignments(self, val, value_assignments):
495 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
496 if val not in self.input_ssa_vals() \
497 and val not in self.output_ssa_vals():
498 raise ValueError(f"{val} must be an operand of {self}")
499 if val.get_phys_loc(value_assignments) is not None:
500 raise ValueError(f"{val} already assigned a physical location")
501
502 RT_range = self.RT.try_get_range(value_assignments,
503 allow_unassigned=True,
504 raise_if_invalid=True)
505 RA_range = self.RA.try_get_range(value_assignments,
506 allow_unassigned=True,
507 raise_if_invalid=True)
508 RC_RS_reg = self.RC.get_reg(value_assignments)
509 if RC_RS_reg is None:
510 RC_RS_reg = self.RS.get_reg(value_assignments)
511
512 if self.RC == val or self.RS == val:
513 if RC_RS_reg is not None:
514 yield RC_RS_reg
515 else:
516 conflicting_regs = to_reg_set(RT_range) | to_reg_set(RA_range)
517 yield from self.RC.possible_reg_assignments(value_assignments,
518 conflicting_regs)
519 elif val in self.RT.regs:
520 # since possible_reg_assignments only returns aligned
521 # vectors, all possible assignments either are the same as
522 # RA or don't overlap with RA and we avoid the incorrect
523 # results caused by partial overlaps overwriting input elements
524 # before they're read
525 yield from self.RT.possible_reg_assignments(
526 val, value_assignments,
527 conflicting_regs=to_reg_set(RA_range) | to_reg_set(RC_RS_reg))
528 else:
529 yield from self.RA.possible_reg_assignments(
530 val, value_assignments,
531 conflicting_regs=to_reg_set(RT_range) | to_reg_set(RC_RS_reg))
532
533 def get_equality_constraints(self):
534 # type: () -> Iterable[EqualityConstraint]
535 yield EqualityConstraint(self.RC, self.RS)
536
537
538 @final
539 @unique
540 class ShiftKind(Enum):
541 Sl = "sl"
542 Sr = "sr"
543 Sra = "sra"
544
545
546 @plain_data(unsafe_hash=True, frozen=True)
547 @final
548 class OpBigIntShift(Op):
549 __slots__ = "RT", "inp", "sh", "kind"
550
551 def inputs(self):
552 # type: () -> dict[str, VecArg | SSAVal]
553 return {"inp": self.inp, "sh": self.sh}
554
555 def outputs(self):
556 # type: () -> dict[str, VecArg | SSAVal]
557 return {"RT": self.RT}
558
559 def __init__(self, inp, sh, kind):
560 # type: (VecArg, SSAGPRVal, ShiftKind) -> None
561 self.RT = inp.like(op=self, arg_name="RT")
562 self.inp = inp
563 self.sh = sh
564 self.kind = kind
565
566 def possible_reg_assignments(self, val, value_assignments):
567 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
568 if val not in self.input_ssa_vals() \
569 and val not in self.output_ssa_vals():
570 raise ValueError(f"{val} must be an operand of {self}")
571 if val.get_phys_loc(value_assignments) is not None:
572 raise ValueError(f"{val} already assigned a physical location")
573
574 RT_range = self.RT.try_get_range(value_assignments,
575 allow_unassigned=True,
576 raise_if_invalid=True)
577 inp_range = self.inp.try_get_range(value_assignments,
578 allow_unassigned=True,
579 raise_if_invalid=True)
580 sh_reg = self.sh.get_reg(value_assignments)
581
582 if self.sh == val:
583 conflicting_regs = to_reg_set(RT_range)
584 yield from self.sh.possible_reg_assignments(value_assignments,
585 conflicting_regs)
586 elif val in self.RT.regs:
587 # since possible_reg_assignments only returns aligned
588 # vectors, all possible assignments either are the same as
589 # RA or don't overlap with RA and we avoid the incorrect
590 # results caused by partial overlaps overwriting input elements
591 # before they're read
592 yield from self.RT.possible_reg_assignments(
593 val, value_assignments,
594 conflicting_regs=to_reg_set(inp_range) | to_reg_set(sh_reg))
595 else:
596 yield from self.inp.possible_reg_assignments(
597 val, value_assignments,
598 conflicting_regs=to_reg_set(RT_range))
599
600
601 @plain_data(unsafe_hash=True, frozen=True)
602 @final
603 class OpLI(Op):
604 __slots__ = "out", "value"
605
606 def inputs(self):
607 # type: () -> dict[str, VecArg | SSAVal]
608 return {}
609
610 def outputs(self):
611 # type: () -> dict[str, VecArg | SSAVal]
612 return {"out": self.out}
613
614 def __init__(self, value, element_count=None):
615 # type: (int, int | None) -> None
616 self.out = vec_or_scalar_arg(element_count, op=self, arg_name="out")
617 self.value = value
618
619 def possible_reg_assignments(self, val, value_assignments):
620 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
621 if val not in self.input_ssa_vals() \
622 and val not in self.output_ssa_vals():
623 raise ValueError(f"{val} must be an operand of {self}")
624 if val.get_phys_loc(value_assignments) is not None:
625 raise ValueError(f"{val} already assigned a physical location")
626
627 if isinstance(self.out, VecArg):
628 yield from self.out.possible_reg_assignments(val,
629 value_assignments)
630 else:
631 yield from self.out.possible_reg_assignments(value_assignments)
632
633
634 @plain_data(unsafe_hash=True, frozen=True)
635 @final
636 class OpClearCY(Op):
637 __slots__ = "out",
638
639 def inputs(self):
640 # type: () -> dict[str, VecArg | SSAVal]
641 return {}
642
643 def outputs(self):
644 # type: () -> dict[str, VecArg | SSAVal]
645 return {"out": self.out}
646
647 def __init__(self):
648 # type: () -> None
649 self.out = SSAXERBitVal(op=self, arg_name="out", element_index=0,
650 phys_loc=XERBit.CY)
651
652 def possible_reg_assignments(self, val, value_assignments):
653 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
654 if val not in self.input_ssa_vals() \
655 and val not in self.output_ssa_vals():
656 raise ValueError(f"{val} must be an operand of {self}")
657 if val.get_phys_loc(value_assignments) is not None:
658 raise ValueError(f"{val} already assigned a physical location")
659
660 yield XERBit.CY
661
662
663 @plain_data(unsafe_hash=True, frozen=True)
664 @final
665 class OpLoad(Op):
666 __slots__ = "RT", "RA", "offset", "mem"
667
668 def inputs(self):
669 # type: () -> dict[str, VecArg | SSAVal]
670 return {"RA": self.RA, "mem": self.mem}
671
672 def outputs(self):
673 # type: () -> dict[str, VecArg | SSAVal]
674 return {"RT": self.RT}
675
676 def __init__(self, RA, offset, mem, element_count=None):
677 # type: (SSAGPRVal, int, SSAMemory, int | None) -> None
678 self.RT = vec_or_scalar_arg(element_count, op=self, arg_name="RT")
679 self.RA = RA
680 self.offset = offset
681 self.mem = mem
682
683 def possible_reg_assignments(self, val, value_assignments):
684 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
685 if val not in self.input_ssa_vals() \
686 and val not in self.output_ssa_vals():
687 raise ValueError(f"{val} must be an operand of {self}")
688 if val.get_phys_loc(value_assignments) is not None:
689 raise ValueError(f"{val} already assigned a physical location")
690
691 RA_reg = self.RA.get_reg(value_assignments)
692
693 if self.mem == val:
694 yield GlobalMem.GlobalMem
695 elif self.RA == val:
696 if isinstance(self.RT, VecArg):
697 conflicting_regs = to_reg_set(self.RT.try_get_range(
698 value_assignments, allow_unassigned=True,
699 raise_if_invalid=True))
700 else:
701 conflicting_regs = set()
702 yield from self.RA.possible_reg_assignments(value_assignments,
703 conflicting_regs)
704 elif isinstance(self.RT, VecArg):
705 yield from self.RT.possible_reg_assignments(
706 val, value_assignments,
707 conflicting_regs=to_reg_set(RA_reg))
708 else:
709 yield from self.RT.possible_reg_assignments(value_assignments)
710
711
712 @plain_data(unsafe_hash=True, frozen=True)
713 @final
714 class OpStore(Op):
715 __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
716
717 def inputs(self):
718 # type: () -> dict[str, VecArg | SSAVal]
719 return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
720
721 def outputs(self):
722 # type: () -> dict[str, VecArg | SSAVal]
723 return {"mem_out": self.mem_out}
724
725 def __init__(self, RS, RA, offset, mem_in):
726 # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None
727 self.RS = RS
728 self.RA = RA
729 self.offset = offset
730 self.mem_in = mem_in
731 self.mem_out = mem_in.like(op=self, arg_name="mem_out")
732
733 def possible_reg_assignments(self, val, value_assignments):
734 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
735 if val not in self.input_ssa_vals() \
736 and val not in self.output_ssa_vals():
737 raise ValueError(f"{val} must be an operand of {self}")
738 if val.get_phys_loc(value_assignments) is not None:
739 raise ValueError(f"{val} already assigned a physical location")
740
741 if self.mem_in == val or self.mem_out == val:
742 yield GlobalMem.GlobalMem
743 elif self.RA == val:
744 yield from self.RA.possible_reg_assignments(value_assignments)
745 elif isinstance(self.RS, VecArg):
746 yield from self.RS.possible_reg_assignments(val, value_assignments)
747 else:
748 yield from self.RS.possible_reg_assignments(value_assignments)
749
750 def get_equality_constraints(self):
751 # type: () -> Iterable[EqualityConstraint]
752 yield EqualityConstraint(self.mem_in, self.mem_out)
753
754
755 @plain_data(unsafe_hash=True, frozen=True)
756 @final
757 class OpFuncArg(Op):
758 __slots__ = "out",
759
760 def inputs(self):
761 # type: () -> dict[str, VecArg | SSAVal]
762 return {}
763
764 def outputs(self):
765 # type: () -> dict[str, VecArg | SSAVal]
766 return {"out": self.out}
767
768 def __init__(self, phys_loc):
769 # type: (GPROrStackLoc | Iterable[GPROrStackLoc]) -> None
770 if isinstance(phys_loc, GPROrStackLoc):
771 self.out = SSAGPRVal(self, "out", 0, phys_loc)
772 else:
773 self.out = VecArg(
774 SSAGPRVal(self, "out", i, v) for i, v in enumerate(phys_loc))
775
776 def possible_reg_assignments(self, val, value_assignments):
777 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
778 if val not in self.input_ssa_vals() \
779 and val not in self.output_ssa_vals():
780 raise ValueError(f"{val} must be an operand of {self}")
781 if val.get_phys_loc(value_assignments) is not None:
782 raise ValueError(f"{val} already assigned a physical location")
783
784 if isinstance(self.out, VecArg):
785 yield from self.out.possible_reg_assignments(val,
786 value_assignments)
787 else:
788 yield from self.out.possible_reg_assignments(value_assignments)
789
790
791 @plain_data(unsafe_hash=True, frozen=True)
792 @final
793 class OpInputMem(Op):
794 __slots__ = "out",
795
796 def inputs(self):
797 # type: () -> dict[str, VecArg | SSAVal]
798 return {}
799
800 def outputs(self):
801 # type: () -> dict[str, VecArg | SSAVal]
802 return {"out": self.out}
803
804 def __init__(self):
805 # type: () -> None
806 self.out = SSAMemory(op=self, arg_name="out", element_index=0)
807
808 def possible_reg_assignments(self, val, value_assignments):
809 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
810 if val not in self.input_ssa_vals() \
811 and val not in self.output_ssa_vals():
812 raise ValueError(f"{val} must be an operand of {self}")
813 if val.get_phys_loc(value_assignments) is not None:
814 raise ValueError(f"{val} already assigned a physical location")
815
816 yield GlobalMem.GlobalMem
817
818
819 def op_set_to_list(ops):
820 # type: (Iterable[Op]) -> list[Op]
821 worklists = [set()] # type: list[set[Op]]
822 input_vals_to_ops_map = defaultdict(set) # type: dict[SSAVal, set[Op]]
823 ops_to_pending_input_count_map = {} # type: dict[Op, int]
824 for op in ops:
825 input_count = 0
826 for val in op.input_ssa_vals():
827 input_count += 1
828 input_vals_to_ops_map[val].add(op)
829 while len(worklists) <= input_count:
830 worklists.append(set())
831 ops_to_pending_input_count_map[op] = input_count
832 worklists[input_count].add(op)
833 retval = [] # type: list[Op]
834 ready_vals = set() # type: set[SSAVal]
835 while len(worklists[0]) != 0:
836 writing_op = worklists[0].pop()
837 retval.append(writing_op)
838 for val in writing_op.output_ssa_vals():
839 if val in ready_vals:
840 raise ValueError(f"multiple instructions must not write "
841 f"to the same SSA value: {val}")
842 ready_vals.add(val)
843 for reading_op in input_vals_to_ops_map[val]:
844 pending = ops_to_pending_input_count_map[reading_op]
845 worklists[pending].remove(reading_op)
846 pending -= 1
847 worklists[pending].add(reading_op)
848 ops_to_pending_input_count_map[reading_op] = pending
849 for worklist in worklists:
850 for op in worklist:
851 raise ValueError(f"instruction is part of a dependency loop or "
852 f"its inputs are never written: {op}")
853 return retval
854
855
856 @plain_data(unsafe_hash=True, order=True, frozen=True)
857 class LiveInterval:
858 __slots__ = "first_write", "last_use"
859
860 def __init__(self, first_write, last_use=None):
861 # type: (int, int | None) -> None
862 if last_use is None:
863 last_use = first_write
864 if last_use < first_write:
865 raise ValueError("uses must be after first_write")
866 if first_write < 0 or last_use < 0:
867 raise ValueError("indexes must be nonnegative")
868 self.first_write = first_write
869 self.last_use = last_use
870
871 def overlaps(self, other):
872 # type: (LiveInterval) -> bool
873 if self.first_write == other.first_write:
874 return True
875 return self.last_use > other.first_write \
876 and other.last_use > self.first_write
877
878 def __add__(self, use):
879 # type: (int) -> LiveInterval
880 last_use = max(self.last_use, use)
881 return LiveInterval(first_write=self.first_write, last_use=last_use)
882
883
884 @final
885 class EqualitySet(AbstractSet[SSAVal]):
886 def __init__(self, items):
887 # type: (Iterable[SSAVal]) -> None
888 self.__items = frozenset(items)
889
890 def __contains__(self, x):
891 # type: (object) -> bool
892 return x in self.__items
893
894 def __iter__(self):
895 return iter(self.__items)
896
897 def __len__(self):
898 return len(self.__items)
899
900 def __hash__(self):
901 return super()._hash()
902
903
904 @final
905 class EqualitySets(Mapping[SSAVal, EqualitySet]):
906 def __init__(self, ops):
907 # type: (Iterable[Op]) -> None
908 indexes = {} # type: dict[SSAVal, int]
909 sets = [] # type: list[set[SSAVal]]
910 for op in ops:
911 for val in (*op.input_ssa_vals(), *op.output_ssa_vals()):
912 if val not in indexes:
913 indexes[val] = len(sets)
914 sets.append({val})
915 for e in op.get_equality_constraints():
916 lhs_index = indexes[e.lhs]
917 rhs_index = indexes[e.rhs]
918 sets[lhs_index] |= sets[rhs_index]
919 for val in sets[rhs_index]:
920 indexes[val] = lhs_index
921
922 equality_sets = [EqualitySet(i) for i in sets]
923 self.__map = {k: equality_sets[v] for k, v in indexes.items()}
924
925 def __getitem__(self, key):
926 # type: (SSAVal) -> EqualitySet
927 return self.__map[key]
928
929 def __iter__(self):
930 return iter(self.__map)
931
932
933 @final
934 class LiveIntervals(Mapping[EqualitySet, LiveInterval]):
935 def __init__(self, ops):
936 # type: (list[Op]) -> None
937 self.__equality_sets = eqsets = EqualitySets(ops)
938 live_intervals = {} # type: dict[EqualitySet, LiveInterval]
939 for op_idx, op in enumerate(ops):
940 for val in op.input_ssa_vals():
941 live_intervals[eqsets[val]] += op_idx
942 for val in op.output_ssa_vals():
943 if eqsets[val] not in live_intervals:
944 live_intervals[eqsets[val]] = LiveInterval(op_idx)
945 else:
946 live_intervals[eqsets[val]] += op_idx
947 self.__live_intervals = live_intervals
948
949 @property
950 def equality_sets(self):
951 return self.__equality_sets
952
953 def __getitem__(self, key):
954 # type: (EqualitySet) -> LiveInterval
955 return self.__live_intervals[key]
956
957 def __iter__(self):
958 return iter(self.__live_intervals)
959
960
961 @final
962 class IGNode:
963 """ interference graph node """
964 __slots__ = "equality_set", "edges"
965
966 def __init__(self, equality_set, edges=()):
967 # type: (EqualitySet, Iterable[IGNode]) -> None
968 self.equality_set = equality_set
969 self.edges = set(edges)
970
971 def add_edge(self, other):
972 # type: (IGNode) -> None
973 self.edges.add(other)
974 other.edges.add(self)
975
976 def __eq__(self, other):
977 # type: (object) -> bool
978 if isinstance(other, IGNode):
979 return self.equality_set == other.equality_set
980 return NotImplemented
981
982 def __hash__(self):
983 return self.equality_set.__hash__()
984
985 def __repr__(self, nodes=None):
986 # type: (None | dict[IGNode, int]) -> str
987 if nodes is None:
988 nodes = {}
989 if self in nodes:
990 return f"<IGNode #{nodes[self]}>"
991 nodes[self] = len(nodes)
992 edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
993 return (f"IGNode(#{nodes[self]}, "
994 f"equality_set={self.equality_set}, "
995 f"edges={edges})")
996
997
998 @final
999 class InterferenceGraph(Mapping[EqualitySet, IGNode]):
1000 def __init__(self, equality_sets):
1001 # type: (Iterable[EqualitySet]) -> None
1002 self.__nodes = {i: IGNode(i) for i in equality_sets}
1003
1004 def __getitem__(self, key):
1005 # type: (EqualitySet) -> IGNode
1006 return self.__nodes[key]
1007
1008 def __iter__(self):
1009 return iter(self.__nodes)
1010
1011
1012 @plain_data()
1013 class AllocationFailed:
1014 __slots__ = "op_idx", "arg", "live_intervals"
1015
1016 def __init__(self, op_idx, arg, live_intervals):
1017 # type: (int, SSAVal | VecArg, LiveIntervals) -> None
1018 self.op_idx = op_idx
1019 self.arg = arg
1020 self.live_intervals = live_intervals
1021
1022
1023 def try_allocate_registers_without_spilling(ops):
1024 # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed
1025 live_intervals = LiveIntervals(ops)
1026
1027 def is_constrained(node):
1028 # type: (EqualitySet) -> bool
1029 raise NotImplementedError
1030
1031 raise NotImplementedError
1032
1033
1034 def allocate_registers(ops):
1035 # type: (list[Op]) -> None
1036 raise NotImplementedError