# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
# of Horizon 2020 EU Programme 957073.
+from dataclasses import fields, replace
import math
import unittest
from nmutil.formaltest import FHDLTestCase
-from nmutil.sim_util import do_sim
+from nmutil.sim_util import do_sim, hash_256
from nmigen.sim import Tick, Delay
from nmigen.hdl.ast import Signal
from nmigen.hdl.dsl import Module
from soc.fu.div.experiment.goldschmidt_div_sqrt import (
- GoldschmidtDivHDL, GoldschmidtDivParams, ParamsNotAccurateEnough,
- goldschmidt_div, FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
+ GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams,
+ GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div,
+ FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
class TestFixedPoint(FHDLTestCase):
table_addr_bits=1, table_data_bits=5,
iter_count=1)
- def tst(self, io_width):
+ @staticmethod
+ def cases(io_width, cases=None):
+ assert isinstance(io_width, int) and io_width >= 1
+ if cases is not None:
+ for n, d in cases:
+ assert isinstance(d, int) \
+ and 0 < d < (1 << io_width), "invalid case"
+ assert isinstance(n, int) \
+ and 0 <= n < (d << io_width), "invalid case"
+ yield (n, d)
+ elif io_width > 6:
+ assert io_width * 2 <= 256, \
+ "can't generate big enough numbers for test cases"
+ for i in range(10000):
+ d = hash_256(f'd {i}') % (1 << io_width)
+ if d == 0:
+ d = 1
+ n = hash_256(f'n {i}') % (d << io_width)
+ yield (n, d)
+ else:
+ for d in range(1, 1 << io_width):
+ for n in range(d << io_width):
+ yield (n, d)
+
+ def tst(self, io_width, cases=None):
assert isinstance(io_width, int)
params = GoldschmidtDivParams.get(io_width)
with self.subTest(params=str(params)):
- for d in range(1, 1 << io_width):
- for n in range(d << io_width):
- expected_q, expected_r = divmod(n, d)
- with self.subTest(n=hex(n), d=hex(d),
- expected_q=hex(expected_q),
- expected_r=hex(expected_r)):
- q, r = goldschmidt_div(n, d, params)
- with self.subTest(q=hex(q), r=hex(r)):
- self.assertEqual((q, r), (expected_q, expected_r))
-
- @unittest.skip("hdl/simulation currently broken")
+ for n, d in self.cases(io_width, cases):
+ expected_q, expected_r = divmod(n, d)
+ with self.subTest(n=hex(n), d=hex(d),
+ expected_q=hex(expected_q),
+ expected_r=hex(expected_r)):
+ trace = []
+
+ def trace_fn(state):
+ assert isinstance(state, GoldschmidtDivState)
+ trace.append((replace(state)))
+ q, r = goldschmidt_div(n, d, params, trace=trace_fn)
+ with self.subTest(q=hex(q), r=hex(r), trace=repr(trace)):
+ self.assertEqual((q, r), (expected_q, expected_r))
+
def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
sync_rom=False):
- # FIXME: finish getting hdl/simulation to work
assert isinstance(io_width, int)
params = GoldschmidtDivParams.get(io_width)
m = Module()
# make sync domain get added
m.d.sync += Signal().eq(0)
- def iter_cases():
- if cases is not None:
- yield from cases
- return
- for d in range(1, 1 << io_width):
- for n in range(d << io_width):
- yield (n, d)
-
def inputs_proc():
yield Tick()
- for n, d in iter_cases():
+ for n, d in self.cases(io_width, cases):
yield dut.n.eq(n)
yield dut.d.eq(d)
yield Tick()
+ def check_interals(n, d):
+ # check internals only if dut is completely combinatorial
+ # so we don't have to figure out how to read values in
+ # previous clock cycles
+ if dut.total_pipeline_registers != 0:
+ return
+ ref_trace = []
+
+ def ref_trace_fn(state):
+ assert isinstance(state, GoldschmidtDivState)
+ ref_trace.append((replace(state)))
+ goldschmidt_div(n=n, d=d, params=params, trace=ref_trace_fn)
+ self.assertEqual(len(dut.trace), len(ref_trace))
+ for index, state in enumerate(dut.trace):
+ ref_state = ref_trace[index]
+ last_op = None if index == 0 else params.ops[index - 1]
+ with self.subTest(index=index, state=repr(state),
+ ref_state=repr(ref_state),
+ last_op=str(last_op)):
+ for field in fields(GoldschmidtDivHDLState):
+ sig = getattr(state, field.name)
+ if not isinstance(sig, Signal):
+ continue
+ ref_value = getattr(ref_state, field.name)
+ ref_value_str = repr(ref_value)
+ if isinstance(ref_value, int):
+ ref_value_str = hex(ref_value)
+ value = yield sig
+ with self.subTest(field_name=field.name,
+ sig=repr(sig),
+ sig_shape=repr(sig.shape()),
+ value=hex(value),
+ ref_value=ref_value_str):
+ if isinstance(ref_value, int):
+ self.assertEqual(value, ref_value)
+ else:
+ assert isinstance(ref_value, FixedPoint)
+ self.assertEqual(value, ref_value.bits)
+
def check_outputs():
yield Tick()
for _ in range(dut.total_pipeline_registers):
yield Tick()
- for n, d in iter_cases():
+ for n, d in self.cases(io_width, cases):
yield Delay(0.1e-6)
expected_q, expected_r = divmod(n, d)
with self.subTest(n=hex(n), d=hex(d),
r = yield dut.r
with self.subTest(q=hex(q), r=hex(r)):
self.assertEqual((q, r), (expected_q, expected_r))
+ yield from check_interals(n, d)
+
yield Tick()
with self.subTest(params=str(params)):
def test_6(self):
self.tst(6)
+ def test_8(self):
+ self.tst(8)
+
+ def test_16(self):
+ self.tst(16)
+
+ def test_32(self):
+ self.tst(32)
+
+ def test_64(self):
+ self.tst(64)
+
def test_sim_5(self):
self.tst_sim(5)
+ def test_sim_8(self):
+ self.tst_sim(8)
+
+ def test_sim_16(self):
+ self.tst_sim(16)
+
+ def test_sim_32(self):
+ self.tst_sim(32)
+
+ def test_sim_64(self):
+ self.tst_sim(64)
+
def tst_params(self, io_width):
assert isinstance(io_width, int)
params = GoldschmidtDivParams.get(io_width)