fix up Logical pipeline to produce HDL with XLEN=32
[soc.git] / src / soc / fu / logical / main_stage.py
index bb6efaf290417de72264210e60d8b06846263524..6a90395783e798165bd55576c769c9bf73144952 100644 (file)
@@ -5,23 +5,21 @@
 # This module however should not gate the carry or overflow, that's up
 # to the output stage
 
+# Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
+# Copyright (C) 2020, 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
+
 from nmigen import (Module, Signal, Cat, Repl, Mux, Const, Array)
 from nmutil.pipemodbase import PipeModBase
-from soc.fu.logical.pipe_data import ALUInputData
-from soc.fu.alu.pipe_data import ALUOutputData
-from ieee754.part.partsig import PartitionedSignal
-from soc.decoder.power_enums import InternalOp
-from soc.fu.logical.countzero import ZeroCounter
-
-from soc.decoder.power_fields import DecodeFields
-from soc.decoder.power_fieldsn import SignalBitRange
-
+from nmutil.clz import CLZ
+from soc.fu.logical.pipe_data import LogicalInputData
+from soc.fu.logical.bpermd import Bpermd
+from soc.fu.logical.popcount import Popcount
+from soc.fu.logical.pipe_data import LogicalOutputData
+from ieee754.part.partsig import SimdSignal
+from openpower.decoder.power_enums import MicrOp
 
-def array_of(count, bitwidth):
-    res = []
-    for i in range(count):
-        res.append(Signal(bitwidth, reset_less=True))
-    return res
+from openpower.decoder.power_fields import DecodeFields
+from openpower.decoder.power_fieldsn import SignalBitRange
 
 
 class LogicalMainStage(PipeModBase):
@@ -31,97 +29,108 @@ class LogicalMainStage(PipeModBase):
         self.fields.create_specs()
 
     def ispec(self):
-        return ALUInputData(self.pspec)
+        return LogicalInputData(self.pspec)
 
     def ospec(self):
-        return ALUOutputData(self.pspec) # TODO: ALUIntermediateData
+        return LogicalOutputData(self.pspec)
 
     def elaborate(self, platform):
+        XLEN = self.pspec.XLEN
         m = Module()
         comb = m.d.comb
         op, a, b, o = self.i.ctx.op, self.i.a, self.i.b, self.o.o
 
+        comb += o.ok.eq(1) # overridden if no op activates
+
+        m.submodules.bpermd = bpermd = Bpermd(XLEN)
+        m.submodules.popcount = popcount = Popcount(XLEN)
+
         ##########################
         # main switch for logic ops AND, OR and XOR, cmpb, parity, and popcount
 
         with m.Switch(op.insn_type):
 
-            ###### AND, OR, XOR #######
-            with m.Case(InternalOp.OP_AND):
-                comb += o.eq(a & b)
-            with m.Case(InternalOp.OP_OR):
-                comb += o.eq(a | b)
-            with m.Case(InternalOp.OP_XOR):
-                comb += o.eq(a ^ b)
+            ###################
+            ###### AND, OR, XOR  v3.0B p92-95
 
-            ###### cmpb #######
-            with m.Case(InternalOp.OP_CMPB):
+            with m.Case(MicrOp.OP_AND):
+                comb += o.data.eq(a & b)
+            with m.Case(MicrOp.OP_OR):
+                comb += o.data.eq(a | b)
+            with m.Case(MicrOp.OP_XOR):
+                comb += o.data.eq(a ^ b)
+
+            ###################
+            ###### cmpb  v3.0B p97
+
+            with m.Case(MicrOp.OP_CMPB):
                 l = []
                 for i in range(8):
                     slc = slice(i*8, (i+1)*8)
                     l.append(Repl(a[slc] == b[slc], 8))
-                comb += o.eq(Cat(*l))
-
-            ###### popcount #######
-            with m.Case(InternalOp.OP_POPCNT):
-                # starting from a, perform successive addition-reductions
-                # creating arrays big enough to store the sum, each time
-                pc = [a]
-                # QTY32 2-bit (to take 2x 1-bit sums) etc.
-                work = [(32, 2), (16, 3), (8, 4), (4, 5), (2, 6), (1, 6)]
-                for l, b in work:
-                    pc.append(array_of(l, b))
-                pc8 = pc[3]     # array of 8 8-bit counts (popcntb)
-                pc32 = pc[5]    # array of 2 32-bit counts (popcntw)
-                popcnt = pc[-1] # array of 1 64-bit count (popcntd)
-                # cascade-tree of adds
-                for idx, (l, b) in enumerate(work):
-                    for i in range(l):
-                        stt, end = i*2, i*2+1
-                        src, dst = pc[idx], pc[idx+1]
-                        comb += dst[i].eq(Cat(src[stt], Const(0, 1)) +
-                                          Cat(src[end], Const(0, 1)))
-                # decode operation length
-                with m.If(op.data_len[2:4] == 0b00):
-                    # popcntb - pack 8x 4-bit answers into output
-                    for i in range(8):
-                        comb += o[i*8:i*8+4].eq(pc8[i])
-                with m.Elif(op.data_len[3] == 0):
-                    # popcntw - pack 2x 5-bit answers into output
-                    for i in range(2):
-                        comb += o[i*32:i*32+5].eq(pc32[i])
-                with m.Else():
-                    # popcntd - put 1x 6-bit answer into output
-                    comb += o.eq(popcnt[0])
+                comb += o.data.eq(Cat(*l))
+
+            ###################
+            ###### popcount v3.0B p97, p98
 
-            ###### parity #######
-            with m.Case(InternalOp.OP_PRTY):
+            with m.Case(MicrOp.OP_POPCNT):
+                comb += popcount.a.eq(a)
+                comb += popcount.b.eq(b)
+                comb += popcount.data_len.eq(op.data_len)
+                comb += o.data.eq(popcount.o)
+
+            ###################
+            ###### parity v3.0B p98
+
+            with m.Case(MicrOp.OP_PRTY):
                 # strange instruction which XORs together the LSBs of each byte
                 par0 = Signal(reset_less=True)
                 par1 = Signal(reset_less=True)
-                comb += par0.eq(Cat(a[0] , a[8] , a[16], a[24]).xor())
-                comb += par1.eq(Cat(a[32], a[40], a[48], a[56]).xor())
+                comb += par0.eq(Cat(a[0], a[8], a[16], a[24]).xor())
+                if XLEN == 64:
+                    comb += par1.eq(Cat(a[32], a[40], a[48], a[56]).xor())
                 with m.If(op.data_len[3] == 1):
-                    comb += o.eq(par0 ^ par1)
+                    comb += o.data.eq(par0 ^ par1)
                 with m.Else():
                     comb += o[0].eq(par0)
-                    comb += o[32].eq(par1)
+                    if XLEN == 64:
+                        comb += o[32].eq(par1)
+
+            ###################
+            ###### cntlz v3.0B p99
 
-            ###### cntlz #######
-            with m.Case(InternalOp.OP_CNTZ):
+            with m.Case(MicrOp.OP_CNTZ):
                 XO = self.fields.FormX.XO[0:-1]
-                m.submodules.countz = countz = ZeroCounter()
-                comb += countz.rs_i.eq(a)
-                comb += countz.is_32bit_i.eq(op.is_32bit)
-                comb += countz.count_right_i.eq(XO[-1])
-                comb += o.eq(countz.result_o)
+                count_right = Signal(reset_less=True)
+                comb += count_right.eq(XO[-1])
+
+                cntz_i = Signal(XLEN, reset_less=True)
+                a32 = Signal(32, reset_less=True)
+                comb += a32.eq(a[0:32])
+
+                with m.If(op.is_32bit):
+                    comb += cntz_i.eq(Mux(count_right, a32[::-1], a32))
+                with m.Else():
+                    comb += cntz_i.eq(Mux(count_right, a[::-1], a))
+
+                m.submodules.clz = clz = CLZ(XLEN)
+                comb += clz.sig_in.eq(cntz_i)
+                comb += o.data.eq(Mux(op.is_32bit, clz.lz-32, clz.lz))
+
+            ###################
+            ###### bpermd v3.0B p100
+
+            with m.Case(MicrOp.OP_BPERM):
+                comb += bpermd.rs.eq(a)
+                comb += bpermd.rb.eq(b)
+                comb += o.data.eq(bpermd.ra)
 
-            ###### bpermd #######
-            # TODO with m.Case(InternalOp.OP_BPERM): - not in microwatt
+            with m.Default():
+                comb += o.ok.eq(0)
 
         ###### sticky overflow and context, both pass-through #####
 
-        comb += self.o.so.eq(self.i.so)
+        comb += self.o.xer_so.data.eq(self.i.xer_so)
         comb += self.o.ctx.eq(self.i.ctx)
 
         return m