get test_mul.py operational
[ieee754fpu.git] / src / ieee754 / fpmul / fmul.py
1 from nmigen import Module, Signal, Cat, Mux, Array, Const
2 from nmigen.cli import main, verilog
3
4 from ieee754.fpcommon.fpbase import (FPNumIn, FPNumOut, FPOpIn,
5 FPOpOut, Overflow, FPBase, FPState)
6 from ieee754.fpcommon.getop import FPGetOp
7 from nmutil.nmoperator import eq
8
9
10 class FPMUL(FPBase):
11
12 def __init__(self, width):
13 FPBase.__init__(self)
14 self.width = width
15
16 self.in_a = FPOpIn(width)
17 self.in_b = FPOpIn(width)
18 self.out_z = FPOpOut(width)
19
20 self.states = []
21
22 def add_state(self, state):
23 self.states.append(state)
24 return state
25
26 def elaborate(self, platform=None):
27 """ creates the HDL code-fragment for FPMUL
28 """
29 m = Module()
30
31 # Latches
32 a = FPNumIn(None, self.width, False)
33 b = FPNumIn(None, self.width, False)
34 z = FPNumOut(self.width, False)
35
36 mw = (z.m_width)*2 - 1 + 3 # sticky/round/guard bits + (2*mant) - 1
37 product = Signal(mw)
38
39 of = Overflow()
40 m.submodules.of = of
41 m.submodules.a = a
42 m.submodules.b = b
43 m.submodules.z = z
44
45 m.d.comb += a.v.eq(self.in_a.v)
46 m.d.comb += b.v.eq(self.in_b.v)
47
48 with m.FSM() as fsm:
49
50 # ******
51 # gets operand a
52
53 with m.State("get_a"):
54 res = self.get_op(m, self.in_a, a, "get_b")
55 m.d.sync += eq([a, self.in_a.ack], res)
56
57 # ******
58 # gets operand b
59
60 with m.State("get_b"):
61 res = self.get_op(m, self.in_b, b, "special_cases")
62 m.d.sync += eq([b, self.in_b.ack], res)
63
64 # ******
65 # special cases
66
67 with m.State("special_cases"):
68 #if a or b is NaN return NaN
69 with m.If(a.is_nan | b.is_nan):
70 m.next = "put_z"
71 m.d.sync += z.nan(1)
72 #if a is inf return inf
73 with m.Elif(a.is_inf):
74 m.next = "put_z"
75 m.d.sync += z.inf(a.s ^ b.s)
76 #if b is zero return NaN
77 with m.If(b.is_zero):
78 m.d.sync += z.nan(1)
79 #if b is inf return inf
80 with m.Elif(b.is_inf):
81 m.next = "put_z"
82 m.d.sync += z.inf(a.s ^ b.s)
83 #if a is zero return NaN
84 with m.If(a.is_zero):
85 m.next = "put_z"
86 m.d.sync += z.nan(1)
87 #if a is zero return zero
88 with m.Elif(a.is_zero):
89 m.next = "put_z"
90 m.d.sync += z.zero(a.s ^ b.s)
91 #if b is zero return zero
92 with m.Elif(b.is_zero):
93 m.next = "put_z"
94 m.d.sync += z.zero(a.s ^ b.s)
95 # Denormalised Number checks
96 with m.Else():
97 m.next = "normalise_a"
98 self.denormalise(m, a)
99 self.denormalise(m, b)
100
101 # ******
102 # normalise_a
103
104 with m.State("normalise_a"):
105 self.op_normalise(m, a, "normalise_b")
106
107 # ******
108 # normalise_b
109
110 with m.State("normalise_b"):
111 self.op_normalise(m, b, "multiply_0")
112
113 #multiply_0
114 with m.State("multiply_0"):
115 m.next = "multiply_1"
116 m.d.sync += [
117 z.s.eq(a.s ^ b.s),
118 z.e.eq(a.e + b.e + 1),
119 product.eq(a.m * b.m * 4)
120 ]
121
122 #multiply_1
123 with m.State("multiply_1"):
124 mw = z.m_width
125 m.next = "normalise_1"
126 m.d.sync += [
127 z.m.eq(product[mw+2:]),
128 of.guard.eq(product[mw+1]),
129 of.round_bit.eq(product[mw]),
130 of.sticky.eq(product[0:mw] != 0)
131 ]
132
133 # ******
134 # First stage of normalisation.
135 with m.State("normalise_1"):
136 self.normalise_1(m, z, of, "normalise_2")
137
138 # ******
139 # Second stage of normalisation.
140
141 with m.State("normalise_2"):
142 self.normalise_2(m, z, of, "round")
143
144 # ******
145 # rounding stage
146
147 with m.State("round"):
148 self.roundz(m, z, of.roundz)
149 m.next = "corrections"
150
151 # ******
152 # correction stage
153
154 with m.State("corrections"):
155 self.corrections(m, z, "pack")
156
157 # ******
158 # pack stage
159 with m.State("pack"):
160 self.pack(m, z, "put_z")
161
162 # ******
163 # put_z stage
164
165 with m.State("put_z"):
166 self.put_z(m, z, self.out_z, "get_a")
167
168 return m
169
170
171 if __name__ == "__main__":
172 alu = FPMUL(width=32)
173 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())