+ name=f"{name}_{shift_in_words}_mul_rt_spread", maxvl=maxvl)
+ yield PartialProduct(
+ ssa_val_spread=[*mul_rt_spread.outputs, mul.outputs[1]],
+ shift_in_words=shift_in_words,
+ is_signed=False)
+ return sum_partial_products(fn=fn, partial_products=partial_products(),
+ name=name)
+
+
+def cast_to_size(fn, ssa_val, src_signed, dest_size, name):
+ # type: (Fn, SSAVal, bool, int, str) -> SSAVal
+ if dest_size <= 0:
+ raise ValueError("invalid dest_size -- must be a positive integer")
+ if ssa_val.ty.reg_len == dest_size:
+ return ssa_val
+ in_setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[ssa_val.ty.reg_len],
+ maxvl=ssa_val.ty.reg_len, name=f"{name}_in_setvl")
+ spread = fn.append_new_op(
+ OpKind.Spread, input_vals=[ssa_val, in_setvl.outputs[0]],
+ name=f"{name}_spread", maxvl=ssa_val.ty.reg_len)
+ spread_values = cast_to_size_spread(
+ fn=fn, ssa_vals=spread.outputs, src_signed=src_signed,
+ dest_size=dest_size, name=name)
+ out_setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[dest_size], maxvl=dest_size,
+ name=f"{name}_out_setvl")
+ concat = fn.append_new_op(
+ OpKind.Concat, input_vals=[*spread_values, out_setvl.outputs[0]],
+ name=f"{name}_concat", maxvl=dest_size)
+ return concat.outputs[0]
+
+
+def cast_to_size_spread(fn, ssa_vals, src_signed, dest_size, name):
+ # type: (Fn, Iterable[SSAVal], bool, int, str) -> list[SSAVal]
+ if dest_size <= 0:
+ raise ValueError("invalid dest_size -- must be a positive integer")
+ spread_values = list(ssa_vals)
+ for ssa_val in ssa_vals:
+ if ssa_val.ty != Ty(base_ty=BaseTy.I64, reg_len=1):
+ raise ValueError("invalid ssa_val.ty")
+ if len(spread_values) == dest_size:
+ return spread_values
+ if len(spread_values) > dest_size:
+ spread_values[dest_size:] = []
+ elif src_signed:
+ sign = fn.append_new_op(
+ OpKind.SRADI, input_vals=[spread_values[-1]],
+ immediates=[GPR_SIZE_IN_BITS - 1], name=f"{name}_sign")
+ spread_values += [sign.outputs[0]] * (dest_size - len(spread_values))
+ else:
+ zero = fn.append_new_op(
+ OpKind.LI, immediates=[0], name=f"{name}_zero")
+ spread_values += [zero.outputs[0]] * (dest_size - len(spread_values))
+ return spread_values
+
+
+def split_into_exact_sized_parts(fn, ssa_val, part_count, part_size, name):
+ # type: (Fn, SSAVal, int, int, str) -> list[SSAVal]
+ """split ssa_val into part_count parts, where all but the last part have
+ `part.ty.reg_len == part_size`.
+ """
+ if part_size <= 0:
+ raise ValueError("invalid part size, must be positive")
+ if part_count <= 0:
+ raise ValueError("invalid part count, must be positive")
+ if part_count == 1:
+ return [ssa_val]
+ too_short_reg_len = (part_count - 1) * part_size
+ if ssa_val.ty.reg_len <= too_short_reg_len:
+ raise ValueError(f"ssa_val is too short to split, must have "
+ f"reg_len > {too_short_reg_len}: {ssa_val}")
+ maxvl = ssa_val.ty.reg_len
+ setvl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl],
+ maxvl=maxvl, name=f"{name}_setvl")
+ spread = fn.append_new_op(
+ OpKind.Spread, input_vals=[ssa_val, setvl.outputs[0]],
+ name=f"{name}_spread", maxvl=maxvl)
+ retval = [] # type: list[SSAVal]
+ for part in range(part_count):
+ start = part * part_size
+ stop = min(maxvl, start + part_size)
+ part_maxvl = stop - start
+ part_setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[part_size], maxvl=part_size,
+ name=f"{name}_{part}_setvl")
+ concat = fn.append_new_op(
+ OpKind.Concat,
+ input_vals=[*spread.outputs[start:stop], part_setvl.outputs[0]],
+ name=f"{name}_{part}_concat", maxvl=part_maxvl)
+ retval.append(concat.outputs[0])
+ return retval
+
+
+def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances,
+ start_instance_index=0):
+ # type: (Fn, SSAVal, bool, SSAVal, bool, tuple[ToomCookInstance, ...], int) -> SSAVal
+ if start_instance_index < 0:
+ raise ValueError("start_instance_index must be non-negative")
+ instance = None
+ part_size = 0
+ while start_instance_index < len(instances):
+ instance = instances[start_instance_index]
+ part_size = max(lhs.ty.reg_len // instance.lhs_part_count,
+ rhs.ty.reg_len // instance.rhs_part_count)
+ if part_size <= 0:
+ instance = None
+ start_instance_index += 1