2 Compiler IR for Toom-Cook algorithm generator for SVP64
4 This assumes VL != 0 throughout.
7 from abc
import ABCMeta
, abstractmethod
8 from collections
import defaultdict
9 from enum
import Enum
, EnumMeta
, unique
10 from functools
import lru_cache
11 from typing
import Any
, Generic
, Iterable
, Sequence
, Type
, TypeVar
, cast
13 from nmutil
.plain_data
import fields
, plain_data
15 from bigint_presentation_code
.util
import OFSet
, OSet
, final
18 class ABCEnumMeta(EnumMeta
, ABCMeta
):
22 class RegLoc(metaclass
=ABCMeta
):
26 def conflicts(self
, other
):
27 # type: (RegLoc) -> bool
30 def get_subreg_at_offset(self
, subreg_type
, offset
):
31 # type: (RegType, int) -> RegLoc
32 if self
not in subreg_type
.reg_class
:
33 raise ValueError(f
"register not a member of subreg_type: "
34 f
"reg={self} subreg_type={subreg_type}")
36 raise ValueError(f
"non-zero sub-register offset not supported "
37 f
"for register: {self}")
44 @plain_data(frozen
=True, unsafe_hash
=True)
46 class GPRRange(RegLoc
, Sequence
["GPRRange"]):
47 __slots__
= "start", "length"
49 def __init__(self
, start
, length
=None):
50 # type: (int | range, int | None) -> None
51 if isinstance(start
, range):
52 if length
is not None:
53 raise TypeError("can't specify length when input is a range")
55 raise ValueError("range must have a step of 1")
60 if length
<= 0 or start
< 0 or start
+ length
> GPR_COUNT
:
61 raise ValueError("invalid GPRRange")
67 return self
.start
+ self
.length
75 return range(self
.start
, self
.stop
, self
.step
)
80 def __getitem__(self
, item
):
81 # type: (int | slice) -> GPRRange
82 return GPRRange(self
.range[item
])
84 def __contains__(self
, value
):
85 # type: (GPRRange) -> bool
86 return value
.start
>= self
.start
and value
.stop
<= self
.stop
88 def index(self
, sub
, start
=None, end
=None):
89 # type: (GPRRange, int | None, int | None) -> int
90 r
= self
.range[start
:end
]
91 if sub
.start
< r
.start
or sub
.stop
> r
.stop
:
92 raise ValueError("GPR range not found")
93 return sub
.start
- self
.start
95 def count(self
, sub
, start
=None, end
=None):
96 # type: (GPRRange, int | None, int | None) -> int
97 r
= self
.range[start
:end
]
100 return int(sub
in GPRRange(r
))
102 def conflicts(self
, other
):
103 # type: (RegLoc) -> bool
104 if isinstance(other
, GPRRange
):
105 return self
.stop
> other
.start
and other
.stop
> self
.start
108 def get_subreg_at_offset(self
, subreg_type
, offset
):
109 # type: (RegType, int) -> GPRRange
110 if not isinstance(subreg_type
, (GPRRangeType
, FixedGPRRangeType
)):
111 raise ValueError(f
"subreg_type is not a FixedGPRRangeType or "
112 f
"GPRRangeType: {subreg_type}")
113 if offset
< 0 or offset
+ subreg_type
.length
> self
.stop
:
114 raise ValueError(f
"sub-register offset is out of range: {offset}")
115 return GPRRange(self
.start
+ offset
, subreg_type
.length
)
118 SPECIAL_GPRS
= GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
123 class XERBit(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
126 def conflicts(self
, other
):
127 # type: (RegLoc) -> bool
128 if isinstance(other
, XERBit
):
135 class GlobalMem(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
136 """singleton representing all non-StackSlot memory -- treated as a single
137 physical register for register allocation purposes.
139 GlobalMem
= "GlobalMem"
141 def conflicts(self
, other
):
142 # type: (RegLoc) -> bool
143 if isinstance(other
, GlobalMem
):
150 class VL(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
151 VL_MAXVL
= "VL_MAXVL"
154 def conflicts(self
, other
):
155 # type: (RegLoc) -> bool
156 if isinstance(other
, VL
):
162 class RegClass(OFSet
[RegLoc
]):
163 """ an ordered set of registers.
164 earlier registers are preferred by the register allocator.
167 @lru_cache(maxsize
=None, typed
=True)
168 def max_conflicts_with(self
, other
):
169 # type: (RegClass | RegLoc) -> int
170 """the largest number of registers in `self` that a single register
171 from `other` can conflict with
173 if isinstance(other
, RegClass
):
174 return max(self
.max_conflicts_with(i
) for i
in other
)
176 return sum(other
.conflicts(i
) for i
in self
)
179 @plain_data(frozen
=True, unsafe_hash
=True)
180 class RegType(metaclass
=ABCMeta
):
186 # type: () -> RegClass
190 _RegType
= TypeVar("_RegType", bound
=RegType
)
191 _RegLoc
= TypeVar("_RegLoc", bound
=RegLoc
)
194 @plain_data(frozen
=True, eq
=False)
196 class GPRRangeType(RegType
):
197 __slots__
= "length",
199 def __init__(self
, length
=1):
200 # type: (int) -> None
201 if length
< 1 or length
> GPR_COUNT
:
202 raise ValueError("invalid length")
206 @lru_cache(maxsize
=None)
207 def __get_reg_class(length
):
208 # type: (int) -> RegClass
210 for start
in range(GPR_COUNT
- length
):
211 reg
= GPRRange(start
, length
)
212 if any(i
in reg
for i
in SPECIAL_GPRS
):
215 return RegClass(regs
)
220 # type: () -> RegClass
221 return GPRRangeType
.__get
_reg
_class
(self
.length
)
224 def __eq__(self
, other
):
225 if isinstance(other
, GPRRangeType
):
226 return self
.length
== other
.length
231 return hash(self
.length
)
234 GPRType
= GPRRangeType
235 """a length=1 GPRRangeType"""
238 @plain_data(frozen
=True, unsafe_hash
=True)
240 class FixedGPRRangeType(RegType
):
243 def __init__(self
, reg
):
244 # type: (GPRRange) -> None
249 # type: () -> RegClass
250 return RegClass([self
.reg
])
255 return self
.reg
.length
258 @plain_data(frozen
=True, unsafe_hash
=True)
260 class CAType(RegType
):
265 # type: () -> RegClass
266 return RegClass([XERBit
.CA
])
269 @plain_data(frozen
=True, unsafe_hash
=True)
271 class GlobalMemType(RegType
):
276 # type: () -> RegClass
277 return RegClass([GlobalMem
.GlobalMem
])
280 @plain_data(frozen
=True, unsafe_hash
=True)
282 class KnownVLType(RegType
):
283 __slots__
= "length",
285 def __init__(self
, length
):
286 # type: (int) -> None
287 if not (0 < length
<= 64):
288 raise ValueError("invalid VL value")
293 # type: () -> RegClass
294 return RegClass([VL
.VL_MAXVL
])
297 def assert_vl_is(vl
, expected_vl
):
298 # type: (SSAKnownVL | KnownVLType | int | None, int) -> None
301 elif isinstance(vl
, SSAVal
):
303 elif isinstance(vl
, KnownVLType
):
305 if vl
!= expected_vl
:
307 f
"wrong VL: expected {expected_vl} got {vl}")
313 @plain_data(frozen
=True, unsafe_hash
=True)
315 class StackSlot(RegLoc
):
316 __slots__
= "start_slot", "length_in_slots",
318 def __init__(self
, start_slot
, length_in_slots
):
319 # type: (int, int) -> None
320 self
.start_slot
= start_slot
321 if length_in_slots
< 1:
322 raise ValueError("invalid length_in_slots")
323 self
.length_in_slots
= length_in_slots
327 return self
.start_slot
+ self
.length_in_slots
330 def start_byte(self
):
331 return self
.start_slot
* STACK_SLOT_SIZE
333 def conflicts(self
, other
):
334 # type: (RegLoc) -> bool
335 if isinstance(other
, StackSlot
):
336 return (self
.stop_slot
> other
.start_slot
337 and other
.stop_slot
> self
.start_slot
)
340 def get_subreg_at_offset(self
, subreg_type
, offset
):
341 # type: (RegType, int) -> StackSlot
342 if not isinstance(subreg_type
, StackSlotType
):
343 raise ValueError(f
"subreg_type is not a "
344 f
"StackSlotType: {subreg_type}")
345 if offset
< 0 or offset
+ subreg_type
.length_in_slots
> self
.stop_slot
:
346 raise ValueError(f
"sub-register offset is out of range: {offset}")
347 return StackSlot(self
.start_slot
+ offset
, subreg_type
.length_in_slots
)
350 STACK_SLOT_COUNT
= 128
353 @plain_data(frozen
=True, eq
=False)
355 class StackSlotType(RegType
):
356 __slots__
= "length_in_slots",
358 def __init__(self
, length_in_slots
=1):
359 # type: (int) -> None
360 if length_in_slots
< 1:
361 raise ValueError("invalid length_in_slots")
362 self
.length_in_slots
= length_in_slots
365 @lru_cache(maxsize
=None)
366 def __get_reg_class(length_in_slots
):
367 # type: (int) -> RegClass
369 for start
in range(STACK_SLOT_COUNT
- length_in_slots
):
370 reg
= StackSlot(start
, length_in_slots
)
372 return RegClass(regs
)
376 # type: () -> RegClass
377 return StackSlotType
.__get
_reg
_class
(self
.length_in_slots
)
380 def __eq__(self
, other
):
381 if isinstance(other
, StackSlotType
):
382 return self
.length_in_slots
== other
.length_in_slots
387 return hash(self
.length_in_slots
)
390 @plain_data(frozen
=True, eq
=False, repr=False)
392 class SSAVal(Generic
[_RegType
]):
393 __slots__
= "op", "arg_name", "ty",
395 def __init__(self
, op
, arg_name
, ty
):
396 # type: (Op, str, _RegType) -> None
398 """the Op that writes this SSAVal"""
400 self
.arg_name
= arg_name
401 """the name of the argument of self.op that writes this SSAVal"""
405 def __eq__(self
, rhs
):
406 if isinstance(rhs
, SSAVal
):
407 return (self
.op
is rhs
.op
408 and self
.arg_name
== rhs
.arg_name
)
412 return hash((id(self
.op
), self
.arg_name
))
414 def __repr__(self
, long=False):
416 return f
"<#{self.op.id}.{self.arg_name}>"
418 for name
in fields(self
):
419 v
= getattr(self
, name
, None)
422 v
= v
.__repr
__(just_id
=True)
425 fields_list
.append(f
"{name}={v}")
426 fields_str
= ", ".join(fields_list
)
427 return f
"SSAVal({fields_str})"
430 SSAGPRRange
= SSAVal
[GPRRangeType
]
431 SSAGPR
= SSAVal
[GPRType
]
432 SSAKnownVL
= SSAVal
[KnownVLType
]
436 @plain_data(unsafe_hash
=True, frozen
=True)
437 class EqualityConstraint
:
438 __slots__
= "lhs", "rhs"
440 def __init__(self
, lhs
, rhs
):
441 # type: (list[SSAVal], list[SSAVal]) -> None
444 if len(lhs
) == 0 or len(rhs
) == 0:
445 raise ValueError("can't constrain an empty list to be equal")
454 self
.ops
= [] # type: list[Op]
456 def __repr__(self
, short
=False):
459 ops
= ", ".join(op
.__repr
__(just_id
=True) for op
in self
.ops
)
460 return f
"<Fn([{ops}])>"
464 """ helper for __repr__ for when fields aren't set """
473 @plain_data(frozen
=True, unsafe_hash
=True)
474 class AsmTemplateSegment(Generic
[_RegType
], metaclass
=ABCMeta
):
475 __slots__
= "ssa_val",
477 def __init__(self
, ssa_val
):
478 # type: (SSAVal[_RegType]) -> None
479 self
.ssa_val
= ssa_val
481 def render(self
, regs
):
482 # type: (dict[SSAVal, RegLoc]) -> str
483 return self
._render
(regs
[self
.ssa_val
])
486 def _render(self
, reg
):
487 # type: (RegLoc) -> str
491 @plain_data(frozen
=True, unsafe_hash
=True)
493 class ATSGPR(AsmTemplateSegment
[GPRRangeType
]):
494 __slots__
= "offset",
496 def __init__(self
, ssa_val
, offset
=0):
497 # type: (SSAGPRRange, int) -> None
498 super().__init
__(ssa_val
)
501 def _render(self
, reg
):
502 # type: (RegLoc) -> str
503 if not isinstance(reg
, GPRRange
):
505 return str(reg
.start
+ self
.offset
)
508 @plain_data(frozen
=True, unsafe_hash
=True)
510 class ATSStackSlot(AsmTemplateSegment
[StackSlotType
]):
513 def _render(self
, reg
):
514 # type: (RegLoc) -> str
515 if not isinstance(reg
, StackSlot
):
517 return f
"{reg.start_slot}(1)"
520 @plain_data(frozen
=True, unsafe_hash
=True)
522 class ATSCopyGPRRange(AsmTemplateSegment
["GPRRangeType | FixedGPRRangeType"]):
523 __slots__
= "src_ssa_val",
525 def __init__(self
, ssa_val
, src_ssa_val
):
526 # type: (SSAVal[GPRRangeType | FixedGPRRangeType], SSAVal[GPRRangeType | FixedGPRRangeType]) -> None
527 self
.ssa_val
= ssa_val
528 self
.src_ssa_val
= src_ssa_val
530 def render(self
, regs
):
531 # type: (dict[SSAVal, RegLoc]) -> str
532 src
= regs
[self
.src_ssa_val
]
533 dest
= regs
[self
.ssa_val
]
534 if not isinstance(dest
, GPRRange
):
536 if not isinstance(src
, GPRRange
):
538 if src
.length
!= dest
.length
:
546 elif src
.conflicts(dest
) and src
.start
> dest
.start
:
548 return f
"{sv_}or{mrr} *{dest.start}, *{src.start}, *{src.start}\n"
550 def _render(self
, reg
):
551 # type: (RegLoc) -> str
552 raise TypeError("must call self.render")
556 class AsmTemplate(Sequence
["str | AsmTemplateSegment"]):
558 def __process_segments(segments
):
559 # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> Iterable[str | AsmTemplateSegment]
561 if isinstance(i
, AsmTemplate
):
566 def __init__(self
, segments
=()):
567 # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> None
568 self
.__segments
= tuple(self
.__process
_segments
(segments
))
570 def __getitem__(self
, index
):
571 # type: (int) -> str | AsmTemplateSegment
572 return self
.__segments
[index
]
575 return len(self
.__segments
)
578 return iter(self
.__segments
)
581 return hash(self
.__segments
)
583 def render(self
, regs
):
584 # type: (dict[SSAVal, RegLoc]) -> str
585 retval
= [] # type: list[str]
587 if isinstance(segment
, AsmTemplateSegment
):
588 retval
.append(segment
.render(regs
))
590 retval
.append(segment
)
591 return "".join(retval
)
596 def __init__(self
, assigned_registers
):
597 # type: (dict[SSAVal, RegLoc]) -> None
598 self
.__assigned
_registers
= assigned_registers
600 def reg(self
, ssa_val
, expected_ty
):
601 # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc
603 reg
= self
.__assigned
_registers
[ssa_val
]
604 except KeyError as e
:
605 raise ValueError(f
"SSAVal not assigned a register: {ssa_val}")
606 wrong_len
= (isinstance(reg
, GPRRange
)
607 and reg
.length
!= ssa_val
.ty
.length
)
608 if not isinstance(reg
, expected_ty
) or wrong_len
:
610 f
"SSAVal is assigned a register of the wrong type: "
611 f
"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}")
614 def gpr_range(self
, ssa_val
):
615 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange
616 return self
.reg(ssa_val
, GPRRange
)
618 def stack_slot(self
, ssa_val
):
619 # type: (SSAVal[StackSlotType]) -> StackSlot
620 return self
.reg(ssa_val
, StackSlot
)
622 def gpr(self
, ssa_val
, vec
, offset
=0):
623 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str
624 reg
= self
.gpr_range(ssa_val
).start
+ offset
625 return "*" * vec
+ str(reg
)
627 def vgpr(self
, ssa_val
, offset
=0):
628 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str
629 return self
.gpr(ssa_val
=ssa_val
, vec
=True, offset
=offset
)
631 def sgpr(self
, ssa_val
, offset
=0):
632 # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str
633 return self
.gpr(ssa_val
=ssa_val
, vec
=False, offset
=offset
)
635 def needs_sv(self
, *regs
):
636 # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool
638 reg
= self
.gpr_range(reg
)
639 if reg
.length
!= 1 or reg
.start
>= 32:
644 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
645 class Op(metaclass
=ABCMeta
):
646 __slots__
= "id", "fn"
650 # type: () -> dict[str, SSAVal]
655 # type: () -> dict[str, SSAVal]
658 def get_equality_constraints(self
):
659 # type: () -> Iterable[EqualityConstraint]
663 def get_extra_interferences(self
):
664 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
668 def __init__(self
, fn
):
670 self
.id = len(fn
.ops
)
675 def __repr__(self
, just_id
=False):
676 fields_list
= [f
"#{self.id}"]
679 outputs
= self
.outputs()
680 except AttributeError:
683 for name
in fields(self
):
684 v
= getattr(self
, name
, _NOT_SET
)
685 if ((outputs
is None or name
in outputs
)
686 and isinstance(v
, SSAVal
)):
687 v
= v
.__repr
__(long=True)
688 elif isinstance(v
, Fn
):
689 v
= v
.__repr
__(short
=True)
692 fields_list
.append(f
"{name}={v}")
693 fields_str
= ', '.join(fields_list
)
694 return f
"{self.__class__.__name__}({fields_str})"
697 def get_asm_lines(self
, ctx
):
698 # type: (AsmContext) -> list[str]
699 """get the lines of assembly for this Op"""
703 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
705 class OpLoadFromStackSlot(Op
):
706 __slots__
= "dest", "src", "vl"
709 # type: () -> dict[str, SSAVal]
710 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
711 if self
.vl
is not None:
712 retval
["vl"] = self
.vl
716 # type: () -> dict[str, SSAVal]
717 return {"dest": self
.dest
}
719 def __init__(self
, fn
, src
, vl
=None):
720 # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None
722 self
.dest
= SSAVal(self
, "dest", GPRRangeType(src
.ty
.length_in_slots
))
725 assert_vl_is(vl
, self
.dest
.ty
.length
)
727 def get_asm_lines(self
, ctx
):
728 # type: (AsmContext) -> list[str]
729 dest
= ctx
.gpr(self
.dest
, vec
=self
.dest
.ty
.length
!= 1)
730 src
= ctx
.stack_slot(self
.src
)
731 if ctx
.needs_sv(self
.dest
):
732 return [f
"sv.ld {dest}, {src.start_byte}(1)"]
733 return [f
"ld {dest}, {src.start_byte}(1)"]
736 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
738 class OpStoreToStackSlot(Op
):
739 __slots__
= "dest", "src", "vl"
742 # type: () -> dict[str, SSAVal]
743 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
744 if self
.vl
is not None:
745 retval
["vl"] = self
.vl
749 # type: () -> dict[str, SSAVal]
750 return {"dest": self
.dest
}
752 def __init__(self
, fn
, src
, vl
=None):
753 # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None
755 self
.dest
= SSAVal(self
, "dest", StackSlotType(src
.ty
.length
))
758 assert_vl_is(vl
, src
.ty
.length
)
760 def get_asm_lines(self
, ctx
):
761 # type: (AsmContext) -> list[str]
762 src
= ctx
.gpr(self
.src
, vec
=self
.src
.ty
.length
!= 1)
763 dest
= ctx
.stack_slot(self
.dest
)
764 if ctx
.needs_sv(self
.src
):
765 return [f
"sv.std {src}, {dest.start_byte}(1)"]
766 return [f
"std {src}, {dest.start_byte}(1)"]
769 _RegSrcType
= TypeVar("_RegSrcType", bound
=RegType
)
772 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
774 class OpCopy(Op
, Generic
[_RegSrcType
, _RegType
]):
775 __slots__
= "dest", "src", "vl"
778 # type: () -> dict[str, SSAVal]
779 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
780 if self
.vl
is not None:
781 retval
["vl"] = self
.vl
785 # type: () -> dict[str, SSAVal]
786 return {"dest": self
.dest
}
788 def __init__(self
, fn
, src
, dest_ty
=None, vl
=None):
789 # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None
792 dest_ty
= cast(_RegType
, src
.ty
)
793 if isinstance(src
.ty
, GPRRangeType
) \
794 and isinstance(dest_ty
, FixedGPRRangeType
):
795 if src
.ty
.length
!= dest_ty
.reg
.length
:
796 raise ValueError(f
"incompatible source and destination "
797 f
"types: {src.ty} and {dest_ty}")
798 length
= src
.ty
.length
799 elif isinstance(src
.ty
, FixedGPRRangeType
) \
800 and isinstance(dest_ty
, GPRRangeType
):
801 if src
.ty
.reg
.length
!= dest_ty
.length
:
802 raise ValueError(f
"incompatible source and destination "
803 f
"types: {src.ty} and {dest_ty}")
804 length
= src
.ty
.length
805 elif src
.ty
!= dest_ty
:
806 raise ValueError(f
"incompatible source and destination "
807 f
"types: {src.ty} and {dest_ty}")
808 elif isinstance(src
.ty
, (GPRRangeType
, FixedGPRRangeType
)):
809 length
= src
.ty
.length
813 self
.dest
= SSAVal(self
, "dest", dest_ty
) # type: SSAVal[_RegType]
816 assert_vl_is(vl
, length
)
818 def get_asm_lines(self
, ctx
):
819 # type: (AsmContext) -> list[str]
820 if ctx
.reg(self
.src
, RegLoc
) == ctx
.reg(self
.dest
, RegLoc
):
822 if (isinstance(self
.src
.ty
, (GPRRangeType
, FixedGPRRangeType
)) and
823 isinstance(self
.dest
.ty
, (GPRRangeType
, FixedGPRRangeType
))):
824 vec
= self
.dest
.ty
.length
!= 1
825 dest
= ctx
.gpr_range(self
.dest
) # type: ignore
826 src
= ctx
.gpr_range(self
.src
) # type: ignore
827 dest_s
= ctx
.gpr(self
.dest
, vec
=vec
) # type: ignore
828 src_s
= ctx
.gpr(self
.src
, vec
=vec
) # type: ignore
830 if src
.conflicts(dest
) and src
.start
> dest
.start
:
832 if ctx
.needs_sv(self
.src
, self
.dest
): # type: ignore
833 return [f
"sv.or{mrr} {dest_s}, {src_s}, {src_s}"]
834 return [f
"or {dest_s}, {src_s}, {src_s}"]
835 raise NotImplementedError
838 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
841 __slots__
= "dest", "sources"
844 # type: () -> dict[str, SSAVal]
845 return {f
"sources[{i}]": v
for i
, v
in enumerate(self
.sources
)}
848 # type: () -> dict[str, SSAVal]
849 return {"dest": self
.dest
}
851 def __init__(self
, fn
, sources
):
852 # type: (Fn, Iterable[SSAGPRRange]) -> None
854 sources
= tuple(sources
)
855 self
.dest
= SSAVal(self
, "dest", GPRRangeType(
856 sum(i
.ty
.length
for i
in sources
)))
857 self
.sources
= sources
859 def get_equality_constraints(self
):
860 # type: () -> Iterable[EqualityConstraint]
861 yield EqualityConstraint([self
.dest
], [*self
.sources
])
863 def get_asm_lines(self
, ctx
):
864 # type: (AsmContext) -> list[str]
868 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
871 __slots__
= "results", "src"
874 # type: () -> dict[str, SSAVal]
875 return {"src": self
.src
}
878 # type: () -> dict[str, SSAVal]
879 return {i
.arg_name
: i
for i
in self
.results
}
881 def __init__(self
, fn
, src
, split_indexes
):
882 # type: (Fn, SSAGPRRange, Iterable[int]) -> None
884 ranges
= [] # type: list[GPRRangeType]
886 for i
in split_indexes
:
887 if not (0 < i
< src
.ty
.length
):
888 raise ValueError(f
"invalid split index: {i}, must be in "
889 f
"0 < i < {src.ty.length}")
890 ranges
.append(GPRRangeType(i
- last
))
892 ranges
.append(GPRRangeType(src
.ty
.length
- last
))
894 self
.results
= tuple(
895 SSAVal(self
, f
"results{i}", r
) for i
, r
in enumerate(ranges
))
897 def get_equality_constraints(self
):
898 # type: () -> Iterable[EqualityConstraint]
899 yield EqualityConstraint([*self
.results
], [self
.src
])
901 def get_asm_lines(self
, ctx
):
902 # type: (AsmContext) -> list[str]
906 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
908 class OpBigIntAddSub(Op
):
909 __slots__
= "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl"
912 # type: () -> dict[str, SSAVal]
913 retval
= {} # type: dict[str, SSAVal[Any]]
914 retval
["lhs"] = self
.lhs
915 retval
["rhs"] = self
.rhs
916 retval
["CA_in"] = self
.CA_in
917 if self
.vl
is not None:
918 retval
["vl"] = self
.vl
922 # type: () -> dict[str, SSAVal]
923 return {"out": self
.out
, "CA_out": self
.CA_out
}
925 def __init__(self
, fn
, lhs
, rhs
, CA_in
, is_sub
, vl
=None):
926 # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None
929 raise TypeError(f
"source types must match: "
930 f
"{lhs} doesn't match {rhs}")
931 self
.out
= SSAVal(self
, "out", lhs
.ty
)
935 self
.CA_out
= SSAVal(self
, "CA_out", CA_in
.ty
)
938 assert_vl_is(vl
, lhs
.ty
.length
)
940 def get_extra_interferences(self
):
941 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
942 yield self
.out
, self
.lhs
943 yield self
.out
, self
.rhs
945 def get_asm_lines(self
, ctx
):
946 # type: (AsmContext) -> list[str]
947 vec
= self
.out
.ty
.length
!= 1
948 out
= ctx
.gpr(self
.out
, vec
=vec
)
949 RA
= ctx
.gpr(self
.lhs
, vec
=vec
)
950 RB
= ctx
.gpr(self
.rhs
, vec
=vec
)
954 RA
, RB
= RB
, RA
# reorder to match subfe
955 if ctx
.needs_sv(self
.out
, self
.lhs
, self
.rhs
):
956 return [f
"sv.{mnemonic} {out}, {RA}, {RB}"]
957 return [f
"{mnemonic} {out}, {RA}, {RB}"]
960 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
962 class OpBigIntMulDiv(Op
):
963 __slots__
= "RT", "RA", "RB", "RC", "RS", "is_div", "vl"
966 # type: () -> dict[str, SSAVal]
967 retval
= {} # type: dict[str, SSAVal[Any]]
968 retval
["RA"] = self
.RA
969 retval
["RB"] = self
.RB
970 retval
["RC"] = self
.RC
971 if self
.vl
is not None:
972 retval
["vl"] = self
.vl
976 # type: () -> dict[str, SSAVal]
977 return {"RT": self
.RT
, "RS": self
.RS
}
979 def __init__(self
, fn
, RA
, RB
, RC
, is_div
, vl
):
980 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None
982 self
.RT
= SSAVal(self
, "RT", RA
.ty
)
986 self
.RS
= SSAVal(self
, "RS", RC
.ty
)
989 assert_vl_is(vl
, RA
.ty
.length
)
991 def get_equality_constraints(self
):
992 # type: () -> Iterable[EqualityConstraint]
993 yield EqualityConstraint([self
.RC
], [self
.RS
])
995 def get_extra_interferences(self
):
996 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
997 yield self
.RT
, self
.RA
998 yield self
.RT
, self
.RB
999 yield self
.RT
, self
.RC
1000 yield self
.RT
, self
.RS
1001 yield self
.RS
, self
.RA
1002 yield self
.RS
, self
.RB
1004 def get_asm_lines(self
, ctx
):
1005 # type: (AsmContext) -> list[str]
1006 vec
= self
.RT
.ty
.length
!= 1
1007 RT
= ctx
.gpr(self
.RT
, vec
=vec
)
1008 RA
= ctx
.gpr(self
.RA
, vec
=vec
)
1009 RB
= ctx
.sgpr(self
.RB
)
1010 RC
= ctx
.sgpr(self
.RC
)
1011 mnemonic
= "maddedu"
1013 mnemonic
= "divmod2du/mrr"
1014 return [f
"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
1019 class ShiftKind(Enum
):
1024 def make_big_int_carry_in(self
, fn
, inp
):
1025 # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]]
1026 if self
is ShiftKind
.Sl
or self
is ShiftKind
.Sr
:
1030 assert self
is ShiftKind
.Sra
1031 split
= OpSplit(fn
, inp
, [inp
.ty
.length
- 1])
1032 shr
= OpShiftImm(fn
, split
.results
[1], sh
=63, kind
=ShiftKind
.Sra
)
1033 return shr
.out
, [split
, shr
]
1035 def make_big_int_shift(self
, fn
, inp
, sh
, vl
):
1036 # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]]
1037 carry_in
, ops
= self
.make_big_int_carry_in(fn
, inp
)
1038 big_int_shift
= OpBigIntShift(fn
, inp
, sh
, carry_in
, kind
=self
, vl
=vl
)
1039 ops
.append(big_int_shift
)
1040 return big_int_shift
.out
, ops
1043 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1045 class OpBigIntShift(Op
):
1046 __slots__
= "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl"
1049 # type: () -> dict[str, SSAVal]
1050 retval
= {} # type: dict[str, SSAVal[Any]]
1051 retval
["inp"] = self
.inp
1052 retval
["sh"] = self
.sh
1053 retval
["carry_in"] = self
.carry_in
1054 if self
.vl
is not None:
1055 retval
["vl"] = self
.vl
1059 # type: () -> dict[str, SSAVal]
1060 return {"out": self
.out
, "_out_padding": self
._out
_padding
}
1062 def __init__(self
, fn
, inp
, sh
, carry_in
, kind
, vl
=None):
1063 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None
1064 super().__init
__(fn
)
1065 self
.out
= SSAVal(self
, "out", inp
.ty
)
1066 self
._out
_padding
= SSAVal(self
, "_out_padding", GPRRangeType())
1067 self
.carry_in
= carry_in
1072 assert_vl_is(vl
, inp
.ty
.length
)
1074 def get_extra_interferences(self
):
1075 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1076 yield self
.out
, self
.sh
1078 def get_equality_constraints(self
):
1079 # type: () -> Iterable[EqualityConstraint]
1080 if self
.kind
is ShiftKind
.Sl
:
1081 yield EqualityConstraint([self
.carry_in
, self
.inp
],
1082 [self
.out
, self
._out
_padding
])
1084 assert self
.kind
is ShiftKind
.Sr
or self
.kind
is ShiftKind
.Sra
1085 yield EqualityConstraint([self
.inp
, self
.carry_in
],
1086 [self
._out
_padding
, self
.out
])
1088 def get_asm_lines(self
, ctx
):
1089 # type: (AsmContext) -> list[str]
1090 vec
= self
.out
.ty
.length
!= 1
1091 if self
.kind
is ShiftKind
.Sl
:
1092 RT
= ctx
.gpr(self
.out
, vec
=vec
)
1093 RA
= ctx
.gpr(self
.out
, vec
=vec
, offset
=-1)
1094 RB
= ctx
.sgpr(self
.sh
)
1095 mrr
= "/mrr" if vec
else ""
1096 return [f
"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"]
1098 assert self
.kind
is ShiftKind
.Sr
or self
.kind
is ShiftKind
.Sra
1099 RT
= ctx
.gpr(self
.out
, vec
=vec
)
1100 RA
= ctx
.gpr(self
.out
, vec
=vec
, offset
=1)
1101 RB
= ctx
.sgpr(self
.sh
)
1102 return [f
"sv.dsrd {RT}, {RA}, {RB}, 1"]
1105 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1107 class OpShiftImm(Op
):
1108 __slots__
= "out", "inp", "sh", "kind", "ca_out"
1111 # type: () -> dict[str, SSAVal]
1112 return {"inp": self
.inp
}
1115 # type: () -> dict[str, SSAVal]
1116 if self
.ca_out
is not None:
1117 return {"out": self
.out
, "ca_out": self
.ca_out
}
1118 return {"out": self
.out
}
1120 def __init__(self
, fn
, inp
, sh
, kind
):
1121 # type: (Fn, SSAGPR, int, ShiftKind) -> None
1122 super().__init
__(fn
)
1123 self
.out
= SSAVal(self
, "out", inp
.ty
)
1125 if not (0 <= sh
< 64):
1126 raise ValueError("shift amount out of range")
1129 if self
.kind
is ShiftKind
.Sra
:
1130 self
.ca_out
= SSAVal(self
, "ca_out", CAType())
1134 def get_asm_lines(self
, ctx
):
1135 # type: (AsmContext) -> list[str]
1136 out
= ctx
.sgpr(self
.out
)
1137 inp
= ctx
.sgpr(self
.inp
)
1138 if self
.kind
is ShiftKind
.Sl
:
1140 args
= f
"{self.sh}, {63 - self.sh}"
1141 elif self
.kind
is ShiftKind
.Sr
:
1143 v
= (64 - self
.sh
) % 64
1144 args
= f
"{v}, {self.sh}"
1146 assert self
.kind
is ShiftKind
.Sra
1149 if ctx
.needs_sv(self
.out
, self
.inp
):
1150 return [f
"sv.{mnemonic} {out}, {inp}, {args}"]
1151 return [f
"{mnemonic} {out}, {inp}, {args}"]
1154 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1157 __slots__
= "out", "value", "vl"
1160 # type: () -> dict[str, SSAVal]
1161 retval
= {} # type: dict[str, SSAVal[Any]]
1162 if self
.vl
is not None:
1163 retval
["vl"] = self
.vl
1167 # type: () -> dict[str, SSAVal]
1168 return {"out": self
.out
}
1170 def __init__(self
, fn
, value
, vl
=None):
1171 # type: (Fn, int, SSAKnownVL | None) -> None
1172 super().__init
__(fn
)
1176 length
= vl
.ty
.length
1177 self
.out
= SSAVal(self
, "out", GPRRangeType(length
))
1178 if not (-1 << 15 <= value
<= (1 << 15) - 1):
1179 raise ValueError(f
"value out of range: {value}")
1182 assert_vl_is(vl
, length
)
1184 def get_asm_lines(self
, ctx
):
1185 # type: (AsmContext) -> list[str]
1186 vec
= self
.out
.ty
.length
!= 1
1187 out
= ctx
.gpr(self
.out
, vec
=vec
)
1188 if ctx
.needs_sv(self
.out
):
1189 return [f
"sv.addi {out}, 0, {self.value}"]
1190 return [f
"addi {out}, 0, {self.value}"]
1193 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1196 __slots__
= "out", "value"
1199 # type: () -> dict[str, SSAVal]
1203 # type: () -> dict[str, SSAVal]
1204 return {"out": self
.out
}
1206 def __init__(self
, fn
, value
):
1207 # type: (Fn, bool) -> None
1208 super().__init
__(fn
)
1209 self
.out
= SSAVal(self
, "out", CAType())
1212 def get_asm_lines(self
, ctx
):
1213 # type: (AsmContext) -> list[str]
1215 return ["subfic 0, 0, -1"]
1216 return ["addic 0, 0, 0"]
1219 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1222 __slots__
= "RT", "RA", "offset", "mem", "vl"
1225 # type: () -> dict[str, SSAVal]
1226 retval
= {} # type: dict[str, SSAVal[Any]]
1227 retval
["RA"] = self
.RA
1228 retval
["mem"] = self
.mem
1229 if self
.vl
is not None:
1230 retval
["vl"] = self
.vl
1234 # type: () -> dict[str, SSAVal]
1235 return {"RT": self
.RT
}
1237 def __init__(self
, fn
, RA
, offset
, mem
, vl
=None):
1238 # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1239 super().__init
__(fn
)
1243 length
= vl
.ty
.length
1244 self
.RT
= SSAVal(self
, "RT", GPRRangeType(length
))
1246 if not (-1 << 15 <= offset
<= (1 << 15) - 1):
1247 raise ValueError(f
"offset out of range: {offset}")
1249 raise ValueError(f
"offset not aligned: {offset}")
1250 self
.offset
= offset
1253 assert_vl_is(vl
, length
)
1255 def get_extra_interferences(self
):
1256 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1257 if self
.RT
.ty
.length
> 1:
1258 yield self
.RT
, self
.RA
1260 def get_asm_lines(self
, ctx
):
1261 # type: (AsmContext) -> list[str]
1262 RT
= ctx
.gpr(self
.RT
, vec
=self
.RT
.ty
.length
!= 1)
1263 RA
= ctx
.sgpr(self
.RA
)
1264 if ctx
.needs_sv(self
.RT
, self
.RA
):
1265 return [f
"sv.ld {RT}, {self.offset}({RA})"]
1266 return [f
"ld {RT}, {self.offset}({RA})"]
1269 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1272 __slots__
= "RS", "RA", "offset", "mem_in", "mem_out", "vl"
1275 # type: () -> dict[str, SSAVal]
1276 retval
= {} # type: dict[str, SSAVal[Any]]
1277 retval
["RS"] = self
.RS
1278 retval
["RA"] = self
.RA
1279 retval
["mem_in"] = self
.mem_in
1280 if self
.vl
is not None:
1281 retval
["vl"] = self
.vl
1285 # type: () -> dict[str, SSAVal]
1286 return {"mem_out": self
.mem_out
}
1288 def __init__(self
, fn
, RS
, RA
, offset
, mem_in
, vl
=None):
1289 # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1290 super().__init
__(fn
)
1293 if not (-1 << 15 <= offset
<= (1 << 15) - 1):
1294 raise ValueError(f
"offset out of range: {offset}")
1296 raise ValueError(f
"offset not aligned: {offset}")
1297 self
.offset
= offset
1298 self
.mem_in
= mem_in
1299 self
.mem_out
= SSAVal(self
, "mem_out", mem_in
.ty
)
1301 assert_vl_is(vl
, RS
.ty
.length
)
1303 def get_asm_lines(self
, ctx
):
1304 # type: (AsmContext) -> list[str]
1305 RS
= ctx
.gpr(self
.RS
, vec
=self
.RS
.ty
.length
!= 1)
1306 RA
= ctx
.sgpr(self
.RA
)
1307 if ctx
.needs_sv(self
.RS
, self
.RA
):
1308 return [f
"sv.std {RS}, {self.offset}({RA})"]
1309 return [f
"std {RS}, {self.offset}({RA})"]
1312 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1314 class OpFuncArg(Op
):
1318 # type: () -> dict[str, SSAVal]
1322 # type: () -> dict[str, SSAVal]
1323 return {"out": self
.out
}
1325 def __init__(self
, fn
, ty
):
1326 # type: (Fn, FixedGPRRangeType) -> None
1327 super().__init
__(fn
)
1328 self
.out
= SSAVal(self
, "out", ty
)
1330 def get_asm_lines(self
, ctx
):
1331 # type: (AsmContext) -> list[str]
1335 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1337 class OpInputMem(Op
):
1341 # type: () -> dict[str, SSAVal]
1345 # type: () -> dict[str, SSAVal]
1346 return {"out": self
.out
}
1348 def __init__(self
, fn
):
1349 # type: (Fn) -> None
1350 super().__init
__(fn
)
1351 self
.out
= SSAVal(self
, "out", GlobalMemType())
1353 def get_asm_lines(self
, ctx
):
1354 # type: (AsmContext) -> list[str]
1358 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1360 class OpSetVLImm(Op
):
1364 # type: () -> dict[str, SSAVal]
1368 # type: () -> dict[str, SSAVal]
1369 return {"out": self
.out
}
1371 def __init__(self
, fn
, length
):
1372 # type: (Fn, int) -> None
1373 super().__init
__(fn
)
1374 self
.out
= SSAVal(self
, "out", KnownVLType(length
))
1376 def get_asm_lines(self
, ctx
):
1377 # type: (AsmContext) -> list[str]
1378 return [f
"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
1381 def op_set_to_list(ops
):
1382 # type: (Iterable[Op]) -> list[Op]
1383 worklists
= [{}] # type: list[dict[Op, None]]
1384 inps_to_ops_map
= defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
1385 ops_to_pending_input_count_map
= {} # type: dict[Op, int]
1388 for val
in op
.inputs().values():
1390 inps_to_ops_map
[val
][op
] = None
1391 while len(worklists
) <= input_count
:
1392 worklists
.append({})
1393 ops_to_pending_input_count_map
[op
] = input_count
1394 worklists
[input_count
][op
] = None
1395 retval
= [] # type: list[Op]
1396 ready_vals
= OSet() # type: OSet[SSAVal]
1397 while len(worklists
[0]) != 0:
1398 writing_op
= next(iter(worklists
[0]))
1399 del worklists
[0][writing_op
]
1400 retval
.append(writing_op
)
1401 for val
in writing_op
.outputs().values():
1402 if val
in ready_vals
:
1403 raise ValueError(f
"multiple instructions must not write "
1404 f
"to the same SSA value: {val}")
1406 for reading_op
in inps_to_ops_map
[val
]:
1407 pending
= ops_to_pending_input_count_map
[reading_op
]
1408 del worklists
[pending
][reading_op
]
1410 worklists
[pending
][reading_op
] = None
1411 ops_to_pending_input_count_map
[reading_op
] = pending
1412 for worklist
in worklists
:
1414 raise ValueError(f
"instruction is part of a dependency loop or "
1415 f
"its inputs are never written: {op}")
1419 def generate_assembly(ops
, assigned_registers
=None):
1420 # type: (list[Op], dict[SSAVal, RegLoc] | None) -> list[str]
1421 if assigned_registers
is None:
1422 from bigint_presentation_code
.register_allocator
import \
1424 assigned_registers
= allocate_registers(ops
)
1425 ctx
= AsmContext(assigned_registers
)
1426 retval
= [] # list[str]
1428 retval
.extend(op
.get_asm_lines(ctx
))
1429 retval
.append("bclr 20, 0, 0")