adding more tests smtlib2-expr-support
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 20 May 2022 08:55:20 +0000 (01:55 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 20 May 2022 08:55:20 +0000 (01:55 -0700)
nmigen/hdl/smtlib2.py
tests/test_hdl_smtlib2.py

index 04846dd3e1d0cfacd526d8ec41975b1859f182b6..a32bde80d89326e61505b00ce84ed8f7eceb5483 100644 (file)
@@ -475,8 +475,9 @@ class _SmtUnary(SmtValue):
         assert isinstance(expr_state, _ExprState)  # :nocov:
         return str(...)  # :nocov:
 
+    @abstractmethod
     def _expected_input_class(self):
-        return SmtValue
+        return SmtValue  # :nocov:
 
     def __init__(self, inp):
         object.__setattr__(self, "inp", inp)
@@ -499,8 +500,9 @@ class _SmtBinary(SmtValue):
         assert isinstance(expr_state, _ExprState)  # :nocov:
         return str(...)  # :nocov:
 
+    @abstractmethod
     def _expected_input_class(self):
-        return SmtValue
+        return SmtValue  # :nocov:
 
     @property
     def input_sort(self):
@@ -542,13 +544,16 @@ class SmtDistinct(_SmtNAry, SmtBool):
 class SmtBoolConst(SmtBool):
     value: bool
 
+    def __post_init__(self):
+        assert isinstance(self.value, bool)
+
     def _smtlib2_expr(self, expr_state):
         assert isinstance(expr_state, _ExprState)
         return "true" if self.value else "false"
 
 
 @final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
 class SmtBoolNot(_SmtUnary, SmtBool):
     inp: SmtBool
 
@@ -720,7 +725,7 @@ class SmtRealConst(SmtReal):
 
 
 @final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
 class SmtRealNeg(_SmtUnary, SmtReal):
     inp: SmtReal
 
@@ -733,7 +738,7 @@ class SmtRealNeg(_SmtUnary, SmtReal):
 
 
 @final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
 class SmtRealIsInt(_SmtUnary, SmtBool):
     inp: SmtReal
 
@@ -927,7 +932,7 @@ class SmtIntConst(SmtInt):
 
 
 @final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
 class SmtIntNeg(_SmtUnary, SmtInt):
     inp: SmtInt
 
@@ -940,7 +945,7 @@ class SmtIntNeg(_SmtUnary, SmtInt):
 
 
 @final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
 class SmtIntAbs(_SmtUnary, SmtInt):
     inp: SmtInt
 
@@ -953,7 +958,7 @@ class SmtIntAbs(_SmtUnary, SmtInt):
 
 
 @final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
 class SmtIntToReal(_SmtUnary, SmtReal):
     inp: SmtInt
 
@@ -966,7 +971,7 @@ class SmtIntToReal(_SmtUnary, SmtReal):
 
 
 @final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
 class SmtRealToInt(_SmtUnary, SmtInt):
     inp: SmtReal
 
@@ -1407,7 +1412,7 @@ class SmtBitVecNot(_SmtBitVecUnary):
 
 
 @final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
 class SmtBitVecToNat(_SmtUnary, SmtInt):
     inp: SmtBitVec
 
@@ -2180,7 +2185,7 @@ class SmtFloatingPointGE(_SmtFloatingPointCompareOp):
         return "fp.geq"
 
 
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
 class _SmtFloatingPointToBoolUnary(_SmtUnary, SmtBool):
     inp: SmtFloatingPoint
 
@@ -2330,7 +2335,7 @@ class SmtFloatingPointFromBits(_SmtUnary, SmtFloatingPoint):
 
 
 @final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
 class SmtFloatingPointToReal(_SmtUnary, SmtReal):
     inp: SmtFloatingPoint
 
@@ -2343,7 +2348,7 @@ class SmtFloatingPointToReal(_SmtUnary, SmtReal):
 
 
 @final
-@dataclass(frozen=True, unsafe_hash=False, eq=False)
+@dataclass(frozen=True, unsafe_hash=False, eq=False, init=False)
 class SmtFloatingPointToUnsignedBV(_SmtUnary, SmtBitVec):
     inp: SmtFloatingPoint
 
index 423ea8f7cd9b38a2449ec5a7b6136a01935e1f02..825c81fc84ee729314ab4d3d5c0094a2521fb283 100644 (file)
@@ -1,6 +1,7 @@
 from nmigen.hdl.smtlib2 import *
 from nmigen.hdl.smtlib2 import _ExprState
-from nmigen.hdl.ast import ValueDict, Signal
+from nmigen.hdl.ast import AnyConst, Assert, Mux, ValueDict, Signal, Value
+from nmigen.hdl.dsl import Module
 from .utils import FHDLTestCase
 
 
@@ -13,26 +14,46 @@ class SmtTestCase(FHDLTestCase):
         self.assertEqual(state.input_bit_ranges, ValueDict(input_bit_ranges))
         self.assertEqual(state.input_width, input_width)
 
+    def assertSmtSame(self, expr, expected, *, inputs, assumptions=()):
+        assert isinstance(expr, (SmtBitVec, SmtBool))
+        expected = Value.cast(expected)
+        m = Module()
+        inputs = list(inputs)
+        for sig in inputs:
+            assert isinstance(sig, Signal)
+            any_const = AnyConst(sig.shape())
+            any_const.src_loc = sig.src_loc  # makes it easier to debug
+            m.d.comb += sig.eq(any_const)
+        expr_s = expr.to_signal()
+        m.submodules += expr_s
+        expected_s = Signal(expected.shape())
+        m.d.comb += [
+            expected_s.eq(expected),
+            Assert(expr_s.o_sig == expected_s),
+            *assumptions,
+        ]
+        self.assertFormal(m)
+
 
 class TestSorts(SmtTestCase):
-    def test_bool(self):
+    def test_sort_bool(self):
         self.assertSmtExpr(SmtSortBool(), SmtSortBool, "Bool")
 
-    def test_int(self):
+    def test_sort_int(self):
         self.assertSmtExpr(SmtSortInt(), SmtSortInt, "Int")
 
-    def test_real(self):
+    def test_sort_real(self):
         self.assertSmtExpr(SmtSortReal(), SmtSortReal, "Real")
 
-    def test_bv(self):
+    def test_sort_bv(self):
         self.assertSmtExpr(SmtSortBitVec(1), SmtSortBitVec, "(_ BitVec 1)")
         self.assertSmtExpr(SmtSortBitVec(16), SmtSortBitVec, "(_ BitVec 16)")
 
-    def test_rm(self):
+    def test_sort_rm(self):
         self.assertSmtExpr(SmtSortRoundingMode(),
                            SmtSortRoundingMode, "RoundingMode")
 
-    def test_fp(self):
+    def test_sort_fp(self):
         self.assertSmtExpr(SmtSortFloat16(), SmtSortFloatingPoint, "Float16")
         self.assertSmtExpr(SmtSortFloat32(), SmtSortFloatingPoint, "Float32")
         self.assertSmtExpr(SmtSortFloat64(), SmtSortFloatingPoint, "Float64")
@@ -44,7 +65,7 @@ class TestSorts(SmtTestCase):
 
 
 class TestBool(SmtTestCase):
-    def test_make(self):
+    def test_bool_make(self):
         self.assertSmtExpr(SmtBool.make(False), SmtBoolConst, "false")
         self.assertSmtExpr(SmtBool.make(True), SmtBoolConst, "true")
         sig = Signal()
@@ -58,12 +79,54 @@ class TestBool(SmtTestCase):
                            input_bit_ranges=[(sig2.bool(), range(0, 1))],
                            input_width=1)
 
-    def test_ite(self):
+    def test_bool_same(self):
+        a = Signal()
+        b = Signal()
+        c = Signal()
+        expr = SmtBool.make(a).same(SmtBool.make(b), SmtBool.make(c))
+        self.assertSmtExpr(
+            expr,
+            SmtSame,
+            "(= (distinct #b0 ((_ extract 0 0) A)) "
+            "(distinct #b0 ((_ extract 1 1) A)) "
+            "(distinct #b0 ((_ extract 2 2) A)))",
+            input_bit_ranges=[
+                (a, range(0, 1)),
+                (b, range(1, 2)),
+                (c, range(2, 3)),
+            ],
+            input_width=3,
+        )
+        self.assertSmtSame(expr, (a == b) & (b == c), inputs=(a, b))
+
+    def test_bool_distinct(self):
+        # only check 2 inputs since if we were to check 3 inputs it would
+        # always return false since there are only 2 possible Bool values
+        # but you'd need at least 3 to make distinct return True since every
+        # input needs to be different than all others.
+        a = Signal()
+        b = Signal()
+        expr = SmtBool.make(a).distinct(SmtBool.make(b))
+        self.assertSmtExpr(
+            expr,
+            SmtDistinct,
+            "(distinct (distinct #b0 ((_ extract 0 0) A)) "
+            "(distinct #b0 ((_ extract 1 1) A)))",
+            input_bit_ranges=[
+                (a, range(0, 1)),
+                (b, range(1, 2)),
+            ],
+            input_width=2,
+        )
+        self.assertSmtSame(expr, a != b, inputs=(a, b))
+
+    def test_bool_ite(self):
         a = Signal()
         b = Signal()
         c = Signal()
+        expr = SmtBool.make(a).ite(SmtBool.make(b), SmtBool.make(c))
         self.assertSmtExpr(
-            SmtBool.make(a).ite(SmtBool.make(b), SmtBool.make(c)),
+            expr,
             SmtBoolITE,
             "(ite (distinct #b0 ((_ extract 0 0) A)) "
             "(distinct #b0 ((_ extract 1 1) A)) "
@@ -75,5 +138,133 @@ class TestBool(SmtTestCase):
             ],
             input_width=3,
         )
+        self.assertSmtSame(expr, Mux(a, b, c), inputs=(a, b, c))
+
+    def test_bool_bool(self):
+        with self.assertRaises(TypeError):
+            bool(SmtBool.make(False))
+
+    def test_bool_invert(self):
+        a = Signal()
+        expr = ~SmtBool.make(a)
+        self.assertSmtExpr(
+            expr,
+            SmtBoolNot,
+            "(not (distinct #b0 ((_ extract 0 0) A)))",
+            input_bit_ranges=[
+                (a, range(0, 1)),
+            ],
+            input_width=1,
+        )
+        self.assertSmtSame(expr, ~a, inputs=(a,))
+
+    def test_bool_and(self):
+        a = Signal()
+        b = Signal()
+        expr = SmtBool.make(a) & SmtBool.make(b)
+        self.assertSmtExpr(
+            expr,
+            SmtBoolAnd,
+            "(and (distinct #b0 ((_ extract 0 0) A)) "
+            "(distinct #b0 ((_ extract 1 1) A)))",
+            input_bit_ranges=[
+                (a, range(0, 1)),
+                (b, range(1, 2)),
+            ],
+            input_width=2,
+        )
+        self.assertEqual(repr(expr),
+                         repr(SmtBool.make(b).__rand__(SmtBool.make(a))))
+        self.assertSmtSame(expr, a & b, inputs=(a, b))
+
+    def test_bool_xor(self):
+        a = Signal()
+        b = Signal()
+        expr = SmtBool.make(a) ^ SmtBool.make(b)
+        self.assertSmtExpr(
+            expr,
+            SmtBoolXor,
+            "(xor (distinct #b0 ((_ extract 0 0) A)) "
+            "(distinct #b0 ((_ extract 1 1) A)))",
+            input_bit_ranges=[
+                (a, range(0, 1)),
+                (b, range(1, 2)),
+            ],
+            input_width=2,
+        )
+        self.assertEqual(repr(expr),
+                         repr(SmtBool.make(b).__rxor__(SmtBool.make(a))))
+        self.assertSmtSame(expr, a ^ b, inputs=(a, b))
+
+    def test_bool_or(self):
+        a = Signal()
+        b = Signal()
+        expr = SmtBool.make(a) | SmtBool.make(b)
+        self.assertSmtExpr(
+            expr,
+            SmtBoolOr,
+            "(or (distinct #b0 ((_ extract 0 0) A)) "
+            "(distinct #b0 ((_ extract 1 1) A)))",
+            input_bit_ranges=[
+                (a, range(0, 1)),
+                (b, range(1, 2)),
+            ],
+            input_width=2,
+        )
+        self.assertEqual(repr(expr),
+                         repr(SmtBool.make(b).__ror__(SmtBool.make(a))))
+        self.assertSmtSame(expr, a | b, inputs=(a, b))
+
+    def test_bool_eq(self):
+        a = Signal()
+        b = Signal()
+        expr = SmtBool.make(a) == SmtBool.make(b)
+        self.assertSmtExpr(
+            expr,
+            SmtSame,
+            "(= (distinct #b0 ((_ extract 0 0) A)) "
+            "(distinct #b0 ((_ extract 1 1) A)))",
+            input_bit_ranges=[
+                (a, range(0, 1)),
+                (b, range(1, 2)),
+            ],
+            input_width=2,
+        )
+        self.assertSmtSame(expr, a == b, inputs=(a, b))
+
+    def test_bool_ne(self):
+        a = Signal()
+        b = Signal()
+        expr = SmtBool.make(a) != SmtBool.make(b)
+        self.assertSmtExpr(
+            expr,
+            SmtDistinct,
+            "(distinct (distinct #b0 ((_ extract 0 0) A)) "
+            "(distinct #b0 ((_ extract 1 1) A)))",
+            input_bit_ranges=[
+                (a, range(0, 1)),
+                (b, range(1, 2)),
+            ],
+            input_width=2,
+        )
+        self.assertSmtSame(expr, a != b, inputs=(a, b))
+
+    def test_bool_implies(self):
+        a = Signal()
+        b = Signal()
+        expr = SmtBool.make(a).implies(SmtBool.make(b))
+        self.assertSmtExpr(
+            expr,
+            SmtBoolImplies,
+            "(=> (distinct #b0 ((_ extract 0 0) A)) "
+            "(distinct #b0 ((_ extract 1 1) A)))",
+            input_bit_ranges=[
+                (a, range(0, 1)),
+                (b, range(1, 2)),
+            ],
+            input_width=2,
+        )
+        self.assertSmtSame(expr, a.implies(b), inputs=(a, b))
+
 
 # FIXME: add more tests