add SimdMap and SimdScope and XLEN
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 13 Oct 2021 08:47:38 +0000 (01:47 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 13 Oct 2021 08:47:38 +0000 (01:47 -0700)
src/ieee754/part/util.py [new file with mode: 0644]

diff --git a/src/ieee754/part/util.py b/src/ieee754/part/util.py
new file mode 100644 (file)
index 0000000..8cadf8c
--- /dev/null
@@ -0,0 +1,397 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# See Notices.txt for copyright information
+
+from enum import Enum
+from typing import Mapping
+import operator
+import math
+from types import MappingProxyType
+from contextlib import contextmanager
+
+from nmigen.hdl.ast import Signal
+
+
+class ElWid(Enum):
+    def __repr__(self):
+        return super().__str__()
+
+
+class FpElWid(ElWid):
+    F64 = 0
+    F32 = 1
+    F16 = 2
+    BF16 = 3
+
+
+class IntElWid(ElWid):
+    I64 = 0
+    I32 = 1
+    I16 = 2
+    I8 = 3
+
+
+class SimdMap:
+    """A map from ElWid values to Python values.
+    SimdMap instances are immutable."""
+
+    ALL_ELWIDTHS = (*FpElWid, *IntElWid)
+    __slots__ = ("__map",)
+
+    @classmethod
+    def extract_value(cls, elwid, values, default=None):
+        """get the value for elwid.
+        if `values` is a `SimdMap` or a `Mapping`, then return the
+        corresponding value for `elwid`, recursing until finding a non-map.
+        if `values` ever ends up not existing (in the case of a map) or being
+        `None`, return `default`.
+
+        Examples:
+        SimdMap.extract_value(IntElWid.I8, 5) == 5
+        SimdMap.extract_value(IntElWid.I8, None) == None
+        SimdMap.extract_value(IntElWid.I8, None, 3) == 3
+        SimdMap.extract_value(IntElWid.I8, {}) == None
+        SimdMap.extract_value(IntElWid.I8, {IntElWid.I8: 5}) == 5
+        SimdMap.extract_value(IntElWid.I8, {
+            IntElWid.I8: {IntElWid.I8: 5},
+        }) == 5
+        SimdMap.extract_value(IntElWid.I8, {
+            IntElWid.I8: SimdMap({IntElWid.I8: 5}),
+        }) == 5
+        """
+        assert elwid in cls.ALL_ELWIDTHS
+        step = 0
+        while values is not None:
+            # specifically use base class to catch all SimdMap instances
+            if isinstance(values, SimdMap):
+                values = values.__map.get(elwid)
+            elif isinstance(values, Mapping):
+                values = values.get(elwid)
+            else:
+                return values
+            step += 1
+            # use object.__repr__ since repr() would probably recurse forever
+            assert step < 10000, (f"can't resolve infinitely recursive "
+                                  f"value {object.__repr__(values)}")
+        return default
+
+    def __init__(self, values=None):
+        """construct a SimdMap"""
+        mapping = {}
+        for elwid in self.ALL_ELWIDTHS:
+            v = self.extract_value(elwid, values)
+            if v is not None:
+                mapping[elwid] = v
+        self.__map = MappingProxyType(mapping)
+
+    @property
+    def mapping(self):
+        """the values as a read-only Mapping[ElWid, Any]"""
+        return self.__map
+
+    def values(self):
+        return self.__map.values()
+
+    def keys(self):
+        return self.__map.keys()
+
+    def items(self):
+        return self.__map.items()
+
+    @classmethod
+    def map_with_elwid(cls, f, *args):
+        """get the SimdMap of the results of calling
+        `f(elwid, value1, value2, value3, ...)` where
+        `value1`, `value2`, `value3`, ... are the results of calling
+        `cls.extract_value` on each `args`.
+
+        This is similar to Python's built-in `map` function.
+
+        Examples:
+        SimdMap.map_with_elwid(lambda elwid, a: a + 1, {IntElWid.I32: 5}) ==
+            SimdMap({IntElWid.I32: 6})
+        SimdMap.map_with_elwid(lambda elwid, a: a + 1, 3) ==
+            SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...})
+        SimdMap.map_with_elwid(lambda elwid, a, b: a + b,
+            3, {IntElWid.I8: 4},
+        ) == SimdMap({IntElWid.I8: 7})
+        SimdMap.map_with_elwid(lambda elwid: elwid.name) ==
+            SimdMap({IntElWid.I8: "I8", IntElWid.I16: "I16"})
+        """
+        retval = {}
+        for elwid in cls.ALL_ELWIDTHS:
+            extracted_args = [cls.extract_value(elwid, arg) for arg in args]
+            if None not in extracted_args:
+                retval[elwid] = f(elwid, *extracted_args)
+        return cls(retval)
+
+    @classmethod
+    def map(cls, f, *args):
+        """get the SimdMap of the results of calling
+        `f(value1, value2, value3, ...)` where
+        `value1`, `value2`, `value3`, ... are the results of calling
+        `cls.extract_value` on each `args`.
+
+        This is similar to Python's built-in `map` function.
+
+        Examples:
+        SimdMap.map(lambda a: a + 1, {IntElWid.I32: 5}) ==
+            SimdMap({IntElWid.I32: 6})
+        SimdMap.map(lambda a: a + 1, 3) ==
+            SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...})
+        SimdMap.map(lambda a, b: a + b,
+            3, {IntElWid.I8: 4},
+        ) == SimdMap({IntElWid.I8: 7})
+        """
+        return cls.map_with_elwid(lambda elwid, *args2: f(*args2), *args)
+
+    def get(self, elwid, default=None, *, raise_key_error=False):
+        if raise_key_error:
+            retval = self.extract_value(elwid, self)
+            if retval is None:
+                raise KeyError()
+            return retval
+        return self.extract_value(elwid, self, default)
+
+    def __iter__(self):
+        """return an iterator of (elwid, value) pairs"""
+        return self.__map.items()
+
+    def __add__(self, other):
+        return self.map(operator.add, self, other)
+
+    def __radd__(self, other):
+        return self.map(operator.add, other, self)
+
+    def __sub__(self, other):
+        return self.map(operator.sub, self, other)
+
+    def __rsub__(self, other):
+        return self.map(operator.sub, other, self)
+
+    def __mul__(self, other):
+        return self.map(operator.mul, self, other)
+
+    def __rmul__(self, other):
+        return self.map(operator.mul, other, self)
+
+    def __floordiv__(self, other):
+        return self.map(operator.floordiv, self, other)
+
+    def __rfloordiv__(self, other):
+        return self.map(operator.floordiv, other, self)
+
+    def __truediv__(self, other):
+        return self.map(operator.truediv, self, other)
+
+    def __rtruediv__(self, other):
+        return self.map(operator.truediv, other, self)
+
+    def __mod__(self, other):
+        return self.map(operator.mod, self, other)
+
+    def __rmod__(self, other):
+        return self.map(operator.mod, other, self)
+
+    def __abs__(self):
+        return self.map(abs, self)
+
+    def __and__(self, other):
+        return self.map(operator.and_, self, other)
+
+    def __rand__(self, other):
+        return self.map(operator.and_, other, self)
+
+    def __divmod__(self, other):
+        return self.map(divmod, self, other)
+
+    def __ceil__(self):
+        return self.map(math.ceil, self)
+
+    def __float__(self):
+        return self.map(float, self)
+
+    def __floor__(self):
+        return self.map(math.floor, self)
+
+    def __eq__(self, other):
+        if isinstance(other, SimdMap):
+            return self.mapping == other.mapping
+        return NotImplemented
+
+    def __hash__(self):
+        return hash(tuple(self.mapping.get(i) for i in self.ALL_ELWIDTHS))
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({dict(self.mapping)})"
+
+    def __invert__(self):
+        return self.map(operator.invert, self)
+
+    def __lshift__(self, other):
+        return self.map(operator.lshift, self, other)
+
+    def __rlshift__(self, other):
+        return self.map(operator.lshift, other, self)
+
+    def __rshift__(self, other):
+        return self.map(operator.rshift, self, other)
+
+    def __rrshift__(self, other):
+        return self.map(operator.rshift, other, self)
+
+    def __neg__(self):
+        return self.map(operator.neg, self)
+
+    def __pos__(self):
+        return self.map(operator.pos, self)
+
+    def __or__(self, other):
+        return self.map(operator.or_, self, other)
+
+    def __ror__(self, other):
+        return self.map(operator.or_, other, self)
+
+    def __xor__(self, other):
+        return self.map(operator.xor, self, other)
+
+    def __rxor__(self, other):
+        return self.map(operator.xor, other, self)
+
+    def missing_elwidths(self, *, all_elwidths=None):
+        """an iterator of the elwidths where self doesn't have a corresponding
+        value"""
+        if all_elwidths is None:
+            all_elwidths = self.ALL_ELWIDTHS
+        for elwid in all_elwidths:
+            if elwid not in self.keys():
+                yield elwid
+
+
+def _check_for_missing_elwidths(name, all_elwidths=None):
+    missing = list(globals()[name].missing_elwidths(all_elwidths=all_elwidths))
+    assert missing == [], f"{name} is missing entries for {missing}"
+
+
+XLEN = SimdMap({
+    IntElWid.I64: 64,
+    IntElWid.I32: 32,
+    IntElWid.I16: 16,
+    IntElWid.I8: 8,
+    FpElWid.F64: 64,
+    FpElWid.F32: 32,
+    FpElWid.F16: 16,
+    FpElWid.BF16: 16,
+})
+
+DEFAULT_FP_PART_COUNTS = SimdMap({
+    FpElWid.F64: 4,
+    FpElWid.F32: 2,
+    FpElWid.F16: 1,
+    FpElWid.BF16: 1,
+})
+
+DEFAULT_INT_PART_COUNTS = SimdMap({
+    IntElWid.I64: 8,
+    IntElWid.I32: 4,
+    IntElWid.I16: 2,
+    IntElWid.I8: 1,
+})
+
+_check_for_missing_elwidths("XLEN")
+_check_for_missing_elwidths("DEFAULT_FP_PART_COUNTS", FpElWid)
+_check_for_missing_elwidths("DEFAULT_INT_PART_COUNTS", IntElWid)
+
+
+class SimdScope:
+    """The global scope object for SimdSignal and friends
+
+    Members:
+    * part_counts: SimdMap
+        a map from `ElWid` values `k` to the number of parts in an element
+        when `self.elwid == k`. Values should be minimized, since higher values
+        often create bigger circuits.
+
+        Example:
+        # here, an I8 element is 1 part wide
+        part_counts = {ElWid.I8: 1, ElWid.I16: 2, ElWid.I32: 4, ElWid.I64: 8}
+
+        Another Example:
+        # here, an F16 element is 1 part wide
+        part_counts = {ElWid.F16: 1, ElWid.BF16: 1, ElWid.F32: 2, ElWid.F64: 4}
+    * simd_full_width_hint: int
+        the default value for SimdLayout's full_width argument, the full number
+        of bits in a SIMD value.
+    * elwid: ElWid or nmigen Value with a shape of some ElWid class
+        the current elwid (simd element type)
+    """
+
+    __SCOPE_STACK = []
+
+    @classmethod
+    def get(cls):
+        """get the current SimdScope.
+
+        Example:
+        SimdScope.get(None) is None
+        SimdScope.get() raises ValueError
+        with SimdScope(...) as s:
+            SimdScope.get() is s
+        """
+        if len(cls.__SCOPE_STACK) > 0:
+            retval = cls.__SCOPE_STACK[-1]
+            assert isinstance(retval, SimdScope), "inconsistent scope stack"
+            return retval
+        raise ValueError("not in a `with SimdScope()` statement")
+
+    def __enter__(self):
+        self.__SCOPE_STACK.append(self)
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        assert self.__SCOPE_STACK.pop() is self, "inconsistent scope stack"
+        return False
+
+    def __init__(self, *, simd_full_width_hint=64, elwid=None,
+                 part_counts=None, elwid_type=IntElWid, scalar=False):
+        # TODO: add more arguments/members and processing for integration with
+        self.simd_full_width_hint = simd_full_width_hint
+        if isinstance(elwid, (IntElWid, FpElWid)):
+            elwid_type = type(elwid)
+            if part_counts is None:
+                part_counts = SimdMap({elwid: 1})
+        assert issubclass(elwid_type, (IntElWid, FpElWid))
+        self.elwid_type = elwid_type
+        scalar_elwid = elwid_type(0)
+        if part_counts is None:
+            if scalar:
+                part_counts = SimdMap({scalar_elwid: 1})
+            elif issubclass(elwid_type, FpElWid):
+                part_counts = DEFAULT_FP_PART_COUNTS
+            else:
+                part_counts = DEFAULT_INT_PART_COUNTS
+
+        def check(elwid, part_count):
+            assert type(elwid) == elwid_type, "inconsistent ElWid types"
+            part_count = int(part_count)
+            assert part_count != 0 and (part_count & (part_count - 1)) == 0,\
+                "part_counts values must all be powers of two"
+            return part_count
+
+        self.part_counts = SimdMap.map_with_elwid(check, part_counts)
+        self.full_part_count = max(part_counts.values())
+        assert self.simd_full_width_hint % self.full_part_count == 0,\
+            "simd_full_width_hint must be a multiple of full_part_count"
+        if elwid is not None:
+            self.elwid = elwid
+        elif scalar:
+            self.elwid = scalar_elwid
+        else:
+            self.elwid = Signal(elwid_type)
+
+    def __repr__(self):
+        return (f"SimdScope(\n"
+                f"        simd_full_width_hint={self.simd_full_width_hint},\n"
+                f"        elwid={self.elwid},\n"
+                f"        elwid_type={self.elwid_type},\n"
+                f"        part_counts={self.part_counts},\n"
+                f"        full_part_count={self.full_part_count})")