2 Compiler IR for Toom-Cook algorithm generator for SVP64
5 from abc
import ABCMeta
, abstractmethod
6 from collections
import defaultdict
7 from enum
import Enum
, EnumMeta
, unique
8 from functools
import lru_cache
9 from typing
import (TYPE_CHECKING
, AbstractSet
, Generic
, Iterable
, Sequence
,
12 from cached_property
import cached_property
13 from nmutil
.plain_data
import fields
, plain_data
16 from typing_extensions
import final
22 class ABCEnumMeta(EnumMeta
, ABCMeta
):
26 class RegLoc(metaclass
=ABCMeta
):
30 def conflicts(self
, other
):
31 # type: (RegLoc) -> bool
34 def get_subreg_at_offset(self
, subreg_type
, offset
):
35 # type: (RegType, int) -> RegLoc
36 if self
not in subreg_type
.reg_class
:
37 raise ValueError(f
"register not a member of subreg_type: "
38 f
"reg={self} subreg_type={subreg_type}")
40 raise ValueError(f
"non-zero sub-register offset not supported "
41 f
"for register: {self}")
48 @plain_data(frozen
=True, unsafe_hash
=True)
50 class GPRRange(RegLoc
, Sequence
["GPRRange"]):
51 __slots__
= "start", "length"
53 def __init__(self
, start
, length
=None):
54 # type: (int | range, int | None) -> None
55 if isinstance(start
, range):
56 if length
is not None:
57 raise TypeError("can't specify length when input is a range")
59 raise ValueError("range must have a step of 1")
64 if length
<= 0 or start
< 0 or start
+ length
> GPR_COUNT
:
65 raise ValueError("invalid GPRRange")
71 return self
.start
+ self
.length
79 return range(self
.start
, self
.stop
, self
.step
)
84 def __getitem__(self
, item
):
85 # type: (int | slice) -> GPRRange
86 return GPRRange(self
.range[item
])
88 def __contains__(self
, value
):
89 # type: (GPRRange) -> bool
90 return value
.start
>= self
.start
and value
.stop
<= self
.stop
92 def index(self
, sub
, start
=None, end
=None):
93 # type: (GPRRange, int | None, int | None) -> int
94 r
= self
.range[start
:end
]
95 if sub
.start
< r
.start
or sub
.stop
> r
.stop
:
96 raise ValueError("GPR range not found")
97 return sub
.start
- self
.start
99 def count(self
, sub
, start
=None, end
=None):
100 # type: (GPRRange, int | None, int | None) -> int
101 r
= self
.range[start
:end
]
104 return int(sub
in GPRRange(r
))
106 def conflicts(self
, other
):
107 # type: (RegLoc) -> bool
108 if isinstance(other
, GPRRange
):
109 return self
.stop
> other
.start
and other
.stop
> self
.start
112 def get_subreg_at_offset(self
, subreg_type
, offset
):
113 # type: (RegType, int) -> GPRRange
114 if not isinstance(subreg_type
, GPRRangeType
):
115 raise ValueError(f
"subreg_type is not a "
116 f
"GPRRangeType: {subreg_type}")
117 if offset
< 0 or offset
+ subreg_type
.length
> self
.stop
:
118 raise ValueError(f
"sub-register offset is out of range: {offset}")
119 return GPRRange(self
.start
+ offset
, subreg_type
.length
)
122 SPECIAL_GPRS
= GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
127 class XERBit(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
130 def conflicts(self
, other
):
131 # type: (RegLoc) -> bool
132 if isinstance(other
, XERBit
):
139 class GlobalMem(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
140 """singleton representing all non-StackSlot memory -- treated as a single
141 physical register for register allocation purposes.
143 GlobalMem
= "GlobalMem"
145 def conflicts(self
, other
):
146 # type: (RegLoc) -> bool
147 if isinstance(other
, GlobalMem
):
153 class RegClass(AbstractSet
[RegLoc
]):
154 """ an ordered set of registers.
155 earlier registers are preferred by the register allocator.
158 def __init__(self
, regs
):
159 # type: (Iterable[RegLoc]) -> None
161 # use dict to maintain order
162 self
.__regs
= dict.fromkeys(regs
) # type: dict[RegLoc, None]
165 return len(self
.__regs
)
168 return iter(self
.__regs
)
170 def __contains__(self
, v
):
171 # type: (RegLoc) -> bool
172 return v
in self
.__regs
175 return super()._hash
()
177 @lru_cache(maxsize
=None, typed
=True)
178 def max_conflicts_with(self
, other
):
179 # type: (RegClass | RegLoc) -> int
180 """the largest number of registers in `self` that a single register
181 from `other` can conflict with
183 if isinstance(other
, RegClass
):
184 return max(self
.max_conflicts_with(i
) for i
in other
)
186 return sum(other
.conflicts(i
) for i
in self
)
189 @plain_data(frozen
=True, unsafe_hash
=True)
190 class RegType(metaclass
=ABCMeta
):
196 # type: () -> RegClass
200 _RegType
= TypeVar("_RegType", bound
=RegType
)
203 @plain_data(frozen
=True, eq
=False)
204 class GPRRangeType(RegType
):
205 __slots__
= "length",
207 def __init__(self
, length
):
208 # type: (int) -> None
209 if length
< 1 or length
> GPR_COUNT
:
210 raise ValueError("invalid length")
214 @lru_cache(maxsize
=None)
215 def __get_reg_class(length
):
216 # type: (int) -> RegClass
218 for start
in range(GPR_COUNT
- length
):
219 reg
= GPRRange(start
, length
)
220 if any(i
in reg
for i
in SPECIAL_GPRS
):
223 return RegClass(regs
)
227 # type: () -> RegClass
228 return GPRRangeType
.__get
_reg
_class
(self
.length
)
231 def __eq__(self
, other
):
232 if isinstance(other
, GPRRangeType
):
233 return self
.length
== other
.length
238 return hash(self
.length
)
241 @plain_data(frozen
=True, eq
=False)
243 class GPRType(GPRRangeType
):
246 def __init__(self
, length
=1):
248 raise ValueError("length must be 1")
249 super().__init
__(length
=1)
252 @plain_data(frozen
=True, unsafe_hash
=True)
254 class FixedGPRRangeType(GPRRangeType
):
257 def __init__(self
, reg
):
258 # type: (GPRRange) -> None
259 super().__init
__(length
=reg
.length
)
264 # type: () -> RegClass
265 return RegClass([self
.reg
])
268 @plain_data(frozen
=True, unsafe_hash
=True)
270 class CYType(RegType
):
275 # type: () -> RegClass
276 return RegClass([XERBit
.CY
])
279 @plain_data(frozen
=True, unsafe_hash
=True)
281 class GlobalMemType(RegType
):
286 # type: () -> RegClass
287 return RegClass([GlobalMem
.GlobalMem
])
290 @plain_data(frozen
=True, unsafe_hash
=True)
292 class StackSlot(RegLoc
):
293 __slots__
= "start_slot", "length_in_slots",
295 def __init__(self
, start_slot
, length_in_slots
):
296 # type: (int, int) -> None
297 self
.start_slot
= start_slot
298 if length_in_slots
< 1:
299 raise ValueError("invalid length_in_slots")
300 self
.length_in_slots
= length_in_slots
304 return self
.start_slot
+ self
.length_in_slots
306 def conflicts(self
, other
):
307 # type: (RegLoc) -> bool
308 if isinstance(other
, StackSlot
):
309 return (self
.stop_slot
> other
.start_slot
310 and other
.stop_slot
> self
.start_slot
)
313 def get_subreg_at_offset(self
, subreg_type
, offset
):
314 # type: (RegType, int) -> StackSlot
315 if not isinstance(subreg_type
, StackSlotType
):
316 raise ValueError(f
"subreg_type is not a "
317 f
"StackSlotType: {subreg_type}")
318 if offset
< 0 or offset
+ subreg_type
.length_in_slots
> self
.stop_slot
:
319 raise ValueError(f
"sub-register offset is out of range: {offset}")
320 return StackSlot(self
.start_slot
+ offset
, subreg_type
.length_in_slots
)
323 STACK_SLOT_COUNT
= 128
326 @plain_data(frozen
=True, eq
=False)
328 class StackSlotType(RegType
):
329 __slots__
= "length_in_slots",
331 def __init__(self
, length_in_slots
=1):
332 # type: (int) -> None
333 if length_in_slots
< 1:
334 raise ValueError("invalid length_in_slots")
335 self
.length_in_slots
= length_in_slots
338 @lru_cache(maxsize
=None)
339 def __get_reg_class(length_in_slots
):
340 # type: (int) -> RegClass
342 for start
in range(STACK_SLOT_COUNT
- length_in_slots
):
343 reg
= StackSlot(start
, length_in_slots
)
345 return RegClass(regs
)
349 # type: () -> RegClass
350 return StackSlotType
.__get
_reg
_class
(self
.length_in_slots
)
353 def __eq__(self
, other
):
354 if isinstance(other
, StackSlotType
):
355 return self
.length_in_slots
== other
.length_in_slots
360 return hash(self
.length_in_slots
)
363 @plain_data(frozen
=True, eq
=False, repr=False)
365 class SSAVal(Generic
[_RegType
]):
366 __slots__
= "op", "arg_name", "ty",
368 def __init__(self
, op
, arg_name
, ty
):
369 # type: (Op, str, _RegType) -> None
371 """the Op that writes this SSAVal"""
373 self
.arg_name
= arg_name
374 """the name of the argument of self.op that writes this SSAVal"""
378 def __eq__(self
, rhs
):
379 if isinstance(rhs
, SSAVal
):
380 return (self
.op
is rhs
.op
381 and self
.arg_name
== rhs
.arg_name
)
385 return hash((id(self
.op
), self
.arg_name
))
389 for name
in fields(self
):
390 v
= getattr(self
, name
, None)
393 v
= v
.__repr
__(just_id
=True)
396 fields_list
.append(f
"{name}={v}")
397 fields_str
= ", ".join(fields_list
)
398 return f
"SSAVal({fields_str})"
402 @plain_data(unsafe_hash
=True, frozen
=True)
403 class EqualityConstraint
:
404 __slots__
= "lhs", "rhs"
406 def __init__(self
, lhs
, rhs
):
407 # type: (list[SSAVal], list[SSAVal]) -> None
410 if len(lhs
) == 0 or len(rhs
) == 0:
411 raise ValueError("can't constrain an empty list to be equal")
415 """ helper for __repr__ for when fields aren't set """
424 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
425 class Op(metaclass
=ABCMeta
):
430 # type: () -> dict[str, SSAVal]
435 # type: () -> dict[str, SSAVal]
438 def get_equality_constraints(self
):
439 # type: () -> Iterable[EqualityConstraint]
443 def get_extra_interferences(self
):
444 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
452 retval
= Op
.__NEXT
_ID
457 def __repr__(self
, just_id
=False):
458 fields_list
= [f
"#{self.id}"]
460 for name
in fields(self
):
461 v
= getattr(self
, name
, _NOT_SET
)
462 fields_list
.append(f
"{name}={v!r}")
463 fields_str
= ', '.join(fields_list
)
464 return f
"{self.__class__.__name__}({fields_str})"
467 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
469 class OpLoadFromStackSlot(Op
):
470 __slots__
= "dest", "src"
473 # type: () -> dict[str, SSAVal]
474 return {"src": self
.src
}
477 # type: () -> dict[str, SSAVal]
478 return {"dest": self
.dest
}
480 def __init__(self
, src
):
481 # type: (SSAVal[GPRRangeType]) -> None
482 self
.dest
= SSAVal(self
, "dest", StackSlotType(src
.ty
.length
))
486 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
488 class OpStoreToStackSlot(Op
):
489 __slots__
= "dest", "src"
492 # type: () -> dict[str, SSAVal]
493 return {"src": self
.src
}
496 # type: () -> dict[str, SSAVal]
497 return {"dest": self
.dest
}
499 def __init__(self
, src
):
500 # type: (SSAVal[StackSlotType]) -> None
501 self
.dest
= SSAVal(self
, "dest", GPRRangeType(src
.ty
.length_in_slots
))
505 _RegSrcType
= TypeVar("_RegSrcType", bound
=RegType
)
508 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
510 class OpCopy(Op
, Generic
[_RegSrcType
, _RegType
]):
511 __slots__
= "dest", "src"
514 # type: () -> dict[str, SSAVal]
515 return {"src": self
.src
}
518 # type: () -> dict[str, SSAVal]
519 return {"dest": self
.dest
}
521 def __init__(self
, src
, dest_ty
=None):
522 # type: (SSAVal[_RegSrcType], _RegType | None) -> None
524 dest_ty
= cast(_RegType
, src
.ty
)
525 if isinstance(src
.ty
, GPRRangeType
) \
526 and isinstance(dest_ty
, GPRRangeType
):
527 if src
.ty
.length
!= dest_ty
.length
:
528 raise ValueError(f
"incompatible source and destination "
529 f
"types: {src.ty} and {dest_ty}")
530 elif src
.ty
!= dest_ty
:
531 raise ValueError(f
"incompatible source and destination "
532 f
"types: {src.ty} and {dest_ty}")
534 self
.dest
= SSAVal(self
, "dest", dest_ty
) # type: SSAVal[_RegType]
538 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
541 __slots__
= "dest", "sources"
544 # type: () -> dict[str, SSAVal]
545 return {f
"sources[{i}]": v
for i
, v
in enumerate(self
.sources
)}
548 # type: () -> dict[str, SSAVal]
549 return {"dest": self
.dest
}
551 def __init__(self
, sources
):
552 # type: (Iterable[SSAVal[GPRRangeType]]) -> None
553 sources
= tuple(sources
)
554 self
.dest
= SSAVal(self
, "dest", GPRRangeType(
555 sum(i
.ty
.length
for i
in sources
)))
556 self
.sources
= sources
558 def get_equality_constraints(self
):
559 # type: () -> Iterable[EqualityConstraint]
560 yield EqualityConstraint([self
.dest
], [*self
.sources
])
563 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
566 __slots__
= "results", "src"
569 # type: () -> dict[str, SSAVal]
570 return {"src": self
.src
}
573 # type: () -> dict[str, SSAVal]
574 return {i
.arg_name
: i
for i
in self
.results
}
576 def __init__(self
, src
, split_indexes
):
577 # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
578 ranges
= [] # type: list[GPRRangeType]
580 for i
in split_indexes
:
581 if not (0 < i
< src
.ty
.length
):
582 raise ValueError(f
"invalid split index: {i}, must be in "
583 f
"0 < i < {src.ty.length}")
584 ranges
.append(GPRRangeType(i
- last
))
586 ranges
.append(GPRRangeType(src
.ty
.length
- last
))
588 self
.results
= tuple(
589 SSAVal(self
, f
"results{i}", r
) for i
, r
in enumerate(ranges
))
591 def get_equality_constraints(self
):
592 # type: () -> Iterable[EqualityConstraint]
593 yield EqualityConstraint([*self
.results
], [self
.src
])
596 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
599 __slots__
= "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
602 # type: () -> dict[str, SSAVal]
603 return {"RA": self
.RA
, "RB": self
.RB
, "CY_in": self
.CY_in
}
606 # type: () -> dict[str, SSAVal]
607 return {"RT": self
.RT
, "CY_out": self
.CY_out
}
609 def __init__(self
, RA
, RB
, CY_in
, is_sub
):
610 # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
612 raise TypeError(f
"source types must match: "
613 f
"{RA} doesn't match {RB}")
614 self
.RT
= SSAVal(self
, "RT", RA
.ty
)
618 self
.CY_out
= SSAVal(self
, "CY_out", CY_in
.ty
)
621 def get_extra_interferences(self
):
622 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
623 yield self
.RT
, self
.RA
624 yield self
.RT
, self
.RB
627 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
629 class OpBigIntMulDiv(Op
):
630 __slots__
= "RT", "RA", "RB", "RC", "RS", "is_div"
633 # type: () -> dict[str, SSAVal]
634 return {"RA": self
.RA
, "RB": self
.RB
, "RC": self
.RC
}
637 # type: () -> dict[str, SSAVal]
638 return {"RT": self
.RT
, "RS": self
.RS
}
640 def __init__(self
, RA
, RB
, RC
, is_div
):
641 # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
642 self
.RT
= SSAVal(self
, "RT", RA
.ty
)
646 self
.RS
= SSAVal(self
, "RS", RC
.ty
)
649 def get_equality_constraints(self
):
650 # type: () -> Iterable[EqualityConstraint]
651 yield EqualityConstraint([self
.RC
], [self
.RS
])
653 def get_extra_interferences(self
):
654 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
655 yield self
.RT
, self
.RA
656 yield self
.RT
, self
.RB
657 yield self
.RT
, self
.RC
658 yield self
.RT
, self
.RS
659 yield self
.RS
, self
.RA
660 yield self
.RS
, self
.RB
665 class ShiftKind(Enum
):
671 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
673 class OpBigIntShift(Op
):
674 __slots__
= "RT", "inp", "sh", "kind"
677 # type: () -> dict[str, SSAVal]
678 return {"inp": self
.inp
, "sh": self
.sh
}
681 # type: () -> dict[str, SSAVal]
682 return {"RT": self
.RT
}
684 def __init__(self
, inp
, sh
, kind
):
685 # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
686 self
.RT
= SSAVal(self
, "RT", inp
.ty
)
691 def get_extra_interferences(self
):
692 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
693 yield self
.RT
, self
.inp
694 yield self
.RT
, self
.sh
697 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
700 __slots__
= "out", "value"
703 # type: () -> dict[str, SSAVal]
707 # type: () -> dict[str, SSAVal]
708 return {"out": self
.out
}
710 def __init__(self
, value
, length
=1):
711 # type: (int, int) -> None
712 self
.out
= SSAVal(self
, "out", GPRRangeType(length
))
716 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
722 # type: () -> dict[str, SSAVal]
726 # type: () -> dict[str, SSAVal]
727 return {"out": self
.out
}
731 self
.out
= SSAVal(self
, "out", CYType())
734 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
737 __slots__
= "RT", "RA", "offset", "mem"
740 # type: () -> dict[str, SSAVal]
741 return {"RA": self
.RA
, "mem": self
.mem
}
744 # type: () -> dict[str, SSAVal]
745 return {"RT": self
.RT
}
747 def __init__(self
, RA
, offset
, mem
, length
=1):
748 # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
749 self
.RT
= SSAVal(self
, "RT", GPRRangeType(length
))
754 def get_extra_interferences(self
):
755 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
756 if self
.RT
.ty
.length
> 1:
757 yield self
.RT
, self
.RA
760 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
763 __slots__
= "RS", "RA", "offset", "mem_in", "mem_out"
766 # type: () -> dict[str, SSAVal]
767 return {"RS": self
.RS
, "RA": self
.RA
, "mem_in": self
.mem_in
}
770 # type: () -> dict[str, SSAVal]
771 return {"mem_out": self
.mem_out
}
773 def __init__(self
, RS
, RA
, offset
, mem_in
):
774 # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
779 self
.mem_out
= SSAVal(self
, "mem_out", mem_in
.ty
)
782 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
788 # type: () -> dict[str, SSAVal]
792 # type: () -> dict[str, SSAVal]
793 return {"out": self
.out
}
795 def __init__(self
, ty
):
796 # type: (FixedGPRRangeType) -> None
797 self
.out
= SSAVal(self
, "out", ty
)
800 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
802 class OpInputMem(Op
):
806 # type: () -> dict[str, SSAVal]
810 # type: () -> dict[str, SSAVal]
811 return {"out": self
.out
}
815 self
.out
= SSAVal(self
, "out", GlobalMemType())
818 def op_set_to_list(ops
):
819 # type: (Iterable[Op]) -> list[Op]
820 worklists
= [{}] # type: list[dict[Op, None]]
821 inps_to_ops_map
= defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
822 ops_to_pending_input_count_map
= {} # type: dict[Op, int]
825 for val
in op
.inputs().values():
827 inps_to_ops_map
[val
][op
] = None
828 while len(worklists
) <= input_count
:
830 ops_to_pending_input_count_map
[op
] = input_count
831 worklists
[input_count
][op
] = None
832 retval
= [] # type: list[Op]
833 ready_vals
= set() # type: set[SSAVal]
834 while len(worklists
[0]) != 0:
835 writing_op
= next(iter(worklists
[0]))
836 del worklists
[0][writing_op
]
837 retval
.append(writing_op
)
838 for val
in writing_op
.outputs().values():
839 if val
in ready_vals
:
840 raise ValueError(f
"multiple instructions must not write "
841 f
"to the same SSA value: {val}")
843 for reading_op
in inps_to_ops_map
[val
]:
844 pending
= ops_to_pending_input_count_map
[reading_op
]
845 del worklists
[pending
][reading_op
]
847 worklists
[pending
][reading_op
] = None
848 ops_to_pending_input_count_map
[reading_op
] = pending
849 for worklist
in worklists
:
851 raise ValueError(f
"instruction is part of a dependency loop or "
852 f
"its inputs are never written: {op}")