implement more of new compiler ir
authorJacob Lifshay <programmerjake@gmail.com>
Sun, 30 Oct 2022 09:20:48 +0000 (02:20 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sun, 30 Oct 2022 09:20:48 +0000 (02:20 -0700)
src/bigint_presentation_code/compiler_ir2.py

index 3e2ad0c15a5e1effa184cd67916ab8cf660b5d61..b04848b53a37be825fc16d2d07692edf5a01931b 100644 (file)
@@ -1,16 +1,16 @@
-from collections import defaultdict
 import enum
+from abc import abstractmethod
 from enum import Enum, unique
-from typing import AbstractSet, Any, Iterable, Iterator, NoReturn, Tuple, Union, Mapping, overload
+from functools import lru_cache
+from typing import (AbstractSet, Any, Generic, Iterable, Iterator, Sequence,
+                    TypeVar, overload)
 from weakref import WeakValueDictionary as _WeakVDict
 
 from cached_property import cached_property
 from nmutil.plain_data import plain_data
 
 from bigint_presentation_code.type_util import Self, assert_never, final
-from bigint_presentation_code.util import (BaseBitSet, BitSet, FBitSet, OFSet,
-                                           OSet, FMap)
-from functools import lru_cache
+from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet
 
 
 @final
@@ -461,11 +461,11 @@ class GenericOperandDesc:
                 return
             for sub_kind in self.sub_kinds:
                 yield from sub_kind.allocatable_locs(ty)
-        loc_set = LocSet(locs())
+        loc_set_before_spread = LocSet(locs())
         for idx in range(rep_count):
             if not self.spread:
                 idx = None
-            yield OperandDesc(loc_set=loc_set,
+            yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
                               tied_input_index=self.tied_input_index,
                               spread_index=idx)
 
@@ -474,16 +474,33 @@ class GenericOperandDesc:
 @final
 class OperandDesc:
     """Op operand descriptor"""
-    __slots__ = "loc_set", "tied_input_index", "spread_index"
+    __slots__ = "loc_set_before_spread", "tied_input_index", "spread_index"
 
-    def __init__(self, loc_set, tied_input_index, spread_index):
+    def __init__(self, loc_set_before_spread, tied_input_index, spread_index):
         # type: (LocSet, int | None, int | None) -> None
-        if len(loc_set) == 0:
-            raise ValueError("loc_set must not be empty")
-        self.loc_set = loc_set
+        if len(loc_set_before_spread) == 0:
+            raise ValueError("loc_set_before_spread must not be empty")
+        self.loc_set_before_spread = loc_set_before_spread
         self.tied_input_index = tied_input_index
         if self.tied_input_index is not None and self.spread_index is not None:
             raise ValueError("operand can't be both spread and tied")
+        self.spread_index = spread_index
+
+    @cached_property
+    def ty_before_spread(self):
+        # type: () -> Ty
+        ty = self.loc_set_before_spread.ty
+        assert ty is not None, (
+            "__init__ checked that the LocSet isn't empty, "
+            "non-empty LocSets should always have ty set")
+        return ty
+
+    @cached_property
+    def ty(self):
+        """ Ty after any spread is applied """
+        if self.spread_index is not None:
+            return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
+        return self.ty_before_spread
 
 
 OD_BASE_SGPR = GenericOperandDesc(
@@ -554,7 +571,7 @@ class GenericOpProperties:
 @plain_data(frozen=True, unsafe_hash=True)
 @final
 class OpProperties:
-    __slots__ = "kind", "inputs", "outputs"
+    __slots__ = "kind", "inputs", "outputs", "maxvl"
 
     def __init__(self, kind, maxvl):
         # type: (OpKind, int) -> None
@@ -567,6 +584,7 @@ class OpProperties:
         for out in self.generic.outputs:
             outputs.extend(out.instantiate(maxvl=maxvl))
         self.outputs = tuple(outputs)
+        self.maxvl = maxvl
 
     @property
     def generic(self):
@@ -715,137 +733,186 @@ class OpKind(Enum):
     )
 
 
-# FIXME: rewrite from here
-
-
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
 class SSAVal:
-    __slots__ = "sliced_op_outputs",
+    __slots__ = "op", "output_idx"
 
-    _SlicedOpOutputIn = Union["tuple[Op, int, int | range | slice]",
-                              "tuple[Op, int]", "SSAVal"]
+    def __init__(self, op, output_idx):
+        # type: (Op, int) -> None
+        self.op = op
+        if output_idx < 0 or output_idx >= len(op.properties.outputs):
+            raise ValueError("invalid output_idx")
+        self.output_idx = output_idx
 
-    @staticmethod
-    def __process_sliced_op_outputs(inp):
-        # type: (Iterable[_SlicedOpOutputIn]) -> Iterable[Tuple["Op", int, range]]
-        for v in inp:
-            if isinstance(v, SSAVal):
-                yield from v.sliced_op_outputs
-                continue
-            op = v[0]
-            output_index = v[1]
-            if output_index < 0 or output_index >= len(op.properties.outputs):
-                raise ValueError("invalid output_index")
-            cur_len = op.properties.outputs[output_index].get_length(op.maxvl)
-            slice_ = slice(None) if len(v) == 2 else v[2]
-            if isinstance(slice_, range):
-                slice_ = slice(slice_.start, slice_.stop, slice_.step)
-            if isinstance(slice_, int):
-                # raise exception for out-of-range values
-                idx = range(cur_len)[slice_]
-                range_ = range(idx, idx + 1)
-            else:
-                # raise exception for out-of-range values
-                range_ = range(cur_len)[slice_]
-                if range_.step != 1:
-                    raise ValueError("slice step must be 1")
-                if len(range_) == 0:
-                    continue
-            yield op, output_index, range_
+    def __repr__(self):
+        # type: () -> str
+        return f"<{self.op.name}#{self.output_idx}>"
 
-    def __init__(self, sliced_op_outputs):
-        # type: (Iterable[_SlicedOpOutputIn] | SSAVal) -> None
-        # we have length arg so plain_data.replace works
-        if isinstance(sliced_op_outputs, SSAVal):
-            inp = sliced_op_outputs.sliced_op_outputs
-        else:
-            inp = SSAVal.__process_sliced_op_outputs(sliced_op_outputs)
-        processed = []  # type: list[tuple[Op, int, range]]
-        length = 0
-        for op, output_index, range_ in inp:
-            length += len(range_)
-            if len(processed) == 0:
-                processed.append((op, output_index, range_))
-                continue
-            last_op, last_output_index, last_range_ = processed[-1]
-            if last_op == op and last_output_index == output_index \
-                    and last_range_.stop == range_.start:
-                # merge slices
-                range_ = range(last_range_.start, range_.stop)
-                processed[-1] = op, output_index, range_
-            else:
-                processed.append((op, output_index, range_))
-        self.sliced_op_outputs = tuple(processed)
-
-    def __add__(self, other):
-        # type: (SSAVal | Any) -> SSAVal
-        if not isinstance(other, SSAVal):
-            return NotImplemented
-        return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs)
-
-    def __radd__(self, other):
-        # type: (SSAVal | Any) -> SSAVal
-        if isinstance(other, SSAVal):
-            return other.__add__(self)
-        return NotImplemented
+    @cached_property
+    def defining_descriptor(self):
+        # type: () -> OperandDesc
+        return self.op.properties.outputs[self.output_idx]
 
     @cached_property
-    def expanded_sliced_op_outputs(self):
-        # type: () -> tuple[tuple[Op, int, int], ...]
-        retval = []  # type: list[tuple[Op, int, int]]
-        for op, output_index, range_ in self.sliced_op_outputs:
-            for i in range_:
-                retval.append((op, output_index, i))
-        # must be tuple to not be modifiable since it's cached
-        return tuple(retval)
+    def loc_set_before_spread(self):
+        # type: () -> LocSet
+        return self.defining_descriptor.loc_set_before_spread
 
+    @cached_property
+    def ty(self):
+        # type: () -> Ty
+        return self.defining_descriptor.ty
+
+    @cached_property
+    def ty_before_spread(self):
+        # type: () -> Ty
+        return self.defining_descriptor.ty_before_spread
+
+
+_T = TypeVar("_T")
+_Desc = TypeVar("_Desc")
+
+
+class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
+    @abstractmethod
+    def _verify_write_with_desc(self, idx, item, desc):
+        # type: (int, _T | Any, _Desc) -> None
+        raise NotImplementedError
+
+    @final
+    def _verify_write(self, idx, item):
+        # type: (int | Any, _T | Any) -> int
+        if not isinstance(idx, int):
+            if isinstance(idx, slice):
+                raise TypeError(
+                    f"can't write to slice of {self.__class__.__name__}")
+            raise TypeError(f"can't write with index {idx!r}")
+        # normalize idx, raising IndexError if it is out of range
+        idx = range(len(self.descriptors))[idx]
+        desc = self.descriptors[idx]
+        self._verify_write_with_desc(idx, item, desc)
+        return idx
+
+    @abstractmethod
+    def _get_descriptors(self):
+        # type: () -> tuple[_Desc, ...]
+        raise NotImplementedError
+
+    @cached_property
+    @final
+    def descriptors(self):
+        # type: () -> tuple[_Desc, ...]
+        return self._get_descriptors()
+
+    @property
+    @final
+    def op(self):
+        return self.__op
+
+    def __init__(self, items, op):
+        # type: (Iterable[_T], Op) -> None
+        self.__op = op
+        self.__items = []  # type: list[_T]
+        for idx, item in enumerate(items):
+            if idx >= len(self.descriptors):
+                raise ValueError("too many items")
+            self._verify_write(idx, item)
+            self.__items.append(item)
+        if len(self.__items) < len(self.descriptors):
+            raise ValueError("not enough items")
+
+    @final
+    def __iter__(self):
+        # type: () -> Iterator[_T]
+        yield from self.__items
+
+    @overload
     def __getitem__(self, idx):
-        # type: (int | slice) -> SSAVal
-        if isinstance(idx, int):
-            return SSAVal([self.expanded_sliced_op_outputs[idx]])
-        return SSAVal(self.expanded_sliced_op_outputs[idx])
+        # type: (int) -> _T
+        ...
 
+    @overload
+    def __getitem__(self, idx):
+        # type: (slice) -> list[_T]
+        ...
+
+    @final
+    def __getitem__(self, idx):
+        # type: (int | slice) -> _T | list[_T]
+        return self.__items[idx]
+
+    @final
+    def __setitem__(self, idx, item):
+        # type: (int, _T) -> None
+        idx = self._verify_write(idx, item)
+        self.__items[idx] = item
+
+    @final
     def __len__(self):
-        return len(self.expanded_sliced_op_outputs)
+        # type: () -> int
+        return len(self.__items)
 
-    def __iter__(self):
-        # type: () -> Iterator[SSAVal]
-        for v in self.expanded_sliced_op_outputs:
-            yield SSAVal([v])
 
-    def __repr__(self):
-        # type: () -> str
-        if len(self.sliced_op_outputs) == 0:
-            return "SSAVal([])"
-        parts = []  # type: list[str]
-        for op, output_index, range_ in self.sliced_op_outputs:
-            out_len = op.properties.outputs[output_index].get_length(op.maxvl)
-            parts.append(f"<{op.name}#{output_index}>")
-            if range_ != range(out_len):
-                parts[-1] += f"[{range_.start}:{range_.stop}]"
-        return " + ".join(parts)
+@final
+class OpInputs(OpInputSeq[SSAVal, OperandDesc]):
+    def _get_descriptors(self):
+        # type: () -> tuple[OperandDesc, ...]
+        return self.op.properties.inputs
+
+    def _verify_write_with_desc(self, idx, item, desc):
+        # type: (int, SSAVal | Any, OperandDesc) -> None
+        if not isinstance(item, SSAVal):
+            raise TypeError("expected value of type SSAVal")
+        if item.ty != desc.ty:
+            raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
+                             f"corresponding input's type {desc.ty!r}")
+
+    def __init__(self, items, op):
+        # type: (Iterable[SSAVal], Op) -> None
+        if hasattr(op, "inputs"):
+            raise ValueError("Op.inputs already set")
+        super().__init__(items, op)
+
+
+@final
+class OpImmediates(OpInputSeq[int, range]):
+    def _get_descriptors(self):
+        # type: () -> tuple[range, ...]
+        return self.op.properties.immediates
+
+    def _verify_write_with_desc(self, idx, item, desc):
+        # type: (int, int | Any, range) -> None
+        if not isinstance(item, int):
+            raise TypeError("expected value of type int")
+        if item not in desc:
+            raise ValueError(f"immediate value {item!r} not in {desc!r}")
+
+    def __init__(self, items, op):
+        # type: (Iterable[int], Op) -> None
+        if hasattr(op, "immediates"):
+            raise ValueError("Op.immediates already set")
+        super().__init__(items, op)
 
 
 @plain_data(frozen=True, eq=False)
 @final
 class Op:
-    __slots__ = "fn", "kind", "inputs", "immediates", "outputs", "maxvl", "name"
+    __slots__ = "fn", "properties", "inputs", "immediates", "outputs", "name"
 
-    def __init__(self, fn, kind, inputs, immediates, maxvl, name=""):
-        # type: (Fn, OpKind, Iterable[SSAVal], Iterable[int], int, str) -> None
+    def __init__(self, fn, properties, inputs, immediates, name=""):
+        # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
         self.fn = fn
-        self.kind = kind
-        self.inputs = list(inputs)
-        self.immediates = list(immediates)
-        self.maxvl = maxvl
+        self.properties = properties
+        self.inputs = OpInputs(inputs, op=self)
+        self.immediates = OpImmediates(immediates, op=self)
         outputs_len = len(self.properties.outputs)
-        self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len))
+        self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
         self.name = fn._add_op_with_unused_name(self, name)  # type: ignore
 
     @property
-    def properties(self):
-        return self.kind.properties
+    def kind(self):
+        return self.properties.kind
 
     def __eq__(self, other):
         # type: (Op | Any) -> bool