2 Compiler IR for Toom-Cook algorithm generator for SVP64
4 This assumes VL != 0 throughout.
7 from abc
import ABCMeta
, abstractmethod
8 from collections
import defaultdict
9 from enum
import Enum
, EnumMeta
, unique
10 from functools
import lru_cache
11 from typing
import Any
, Generic
, Iterable
, Sequence
, Type
, TypeVar
, cast
13 from nmutil
.plain_data
import fields
, plain_data
15 from bigint_presentation_code
.util
import FMap
, OFSet
, OSet
, final
18 class ABCEnumMeta(EnumMeta
, ABCMeta
):
22 class RegLoc(metaclass
=ABCMeta
):
26 def conflicts(self
, other
):
27 # type: (RegLoc) -> bool
30 def get_subreg_at_offset(self
, subreg_type
, offset
):
31 # type: (RegType, int) -> RegLoc
32 if self
not in subreg_type
.reg_class
:
33 raise ValueError(f
"register not a member of subreg_type: "
34 f
"reg={self} subreg_type={subreg_type}")
36 raise ValueError(f
"non-zero sub-register offset not supported "
37 f
"for register: {self}")
44 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
46 class GPRRange(RegLoc
, Sequence
["GPRRange"]):
47 __slots__
= "start", "length"
49 def __init__(self
, start
, length
=None):
50 # type: (int | range, int | None) -> None
51 if isinstance(start
, range):
52 if length
is not None:
53 raise TypeError("can't specify length when input is a range")
55 raise ValueError("range must have a step of 1")
60 if length
<= 0 or start
< 0 or start
+ length
> GPR_COUNT
:
61 raise ValueError("invalid GPRRange")
67 return self
.start
+ self
.length
75 return range(self
.start
, self
.stop
, self
.step
)
80 def __getitem__(self
, item
):
81 # type: (int | slice) -> GPRRange
82 return GPRRange(self
.range[item
])
84 def __contains__(self
, value
):
85 # type: (GPRRange) -> bool
86 return value
.start
>= self
.start
and value
.stop
<= self
.stop
88 def index(self
, sub
, start
=None, end
=None):
89 # type: (GPRRange, int | None, int | None) -> int
90 r
= self
.range[start
:end
]
91 if sub
.start
< r
.start
or sub
.stop
> r
.stop
:
92 raise ValueError("GPR range not found")
93 return sub
.start
- self
.start
95 def count(self
, sub
, start
=None, end
=None):
96 # type: (GPRRange, int | None, int | None) -> int
97 r
= self
.range[start
:end
]
100 return int(sub
in GPRRange(r
))
102 def conflicts(self
, other
):
103 # type: (RegLoc) -> bool
104 if isinstance(other
, GPRRange
):
105 return self
.stop
> other
.start
and other
.stop
> self
.start
108 def get_subreg_at_offset(self
, subreg_type
, offset
):
109 # type: (RegType, int) -> GPRRange
110 if not isinstance(subreg_type
, (GPRRangeType
, FixedGPRRangeType
)):
111 raise ValueError(f
"subreg_type is not a FixedGPRRangeType or "
112 f
"GPRRangeType: {subreg_type}")
113 if offset
< 0 or offset
+ subreg_type
.length
> self
.stop
:
114 raise ValueError(f
"sub-register offset is out of range: {offset}")
115 return GPRRange(self
.start
+ offset
, subreg_type
.length
)
119 return f
"<r{self.start}>"
120 return f
"<r{self.start}..len={self.length}>"
123 SPECIAL_GPRS
= GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
128 class XERBit(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
131 def conflicts(self
, other
):
132 # type: (RegLoc) -> bool
133 if isinstance(other
, XERBit
):
140 class GlobalMem(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
141 """singleton representing all non-StackSlot memory -- treated as a single
142 physical register for register allocation purposes.
144 GlobalMem
= "GlobalMem"
146 def conflicts(self
, other
):
147 # type: (RegLoc) -> bool
148 if isinstance(other
, GlobalMem
):
155 class VL(RegLoc
, Enum
, metaclass
=ABCEnumMeta
):
156 VL_MAXVL
= "VL_MAXVL"
159 def conflicts(self
, other
):
160 # type: (RegLoc) -> bool
161 if isinstance(other
, VL
):
167 class RegClass(OFSet
[RegLoc
]):
168 """ an ordered set of registers.
169 earlier registers are preferred by the register allocator.
172 @lru_cache(maxsize
=None, typed
=True)
173 def max_conflicts_with(self
, other
):
174 # type: (RegClass | RegLoc) -> int
175 """the largest number of registers in `self` that a single register
176 from `other` can conflict with
178 if isinstance(other
, RegClass
):
179 return max(self
.max_conflicts_with(i
) for i
in other
)
181 return sum(other
.conflicts(i
) for i
in self
)
184 @plain_data(frozen
=True, unsafe_hash
=True)
185 class RegType(metaclass
=ABCMeta
):
191 # type: () -> RegClass
195 _RegType
= TypeVar("_RegType", bound
=RegType
)
196 _RegLoc
= TypeVar("_RegLoc", bound
=RegLoc
)
199 @plain_data(frozen
=True, eq
=False, repr=False)
201 class GPRRangeType(RegType
):
202 __slots__
= "length",
204 def __init__(self
, length
=1):
205 # type: (int) -> None
206 if length
< 1 or length
> GPR_COUNT
:
207 raise ValueError("invalid length")
211 @lru_cache(maxsize
=None)
212 def __get_reg_class(length
):
213 # type: (int) -> RegClass
215 for start
in range(GPR_COUNT
- length
):
216 reg
= GPRRange(start
, length
)
217 if any(i
in reg
for i
in SPECIAL_GPRS
):
220 return RegClass(regs
)
225 # type: () -> RegClass
226 return GPRRangeType
.__get
_reg
_class
(self
.length
)
229 def __eq__(self
, other
):
230 if isinstance(other
, GPRRangeType
):
231 return self
.length
== other
.length
236 return hash(self
.length
)
239 return f
"<gpr_ty[{self.length}]>"
242 GPRType
= GPRRangeType
243 """a length=1 GPRRangeType"""
246 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
248 class FixedGPRRangeType(RegType
):
251 def __init__(self
, reg
):
252 # type: (GPRRange) -> None
257 # type: () -> RegClass
258 return RegClass([self
.reg
])
263 return self
.reg
.length
266 return f
"<fixed({self.reg})>"
269 @plain_data(frozen
=True, unsafe_hash
=True)
271 class CAType(RegType
):
276 # type: () -> RegClass
277 return RegClass([XERBit
.CA
])
280 @plain_data(frozen
=True, unsafe_hash
=True)
282 class GlobalMemType(RegType
):
287 # type: () -> RegClass
288 return RegClass([GlobalMem
.GlobalMem
])
291 @plain_data(frozen
=True, unsafe_hash
=True)
293 class KnownVLType(RegType
):
294 __slots__
= "length",
296 def __init__(self
, length
):
297 # type: (int) -> None
298 if not (0 < length
<= 64):
299 raise ValueError("invalid VL value")
304 # type: () -> RegClass
305 return RegClass([VL
.VL_MAXVL
])
308 def assert_vl_is(vl
, expected_vl
):
309 # type: (SSAKnownVL | KnownVLType | int | None, int) -> None
312 elif isinstance(vl
, SSAVal
):
314 elif isinstance(vl
, KnownVLType
):
316 if vl
!= expected_vl
:
318 f
"wrong VL: expected {expected_vl} got {vl}")
324 @plain_data(frozen
=True, unsafe_hash
=True)
326 class StackSlot(RegLoc
):
327 __slots__
= "start_slot", "length_in_slots",
329 def __init__(self
, start_slot
, length_in_slots
):
330 # type: (int, int) -> None
331 self
.start_slot
= start_slot
332 if length_in_slots
< 1:
333 raise ValueError("invalid length_in_slots")
334 self
.length_in_slots
= length_in_slots
338 return self
.start_slot
+ self
.length_in_slots
341 def start_byte(self
):
342 return self
.start_slot
* STACK_SLOT_SIZE
344 def conflicts(self
, other
):
345 # type: (RegLoc) -> bool
346 if isinstance(other
, StackSlot
):
347 return (self
.stop_slot
> other
.start_slot
348 and other
.stop_slot
> self
.start_slot
)
351 def get_subreg_at_offset(self
, subreg_type
, offset
):
352 # type: (RegType, int) -> StackSlot
353 if not isinstance(subreg_type
, StackSlotType
):
354 raise ValueError(f
"subreg_type is not a "
355 f
"StackSlotType: {subreg_type}")
356 if offset
< 0 or offset
+ subreg_type
.length_in_slots
> self
.stop_slot
:
357 raise ValueError(f
"sub-register offset is out of range: {offset}")
358 return StackSlot(self
.start_slot
+ offset
, subreg_type
.length_in_slots
)
361 STACK_SLOT_COUNT
= 128
364 @plain_data(frozen
=True, eq
=False)
366 class StackSlotType(RegType
):
367 __slots__
= "length_in_slots",
369 def __init__(self
, length_in_slots
=1):
370 # type: (int) -> None
371 if length_in_slots
< 1:
372 raise ValueError("invalid length_in_slots")
373 self
.length_in_slots
= length_in_slots
376 @lru_cache(maxsize
=None)
377 def __get_reg_class(length_in_slots
):
378 # type: (int) -> RegClass
380 for start
in range(STACK_SLOT_COUNT
- length_in_slots
):
381 reg
= StackSlot(start
, length_in_slots
)
383 return RegClass(regs
)
387 # type: () -> RegClass
388 return StackSlotType
.__get
_reg
_class
(self
.length_in_slots
)
391 def __eq__(self
, other
):
392 if isinstance(other
, StackSlotType
):
393 return self
.length_in_slots
== other
.length_in_slots
398 return hash(self
.length_in_slots
)
401 @plain_data(frozen
=True, eq
=False, repr=False)
403 class SSAVal(Generic
[_RegType
]):
404 __slots__
= "op", "arg_name", "ty",
406 def __init__(self
, op
, arg_name
, ty
):
407 # type: (Op, str, _RegType) -> None
409 """the Op that writes this SSAVal"""
411 self
.arg_name
= arg_name
412 """the name of the argument of self.op that writes this SSAVal"""
416 def __eq__(self
, rhs
):
417 if isinstance(rhs
, SSAVal
):
418 return (self
.op
is rhs
.op
419 and self
.arg_name
== rhs
.arg_name
)
423 return hash((id(self
.op
), self
.arg_name
))
426 return f
"<#{self.op.id}.{self.arg_name}: {self.ty}>"
429 SSAGPRRange
= SSAVal
[GPRRangeType
]
430 SSAGPR
= SSAVal
[GPRType
]
431 SSAKnownVL
= SSAVal
[KnownVLType
]
435 @plain_data(unsafe_hash
=True, frozen
=True)
436 class EqualityConstraint
:
437 __slots__
= "lhs", "rhs"
439 def __init__(self
, lhs
, rhs
):
440 # type: (list[SSAVal], list[SSAVal]) -> None
443 if len(lhs
) == 0 or len(rhs
) == 0:
444 raise ValueError("can't constrain an empty list to be equal")
453 self
.ops
= [] # type: list[Op]
455 def __repr__(self
, short
=False):
458 ops
= ", ".join(op
.__repr
__(just_id
=True) for op
in self
.ops
)
459 return f
"<Fn([{ops}])>"
461 def pre_ra_sim(self
, state
):
462 # type: (PreRASimState) -> None
468 """ helper for __repr__ for when fields aren't set """
479 def __init__(self
, assigned_registers
):
480 # type: (dict[SSAVal, RegLoc]) -> None
481 self
.__assigned
_registers
= assigned_registers
483 def reg(self
, ssa_val
, expected_ty
):
484 # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc
486 reg
= self
.__assigned
_registers
[ssa_val
]
487 except KeyError as e
:
488 raise ValueError(f
"SSAVal not assigned a register: {ssa_val}")
489 wrong_len
= (isinstance(reg
, GPRRange
)
490 and reg
.length
!= ssa_val
.ty
.length
)
491 if not isinstance(reg
, expected_ty
) or wrong_len
:
493 f
"SSAVal is assigned a register of the wrong type: "
494 f
"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}")
497 def gpr_range(self
, ssa_val
):
498 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange
499 return self
.reg(ssa_val
, GPRRange
)
501 def stack_slot(self
, ssa_val
):
502 # type: (SSAVal[StackSlotType]) -> StackSlot
503 return self
.reg(ssa_val
, StackSlot
)
505 def gpr(self
, ssa_val
, vec
, offset
=0):
506 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str
507 reg
= self
.gpr_range(ssa_val
).start
+ offset
508 return "*" * vec
+ str(reg
)
510 def vgpr(self
, ssa_val
, offset
=0):
511 # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str
512 return self
.gpr(ssa_val
=ssa_val
, vec
=True, offset
=offset
)
514 def sgpr(self
, ssa_val
, offset
=0):
515 # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str
516 return self
.gpr(ssa_val
=ssa_val
, vec
=False, offset
=offset
)
518 def needs_sv(self
, *regs
):
519 # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool
521 reg
= self
.gpr_range(reg
)
522 if reg
.length
!= 1 or reg
.start
>= 32:
527 GPR_SIZE_IN_BYTES
= 8
528 GPR_SIZE_IN_BITS
= GPR_SIZE_IN_BYTES
* 8
529 GPR_VALUE_MASK
= (1 << GPR_SIZE_IN_BITS
) - 1
532 @plain_data(frozen
=True)
535 __slots__
= ("gprs", "VLs", "CAs",
536 "global_mems", "stack_slots",
541 gprs
, # type: dict[SSAGPRRange, tuple[int, ...]]
542 VLs
, # type: dict[SSAKnownVL, int]
543 CAs
, # type: dict[SSAVal[CAType], bool]
544 global_mems
, # type: dict[SSAVal[GlobalMemType], FMap[int, int]]
545 stack_slots
, # type: dict[SSAVal[StackSlotType], tuple[int, ...]]
546 fixed_gprs
, # type: dict[SSAVal[FixedGPRRangeType], tuple[int, ...]]
548 # type: (...) -> None
552 self
.global_mems
= global_mems
553 self
.stack_slots
= stack_slots
554 self
.fixed_gprs
= fixed_gprs
557 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
558 class Op(metaclass
=ABCMeta
):
559 __slots__
= "id", "fn"
563 # type: () -> dict[str, SSAVal]
568 # type: () -> dict[str, SSAVal]
571 def get_equality_constraints(self
):
572 # type: () -> Iterable[EqualityConstraint]
576 def get_extra_interferences(self
):
577 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
581 def __init__(self
, fn
):
583 self
.id = len(fn
.ops
)
588 def __repr__(self
, just_id
=False):
589 fields_list
= [f
"#{self.id}"]
592 outputs
= self
.outputs()
593 except AttributeError:
596 for name
in fields(self
):
597 if name
in ("id", "fn"):
599 v
= getattr(self
, name
, _NOT_SET
)
600 if (outputs
is not None and name
in outputs
601 and outputs
[name
] is v
):
602 fields_list
.append(repr(v
))
604 fields_list
.append(f
"{name}={v!r}")
605 fields_str
= ', '.join(fields_list
)
606 return f
"{self.__class__.__name__}({fields_str})"
609 def get_asm_lines(self
, ctx
):
610 # type: (AsmContext) -> list[str]
611 """get the lines of assembly for this Op"""
615 def pre_ra_sim(self
, state
):
616 # type: (PreRASimState) -> None
617 """simulate op before register allocation"""
621 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
623 class OpLoadFromStackSlot(Op
):
624 __slots__
= "dest", "src", "vl"
627 # type: () -> dict[str, SSAVal]
628 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
629 if self
.vl
is not None:
630 retval
["vl"] = self
.vl
634 # type: () -> dict[str, SSAVal]
635 return {"dest": self
.dest
}
637 def __init__(self
, fn
, src
, vl
=None):
638 # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None
640 self
.dest
= SSAVal(self
, "dest", GPRRangeType(src
.ty
.length_in_slots
))
643 assert_vl_is(vl
, self
.dest
.ty
.length
)
645 def get_asm_lines(self
, ctx
):
646 # type: (AsmContext) -> list[str]
647 dest
= ctx
.gpr(self
.dest
, vec
=self
.dest
.ty
.length
!= 1)
648 src
= ctx
.stack_slot(self
.src
)
649 if ctx
.needs_sv(self
.dest
):
650 return [f
"sv.ld {dest}, {src.start_byte}(1)"]
651 return [f
"ld {dest}, {src.start_byte}(1)"]
653 def pre_ra_sim(self
, state
):
654 # type: (PreRASimState) -> None
655 """simulate op before register allocation"""
656 state
.gprs
[self
.dest
] = state
.stack_slots
[self
.src
]
659 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
661 class OpStoreToStackSlot(Op
):
662 __slots__
= "dest", "src", "vl"
665 # type: () -> dict[str, SSAVal]
666 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
667 if self
.vl
is not None:
668 retval
["vl"] = self
.vl
672 # type: () -> dict[str, SSAVal]
673 return {"dest": self
.dest
}
675 def __init__(self
, fn
, src
, vl
=None):
676 # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None
678 self
.dest
= SSAVal(self
, "dest", StackSlotType(src
.ty
.length
))
681 assert_vl_is(vl
, src
.ty
.length
)
683 def get_asm_lines(self
, ctx
):
684 # type: (AsmContext) -> list[str]
685 src
= ctx
.gpr(self
.src
, vec
=self
.src
.ty
.length
!= 1)
686 dest
= ctx
.stack_slot(self
.dest
)
687 if ctx
.needs_sv(self
.src
):
688 return [f
"sv.std {src}, {dest.start_byte}(1)"]
689 return [f
"std {src}, {dest.start_byte}(1)"]
691 def pre_ra_sim(self
, state
):
692 # type: (PreRASimState) -> None
693 """simulate op before register allocation"""
694 state
.stack_slots
[self
.dest
] = state
.gprs
[self
.src
]
697 _RegSrcType
= TypeVar("_RegSrcType", bound
=RegType
)
700 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
702 class OpCopy(Op
, Generic
[_RegSrcType
, _RegType
]):
703 __slots__
= "dest", "src", "vl"
706 # type: () -> dict[str, SSAVal]
707 retval
= {"src": self
.src
} # type: dict[str, SSAVal[Any]]
708 if self
.vl
is not None:
709 retval
["vl"] = self
.vl
713 # type: () -> dict[str, SSAVal]
714 return {"dest": self
.dest
}
716 def __init__(self
, fn
, src
, dest_ty
=None, vl
=None):
717 # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None
720 dest_ty
= cast(_RegType
, src
.ty
)
721 if isinstance(src
.ty
, GPRRangeType
) \
722 and isinstance(dest_ty
, FixedGPRRangeType
):
723 if src
.ty
.length
!= dest_ty
.reg
.length
:
724 raise ValueError(f
"incompatible source and destination "
725 f
"types: {src.ty} and {dest_ty}")
726 length
= src
.ty
.length
727 elif isinstance(src
.ty
, FixedGPRRangeType
) \
728 and isinstance(dest_ty
, GPRRangeType
):
729 if src
.ty
.reg
.length
!= dest_ty
.length
:
730 raise ValueError(f
"incompatible source and destination "
731 f
"types: {src.ty} and {dest_ty}")
732 length
= src
.ty
.length
733 elif src
.ty
!= dest_ty
:
734 raise ValueError(f
"incompatible source and destination "
735 f
"types: {src.ty} and {dest_ty}")
736 elif isinstance(src
.ty
, StackSlotType
):
737 raise ValueError("can't use OpCopy on stack slots")
738 elif isinstance(src
.ty
, (GPRRangeType
, FixedGPRRangeType
)):
739 length
= src
.ty
.length
743 self
.dest
= SSAVal(self
, "dest", dest_ty
) # type: SSAVal[_RegType]
746 assert_vl_is(vl
, length
)
748 def get_asm_lines(self
, ctx
):
749 # type: (AsmContext) -> list[str]
750 if ctx
.reg(self
.src
, RegLoc
) == ctx
.reg(self
.dest
, RegLoc
):
752 if (isinstance(self
.src
.ty
, (GPRRangeType
, FixedGPRRangeType
)) and
753 isinstance(self
.dest
.ty
, (GPRRangeType
, FixedGPRRangeType
))):
754 vec
= self
.dest
.ty
.length
!= 1
755 dest
= ctx
.gpr_range(self
.dest
) # type: ignore
756 src
= ctx
.gpr_range(self
.src
) # type: ignore
757 dest_s
= ctx
.gpr(self
.dest
, vec
=vec
) # type: ignore
758 src_s
= ctx
.gpr(self
.src
, vec
=vec
) # type: ignore
760 if src
.conflicts(dest
) and src
.start
> dest
.start
:
762 if ctx
.needs_sv(self
.src
, self
.dest
): # type: ignore
763 return [f
"sv.or{mrr} {dest_s}, {src_s}, {src_s}"]
764 return [f
"or {dest_s}, {src_s}, {src_s}"]
765 raise NotImplementedError
767 def pre_ra_sim(self
, state
):
768 # type: (PreRASimState) -> None
769 if (isinstance(self
.src
.ty
, (GPRRangeType
, FixedGPRRangeType
)) and
770 isinstance(self
.dest
.ty
, (GPRRangeType
, FixedGPRRangeType
))):
771 if isinstance(self
.src
.ty
, GPRRangeType
):
772 v
= state
.gprs
[self
.src
] # type: ignore
774 v
= state
.fixed_gprs
[self
.src
] # type: ignore
775 if isinstance(self
.dest
.ty
, GPRRangeType
):
776 state
.gprs
[self
.dest
] = v
# type: ignore
778 state
.fixed_gprs
[self
.dest
] = v
# type: ignore
779 elif (isinstance(self
.src
.ty
, FixedGPRRangeType
) and
780 isinstance(self
.dest
.ty
, GPRRangeType
)):
781 state
.gprs
[self
.dest
] = state
.fixed_gprs
[self
.src
] # type: ignore
782 elif (isinstance(self
.src
.ty
, GPRRangeType
) and
783 isinstance(self
.dest
.ty
, FixedGPRRangeType
)):
784 state
.fixed_gprs
[self
.dest
] = state
.gprs
[self
.src
] # type: ignore
785 elif (isinstance(self
.src
.ty
, CAType
) and
786 self
.src
.ty
== self
.dest
.ty
):
787 state
.CAs
[self
.dest
] = state
.CAs
[self
.src
] # type: ignore
788 elif (isinstance(self
.src
.ty
, KnownVLType
) and
789 self
.src
.ty
== self
.dest
.ty
):
790 state
.VLs
[self
.dest
] = state
.VLs
[self
.src
] # type: ignore
791 elif (isinstance(self
.src
.ty
, GlobalMemType
) and
792 self
.src
.ty
== self
.dest
.ty
):
793 v
= state
.global_mems
[self
.src
] # type: ignore
794 state
.global_mems
[self
.dest
] = v
# type: ignore
796 raise NotImplementedError
799 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
802 __slots__
= "dest", "sources"
805 # type: () -> dict[str, SSAVal]
806 return {f
"sources[{i}]": v
for i
, v
in enumerate(self
.sources
)}
809 # type: () -> dict[str, SSAVal]
810 return {"dest": self
.dest
}
812 def __init__(self
, fn
, sources
):
813 # type: (Fn, Iterable[SSAGPRRange]) -> None
815 sources
= tuple(sources
)
816 self
.dest
= SSAVal(self
, "dest", GPRRangeType(
817 sum(i
.ty
.length
for i
in sources
)))
818 self
.sources
= sources
820 def get_equality_constraints(self
):
821 # type: () -> Iterable[EqualityConstraint]
822 yield EqualityConstraint([self
.dest
], [*self
.sources
])
824 def get_asm_lines(self
, ctx
):
825 # type: (AsmContext) -> list[str]
828 def pre_ra_sim(self
, state
):
829 # type: (PreRASimState) -> None
831 for src
in self
.sources
:
832 v
.extend(state
.gprs
[src
])
833 state
.gprs
[self
.dest
] = tuple(v
)
836 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
839 __slots__
= "results", "src"
842 # type: () -> dict[str, SSAVal]
843 return {"src": self
.src
}
846 # type: () -> dict[str, SSAVal]
847 return {i
.arg_name
: i
for i
in self
.results
}
849 def __init__(self
, fn
, src
, split_indexes
):
850 # type: (Fn, SSAGPRRange, Iterable[int]) -> None
852 ranges
= [] # type: list[GPRRangeType]
854 for i
in split_indexes
:
855 if not (0 < i
< src
.ty
.length
):
856 raise ValueError(f
"invalid split index: {i}, must be in "
857 f
"0 < i < {src.ty.length}")
858 ranges
.append(GPRRangeType(i
- last
))
860 ranges
.append(GPRRangeType(src
.ty
.length
- last
))
862 self
.results
= tuple(
863 SSAVal(self
, f
"results[{i}]", r
) for i
, r
in enumerate(ranges
))
865 def get_equality_constraints(self
):
866 # type: () -> Iterable[EqualityConstraint]
867 yield EqualityConstraint([*self
.results
], [self
.src
])
869 def get_asm_lines(self
, ctx
):
870 # type: (AsmContext) -> list[str]
873 def pre_ra_sim(self
, state
):
874 # type: (PreRASimState) -> None
875 rest
= state
.gprs
[self
.src
]
876 for dest
in reversed(self
.results
):
877 state
.gprs
[dest
] = rest
[-dest
.ty
.length
:]
878 rest
= rest
[:-dest
.ty
.length
]
881 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
883 class OpBigIntAddSub(Op
):
884 __slots__
= "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl"
887 # type: () -> dict[str, SSAVal]
888 retval
= {} # type: dict[str, SSAVal[Any]]
889 retval
["lhs"] = self
.lhs
890 retval
["rhs"] = self
.rhs
891 retval
["CA_in"] = self
.CA_in
892 if self
.vl
is not None:
893 retval
["vl"] = self
.vl
897 # type: () -> dict[str, SSAVal]
898 return {"out": self
.out
, "CA_out": self
.CA_out
}
900 def __init__(self
, fn
, lhs
, rhs
, CA_in
, is_sub
, vl
=None):
901 # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None
904 raise TypeError(f
"source types must match: "
905 f
"{lhs} doesn't match {rhs}")
906 self
.out
= SSAVal(self
, "out", lhs
.ty
)
910 self
.CA_out
= SSAVal(self
, "CA_out", CA_in
.ty
)
913 assert_vl_is(vl
, lhs
.ty
.length
)
915 def get_extra_interferences(self
):
916 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
917 yield self
.out
, self
.lhs
918 yield self
.out
, self
.rhs
920 def get_asm_lines(self
, ctx
):
921 # type: (AsmContext) -> list[str]
922 vec
= self
.out
.ty
.length
!= 1
923 out
= ctx
.gpr(self
.out
, vec
=vec
)
924 RA
= ctx
.gpr(self
.lhs
, vec
=vec
)
925 RB
= ctx
.gpr(self
.rhs
, vec
=vec
)
929 RA
, RB
= RB
, RA
# reorder to match subfe
930 if ctx
.needs_sv(self
.out
, self
.lhs
, self
.rhs
):
931 return [f
"sv.{mnemonic} {out}, {RA}, {RB}"]
932 return [f
"{mnemonic} {out}, {RA}, {RB}"]
934 def pre_ra_sim(self
, state
):
935 # type: (PreRASimState) -> None
936 carry
= state
.CAs
[self
.CA_in
]
937 out
= [] # type: list[int]
938 for l
, r
in zip(state
.gprs
[self
.lhs
], state
.gprs
[self
.rhs
]):
940 r
= r ^ GPR_VALUE_MASK
942 carry
= s
!= (s
& GPR_VALUE_MASK
)
943 out
.append(s
& GPR_VALUE_MASK
)
944 state
.CAs
[self
.CA_out
] = carry
945 state
.gprs
[self
.out
] = tuple(out
)
948 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
950 class OpBigIntMulDiv(Op
):
951 __slots__
= "RT", "RA", "RB", "RC", "RS", "is_div", "vl"
954 # type: () -> dict[str, SSAVal]
955 retval
= {} # type: dict[str, SSAVal[Any]]
956 retval
["RA"] = self
.RA
957 retval
["RB"] = self
.RB
958 retval
["RC"] = self
.RC
959 if self
.vl
is not None:
960 retval
["vl"] = self
.vl
964 # type: () -> dict[str, SSAVal]
965 return {"RT": self
.RT
, "RS": self
.RS
}
967 def __init__(self
, fn
, RA
, RB
, RC
, is_div
, vl
):
968 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None
970 self
.RT
= SSAVal(self
, "RT", RA
.ty
)
974 self
.RS
= SSAVal(self
, "RS", RC
.ty
)
977 assert_vl_is(vl
, RA
.ty
.length
)
979 def get_equality_constraints(self
):
980 # type: () -> Iterable[EqualityConstraint]
981 yield EqualityConstraint([self
.RC
], [self
.RS
])
983 def get_extra_interferences(self
):
984 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
985 yield self
.RT
, self
.RA
986 yield self
.RT
, self
.RB
987 yield self
.RT
, self
.RC
988 yield self
.RT
, self
.RS
989 yield self
.RS
, self
.RA
990 yield self
.RS
, self
.RB
992 def get_asm_lines(self
, ctx
):
993 # type: (AsmContext) -> list[str]
994 vec
= self
.RT
.ty
.length
!= 1
995 RT
= ctx
.gpr(self
.RT
, vec
=vec
)
996 RA
= ctx
.gpr(self
.RA
, vec
=vec
)
997 RB
= ctx
.sgpr(self
.RB
)
998 RC
= ctx
.sgpr(self
.RC
)
1001 mnemonic
= "divmod2du/mrr"
1002 return [f
"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
1004 def pre_ra_sim(self
, state
):
1005 # type: (PreRASimState) -> None
1006 carry
= state
.gprs
[self
.RC
][0]
1007 RA
= state
.gprs
[self
.RA
]
1008 RB
= state
.gprs
[self
.RB
][0]
1009 RT
= [0] * self
.RT
.ty
.length
1011 for i
in reversed(range(self
.RT
.ty
.length
)):
1012 if carry
< RB
and RB
!= 0:
1013 div
, mod
= divmod((carry
<< 64) | RA
[i
], RB
)
1014 RT
[i
] = div
& GPR_VALUE_MASK
1015 carry
= mod
& GPR_VALUE_MASK
1017 RT
[i
] = GPR_VALUE_MASK
1020 for i
in range(self
.RT
.ty
.length
):
1021 v
= RA
[i
] * RB
+ carry
1023 RT
[i
] = v
& GPR_VALUE_MASK
1024 state
.gprs
[self
.RS
] = carry
,
1025 state
.gprs
[self
.RT
] = tuple(RT
)
1030 class ShiftKind(Enum
):
1035 def make_big_int_carry_in(self
, fn
, inp
):
1036 # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]]
1037 if self
is ShiftKind
.Sl
or self
is ShiftKind
.Sr
:
1041 assert self
is ShiftKind
.Sra
1042 split
= OpSplit(fn
, inp
, [inp
.ty
.length
- 1])
1043 shr
= OpShiftImm(fn
, split
.results
[1], sh
=63, kind
=ShiftKind
.Sra
)
1044 return shr
.out
, [split
, shr
]
1046 def make_big_int_shift(self
, fn
, inp
, sh
, vl
):
1047 # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]]
1048 carry_in
, ops
= self
.make_big_int_carry_in(fn
, inp
)
1049 big_int_shift
= OpBigIntShift(fn
, inp
, sh
, carry_in
, kind
=self
, vl
=vl
)
1050 ops
.append(big_int_shift
)
1051 return big_int_shift
.out
, ops
1054 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1056 class OpBigIntShift(Op
):
1057 __slots__
= "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl"
1060 # type: () -> dict[str, SSAVal]
1061 retval
= {} # type: dict[str, SSAVal[Any]]
1062 retval
["inp"] = self
.inp
1063 retval
["sh"] = self
.sh
1064 retval
["carry_in"] = self
.carry_in
1065 if self
.vl
is not None:
1066 retval
["vl"] = self
.vl
1070 # type: () -> dict[str, SSAVal]
1071 return {"out": self
.out
, "_out_padding": self
._out
_padding
}
1073 def __init__(self
, fn
, inp
, sh
, carry_in
, kind
, vl
=None):
1074 # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None
1075 super().__init
__(fn
)
1076 self
.out
= SSAVal(self
, "out", inp
.ty
)
1077 self
._out
_padding
= SSAVal(self
, "_out_padding", GPRRangeType())
1078 self
.carry_in
= carry_in
1083 assert_vl_is(vl
, inp
.ty
.length
)
1085 def get_extra_interferences(self
):
1086 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1087 yield self
.out
, self
.sh
1089 def get_equality_constraints(self
):
1090 # type: () -> Iterable[EqualityConstraint]
1091 if self
.kind
is ShiftKind
.Sl
:
1092 yield EqualityConstraint([self
.carry_in
, self
.inp
],
1093 [self
.out
, self
._out
_padding
])
1095 assert self
.kind
is ShiftKind
.Sr
or self
.kind
is ShiftKind
.Sra
1096 yield EqualityConstraint([self
.inp
, self
.carry_in
],
1097 [self
._out
_padding
, self
.out
])
1099 def get_asm_lines(self
, ctx
):
1100 # type: (AsmContext) -> list[str]
1101 vec
= self
.out
.ty
.length
!= 1
1102 if self
.kind
is ShiftKind
.Sl
:
1103 RT
= ctx
.gpr(self
.out
, vec
=vec
)
1104 RA
= ctx
.gpr(self
.out
, vec
=vec
, offset
=-1)
1105 RB
= ctx
.sgpr(self
.sh
)
1106 mrr
= "/mrr" if vec
else ""
1107 return [f
"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"]
1109 assert self
.kind
is ShiftKind
.Sr
or self
.kind
is ShiftKind
.Sra
1110 RT
= ctx
.gpr(self
.out
, vec
=vec
)
1111 RA
= ctx
.gpr(self
.out
, vec
=vec
, offset
=1)
1112 RB
= ctx
.sgpr(self
.sh
)
1113 return [f
"sv.dsrd {RT}, {RA}, {RB}, 1"]
1115 def pre_ra_sim(self
, state
):
1116 # type: (PreRASimState) -> None
1117 out
= [0] * self
.out
.ty
.length
1118 carry
= state
.gprs
[self
.carry_in
][0]
1119 sh
= state
.gprs
[self
.sh
][0] % 64
1120 if self
.kind
is ShiftKind
.Sl
:
1121 inp
= carry
, *state
.gprs
[self
.inp
]
1122 for i
in reversed(range(self
.out
.ty
.length
)):
1123 v
= inp
[i
] |
(inp
[i
+ 1] << 64)
1125 out
[i
] = (v
>> 64) & GPR_VALUE_MASK
1127 assert self
.kind
is ShiftKind
.Sr
or self
.kind
is ShiftKind
.Sra
1128 inp
= *state
.gprs
[self
.inp
], carry
1129 for i
in range(self
.out
.ty
.length
):
1130 v
= inp
[i
] |
(inp
[i
+ 1] << 64)
1132 out
[i
] = v
& GPR_VALUE_MASK
1133 # state.gprs[self._out_padding] is intentionally not written
1134 state
.gprs
[self
.out
] = tuple(out
)
1137 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1139 class OpShiftImm(Op
):
1140 __slots__
= "out", "inp", "sh", "kind", "ca_out"
1143 # type: () -> dict[str, SSAVal]
1144 return {"inp": self
.inp
}
1147 # type: () -> dict[str, SSAVal]
1148 if self
.ca_out
is not None:
1149 return {"out": self
.out
, "ca_out": self
.ca_out
}
1150 return {"out": self
.out
}
1152 def __init__(self
, fn
, inp
, sh
, kind
):
1153 # type: (Fn, SSAGPR, int, ShiftKind) -> None
1154 super().__init
__(fn
)
1155 self
.out
= SSAVal(self
, "out", inp
.ty
)
1157 if not (0 <= sh
< 64):
1158 raise ValueError("shift amount out of range")
1161 if self
.kind
is ShiftKind
.Sra
:
1162 self
.ca_out
= SSAVal(self
, "ca_out", CAType())
1166 def get_asm_lines(self
, ctx
):
1167 # type: (AsmContext) -> list[str]
1168 out
= ctx
.sgpr(self
.out
)
1169 inp
= ctx
.sgpr(self
.inp
)
1170 if self
.kind
is ShiftKind
.Sl
:
1172 args
= f
"{self.sh}, {63 - self.sh}"
1173 elif self
.kind
is ShiftKind
.Sr
:
1175 v
= (64 - self
.sh
) % 64
1176 args
= f
"{v}, {self.sh}"
1178 assert self
.kind
is ShiftKind
.Sra
1181 if ctx
.needs_sv(self
.out
, self
.inp
):
1182 return [f
"sv.{mnemonic} {out}, {inp}, {args}"]
1183 return [f
"{mnemonic} {out}, {inp}, {args}"]
1185 def pre_ra_sim(self
, state
):
1186 # type: (PreRASimState) -> None
1187 inp
= state
.gprs
[self
.inp
][0]
1188 if self
.kind
is ShiftKind
.Sl
:
1189 assert self
.ca_out
is None
1190 out
= inp
<< self
.sh
1191 elif self
.kind
is ShiftKind
.Sr
:
1192 assert self
.ca_out
is None
1193 out
= inp
>> self
.sh
1195 assert self
.kind
is ShiftKind
.Sra
1196 assert self
.ca_out
is not None
1197 if inp
& (1 << 63): # sign extend
1199 out
= inp
>> self
.sh
1200 ca
= inp
< 0 and (out
<< self
.sh
) != inp
1201 state
.CAs
[self
.ca_out
] = ca
1202 state
.gprs
[self
.out
] = out
,
1205 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1208 __slots__
= "out", "value", "vl"
1211 # type: () -> dict[str, SSAVal]
1212 retval
= {} # type: dict[str, SSAVal[Any]]
1213 if self
.vl
is not None:
1214 retval
["vl"] = self
.vl
1218 # type: () -> dict[str, SSAVal]
1219 return {"out": self
.out
}
1221 def __init__(self
, fn
, value
, vl
=None):
1222 # type: (Fn, int, SSAKnownVL | None) -> None
1223 super().__init
__(fn
)
1227 length
= vl
.ty
.length
1228 self
.out
= SSAVal(self
, "out", GPRRangeType(length
))
1229 if not (-1 << 15 <= value
<= (1 << 15) - 1):
1230 raise ValueError(f
"value out of range: {value}")
1233 assert_vl_is(vl
, length
)
1235 def get_asm_lines(self
, ctx
):
1236 # type: (AsmContext) -> list[str]
1237 vec
= self
.out
.ty
.length
!= 1
1238 out
= ctx
.gpr(self
.out
, vec
=vec
)
1239 if ctx
.needs_sv(self
.out
):
1240 return [f
"sv.addi {out}, 0, {self.value}"]
1241 return [f
"addi {out}, 0, {self.value}"]
1243 def pre_ra_sim(self
, state
):
1244 # type: (PreRASimState) -> None
1245 value
= self
.value
& GPR_VALUE_MASK
1246 state
.gprs
[self
.out
] = (value
,) * self
.out
.ty
.length
1249 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1252 __slots__
= "out", "value"
1255 # type: () -> dict[str, SSAVal]
1259 # type: () -> dict[str, SSAVal]
1260 return {"out": self
.out
}
1262 def __init__(self
, fn
, value
):
1263 # type: (Fn, bool) -> None
1264 super().__init
__(fn
)
1265 self
.out
= SSAVal(self
, "out", CAType())
1268 def get_asm_lines(self
, ctx
):
1269 # type: (AsmContext) -> list[str]
1271 return ["subfic 0, 0, -1"]
1272 return ["addic 0, 0, 0"]
1274 def pre_ra_sim(self
, state
):
1275 # type: (PreRASimState) -> None
1276 state
.CAs
[self
.out
] = self
.value
1279 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1282 __slots__
= "RT", "RA", "offset", "mem", "vl"
1285 # type: () -> dict[str, SSAVal]
1286 retval
= {} # type: dict[str, SSAVal[Any]]
1287 retval
["RA"] = self
.RA
1288 retval
["mem"] = self
.mem
1289 if self
.vl
is not None:
1290 retval
["vl"] = self
.vl
1294 # type: () -> dict[str, SSAVal]
1295 return {"RT": self
.RT
}
1297 def __init__(self
, fn
, RA
, offset
, mem
, vl
=None):
1298 # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1299 super().__init
__(fn
)
1303 length
= vl
.ty
.length
1304 self
.RT
= SSAVal(self
, "RT", GPRRangeType(length
))
1306 if not (-1 << 15 <= offset
<= (1 << 15) - 1):
1307 raise ValueError(f
"offset out of range: {offset}")
1309 raise ValueError(f
"offset not aligned: {offset}")
1310 self
.offset
= offset
1313 assert_vl_is(vl
, length
)
1315 def get_extra_interferences(self
):
1316 # type: () -> Iterable[tuple[SSAVal, SSAVal]]
1317 if self
.RT
.ty
.length
> 1:
1318 yield self
.RT
, self
.RA
1320 def get_asm_lines(self
, ctx
):
1321 # type: (AsmContext) -> list[str]
1322 RT
= ctx
.gpr(self
.RT
, vec
=self
.RT
.ty
.length
!= 1)
1323 RA
= ctx
.sgpr(self
.RA
)
1324 if ctx
.needs_sv(self
.RT
, self
.RA
):
1325 return [f
"sv.ld {RT}, {self.offset}({RA})"]
1326 return [f
"ld {RT}, {self.offset}({RA})"]
1328 def pre_ra_sim(self
, state
):
1329 # type: (PreRASimState) -> None
1330 addr
= state
.gprs
[self
.RA
][0]
1332 RT
= [0] * self
.RT
.ty
.length
1333 mem
= state
.global_mems
[self
.mem
]
1334 for i
in range(self
.RT
.ty
.length
):
1335 cur_addr
= (addr
+ i
* GPR_SIZE_IN_BYTES
) & GPR_VALUE_MASK
1336 if cur_addr
% GPR_SIZE_IN_BYTES
!= 0:
1337 raise ValueError(f
"can't load from unaligned address: "
1339 for j
in range(GPR_SIZE_IN_BYTES
):
1340 byte_val
= mem
.get(cur_addr
+ j
, 0) & 0xFF
1341 RT
[i
] |
= byte_val
<< (j
* 8)
1342 state
.gprs
[self
.RT
] = tuple(RT
)
1345 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1348 __slots__
= "RS", "RA", "offset", "mem_in", "mem_out", "vl"
1351 # type: () -> dict[str, SSAVal]
1352 retval
= {} # type: dict[str, SSAVal[Any]]
1353 retval
["RS"] = self
.RS
1354 retval
["RA"] = self
.RA
1355 retval
["mem_in"] = self
.mem_in
1356 if self
.vl
is not None:
1357 retval
["vl"] = self
.vl
1361 # type: () -> dict[str, SSAVal]
1362 return {"mem_out": self
.mem_out
}
1364 def __init__(self
, fn
, RS
, RA
, offset
, mem_in
, vl
=None):
1365 # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
1366 super().__init
__(fn
)
1369 if not (-1 << 15 <= offset
<= (1 << 15) - 1):
1370 raise ValueError(f
"offset out of range: {offset}")
1372 raise ValueError(f
"offset not aligned: {offset}")
1373 self
.offset
= offset
1374 self
.mem_in
= mem_in
1375 self
.mem_out
= SSAVal(self
, "mem_out", mem_in
.ty
)
1377 assert_vl_is(vl
, RS
.ty
.length
)
1379 def get_asm_lines(self
, ctx
):
1380 # type: (AsmContext) -> list[str]
1381 RS
= ctx
.gpr(self
.RS
, vec
=self
.RS
.ty
.length
!= 1)
1382 RA
= ctx
.sgpr(self
.RA
)
1383 if ctx
.needs_sv(self
.RS
, self
.RA
):
1384 return [f
"sv.std {RS}, {self.offset}({RA})"]
1385 return [f
"std {RS}, {self.offset}({RA})"]
1387 def pre_ra_sim(self
, state
):
1388 # type: (PreRASimState) -> None
1389 mem
= dict(state
.global_mems
[self
.mem_in
])
1390 addr
= state
.gprs
[self
.RA
][0]
1392 RS
= state
.gprs
[self
.RS
]
1393 for i
in range(self
.RS
.ty
.length
):
1394 cur_addr
= (addr
+ i
* GPR_SIZE_IN_BYTES
) & GPR_VALUE_MASK
1395 if cur_addr
% GPR_SIZE_IN_BYTES
!= 0:
1396 raise ValueError(f
"can't store to unaligned address: "
1398 for j
in range(GPR_SIZE_IN_BYTES
):
1399 mem
[cur_addr
+ j
] = (RS
[i
] >> (j
* 8)) & 0xFF
1400 state
.global_mems
[self
.mem_out
] = FMap(mem
)
1403 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1405 class OpFuncArg(Op
):
1409 # type: () -> dict[str, SSAVal]
1413 # type: () -> dict[str, SSAVal]
1414 return {"out": self
.out
}
1416 def __init__(self
, fn
, ty
):
1417 # type: (Fn, FixedGPRRangeType) -> None
1418 super().__init
__(fn
)
1419 self
.out
= SSAVal(self
, "out", ty
)
1421 def get_asm_lines(self
, ctx
):
1422 # type: (AsmContext) -> list[str]
1425 def pre_ra_sim(self
, state
):
1426 # type: (PreRASimState) -> None
1427 if self
.out
not in state
.fixed_gprs
:
1428 state
.fixed_gprs
[self
.out
] = (0,) * self
.out
.ty
.length
1431 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1433 class OpInputMem(Op
):
1437 # type: () -> dict[str, SSAVal]
1441 # type: () -> dict[str, SSAVal]
1442 return {"out": self
.out
}
1444 def __init__(self
, fn
):
1445 # type: (Fn) -> None
1446 super().__init
__(fn
)
1447 self
.out
= SSAVal(self
, "out", GlobalMemType())
1449 def get_asm_lines(self
, ctx
):
1450 # type: (AsmContext) -> list[str]
1453 def pre_ra_sim(self
, state
):
1454 # type: (PreRASimState) -> None
1455 if self
.out
not in state
.global_mems
:
1456 state
.global_mems
[self
.out
] = FMap()
1459 @plain_data(unsafe_hash
=True, frozen
=True, repr=False)
1461 class OpSetVLImm(Op
):
1465 # type: () -> dict[str, SSAVal]
1469 # type: () -> dict[str, SSAVal]
1470 return {"out": self
.out
}
1472 def __init__(self
, fn
, length
):
1473 # type: (Fn, int) -> None
1474 super().__init
__(fn
)
1475 self
.out
= SSAVal(self
, "out", KnownVLType(length
))
1477 def get_asm_lines(self
, ctx
):
1478 # type: (AsmContext) -> list[str]
1479 return [f
"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
1481 def pre_ra_sim(self
, state
):
1482 # type: (PreRASimState) -> None
1483 state
.VLs
[self
.out
] = self
.out
.ty
.length
1486 def op_set_to_list(ops
):
1487 # type: (Iterable[Op]) -> list[Op]
1488 worklists
= [{}] # type: list[dict[Op, None]]
1489 inps_to_ops_map
= defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
1490 ops_to_pending_input_count_map
= {} # type: dict[Op, int]
1493 for val
in op
.inputs().values():
1495 inps_to_ops_map
[val
][op
] = None
1496 while len(worklists
) <= input_count
:
1497 worklists
.append({})
1498 ops_to_pending_input_count_map
[op
] = input_count
1499 worklists
[input_count
][op
] = None
1500 retval
= [] # type: list[Op]
1501 ready_vals
= OSet() # type: OSet[SSAVal]
1502 while len(worklists
[0]) != 0:
1503 writing_op
= next(iter(worklists
[0]))
1504 del worklists
[0][writing_op
]
1505 retval
.append(writing_op
)
1506 for val
in writing_op
.outputs().values():
1507 if val
in ready_vals
:
1508 raise ValueError(f
"multiple instructions must not write "
1509 f
"to the same SSA value: {val}")
1511 for reading_op
in inps_to_ops_map
[val
]:
1512 pending
= ops_to_pending_input_count_map
[reading_op
]
1513 del worklists
[pending
][reading_op
]
1515 worklists
[pending
][reading_op
] = None
1516 ops_to_pending_input_count_map
[reading_op
] = pending
1517 for worklist
in worklists
:
1519 raise ValueError(f
"instruction is part of a dependency loop or "
1520 f
"its inputs are never written: {op}")
1524 def generate_assembly(ops
, assigned_registers
=None):
1525 # type: (list[Op], dict[SSAVal, RegLoc] | None) -> list[str]
1526 if assigned_registers
is None:
1527 from bigint_presentation_code
.register_allocator
import \
1529 assigned_registers
= allocate_registers(ops
)
1530 ctx
= AsmContext(assigned_registers
)
1531 retval
= [] # list[str]
1533 retval
.extend(op
.get_asm_lines(ctx
))
1534 retval
.append("bclr 20, 0, 0")