disable fadd f32 formal proofs by default -- they're too slow
[ieee754fpu.git] / src / ieee754 / fpadd / test / test_add_formal.py
1 import unittest
2 from nmutil.formaltest import FHDLTestCase
3 from ieee754.fpadd.pipeline import FPADDBasePipe
4 from nmigen.hdl.dsl import Module
5 from nmigen.hdl.ast import Initial, Assert, AnyConst, Signal, Assume, Mux
6 from nmigen.hdl.smtlib2 import SmtFloatingPoint, SmtSortFloatingPoint, \
7 SmtSortFloat16, SmtSortFloat32, SmtSortFloat64, SmtBool, \
8 SmtRoundingMode, ROUND_TOWARD_POSITIVE, ROUND_TOWARD_NEGATIVE
9 from ieee754.fpcommon.fpbase import FPRoundingMode
10 from ieee754.pipeline import PipelineSpec
11 import os
12
13 ENABLE_FADD_F32_FORMAL = os.getenv("ENABLE_FADD_F32_FORMAL") is not None
14
15
16 class TestFAddFSubFormal(FHDLTestCase):
17 def tst_fadd_fsub_formal(self, sort, rm, is_sub):
18 assert isinstance(sort, SmtSortFloatingPoint)
19 assert isinstance(rm, FPRoundingMode)
20 assert isinstance(is_sub, bool)
21 width = sort.width
22 dut = FPADDBasePipe(PipelineSpec(width, id_width=4))
23 m = Module()
24 m.submodules.dut = dut
25 m.d.comb += dut.n.i_ready.eq(True)
26 m.d.comb += dut.p.i_valid.eq(Initial())
27 m.d.comb += dut.p.i_data.rm.eq(Mux(Initial(), rm, 0))
28 out = Signal(width)
29 out_full = Signal(reset=False)
30 with m.If(dut.n.trigger):
31 # check we only got output for one cycle
32 m.d.comb += Assert(~out_full)
33 m.d.sync += out.eq(dut.n.o_data.z)
34 m.d.sync += out_full.eq(True)
35 a = Signal(width)
36 b = Signal(width)
37 m.d.comb += dut.p.i_data.a.eq(Mux(Initial(), a, 0))
38 m.d.comb += dut.p.i_data.b.eq(Mux(Initial(), b, 0))
39 m.d.comb += dut.p.i_data.is_sub.eq(Mux(Initial(), is_sub, 0))
40
41 smt_add_sub = SmtFloatingPoint.sub if is_sub else SmtFloatingPoint.add
42 a_fp = SmtFloatingPoint.from_bits(a, sort=sort)
43 b_fp = SmtFloatingPoint.from_bits(b, sort=sort)
44 out_fp = SmtFloatingPoint.from_bits(out, sort=sort)
45 if rm in (FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE,
46 FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE):
47 rounded_up = Signal(width)
48 m.d.comb += rounded_up.eq(AnyConst(width))
49 rounded_up_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_POSITIVE)
50 rounded_down_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_NEGATIVE)
51 m.d.comb += Assume(SmtFloatingPoint.from_bits(
52 rounded_up, sort=sort).same(rounded_up_fp).as_value())
53 use_rounded_up = SmtBool.make(rounded_up[0])
54 if rm is FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE:
55 is_zero = rounded_up_fp.is_zero() & rounded_down_fp.is_zero()
56 use_rounded_up |= is_zero
57 expected_fp = use_rounded_up.ite(rounded_up_fp, rounded_down_fp)
58 else:
59 smt_rm = SmtRoundingMode.make(rm.to_smtlib2())
60 expected_fp = smt_add_sub(a_fp, b_fp, rm=smt_rm)
61 expected = Signal(width)
62 m.d.comb += expected.eq(AnyConst(width))
63 quiet_bit = 1 << (sort.mantissa_field_width - 1)
64 nan_exponent = ((1 << sort.eb) - 1) << sort.mantissa_field_width
65 with m.If(expected_fp.is_nan().as_value()):
66 with m.If(a_fp.is_nan().as_value()):
67 m.d.comb += Assume(expected == (a | quiet_bit))
68 with m.Elif(b_fp.is_nan().as_value()):
69 m.d.comb += Assume(expected == (b | quiet_bit))
70 with m.Else():
71 m.d.comb += Assume(expected == (nan_exponent | quiet_bit))
72 with m.Else():
73 m.d.comb += Assume(SmtFloatingPoint.from_bits(expected, sort=sort)
74 .same(expected_fp).as_value())
75 m.d.comb += a.eq(AnyConst(width))
76 m.d.comb += b.eq(AnyConst(width))
77 with m.If(out_full):
78 m.d.comb += Assert(out_fp.same(expected_fp).as_value())
79 m.d.comb += Assert(out == expected)
80 self.assertFormal(m, depth=5, solver="bitwuzla")
81
82 # FIXME: check exception flags
83
84 def test_fadd_f16_rne_formal(self):
85 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, False)
86
87 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
88 "ENABLE_FADD_F32_FORMAL not in environ")
89 def test_fadd_f32_rne_formal(self):
90 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, False)
91
92 @unittest.skip("too slow")
93 def test_fadd_f64_rne_formal(self):
94 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, False)
95
96 def test_fadd_f16_rtz_formal(self):
97 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, False)
98
99 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
100 "ENABLE_FADD_F32_FORMAL not in environ")
101 def test_fadd_f32_rtz_formal(self):
102 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, False)
103
104 @unittest.skip("too slow")
105 def test_fadd_f64_rtz_formal(self):
106 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, False)
107
108 def test_fadd_f16_rtp_formal(self):
109 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, False)
110
111 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
112 "ENABLE_FADD_F32_FORMAL not in environ")
113 def test_fadd_f32_rtp_formal(self):
114 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, False)
115
116 @unittest.skip("too slow")
117 def test_fadd_f64_rtp_formal(self):
118 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, False)
119
120 def test_fadd_f16_rtn_formal(self):
121 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, False)
122
123 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
124 "ENABLE_FADD_F32_FORMAL not in environ")
125 def test_fadd_f32_rtn_formal(self):
126 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, False)
127
128 @unittest.skip("too slow")
129 def test_fadd_f64_rtn_formal(self):
130 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, False)
131
132 def test_fadd_f16_rna_formal(self):
133 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, False)
134
135 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
136 "ENABLE_FADD_F32_FORMAL not in environ")
137 def test_fadd_f32_rna_formal(self):
138 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, False)
139
140 @unittest.skip("too slow")
141 def test_fadd_f64_rna_formal(self):
142 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, False)
143
144 def test_fadd_f16_rtop_formal(self):
145 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, False)
146
147 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
148 "ENABLE_FADD_F32_FORMAL not in environ")
149 def test_fadd_f32_rtop_formal(self):
150 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, False)
151
152 @unittest.skip("too slow")
153 def test_fadd_f64_rtop_formal(self):
154 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, False)
155
156 def test_fadd_f16_rton_formal(self):
157 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, False)
158
159 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
160 "ENABLE_FADD_F32_FORMAL not in environ")
161 def test_fadd_f32_rton_formal(self):
162 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, False)
163
164 @unittest.skip("too slow")
165 def test_fadd_f64_rton_formal(self):
166 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, False)
167
168 def test_fsub_f16_rne_formal(self):
169 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, True)
170
171 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
172 "ENABLE_FADD_F32_FORMAL not in environ")
173 def test_fsub_f32_rne_formal(self):
174 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, True)
175
176 @unittest.skip("too slow")
177 def test_fsub_f64_rne_formal(self):
178 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, True)
179
180 def test_fsub_f16_rtz_formal(self):
181 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, True)
182
183 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
184 "ENABLE_FADD_F32_FORMAL not in environ")
185 def test_fsub_f32_rtz_formal(self):
186 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, True)
187
188 @unittest.skip("too slow")
189 def test_fsub_f64_rtz_formal(self):
190 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, True)
191
192 def test_fsub_f16_rtp_formal(self):
193 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, True)
194
195 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
196 "ENABLE_FADD_F32_FORMAL not in environ")
197 def test_fsub_f32_rtp_formal(self):
198 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, True)
199
200 @unittest.skip("too slow")
201 def test_fsub_f64_rtp_formal(self):
202 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, True)
203
204 def test_fsub_f16_rtn_formal(self):
205 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, True)
206
207 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
208 "ENABLE_FADD_F32_FORMAL not in environ")
209 def test_fsub_f32_rtn_formal(self):
210 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, True)
211
212 @unittest.skip("too slow")
213 def test_fsub_f64_rtn_formal(self):
214 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, True)
215
216 def test_fsub_f16_rna_formal(self):
217 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, True)
218
219 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
220 "ENABLE_FADD_F32_FORMAL not in environ")
221 def test_fsub_f32_rna_formal(self):
222 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, True)
223
224 @unittest.skip("too slow")
225 def test_fsub_f64_rna_formal(self):
226 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, True)
227
228 def test_fsub_f16_rtop_formal(self):
229 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, True)
230
231 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
232 "ENABLE_FADD_F32_FORMAL not in environ")
233 def test_fsub_f32_rtop_formal(self):
234 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, True)
235
236 @unittest.skip("too slow")
237 def test_fsub_f64_rtop_formal(self):
238 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, True)
239
240 def test_fsub_f16_rton_formal(self):
241 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, True)
242
243 @unittest.skipUnless(ENABLE_FADD_F32_FORMAL,
244 "ENABLE_FADD_F32_FORMAL not in environ")
245 def test_fsub_f32_rton_formal(self):
246 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, True)
247
248 @unittest.skip("too slow")
249 def test_fsub_f64_rton_formal(self):
250 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, True)
251
252 def test_all_rounding_modes_covered(self):
253 for width in 16, 32, 64:
254 for rm in FPRoundingMode:
255 rm_s = rm.name.lower()
256 name = f"test_fadd_f{width}_{rm_s}_formal"
257 assert callable(getattr(self, name))
258 name = f"test_fsub_f{width}_{rm_s}_formal"
259 assert callable(getattr(self, name))
260
261
262 if __name__ == '__main__':
263 unittest.main()