1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
7 dynamically-partitionable "comparison" class, directly equivalent
8 to Signal.__eq__ except SIMD-partitionable
12 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/eq
13 * http://bugs.libre-riscv.org/show_bug.cgi?id=132
16 from nmigen
import Signal
, Module
, Elaboratable
, Cat
, C
, Mux
, Repl
17 from nmigen
.cli
import main
19 from ieee754
.part_mul_add
.partpoints
import PartitionPoints
21 class PartitionedEq(Elaboratable
):
23 def __init__(self
, width
, partition_points
):
24 """Create a ``PartitionedEq`` operator
27 self
.a
= Signal(width
, reset_less
=True)
28 self
.b
= Signal(width
, reset_less
=True)
29 self
.partition_points
= PartitionPoints(partition_points
)
30 self
.mwidth
= len(self
.partition_points
)+1
31 self
.output
= Signal(self
.mwidth
, reset_less
=True)
32 if not self
.partition_points
.fits_in_width(width
):
33 raise ValueError("partition_points doesn't fit in width")
35 def elaborate(self
, platform
):
39 # make a series of "eqs", splitting a and b into partition chunks
40 eqs
= Signal(self
.mwidth
, reset_less
=True)
42 keys
= list(self
.partition_points
.keys()) + [self
.width
]
44 for i
in range(len(keys
)):
46 eql
.append(self
.a
[start
:end
] != self
.b
[start
:end
]) # see bool below
47 start
= end
# for next time round loop
48 comb
+= eqs
.eq(Cat(*eql
))
50 # now, based on the partition points, create the (multi-)boolean result
51 # this is a terrible way to do it, it's very laborious. however it
52 # will actually "work". optimisations come later
54 idxs
= list(range(self
.mwidth
))
56 #bitrange = int(math.floor(math.log(self.mwidth-1)/math.log(2)))
57 # first loop on bits in output
59 for i
in range(self
.mwidth
):
60 eqsig
= Signal(name
="eqsig%d"%i, reset_less
=True)
64 ppoints
= Signal(self
.mwidth
-1)
65 comb
+= ppoints
.eq(self
.partition_points
.as_sig())
67 for pval
in range(1<<(self
.mwidth
-1)): # for each partition point
68 cpv
= C(pval
, self
.mwidth
-1)
69 with m
.If(ppoints
== cpv
):
70 # identify (find-first) transition points, and how long each
74 idx
= [0] * self
.mwidth
75 for ipdx
in range((self
.mwidth
-1)):
76 if (pval
& (1<<ipdx
)):
82 idx
[start
] = count
# update last point (or create it)
84 print (pval
, bin(pval
), idx
)
85 for i
in range(self
.mwidth
):
86 name
= "andsig_%d_%d" % (pval
, i
)
88 ands
= eqs
[i
:i
+idx
[start
]]
89 andsig
= Signal(len(ands
), name
=name
, reset_less
=True)
90 ands
= ands
.bool() # create an AND cascade
91 print ("ands", pval
, i
, ands
)
93 andsig
= Signal(name
=name
, reset_less
=True)
95 comb
+= andsig
.eq(ands
)
96 comb
+= eqsigs
[i
].eq(~andsig
)
98 print ("eqsigs", eqsigs
, self
.output
.shape())
100 # assign cascade-SIMD-compares to output
101 comb
+= self
.output
.eq(Cat(*eqsigs
))