fix up Logical pipeline to produce HDL with XLEN=32
[soc.git] / src / soc / fu / logical / popcount.py
1 """Popcount: a successive (cascading) sum-reduction algorithm for counting bits
2
3 starting from single-bit adds and reducing down to one final answer:
4 the total number of bits set to "1" in the input.
5
6 unfortunately there is a bit of a "trick" going on which you have to
7 watch out for: whilst the first list added to pc is a single entry (the
8 input, a), subsequent rows in the cascade are arrays of partial results,
9 yet it turns out that referring to them using the exact same start/end
10 slicing is perfect. this comes down to nmigen's transparent use of
11 python features to make Signals iterable.
12 """
13
14 from nmigen import (Elaboratable, Module, Signal, Cat, Const)
15
16
17 def array_of(count, bitwidth):
18 res = []
19 for i in range(count):
20 res.append(Signal(bitwidth, reset_less=True,
21 name=f"pop_{bitwidth}_{i}"))
22 return res
23
24
25 class Popcount(Elaboratable):
26 def __init__(self, width=64):
27 self.width = width
28 self.a = Signal(width, reset_less=True)
29 self.b = Signal(width, reset_less=True)
30 self.data_len = Signal(4, reset_less=True) # data len up to... err.. 8?
31 self.o = Signal(width, reset_less=True)
32 assert width in [32, 64], "only 32 or 64 bit supported for now"
33
34 def elaborate(self, platform):
35 m = Module()
36 comb = m.d.comb
37 a, b, data_len, o = self.a, self.b, self.data_len, self.o
38
39 # starting from a, perform successive addition-reductions
40 # creating arrays big enough to store the sum, each time
41 pc = [a]
42 # QTY32 2-bit (to take 2x 1-bit sums) etc.
43 work = [(16, 3), (8, 4), (4, 5), (2, 6), (1, 7)]
44 if self.width == 64:
45 work = [(32, 2)] + work
46 for l, bw in work: # l=number of add-reductions, bw=bitwidth
47 pc.append(array_of(l, bw))
48 pc8 = pc[-4] # array of 8 8-bit counts (popcntb)
49 pc32 = pc[-2] # array of 2 32-bit counts (popcntw)
50 popcnt = pc[-1] # array of 1 64-bit count (popcntd)
51 # cascade-tree of adds
52 for idx, (l, bw) in enumerate(work):
53 for i in range(l):
54 stt, end = i*2, i*2+1
55 src, dst = pc[idx], pc[idx+1]
56 comb += dst[i].eq(Cat(src[stt], Const(0, 1)) +
57 Cat(src[end], Const(0, 1)))
58 # decode operation length (1-hot)
59 with m.If(data_len == 1):
60 # popcntb - pack 8x 4-bit answers into 8x 8-bit output fields
61 for i in range(self.width//8):
62 comb += o[i*8:(i+1)*8].eq(pc8[i])
63 with m.Elif(data_len == 4):
64 if self.width == 64:
65 # popcntw - pack 2x 5-bit answers into 2x 32-bit output fields
66 for i in range(2):
67 comb += o[i*32:(i+1)*32].eq(pc32[i])
68 else:
69 comb += o.eq(popcnt[0])
70 with m.Else():
71 # popcntd - put 1x 6-bit answer into 64-bit output
72 comb += o.eq(popcnt[0])
73
74 return m