2209d8b780f3189588dbe7b1317157615d8d836f
[bigint-presentation-code.git] / src / bigint_presentation_code / toom_cook.py
1 from abc import ABCMeta, abstractmethod
2 import builtins
3 from collections import defaultdict
4 from enum import Enum, unique
5 from typing import Iterable, Mapping, TYPE_CHECKING
6
7 from nmutil.plain_data import plain_data
8
9 if TYPE_CHECKING:
10 from typing_extensions import final
11 else:
12 def final(v):
13 return v
14
15
16 @plain_data(frozen=True, unsafe_hash=True)
17 class PhysLoc:
18 pass
19
20
21 @plain_data(frozen=True, unsafe_hash=True)
22 class GPROrStackLoc(PhysLoc):
23 pass
24
25
26 @final
27 class GPR(GPROrStackLoc, Enum):
28 def __init__(self, reg_num):
29 # type: (int) -> None
30 self.reg_num = reg_num
31 # fmt: off
32 R0 = 0; R1 = 1; R2 = 2; R3 = 3; R4 = 4; R5 = 5
33 R6 = 6; R7 = 7; R8 = 8; R9 = 9; R10 = 10; R11 = 11
34 R12 = 12; R13 = 13; R14 = 14; R15 = 15; R16 = 16; R17 = 17
35 R18 = 18; R19 = 19; R20 = 20; R21 = 21; R22 = 22; R23 = 23
36 R24 = 24; R25 = 25; R26 = 26; R27 = 27; R28 = 28; R29 = 29
37 R30 = 30; R31 = 31; R32 = 32; R33 = 33; R34 = 34; R35 = 35
38 R36 = 36; R37 = 37; R38 = 38; R39 = 39; R40 = 40; R41 = 41
39 R42 = 42; R43 = 43; R44 = 44; R45 = 45; R46 = 46; R47 = 47
40 R48 = 48; R49 = 49; R50 = 50; R51 = 51; R52 = 52; R53 = 53
41 R54 = 54; R55 = 55; R56 = 56; R57 = 57; R58 = 58; R59 = 59
42 R60 = 60; R61 = 61; R62 = 62; R63 = 63; R64 = 64; R65 = 65
43 R66 = 66; R67 = 67; R68 = 68; R69 = 69; R70 = 70; R71 = 71
44 R72 = 72; R73 = 73; R74 = 74; R75 = 75; R76 = 76; R77 = 77
45 R78 = 78; R79 = 79; R80 = 80; R81 = 81; R82 = 82; R83 = 83
46 R84 = 84; R85 = 85; R86 = 86; R87 = 87; R88 = 88; R89 = 89
47 R90 = 90; R91 = 91; R92 = 92; R93 = 93; R94 = 94; R95 = 95
48 R96 = 96; R97 = 97; R98 = 98; R99 = 99; R100 = 100; R101 = 101
49 R102 = 102; R103 = 103; R104 = 104; R105 = 105; R106 = 106; R107 = 107
50 R108 = 108; R109 = 109; R110 = 110; R111 = 111; R112 = 112; R113 = 113
51 R114 = 114; R115 = 115; R116 = 116; R117 = 117; R118 = 118; R119 = 119
52 R120 = 120; R121 = 121; R122 = 122; R123 = 123; R124 = 124; R125 = 125
53 R126 = 126; R127 = 127
54 # fmt: on
55 SP = 1
56 TOC = 2
57
58
59 SPECIAL_GPRS = GPR.R0, GPR.SP, GPR.TOC, GPR.R13
60
61
62 @final
63 @unique
64 class XERBit(Enum, PhysLoc):
65 CY = "CY"
66
67
68 @final
69 @unique
70 class GlobalMem(Enum, PhysLoc):
71 """singleton representing all non-StackSlot memory"""
72 GlobalMem = "GlobalMem"
73
74
75 @plain_data()
76 @final
77 class StackSlot(GPROrStackLoc):
78 """a stack slot. Use OpCopy to load from/store into this stack slot."""
79 __slots__ = "offset",
80
81 def __init__(self, offset=None):
82 # type: (int | None) -> None
83 self.offset = offset
84
85
86 @plain_data(eq=False)
87 class SSAVal(metaclass=ABCMeta):
88 __slots__ = "id",
89
90 def __init__(self, id=None):
91 # type: (int | None) -> None
92 if id is None:
93 id = builtins.id(self)
94 self.id = id
95
96 def __eq__(self, rhs):
97 if isinstance(rhs, SSAVal):
98 return self.id == rhs.id
99 return False
100
101 def __hash__(self):
102 return hash(self.id)
103
104 def _get_phys_loc(self, phys_loc_in, value_assignments=None):
105 # type: (PhysLoc | None, dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
106 if phys_loc_in is not None:
107 return phys_loc_in
108 if value_assignments is not None:
109 return value_assignments.get(self)
110 return None
111
112 @abstractmethod
113 def get_phys_loc(self, value_assignments=None):
114 # type: (dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
115 ...
116
117
118 @plain_data(eq=False)
119 @final
120 class SSAGPRVal(SSAVal):
121 __slots__ = "phys_loc",
122
123 def __init__(self, phys_loc=None):
124 # type: (GPROrStackLoc | None) -> None
125 self.phys_loc = phys_loc
126 super().__init__()
127
128 def __len__(self):
129 return 1
130
131 def get_phys_loc(self, value_assignments=None):
132 # type: (dict[SSAVal, PhysLoc] | None) -> GPROrStackLoc | None
133 loc = self._get_phys_loc(self.phys_loc, value_assignments)
134 if isinstance(loc, GPROrStackLoc):
135 return loc
136 return None
137
138 def get_reg_num(self, value_assignments=None):
139 # type: (dict[SSAVal, PhysLoc] | None) -> int | None
140 reg = self.get_reg(value_assignments)
141 if reg is not None:
142 return reg.reg_num
143 return None
144
145 def get_reg(self, value_assignments=None):
146 # type: (dict[SSAVal, PhysLoc] | None) -> GPR | None
147 loc = self.get_phys_loc(value_assignments)
148 if isinstance(loc, GPR):
149 return loc
150 return None
151
152 def get_stack_slot(self, value_assignments=None):
153 # type: (dict[SSAVal, PhysLoc] | None) -> StackSlot | None
154 loc = self.get_phys_loc(value_assignments)
155 if isinstance(loc, StackSlot):
156 return loc
157 return None
158
159 def possible_reg_assignments(self, value_assignments,
160 conflicting_regs=set()):
161 # type: (dict[SSAVal, PhysLoc] | None, set[GPR]) -> Iterable[GPR]
162 if self.get_phys_loc(value_assignments) is not None:
163 raise ValueError("can't assign a already-assigned SSA value")
164 for reg in GPR:
165 if reg not in conflicting_regs:
166 yield reg
167
168
169 @plain_data(eq=False)
170 @final
171 class SSAXERBitVal(SSAVal):
172 __slots__ = "phys_loc",
173
174 def __init__(self, phys_loc=None):
175 # type: (XERBit | None) -> None
176 self.phys_loc = phys_loc
177
178 def get_phys_loc(self, value_assignments=None):
179 # type: (dict[SSAVal, PhysLoc] | None) -> XERBit | None
180 loc = self._get_phys_loc(self.phys_loc, value_assignments)
181 if isinstance(loc, XERBit):
182 return loc
183 return None
184
185
186 @plain_data(eq=False)
187 @final
188 class SSAMemory(SSAVal):
189 __slots__ = "phys_loc",
190
191 def __init__(self, phys_loc=GlobalMem.GlobalMem):
192 # type: (GlobalMem) -> None
193 self.phys_loc = phys_loc
194
195 def get_phys_loc(self, value_assignments=None):
196 # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem | None
197 loc = self._get_phys_loc(self.phys_loc, value_assignments)
198 if isinstance(loc, GlobalMem):
199 return loc
200 return None
201
202
203 @plain_data(unsafe_hash=True, frozen=True)
204 @final
205 class VecArg:
206 __slots__ = "regs",
207
208 def __init__(self, regs):
209 # type: (Iterable[SSAGPRVal]) -> None
210 self.regs = tuple(regs)
211
212 def __len__(self):
213 return len(self.regs)
214
215 def is_unassigned(self, value_assignments=None):
216 # type: (dict[SSAVal, PhysLoc] | None) -> bool
217 for val in self.regs:
218 if val.get_phys_loc(value_assignments) is not None:
219 return False
220 return True
221
222 def try_get_range(self, value_assignments=None, allow_unassigned=False,
223 raise_if_invalid=False):
224 # type: (dict[SSAVal, PhysLoc] | None, bool, bool) -> range | None
225 if len(self.regs) == 0:
226 return range(0)
227
228 retval = None # type: range | None
229 for i, val in enumerate(self.regs):
230 if val.get_phys_loc(value_assignments) is None:
231 if not allow_unassigned:
232 if raise_if_invalid:
233 raise ValueError("not a valid register range: "
234 "unassigned SSA value encountered")
235 return None
236 continue
237 reg = val.get_reg_num(value_assignments)
238 if reg is None:
239 if raise_if_invalid:
240 raise ValueError("not a valid register range: "
241 "non-register encountered")
242 return None
243 expected_range = range(reg - i, reg - i + len(self.regs))
244 if retval is None:
245 retval = expected_range
246 elif retval != expected_range:
247 if raise_if_invalid:
248 raise ValueError("not a valid register range: "
249 "register out of sequence")
250 return None
251 return retval
252
253 def possible_reg_assignments(
254 self,
255 val, # type: SSAVal
256 value_assignments, # type: dict[SSAVal, PhysLoc] | None
257 conflicting_regs=set(), # type: set[GPR]
258 ): # type: (...) -> Iterable[GPR]
259 index = self.regs.index(val)
260 alignment = 1
261 while alignment < len(self.regs):
262 alignment *= 2
263 r = self.try_get_range(value_assignments)
264 if r is not None and r.start % alignment != 0:
265 raise ValueError("must be a ascending aligned range of GPRs")
266 if r is None:
267 for i in range(0, len(GPR), alignment):
268 r = range(i, i + len(self.regs))
269 if any(GPR(reg) in conflicting_regs for reg in r):
270 continue
271 yield GPR(r[index])
272 else:
273 yield GPR(r[index])
274
275
276 @final
277 @plain_data(unsafe_hash=True, frozen=True)
278 class EqualityConstraint:
279 __slots__ = "lhs", "rhs"
280
281 def __init__(self, lhs, rhs):
282 # type: (SSAVal, SSAVal) -> None
283 self.lhs = lhs
284 self.rhs = rhs
285
286
287 @plain_data(unsafe_hash=True, frozen=True)
288 class Op(metaclass=ABCMeta):
289 __slots__ = ()
290
291 def input_ssa_vals(self):
292 # type: () -> Iterable[SSAVal]
293 for arg in self.inputs().values():
294 if isinstance(arg, VecArg):
295 yield from arg.regs
296 else:
297 yield arg
298
299 def output_ssa_vals(self):
300 # type: () -> Iterable[SSAVal]
301 for arg in self.outputs().values():
302 if isinstance(arg, VecArg):
303 yield from arg.regs
304 else:
305 yield arg
306
307 @abstractmethod
308 def inputs(self):
309 # type: () -> dict[str, VecArg | SSAVal]
310 ...
311
312 @abstractmethod
313 def outputs(self):
314 # type: () -> dict[str, VecArg | SSAVal]
315 ...
316
317 @abstractmethod
318 def possible_reg_assignments(self, val, value_assignments):
319 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
320 ...
321
322 def get_equality_constraints(self):
323 # type: () -> Iterable[EqualityConstraint]
324 if False:
325 yield ...
326
327 def __init__(self):
328 pass
329
330
331 @plain_data(unsafe_hash=True, frozen=True)
332 @final
333 class OpCopy(Op):
334 __slots__ = "dest", "src"
335
336 def inputs(self):
337 # type: () -> dict[str, VecArg | SSAVal]
338 return {"src": self.src}
339
340 def outputs(self):
341 # type: () -> dict[str, VecArg | SSAVal]
342 return {"dest": self.dest}
343
344 def __init__(self, dest, src):
345 # type: (VecArg | SSAVal, VecArg | SSAVal) -> None
346 if isinstance(dest, VecArg) and isinstance(src, VecArg):
347 if len(src.regs) != len(dest.regs):
348 raise TypeError(f"source length must match dest "
349 f"length: {src} doesn't match {dest}")
350 elif type(dest) != type(src):
351 raise TypeError(f"source argument type must match dest "
352 f"argument type: {src} doesn't match {dest}")
353 self.dest = dest
354 self.src = src
355
356 def possible_reg_assignments(self, val, value_assignments):
357 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
358 if val not in self.input_ssa_vals() \
359 and val not in self.output_ssa_vals():
360 raise ValueError(f"{val} must be an operand of {self}")
361 if val.get_phys_loc(value_assignments) is not None:
362 raise ValueError(f"{val} already assigned a physical location")
363 conflicting_regs = set() # type: set[GPR]
364 if val in self.output_ssa_vals() and isinstance(self.dest, VecArg):
365 # OpCopy is the only op that can write to physical locations in
366 # any order, it handles figuring out the right instruction sequence
367 dest_locs = {} # type: dict[GPROrStackLoc, SSAVal]
368 for val in self.dest.regs:
369 loc = val.get_phys_loc(value_assignments)
370 if loc is None:
371 continue
372 if loc in dest_locs:
373 raise ValueError(
374 f"duplicate destination location not allowed: "
375 f"{val} is assigned to {loc} which is also "
376 f"written by {dest_locs[loc]}")
377 dest_locs[loc] = val
378 if isinstance(loc, GPR):
379 conflicting_regs.add(loc)
380 if not isinstance(val, SSAGPRVal):
381 raise ValueError("invalid operand type")
382 return val.possible_reg_assignments(value_assignments,
383 conflicting_regs)
384
385
386 def range_overlaps(range1, range2):
387 # type: (range, range) -> bool
388 if len(range1) == 0 or len(range2) == 0:
389 return False
390 range1_last = range1[-1]
391 range2_last = range2[-1]
392 return (range1.start in range2 or range1_last in range2 or
393 range2.start in range1 or range2_last in range1)
394
395
396 @plain_data(unsafe_hash=True, frozen=True)
397 @final
398 class OpAddSubE(Op):
399 __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
400
401 def inputs(self):
402 # type: () -> dict[str, VecArg | SSAVal]
403 return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
404
405 def outputs(self):
406 # type: () -> dict[str, VecArg | SSAVal]
407 return {"RT": self.RT, "CY_out": self.CY_out}
408
409 def __init__(self, RT, RA, RB, CY_in, CY_out, is_sub):
410 # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal, bool) -> None
411 if len(RA.regs) != len(RT.regs):
412 raise TypeError(f"source length must match dest "
413 f"length: {RA} doesn't match {RT}")
414 if len(RB.regs) != len(RT.regs):
415 raise TypeError(f"source length must match dest "
416 f"length: {RB} doesn't match {RT}")
417 self.RT = RT
418 self.RA = RA
419 self.RB = RB
420 self.CY_in = CY_in
421 self.CY_out = CY_out
422 self.is_sub = is_sub
423
424 def possible_reg_assignments(self, val, value_assignments):
425 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
426 if val not in self.input_ssa_vals() \
427 and val not in self.output_ssa_vals():
428 raise ValueError(f"{val} must be an operand of {self}")
429 if val.get_phys_loc(value_assignments) is not None:
430 raise ValueError(f"{val} already assigned a physical location")
431 if self.CY_in == val or self.CY_out == val:
432 yield XERBit.CY
433 elif val in self.RT.regs:
434 # since possible_reg_assignments only returns aligned
435 # vectors, all possible assignments either are the same as an
436 # input or don't overlap with an input and we avoid the incorrect
437 # results caused by partial overlaps overwriting input elements
438 # before they're read
439 yield from self.RT.possible_reg_assignments(val, value_assignments)
440 elif val in self.RA.regs:
441 yield from self.RA.possible_reg_assignments(val, value_assignments)
442 else:
443 yield from self.RB.possible_reg_assignments(val, value_assignments)
444
445 def get_equality_constraints(self):
446 # type: () -> Iterable[EqualityConstraint]
447 yield EqualityConstraint(self.CY_in, self.CY_out)
448
449
450 def to_reg_set(v):
451 # type: (None | GPR | range) -> set[GPR]
452 if v is None:
453 return set()
454 if isinstance(v, range):
455 return set(map(GPR, v))
456 return {v}
457
458
459 @plain_data(unsafe_hash=True, frozen=True)
460 @final
461 class OpBigIntMulDiv(Op):
462 __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
463
464 def inputs(self):
465 # type: () -> dict[str, VecArg | SSAVal]
466 return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
467
468 def outputs(self):
469 # type: () -> dict[str, VecArg | SSAVal]
470 return {"RT": self.RT, "RS": self.RS}
471
472 def __init__(self, RT, RA, RB, RC, RS, is_div):
473 # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal, bool) -> None
474 if len(RA.regs) != len(RT.regs):
475 raise TypeError(f"source length must match dest "
476 f"length: {RA} doesn't match {RT}")
477 self.RT = RT
478 self.RA = RA
479 self.RB = RB
480 self.RC = RC
481 self.RS = RS
482 self.is_div = is_div
483
484 def possible_reg_assignments(self, val, value_assignments):
485 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
486 if val not in self.input_ssa_vals() \
487 and val not in self.output_ssa_vals():
488 raise ValueError(f"{val} must be an operand of {self}")
489 if val.get_phys_loc(value_assignments) is not None:
490 raise ValueError(f"{val} already assigned a physical location")
491
492 RT_range = self.RT.try_get_range(value_assignments,
493 allow_unassigned=True,
494 raise_if_invalid=True)
495 RA_range = self.RA.try_get_range(value_assignments,
496 allow_unassigned=True,
497 raise_if_invalid=True)
498 RC_RS_reg = self.RC.get_reg(value_assignments)
499 if RC_RS_reg is None:
500 RC_RS_reg = self.RS.get_reg(value_assignments)
501
502 if self.RC == val or self.RS == val:
503 if RC_RS_reg is not None:
504 yield RC_RS_reg
505 else:
506 conflicting_regs = to_reg_set(RT_range) | to_reg_set(RA_range)
507 yield from self.RC.possible_reg_assignments(value_assignments,
508 conflicting_regs)
509 elif val in self.RT.regs:
510 # since possible_reg_assignments only returns aligned
511 # vectors, all possible assignments either are the same as
512 # RA or don't overlap with RA and we avoid the incorrect
513 # results caused by partial overlaps overwriting input elements
514 # before they're read
515 yield from self.RT.possible_reg_assignments(
516 val, value_assignments,
517 conflicting_regs=to_reg_set(RA_range) | to_reg_set(RC_RS_reg))
518 else:
519 yield from self.RA.possible_reg_assignments(
520 val, value_assignments,
521 conflicting_regs=to_reg_set(RT_range) | to_reg_set(RC_RS_reg))
522
523 def get_equality_constraints(self):
524 # type: () -> Iterable[EqualityConstraint]
525 yield EqualityConstraint(self.RC, self.RS)
526
527
528 @final
529 @unique
530 class ShiftKind(Enum):
531 Sl = "sl"
532 Sr = "sr"
533 Sra = "sra"
534
535
536 @plain_data(unsafe_hash=True, frozen=True)
537 @final
538 class OpBigIntShift(Op):
539 __slots__ = "RT", "inp", "sh", "kind"
540
541 def inputs(self):
542 # type: () -> dict[str, VecArg | SSAVal]
543 return {"inp": self.inp, "sh": self.sh}
544
545 def outputs(self):
546 # type: () -> dict[str, VecArg | SSAVal]
547 return {"RT": self.RT}
548
549 def __init__(self, RT, inp, sh, kind):
550 # type: (VecArg, VecArg, SSAGPRVal, ShiftKind) -> None
551 if len(inp.regs) != len(RT.regs):
552 raise TypeError(f"source length must match dest "
553 f"length: {inp} doesn't match {RT}")
554 self.RT = RT
555 self.inp = inp
556 self.sh = sh
557 self.kind = kind
558
559 def possible_reg_assignments(self, val, value_assignments):
560 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
561 if val not in self.input_ssa_vals() \
562 and val not in self.output_ssa_vals():
563 raise ValueError(f"{val} must be an operand of {self}")
564 if val.get_phys_loc(value_assignments) is not None:
565 raise ValueError(f"{val} already assigned a physical location")
566
567 RT_range = self.RT.try_get_range(value_assignments,
568 allow_unassigned=True,
569 raise_if_invalid=True)
570 inp_range = self.inp.try_get_range(value_assignments,
571 allow_unassigned=True,
572 raise_if_invalid=True)
573 sh_reg = self.sh.get_reg(value_assignments)
574
575 if self.sh == val:
576 conflicting_regs = to_reg_set(RT_range)
577 yield from self.sh.possible_reg_assignments(value_assignments,
578 conflicting_regs)
579 elif val in self.RT.regs:
580 # since possible_reg_assignments only returns aligned
581 # vectors, all possible assignments either are the same as
582 # RA or don't overlap with RA and we avoid the incorrect
583 # results caused by partial overlaps overwriting input elements
584 # before they're read
585 yield from self.RT.possible_reg_assignments(
586 val, value_assignments,
587 conflicting_regs=to_reg_set(inp_range) | to_reg_set(sh_reg))
588 else:
589 yield from self.inp.possible_reg_assignments(
590 val, value_assignments,
591 conflicting_regs=to_reg_set(RT_range))
592
593
594 @plain_data(unsafe_hash=True, frozen=True)
595 @final
596 class OpLI(Op):
597 __slots__ = "out", "value"
598
599 def inputs(self):
600 # type: () -> dict[str, VecArg | SSAVal]
601 return {}
602
603 def outputs(self):
604 # type: () -> dict[str, VecArg | SSAVal]
605 return {"out": self.out}
606
607 def __init__(self, out, value):
608 # type: (VecArg | SSAGPRVal, int) -> None
609 self.out = out
610 self.value = value
611
612 def possible_reg_assignments(self, val, value_assignments):
613 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
614 if val not in self.input_ssa_vals() \
615 and val not in self.output_ssa_vals():
616 raise ValueError(f"{val} must be an operand of {self}")
617 if val.get_phys_loc(value_assignments) is not None:
618 raise ValueError(f"{val} already assigned a physical location")
619
620 if isinstance(self.out, VecArg):
621 yield from self.out.possible_reg_assignments(val,
622 value_assignments)
623 else:
624 yield from self.out.possible_reg_assignments(value_assignments)
625
626
627 @plain_data(unsafe_hash=True, frozen=True)
628 @final
629 class OpClearCY(Op):
630 __slots__ = "out",
631
632 def inputs(self):
633 # type: () -> dict[str, VecArg | SSAVal]
634 return {}
635
636 def outputs(self):
637 # type: () -> dict[str, VecArg | SSAVal]
638 return {"out": self.out}
639
640 def __init__(self, out):
641 # type: (SSAXERBitVal) -> None
642 self.out = out
643
644 def possible_reg_assignments(self, val, value_assignments):
645 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
646 if val not in self.input_ssa_vals() \
647 and val not in self.output_ssa_vals():
648 raise ValueError(f"{val} must be an operand of {self}")
649 if val.get_phys_loc(value_assignments) is not None:
650 raise ValueError(f"{val} already assigned a physical location")
651
652 yield XERBit.CY
653
654
655 @plain_data(unsafe_hash=True, frozen=True)
656 @final
657 class OpLoad(Op):
658 __slots__ = "RT", "RA", "offset", "mem"
659
660 def inputs(self):
661 # type: () -> dict[str, VecArg | SSAVal]
662 return {"RA": self.RA, "mem": self.mem}
663
664 def outputs(self):
665 # type: () -> dict[str, VecArg | SSAVal]
666 return {"RT": self.RT}
667
668 def __init__(self, RT, RA, offset, mem):
669 # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None
670 self.RT = RT
671 self.RA = RA
672 self.offset = offset
673 self.mem = mem
674
675 def possible_reg_assignments(self, val, value_assignments):
676 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
677 if val not in self.input_ssa_vals() \
678 and val not in self.output_ssa_vals():
679 raise ValueError(f"{val} must be an operand of {self}")
680 if val.get_phys_loc(value_assignments) is not None:
681 raise ValueError(f"{val} already assigned a physical location")
682
683 RA_reg = self.RA.get_reg(value_assignments)
684
685 if self.mem == val:
686 yield GlobalMem.GlobalMem
687 elif self.RA == val:
688 if isinstance(self.RT, VecArg):
689 conflicting_regs = to_reg_set(self.RT.try_get_range(
690 value_assignments, allow_unassigned=True,
691 raise_if_invalid=True))
692 else:
693 conflicting_regs = set()
694 yield from self.RA.possible_reg_assignments(value_assignments,
695 conflicting_regs)
696 elif isinstance(self.RT, VecArg):
697 yield from self.RT.possible_reg_assignments(
698 val, value_assignments,
699 conflicting_regs=to_reg_set(RA_reg))
700 else:
701 yield from self.RT.possible_reg_assignments(value_assignments)
702
703
704 @plain_data(unsafe_hash=True, frozen=True)
705 @final
706 class OpStore(Op):
707 __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
708
709 def inputs(self):
710 # type: () -> dict[str, VecArg | SSAVal]
711 return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
712
713 def outputs(self):
714 # type: () -> dict[str, VecArg | SSAVal]
715 return {"mem_out": self.mem_out}
716
717 def __init__(self, RS, RA, offset, mem_in, mem_out):
718 # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory, SSAMemory) -> None
719 self.RS = RS
720 self.RA = RA
721 self.offset = offset
722 self.mem_in = mem_in
723 self.mem_out = mem_out
724
725 def possible_reg_assignments(self, val, value_assignments):
726 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
727 if val not in self.input_ssa_vals() \
728 and val not in self.output_ssa_vals():
729 raise ValueError(f"{val} must be an operand of {self}")
730 if val.get_phys_loc(value_assignments) is not None:
731 raise ValueError(f"{val} already assigned a physical location")
732
733 if self.mem_in == val or self.mem_out == val:
734 yield GlobalMem.GlobalMem
735 elif self.RA == val:
736 yield from self.RA.possible_reg_assignments(value_assignments)
737 elif isinstance(self.RS, VecArg):
738 yield from self.RS.possible_reg_assignments(val, value_assignments)
739 else:
740 yield from self.RS.possible_reg_assignments(value_assignments)
741
742 def get_equality_constraints(self):
743 # type: () -> Iterable[EqualityConstraint]
744 yield EqualityConstraint(self.mem_in, self.mem_out)
745
746
747 @plain_data(unsafe_hash=True, frozen=True)
748 @final
749 class OpFuncArg(Op):
750 __slots__ = "out",
751
752 def inputs(self):
753 # type: () -> dict[str, VecArg | SSAVal]
754 return {}
755
756 def outputs(self):
757 # type: () -> dict[str, VecArg | SSAVal]
758 return {"out": self.out}
759
760 def __init__(self, out):
761 # type: (VecArg | SSAGPRVal) -> None
762 self.out = out
763
764 def possible_reg_assignments(self, val, value_assignments):
765 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
766 if val not in self.input_ssa_vals() \
767 and val not in self.output_ssa_vals():
768 raise ValueError(f"{val} must be an operand of {self}")
769 if val.get_phys_loc(value_assignments) is not None:
770 raise ValueError(f"{val} already assigned a physical location")
771
772 if isinstance(self.out, VecArg):
773 yield from self.out.possible_reg_assignments(val,
774 value_assignments)
775 else:
776 yield from self.out.possible_reg_assignments(value_assignments)
777
778
779 @plain_data(unsafe_hash=True, frozen=True)
780 @final
781 class OpInputMem(Op):
782 __slots__ = "out",
783
784 def inputs(self):
785 # type: () -> dict[str, VecArg | SSAVal]
786 return {}
787
788 def outputs(self):
789 # type: () -> dict[str, VecArg | SSAVal]
790 return {"out": self.out}
791
792 def __init__(self, out):
793 # type: (SSAMemory) -> None
794 self.out = out
795
796 def possible_reg_assignments(self, val, value_assignments):
797 # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
798 if val not in self.input_ssa_vals() \
799 and val not in self.output_ssa_vals():
800 raise ValueError(f"{val} must be an operand of {self}")
801 if val.get_phys_loc(value_assignments) is not None:
802 raise ValueError(f"{val} already assigned a physical location")
803
804 yield GlobalMem.GlobalMem
805
806
807 def op_set_to_list(ops):
808 # type: (Iterable[Op]) -> list[Op]
809 worklists = [set()] # type: list[set[Op]]
810 input_vals_to_ops_map = defaultdict(set) # type: dict[SSAVal, set[Op]]
811 ops_to_pending_input_count_map = {} # type: dict[Op, int]
812 for op in ops:
813 input_count = 0
814 for val in op.input_ssa_vals():
815 input_count += 1
816 input_vals_to_ops_map[val].add(op)
817 while len(worklists) <= input_count:
818 worklists.append(set())
819 ops_to_pending_input_count_map[op] = input_count
820 worklists[input_count].add(op)
821 retval = [] # type: list[Op]
822 ready_vals = set() # type: set[SSAVal]
823 while len(worklists[0]) != 0:
824 writing_op = worklists[0].pop()
825 retval.append(writing_op)
826 for val in writing_op.output_ssa_vals():
827 if val in ready_vals:
828 raise ValueError(f"multiple instructions must not write "
829 f"to the same SSA value: {val}")
830 ready_vals.add(val)
831 for reading_op in input_vals_to_ops_map[val]:
832 pending = ops_to_pending_input_count_map[reading_op]
833 worklists[pending].remove(reading_op)
834 pending -= 1
835 worklists[pending].add(reading_op)
836 ops_to_pending_input_count_map[reading_op] = pending
837 for worklist in worklists:
838 for op in worklist:
839 raise ValueError(f"instruction is part of a dependency loop or "
840 f"its inputs are never written: {op}")
841 return retval
842
843
844 @plain_data(unsafe_hash=True, order=True, frozen=True)
845 class LiveInterval:
846 __slots__ = "assignment", "last_use"
847
848 def __init__(self, assignment, last_use=None):
849 # type: (int, int | None) -> None
850 if last_use is None:
851 last_use = assignment
852 if last_use < assignment:
853 raise ValueError("uses must be after assignment")
854 if assignment < 0 or last_use < 0:
855 raise ValueError("indexes must be nonnegative")
856 self.assignment = assignment
857 self.last_use = last_use
858
859 def overlaps(self, other):
860 # type: (LiveInterval) -> bool
861 if self.assignment == other.assignment:
862 return True
863 return self.last_use > other.assignment \
864 and other.last_use > self.assignment
865
866 def __add__(self, use):
867 # type: (int) -> LiveInterval
868 last_use = max(self.last_use, use)
869 return LiveInterval(assignment=self.assignment, last_use=last_use)
870
871
872 class LiveIntervals(Mapping[SSAVal, LiveInterval]):
873 def __init__(self, ops):
874 # type: (list[Op]) -> None
875 live_intervals = {} # type: dict[SSAVal, LiveInterval]
876 for op_idx, op in enumerate(ops):
877 for val in op.input_ssa_vals():
878 live_intervals[val] += op_idx
879 for val in op.output_ssa_vals():
880 if val in live_intervals:
881 raise ValueError(f"multiple instructions must not write "
882 f"to the same SSA value: {val}")
883 live_intervals[val] = LiveInterval(op_idx)
884 self.__live_intervals = live_intervals
885
886 def __getitem__(self, key):
887 # type: (SSAVal) -> LiveInterval
888 return self.__live_intervals[key]
889
890 def __iter__(self):
891 return iter(self.__live_intervals)
892
893
894 @plain_data()
895 class AllocationFailed:
896 __slots__ = "op_idx", "arg", "live_intervals", "free_regs"
897
898 def __init__(self, op_idx, arg, live_intervals, free_regs):
899 # type: (int, SSAVal | VecArg, LiveIntervals, set[GPR | XERBit]) -> None
900 self.op_idx = op_idx
901 self.arg = arg
902 self.live_intervals = live_intervals
903 self.free_regs = free_regs
904
905
906 def try_allocate_registers_without_spilling(ops):
907 # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed
908 live_intervals = LiveIntervals(ops)
909 free_regs = set() # type: set[GPR | XERBit]
910 free_regs.update(GPR)
911 free_regs.difference_update(SPECIAL_GPRS)
912 free_regs.update(XERBit)
913 raise NotImplementedError
914
915
916 def allocate_registers(ops):
917 # type: (list[Op]) -> None
918 raise NotImplementedError