2 from nmutil
.formaltest
import FHDLTestCase
3 from ieee754
.fpadd
.pipeline
import FPADDBasePipe
4 from nmigen
.hdl
.dsl
import Module
5 from nmigen
.hdl
.ast
import Initial
, Assert
, AnyConst
, Signal
, Assume
, Mux
6 from nmigen
.hdl
.smtlib2
import SmtFloatingPoint
, SmtSortFloatingPoint
, \
7 SmtSortFloat16
, SmtSortFloat32
, SmtSortFloat64
, SmtBool
, \
8 SmtRoundingMode
, ROUND_TOWARD_POSITIVE
, ROUND_TOWARD_NEGATIVE
9 from ieee754
.fpcommon
.fpbase
import FPRoundingMode
10 from ieee754
.pipeline
import PipelineSpec
13 class TestFAddFSubFormal(FHDLTestCase
):
14 def tst_fadd_fsub_formal(self
, sort
, rm
, is_sub
):
15 assert isinstance(sort
, SmtSortFloatingPoint
)
16 assert isinstance(rm
, FPRoundingMode
)
17 assert isinstance(is_sub
, bool)
19 dut
= FPADDBasePipe(PipelineSpec(width
, id_width
=4))
21 m
.submodules
.dut
= dut
22 m
.d
.comb
+= dut
.n
.i_ready
.eq(True)
23 m
.d
.comb
+= dut
.p
.i_valid
.eq(Initial())
24 m
.d
.comb
+= dut
.p
.i_data
.rm
.eq(Mux(Initial(), rm
, 0))
26 out_full
= Signal(reset
=False)
27 with m
.If(dut
.n
.trigger
):
28 # check we only got output for one cycle
29 m
.d
.comb
+= Assert(~out_full
)
30 m
.d
.sync
+= out
.eq(dut
.n
.o_data
.z
)
31 m
.d
.sync
+= out_full
.eq(True)
34 m
.d
.comb
+= dut
.p
.i_data
.a
.eq(Mux(Initial(), a
, 0))
35 m
.d
.comb
+= dut
.p
.i_data
.b
.eq(Mux(Initial(), b
, 0))
36 m
.d
.comb
+= dut
.p
.i_data
.is_sub
.eq(Mux(Initial(), is_sub
, 0))
38 smt_add_sub
= SmtFloatingPoint
.sub
if is_sub
else SmtFloatingPoint
.add
39 a_fp
= SmtFloatingPoint
.from_bits(a
, sort
=sort
)
40 b_fp
= SmtFloatingPoint
.from_bits(b
, sort
=sort
)
41 out_fp
= SmtFloatingPoint
.from_bits(out
, sort
=sort
)
42 if rm
in (FPRoundingMode
.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE
,
43 FPRoundingMode
.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE
):
44 rounded_up
= Signal(width
)
45 m
.d
.comb
+= rounded_up
.eq(AnyConst(width
))
46 rounded_up_fp
= smt_add_sub(a_fp
, b_fp
, rm
=ROUND_TOWARD_POSITIVE
)
47 rounded_down_fp
= smt_add_sub(a_fp
, b_fp
, rm
=ROUND_TOWARD_NEGATIVE
)
48 m
.d
.comb
+= Assume(SmtFloatingPoint
.from_bits(
49 rounded_up
, sort
=sort
).same(rounded_up_fp
).as_value())
50 use_rounded_up
= SmtBool
.make(rounded_up
[0])
51 if rm
is FPRoundingMode
.ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE
:
52 is_zero
= rounded_up_fp
.is_zero() & rounded_down_fp
.is_zero()
53 use_rounded_up |
= is_zero
54 expected_fp
= use_rounded_up
.ite(rounded_up_fp
, rounded_down_fp
)
56 smt_rm
= SmtRoundingMode
.make(rm
.to_smtlib2())
57 expected_fp
= smt_add_sub(a_fp
, b_fp
, rm
=smt_rm
)
58 expected
= Signal(width
)
59 m
.d
.comb
+= expected
.eq(AnyConst(width
))
60 quiet_bit
= 1 << (sort
.mantissa_field_width
- 1)
61 nan_exponent
= ((1 << sort
.eb
) - 1) << sort
.mantissa_field_width
62 with m
.If(expected_fp
.is_nan().as_value()):
63 with m
.If(a_fp
.is_nan().as_value()):
64 m
.d
.comb
+= Assume(expected
== (a | quiet_bit
))
65 with m
.Elif(b_fp
.is_nan().as_value()):
66 m
.d
.comb
+= Assume(expected
== (b | quiet_bit
))
68 m
.d
.comb
+= Assume(expected
== (nan_exponent | quiet_bit
))
70 m
.d
.comb
+= Assume(SmtFloatingPoint
.from_bits(expected
, sort
=sort
)
71 .same(expected_fp
).as_value())
72 m
.d
.comb
+= a
.eq(AnyConst(width
))
73 m
.d
.comb
+= b
.eq(AnyConst(width
))
75 m
.d
.comb
+= Assert(out_fp
.same(expected_fp
).as_value())
76 m
.d
.comb
+= Assert(out
== expected
)
77 self
.assertFormal(m
, depth
=5, solver
="bitwuzla")
79 # FIXME: check exception flags
81 def test_fadd_f16_rne_formal(self
):
82 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RNE
, False)
84 def test_fadd_f32_rne_formal(self
):
85 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RNE
, False)
87 @unittest.skip("too slow")
88 def test_fadd_f64_rne_formal(self
):
89 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RNE
, False)
91 def test_fadd_f16_rtz_formal(self
):
92 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RTZ
, False)
94 def test_fadd_f32_rtz_formal(self
):
95 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RTZ
, False)
97 @unittest.skip("too slow")
98 def test_fadd_f64_rtz_formal(self
):
99 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RTZ
, False)
101 def test_fadd_f16_rtp_formal(self
):
102 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RTP
, False)
104 def test_fadd_f32_rtp_formal(self
):
105 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RTP
, False)
107 @unittest.skip("too slow")
108 def test_fadd_f64_rtp_formal(self
):
109 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RTP
, False)
111 def test_fadd_f16_rtn_formal(self
):
112 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RTN
, False)
114 def test_fadd_f32_rtn_formal(self
):
115 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RTN
, False)
117 @unittest.skip("too slow")
118 def test_fadd_f64_rtn_formal(self
):
119 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RTN
, False)
121 def test_fadd_f16_rna_formal(self
):
122 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RNA
, False)
124 def test_fadd_f32_rna_formal(self
):
125 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RNA
, False)
127 @unittest.skip("too slow")
128 def test_fadd_f64_rna_formal(self
):
129 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RNA
, False)
131 def test_fadd_f16_rtop_formal(self
):
132 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RTOP
, False)
134 def test_fadd_f32_rtop_formal(self
):
135 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RTOP
, False)
137 @unittest.skip("too slow")
138 def test_fadd_f64_rtop_formal(self
):
139 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RTOP
, False)
141 def test_fadd_f16_rton_formal(self
):
142 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RTON
, False)
144 def test_fadd_f32_rton_formal(self
):
145 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RTON
, False)
147 @unittest.skip("too slow")
148 def test_fadd_f64_rton_formal(self
):
149 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RTON
, False)
151 def test_fsub_f16_rne_formal(self
):
152 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RNE
, True)
154 def test_fsub_f32_rne_formal(self
):
155 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RNE
, True)
157 @unittest.skip("too slow")
158 def test_fsub_f64_rne_formal(self
):
159 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RNE
, True)
161 def test_fsub_f16_rtz_formal(self
):
162 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RTZ
, True)
164 def test_fsub_f32_rtz_formal(self
):
165 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RTZ
, True)
167 @unittest.skip("too slow")
168 def test_fsub_f64_rtz_formal(self
):
169 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RTZ
, True)
171 def test_fsub_f16_rtp_formal(self
):
172 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RTP
, True)
174 def test_fsub_f32_rtp_formal(self
):
175 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RTP
, True)
177 @unittest.skip("too slow")
178 def test_fsub_f64_rtp_formal(self
):
179 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RTP
, True)
181 def test_fsub_f16_rtn_formal(self
):
182 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RTN
, True)
184 def test_fsub_f32_rtn_formal(self
):
185 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RTN
, True)
187 @unittest.skip("too slow")
188 def test_fsub_f64_rtn_formal(self
):
189 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RTN
, True)
191 def test_fsub_f16_rna_formal(self
):
192 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RNA
, True)
194 def test_fsub_f32_rna_formal(self
):
195 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RNA
, True)
197 @unittest.skip("too slow")
198 def test_fsub_f64_rna_formal(self
):
199 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RNA
, True)
201 def test_fsub_f16_rtop_formal(self
):
202 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RTOP
, True)
204 def test_fsub_f32_rtop_formal(self
):
205 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RTOP
, True)
207 @unittest.skip("too slow")
208 def test_fsub_f64_rtop_formal(self
):
209 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RTOP
, True)
211 def test_fsub_f16_rton_formal(self
):
212 self
.tst_fadd_fsub_formal(SmtSortFloat16(), FPRoundingMode
.RTON
, True)
214 def test_fsub_f32_rton_formal(self
):
215 self
.tst_fadd_fsub_formal(SmtSortFloat32(), FPRoundingMode
.RTON
, True)
217 @unittest.skip("too slow")
218 def test_fsub_f64_rton_formal(self
):
219 self
.tst_fadd_fsub_formal(SmtSortFloat64(), FPRoundingMode
.RTON
, True)
221 def test_all_rounding_modes_covered(self
):
222 for width
in 16, 32, 64:
223 for rm
in FPRoundingMode
:
224 rm_s
= rm
.name
.lower()
225 name
= f
"test_fadd_f{width}_{rm_s}_formal"
226 assert callable(getattr(self
, name
))
227 name
= f
"test_fsub_f{width}_{rm_s}_formal"
228 assert callable(getattr(self
, name
))
231 if __name__
== '__main__':