from collections import defaultdict
from enum import Enum, EnumMeta, unique
from functools import lru_cache
-from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence,
- TypeVar, cast)
+from typing import TYPE_CHECKING, Generic, Iterable, Sequence, TypeVar, cast
from cached_property import cached_property
from nmutil.plain_data import fields, plain_data
+from bigint_presentation_code.ordered_set import OFSet, OSet
+
if TYPE_CHECKING:
from typing_extensions import final
else:
def get_subreg_at_offset(self, subreg_type, offset):
# type: (RegType, int) -> GPRRange
- if not isinstance(subreg_type, GPRRangeType):
- raise ValueError(f"subreg_type is not a "
+ if not isinstance(subreg_type, (GPRRangeType, FixedGPRRangeType)):
+ raise ValueError(f"subreg_type is not a FixedGPRRangeType or "
f"GPRRangeType: {subreg_type}")
if offset < 0 or offset + subreg_type.length > self.stop:
raise ValueError(f"sub-register offset is out of range: {offset}")
@final
-class RegClass(AbstractSet[RegLoc]):
+class RegClass(OFSet[RegLoc]):
""" an ordered set of registers.
earlier registers are preferred by the register allocator.
"""
- def __init__(self, regs):
- # type: (Iterable[RegLoc]) -> None
-
- # use dict to maintain order
- self.__regs = dict.fromkeys(regs) # type: dict[RegLoc, None]
-
- def __len__(self):
- return len(self.__regs)
-
- def __iter__(self):
- return iter(self.__regs)
-
- def __contains__(self, v):
- # type: (RegLoc) -> bool
- return v in self.__regs
-
- def __hash__(self):
- return super()._hash()
-
@lru_cache(maxsize=None, typed=True)
def max_conflicts_with(self, other):
# type: (RegClass | RegLoc) -> int
@plain_data(frozen=True, unsafe_hash=True)
@final
-class FixedGPRRangeType(GPRRangeType):
+class FixedGPRRangeType(RegType):
__slots__ = "reg",
def __init__(self, reg):
# type: (GPRRange) -> None
- super().__init__(length=reg.length)
self.reg = reg
@property
# type: () -> RegClass
return RegClass([self.reg])
+ @property
+ def length(self):
+ # type: () -> int
+ return self.reg.length
+
@plain_data(frozen=True, unsafe_hash=True)
@final
def __hash__(self):
return hash((id(self.op), self.arg_name))
- def __repr__(self):
+ def __repr__(self, long=False):
+ if not long:
+ return f"<#{self.op.id}.{self.arg_name}>"
fields_list = []
for name in fields(self):
v = getattr(self, name, None)
@cached_property
def id(self):
+ # type: () -> int
+ # use cached_property rather than done in init so id is usable even if
+ # init hasn't run
retval = Op.__NEXT_ID
Op.__NEXT_ID += 1
return retval
+ def __init__(self):
+ self.id # initialize
+
@final
def __repr__(self, just_id=False):
fields_list = [f"#{self.id}"]
+ outputs = None
+ try:
+ outputs = self.outputs()
+ except AttributeError:
+ pass
if not just_id:
for name in fields(self):
v = getattr(self, name, _NOT_SET)
- fields_list.append(f"{name}={v!r}")
+ if ((outputs is None or name in outputs)
+ and isinstance(v, SSAVal)):
+ v = v.__repr__(long=True)
+ else:
+ v = repr(v)
+ fields_list.append(f"{name}={v}")
fields_str = ', '.join(fields_list)
return f"{self.__class__.__name__}({fields_str})"
def __init__(self, src):
# type: (SSAVal[GPRRangeType]) -> None
+ super().__init__()
self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
self.src = src
def __init__(self, src):
# type: (SSAVal[StackSlotType]) -> None
+ super().__init__()
self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
self.src = src
def __init__(self, src, dest_ty=None):
# type: (SSAVal[_RegSrcType], _RegType | None) -> None
+ super().__init__()
if dest_ty is None:
dest_ty = cast(_RegType, src.ty)
if isinstance(src.ty, GPRRangeType) \
+ and isinstance(dest_ty, FixedGPRRangeType):
+ if src.ty.length != dest_ty.reg.length:
+ raise ValueError(f"incompatible source and destination "
+ f"types: {src.ty} and {dest_ty}")
+ elif isinstance(src.ty, FixedGPRRangeType) \
and isinstance(dest_ty, GPRRangeType):
- if src.ty.length != dest_ty.length:
+ if src.ty.reg.length != dest_ty.length:
raise ValueError(f"incompatible source and destination "
f"types: {src.ty} and {dest_ty}")
elif src.ty != dest_ty:
def __init__(self, sources):
# type: (Iterable[SSAVal[GPRRangeType]]) -> None
+ super().__init__()
sources = tuple(sources)
self.dest = SSAVal(self, "dest", GPRRangeType(
sum(i.ty.length for i in sources)))
def __init__(self, src, split_indexes):
# type: (SSAVal[GPRRangeType], Iterable[int]) -> None
+ super().__init__()
ranges = [] # type: list[GPRRangeType]
last = 0
for i in split_indexes:
def __init__(self, RA, RB, CY_in, is_sub):
# type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
+ super().__init__()
if RA.ty != RB.ty:
raise TypeError(f"source types must match: "
f"{RA} doesn't match {RB}")
def __init__(self, RA, RB, RC, is_div):
# type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
+ super().__init__()
self.RT = SSAVal(self, "RT", RA.ty)
self.RA = RA
self.RB = RB
def __init__(self, inp, sh, kind):
# type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
+ super().__init__()
self.RT = SSAVal(self, "RT", inp.ty)
self.inp = inp
self.sh = sh
def __init__(self, value, length=1):
# type: (int, int) -> None
+ super().__init__()
self.out = SSAVal(self, "out", GPRRangeType(length))
self.value = value
def __init__(self):
# type: () -> None
+ super().__init__()
self.out = SSAVal(self, "out", CYType())
def __init__(self, RA, offset, mem, length=1):
# type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
+ super().__init__()
self.RT = SSAVal(self, "RT", GPRRangeType(length))
self.RA = RA
self.offset = offset
def __init__(self, RS, RA, offset, mem_in):
# type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
+ super().__init__()
self.RS = RS
self.RA = RA
self.offset = offset
def __init__(self, ty):
# type: (FixedGPRRangeType) -> None
+ super().__init__()
self.out = SSAVal(self, "out", ty)
def __init__(self):
# type: () -> None
+ super().__init__()
self.out = SSAVal(self, "out", GlobalMemType())
ops_to_pending_input_count_map[op] = input_count
worklists[input_count][op] = None
retval = [] # type: list[Op]
- ready_vals = set() # type: set[SSAVal]
+ ready_vals = OSet() # type: OSet[SSAVal]
while len(worklists[0]) != 0:
writing_op = next(iter(worklists[0]))
del worklists[0][writing_op]