# type: () -> int
return self.value_range.output_size
+ @property
+ def current_debugging_value(self):
+ # type: () -> tuple[int, ...]
+ """ get the current value for debugging in pdb or similar.
+
+ This is intended for use with
+ `PreRASimState.set_current_debugging_state`.
+
+ This is only intended for debugging, do not use in unit tests or
+ production code.
+ """
+ return self.output.current_debugging_value
+
@plain_data(frozen=True, unsafe_hash=True)
@final
if self.min_value > self.max_value:
raise ValueError("invalid value range")
+ @property
+ def current_debugging_value(self):
+ # type: () -> tuple[int, ...]
+ """ get the current value for debugging in pdb or similar.
+
+ This is intended for use with
+ `PreRASimState.set_current_debugging_state`.
+
+ This is only intended for debugging, do not use in unit tests or
+ production code.
+ """
+ return self.ssa_val.current_debugging_value
+
@plain_data(frozen=True)
@final
def split_into_exact_sized_parts(fn, ssa_val, part_count, part_size, name):
- # type: (Fn, SSAVal, int, int, str) -> list[SSAVal]
+ # type: (Fn, SSAVal, int, int, str) -> tuple[SSAVal, ...]
"""split ssa_val into part_count parts, where all but the last part have
`part.ty.reg_len == part_size`.
"""
if part_count <= 0:
raise ValueError("invalid part count, must be positive")
if part_count == 1:
- return [ssa_val]
+ return (ssa_val,)
too_short_reg_len = (part_count - 1) * part_size
if ssa_val.ty.reg_len <= too_short_reg_len:
raise ValueError(f"ssa_val is too short to split, must have "
input_vals=[*spread.outputs[start:stop], part_setvl.outputs[0]],
name=f"{name}_{part}_concat", maxvl=part_maxvl)
retval.append(concat.outputs[0])
- return retval
+ return tuple(retval)
-__TCIs = Tuple[ToomCookInstance, ...]
+_TCIs = Tuple[ToomCookInstance, ...]
+
+
+@plain_data(frozen=True)
+@final
+class ToomCookMul:
+ __slots__ = (
+ "fn", "lhs", "lhs_signed", "rhs", "rhs_signed", "instances",
+ "retval_size", "start_instance_index", "instance", "part_size",
+ "lhs_parts", "lhs_inputs", "lhs_eval_state", "lhs_outputs",
+ "rhs_parts", "rhs_inputs", "rhs_eval_state", "rhs_outputs",
+ "prod_inputs", "prod_eval_state", "prod_parts",
+ "partial_products", "retval",
+ )
+
+ def __init__(self, fn, lhs, lhs_signed, rhs, rhs_signed, instances,
+ retval_size=None, start_instance_index=0):
+ # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> None
+ self.fn = fn
+ self.lhs = lhs
+ self.lhs_signed = lhs_signed
+ self.rhs = rhs
+ self.rhs_signed = rhs_signed
+ self.instances = instances
+ if retval_size is None:
+ retval_size = lhs.ty.reg_len + rhs.ty.reg_len
+ self.retval_size = retval_size
+ if start_instance_index < 0:
+ raise ValueError("start_instance_index must be non-negative")
+ self.start_instance_index = start_instance_index
+ self.instance = None
+ self.part_size = 0 # type: int
+ while start_instance_index < len(instances):
+ self.instance = instances[start_instance_index]
+ self.part_size = max(
+ lhs.ty.reg_len // self.instance.lhs_part_count,
+ rhs.ty.reg_len // self.instance.rhs_part_count)
+ if self.part_size <= 0:
+ self.instance = None
+ start_instance_index += 1
+ else:
+ break
+ if self.instance is None:
+ self.retval = simple_mul(fn=fn,
+ lhs=lhs, lhs_signed=lhs_signed,
+ rhs=rhs, rhs_signed=rhs_signed,
+ name="toom_cook_base_case")
+ return
+ self.lhs_parts = split_into_exact_sized_parts(
+ fn=fn, ssa_val=lhs, part_count=self.instance.lhs_part_count,
+ part_size=self.part_size, name="lhs")
+ self.lhs_inputs = [] # type: list[EvalOpGenIrInput]
+ for part, ssa_val in enumerate(self.lhs_parts):
+ self.lhs_inputs.append(EvalOpGenIrInput(
+ ssa_val=ssa_val,
+ is_signed=lhs_signed and part == len(self.lhs_parts) - 1))
+ self.lhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.lhs_inputs)
+ lhs_eval_ops = self.instance.lhs_eval_ops
+ self.lhs_outputs = [
+ self.lhs_eval_state.get_output(i) for i in lhs_eval_ops]
+ self.rhs_parts = split_into_exact_sized_parts(
+ fn=fn, ssa_val=rhs, part_count=self.instance.rhs_part_count,
+ part_size=self.part_size, name="rhs")
+ self.rhs_inputs = [] # type: list[EvalOpGenIrInput]
+ for part, ssa_val in enumerate(self.rhs_parts):
+ self.rhs_inputs.append(EvalOpGenIrInput(
+ ssa_val=ssa_val,
+ is_signed=rhs_signed and part == len(self.rhs_parts) - 1))
+ self.rhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.rhs_inputs)
+ rhs_eval_ops = self.instance.rhs_eval_ops
+ self.rhs_outputs = [
+ self.rhs_eval_state.get_output(i) for i in rhs_eval_ops]
+ self.prod_inputs = [] # type: list[EvalOpGenIrInput]
+ for lhs_output, rhs_output in zip(self.lhs_outputs, self.rhs_outputs):
+ ssa_val = toom_cook_mul(
+ fn=fn,
+ lhs=lhs_output.output, lhs_signed=lhs_output.is_signed,
+ rhs=rhs_output.output, rhs_signed=rhs_output.is_signed,
+ instances=instances,
+ start_instance_index=start_instance_index + 1)
+ products = (lhs_output.min_value * rhs_output.min_value,
+ lhs_output.min_value * rhs_output.max_value,
+ lhs_output.max_value * rhs_output.min_value,
+ lhs_output.max_value * rhs_output.max_value)
+ self.prod_inputs.append(EvalOpGenIrInput(
+ ssa_val=ssa_val,
+ is_signed=None,
+ min_value=min(products),
+ max_value=max(products)))
+ self.prod_eval_state = EvalOpGenIrState(fn=fn, inputs=self.prod_inputs)
+ prod_eval_ops = self.instance.prod_eval_ops
+ self.prod_parts = [
+ self.prod_eval_state.get_output(i) for i in prod_eval_ops]
+
+ def partial_products():
+ # type: () -> Iterable[PartialProduct]
+ for part, prod_part in enumerate(self.prod_parts):
+ part_maxvl = prod_part.output.ty.reg_len
+ part_setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[part_maxvl],
+ name=f"prod_{part}_setvl", maxvl=part_maxvl)
+ spread_part = fn.append_new_op(
+ OpKind.Spread,
+ input_vals=[prod_part.output, part_setvl.outputs[0]],
+ name=f"prod_{part}_spread", maxvl=part_maxvl)
+ yield PartialProduct(
+ spread_part.outputs, shift_in_words=part * self.part_size,
+ is_signed=prod_part.is_signed, subtract=False)
+ self.partial_products = tuple(partial_products())
+ self.retval = sum_partial_products(
+ fn=fn, partial_products=self.partial_products,
+ retval_size=retval_size, name="prod")
def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances,
retval_size=None, start_instance_index=0):
- # type: (Fn, SSAVal, bool, SSAVal, bool, __TCIs, None | int, int) -> SSAVal
- if retval_size is None:
- retval_size = lhs.ty.reg_len + rhs.ty.reg_len
- if start_instance_index < 0:
- raise ValueError("start_instance_index must be non-negative")
- instance = None
- part_size = 0
- while start_instance_index < len(instances):
- instance = instances[start_instance_index]
- part_size = max(lhs.ty.reg_len // instance.lhs_part_count,
- rhs.ty.reg_len // instance.rhs_part_count)
- if part_size <= 0:
- instance = None
- start_instance_index += 1
- else:
- break
- if instance is None:
- return simple_mul(fn=fn,
- lhs=lhs, lhs_signed=lhs_signed,
- rhs=rhs, rhs_signed=rhs_signed,
- name="toom_cook_base_case")
- lhs_parts = split_into_exact_sized_parts(
- fn=fn, ssa_val=lhs, part_count=instance.lhs_part_count,
- part_size=part_size, name="lhs")
- lhs_inputs = [] # type: list[EvalOpGenIrInput]
- for part, ssa_val in enumerate(lhs_parts):
- lhs_inputs.append(EvalOpGenIrInput(
- ssa_val=ssa_val,
- is_signed=lhs_signed and part == len(lhs_parts) - 1))
- lhs_eval_state = EvalOpGenIrState(fn=fn, inputs=lhs_inputs)
- lhs_outputs = [lhs_eval_state.get_output(i) for i in instance.lhs_eval_ops]
- rhs_parts = split_into_exact_sized_parts(
- fn=fn, ssa_val=rhs, part_count=instance.rhs_part_count,
- part_size=part_size, name="rhs")
- rhs_inputs = [] # type: list[EvalOpGenIrInput]
- for part, ssa_val in enumerate(rhs_parts):
- rhs_inputs.append(EvalOpGenIrInput(
- ssa_val=ssa_val,
- is_signed=rhs_signed and part == len(rhs_parts) - 1))
- rhs_eval_state = EvalOpGenIrState(fn=fn, inputs=rhs_inputs)
- rhs_outputs = [rhs_eval_state.get_output(i) for i in instance.rhs_eval_ops]
- prod_inputs = [] # type: list[EvalOpGenIrInput]
- for lhs_output, rhs_output in zip(lhs_outputs, rhs_outputs):
- ssa_val = toom_cook_mul(
- fn=fn,
- lhs=lhs_output.output, lhs_signed=lhs_output.is_signed,
- rhs=rhs_output.output, rhs_signed=rhs_output.is_signed,
- instances=instances, start_instance_index=start_instance_index + 1)
- products = (lhs_output.min_value * rhs_output.min_value,
- lhs_output.min_value * rhs_output.max_value,
- lhs_output.max_value * rhs_output.min_value,
- lhs_output.max_value * rhs_output.max_value)
- prod_inputs.append(EvalOpGenIrInput(
- ssa_val=ssa_val,
- is_signed=None,
- min_value=min(products),
- max_value=max(products)))
- prod_eval_state = EvalOpGenIrState(fn=fn, inputs=prod_inputs)
- prod_parts = [
- prod_eval_state.get_output(i) for i in instance.prod_eval_ops]
-
- def partial_products():
- # type: () -> Iterable[PartialProduct]
- for part, prod_part in enumerate(prod_parts):
- part_maxvl = prod_part.output.ty.reg_len
- part_setvl = fn.append_new_op(
- OpKind.SetVLI, immediates=[part_maxvl],
- name=f"prod_{part}_setvl", maxvl=part_maxvl)
- spread_part = fn.append_new_op(
- OpKind.Spread,
- input_vals=[prod_part.output, part_setvl.outputs[0]],
- name=f"prod_{part}_spread", maxvl=part_maxvl)
- yield PartialProduct(
- spread_part.outputs, shift_in_words=part * part_size,
- is_signed=prod_part.is_signed, subtract=False)
- return sum_partial_products(fn=fn, partial_products=partial_products(),
- retval_size=retval_size, name="prod")
+ # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> SSAVal
+ return ToomCookMul(
+ fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs, rhs_signed=rhs_signed,
+ instances=instances, retval_size=retval_size,
+ start_instance_index=start_instance_index).retval