working partitioned eqs function
[ieee754fpu.git] / src / ieee754 / part_cmp / equal.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamically-partitionable "comparison" class, directly equivalent
8 to Signal.__eq__ except SIMD-partitionable
9
10 See:
11
12 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/eq
13 * http://bugs.libre-riscv.org/show_bug.cgi?id=132
14 """
15
16 from nmigen import Signal, Module, Elaboratable, Cat, C, Mux, Repl
17 from nmigen.cli import main
18
19 from ieee754.part_mul_add.partpoints import PartitionPoints
20
21 class PartitionedEq(Elaboratable):
22
23 def __init__(self, width, partition_points):
24 """Create a ``PartitionedEq`` operator
25 """
26 self.width = width
27 self.a = Signal(width, reset_less=True)
28 self.b = Signal(width, reset_less=True)
29 self.partition_points = PartitionPoints(partition_points)
30 self.mwidth = len(self.partition_points)+1
31 self.output = Signal(self.mwidth, reset_less=True)
32 if not self.partition_points.fits_in_width(width):
33 raise ValueError("partition_points doesn't fit in width")
34
35 def elaborate(self, platform):
36 m = Module()
37 comb = m.d.comb
38
39 # make a series of "eqs", splitting a and b into partition chunks
40 eqs = Signal(self.mwidth, reset_less=True)
41 eql = []
42 keys = list(self.partition_points.keys()) + [self.width]
43 start = 0
44 for i in range(len(keys)):
45 end = keys[i]
46 eql.append(self.a[start:end] != self.b[start:end]) # see bool below
47 start = end # for next time round loop
48 comb += eqs.eq(Cat(*eql))
49
50 # now, based on the partition points, create the (multi-)boolean result
51 # this is a terrible way to do it, it's very laborious. however it
52 # will actually "work". optimisations come later
53 eqsigs = []
54 idxs = list(range(self.mwidth))
55 idxs.reverse()
56 #bitrange = int(math.floor(math.log(self.mwidth-1)/math.log(2)))
57 # first loop on bits in output
58 olist = []
59 for i in range(self.mwidth):
60 eqsig = Signal(name="eqsig%d"%i, reset_less=True)
61 eqsigs.append(eqsig)
62 olist.append([])
63
64 ppoints = Signal(self.mwidth-1)
65 comb += ppoints.eq(self.partition_points.as_sig())
66
67 for pval in range(1<<(self.mwidth-1)): # for each partition point
68 cpv = C(pval, self.mwidth-1)
69 with m.If(ppoints == cpv):
70 # identify (find-first) transition points, and how long each
71 # partition is
72 start = 0
73 count = 1
74 idx = [0] * self.mwidth
75 for ipdx in range((self.mwidth-1)):
76 if (pval & (1<<ipdx)):
77 idx[start] = count
78 start = ipdx + 1
79 count = 1
80 else:
81 count += 1
82 idx[start] = count # update last point (or create it)
83
84 print (pval, bin(pval), idx)
85 for i in range(self.mwidth):
86 name = "andsig_%d_%d" % (pval, i)
87 if idx[start]:
88 ands = eqs[i:i+idx[start]]
89 andsig = Signal(len(ands), name=name, reset_less=True)
90 ands = ands.bool() # create an AND cascade
91 print ("ands", pval, i, ands)
92 else:
93 andsig = Signal(name=name, reset_less=True)
94 ands = C(1)
95 comb += andsig.eq(ands)
96 comb += eqsigs[i].eq(~andsig)
97
98 print ("eqsigs", eqsigs, self.output.shape())
99
100 # assign cascade-SIMD-compares to output
101 comb += self.output.eq(Cat(*eqsigs))
102
103 return m