- for extra_precision in range(io_width * 2 + 4):
- for table_addr_bits in range(1, 7 + 1):
- table_data_bits = io_width + extra_precision
- for iter_count in range(1, 2 * io_width.bit_length()):
- try:
- return GoldschmidtDivParams(
- io_width=io_width,
- extra_precision=extra_precision,
- table_addr_bits=table_addr_bits,
- table_data_bits=table_data_bits,
- iter_count=iter_count)
- except ParamsNotAccurateEnough as e:
- last_params = (f"GoldschmidtDivParams("
- f"io_width={io_width!r}, "
- f"extra_precision={extra_precision!r}, "
- f"table_addr_bits={table_addr_bits!r}, "
- f"table_data_bits={table_data_bits!r}, "
- f"iter_count={iter_count!r})")
- last_error = e
- raise ValueError(f"can't find working parameters for a goldschmidt "
- f"division algorithm: last params: {last_params}"
- ) from last_error
+ last_error_params = None
+
+ def cached_new(base_params):
+ def handle_error(e):
+ nonlocal last_error, last_error_params
+ last_error = e
+ last_error_params = base_params
+ return None
+
+ retval = GoldschmidtDivParams.cached_new(base_params, handle_error)
+ if retval is None:
+ logging.debug(f"GoldschmidtDivParams.get: err: {base_params}")
+ else:
+ logging.debug(f"GoldschmidtDivParams.get: ok: {base_params}")
+ return retval
+
+ @lru_cache(maxsize=None)
+ def get_cost(base_params):
+ params = cached_new(base_params)
+ if params is None:
+ return math.inf
+ retval = cost_fn(params)
+ logging.debug(f"GoldschmidtDivParams.get: cost={retval}: {params}")
+ return retval
+
+ # start with parameters big enough to always work.
+ initial_extra_precision = io_width * 2 + 4
+ initial_params = GoldschmidtDivParamsBase(
+ io_width=io_width,
+ extra_precision=initial_extra_precision,
+ table_addr_bits=min(max_table_addr_bits, io_width),
+ table_data_bits=io_width + initial_extra_precision,
+ iter_count=1 + io_width.bit_length())
+
+ if cached_new(initial_params) is None:
+ raise ValueError(f"initial goldschmidt division algorithm "
+ f"parameters are invalid: {initial_params}"
+ ) from last_error
+
+ # find good initial `iter_count`
+ params = initial_params
+ for iter_count in range(1, initial_params.iter_count):
+ trial_params = replace(params, iter_count=iter_count)
+ if cached_new(trial_params) is not None:
+ params = trial_params
+ break
+
+ # now find `table_addr_bits`
+ cost = get_cost(params)
+ for table_addr_bits in range(1, max_table_addr_bits):
+ trial_params = replace(params, table_addr_bits=table_addr_bits)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ break
+
+ # check one higher `iter_count` to see if it has lower cost
+ for table_addr_bits in range(1, max_table_addr_bits + 1):
+ trial_params = replace(params,
+ table_addr_bits=table_addr_bits,
+ iter_count=params.iter_count + 1)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ break
+
+ # now shrink `table_data_bits`
+ while True:
+ trial_params = replace(params,
+ table_data_bits=params.table_data_bits - 1)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ else:
+ break
+
+ # and shrink `extra_precision`
+ while True:
+ trial_params = replace(params,
+ extra_precision=params.extra_precision - 1)
+ trial_cost = get_cost(trial_params)
+ if trial_cost < cost:
+ params = trial_params
+ cost = trial_cost
+ else:
+ break
+
+ return cached_new(params)