From 98b4968397c921fc1e104543240e4be5b29e1615 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 27 Oct 2021 02:10:24 -0700 Subject: [PATCH] add SimdWHintMap to support tracking width_hint for XLEN --- src/ieee754/part/util.py | 109 ++++++++++++++++++++++++++++++++------ src/ieee754/part/util.pyi | 105 ++++++++++++++++++++++++++++++++++-- 2 files changed, 193 insertions(+), 21 deletions(-) diff --git a/src/ieee754/part/util.py b/src/ieee754/part/util.py index 6a9bcb13..bdb33279 100644 --- a/src/ieee754/part/util.py +++ b/src/ieee754/part/util.py @@ -34,6 +34,23 @@ class SimdMap: ALL_ELWIDTHS = (*FpElWid, *IntElWid) __slots__ = ("__map",) + @staticmethod + def extract_value_algo(values, default=None, *, simd_map_get, mapping_get): + step = 0 + while values is not None: + # specifically use base class to catch all SimdMap instances + if isinstance(values, SimdMap): + values = simd_map_get(values) + elif isinstance(values, Mapping): + values = mapping_get(values) + else: + return values + step += 1 + # use object.__repr__ since repr() would probably recurse forever + assert step < 10000, (f"can't resolve infinitely recursive " + f"value {object.__repr__(values)}") + return default + @classmethod def extract_value(cls, elwid, values, default=None): """get the value for elwid. @@ -56,20 +73,10 @@ class SimdMap: }) == 5 """ assert elwid in cls.ALL_ELWIDTHS - step = 0 - while values is not None: - # specifically use base class to catch all SimdMap instances - if isinstance(values, SimdMap): - values = values.__map.get(elwid) - elif isinstance(values, Mapping): - values = values.get(elwid) - else: - return values - step += 1 - # use object.__repr__ since repr() would probably recurse forever - assert step < 10000, (f"can't resolve infinitely recursive " - f"value {object.__repr__(values)}") - return default + return SimdMap.extract_value_algo( + values, default, + simd_map_get=lambda v: v.__map.get(elwid), + mapping_get=lambda v: v.get(elwid)) def __init__(self, values=None): """construct a SimdMap""" @@ -151,7 +158,7 @@ class SimdMap: def __iter__(self): """return an iterator of (elwid, value) pairs""" - return self.__map.items() + return iter(self.__map.items()) def __add__(self, other): return self.map(operator.add, self, other) @@ -264,12 +271,80 @@ class SimdMap: yield elwid +class SimdWHintMap(SimdMap): + """SimdMap with a width hint.""" + + __slots__ = ("__width_hint",) + + @classmethod + def extract_width_hint(cls, values, default=None): + """get the value for width hint.""" + def simd_map_get(v): + return v.width_hint if isinstance(v, SimdWHintMap) else None + return SimdMap.extract_value_algo( + values, default, + simd_map_get=simd_map_get, + mapping_get=lambda _: None) + + def __init__(self, values=None, *, width_hint=None): + """construct a SimdWHintMap""" + super().__init__(values) + if width_hint is None: + width_hint = values + + self.__width_hint = self.extract_width_hint(width_hint) + + @property + def width_hint(self): + return self.__width_hint + + def __eq__(self, other): + if isinstance(other, SimdMap): + return self.mapping == other.mapping \ + and self.width_hint == self.extract_width_hint(other) + return NotImplemented + + def __hash__(self): + if self.width_hint is None: + return super().__hash__() + return hash((tuple(self.mapping.get(i) for i in self.ALL_ELWIDTHS), + self.width_hint)) + + def __repr__(self): + wh = "" + if self.width_hint is not None: + wh = f", width_hint={self.width_hint!r}" + return f"{self.__class__.__name__}({dict(self.mapping)}{wh})" + + @classmethod + def map(cls, f, *args): + """get the SimdWHintMap of the results of calling + `f(value1, value2, value3, ...)` where + `value1`, `value2`, `value3`, ... are the results of calling + `cls.extract_value` on each `args`. + """ + retval = {} + for elwid in cls.ALL_ELWIDTHS: + extracted_args = [cls.extract_value(elwid, arg) for arg in args] + if None not in extracted_args: + retval[elwid] = f(*extracted_args) + width_hint = None + try: + extracted_args = [cls.extract_width_hint(arg) for arg in args] + if None not in extracted_args: + width_hint = f(*extracted_args) + except (ArithmeticError, LookupError, ValueError): + # ignore some errors and just clear width_hint + pass + return cls(retval, width_hint=width_hint) + + def _check_for_missing_elwidths(name, all_elwidths=None): missing = list(globals()[name].missing_elwidths(all_elwidths=all_elwidths)) assert missing == [], f"{name} is missing entries for {missing}" -XLEN = SimdMap({ +XLEN = SimdWHintMap({ IntElWid.I64: 64, IntElWid.I32: 32, IntElWid.I16: 16, @@ -278,7 +353,7 @@ XLEN = SimdMap({ FpElWid.F32: 32, FpElWid.F16: 16, FpElWid.BF16: 16, -}) +}, width_hint=64) DEFAULT_FP_VEC_EL_COUNTS = SimdMap({ FpElWid.F64: 1, diff --git a/src/ieee754/part/util.pyi b/src/ieee754/part/util.pyi index 97486600..5d2189bc 100644 --- a/src/ieee754/part/util.pyi +++ b/src/ieee754/part/util.pyi @@ -3,8 +3,8 @@ from enum import Enum from typing import (Any, Callable, ClassVar, Generic, ItemsView, Iterable, - KeysView, Literal, Mapping, Optional, Tuple, TypeVar, - Union, ValuesView, overload) + Iterator, KeysView, Literal, Mapping, Optional, Tuple, + TypeVar, Union, ValuesView, overload) class ElWid(Enum): @@ -37,6 +37,38 @@ class SimdMap(Generic[_T]): __map: Mapping[_ElWid, _T] + @overload + @staticmethod + def extract_value_algo(values: None, + default: _T2 = None, *, + simd_map_get: Callable[["SimdMap[_T]"], _T], + mapping_get: Callable[[Mapping[_ElWid, _T]], _T], + ) -> _T2: ... + + @overload + @staticmethod + def extract_value_algo(values: SimdMap[_T], + default: _T2 = None, *, + simd_map_get: Callable[["SimdMap[_T]"], _T], + mapping_get: Callable[[Mapping[_ElWid, _T]], _T], + ) -> Union[_T, _T2]: ... + + @overload + @staticmethod + def extract_value_algo(values: Mapping[_ElWid, _T], + default: _T2 = None, *, + simd_map_get: Callable[["SimdMap[_T]"], _T], + mapping_get: Callable[[Mapping[_ElWid, _T]], _T], + ) -> Union[_T, _T2]: ... + + @overload + @staticmethod + def extract_value_algo(values: _T, + default: _T2 = None, *, + simd_map_get: Callable[["SimdMap[_T]"], _T], + mapping_get: Callable[[Mapping[_ElWid, _T]], _T], + ) -> Union[_T, _T2]: ... + @overload @classmethod def extract_value(cls, @@ -803,7 +835,7 @@ class SimdMap(Generic[_T]): def get(self, elwid: _ElWid, default: _T2 = None, *, raise_key_error: bool = False) -> Union[_T, _T2]: ... - def __iter__(self) -> Iterable[Tuple[_ElWid, _T]]: ... + def __iter__(self) -> Iterator[Tuple[_ElWid, _T]]: ... @overload def __add__(self, other: SimdMap[_T]) -> SimdMap[_T]: ... @@ -1109,7 +1141,72 @@ class SimdMap(Generic[_T]): ) -> Iterable[_ElWid]: ... -XLEN: SimdMap[int] = ... +class SimdWHintMap(SimdMap[_T]): + @overload + @classmethod + def extract_width_hint(cls, + values: None, + default: _T2 = None) -> _T2: ... + + @overload + @classmethod + def extract_width_hint(cls, + values: SimdMap[_T], + default: _T2 = None) -> Union[_T, _T2]: ... + + @overload + @classmethod + def extract_width_hint(cls, + values: Mapping[_ElWid, _T], + default: _T2 = None) -> Union[_T, _T2]: ... + + @overload + @classmethod + def extract_width_hint(cls, + values: _T, + default: _T2 = None) -> Union[_T, _T2]: ... + + @overload + def __init__(self, values: Optional[SimdMap[_T]] = None, *, + width_hint: Optional[SimdMap[_T]] = None): ... + + @overload + def __init__(self, values: Optional[Mapping[_ElWid, _T]] = None, *, + width_hint: Optional[SimdMap[_T]] = None): ... + + @overload + def __init__(self, values: Optional[_T] = None, *, + width_hint: Optional[SimdMap[_T]] = None): ... + + @overload + def __init__(self, values: Optional[SimdMap[_T]] = None, *, + width_hint: Optional[Mapping[_ElWid, _T]] = None): ... + + @overload + def __init__(self, values: Optional[Mapping[_ElWid, _T]] = None, *, + width_hint: Optional[Mapping[_ElWid, _T]] = None): ... + + @overload + def __init__(self, values: Optional[_T] = None, *, + width_hint: Optional[Mapping[_ElWid, _T]] = None): ... + + @overload + def __init__(self, values: Optional[SimdMap[_T]] = None, *, + width_hint: Optional[_T] = None): ... + + @overload + def __init__(self, values: Optional[Mapping[_ElWid, _T]] = None, *, + width_hint: Optional[_T] = None): ... + + @overload + def __init__(self, values: Optional[_T] = None, *, + width_hint: Optional[_T] = None): ... + + @property + def width_hint(self) -> _T: ... + + +XLEN: SimdWHintMap[int] = ... DEFAULT_FP_VEC_EL_COUNTS: SimdMap[int] = ... -- 2.30.2