X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fadd%2Ffsqrt.py;h=a87af61db4ded516df1f93d6cb3af933d6318af3;hb=7541ca979084de96ebdf292e1baa0a03af64d3fc;hp=2171646c51a54e7da9284105698ab18b84fe8349;hpb=d8ce060336180aae35957f375ab93a16787bdb1a;p=ieee754fpu.git diff --git a/src/add/fsqrt.py b/src/add/fsqrt.py index 2171646c..a87af61d 100644 --- a/src/add/fsqrt.py +++ b/src/add/fsqrt.py @@ -23,30 +23,24 @@ def sqrtsimple(num): def sqrt(num): D = num # D is input (from num) Q = 0 - R = 0 - r = 0 # remainder + R = 0 # remainder for i in range(64, -1, -1): # negative ranges are weird... - if (R>=0): - - R = (R<<2)|((D>>(i+i))&3) - R = R-((Q<<2)|1) #/*-Q01*/ + R = (R<<2)|((D>>(i+i))&3) + if R >= 0: + R -= ((Q<<2)|1) # -Q01 else: + R += ((Q<<2)|3) # +Q11 - R = (R<<2)|((D>>(i+i))&3) - R = R+((Q<<2)|3) #/*+Q11*/ - - if (R>=0): - Q = (Q<<1)|1 #/*new Q:*/ - else: - Q = (Q<<1)|0 #/*new Q:*/ + Q <<= 1 + if R >= 0: + Q |= 1 # new Q + if R < 0: + R = R + ((Q<<1)|1) - if (R<0): - R = R+((Q<<1)|1) - r = R - return Q + return Q, R # grabbed these from unit_test_single (convenience, this is just experimenting) @@ -80,9 +74,10 @@ def decode_fp32(x): def main(mantissa, exponent): if exponent & 1 != 0: # shift mantissa up, subtract 1 from exp to compensate - return sqrt(mantissa << 1), (exponent - 1) >> 1 - # mantissa as-is, no compensating needed on exp - return sqrt(mantissa), (exponent >> 1) + mantissa <<= 1 + exponent -= 1 + m, r = sqrt(mantissa) + return m, r, exponent >> 1 def fsqrt_test(x): @@ -97,13 +92,17 @@ def fsqrt_test(x): print("x decode", s, e, m, hex(m)) m |= 1<<23 # set top bit (the missing "1" from mantissa) - m <<= 25 + m <<= 27 - sm, se = main(m, e) - sm >>= 1 + sm, sr, se = main(m, e) + lowbits = sm & 0x3 + sm >>= 2 sm = get_mantissa(sm) #sm += 2 - print("our sqrt", s, se, sm, hex(sm), bin(sm)) + print("our sqrt", s, se, sm, hex(sm), bin(sm), "lowbits", lowbits, + "rem", hex(sr)) + if lowbits >= 2: + print ("probably needs rounding (+1 on mantissa)") sq_xbits = sq_test.bits s, e, m = decode_fp32(sq_xbits) @@ -116,13 +115,13 @@ if __name__ == '__main__': for Q in range(1, int(1e4)): print(Q, sqrt(Q), sqrtsimple(Q), int(Q**0.5)) assert int(Q**0.5) == sqrtsimple(Q), "Q sqrtsimpl fail %d" % Q - assert int(Q**0.5) == sqrt(Q), "Q sqrt fail %d" % Q + assert int(Q**0.5) == sqrt(Q)[0], "Q sqrt fail %d" % Q # quick mantissa/exponent demo for e in range(26): for m in range(26): - ms, es = main(m, e) - print("m:%d e:%d sqrt: m:%d e:%d" % (m, e, ms, es)) + ms, mr, es = main(m, e) + print("m:%d e:%d sqrt: m:%d-%d e:%d" % (m, e, ms, mr, es)) x = Float32(1234.123456789) fsqrt_test(x) @@ -134,6 +133,12 @@ if __name__ == '__main__': fsqrt_test(x) x = Float32(8.5) fsqrt_test(x) + x = Float32(3.14159265358979323) + fsqrt_test(x) + x = Float32(12.99392923123123) + fsqrt_test(x) + x = Float32(0.123456) + fsqrt_test(x) """