hmm start adding st in (half done)
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 23 Mar 2020 18:05:23 +0000 (18:05 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 23 Mar 2020 18:05:23 +0000 (18:05 +0000)
src/soc/scoreboard/addr_split.py

index bf89e0970e9a8b44c76018660114172f5a3061f4..afcd2dc0eedf2464790b96d86166d84466e551ad 100644 (file)
@@ -53,17 +53,24 @@ class LDSTSplitter(Elaboratable):
         self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
         self.addr_i = Signal(awidth, reset_less=True)
         self.len_i = Signal(dlen, reset_less=True)
+        self.valid_i = Signal(reset_less=True)
+        self.valid_o = Signal(reset_less=True)
+
         self.is_ld_i = Signal(reset_less=True)
+        self.is_st_i = Signal(reset_less=True)
+
         self.ld_data_o = LDData(dwidth, "ld_data_o")
-        self.ld_valid_i = Signal(reset_less=True)
-        self.valid_o = Signal(reset_less=True)
+        self.st_data_i = LDData(dwidth, "st_data_i")
+
         self.sld_valid_o = Signal(2, reset_less=True)
         self.sld_valid_i = Signal(2, reset_less=True)
         self.sld_data_i = Array((LDData(dwidth, "ld_data_i1"),
                                 LDData(dwidth, "ld_data_i2")))
 
-        #self.is_st_i = Signal(reset_less=True)
-        #self.st_data_i = Signal(dwidth, reset_less=True)
+        self.sst_valid_o = Signal(2, reset_less=True)
+        self.sst_valid_i = Signal(2, reset_less=True)
+        self.sst_data_o = Array((LDData(dwidth, "st_data_i1"),
+                                LDData(dwidth, "st_data_i2")))
 
     def elaborate(self, platform):
         m = Module()
@@ -86,35 +93,48 @@ class LDSTSplitter(Elaboratable):
         comb += ld1.addr_i.eq(self.addr_i[dlen:])
         comb += ld2.addr_i.eq(self.addr_i[dlen:] + 1) # TODO exception if rolls
 
-        # set up connections to LD-split.  note: not active if mask is zero
-        mzero = Const(0, mlen)
-        for i, (ld, mask) in enumerate(((ld1, mask1),
-                                        (ld2, mask2))):
-            ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
-            comb += ld_valid.eq(self.ld_valid_i & self.sld_valid_i[i])
-            comb += ld.valid_i.eq(ld_valid & (mask != mzero))
-            comb += ld.ld_i.eq(self.sld_data_i[i])
-            comb += self.sld_valid_o[i].eq(ld.valid_o)
-
-        # sort out valid: mask2 zero we ignore 2nd LD
-        with m.If(mask2 == mzero):
-            comb += self.valid_o.eq(self.sld_valid_o[0])
-        with m.Else():
-            comb += self.valid_o.eq(self.sld_valid_o.all())
-
-        # all bits valid (including when a data error occurs!) decode ld1/ld2
-        with m.If(self.valid_o):
-            # errors cause error condition
-            comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
-            # data needs recombining via shifting.
-            ashift1 = Signal(self.dlen)
-            ashift2 = Signal(self.dlen)
-            comb += ashift1.eq(self.addr_i[:self.dlen])
-            comb += ashift2.eq((1<<dlen)-ashift1)
-            # note that data from LD1 will be in *cache-line* byte position
-            # likewise from LD2 but we *know* it is at the start of the line
-            comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
-                                            (ld2.ld_o.data << ashift2))
+        with m.If(self.is_ld_i):
+            # set up connections to LD-split.  note: not active if mask is zero
+            mzero = Const(0, mlen)
+            for i, (ld, mask) in enumerate(((ld1, mask1),
+                                            (ld2, mask2))):
+                ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
+                comb += ld_valid.eq(self.valid_i & self.sld_valid_i[i])
+                comb += ld.valid_i.eq(ld_valid & (mask != mzero))
+                comb += ld.ld_i.eq(self.sld_data_i[i])
+                comb += self.sld_valid_o[i].eq(ld.valid_o)
+
+            # sort out valid: mask2 zero we ignore 2nd LD
+            with m.If(mask2 == mzero):
+                comb += self.valid_o.eq(self.sld_valid_o[0])
+            with m.Else():
+                comb += self.valid_o.eq(self.sld_valid_o.all())
+
+            # all bits valid (including when data error occurs!) decode ld1/ld2
+            with m.If(self.valid_o):
+                # errors cause error condition
+                comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
+                # data needs recombining via shifting.
+                ashift1 = Signal(self.dlen)
+                ashift2 = Signal(self.dlen)
+                comb += ashift1.eq(self.addr_i[:self.dlen])
+                comb += ashift2.eq((1<<dlen)-ashift1)
+                # note that data from LD1 will be in *cache-line* byte position
+                # likewise from LD2 but we *know* it is at the start of the line
+                comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
+                                                (ld2.ld_o.data << ashift2))
+
+        with m.If(self.is_st_i):
+            mzero = Const(0, mlen)
+            for i, (ld, mask) in enumerate(((ld1, mask1),
+                                            (ld2, mask2))):
+                valid = Signal(name="stvalid_i%d" % i, reset_less=True)
+                comb += valid.eq(self.valid_i & self.sst_valid_i[i])
+                comb += ld.valid_i.eq(valid & (mask != mzero))
+                comb += self.sld_valid_o[i].eq(ld.valid_o)
+
+            comb += ld1.ld_i.eq((self.st_data_i & mask1) << ashift1)
+            comb += ld2.ld_i.eq((self.st_data_i & mask2) >> ashift2)
 
         return m
 
@@ -124,7 +144,7 @@ class LDSTSplitter(Elaboratable):
         yield self.is_ld_i
         yield self.ld_data_o.err
         yield self.ld_data_o.data
-        yield self.ld_valid_i
+        yield self.valid_i
         yield self.valid_o
         yield self.sld_valid_i
         for i in range(2):
@@ -156,9 +176,10 @@ def sim(dut):
 
     def send_in():
         print ("send_in")
+        yield dut.is_ld_i.eq(1)
         yield dut.len_i.eq(ld_len)
         yield dut.addr_i.eq(addr)
-        yield dut.ld_valid_i.eq(1)
+        yield dut.valid_i.eq(1)
         print ("waiting")
         while True:
             valid_o = yield dut.valid_o
@@ -166,6 +187,7 @@ def sim(dut):
                 break
             yield
         ld_data_o = yield dut.ld_data_o.data
+        yield dut.is_ld_i.eq(0)
         yield
 
         print (bin(ld_data_o), bin(data))
@@ -174,7 +196,7 @@ def sim(dut):
     def lds():
         print ("lds")
         while True:
-            valid_i = yield dut.ld_valid_i
+            valid_i = yield dut.valid_i
             if valid_i:
                 break
             yield