Fix bpermd and make tests pass
[soc.git] / src / soc / fu / logical / main_stage.py
1 # This stage is intended to do most of the work of executing Logical
2 # instructions. This is OR, AND, XOR, POPCNT, PRTY, CMPB, BPERMD, CNTLZ
3 # however input and output stages also perform bit-negation on input(s)
4 # and output, as well as carry and overflow generation.
5 # This module however should not gate the carry or overflow, that's up
6 # to the output stage
7
8 from nmigen import (Module, Signal, Cat, Repl, Mux, Const, Array)
9 from nmutil.pipemodbase import PipeModBase
10 from nmutil.clz import CLZ
11 from soc.fu.logical.pipe_data import LogicalInputData
12 from soc.fu.logical.bpermd import Bpermd
13 from soc.fu.logical.pipe_data import LogicalOutputData
14 from ieee754.part.partsig import PartitionedSignal
15 from soc.decoder.power_enums import InternalOp
16
17 from soc.decoder.power_fields import DecodeFields
18 from soc.decoder.power_fieldsn import SignalBitRange
19
20
21 def array_of(count, bitwidth):
22 res = []
23 for i in range(count):
24 res.append(Signal(bitwidth, reset_less=True,
25 name=f"pop_{bitwidth}_{i}"))
26 return res
27
28
29 class LogicalMainStage(PipeModBase):
30 def __init__(self, pspec):
31 super().__init__(pspec, "main")
32 self.fields = DecodeFields(SignalBitRange, [self.i.ctx.op.insn])
33 self.fields.create_specs()
34
35 def ispec(self):
36 return LogicalInputData(self.pspec)
37
38 def ospec(self):
39 return LogicalOutputData(self.pspec)
40
41 def elaborate(self, platform):
42 m = Module()
43 comb = m.d.comb
44 op, a, b, o = self.i.ctx.op, self.i.a, self.i.b, self.o.o
45
46 comb += o.ok.eq(1) # overridden if no op activates
47
48
49 m.submodules.bpermd = bpermd = Bpermd(64)
50
51 ##########################
52 # main switch for logic ops AND, OR and XOR, cmpb, parity, and popcount
53
54 with m.Switch(op.insn_type):
55
56 ###### AND, OR, XOR #######
57 with m.Case(InternalOp.OP_AND):
58 comb += o.data.eq(a & b)
59 with m.Case(InternalOp.OP_OR):
60 comb += o.data.eq(a | b)
61 with m.Case(InternalOp.OP_XOR):
62 comb += o.data.eq(a ^ b)
63
64 ###### cmpb #######
65 with m.Case(InternalOp.OP_CMPB):
66 l = []
67 for i in range(8):
68 slc = slice(i*8, (i+1)*8)
69 l.append(Repl(a[slc] == b[slc], 8))
70 comb += o.data.eq(Cat(*l))
71
72 ###### popcount #######
73 with m.Case(InternalOp.OP_POPCNT):
74 # starting from a, perform successive addition-reductions
75 # creating arrays big enough to store the sum, each time
76 pc = [a]
77 # QTY32 2-bit (to take 2x 1-bit sums) etc.
78 work = [(32, 2), (16, 3), (8, 4), (4, 5), (2, 6), (1, 7)]
79 for l, b in work:
80 pc.append(array_of(l, b))
81 pc8 = pc[3] # array of 8 8-bit counts (popcntb)
82 pc32 = pc[5] # array of 2 32-bit counts (popcntw)
83 popcnt = pc[-1] # array of 1 64-bit count (popcntd)
84 # cascade-tree of adds
85 for idx, (l, b) in enumerate(work):
86 for i in range(l):
87 stt, end = i*2, i*2+1
88 src, dst = pc[idx], pc[idx+1]
89 comb += dst[i].eq(Cat(src[stt], Const(0, 1)) +
90 Cat(src[end], Const(0, 1)))
91 # decode operation length
92 with m.If(op.data_len == 1):
93 # popcntb - pack 8x 4-bit answers into output
94 for i in range(8):
95 comb += o[i*8:(i+1)*8].eq(pc8[i])
96 with m.Elif(op.data_len == 4):
97 # popcntw - pack 2x 5-bit answers into output
98 for i in range(2):
99 comb += o[i*32:(i+1)*32].eq(pc32[i])
100 with m.Else():
101 # popcntd - put 1x 6-bit answer into output
102 comb += o.data.eq(popcnt[0])
103
104 ###### parity #######
105 with m.Case(InternalOp.OP_PRTY):
106 # strange instruction which XORs together the LSBs of each byte
107 par0 = Signal(reset_less=True)
108 par1 = Signal(reset_less=True)
109 comb += par0.eq(Cat(a[0], a[8], a[16], a[24]).xor())
110 comb += par1.eq(Cat(a[32], a[40], a[48], a[56]).xor())
111 with m.If(op.data_len[3] == 1):
112 comb += o.data.eq(par0 ^ par1)
113 with m.Else():
114 comb += o[0].eq(par0)
115 comb += o[32].eq(par1)
116
117 ###### cntlz #######
118 with m.Case(InternalOp.OP_CNTZ):
119 XO = self.fields.FormX.XO[0:-1]
120 count_right = Signal(reset_less=True)
121 comb += count_right.eq(XO[-1])
122
123 cntz_i = Signal(64, reset_less=True)
124 a32 = Signal(32, reset_less=True)
125 comb += a32.eq(a[0:32])
126
127 with m.If(op.is_32bit):
128 comb += cntz_i.eq(Mux(count_right, a32[::-1], a32))
129 with m.Else():
130 comb += cntz_i.eq(Mux(count_right, a[::-1], a))
131
132 m.submodules.clz = clz = CLZ(64)
133 comb += clz.sig_in.eq(cntz_i)
134 comb += o.data.eq(Mux(op.is_32bit, clz.lz-32, clz.lz))
135
136 ###### bpermd #######
137 with m.Case(InternalOp.OP_BPERM):
138 comb += bpermd.rs.eq(a)
139 comb += bpermd.rb.eq(self.i.b)
140 comb += o.data.eq(bpermd.ra)
141
142 with m.Default():
143 comb += o.ok.eq(0)
144
145 ###### context, pass-through #####
146
147 comb += self.o.ctx.eq(self.i.ctx)
148
149 return m