3 Compiler IR for Toom-Cook algorithm generator for SVP64
5 This assumes VL != 0 throughout.
8 from abc
import ABCMeta
, abstractmethod
9 from collections
import defaultdict
10 from enum
import Enum
, EnumMeta
, unique
11 from functools
import lru_cache
12 from typing
import Any
, Generic
, Iterable
, Sequence
, Type
, TypeVar
, cast
14 from nmutil
.plain_data
import fields
, plain_data
16 from bigint_presentation_code
.type_util
import final
17 from bigint_presentation_code
.util
import FMap
, OFSet
, OSet
20 class ABCEnumMeta(EnumMeta
, ABCMeta
):
24 class RegLoc(metaclass
=ABCMeta
):
28 def conflicts(self
, other
):
29 # type: (RegLoc) -> bool
32 def get_subreg_at_offset(self
, subreg_type
, offset
):
33 # type: (RegType, int) -> RegLoc
34 if self
not in subreg_type
.reg_class
:
35 raise ValueError(f
"register not a member of subreg_type: "
36 f
"reg={self} subreg_type={subreg_type}")
38 raise ValueError(f
"non-zero sub-register offset not supported "
39 f
"for register: {self}")
46 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
48 class GPRRange(RegLoc
, Sequence
["GPRRange"]):
49 __slots__
= "start", "length"
51 def __init__(self
, start
, length
=None):
52 # type: (int | range, int | None) -> None
53 if isinstance(start
, range):
54 if length
is not None:
55 raise TypeError("can't specify length when input is a range")
57 raise ValueError("range must have a step of 1")
62 if length
<= 0 or start
< 0 or start
+ length
> GPR_COUNT
:
63 raise ValueError("invalid GPRRange")
69 return self
.start
+ self
.length
77 return range(self
.start
, self
.stop
, self
.step
)
82 def __getitem__(self
, item
):
83 # type: (int | slice) -> GPRRange
84 return GPRRange(self
.range[item
])
86 def __contains__(self
, value
):
87 # type: (GPRRange) -> bool
88 return value
.start
>= self
.start
and value
.stop
<= self
.stop
90 def index(self
, sub
, start
=None, end
=None):
91 # type: (GPRRange, int | None, int | None) -> int
92 r
= self
.range[start
:end
]
93 if sub
.start
< r
.start
or sub
.stop
> r
.stop
:
94 raise ValueError("GPR range not found")
95 return sub
.start
- self
.start
97 def count(self
, sub
, start
=None, end
=None):
98 # type: (GPRRange, int | None, int | None) -> int
99 r
= self
.range[start
:end
]
102 return int(sub
in GPRRange(r
))
104 def conflicts(self
, other
):
105 # type: (RegLoc) -> bool
106 if isinstance(other
, GPRRange
):
107 return self
.stop
> other
.start
and other
.stop
> self
.start
110 def get_subreg_at_offset(self
, subreg_type
, offset
):
111 # type: (RegType, int) -> GPRRange
112 if not isinstance(subreg_type
, (GPRRangeType
, FixedGPRRangeType
)):
113 raise ValueError(f
"subreg_type is not a FixedGPRRangeType or "
114 f
"GPRRangeType: {subreg_type}")
115 if offset
< 0 or offset
+ subreg_type
.length
> self
.stop
:
116 raise ValueError(f
"sub-register offset is out of range: {offset}")
117 return GPRRange(self
.start
+ offset
, subreg_type
.length
)
121 return f
"<r{self.start}>"
122 return f
"<r{self.start}..len={self.length}>"
125 SPECIAL_GPRS
= GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
130 class XERBit(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
133 def conflicts(self
, other
):
134 # type: (RegLoc) -> bool
135 if isinstance(other
, XERBit
):
142 class GlobalMem(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
143 """singleton representing all non-StackSlot memory -- treated as a single
144 physical register for register allocation purposes.
146 GlobalMem
= "GlobalMem"
148 def conflicts(self
, other
):
149 # type: (RegLoc) -> bool
150 if isinstance(other
, GlobalMem
):
157 class VL(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
158 VL_MAXVL
= "VL_MAXVL"
161 def conflicts(self
, other
):
162 # type: (RegLoc) -> bool
163 if isinstance(other
, VL
):
169 class RegClass(OFSet
[RegLoc
]):
170 """ an ordered set of registers.
171 earlier registers are preferred by the register allocator.
174 @lru_cache(maxsize
=None, typed
=True)
175 def max_conflicts_with(self
, other
):
176 # type: (RegClass | RegLoc) -> int
177 """the largest number of registers in `self` that a single register
178 from `other` can conflict with
180 if isinstance(other
, RegClass
):
181 return max(self
.max_conflicts_with(i
) for i
in other
)
183 return sum(other
.conflicts(i
) for i
in self
)
186 @plain_data(frozen
=True, unsafe_hash
=True)
187 class RegType(metaclass
=ABCMeta
):
193 # type: () -> RegClass
197 _RegType
= TypeVar("_RegType", bound
=RegType
)
198 _RegLoc
= TypeVar("_RegLoc", bound
=RegLoc
)
201 @plain_data(frozen
=True, eq
=False, repr=False)
203 class GPRRangeType(RegType
):
204 __slots__
= "length",
206 def __init__(self
, length
=1):
207 # type: (int) -> None
208 if length
< 1 or length
> GPR_COUNT
:
209 raise ValueError("invalid length")
213 @lru_cache(maxsize
=None)
214 def __get_reg_class(length
):
215 # type: (int) -> RegClass
217 for start
in range(GPR_COUNT
- length
):
218 reg
= GPRRange(start
, length
)
219 if any(i
in reg
for i
in SPECIAL_GPRS
):
222 return RegClass(regs
)
227 # type: () -> RegClass
228 return GPRRangeType
.__get
_reg
_class
(self
.length
)
231 def __eq__(self
, other
):
232 if isinstance(other
, GPRRangeType
):
233 return self
.length
== other
.length
238 return hash(self
.length
)
241 return f
"<gpr_ty[{self.length}]>"
244 GPRType
= GPRRangeType
245 """a length=1 GPRRangeType"""
248 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
250 class FixedGPRRangeType(RegType
):
253 def __init__(self
, reg
):
254 # type: (GPRRange) -> None
259 # type: () -> RegClass
260 return RegClass([self
.reg
])
265 return self
.reg
.length
268 return f
"<fixed({self.reg})>"
271 @plain_data(frozen
=True, unsafe_hash
=True)
273 class CAType(RegType
):
278 # type: () -> RegClass
279 return RegClass([XERBit
.CA
])
282 @plain_data(frozen
=True, unsafe_hash
=True)
284 class GlobalMemType(RegType
):
289 # type: () -> RegClass
290 return RegClass([GlobalMem
.GlobalMem
])
293 @plain_data(frozen
=True, unsafe_hash
=True)
295 class KnownVLType(RegType
):
296 __slots__
= "length",
298 def __init__(self
, length
):
299 # type: (int) -> None
300 if not (0 < length
<= 64):
301 raise ValueError("invalid VL value")
306 # type: () -> RegClass
307 return RegClass([VL
.VL_MAXVL
])
310 def assert_vl_is(vl
, expected_vl
):
311 # type: (SSAKnownVL | KnownVLType | int | None, int) -> None
314 elif isinstance(vl
, SSAVal
):
316 elif isinstance(vl
, KnownVLType
):
318 if vl
!= expected_vl
:
320 f
"wrong VL: expected {expected_vl} got {vl}")
326 @plain_data(frozen
=True, unsafe_hash
=True)
328 class StackSlot(RegLoc
):
329 __slots__
= "start_slot", "length_in_slots",
331 def __init__(self
, start_slot
, length_in_slots
):
332 # type: (int, int) -> None
333 self
.start_slot
= start_slot
334 if length_in_slots
< 1:
335 raise ValueError("invalid length_in_slots")
336 self
.length_in_slots
= length_in_slots
340 return self
.start_slot
+ self
.length_in_slots
343 def start_byte(self
):
344 return self
.start_slot
* STACK_SLOT_SIZE
346 def conflicts(self
, other
):
347 # type: (RegLoc) -> bool
348 if isinstance(other
, StackSlot
):
349 return (self
.stop_slot
> other
.start_slot
350 and other
.stop_slot
> self
.start_slot
)
353 def get_subreg_at_offset(self
, subreg_type
, offset
):
354 # type: (RegType, int) -> StackSlot
355 if not isinstance(subreg_type
, StackSlotType
):
356 raise ValueError(f
"subreg_type is not a "
357 f
"StackSlotType: {subreg_type}")
358 if offset
< 0 or offset
+ subreg_type
.length_in_slots
> self
.stop_slot
:
359 raise ValueError(f
"sub-register offset is out of range: {offset}")
360 return StackSlot(self
.start_slot
+ offset
, subreg_type
.length_in_slots
)
363 STACK_SLOT_COUNT
= 128
366 @plain_data(frozen
=True, eq
=False)
368 class StackSlotType(RegType
):
369 __slots__
= "length_in_slots",
371 def __init__(self
, length_in_slots
=1):
372 # type: (int) -> None
373 if length_in_slots
< 1:
374 raise ValueError("invalid length_in_slots")
375 self
.length_in_slots
= length_in_slots
378 @lru_cache(maxsize
=None)
379 def __get_reg_class(length_in_slots
):
380 # type: (int) -> RegClass
382 for start
in range(STACK_SLOT_COUNT
- length_in_slots
):
383 reg
= StackSlot(start
, length_in_slots
)
385 return RegClass(regs
)
389 # type: () -> RegClass
390 return StackSlotType
.__get
_reg
_class
(self
.length_in_slots
)
393 def __eq__(self
, other
):
394 if isinstance(other
, StackSlotType
):
395 return self
.length_in_slots
== other
.length_in_slots
400 return hash(self
.length_in_slots
)
403 @plain_data(frozen
=True, eq
=False, repr=False)
405 class SSAVal(Generic
[_RegType
]):
406 __slots__
= "op", "arg_name", "ty",
408 def __init__(self
, op
, arg_name
, ty
):
409 # type: (Op, str, _RegType) -> None
411 """the Op that writes this SSAVal"""
413 self
.arg_name
= arg_name
414 """the name of the argument of self.op that writes this SSAVal"""
418 def __eq__(self
, rhs
):
419 if isinstance(rhs
, SSAVal
):
420 return (self
.op
is rhs
.op
421 and self
.arg_name
== rhs
.arg_name
)
425 return hash((id(self
.op
), self
.arg_name
))
428 return f
"<#{self.op.id}.{self.arg_name}: {self.ty}>"
431 SSAGPRRange
= SSAVal
[GPRRangeType
]
432 SSAGPR
= SSAVal
[GPRType
]
433 SSAKnownVL
= SSAVal
[KnownVLType
]
437 @plain_data(unsafe_hash
=True, frozen
=True)
438 class EqualityConstraint
:
439 __slots__
= "lhs", "rhs"
441 def __init__(self
, lhs
, rhs
):
442 # type: (list[SSAVal], list[SSAVal]) -> None
445 if len(lhs
) == 0 or len(rhs
) == 0:
446 raise ValueError("can't constrain an empty list to be equal")
455 self
.ops
= [] # type: list[Op]
457 def __repr__(self
, short
=False):
460 ops
= ", ".join(op
.__repr
__(just_id
=True) for op
in self
.ops
)
461 return f
"<Fn([{ops}])>"
463 def pre_ra_sim(self
, state
):
464 # type: (PreRASimState) -> None
470 """ helper for __repr__ for when fields aren't set """
481 def __init__(self
, assigned_registers
):
482 # type: (dict[SSAVal, RegLoc]) -> None
483 self
.__assigned
_registers
= assigned_registers
485 def reg(self
, ssa_val
, expected_ty
):
486 # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc
488 reg
= self
.__assigned
_registers
[ssa_val
]
489 except KeyError as e
:
490 raise ValueError(f
"SSAVal not assigned a register: {ssa_val}")
491 wrong_len
= (isinstance(reg
, GPRRange
)
492 and reg
.length
!= ssa_val
.ty
.length
)
493 if not isinstance(reg
, expected_ty
) or wrong_len
:
495 f
"SSAVal is assigned a register of the wrong type: "
496 f
"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}")
499 def gpr_range(self
, ssa_val
):
500 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange
501 return self
.reg(ssa_val
, GPRRange
)
503 def stack_slot(self
, ssa_val
):
504 # type: (SSAVal[StackSlotType]) -> StackSlot
505 return self
.reg(ssa_val
, StackSlot
)
507 def gpr(self
, ssa_val
, vec
, offset
=0):
508 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str
509 reg
= self
.gpr_range(ssa_val
).start
+ offset
510 return "*" * vec
+ str(reg
)
512 def vgpr(self
, ssa_val
, offset
=0):
513 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str
514 return self
.gpr(ssa_val
=ssa_val
, vec
=True, offset
=offset
)
516 def sgpr(self
, ssa_val
, offset
=0):
517 # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str
518 return self
.gpr(ssa_val
=ssa_val
, vec
=False, offset
=offset
)
520 def needs_sv(self
, *regs
):
521 # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool
523 reg
= self
.gpr_range(reg
)
524 if reg
.length
!= 1 or reg
.start
>= 32:
529 GPR_SIZE_IN_BYTES
= 8
530 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* 8
531 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
534 @plain_data(frozen
=True)
537 __slots__
= ("gprs", "VLs", "CAs",
538 "global_mems", "stack_slots",
543 gprs
, # type: dict[SSAGPRRange, tuple[int, ...]]
544 VLs
, # type: dict[SSAKnownVL, int]
545 CAs
, # type: dict[SSAVal[CAType], bool]
546 global_mems
, # type: dict[SSAVal[GlobalMemType], FMap[int, int]]
547 stack_slots
, # type: dict[SSAVal[StackSlotType], tuple[int, ...]]
548 fixed_gprs
, # type: dict[SSAVal[FixedGPRRangeType], tuple[int, ...]]
550 # type: (...) -> None
554 self
.global_mems
= global_mems
555 self
.stack_slots
= stack_slots
556 self
.fixed_gprs
= fixed_gprs
559 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
560 class Op(metaclass
=ABCMeta
):
561 __slots__
= "id", "fn"
565 # type: () -> dict[str, SSAVal]
570 # type: () -> dict[str, SSAVal]
573 def get_equality_constraints(self
):
574 # type: () -> Iterable[EqualityConstraint]
578 def get_extra_interferences(self
):
579 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
583 def __init__(self
, fn
):
585 self
.id = len(fn
.ops
)
590 def __repr__(self
, just_id
=False):
591 fields_list
= [f
"#{self.id}"]
594 outputs
= self
.outputs()
595 except AttributeError:
598 for name
in fields(self
):
599 if name
in ("id", "fn"):
601 v
= getattr(self
, name
, _NOT_SET
)
602 if (outputs
is not None and name
in outputs
603 and outputs
[name
] is v
):
604 fields_list
.append(repr(v
))
606 fields_list
.append(f
"{name}={v!r}")
607 fields_str
= ', '.join(fields_list
)
608 return f
"{self.__class__.__name__}({fields_str})"
611 def get_asm_lines(self
, ctx
):
612 # type: (AsmContext) -> list[str]
613 """get the lines of assembly for this Op"""
617 def pre_ra_sim(self
, state
):
618 # type: (PreRASimState) -> None
619 """simulate op before register allocation"""
623 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
625 class OpLoadFromStackSlot(Op
):
626 __slots__
= "dest", "src", "vl"
629 # type: () -> dict[str, SSAVal]
630 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
631 if self
.vl
is not None:
632 retval
["vl"] = self
.vl
636 # type: () -> dict[str, SSAVal]
637 return {"dest": self
.dest
}
639 def __init__(self
, fn
, src
, vl
=None):
640 # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None
642 self
.dest
= SSAVal(self
, "dest", GPRRangeType(src
.ty
.length_in_slots
))
645 assert_vl_is(vl
, self
.dest
.ty
.length
)
647 def get_asm_lines(self
, ctx
):
648 # type: (AsmContext) -> list[str]
649 dest
= ctx
.gpr(self
.dest
, vec
=self
.dest
.ty
.length
!= 1)
650 src
= ctx
.stack_slot(self
.src
)
651 if ctx
.needs_sv(self
.dest
):
652 return [f
"sv.ld {dest}, {src.start_byte}(1)"]
653 return [f
"ld {dest}, {src.start_byte}(1)"]
655 def pre_ra_sim(self
, state
):
656 # type: (PreRASimState) -> None
657 """simulate op before register allocation"""
658 state
.gprs
[self
.dest
] = state
.stack_slots
[self
.src
]
661 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
663 class OpStoreToStackSlot(Op
):
664 __slots__
= "dest", "src", "vl"
667 # type: () -> dict[str, SSAVal]
668 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
669 if self
.vl
is not None:
670 retval
["vl"] = self
.vl
674 # type: () -> dict[str, SSAVal]
675 return {"dest": self
.dest
}
677 def __init__(self
, fn
, src
, vl
=None):
678 # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None
680 self
.dest
= SSAVal(self
, "dest", StackSlotType(src
.ty
.length
))
683 assert_vl_is(vl
, src
.ty
.length
)
685 def get_asm_lines(self
, ctx
):
686 # type: (AsmContext) -> list[str]
687 src
= ctx
.gpr(self
.src
, vec
=self
.src
.ty
.length
!= 1)
688 dest
= ctx
.stack_slot(self
.dest
)
689 if ctx
.needs_sv(self
.src
):
690 return [f
"sv.std {src}, {dest.start_byte}(1)"]
691 return [f
"std {src}, {dest.start_byte}(1)"]
693 def pre_ra_sim(self
, state
):
694 # type: (PreRASimState) -> None
695 """simulate op before register allocation"""
696 state
.stack_slots
[self
.dest
] = state
.gprs
[self
.src
]
699 _RegSrcType
= TypeVar("_RegSrcType", bound
=RegType
)
702 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
704 class OpCopy(Op
, Generic
[_RegSrcType
, _RegType
]):
705 __slots__
= "dest", "src", "vl"
708 # type: () -> dict[str, SSAVal]
709 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
710 if self
.vl
is not None:
711 retval
["vl"] = self
.vl
715 # type: () -> dict[str, SSAVal]
716 return {"dest": self
.dest
}
718 def __init__(self
, fn
, src
, dest_ty
=None, vl
=None):
719 # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None
722 dest_ty
= cast(_RegType
, src
.ty
)
723 if isinstance(src
.ty
, GPRRangeType
) \
724 and isinstance(dest_ty
, FixedGPRRangeType
):
725 if src
.ty
.length
!= dest_ty
.reg
.length
:
726 raise ValueError(f
"incompatible source and destination "
727 f
"types: {src.ty} and {dest_ty}")
728 length
= src
.ty
.length
729 elif isinstance(src
.ty
, FixedGPRRangeType
) \
730 and isinstance(dest_ty
, GPRRangeType
):
731 if src
.ty
.reg
.length
!= dest_ty
.length
:
732 raise ValueError(f
"incompatible source and destination "
733 f
"types: {src.ty} and {dest_ty}")
734 length
= src
.ty
.length
735 elif src
.ty
!= dest_ty
:
736 raise ValueError(f
"incompatible source and destination "
737 f
"types: {src.ty} and {dest_ty}")
738 elif isinstance(src
.ty
, StackSlotType
):
739 raise ValueError("can't use OpCopy on stack slots")
740 elif isinstance(src
.ty
, (GPRRangeType
, FixedGPRRangeType
)):
741 length
= src
.ty
.length
745 self
.dest
= SSAVal(self
, "dest", dest_ty
) # type: SSAVal[_RegType]
748 assert_vl_is(vl
, length
)
750 def get_asm_lines(self
, ctx
):
751 # type: (AsmContext) -> list[str]
752 if ctx
.reg(self
.src
, RegLoc
) == ctx
.reg(self
.dest
, RegLoc
):
754 if (isinstance(self
.src
.ty
, (GPRRangeType
, FixedGPRRangeType
)) and
755 isinstance(self
.dest
.ty
, (GPRRangeType
, FixedGPRRangeType
))):
756 vec
= self
.dest
.ty
.length
!= 1
757 dest
= ctx
.gpr_range(self
.dest
) # type: ignore
758 src
= ctx
.gpr_range(self
.src
) # type: ignore
759 dest_s
= ctx
.gpr(self
.dest
, vec
=vec
) # type: ignore
760 src_s
= ctx
.gpr(self
.src
, vec
=vec
) # type: ignore
762 if src
.conflicts(dest
) and src
.start
> dest
.start
:
764 if ctx
.needs_sv(self
.src
, self
.dest
): # type: ignore
765 return [f
"sv.or{mrr} {dest_s}, {src_s}, {src_s}"]
766 return [f
"or {dest_s}, {src_s}, {src_s}"]
767 raise NotImplementedError
769 def pre_ra_sim(self
, state
):
770 # type: (PreRASimState) -> None
771 if (isinstance(self
.src
.ty
, (GPRRangeType
, FixedGPRRangeType
)) and
772 isinstance(self
.dest
.ty
, (GPRRangeType
, FixedGPRRangeType
))):
773 if isinstance(self
.src
.ty
, GPRRangeType
):
774 v
= state
.gprs
[self
.src
] # type: ignore
776 v
= state
.fixed_gprs
[self
.src
] # type: ignore
777 if isinstance(self
.dest
.ty
, GPRRangeType
):
778 state
.gprs
[self
.dest
] = v
# type: ignore
780 state
.fixed_gprs
[self
.dest
] = v
# type: ignore
781 elif (isinstance(self
.src
.ty
, FixedGPRRangeType
) and
782 isinstance(self
.dest
.ty
, GPRRangeType
)):
783 state
.gprs
[self
.dest
] = state
.fixed_gprs
[self
.src
] # type: ignore
784 elif (isinstance(self
.src
.ty
, GPRRangeType
) and
785 isinstance(self
.dest
.ty
, FixedGPRRangeType
)):
786 state
.fixed_gprs
[self
.dest
] = state
.gprs
[self
.src
] # type: ignore
787 elif (isinstance(self
.src
.ty
, CAType
) and
788 self
.src
.ty
== self
.dest
.ty
):
789 state
.CAs
[self
.dest
] = state
.CAs
[self
.src
] # type: ignore
790 elif (isinstance(self
.src
.ty
, KnownVLType
) and
791 self
.src
.ty
== self
.dest
.ty
):
792 state
.VLs
[self
.dest
] = state
.VLs
[self
.src
] # type: ignore
793 elif (isinstance(self
.src
.ty
, GlobalMemType
) and
794 self
.src
.ty
== self
.dest
.ty
):
795 v
= state
.global_mems
[self
.src
] # type: ignore
796 state
.global_mems
[self
.dest
] = v
# type: ignore
798 raise NotImplementedError
801 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
804 __slots__
= "dest", "sources"
807 # type: () -> dict[str, SSAVal]
808 return {f
"sources[{i}]": v
for i
, v
in enumerate(self
.sources
)}
811 # type: () -> dict[str, SSAVal]
812 return {"dest": self
.dest
}
814 def __init__(self
, fn
, sources
):
815 # type: (Fn, Iterable[SSAGPRRange]) -> None
817 sources
= tuple(sources
)
818 self
.dest
= SSAVal(self
, "dest", GPRRangeType(
819 sum(i
.ty
.length
for i
in sources
)))
820 self
.sources
= sources
822 def get_equality_constraints(self
):
823 # type: () -> Iterable[EqualityConstraint]
824 yield EqualityConstraint([self
.dest
], [*self
.sources
])
826 def get_asm_lines(self
, ctx
):
827 # type: (AsmContext) -> list[str]
830 def pre_ra_sim(self
, state
):
831 # type: (PreRASimState) -> None
833 for src
in self
.sources
:
834 v
.extend(state
.gprs
[src
])
835 state
.gprs
[self
.dest
] = tuple(v
)
838 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
841 __slots__
= "results", "src"
844 # type: () -> dict[str, SSAVal]
845 return {"src": self
.src
}
848 # type: () -> dict[str, SSAVal]
849 return {i
.arg_name
: i
for i
in self
.results
}
851 def __init__(self
, fn
, src
, split_indexes
):
852 # type: (Fn, SSAGPRRange, Iterable[int]) -> None
854 ranges
= [] # type: list[GPRRangeType]
856 for i
in split_indexes
:
857 if not (0 < i
< src
.ty
.length
):
858 raise ValueError(f
"invalid split index: {i}, must be in "
859 f
"0 < i < {src.ty.length}")
860 ranges
.append(GPRRangeType(i
- last
))
862 ranges
.append(GPRRangeType(src
.ty
.length
- last
))
864 self
.results
= tuple(
865 SSAVal(self
, f
"results[{i}]", r
) for i
, r
in enumerate(ranges
))
867 def get_equality_constraints(self
):
868 # type: () -> Iterable[EqualityConstraint]
869 yield EqualityConstraint([*self
.results
], [self
.src
])
871 def get_asm_lines(self
, ctx
):
872 # type: (AsmContext) -> list[str]
875 def pre_ra_sim(self
, state
):
876 # type: (PreRASimState) -> None
877 rest
= state
.gprs
[self
.src
]
878 for dest
in reversed(self
.results
):
879 state
.gprs
[dest
] = rest
[-dest
.ty
.length
:]
880 rest
= rest
[:-dest
.ty
.length
]
883 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
885 class OpBigIntAddSub(Op
):
886 __slots__
= "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl"
889 # type: () -> dict[str, SSAVal]
890 retval
= {} # type: dict[str, SSAVal[Any]]
891 retval
["lhs"] = self
.lhs
892 retval
["rhs"] = self
.rhs
893 retval
["CA_in"] = self
.CA_in
894 if self
.vl
is not None:
895 retval
["vl"] = self
.vl
899 # type: () -> dict[str, SSAVal]
900 return {"out": self
.out
, "CA_out": self
.CA_out
}
902 def __init__(self
, fn
, lhs
, rhs
, CA_in
, is_sub
, vl
=None):
903 # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None
906 raise TypeError(f
"source types must match: "
907 f
"{lhs} doesn't match {rhs}")
908 self
.out
= SSAVal(self
, "out", lhs
.ty
)
912 self
.CA_out
= SSAVal(self
, "CA_out", CA_in
.ty
)
915 assert_vl_is(vl
, lhs
.ty
.length
)
917 def get_extra_interferences(self
):
918 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
919 yield self
.out
, self
.lhs
920 yield self
.out
, self
.rhs
922 def get_asm_lines(self
, ctx
):
923 # type: (AsmContext) -> list[str]
924 vec
= self
.out
.ty
.length
!= 1
925 out
= ctx
.gpr(self
.out
, vec
=vec
)
926 RA
= ctx
.gpr(self
.lhs
, vec
=vec
)
927 RB
= ctx
.gpr(self
.rhs
, vec
=vec
)
931 RA
, RB
= RB
, RA
# reorder to match subfe
932 if ctx
.needs_sv(self
.out
, self
.lhs
, self
.rhs
):
933 return [f
"sv.{mnemonic} {out}, {RA}, {RB}"]
934 return [f
"{mnemonic} {out}, {RA}, {RB}"]
936 def pre_ra_sim(self
, state
):
937 # type: (PreRASimState) -> None
938 carry
= state
.CAs
[self
.CA_in
]
939 out
= [] # type: list[int]
940 for l
, r
in zip(state
.gprs
[self
.lhs
], state
.gprs
[self
.rhs
]):
942 r
= r ^ GPR_VALUE_MASK
944 carry
= s
!= (s
& GPR_VALUE_MASK
)
945 out
.append(s
& GPR_VALUE_MASK
)
946 state
.CAs
[self
.CA_out
] = carry
947 state
.gprs
[self
.out
] = tuple(out
)
950 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
952 class OpBigIntMulDiv(Op
):
953 __slots__
= "RT", "RA", "RB", "RC", "RS", "is_div", "vl"
956 # type: () -> dict[str, SSAVal]
957 retval
= {} # type: dict[str, SSAVal[Any]]
958 retval
["RA"] = self
.RA
959 retval
["RB"] = self
.RB
960 retval
["RC"] = self
.RC
961 if self
.vl
is not None:
962 retval
["vl"] = self
.vl
966 # type: () -> dict[str, SSAVal]
967 return {"RT": self
.RT
, "RS": self
.RS
}
969 def __init__(self
, fn
, RA
, RB
, RC
, is_div
, vl
):
970 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None
972 self
.RT
= SSAVal(self
, "RT", RA
.ty
)
976 self
.RS
= SSAVal(self
, "RS", RC
.ty
)
979 assert_vl_is(vl
, RA
.ty
.length
)
981 def get_equality_constraints(self
):
982 # type: () -> Iterable[EqualityConstraint]
983 yield EqualityConstraint([self
.RC
], [self
.RS
])
985 def get_extra_interferences(self
):
986 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
987 yield self
.RT
, self
.RA
988 yield self
.RT
, self
.RB
989 yield self
.RT
, self
.RC
990 yield self
.RT
, self
.RS
991 yield self
.RS
, self
.RA
992 yield self
.RS
, self
.RB
994 def get_asm_lines(self
, ctx
):
995 # type: (AsmContext) -> list[str]
996 vec
= self
.RT
.ty
.length
!= 1
997 RT
= ctx
.gpr(self
.RT
, vec
=vec
)
998 RA
= ctx
.gpr(self
.RA
, vec
=vec
)
999 RB
= ctx
.sgpr(self
.RB
)
1000 RC
= ctx
.sgpr(self
.RC
)
1001 mnemonic
= "maddedu"
1003 mnemonic
= "divmod2du/mrr"
1004 return [f
"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
1006 def pre_ra_sim(self
, state
):
1007 # type: (PreRASimState) -> None
1008 carry
= state
.gprs
[self
.RC
][0]
1009 RA
= state
.gprs
[self
.RA
]
1010 RB
= state
.gprs
[self
.RB
][0]
1011 RT
= [0] * self
.RT
.ty
.length
1013 for i
in reversed(range(self
.RT
.ty
.length
)):
1014 if carry
< RB
and RB
!= 0:
1015 div
, mod
= divmod((carry
<< 64) | RA
[i
], RB
)
1016 RT
[i
] = div
& GPR_VALUE_MASK
1017 carry
= mod
& GPR_VALUE_MASK
1019 RT
[i
] = GPR_VALUE_MASK
1022 for i
in range(self
.RT
.ty
.length
):
1023 v
= RA
[i
] * RB
+ carry
1025 RT
[i
] = v
& GPR_VALUE_MASK
1026 state
.gprs
[self
.RS
] = carry
,
1027 state
.gprs
[self
.RT
] = tuple(RT
)
1032 class ShiftKind(Enum
):
1037 def make_big_int_carry_in(self
, fn
, inp
):
1038 # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]]
1039 if self
is ShiftKind
.Sl
or self
is ShiftKind
.Sr
:
1043 assert self
is ShiftKind
.Sra
1044 split
= OpSplit(fn
, inp
, [inp
.ty
.length
- 1])
1045 shr
= OpShiftImm(fn
, split
.results
[1], sh
=63, kind
=ShiftKind
.Sra
)
1046 return shr
.out
, [split
, shr
]
1048 def make_big_int_shift(self
, fn
, inp
, sh
, vl
):
1049 # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]]
1050 carry_in
, ops
= self
.make_big_int_carry_in(fn
, inp
)
1051 big_int_shift
= OpBigIntShift(fn
, inp
, sh
, carry_in
, kind
=self
, vl
=vl
)
1052 ops
.append(big_int_shift
)
1053 return big_int_shift
.out
, ops
1056 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1058 class OpBigIntShift(Op
):
1059 __slots__
= "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl"
1062 # type: () -> dict[str, SSAVal]
1063 retval
= {} # type: dict[str, SSAVal[Any]]
1064 retval
["inp"] = self
.inp
1065 retval
["sh"] = self
.sh
1066 retval
["carry_in"] = self
.carry_in
1067 if self
.vl
is not None:
1068 retval
["vl"] = self
.vl
1072 # type: () -> dict[str, SSAVal]
1073 return {"out": self
.out
, "_out_padding": self
._out
_padding
}
1075 def __init__(self
, fn
, inp
, sh
, carry_in
, kind
, vl
=None):
1076 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None
1077 super().__init
__(fn
)
1078 self
.out
= SSAVal(self
, "out", inp
.ty
)
1079 self
._out
_padding
= SSAVal(self
, "_out_padding", GPRRangeType())
1080 self
.carry_in
= carry_in
1085 assert_vl_is(vl
, inp
.ty
.length
)
1087 def get_extra_interferences(self
):
1088 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1089 yield self
.out
, self
.sh
1091 def get_equality_constraints(self
):
1092 # type: () -> Iterable[EqualityConstraint]
1093 if self
.kind
is ShiftKind
.Sl
:
1094 yield EqualityConstraint([self
.carry_in
, self
.inp
],
1095 [self
.out
, self
._out
_padding
])
1097 assert self
.kind
is ShiftKind
.Sr
or self
.kind
is ShiftKind
.Sra
1098 yield EqualityConstraint([self
.inp
, self
.carry_in
],
1099 [self
._out
_padding
, self
.out
])
1101 def get_asm_lines(self
, ctx
):
1102 # type: (AsmContext) -> list[str]
1103 vec
= self
.out
.ty
.length
!= 1
1104 if self
.kind
is ShiftKind
.Sl
:
1105 RT
= ctx
.gpr(self
.out
, vec
=vec
)
1106 RA
= ctx
.gpr(self
.out
, vec
=vec
, offset
=-1)
1107 RB
= ctx
.sgpr(self
.sh
)
1108 mrr
= "/mrr" if vec
else ""
1109 return [f
"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"]
1111 assert self
.kind
is ShiftKind
.Sr
or self
.kind
is ShiftKind
.Sra
1112 RT
= ctx
.gpr(self
.out
, vec
=vec
)
1113 RA
= ctx
.gpr(self
.out
, vec
=vec
, offset
=1)
1114 RB
= ctx
.sgpr(self
.sh
)
1115 return [f
"sv.dsrd {RT}, {RA}, {RB}, 1"]
1117 def pre_ra_sim(self
, state
):
1118 # type: (PreRASimState) -> None
1119 out
= [0] * self
.out
.ty
.length
1120 carry
= state
.gprs
[self
.carry_in
][0]
1121 sh
= state
.gprs
[self
.sh
][0] % 64
1122 if self
.kind
is ShiftKind
.Sl
:
1123 inp
= carry
, *state
.gprs
[self
.inp
]
1124 for i
in reversed(range(self
.out
.ty
.length
)):
1125 v
= inp
[i
] |
(inp
[i
+ 1] << 64)
1127 out
[i
] = (v
>> 64) & GPR_VALUE_MASK
1129 assert self
.kind
is ShiftKind
.Sr
or self
.kind
is ShiftKind
.Sra
1130 inp
= *state
.gprs
[self
.inp
], carry
1131 for i
in range(self
.out
.ty
.length
):
1132 v
= inp
[i
] |
(inp
[i
+ 1] << 64)
1134 out
[i
] = v
& GPR_VALUE_MASK
1135 # state.gprs[self._out_padding] is intentionally not written
1136 state
.gprs
[self
.out
] = tuple(out
)
1139 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1141 class OpShiftImm(Op
):
1142 __slots__
= "out", "inp", "sh", "kind", "ca_out"
1145 # type: () -> dict[str, SSAVal]
1146 return {"inp": self
.inp
}
1149 # type: () -> dict[str, SSAVal]
1150 if self
.ca_out
is not None:
1151 return {"out": self
.out
, "ca_out": self
.ca_out
}
1152 return {"out": self
.out
}
1154 def __init__(self
, fn
, inp
, sh
, kind
):
1155 # type: (Fn, SSAGPR, int, ShiftKind) -> None
1156 super().__init
__(fn
)
1157 self
.out
= SSAVal(self
, "out", inp
.ty
)
1159 if not (0 <= sh
< 64):
1160 raise ValueError("shift amount out of range")
1163 if self
.kind
is ShiftKind
.Sra
:
1164 self
.ca_out
= SSAVal(self
, "ca_out", CAType())
1168 def get_asm_lines(self
, ctx
):
1169 # type: (AsmContext) -> list[str]
1170 out
= ctx
.sgpr(self
.out
)
1171 inp
= ctx
.sgpr(self
.inp
)
1172 if self
.kind
is ShiftKind
.Sl
:
1174 args
= f
"{self.sh}, {63 - self.sh}"
1175 elif self
.kind
is ShiftKind
.Sr
:
1177 v
= (64 - self
.sh
) % 64
1178 args
= f
"{v}, {self.sh}"
1180 assert self
.kind
is ShiftKind
.Sra
1183 if ctx
.needs_sv(self
.out
, self
.inp
):
1184 return [f
"sv.{mnemonic} {out}, {inp}, {args}"]
1185 return [f
"{mnemonic} {out}, {inp}, {args}"]
1187 def pre_ra_sim(self
, state
):
1188 # type: (PreRASimState) -> None
1189 inp
= state
.gprs
[self
.inp
][0]
1190 if self
.kind
is ShiftKind
.Sl
:
1191 assert self
.ca_out
is None
1192 out
= inp
<< self
.sh
1193 elif self
.kind
is ShiftKind
.Sr
:
1194 assert self
.ca_out
is None
1195 out
= inp
>> self
.sh
1197 assert self
.kind
is ShiftKind
.Sra
1198 assert self
.ca_out
is not None
1199 if inp
& (1 << 63): # sign extend
1201 out
= inp
>> self
.sh
1202 ca
= inp
< 0 and (out
<< self
.sh
) != inp
1203 state
.CAs
[self
.ca_out
] = ca
1204 state
.gprs
[self
.out
] = out
,
1207 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1210 __slots__
= "out", "value", "vl"
1213 # type: () -> dict[str, SSAVal]
1214 retval
= {} # type: dict[str, SSAVal[Any]]
1215 if self
.vl
is not None:
1216 retval
["vl"] = self
.vl
1220 # type: () -> dict[str, SSAVal]
1221 return {"out": self
.out
}
1223 def __init__(self
, fn
, value
, vl
=None):
1224 # type: (Fn, int, SSAKnownVL | None) -> None
1225 super().__init
__(fn
)
1229 length
= vl
.ty
.length
1230 self
.out
= SSAVal(self
, "out", GPRRangeType(length
))
1231 if not (-1 << 15 <= value
<= (1 << 15) - 1):
1232 raise ValueError(f
"value out of range: {value}")
1235 assert_vl_is(vl
, length
)
1237 def get_asm_lines(self
, ctx
):
1238 # type: (AsmContext) -> list[str]
1239 vec
= self
.out
.ty
.length
!= 1
1240 out
= ctx
.gpr(self
.out
, vec
=vec
)
1241 if ctx
.needs_sv(self
.out
):
1242 return [f
"sv.addi {out}, 0, {self.value}"]
1243 return [f
"addi {out}, 0, {self.value}"]
1245 def pre_ra_sim(self
, state
):
1246 # type: (PreRASimState) -> None
1247 value
= self
.value
& GPR_VALUE_MASK
1248 state
.gprs
[self
.out
] = (value
,) * self
.out
.ty
.length
1251 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1254 __slots__
= "out", "value"
1257 # type: () -> dict[str, SSAVal]
1261 # type: () -> dict[str, SSAVal]
1262 return {"out": self
.out
}
1264 def __init__(self
, fn
, value
):
1265 # type: (Fn, bool) -> None
1266 super().__init
__(fn
)
1267 self
.out
= SSAVal(self
, "out", CAType())
1270 def get_asm_lines(self
, ctx
):
1271 # type: (AsmContext) -> list[str]
1273 return ["subfic 0, 0, -1"]
1274 return ["addic 0, 0, 0"]
1276 def pre_ra_sim(self
, state
):
1277 # type: (PreRASimState) -> None
1278 state
.CAs
[self
.out
] = self
.value
1281 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1284 __slots__
= "RT", "RA", "offset", "mem", "vl"
1287 # type: () -> dict[str, SSAVal]
1288 retval
= {} # type: dict[str, SSAVal[Any]]
1289 retval
["RA"] = self
.RA
1290 retval
["mem"] = self
.mem
1291 if self
.vl
is not None:
1292 retval
["vl"] = self
.vl
1296 # type: () -> dict[str, SSAVal]
1297 return {"RT": self
.RT
}
1299 def __init__(self
, fn
, RA
, offset
, mem
, vl
=None):
1300 # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1301 super().__init
__(fn
)
1305 length
= vl
.ty
.length
1306 self
.RT
= SSAVal(self
, "RT", GPRRangeType(length
))
1308 if not (-1 << 15 <= offset
<= (1 << 15) - 1):
1309 raise ValueError(f
"offset out of range: {offset}")
1311 raise ValueError(f
"offset not aligned: {offset}")
1312 self
.offset
= offset
1315 assert_vl_is(vl
, length
)
1317 def get_extra_interferences(self
):
1318 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1319 if self
.RT
.ty
.length
> 1:
1320 yield self
.RT
, self
.RA
1322 def get_asm_lines(self
, ctx
):
1323 # type: (AsmContext) -> list[str]
1324 RT
= ctx
.gpr(self
.RT
, vec
=self
.RT
.ty
.length
!= 1)
1325 RA
= ctx
.sgpr(self
.RA
)
1326 if ctx
.needs_sv(self
.RT
, self
.RA
):
1327 return [f
"sv.ld {RT}, {self.offset}({RA})"]
1328 return [f
"ld {RT}, {self.offset}({RA})"]
1330 def pre_ra_sim(self
, state
):
1331 # type: (PreRASimState) -> None
1332 addr
= state
.gprs
[self
.RA
][0]
1334 RT
= [0] * self
.RT
.ty
.length
1335 mem
= state
.global_mems
[self
.mem
]
1336 for i
in range(self
.RT
.ty
.length
):
1337 cur_addr
= (addr
+ i
* GPR_SIZE_IN_BYTES
) & GPR_VALUE_MASK
1338 if cur_addr
% GPR_SIZE_IN_BYTES
!= 0:
1339 raise ValueError(f
"can't load from unaligned address: "
1341 for j
in range(GPR_SIZE_IN_BYTES
):
1342 byte_val
= mem
.get(cur_addr
+ j
, 0) & 0xFF
1343 RT
[i
] |
= byte_val
<< (j
* 8)
1344 state
.gprs
[self
.RT
] = tuple(RT
)
1347 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1350 __slots__
= "RS", "RA", "offset", "mem_in", "mem_out", "vl"
1353 # type: () -> dict[str, SSAVal]
1354 retval
= {} # type: dict[str, SSAVal[Any]]
1355 retval
["RS"] = self
.RS
1356 retval
["RA"] = self
.RA
1357 retval
["mem_in"] = self
.mem_in
1358 if self
.vl
is not None:
1359 retval
["vl"] = self
.vl
1363 # type: () -> dict[str, SSAVal]
1364 return {"mem_out": self
.mem_out
}
1366 def __init__(self
, fn
, RS
, RA
, offset
, mem_in
, vl
=None):
1367 # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1368 super().__init
__(fn
)
1371 if not (-1 << 15 <= offset
<= (1 << 15) - 1):
1372 raise ValueError(f
"offset out of range: {offset}")
1374 raise ValueError(f
"offset not aligned: {offset}")
1375 self
.offset
= offset
1376 self
.mem_in
= mem_in
1377 self
.mem_out
= SSAVal(self
, "mem_out", mem_in
.ty
)
1379 assert_vl_is(vl
, RS
.ty
.length
)
1381 def get_asm_lines(self
, ctx
):
1382 # type: (AsmContext) -> list[str]
1383 RS
= ctx
.gpr(self
.RS
, vec
=self
.RS
.ty
.length
!= 1)
1384 RA
= ctx
.sgpr(self
.RA
)
1385 if ctx
.needs_sv(self
.RS
, self
.RA
):
1386 return [f
"sv.std {RS}, {self.offset}({RA})"]
1387 return [f
"std {RS}, {self.offset}({RA})"]
1389 def pre_ra_sim(self
, state
):
1390 # type: (PreRASimState) -> None
1391 mem
= dict(state
.global_mems
[self
.mem_in
])
1392 addr
= state
.gprs
[self
.RA
][0]
1394 RS
= state
.gprs
[self
.RS
]
1395 for i
in range(self
.RS
.ty
.length
):
1396 cur_addr
= (addr
+ i
* GPR_SIZE_IN_BYTES
) & GPR_VALUE_MASK
1397 if cur_addr
% GPR_SIZE_IN_BYTES
!= 0:
1398 raise ValueError(f
"can't store to unaligned address: "
1400 for j
in range(GPR_SIZE_IN_BYTES
):
1401 mem
[cur_addr
+ j
] = (RS
[i
] >> (j
* 8)) & 0xFF
1402 state
.global_mems
[self
.mem_out
] = FMap(mem
)
1405 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1407 class OpFuncArg(Op
):
1411 # type: () -> dict[str, SSAVal]
1415 # type: () -> dict[str, SSAVal]
1416 return {"out": self
.out
}
1418 def __init__(self
, fn
, ty
):
1419 # type: (Fn, FixedGPRRangeType) -> None
1420 super().__init
__(fn
)
1421 self
.out
= SSAVal(self
, "out", ty
)
1423 def get_asm_lines(self
, ctx
):
1424 # type: (AsmContext) -> list[str]
1427 def pre_ra_sim(self
, state
):
1428 # type: (PreRASimState) -> None
1429 if self
.out
not in state
.fixed_gprs
:
1430 state
.fixed_gprs
[self
.out
] = (0,) * self
.out
.ty
.length
1433 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1435 class OpInputMem(Op
):
1439 # type: () -> dict[str, SSAVal]
1443 # type: () -> dict[str, SSAVal]
1444 return {"out": self
.out
}
1446 def __init__(self
, fn
):
1447 # type: (Fn) -> None
1448 super().__init
__(fn
)
1449 self
.out
= SSAVal(self
, "out", GlobalMemType())
1451 def get_asm_lines(self
, ctx
):
1452 # type: (AsmContext) -> list[str]
1455 def pre_ra_sim(self
, state
):
1456 # type: (PreRASimState) -> None
1457 if self
.out
not in state
.global_mems
:
1458 state
.global_mems
[self
.out
] = FMap()
1461 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1463 class OpSetVLImm(Op
):
1467 # type: () -> dict[str, SSAVal]
1471 # type: () -> dict[str, SSAVal]
1472 return {"out": self
.out
}
1474 def __init__(self
, fn
, length
):
1475 # type: (Fn, int) -> None
1476 super().__init
__(fn
)
1477 self
.out
= SSAVal(self
, "out", KnownVLType(length
))
1479 def get_asm_lines(self
, ctx
):
1480 # type: (AsmContext) -> list[str]
1481 return [f
"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
1483 def pre_ra_sim(self
, state
):
1484 # type: (PreRASimState) -> None
1485 state
.VLs
[self
.out
] = self
.out
.ty
.length
1488 def op_set_to_list(ops
):
1489 # type: (Iterable[Op]) -> list[Op]
1490 worklists
= [{}] # type: list[dict[Op, None]]
1491 inps_to_ops_map
= defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
1492 ops_to_pending_input_count_map
= {} # type: dict[Op, int]
1495 for val
in op
.inputs().values():
1497 inps_to_ops_map
[val
][op
] = None
1498 while len(worklists
) <= input_count
:
1499 worklists
.append({})
1500 ops_to_pending_input_count_map
[op
] = input_count
1501 worklists
[input_count
][op
] = None
1502 retval
= [] # type: list[Op]
1503 ready_vals
= OSet() # type: OSet[SSAVal]
1504 while len(worklists
[0]) != 0:
1505 writing_op
= next(iter(worklists
[0]))
1506 del worklists
[0][writing_op
]
1507 retval
.append(writing_op
)
1508 for val
in writing_op
.outputs().values():
1509 if val
in ready_vals
:
1510 raise ValueError(f
"multiple instructions must not write "
1511 f
"to the same SSA value: {val}")
1513 for reading_op
in inps_to_ops_map
[val
]:
1514 pending
= ops_to_pending_input_count_map
[reading_op
]
1515 del worklists
[pending
][reading_op
]
1517 worklists
[pending
][reading_op
] = None
1518 ops_to_pending_input_count_map
[reading_op
] = pending
1519 for worklist
in worklists
:
1521 raise ValueError(f
"instruction is part of a dependency loop or "
1522 f
"its inputs are never written: {op}")
1526 def generate_assembly(ops
, assigned_registers
=None):
1527 # type: (list[Op], dict[SSAVal, RegLoc] | None) -> list[str]
1528 if assigned_registers
is None:
1529 from bigint_presentation_code
.register_allocator
import \
1531 assigned_registers
= allocate_registers(ops
)
1532 ctx
= AsmContext(assigned_registers
)
1533 retval
= [] # list[str]
1535 retval
.extend(op
.get_asm_lines(ctx
))
1536 retval
.append("bclr 20, 0, 0")