2 Compiler IR for Toom-Cook algorithm generator for SVP64
4 This assumes VL != 0 throughout.
7 from abc
import ABCMeta
, abstractmethod
8 from collections
import defaultdict
9 from enum
import Enum
, EnumMeta
, unique
10 from functools
import lru_cache
11 from typing
import Any
, Generic
, Iterable
, Sequence
, Type
, TypeVar
, cast
13 from nmutil
.plain_data
import fields
, plain_data
15 from bigint_presentation_code
.util
import FMap
, OFSet
, OSet
, final
18 class ABCEnumMeta(EnumMeta
, ABCMeta
):
22 class RegLoc(metaclass
=ABCMeta
):
26 def conflicts(self
, other
):
27 # type: (RegLoc) -> bool
30 def get_subreg_at_offset(self
, subreg_type
, offset
):
31 # type: (RegType, int) -> RegLoc
32 if self
not in subreg_type
.reg_class
:
33 raise ValueError(f
"register not a member of subreg_type: "
34 f
"reg={self} subreg_type={subreg_type}")
36 raise ValueError(f
"non-zero sub-register offset not supported "
37 f
"for register: {self}")
44 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
46 class GPRRange(RegLoc
, Sequence
["GPRRange"]):
47 __slots__
= "start", "length"
49 def __init__(self
, start
, length
=None):
50 # type: (int | range, int | None) -> None
51 if isinstance(start
, range):
52 if length
is not None:
53 raise TypeError("can't specify length when input is a range")
55 raise ValueError("range must have a step of 1")
60 if length
<= 0 or start
< 0 or start
+ length
> GPR_COUNT
:
61 raise ValueError("invalid GPRRange")
67 return self
.start
+ self
.length
75 return range(self
.start
, self
.stop
, self
.step
)
80 def __getitem__(self
, item
):
81 # type: (int | slice) -> GPRRange
82 return GPRRange(self
.range[item
])
84 def __contains__(self
, value
):
85 # type: (GPRRange) -> bool
86 return value
.start
>= self
.start
and value
.stop
<= self
.stop
88 def index(self
, sub
, start
=None, end
=None):
89 # type: (GPRRange, int | None, int | None) -> int
90 r
= self
.range[start
:end
]
91 if sub
.start
< r
.start
or sub
.stop
> r
.stop
:
92 raise ValueError("GPR range not found")
93 return sub
.start
- self
.start
95 def count(self
, sub
, start
=None, end
=None):
96 # type: (GPRRange, int | None, int | None) -> int
97 r
= self
.range[start
:end
]
100 return int(sub
in GPRRange(r
))
102 def conflicts(self
, other
):
103 # type: (RegLoc) -> bool
104 if isinstance(other
, GPRRange
):
105 return self
.stop
> other
.start
and other
.stop
> self
.start
108 def get_subreg_at_offset(self
, subreg_type
, offset
):
109 # type: (RegType, int) -> GPRRange
110 if not isinstance(subreg_type
, (GPRRangeType
, FixedGPRRangeType
)):
111 raise ValueError(f
"subreg_type is not a FixedGPRRangeType or "
112 f
"GPRRangeType: {subreg_type}")
113 if offset
< 0 or offset
+ subreg_type
.length
> self
.stop
:
114 raise ValueError(f
"sub-register offset is out of range: {offset}")
115 return GPRRange(self
.start
+ offset
, subreg_type
.length
)
119 return f
"<r{self.start}>"
120 return f
"<r{self.start}..len={self.length}>"
123 SPECIAL_GPRS
= GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
128 class XERBit(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
131 def conflicts(self
, other
):
132 # type: (RegLoc) -> bool
133 if isinstance(other
, XERBit
):
140 class GlobalMem(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
141 """singleton representing all non-StackSlot memory -- treated as a single
142 physical register for register allocation purposes.
144 GlobalMem
= "GlobalMem"
146 def conflicts(self
, other
):
147 # type: (RegLoc) -> bool
148 if isinstance(other
, GlobalMem
):
155 class VL(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
156 VL_MAXVL
= "VL_MAXVL"
159 def conflicts(self
, other
):
160 # type: (RegLoc) -> bool
161 if isinstance(other
, VL
):
167 class RegClass(OFSet
[RegLoc
]):
168 """ an ordered set of registers.
169 earlier registers are preferred by the register allocator.
172 @lru_cache(maxsize
=None, typed
=True)
173 def max_conflicts_with(self
, other
):
174 # type: (RegClass | RegLoc) -> int
175 """the largest number of registers in `self` that a single register
176 from `other` can conflict with
178 if isinstance(other
, RegClass
):
179 return max(self
.max_conflicts_with(i
) for i
in other
)
181 return sum(other
.conflicts(i
) for i
in self
)
184 @plain_data(frozen
=True, unsafe_hash
=True)
185 class RegType(metaclass
=ABCMeta
):
191 # type: () -> RegClass
195 _RegType
= TypeVar("_RegType", bound
=RegType
)
196 _RegLoc
= TypeVar("_RegLoc", bound
=RegLoc
)
199 @plain_data(frozen
=True, eq
=False, repr=False)
201 class GPRRangeType(RegType
):
202 __slots__
= "length",
204 def __init__(self
, length
=1):
205 # type: (int) -> None
206 if length
< 1 or length
> GPR_COUNT
:
207 raise ValueError("invalid length")
211 @lru_cache(maxsize
=None)
212 def __get_reg_class(length
):
213 # type: (int) -> RegClass
215 for start
in range(GPR_COUNT
- length
):
216 reg
= GPRRange(start
, length
)
217 if any(i
in reg
for i
in SPECIAL_GPRS
):
220 return RegClass(regs
)
225 # type: () -> RegClass
226 return GPRRangeType
.__get
_reg
_class
(self
.length
)
229 def __eq__(self
, other
):
230 if isinstance(other
, GPRRangeType
):
231 return self
.length
== other
.length
236 return hash(self
.length
)
239 return f
"<gpr_ty[{self.length}]>"
242 GPRType
= GPRRangeType
243 """a length=1 GPRRangeType"""
246 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
248 class FixedGPRRangeType(RegType
):
251 def __init__(self
, reg
):
252 # type: (GPRRange) -> None
257 # type: () -> RegClass
258 return RegClass([self
.reg
])
263 return self
.reg
.length
266 return f
"<fixed({self.reg})>"
269 @plain_data(frozen
=True, unsafe_hash
=True)
271 class CAType(RegType
):
276 # type: () -> RegClass
277 return RegClass([XERBit
.CA
])
280 @plain_data(frozen
=True, unsafe_hash
=True)
282 class GlobalMemType(RegType
):
287 # type: () -> RegClass
288 return RegClass([GlobalMem
.GlobalMem
])
291 @plain_data(frozen
=True, unsafe_hash
=True)
293 class KnownVLType(RegType
):
294 __slots__
= "length",
296 def __init__(self
, length
):
297 # type: (int) -> None
298 if not (0 < length
<= 64):
299 raise ValueError("invalid VL value")
304 # type: () -> RegClass
305 return RegClass([VL
.VL_MAXVL
])
308 def assert_vl_is(vl
, expected_vl
):
309 # type: (SSAKnownVL | KnownVLType | int | None, int) -> None
312 elif isinstance(vl
, SSAVal
):
314 elif isinstance(vl
, KnownVLType
):
316 if vl
!= expected_vl
:
318 f
"wrong VL: expected {expected_vl} got {vl}")
324 @plain_data(frozen
=True, unsafe_hash
=True)
326 class StackSlot(RegLoc
):
327 __slots__
= "start_slot", "length_in_slots",
329 def __init__(self
, start_slot
, length_in_slots
):
330 # type: (int, int) -> None
331 self
.start_slot
= start_slot
332 if length_in_slots
< 1:
333 raise ValueError("invalid length_in_slots")
334 self
.length_in_slots
= length_in_slots
338 return self
.start_slot
+ self
.length_in_slots
341 def start_byte(self
):
342 return self
.start_slot
* STACK_SLOT_SIZE
344 def conflicts(self
, other
):
345 # type: (RegLoc) -> bool
346 if isinstance(other
, StackSlot
):
347 return (self
.stop_slot
> other
.start_slot
348 and other
.stop_slot
> self
.start_slot
)
351 def get_subreg_at_offset(self
, subreg_type
, offset
):
352 # type: (RegType, int) -> StackSlot
353 if not isinstance(subreg_type
, StackSlotType
):
354 raise ValueError(f
"subreg_type is not a "
355 f
"StackSlotType: {subreg_type}")
356 if offset
< 0 or offset
+ subreg_type
.length_in_slots
> self
.stop_slot
:
357 raise ValueError(f
"sub-register offset is out of range: {offset}")
358 return StackSlot(self
.start_slot
+ offset
, subreg_type
.length_in_slots
)
361 STACK_SLOT_COUNT
= 128
364 @plain_data(frozen
=True, eq
=False)
366 class StackSlotType(RegType
):
367 __slots__
= "length_in_slots",
369 def __init__(self
, length_in_slots
=1):
370 # type: (int) -> None
371 if length_in_slots
< 1:
372 raise ValueError("invalid length_in_slots")
373 self
.length_in_slots
= length_in_slots
376 @lru_cache(maxsize
=None)
377 def __get_reg_class(length_in_slots
):
378 # type: (int) -> RegClass
380 for start
in range(STACK_SLOT_COUNT
- length_in_slots
):
381 reg
= StackSlot(start
, length_in_slots
)
383 return RegClass(regs
)
387 # type: () -> RegClass
388 return StackSlotType
.__get
_reg
_class
(self
.length_in_slots
)
391 def __eq__(self
, other
):
392 if isinstance(other
, StackSlotType
):
393 return self
.length_in_slots
== other
.length_in_slots
398 return hash(self
.length_in_slots
)
401 @plain_data(frozen
=True, eq
=False, repr=False)
403 class SSAVal(Generic
[_RegType
]):
404 __slots__
= "op", "arg_name", "ty",
406 def __init__(self
, op
, arg_name
, ty
):
407 # type: (Op, str, _RegType) -> None
409 """the Op that writes this SSAVal"""
411 self
.arg_name
= arg_name
412 """the name of the argument of self.op that writes this SSAVal"""
416 def __eq__(self
, rhs
):
417 if isinstance(rhs
, SSAVal
):
418 return (self
.op
is rhs
.op
419 and self
.arg_name
== rhs
.arg_name
)
423 return hash((id(self
.op
), self
.arg_name
))
426 return f
"<#{self.op.id}.{self.arg_name}: {self.ty}>"
429 SSAGPRRange
= SSAVal
[GPRRangeType
]
430 SSAGPR
= SSAVal
[GPRType
]
431 SSAKnownVL
= SSAVal
[KnownVLType
]
435 @plain_data(unsafe_hash
=True, frozen
=True)
436 class EqualityConstraint
:
437 __slots__
= "lhs", "rhs"
439 def __init__(self
, lhs
, rhs
):
440 # type: (list[SSAVal], list[SSAVal]) -> None
443 if len(lhs
) == 0 or len(rhs
) == 0:
444 raise ValueError("can't constrain an empty list to be equal")
453 self
.ops
= [] # type: list[Op]
455 def __repr__(self
, short
=False):
458 ops
= ", ".join(op
.__repr
__(just_id
=True) for op
in self
.ops
)
459 return f
"<Fn([{ops}])>"
461 def pre_ra_sim(self
, state
):
462 # type: (PreRASimState) -> None
468 """ helper for __repr__ for when fields aren't set """
477 @plain_data(frozen
=True, unsafe_hash
=True)
478 class AsmTemplateSegment(Generic
[_RegType
], metaclass
=ABCMeta
):
479 __slots__
= "ssa_val",
481 def __init__(self
, ssa_val
):
482 # type: (SSAVal[_RegType]) -> None
483 self
.ssa_val
= ssa_val
485 def render(self
, regs
):
486 # type: (dict[SSAVal, RegLoc]) -> str
487 return self
._render
(regs
[self
.ssa_val
])
490 def _render(self
, reg
):
491 # type: (RegLoc) -> str
495 @plain_data(frozen
=True, unsafe_hash
=True)
497 class ATSGPR(AsmTemplateSegment
[GPRRangeType
]):
498 __slots__
= "offset",
500 def __init__(self
, ssa_val
, offset
=0):
501 # type: (SSAGPRRange, int) -> None
502 super().__init
__(ssa_val
)
505 def _render(self
, reg
):
506 # type: (RegLoc) -> str
507 if not isinstance(reg
, GPRRange
):
509 return str(reg
.start
+ self
.offset
)
512 @plain_data(frozen
=True, unsafe_hash
=True)
514 class ATSStackSlot(AsmTemplateSegment
[StackSlotType
]):
517 def _render(self
, reg
):
518 # type: (RegLoc) -> str
519 if not isinstance(reg
, StackSlot
):
521 return f
"{reg.start_slot}(1)"
524 @plain_data(frozen
=True, unsafe_hash
=True)
526 class ATSCopyGPRRange(AsmTemplateSegment
["GPRRangeType | FixedGPRRangeType"]):
527 __slots__
= "src_ssa_val",
529 def __init__(self
, ssa_val
, src_ssa_val
):
530 # type: (SSAVal[GPRRangeType | FixedGPRRangeType], SSAVal[GPRRangeType | FixedGPRRangeType]) -> None
531 self
.ssa_val
= ssa_val
532 self
.src_ssa_val
= src_ssa_val
534 def render(self
, regs
):
535 # type: (dict[SSAVal, RegLoc]) -> str
536 src
= regs
[self
.src_ssa_val
]
537 dest
= regs
[self
.ssa_val
]
538 if not isinstance(dest
, GPRRange
):
540 if not isinstance(src
, GPRRange
):
542 if src
.length
!= dest
.length
:
550 elif src
.conflicts(dest
) and src
.start
> dest
.start
:
552 return f
"{sv_}or{mrr} *{dest.start}, *{src.start}, *{src.start}\n"
554 def _render(self
, reg
):
555 # type: (RegLoc) -> str
556 raise TypeError("must call self.render")
560 class AsmTemplate(Sequence
["str | AsmTemplateSegment"]):
562 def __process_segments(segments
):
563 # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> Iterable[str | AsmTemplateSegment]
565 if isinstance(i
, AsmTemplate
):
570 def __init__(self
, segments
=()):
571 # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> None
572 self
.__segments
= tuple(self
.__process
_segments
(segments
))
574 def __getitem__(self
, index
):
575 # type: (int) -> str | AsmTemplateSegment
576 return self
.__segments
[index
]
579 return len(self
.__segments
)
582 return iter(self
.__segments
)
585 return hash(self
.__segments
)
587 def render(self
, regs
):
588 # type: (dict[SSAVal, RegLoc]) -> str
589 retval
= [] # type: list[str]
591 if isinstance(segment
, AsmTemplateSegment
):
592 retval
.append(segment
.render(regs
))
594 retval
.append(segment
)
595 return "".join(retval
)
600 def __init__(self
, assigned_registers
):
601 # type: (dict[SSAVal, RegLoc]) -> None
602 self
.__assigned
_registers
= assigned_registers
604 def reg(self
, ssa_val
, expected_ty
):
605 # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc
607 reg
= self
.__assigned
_registers
[ssa_val
]
608 except KeyError as e
:
609 raise ValueError(f
"SSAVal not assigned a register: {ssa_val}")
610 wrong_len
= (isinstance(reg
, GPRRange
)
611 and reg
.length
!= ssa_val
.ty
.length
)
612 if not isinstance(reg
, expected_ty
) or wrong_len
:
614 f
"SSAVal is assigned a register of the wrong type: "
615 f
"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}")
618 def gpr_range(self
, ssa_val
):
619 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange
620 return self
.reg(ssa_val
, GPRRange
)
622 def stack_slot(self
, ssa_val
):
623 # type: (SSAVal[StackSlotType]) -> StackSlot
624 return self
.reg(ssa_val
, StackSlot
)
626 def gpr(self
, ssa_val
, vec
, offset
=0):
627 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str
628 reg
= self
.gpr_range(ssa_val
).start
+ offset
629 return "*" * vec
+ str(reg
)
631 def vgpr(self
, ssa_val
, offset
=0):
632 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str
633 return self
.gpr(ssa_val
=ssa_val
, vec
=True, offset
=offset
)
635 def sgpr(self
, ssa_val
, offset
=0):
636 # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str
637 return self
.gpr(ssa_val
=ssa_val
, vec
=False, offset
=offset
)
639 def needs_sv(self
, *regs
):
640 # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool
642 reg
= self
.gpr_range(reg
)
643 if reg
.length
!= 1 or reg
.start
>= 32:
648 GPR_SIZE_IN_BYTES
= 8
649 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* 8
650 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
653 @plain_data(frozen
=True)
656 __slots__
= ("gprs", "VLs", "CAs",
657 "global_mems", "stack_slots",
662 gprs
, # type: dict[SSAGPRRange, tuple[int, ...]]
663 VLs
, # type: dict[SSAKnownVL, int]
664 CAs
, # type: dict[SSAVal[CAType], bool]
665 global_mems
, # type: dict[SSAVal[GlobalMemType], FMap[int, int]]
666 stack_slots
, # type: dict[SSAVal[StackSlotType], tuple[int, ...]]
667 fixed_gprs
, # type: dict[SSAVal[FixedGPRRangeType], tuple[int, ...]]
669 # type: (...) -> None
673 self
.global_mems
= global_mems
674 self
.stack_slots
= stack_slots
675 self
.fixed_gprs
= fixed_gprs
678 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
679 class Op(metaclass
=ABCMeta
):
680 __slots__
= "id", "fn"
684 # type: () -> dict[str, SSAVal]
689 # type: () -> dict[str, SSAVal]
692 def get_equality_constraints(self
):
693 # type: () -> Iterable[EqualityConstraint]
697 def get_extra_interferences(self
):
698 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
702 def __init__(self
, fn
):
704 self
.id = len(fn
.ops
)
709 def __repr__(self
, just_id
=False):
710 fields_list
= [f
"#{self.id}"]
713 outputs
= self
.outputs()
714 except AttributeError:
717 for name
in fields(self
):
718 if name
in ("id", "fn"):
720 v
= getattr(self
, name
, _NOT_SET
)
721 if (outputs
is not None and name
in outputs
722 and outputs
[name
] is v
):
723 fields_list
.append(repr(v
))
725 fields_list
.append(f
"{name}={v!r}")
726 fields_str
= ', '.join(fields_list
)
727 return f
"{self.__class__.__name__}({fields_str})"
730 def get_asm_lines(self
, ctx
):
731 # type: (AsmContext) -> list[str]
732 """get the lines of assembly for this Op"""
736 def pre_ra_sim(self
, state
):
737 # type: (PreRASimState) -> None
738 """simulate op before register allocation"""
742 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
744 class OpLoadFromStackSlot(Op
):
745 __slots__
= "dest", "src", "vl"
748 # type: () -> dict[str, SSAVal]
749 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
750 if self
.vl
is not None:
751 retval
["vl"] = self
.vl
755 # type: () -> dict[str, SSAVal]
756 return {"dest": self
.dest
}
758 def __init__(self
, fn
, src
, vl
=None):
759 # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None
761 self
.dest
= SSAVal(self
, "dest", GPRRangeType(src
.ty
.length_in_slots
))
764 assert_vl_is(vl
, self
.dest
.ty
.length
)
766 def get_asm_lines(self
, ctx
):
767 # type: (AsmContext) -> list[str]
768 dest
= ctx
.gpr(self
.dest
, vec
=self
.dest
.ty
.length
!= 1)
769 src
= ctx
.stack_slot(self
.src
)
770 if ctx
.needs_sv(self
.dest
):
771 return [f
"sv.ld {dest}, {src.start_byte}(1)"]
772 return [f
"ld {dest}, {src.start_byte}(1)"]
774 def pre_ra_sim(self
, state
):
775 # type: (PreRASimState) -> None
776 """simulate op before register allocation"""
777 state
.gprs
[self
.dest
] = state
.stack_slots
[self
.src
]
780 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
782 class OpStoreToStackSlot(Op
):
783 __slots__
= "dest", "src", "vl"
786 # type: () -> dict[str, SSAVal]
787 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
788 if self
.vl
is not None:
789 retval
["vl"] = self
.vl
793 # type: () -> dict[str, SSAVal]
794 return {"dest": self
.dest
}
796 def __init__(self
, fn
, src
, vl
=None):
797 # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None
799 self
.dest
= SSAVal(self
, "dest", StackSlotType(src
.ty
.length
))
802 assert_vl_is(vl
, src
.ty
.length
)
804 def get_asm_lines(self
, ctx
):
805 # type: (AsmContext) -> list[str]
806 src
= ctx
.gpr(self
.src
, vec
=self
.src
.ty
.length
!= 1)
807 dest
= ctx
.stack_slot(self
.dest
)
808 if ctx
.needs_sv(self
.src
):
809 return [f
"sv.std {src}, {dest.start_byte}(1)"]
810 return [f
"std {src}, {dest.start_byte}(1)"]
812 def pre_ra_sim(self
, state
):
813 # type: (PreRASimState) -> None
814 """simulate op before register allocation"""
815 state
.stack_slots
[self
.dest
] = state
.gprs
[self
.src
]
818 _RegSrcType
= TypeVar("_RegSrcType", bound
=RegType
)
821 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
823 class OpCopy(Op
, Generic
[_RegSrcType
, _RegType
]):
824 __slots__
= "dest", "src", "vl"
827 # type: () -> dict[str, SSAVal]
828 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
829 if self
.vl
is not None:
830 retval
["vl"] = self
.vl
834 # type: () -> dict[str, SSAVal]
835 return {"dest": self
.dest
}
837 def __init__(self
, fn
, src
, dest_ty
=None, vl
=None):
838 # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None
841 dest_ty
= cast(_RegType
, src
.ty
)
842 if isinstance(src
.ty
, GPRRangeType
) \
843 and isinstance(dest_ty
, FixedGPRRangeType
):
844 if src
.ty
.length
!= dest_ty
.reg
.length
:
845 raise ValueError(f
"incompatible source and destination "
846 f
"types: {src.ty} and {dest_ty}")
847 length
= src
.ty
.length
848 elif isinstance(src
.ty
, FixedGPRRangeType
) \
849 and isinstance(dest_ty
, GPRRangeType
):
850 if src
.ty
.reg
.length
!= dest_ty
.length
:
851 raise ValueError(f
"incompatible source and destination "
852 f
"types: {src.ty} and {dest_ty}")
853 length
= src
.ty
.length
854 elif src
.ty
!= dest_ty
:
855 raise ValueError(f
"incompatible source and destination "
856 f
"types: {src.ty} and {dest_ty}")
857 elif isinstance(src
.ty
, StackSlotType
):
858 raise ValueError("can't use OpCopy on stack slots")
859 elif isinstance(src
.ty
, (GPRRangeType
, FixedGPRRangeType
)):
860 length
= src
.ty
.length
864 self
.dest
= SSAVal(self
, "dest", dest_ty
) # type: SSAVal[_RegType]
867 assert_vl_is(vl
, length
)
869 def get_asm_lines(self
, ctx
):
870 # type: (AsmContext) -> list[str]
871 if ctx
.reg(self
.src
, RegLoc
) == ctx
.reg(self
.dest
, RegLoc
):
873 if (isinstance(self
.src
.ty
, (GPRRangeType
, FixedGPRRangeType
)) and
874 isinstance(self
.dest
.ty
, (GPRRangeType
, FixedGPRRangeType
))):
875 vec
= self
.dest
.ty
.length
!= 1
876 dest
= ctx
.gpr_range(self
.dest
) # type: ignore
877 src
= ctx
.gpr_range(self
.src
) # type: ignore
878 dest_s
= ctx
.gpr(self
.dest
, vec
=vec
) # type: ignore
879 src_s
= ctx
.gpr(self
.src
, vec
=vec
) # type: ignore
881 if src
.conflicts(dest
) and src
.start
> dest
.start
:
883 if ctx
.needs_sv(self
.src
, self
.dest
): # type: ignore
884 return [f
"sv.or{mrr} {dest_s}, {src_s}, {src_s}"]
885 return [f
"or {dest_s}, {src_s}, {src_s}"]
886 raise NotImplementedError
888 def pre_ra_sim(self
, state
):
889 # type: (PreRASimState) -> None
890 if (isinstance(self
.src
.ty
, (GPRRangeType
, FixedGPRRangeType
)) and
891 isinstance(self
.dest
.ty
, (GPRRangeType
, FixedGPRRangeType
))):
892 if isinstance(self
.src
.ty
, GPRRangeType
):
893 v
= state
.gprs
[self
.src
] # type: ignore
895 v
= state
.fixed_gprs
[self
.src
] # type: ignore
896 if isinstance(self
.dest
.ty
, GPRRangeType
):
897 state
.gprs
[self
.dest
] = v
# type: ignore
899 state
.fixed_gprs
[self
.dest
] = v
# type: ignore
900 elif (isinstance(self
.src
.ty
, FixedGPRRangeType
) and
901 isinstance(self
.dest
.ty
, GPRRangeType
)):
902 state
.gprs
[self
.dest
] = state
.fixed_gprs
[self
.src
] # type: ignore
903 elif (isinstance(self
.src
.ty
, GPRRangeType
) and
904 isinstance(self
.dest
.ty
, FixedGPRRangeType
)):
905 state
.fixed_gprs
[self
.dest
] = state
.gprs
[self
.src
] # type: ignore
906 elif (isinstance(self
.src
.ty
, CAType
) and
907 self
.src
.ty
== self
.dest
.ty
):
908 state
.CAs
[self
.dest
] = state
.CAs
[self
.src
] # type: ignore
909 elif (isinstance(self
.src
.ty
, KnownVLType
) and
910 self
.src
.ty
== self
.dest
.ty
):
911 state
.VLs
[self
.dest
] = state
.VLs
[self
.src
] # type: ignore
912 elif (isinstance(self
.src
.ty
, GlobalMemType
) and
913 self
.src
.ty
== self
.dest
.ty
):
914 v
= state
.global_mems
[self
.src
] # type: ignore
915 state
.global_mems
[self
.dest
] = v
# type: ignore
917 raise NotImplementedError
920 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
923 __slots__
= "dest", "sources"
926 # type: () -> dict[str, SSAVal]
927 return {f
"sources[{i}]": v
for i
, v
in enumerate(self
.sources
)}
930 # type: () -> dict[str, SSAVal]
931 return {"dest": self
.dest
}
933 def __init__(self
, fn
, sources
):
934 # type: (Fn, Iterable[SSAGPRRange]) -> None
936 sources
= tuple(sources
)
937 self
.dest
= SSAVal(self
, "dest", GPRRangeType(
938 sum(i
.ty
.length
for i
in sources
)))
939 self
.sources
= sources
941 def get_equality_constraints(self
):
942 # type: () -> Iterable[EqualityConstraint]
943 yield EqualityConstraint([self
.dest
], [*self
.sources
])
945 def get_asm_lines(self
, ctx
):
946 # type: (AsmContext) -> list[str]
949 def pre_ra_sim(self
, state
):
950 # type: (PreRASimState) -> None
952 for src
in self
.sources
:
953 v
.extend(state
.gprs
[src
])
954 state
.gprs
[self
.dest
] = tuple(v
)
957 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
960 __slots__
= "results", "src"
963 # type: () -> dict[str, SSAVal]
964 return {"src": self
.src
}
967 # type: () -> dict[str, SSAVal]
968 return {i
.arg_name
: i
for i
in self
.results
}
970 def __init__(self
, fn
, src
, split_indexes
):
971 # type: (Fn, SSAGPRRange, Iterable[int]) -> None
973 ranges
= [] # type: list[GPRRangeType]
975 for i
in split_indexes
:
976 if not (0 < i
< src
.ty
.length
):
977 raise ValueError(f
"invalid split index: {i}, must be in "
978 f
"0 < i < {src.ty.length}")
979 ranges
.append(GPRRangeType(i
- last
))
981 ranges
.append(GPRRangeType(src
.ty
.length
- last
))
983 self
.results
= tuple(
984 SSAVal(self
, f
"results[{i}]", r
) for i
, r
in enumerate(ranges
))
986 def get_equality_constraints(self
):
987 # type: () -> Iterable[EqualityConstraint]
988 yield EqualityConstraint([*self
.results
], [self
.src
])
990 def get_asm_lines(self
, ctx
):
991 # type: (AsmContext) -> list[str]
994 def pre_ra_sim(self
, state
):
995 # type: (PreRASimState) -> None
996 rest
= state
.gprs
[self
.src
]
997 for dest
in reversed(self
.results
):
998 state
.gprs
[dest
] = rest
[-dest
.ty
.length
:]
999 rest
= rest
[:-dest
.ty
.length
]
1002 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1004 class OpBigIntAddSub(Op
):
1005 __slots__
= "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl"
1008 # type: () -> dict[str, SSAVal]
1009 retval
= {} # type: dict[str, SSAVal[Any]]
1010 retval
["lhs"] = self
.lhs
1011 retval
["rhs"] = self
.rhs
1012 retval
["CA_in"] = self
.CA_in
1013 if self
.vl
is not None:
1014 retval
["vl"] = self
.vl
1018 # type: () -> dict[str, SSAVal]
1019 return {"out": self
.out
, "CA_out": self
.CA_out
}
1021 def __init__(self
, fn
, lhs
, rhs
, CA_in
, is_sub
, vl
=None):
1022 # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None
1023 super().__init
__(fn
)
1024 if lhs
.ty
!= rhs
.ty
:
1025 raise TypeError(f
"source types must match: "
1026 f
"{lhs} doesn't match {rhs}")
1027 self
.out
= SSAVal(self
, "out", lhs
.ty
)
1031 self
.CA_out
= SSAVal(self
, "CA_out", CA_in
.ty
)
1032 self
.is_sub
= is_sub
1034 assert_vl_is(vl
, lhs
.ty
.length
)
1036 def get_extra_interferences(self
):
1037 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1038 yield self
.out
, self
.lhs
1039 yield self
.out
, self
.rhs
1041 def get_asm_lines(self
, ctx
):
1042 # type: (AsmContext) -> list[str]
1043 vec
= self
.out
.ty
.length
!= 1
1044 out
= ctx
.gpr(self
.out
, vec
=vec
)
1045 RA
= ctx
.gpr(self
.lhs
, vec
=vec
)
1046 RB
= ctx
.gpr(self
.rhs
, vec
=vec
)
1050 RA
, RB
= RB
, RA
# reorder to match subfe
1051 if ctx
.needs_sv(self
.out
, self
.lhs
, self
.rhs
):
1052 return [f
"sv.{mnemonic} {out}, {RA}, {RB}"]
1053 return [f
"{mnemonic} {out}, {RA}, {RB}"]
1055 def pre_ra_sim(self
, state
):
1056 # type: (PreRASimState) -> None
1057 carry
= state
.CAs
[self
.CA_in
]
1058 out
= [] # type: list[int]
1059 for l
, r
in zip(state
.gprs
[self
.lhs
], state
.gprs
[self
.rhs
]):
1061 r
= r ^ GPR_VALUE_MASK
1063 carry
= s
!= (s
& GPR_VALUE_MASK
)
1064 out
.append(s
& GPR_VALUE_MASK
)
1065 state
.CAs
[self
.CA_out
] = carry
1066 state
.gprs
[self
.out
] = tuple(out
)
1069 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1071 class OpBigIntMulDiv(Op
):
1072 __slots__
= "RT", "RA", "RB", "RC", "RS", "is_div", "vl"
1075 # type: () -> dict[str, SSAVal]
1076 retval
= {} # type: dict[str, SSAVal[Any]]
1077 retval
["RA"] = self
.RA
1078 retval
["RB"] = self
.RB
1079 retval
["RC"] = self
.RC
1080 if self
.vl
is not None:
1081 retval
["vl"] = self
.vl
1085 # type: () -> dict[str, SSAVal]
1086 return {"RT": self
.RT
, "RS": self
.RS
}
1088 def __init__(self
, fn
, RA
, RB
, RC
, is_div
, vl
):
1089 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None
1090 super().__init
__(fn
)
1091 self
.RT
= SSAVal(self
, "RT", RA
.ty
)
1095 self
.RS
= SSAVal(self
, "RS", RC
.ty
)
1096 self
.is_div
= is_div
1098 assert_vl_is(vl
, RA
.ty
.length
)
1100 def get_equality_constraints(self
):
1101 # type: () -> Iterable[EqualityConstraint]
1102 yield EqualityConstraint([self
.RC
], [self
.RS
])
1104 def get_extra_interferences(self
):
1105 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1106 yield self
.RT
, self
.RA
1107 yield self
.RT
, self
.RB
1108 yield self
.RT
, self
.RC
1109 yield self
.RT
, self
.RS
1110 yield self
.RS
, self
.RA
1111 yield self
.RS
, self
.RB
1113 def get_asm_lines(self
, ctx
):
1114 # type: (AsmContext) -> list[str]
1115 vec
= self
.RT
.ty
.length
!= 1
1116 RT
= ctx
.gpr(self
.RT
, vec
=vec
)
1117 RA
= ctx
.gpr(self
.RA
, vec
=vec
)
1118 RB
= ctx
.sgpr(self
.RB
)
1119 RC
= ctx
.sgpr(self
.RC
)
1120 mnemonic
= "maddedu"
1122 mnemonic
= "divmod2du/mrr"
1123 return [f
"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
1125 def pre_ra_sim(self
, state
):
1126 # type: (PreRASimState) -> None
1127 carry
= state
.gprs
[self
.RC
][0]
1128 RA
= state
.gprs
[self
.RA
]
1129 RB
= state
.gprs
[self
.RB
][0]
1130 RT
= [0] * self
.RT
.ty
.length
1132 for i
in reversed(range(self
.RT
.ty
.length
)):
1133 if carry
< RB
and RB
!= 0:
1134 div
, mod
= divmod((carry
<< 64) | RA
[i
], RB
)
1135 RT
[i
] = div
& GPR_VALUE_MASK
1136 carry
= mod
& GPR_VALUE_MASK
1138 RT
[i
] = GPR_VALUE_MASK
1141 for i
in range(self
.RT
.ty
.length
):
1142 v
= RA
[i
] * RB
+ carry
1144 RT
[i
] = v
& GPR_VALUE_MASK
1145 state
.gprs
[self
.RS
] = carry
,
1146 state
.gprs
[self
.RT
] = tuple(RT
)
1151 class ShiftKind(Enum
):
1156 def make_big_int_carry_in(self
, fn
, inp
):
1157 # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]]
1158 if self
is ShiftKind
.Sl
or self
is ShiftKind
.Sr
:
1162 assert self
is ShiftKind
.Sra
1163 split
= OpSplit(fn
, inp
, [inp
.ty
.length
- 1])
1164 shr
= OpShiftImm(fn
, split
.results
[1], sh
=63, kind
=ShiftKind
.Sra
)
1165 return shr
.out
, [split
, shr
]
1167 def make_big_int_shift(self
, fn
, inp
, sh
, vl
):
1168 # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]]
1169 carry_in
, ops
= self
.make_big_int_carry_in(fn
, inp
)
1170 big_int_shift
= OpBigIntShift(fn
, inp
, sh
, carry_in
, kind
=self
, vl
=vl
)
1171 ops
.append(big_int_shift
)
1172 return big_int_shift
.out
, ops
1175 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1177 class OpBigIntShift(Op
):
1178 __slots__
= "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl"
1181 # type: () -> dict[str, SSAVal]
1182 retval
= {} # type: dict[str, SSAVal[Any]]
1183 retval
["inp"] = self
.inp
1184 retval
["sh"] = self
.sh
1185 retval
["carry_in"] = self
.carry_in
1186 if self
.vl
is not None:
1187 retval
["vl"] = self
.vl
1191 # type: () -> dict[str, SSAVal]
1192 return {"out": self
.out
, "_out_padding": self
._out
_padding
}
1194 def __init__(self
, fn
, inp
, sh
, carry_in
, kind
, vl
=None):
1195 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None
1196 super().__init
__(fn
)
1197 self
.out
= SSAVal(self
, "out", inp
.ty
)
1198 self
._out
_padding
= SSAVal(self
, "_out_padding", GPRRangeType())
1199 self
.carry_in
= carry_in
1204 assert_vl_is(vl
, inp
.ty
.length
)
1206 def get_extra_interferences(self
):
1207 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1208 yield self
.out
, self
.sh
1210 def get_equality_constraints(self
):
1211 # type: () -> Iterable[EqualityConstraint]
1212 if self
.kind
is ShiftKind
.Sl
:
1213 yield EqualityConstraint([self
.carry_in
, self
.inp
],
1214 [self
.out
, self
._out
_padding
])
1216 assert self
.kind
is ShiftKind
.Sr
or self
.kind
is ShiftKind
.Sra
1217 yield EqualityConstraint([self
.inp
, self
.carry_in
],
1218 [self
._out
_padding
, self
.out
])
1220 def get_asm_lines(self
, ctx
):
1221 # type: (AsmContext) -> list[str]
1222 vec
= self
.out
.ty
.length
!= 1
1223 if self
.kind
is ShiftKind
.Sl
:
1224 RT
= ctx
.gpr(self
.out
, vec
=vec
)
1225 RA
= ctx
.gpr(self
.out
, vec
=vec
, offset
=-1)
1226 RB
= ctx
.sgpr(self
.sh
)
1227 mrr
= "/mrr" if vec
else ""
1228 return [f
"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"]
1230 assert self
.kind
is ShiftKind
.Sr
or self
.kind
is ShiftKind
.Sra
1231 RT
= ctx
.gpr(self
.out
, vec
=vec
)
1232 RA
= ctx
.gpr(self
.out
, vec
=vec
, offset
=1)
1233 RB
= ctx
.sgpr(self
.sh
)
1234 return [f
"sv.dsrd {RT}, {RA}, {RB}, 1"]
1236 def pre_ra_sim(self
, state
):
1237 # type: (PreRASimState) -> None
1238 out
= [0] * self
.out
.ty
.length
1239 carry
= state
.gprs
[self
.carry_in
][0]
1240 sh
= state
.gprs
[self
.sh
][0] % 64
1241 if self
.kind
is ShiftKind
.Sl
:
1242 inp
= carry
, *state
.gprs
[self
.inp
]
1243 for i
in reversed(range(self
.out
.ty
.length
)):
1244 v
= inp
[i
] |
(inp
[i
+ 1] << 64)
1246 out
[i
] = (v
>> 64) & GPR_VALUE_MASK
1248 assert self
.kind
is ShiftKind
.Sr
or self
.kind
is ShiftKind
.Sra
1249 inp
= *state
.gprs
[self
.inp
], carry
1250 for i
in range(self
.out
.ty
.length
):
1251 v
= inp
[i
] |
(inp
[i
+ 1] << 64)
1253 out
[i
] = v
& GPR_VALUE_MASK
1254 # state.gprs[self._out_padding] is intentionally not written
1255 state
.gprs
[self
.out
] = tuple(out
)
1258 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1260 class OpShiftImm(Op
):
1261 __slots__
= "out", "inp", "sh", "kind", "ca_out"
1264 # type: () -> dict[str, SSAVal]
1265 return {"inp": self
.inp
}
1268 # type: () -> dict[str, SSAVal]
1269 if self
.ca_out
is not None:
1270 return {"out": self
.out
, "ca_out": self
.ca_out
}
1271 return {"out": self
.out
}
1273 def __init__(self
, fn
, inp
, sh
, kind
):
1274 # type: (Fn, SSAGPR, int, ShiftKind) -> None
1275 super().__init
__(fn
)
1276 self
.out
= SSAVal(self
, "out", inp
.ty
)
1278 if not (0 <= sh
< 64):
1279 raise ValueError("shift amount out of range")
1282 if self
.kind
is ShiftKind
.Sra
:
1283 self
.ca_out
= SSAVal(self
, "ca_out", CAType())
1287 def get_asm_lines(self
, ctx
):
1288 # type: (AsmContext) -> list[str]
1289 out
= ctx
.sgpr(self
.out
)
1290 inp
= ctx
.sgpr(self
.inp
)
1291 if self
.kind
is ShiftKind
.Sl
:
1293 args
= f
"{self.sh}, {63 - self.sh}"
1294 elif self
.kind
is ShiftKind
.Sr
:
1296 v
= (64 - self
.sh
) % 64
1297 args
= f
"{v}, {self.sh}"
1299 assert self
.kind
is ShiftKind
.Sra
1302 if ctx
.needs_sv(self
.out
, self
.inp
):
1303 return [f
"sv.{mnemonic} {out}, {inp}, {args}"]
1304 return [f
"{mnemonic} {out}, {inp}, {args}"]
1306 def pre_ra_sim(self
, state
):
1307 # type: (PreRASimState) -> None
1308 inp
= state
.gprs
[self
.inp
][0]
1309 if self
.kind
is ShiftKind
.Sl
:
1310 assert self
.ca_out
is None
1311 out
= inp
<< self
.sh
1312 elif self
.kind
is ShiftKind
.Sr
:
1313 assert self
.ca_out
is None
1314 out
= inp
>> self
.sh
1316 assert self
.kind
is ShiftKind
.Sra
1317 assert self
.ca_out
is not None
1318 if inp
& (1 << 63): # sign extend
1320 out
= inp
>> self
.sh
1321 ca
= inp
< 0 and (out
<< self
.sh
) != inp
1322 state
.CAs
[self
.ca_out
] = ca
1323 state
.gprs
[self
.out
] = out
,
1326 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1329 __slots__
= "out", "value", "vl"
1332 # type: () -> dict[str, SSAVal]
1333 retval
= {} # type: dict[str, SSAVal[Any]]
1334 if self
.vl
is not None:
1335 retval
["vl"] = self
.vl
1339 # type: () -> dict[str, SSAVal]
1340 return {"out": self
.out
}
1342 def __init__(self
, fn
, value
, vl
=None):
1343 # type: (Fn, int, SSAKnownVL | None) -> None
1344 super().__init
__(fn
)
1348 length
= vl
.ty
.length
1349 self
.out
= SSAVal(self
, "out", GPRRangeType(length
))
1350 if not (-1 << 15 <= value
<= (1 << 15) - 1):
1351 raise ValueError(f
"value out of range: {value}")
1354 assert_vl_is(vl
, length
)
1356 def get_asm_lines(self
, ctx
):
1357 # type: (AsmContext) -> list[str]
1358 vec
= self
.out
.ty
.length
!= 1
1359 out
= ctx
.gpr(self
.out
, vec
=vec
)
1360 if ctx
.needs_sv(self
.out
):
1361 return [f
"sv.addi {out}, 0, {self.value}"]
1362 return [f
"addi {out}, 0, {self.value}"]
1364 def pre_ra_sim(self
, state
):
1365 # type: (PreRASimState) -> None
1366 value
= self
.value
& GPR_VALUE_MASK
1367 state
.gprs
[self
.out
] = (value
,) * self
.out
.ty
.length
1370 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1373 __slots__
= "out", "value"
1376 # type: () -> dict[str, SSAVal]
1380 # type: () -> dict[str, SSAVal]
1381 return {"out": self
.out
}
1383 def __init__(self
, fn
, value
):
1384 # type: (Fn, bool) -> None
1385 super().__init
__(fn
)
1386 self
.out
= SSAVal(self
, "out", CAType())
1389 def get_asm_lines(self
, ctx
):
1390 # type: (AsmContext) -> list[str]
1392 return ["subfic 0, 0, -1"]
1393 return ["addic 0, 0, 0"]
1395 def pre_ra_sim(self
, state
):
1396 # type: (PreRASimState) -> None
1397 state
.CAs
[self
.out
] = self
.value
1400 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1403 __slots__
= "RT", "RA", "offset", "mem", "vl"
1406 # type: () -> dict[str, SSAVal]
1407 retval
= {} # type: dict[str, SSAVal[Any]]
1408 retval
["RA"] = self
.RA
1409 retval
["mem"] = self
.mem
1410 if self
.vl
is not None:
1411 retval
["vl"] = self
.vl
1415 # type: () -> dict[str, SSAVal]
1416 return {"RT": self
.RT
}
1418 def __init__(self
, fn
, RA
, offset
, mem
, vl
=None):
1419 # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1420 super().__init
__(fn
)
1424 length
= vl
.ty
.length
1425 self
.RT
= SSAVal(self
, "RT", GPRRangeType(length
))
1427 if not (-1 << 15 <= offset
<= (1 << 15) - 1):
1428 raise ValueError(f
"offset out of range: {offset}")
1430 raise ValueError(f
"offset not aligned: {offset}")
1431 self
.offset
= offset
1434 assert_vl_is(vl
, length
)
1436 def get_extra_interferences(self
):
1437 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1438 if self
.RT
.ty
.length
> 1:
1439 yield self
.RT
, self
.RA
1441 def get_asm_lines(self
, ctx
):
1442 # type: (AsmContext) -> list[str]
1443 RT
= ctx
.gpr(self
.RT
, vec
=self
.RT
.ty
.length
!= 1)
1444 RA
= ctx
.sgpr(self
.RA
)
1445 if ctx
.needs_sv(self
.RT
, self
.RA
):
1446 return [f
"sv.ld {RT}, {self.offset}({RA})"]
1447 return [f
"ld {RT}, {self.offset}({RA})"]
1449 def pre_ra_sim(self
, state
):
1450 # type: (PreRASimState) -> None
1451 addr
= state
.gprs
[self
.RA
][0]
1453 RT
= [0] * self
.RT
.ty
.length
1454 mem
= state
.global_mems
[self
.mem
]
1455 for i
in range(self
.RT
.ty
.length
):
1456 cur_addr
= (addr
+ i
* GPR_SIZE_IN_BYTES
) & GPR_VALUE_MASK
1457 if cur_addr
% GPR_SIZE_IN_BYTES
!= 0:
1458 raise ValueError(f
"can't load from unaligned address: "
1460 for j
in range(GPR_SIZE_IN_BYTES
):
1461 byte_val
= mem
.get(cur_addr
+ j
, 0) & 0xFF
1462 RT
[i
] |
= byte_val
<< (j
* 8)
1463 state
.gprs
[self
.RT
] = tuple(RT
)
1466 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1469 __slots__
= "RS", "RA", "offset", "mem_in", "mem_out", "vl"
1472 # type: () -> dict[str, SSAVal]
1473 retval
= {} # type: dict[str, SSAVal[Any]]
1474 retval
["RS"] = self
.RS
1475 retval
["RA"] = self
.RA
1476 retval
["mem_in"] = self
.mem_in
1477 if self
.vl
is not None:
1478 retval
["vl"] = self
.vl
1482 # type: () -> dict[str, SSAVal]
1483 return {"mem_out": self
.mem_out
}
1485 def __init__(self
, fn
, RS
, RA
, offset
, mem_in
, vl
=None):
1486 # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1487 super().__init
__(fn
)
1490 if not (-1 << 15 <= offset
<= (1 << 15) - 1):
1491 raise ValueError(f
"offset out of range: {offset}")
1493 raise ValueError(f
"offset not aligned: {offset}")
1494 self
.offset
= offset
1495 self
.mem_in
= mem_in
1496 self
.mem_out
= SSAVal(self
, "mem_out", mem_in
.ty
)
1498 assert_vl_is(vl
, RS
.ty
.length
)
1500 def get_asm_lines(self
, ctx
):
1501 # type: (AsmContext) -> list[str]
1502 RS
= ctx
.gpr(self
.RS
, vec
=self
.RS
.ty
.length
!= 1)
1503 RA
= ctx
.sgpr(self
.RA
)
1504 if ctx
.needs_sv(self
.RS
, self
.RA
):
1505 return [f
"sv.std {RS}, {self.offset}({RA})"]
1506 return [f
"std {RS}, {self.offset}({RA})"]
1508 def pre_ra_sim(self
, state
):
1509 # type: (PreRASimState) -> None
1510 mem
= dict(state
.global_mems
[self
.mem_in
])
1511 addr
= state
.gprs
[self
.RA
][0]
1513 RS
= state
.gprs
[self
.RS
]
1514 for i
in range(self
.RS
.ty
.length
):
1515 cur_addr
= (addr
+ i
* GPR_SIZE_IN_BYTES
) & GPR_VALUE_MASK
1516 if cur_addr
% GPR_SIZE_IN_BYTES
!= 0:
1517 raise ValueError(f
"can't store to unaligned address: "
1519 for j
in range(GPR_SIZE_IN_BYTES
):
1520 mem
[cur_addr
+ j
] = (RS
[i
] >> (j
* 8)) & 0xFF
1521 state
.global_mems
[self
.mem_out
] = FMap(mem
)
1524 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1526 class OpFuncArg(Op
):
1530 # type: () -> dict[str, SSAVal]
1534 # type: () -> dict[str, SSAVal]
1535 return {"out": self
.out
}
1537 def __init__(self
, fn
, ty
):
1538 # type: (Fn, FixedGPRRangeType) -> None
1539 super().__init
__(fn
)
1540 self
.out
= SSAVal(self
, "out", ty
)
1542 def get_asm_lines(self
, ctx
):
1543 # type: (AsmContext) -> list[str]
1546 def pre_ra_sim(self
, state
):
1547 # type: (PreRASimState) -> None
1548 if self
.out
not in state
.fixed_gprs
:
1549 state
.fixed_gprs
[self
.out
] = (0,) * self
.out
.ty
.length
1552 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1554 class OpInputMem(Op
):
1558 # type: () -> dict[str, SSAVal]
1562 # type: () -> dict[str, SSAVal]
1563 return {"out": self
.out
}
1565 def __init__(self
, fn
):
1566 # type: (Fn) -> None
1567 super().__init
__(fn
)
1568 self
.out
= SSAVal(self
, "out", GlobalMemType())
1570 def get_asm_lines(self
, ctx
):
1571 # type: (AsmContext) -> list[str]
1574 def pre_ra_sim(self
, state
):
1575 # type: (PreRASimState) -> None
1576 if self
.out
not in state
.global_mems
:
1577 state
.global_mems
[self
.out
] = FMap()
1580 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1582 class OpSetVLImm(Op
):
1586 # type: () -> dict[str, SSAVal]
1590 # type: () -> dict[str, SSAVal]
1591 return {"out": self
.out
}
1593 def __init__(self
, fn
, length
):
1594 # type: (Fn, int) -> None
1595 super().__init
__(fn
)
1596 self
.out
= SSAVal(self
, "out", KnownVLType(length
))
1598 def get_asm_lines(self
, ctx
):
1599 # type: (AsmContext) -> list[str]
1600 return [f
"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
1602 def pre_ra_sim(self
, state
):
1603 # type: (PreRASimState) -> None
1604 state
.VLs
[self
.out
] = self
.out
.ty
.length
1607 def op_set_to_list(ops
):
1608 # type: (Iterable[Op]) -> list[Op]
1609 worklists
= [{}] # type: list[dict[Op, None]]
1610 inps_to_ops_map
= defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
1611 ops_to_pending_input_count_map
= {} # type: dict[Op, int]
1614 for val
in op
.inputs().values():
1616 inps_to_ops_map
[val
][op
] = None
1617 while len(worklists
) <= input_count
:
1618 worklists
.append({})
1619 ops_to_pending_input_count_map
[op
] = input_count
1620 worklists
[input_count
][op
] = None
1621 retval
= [] # type: list[Op]
1622 ready_vals
= OSet() # type: OSet[SSAVal]
1623 while len(worklists
[0]) != 0:
1624 writing_op
= next(iter(worklists
[0]))
1625 del worklists
[0][writing_op
]
1626 retval
.append(writing_op
)
1627 for val
in writing_op
.outputs().values():
1628 if val
in ready_vals
:
1629 raise ValueError(f
"multiple instructions must not write "
1630 f
"to the same SSA value: {val}")
1632 for reading_op
in inps_to_ops_map
[val
]:
1633 pending
= ops_to_pending_input_count_map
[reading_op
]
1634 del worklists
[pending
][reading_op
]
1636 worklists
[pending
][reading_op
] = None
1637 ops_to_pending_input_count_map
[reading_op
] = pending
1638 for worklist
in worklists
:
1640 raise ValueError(f
"instruction is part of a dependency loop or "
1641 f
"its inputs are never written: {op}")
1645 def generate_assembly(ops
, assigned_registers
=None):
1646 # type: (list[Op], dict[SSAVal, RegLoc] | None) -> list[str]
1647 if assigned_registers
is None:
1648 from bigint_presentation_code
.register_allocator
import \
1650 assigned_registers
= allocate_registers(ops
)
1651 ctx
= AsmContext(assigned_registers
)
1652 retval
= [] # list[str]
1654 retval
.extend(op
.get_asm_lines(ctx
))
1655 retval
.append("bclr 20, 0, 0")