95d04d1705675518ab478f59ce5dce9f8d4299e6
[ieee754fpu.git] / src / ieee754 / fpadd / test / test_add_formal.py
1 import unittest
2 from nmutil.formaltest import FHDLTestCase
3 from ieee754.fpadd.pipeline import FPADDBasePipe
4 from nmigen.hdl.dsl import Module
5 from nmigen.hdl.ast import Initial, Assert, AnyConst, Signal, Assume, Mux
6 from nmigen.hdl.smtlib2 import SmtFloatingPoint, SmtSortFloatingPoint, \
7 SmtSortFloat16, SmtSortFloat32, SmtSortFloat64, SmtBool, \
8 SmtRoundingMode, ROUND_TOWARD_POSITIVE, ROUND_TOWARD_NEGATIVE
9 from ieee754.fpcommon.fpbase import FPRoundingMode
10 from ieee754.pipeline import PipelineSpec
11
12
13 class TestFAddFSubFormal(FHDLTestCase):
14 def tst_fadd_fsub_formal(self, sort, rm, is_sub):
15 assert isinstance(sort, SmtSortFloatingPoint)
16 assert isinstance(rm, FPRoundingMode)
17 assert isinstance(is_sub, bool)
18 width = sort.width
19 dut = FPADDBasePipe(PipelineSpec(width, id_width=4))
20 m = Module()
21 m.submodules.dut = dut
22 m.d.comb += dut.n.i_ready.eq(True)
23 m.d.comb += dut.p.i_valid.eq(Initial())
24 m.d.comb += dut.p.i_data.rm.eq(Mux(Initial(), rm, 0))
25 out = Signal(width)
26 out_full = Signal(reset=False)
27 with m.If(dut.n.trigger):
28 # check we only got output for one cycle
29 m.d.comb += Assert(~out_full)
30 m.d.sync += out.eq(dut.n.o_data.z)
31 m.d.sync += out_full.eq(True)
32 a = Signal(width)
33 b = Signal(width)
34 m.d.comb += dut.p.i_data.a.eq(Mux(Initial(), a, 0))
35 m.d.comb += dut.p.i_data.b.eq(Mux(Initial(), b, 0))
36 m.d.comb += dut.p.i_data.is_sub.eq(Mux(Initial(), is_sub, 0))
37
38 smt_add_sub = SmtFloatingPoint.sub if is_sub else SmtFloatingPoint.add
39 a_fp = SmtFloatingPoint.from_bits(a, sort=sort)
40 b_fp = SmtFloatingPoint.from_bits(b, sort=sort)
41 out_fp = SmtFloatingPoint.from_bits(out, sort=sort)
42 if rm in (FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE,
43 FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE):
44 rounded_up = Signal(width)
45 m.d.comb += rounded_up.eq(AnyConst(width))
46 rounded_up_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_POSITIVE)
47 rounded_down_fp = smt_add_sub(a_fp, b_fp, rm=ROUND_TOWARD_NEGATIVE)
48 m.d.comb += Assume(SmtFloatingPoint.from_bits(
49 rounded_up, sort=sort).same(rounded_up_fp).as_value())
50 use_rounded_up = SmtBool.make(rounded_up[0])
51 if rm is FPRoundingMode.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE:
52 is_zero = rounded_up_fp.is_zero() & rounded_down_fp.is_zero()
53 use_rounded_up |= is_zero
54 expected_fp = use_rounded_up.ite(rounded_up_fp, rounded_down_fp)
55 else:
56 smt_rm = SmtRoundingMode.make(rm.to_smtlib2())
57 expected_fp = smt_add_sub(a_fp, b_fp, rm=smt_rm)
58 expected = Signal(width)
59 m.d.comb += expected.eq(AnyConst(width))
60 quiet_bit = 1 << (sort.mantissa_field_width - 1)
61 nan_exponent = ((1 << sort.eb) - 1) << sort.mantissa_field_width
62 with m.If(expected_fp.is_nan().as_value()):
63 with m.If(a_fp.is_nan().as_value()):
64 m.d.comb += Assume(expected == (a | quiet_bit))
65 with m.Elif(b_fp.is_nan().as_value()):
66 m.d.comb += Assume(expected == (b | quiet_bit))
67 with m.Else():
68 m.d.comb += Assume(expected == (nan_exponent | quiet_bit))
69 with m.Else():
70 m.d.comb += Assume(SmtFloatingPoint.from_bits(expected, sort=sort)
71 .same(expected_fp).as_value())
72 m.d.comb += a.eq(AnyConst(width))
73 m.d.comb += b.eq(AnyConst(width))
74 with m.If(out_full):
75 m.d.comb += Assert(out_fp.same(expected_fp).as_value())
76 m.d.comb += Assert(out == expected)
77 self.assertFormal(m, depth=5, solver="bitwuzla")
78
79 # FIXME: check exception flags
80
81 def test_fadd_f16_rne_formal(self):
82 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, False)
83
84 def test_fadd_f32_rne_formal(self):
85 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, False)
86
87 @unittest.skip("too slow")
88 def test_fadd_f64_rne_formal(self):
89 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, False)
90
91 def test_fadd_f16_rtz_formal(self):
92 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, False)
93
94 def test_fadd_f32_rtz_formal(self):
95 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, False)
96
97 @unittest.skip("too slow")
98 def test_fadd_f64_rtz_formal(self):
99 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, False)
100
101 def test_fadd_f16_rtp_formal(self):
102 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, False)
103
104 def test_fadd_f32_rtp_formal(self):
105 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, False)
106
107 @unittest.skip("too slow")
108 def test_fadd_f64_rtp_formal(self):
109 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, False)
110
111 def test_fadd_f16_rtn_formal(self):
112 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, False)
113
114 def test_fadd_f32_rtn_formal(self):
115 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, False)
116
117 @unittest.skip("too slow")
118 def test_fadd_f64_rtn_formal(self):
119 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, False)
120
121 def test_fadd_f16_rna_formal(self):
122 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, False)
123
124 def test_fadd_f32_rna_formal(self):
125 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, False)
126
127 @unittest.skip("too slow")
128 def test_fadd_f64_rna_formal(self):
129 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, False)
130
131 def test_fadd_f16_rtop_formal(self):
132 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, False)
133
134 def test_fadd_f32_rtop_formal(self):
135 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, False)
136
137 @unittest.skip("too slow")
138 def test_fadd_f64_rtop_formal(self):
139 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, False)
140
141 def test_fadd_f16_rton_formal(self):
142 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, False)
143
144 def test_fadd_f32_rton_formal(self):
145 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, False)
146
147 @unittest.skip("too slow")
148 def test_fadd_f64_rton_formal(self):
149 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, False)
150
151 def test_fsub_f16_rne_formal(self):
152 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNE, True)
153
154 def test_fsub_f32_rne_formal(self):
155 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNE, True)
156
157 @unittest.skip("too slow")
158 def test_fsub_f64_rne_formal(self):
159 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNE, True)
160
161 def test_fsub_f16_rtz_formal(self):
162 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTZ, True)
163
164 def test_fsub_f32_rtz_formal(self):
165 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTZ, True)
166
167 @unittest.skip("too slow")
168 def test_fsub_f64_rtz_formal(self):
169 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTZ, True)
170
171 def test_fsub_f16_rtp_formal(self):
172 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTP, True)
173
174 def test_fsub_f32_rtp_formal(self):
175 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTP, True)
176
177 @unittest.skip("too slow")
178 def test_fsub_f64_rtp_formal(self):
179 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTP, True)
180
181 def test_fsub_f16_rtn_formal(self):
182 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTN, True)
183
184 def test_fsub_f32_rtn_formal(self):
185 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTN, True)
186
187 @unittest.skip("too slow")
188 def test_fsub_f64_rtn_formal(self):
189 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTN, True)
190
191 def test_fsub_f16_rna_formal(self):
192 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RNA, True)
193
194 def test_fsub_f32_rna_formal(self):
195 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RNA, True)
196
197 @unittest.skip("too slow")
198 def test_fsub_f64_rna_formal(self):
199 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RNA, True)
200
201 def test_fsub_f16_rtop_formal(self):
202 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTOP, True)
203
204 def test_fsub_f32_rtop_formal(self):
205 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTOP, True)
206
207 @unittest.skip("too slow")
208 def test_fsub_f64_rtop_formal(self):
209 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTOP, True)
210
211 def test_fsub_f16_rton_formal(self):
212 self.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode.RTON, True)
213
214 def test_fsub_f32_rton_formal(self):
215 self.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode.RTON, True)
216
217 @unittest.skip("too slow")
218 def test_fsub_f64_rton_formal(self):
219 self.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode.RTON, True)
220
221 def test_all_rounding_modes_covered(self):
222 for width in 16, 32, 64:
223 for rm in FPRoundingMode:
224 rm_s = rm.name.lower()
225 name = f"test_fadd_f{width}_{rm_s}_formal"
226 assert callable(getattr(self, name))
227 name = f"test_fsub_f{width}_{rm_s}_formal"
228 assert callable(getattr(self, name))
229
230
231 if __name__ == '__main__':
232 unittest.main()