Add rudimentary branch unit test bench
[soc.git] / src / soc / branch / main_stage.py
index b50afc278f0356a3eb9c51d62e4872675c1e8b7b..e4e522b6aaa4621bc69e7a2e3692bcc5c2c1964a 100644 (file)
@@ -7,11 +7,8 @@
 
 from nmigen import (Module, Signal, Cat, Repl, Mux, Const, Array)
 from nmutil.pipemodbase import PipeModBase
-from soc.logical.pipe_data import ALUInputData
-from soc.alu.pipe_data import ALUOutputData
-from ieee754.part.partsig import PartitionedSignal
+from soc.branch.pipe_data import BranchInputData, BranchOutputData
 from soc.decoder.power_enums import InternalOp
-from soc.countzero.countzero import ZeroCounter
 
 from soc.decoder.power_fields import DecodeFields
 from soc.decoder.power_fieldsn import SignalBitRange
@@ -24,105 +21,32 @@ def array_of(count, bitwidth):
     return res
 
 
-class LogicalMainStage(PipeModBase):
+class BranchMainStage(PipeModBase):
     def __init__(self, pspec):
         super().__init__(pspec, "main")
         self.fields = DecodeFields(SignalBitRange, [self.i.ctx.op.insn])
         self.fields.create_specs()
 
     def ispec(self):
-        return ALUInputData(self.pspec)
+        return BranchInputData(self.pspec)
 
     def ospec(self):
-        return ALUOutputData(self.pspec) # TODO: ALUIntermediateData
+        return BranchOutputData(self.pspec) # TODO: ALUIntermediateData
 
     def elaborate(self, platform):
         m = Module()
         comb = m.d.comb
-        op, a, b, o = self.i.ctx.op, self.i.a, self.i.b, self.o.o
+        op = self.i.ctx.op
 
         ##########################
         # main switch for logic ops AND, OR and XOR, cmpb, parity, and popcount
 
         with m.Switch(op.insn_type):
+            pass
 
-            ###### AND, OR, XOR #######
-            with m.Case(InternalOp.OP_AND):
-                comb += o.eq(a & b)
-            with m.Case(InternalOp.OP_OR):
-                comb += o.eq(a | b)
-            with m.Case(InternalOp.OP_XOR):
-                comb += o.eq(a ^ b)
-
-            ###### cmpb #######
-            with m.Case(InternalOp.OP_CMPB):
-                l = []
-                for i in range(8):
-                    slc = slice(i*8, (i+1)*8)
-                    l.append(Repl(a[slc] == b[slc], 8))
-                comb += o.eq(Cat(*l))
-
-            ###### popcount #######
-            with m.Case(InternalOp.OP_POPCNT):
-                # starting from a, perform successive addition-reductions
-                # creating arrays big enough to store the sum, each time
-                pc = [a]
-                # QTY32 2-bit (to take 2x 1-bit sums) etc.
-                work = [(32, 2), (16, 3), (8, 4), (4, 5), (2, 6), (1, 6)]
-                for l, b in work:
-                    pc.append(array_of(l, b))
-                pc8 = pc[3]     # array of 8 8-bit counts (popcntb)
-                pc32 = pc[5]    # array of 2 32-bit counts (popcntw)
-                popcnt = pc[-1] # array of 1 64-bit count (popcntd)
-                # cascade-tree of adds
-                for idx, (l, b) in enumerate(work):
-                    for i in range(l):
-                        stt, end = i*2, i*2+1
-                        src, dst = pc[idx], pc[idx+1]
-                        comb += dst[i].eq(Cat(src[stt], Const(0, 1)) +
-                                          Cat(src[end], Const(0, 1)))
-                # decode operation length
-                with m.If(op.data_len[2:4] == 0b00):
-                    # popcntb - pack 8x 4-bit answers into output
-                    for i in range(8):
-                        comb += o[i*8:i*8+4].eq(pc8[i])
-                with m.Elif(op.data_len[3] == 0):
-                    # popcntw - pack 2x 5-bit answers into output
-                    for i in range(2):
-                        comb += o[i*32:i*32+5].eq(pc32[i])
-                with m.Else():
-                    # popcntd - put 1x 6-bit answer into output
-                    comb += o.eq(popcnt[0])
-
-            ###### parity #######
-            with m.Case(InternalOp.OP_PRTY):
-                # strange instruction which XORs together the LSBs of each byte
-                par0 = Signal(reset_less=True)
-                par1 = Signal(reset_less=True)
-                comb += par0.eq(Cat(a[0] , a[8] , a[16], a[24]).xor())
-                comb += par1.eq(Cat(a[32], a[40], a[48], a[56]).xor())
-                with m.If(op.data_len[3] == 1):
-                    comb += o.eq(par0 ^ par1)
-                with m.Else():
-                    comb += o[0].eq(par0)
-                    comb += o[32].eq(par1)
-
-            ###### cntlz #######
-            with m.Case(InternalOp.OP_CNTZ):
-                x_fields = self.fields.instrs['X']
-                XO = Signal(x_fields['XO'][0:-1].shape())
-                m.submodules.countz = countz = ZeroCounter()
-                comb += countz.rs_i.eq(a)
-                comb += countz.is_32bit_i.eq(op.is_32bit)
-                comb += countz.count_right_i.eq(XO[-1])
-                comb += o.eq(countz.result_o)
-
-            ###### bpermd #######
-            # TODO with m.Case(InternalOp.OP_BPERM): - not in microwatt
 
         ###### sticky overflow and context, both pass-through #####
 
-        comb += self.o.so.eq(self.i.so)
         comb += self.o.ctx.eq(self.i.ctx)
 
         return m