X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fadd%2Ffmul.py;h=130d49e814d05028f5248206c4e87611e304a35b;hb=286fdefc4bbe8c7b4bb34ae33b513e8bb81b3d7e;hp=4d54c98dffb813f2ab4999576cd94d89f1be863e;hpb=e1db3d985344622c4782ec9c4f126995e2b795de;p=ieee754fpu.git diff --git a/src/add/fmul.py b/src/add/fmul.py index 4d54c98d..130d49e8 100644 --- a/src/add/fmul.py +++ b/src/add/fmul.py @@ -1,8 +1,8 @@ -from nmigen import Module, Signal +from nmigen import Module, Signal, Cat, Mux, Array, Const from nmigen.cli import main, verilog -from fpbase import FPNum, FPOp, Overflow, FPBase - +from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase +from nmigen_add_experiment import FPState class FPMUL(FPBase): @@ -20,77 +20,67 @@ class FPMUL(FPBase): m = Module() # Latches - a = FPNum(self.width) - b = FPNum(self.width) - z = FPNum(self.width, False) + a = FPNumIn(None, self.width, False) + b = FPNumIn(None, self.width, False) + z = FPNumOut(self.width, False) - tot = Signal(28) # sticky/round/guard bits, 23 result, 1 overflow + mw = (z.m_width)*2 - 1 + 3 # sticky/round/guard bits + (2*mant) - 1 + product = Signal(mw) of = Overflow() + m.submodules.of = of + m.submodules.a = a + m.submodules.b = b + m.submodules.z = z with m.FSM() as fsm: - with m.State("get_a"): - m.next += "get_b" - m.d.sync += s.in_a.ack.eq(1) - with m.If(s.in_a.ack & in_a.stb): - m.d.sync += [ - a.eq(in_a), - s.in_a.ack(0) - ] - - with m.State("get_b"): - m.next += "unpack" - m.d.sync += s.in_b.ack.eq(1) - with m.If(s.in_b.ack & in_b.stb): - m.d.sync += [ - b.eq(in_b), - s.in_b.ack(0) - ] - - with m.State("unpack"): - m.next += "special_cases" - m.d.sync += [ - a.m.eq(a[0:22]), - b.m.eq(b[0:22]), - a.e.eq(a[23:31] - 127), - b.e.eq(b[23:31] - 127), - a.s.eq(a[31]), - b.s.eq(b[31]) - ] - - with m.State("special_cases"): - m.next = "normalise_a" - #if a or b is NaN return NaN - with m.If(a.is_nan() | b.is_nan()): - m.next += "put_z" - m.d.sync += z.nan(1) - #if a is inf return inf - with m.Elif(a.is_inf()): - m.next += "put_z" - m.d.sync += z.inf(0) - #if b is zero return NaN - with m.If(b.is_zero()): - m.d.sync += z.nan(1) - #if b is inf return inf - with m.Elif(b.is_inf()): - m.next += "put_z" - m.d.sync += z.inf(0) - #if a is zero return NaN - with m.If(a.is_zero()): - m.next += "put_z" - m.d.sync += z.nan(1) - #if a is zero return zero - with m.Elif(a.is_zero()): - m.next += "put_z" - m.d.sync += z.zero(0) - #if b is zero return zero - with m.Elif(b.is_zero()): - m.next += "put_z" - m.d.sync += z.zero(0) + # ****** + # gets operand a + + with m.State("get_a"): + self.get_op(m, self.in_a, a, "get_b") + + # ****** + # gets operand b + + with m.State("get_b"): + self.get_op(m, self.in_b, b, "special_cases") + + # ****** + # special cases + + with m.State("special_cases"): + #if a or b is NaN return NaN + with m.If(a.is_nan | b.is_nan): + m.next = "put_z" + m.d.sync += z.nan(1) + #if a is inf return inf + with m.Elif(a.is_inf): + m.next = "put_z" + m.d.sync += z.inf(a.s ^ b.s) + #if b is zero return NaN + with m.If(b.is_zero): + m.d.sync += z.nan(1) + #if b is inf return inf + with m.Elif(b.is_inf): + m.next = "put_z" + m.d.sync += z.inf(a.s ^ b.s) + #if a is zero return NaN + with m.If(a.is_zero): + m.next = "put_z" + m.d.sync += z.nan(1) + #if a is zero return zero + with m.Elif(a.is_zero): + m.next = "put_z" + m.d.sync += z.zero(a.s ^ b.s) + #if b is zero return zero + with m.Elif(b.is_zero): + m.next = "put_z" + m.d.sync += z.zero(a.s ^ b.s) # Denormalised Number checks - with m.Else(): - m.next = "normalise_a" + with m.Else(): + m.next = "normalise_a" self.denormalise(m, a) self.denormalise(m, b) @@ -106,173 +96,64 @@ class FPMUL(FPBase): with m.State("normalise_b"): self.op_normalise(m, b, "multiply_0") + #multiply_0 + with m.State("multiply_0"): + m.next = "multiply_1" + m.d.sync += [ + z.s.eq(a.s ^ b.s), + z.e.eq(a.e + b.e + 1), + product.eq(a.m * b.m * 4) + ] + + #multiply_1 + with m.State("multiply_1"): + mw = z.m_width + m.next = "normalise_1" + m.d.sync += [ + z.m.eq(product[mw+2:]), + of.guard.eq(product[mw+1]), + of.round_bit.eq(product[mw]), + of.sticky.eq(product[0:mw] != 0) + ] + + # ****** + # First stage of normalisation. + with m.State("normalise_1"): + self.normalise_1(m, z, of, "normalise_2") + + # ****** + # Second stage of normalisation. + + with m.State("normalise_2"): + self.normalise_2(m, z, of, "round") + + # ****** + # rounding stage + + with m.State("round"): + self.roundz(m, z, of.roundz) + m.next = "corrections" + + # ****** + # correction stage + + with m.State("corrections"): + self.corrections(m, z, "pack") + + # ****** + # pack stage + with m.State("pack"): + self.pack(m, z, "put_z") + + # ****** + # put_z stage + + with m.State("put_z"): + self.put_z(m, z, self.out_z, "get_a") + + return m -""" -special_cases: - begin - //if a is NaN or b is NaN return NaN - if ((a_e == 128 && a_m != 0) || (b_e == 128 && b_m != 0)) begin - z[31] <= 1; - z[30:23] <= 255; - z[22] <= 1; - z[21:0] <= 0; - state <= put_z; - //if a is inf return inf - end else if (a_e == 128) begin - z[31] <= a_s ^ b_s; - z[30:23] <= 255; - z[22:0] <= 0; - //if b is zero return NaN - if (($signed(b_e) == -127) && (b_m == 0)) begin - z[31] <= 1; - z[30:23] <= 255; - z[22] <= 1; - z[21:0] <= 0; - end - state <= put_z; - //if b is inf return inf - end else if (b_e == 128) begin - z[31] <= a_s ^ b_s; - z[30:23] <= 255; - z[22:0] <= 0; - //if a is zero return NaN - if (($signed(a_e) == -127) && (a_m == 0)) begin - z[31] <= 1; - z[30:23] <= 255; - z[22] <= 1; - z[21:0] <= 0; - end - state <= put_z; - //if a is zero return zero - end else if (($signed(a_e) == -127) && (a_m == 0)) begin - z[31] <= a_s ^ b_s; - z[30:23] <= 0; - z[22:0] <= 0; - state <= put_z; - //if b is zero return zero - end else if (($signed(b_e) == -127) && (b_m == 0)) begin - z[31] <= a_s ^ b_s; - z[30:23] <= 0; - z[22:0] <= 0; - state <= put_z; - //^ done up to here - end else begin - //Denormalised Number - if ($signed(a_e) == -127) begin - a_e <= -126; - end else begin - a_m[23] <= 1; - end - //Denormalised Number - if ($signed(b_e) == -127) begin - b_e <= -126; - end else begin - b_m[23] <= 1; - end - state <= normalise_a; - end - end - - normalise_a: - begin - if (a_m[23]) begin - state <= normalise_b; - end else begin - a_m <= a_m << 1; - a_e <= a_e - 1; - end - end - - normalise_b: - begin - if (b_m[23]) begin - state <= multiply_0; - end else begin - b_m <= b_m << 1; - b_e <= b_e - 1; - end - end - - multiply_0: - begin - z_s <= a_s ^ b_s; - z_e <= a_e + b_e + 1; - product <= a_m * b_m * 4; - state <= multiply_1; - end - - multiply_1: - begin - z_m <= product[49:26]; - guard <= product[25]; - round_bit <= product[24]; - sticky <= (product[23:0] != 0); - state <= normalise_1; - end - - normalise_1: - begin - if (z_m[23] == 0) begin - z_e <= z_e - 1; - z_m <= z_m << 1; - z_m[0] <= guard; - guard <= round_bit; - round_bit <= 0; - end else begin - state <= normalise_2; - end - end - - normalise_2: - begin - if ($signed(z_e) < -126) begin - z_e <= z_e + 1; - z_m <= z_m >> 1; - guard <= z_m[0]; - round_bit <= guard; - sticky <= sticky | round_bit; - end else begin - state <= round; - end - end - - round: - begin - if (guard && (round_bit | sticky | z_m[0])) begin - z_m <= z_m + 1; - if (z_m == 24'hffffff) begin - z_e <=z_e + 1; - end - end - state <= pack; - end - - pack: - begin - z[22 : 0] <= z_m[22:0]; - z[30 : 23] <= z_e[7:0] + 127; - z[31] <= z_s; - if ($signed(z_e) == -126 && z_m[23] == 0) begin - z[30 : 23] <= 0; - end - //if overflow occurs, return inf - if ($signed(z_e) > 127) begin - z[22 : 0] <= 0; - z[30 : 23] <= 255; - z[31] <= z_s; - end - state <= put_z; - end - - put_z: - begin - s_output_z_stb <= 1; - s_output_z <= z; - if (s_output_z_stb && output_z_ack) begin - s_output_z_stb <= 0; - state <= get_a; - end -end - -""" +if __name__ == "__main__": + alu = FPMUL(width=32) + main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())