add fixed PLRU
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 19 Aug 2022 07:28:30 +0000 (00:28 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 19 Aug 2022 07:28:30 +0000 (00:28 -0700)
src/nmutil/formal/test_plru.py
src/nmutil/plru2.py

index 49b8599b4ce9dee56747a00faf7e747971519f36..f7c387071c7c53420867f5b2a4d47276d0353001 100644 (file)
@@ -2,12 +2,12 @@
 # Copyright 2022 Jacob Lifshay
 
 import unittest
-from nmigen.hdl.ast import (AnySeq, Assert, Signal, Assume, Const,
-                            unsigned, AnyConst, Value)
+from nmigen.hdl.ast import (AnySeq, Assert, Signal, Value, Array, Value)
 from nmigen.hdl.dsl import Module
+from nmigen.sim import Delay, Tick
 from nmutil.formaltest import FHDLTestCase
-from nmutil.plru import PLRU, PLRUs
-from nmutil.sim_util import write_il
+from nmutil.plru2 import PLRU  # , PLRUs
+from nmutil.sim_util import write_il, do_sim
 from nmutil.plain_data import plain_data
 
 
@@ -33,18 +33,29 @@ class PrettyPrintState:
 
 @plain_data()
 class PLRUNode:
-    __slots__ = "state", "left_child", "right_child"
+    __slots__ = "id", "state", "left_child", "right_child"
 
-    def __init__(self, state, left_child=None, right_child=None):
-        # type: (Signal, PLRUNode | None, PLRUNode | None) -> None
-        self.state = state
+    def __init__(self, id, left_child=None, right_child=None):
+        # type: (int, PLRUNode | None, PLRUNode | None) -> None
+        self.id = id
+        self.state = Signal(name=f"state_{id}")
         self.left_child = left_child
         self.right_child = right_child
 
+    @property
+    def depth(self):
+        depth = 0
+        if self.left_child is not None:
+            depth = max(depth, 1 + self.left_child.depth)
+        if self.right_child is not None:
+            depth = max(depth, 1 + self.right_child.depth)
+        return depth
+
     def __pretty_print(self, state):
         # type: (PrettyPrintState) -> None
         state.write("PLRUNode(")
         state.indent += 1
+        state.write(f"id={self.id!r},\n")
         state.write(f"state={self.state!r},\n")
         state.write("left_child=")
         if self.left_child is None:
@@ -61,59 +72,147 @@ class PLRUNode:
 
     def pretty_print(self, file=None):
         self.__pretty_print(PrettyPrintState(file=file))
+        print(file=file)
 
-    def set_states_from_index(self, m, index):
-        # type: (Module, Value) -> None
-        m.d.sync += self.state.eq(index[-1])
+    def set_states_from_index(self, m, index, ids):
+        # type: (Module, Value, list[Signal]) -> None
+        m.d.sync += self.state.eq(~index[-1])
+        m.d.comb += ids[0].eq(self.id)
         with m.If(index[-1]):
-            if self.left_child is not None:
-                self.left_child.set_states_from_index(m, index[:-1])
+            if self.right_child is not None:
+                self.right_child.set_states_from_index(m, index[:-1], ids[1:])
         with m.Else():
+            if self.left_child is not None:
+                self.left_child.set_states_from_index(m, index[:-1], ids[1:])
+
+    def get_lru(self, m, ids):
+        # type: (Module, list[Signal]) -> Signal
+        retval = Signal(1 + self.depth, name=f"lru_{self.id}", reset=0)
+        m.d.comb += retval[-1].eq(self.state)
+        m.d.comb += ids[0].eq(self.id)
+        with m.If(self.state):
             if self.right_child is not None:
-                self.right_child.set_states_from_index(m, index[:-1])
+                right_lru = self.right_child.get_lru(m, ids[1:])
+                m.d.comb += retval[:-1].eq(right_lru)
+        with m.Else():
+            if self.left_child is not None:
+                left_lru = self.left_child.get_lru(m, ids[1:])
+                m.d.comb += retval[:-1].eq(left_lru)
+        return retval
 
 
 class TestPLRU(FHDLTestCase):
-    @unittest.skip("not finished yet")
-    def tst(self, BITS):
-        # type: (int) -> None
-
-        # FIXME: figure out what BITS is supposed to mean -- I would have
-        # expected it to be the number of cache ways, or the number of state
-        # bits in PLRU, but it's neither of those, making me think whoever
-        # converted the code botched their math.
-        #
-        # Until that's figured out, this test is broken.
-
-        dut = PLRU(BITS)
+    def tst(self, log2_num_ways, test_seq=None):
+        # type: (int, list[int | None] | None) -> None
+
+        @plain_data()
+        class MyAssert:
+            __slots__ = "test", "en"
+
+            def __init__(self, test, en):
+                # type: (Value, Signal) -> None
+                self.test = test
+                self.en = en
+
+        asserts = []  # type: list[MyAssert]
+
+        def assert_(test):
+            if test_seq is None:
+                return [Assert(test, src_loc_at=1)]
+            assert_en = Signal(name="assert_en", src_loc_at=1, reset=False)
+            asserts.append(MyAssert(test=test, en=assert_en))
+            return [assert_en.eq(True)]
+
+        dut = PLRU(log2_num_ways, debug=True)  # check debug works
+        write_il(self, dut, ports=dut.ports())
+        # debug clutters up vcd, so disable it for formal proofs
+        dut = PLRU(log2_num_ways, debug=test_seq is not None)
+        num_ways = 1 << log2_num_ways
+        self.assertEqual(dut.log2_num_ways, log2_num_ways)
+        self.assertEqual(dut.num_ways, num_ways)
+        self.assertIsInstance(dut.acc_i, Signal)
+        self.assertIsInstance(dut.acc_en_i, Signal)
+        self.assertIsInstance(dut.lru_o, Signal)
+        self.assertEqual(len(dut.acc_i), log2_num_ways)
+        self.assertEqual(len(dut.acc_en_i), 1)
+        self.assertEqual(len(dut.lru_o), log2_num_ways)
         write_il(self, dut, ports=dut.ports())
         m = Module()
-        nodes = [PLRUNode(Signal(name=f"state_{i}")) for i in range(dut.TLBSZ)]
-        self.assertEqual(len(dut._plru_tree), len(nodes))
-        for i in range(1, dut.TLBSZ):
-            parent = (i + 1) // 2 - 1
-            if i % 2:
-                nodes[parent].left_child = nodes[i]
-            else:
-                nodes[parent].right_child = nodes[i]
-            m.d.comb += Assert(nodes[i].state == dut._plru_tree[i])
-
-        in_index = Signal(range(BITS))
-
-        m.d.comb += [
-            in_index.eq(AnySeq(range(BITS))),
-            Assume(in_index < BITS),
-            dut.acc_i.eq(1 << in_index),
-            dut.acc_en.eq(AnySeq(1)),
-        ]
-
-        with m.If(dut.acc_en):
-            nodes[0].set_states_from_index(m, in_index)
+        nodes = [PLRUNode(i) for i in range(num_ways - 1)]
+        self.assertIsInstance(dut._tree, Array)
+        self.assertEqual(len(dut._tree), len(nodes))
+        for i in range(len(nodes)):
+            if i != 0:
+                parent = (i + 1) // 2 - 1
+                if i % 2:
+                    nodes[parent].left_child = nodes[i]
+                else:
+                    nodes[parent].right_child = nodes[i]
+            self.assertIsInstance(dut._tree[i], Signal)
+            self.assertEqual(len(dut._tree[i]), 1)
+            m.d.comb += assert_(nodes[i].state == dut._tree[i])
+
+        if test_seq is None:
+            m.d.comb += [
+                dut.acc_i.eq(AnySeq(log2_num_ways)),
+                dut.acc_en_i.eq(AnySeq(1)),
+            ]
+
+        l2nwr = range(log2_num_ways)
+        upd_ids = [Signal(log2_num_ways, name=f"upd_id_{i}") for i in l2nwr]
+        with m.If(dut.acc_en_i):
+            nodes[0].set_states_from_index(m, dut.acc_i, upd_ids)
+
+            self.assertEqual(len(dut._upd_lru_nodes), len(upd_ids))
+            for l, r in zip(dut._upd_lru_nodes, upd_ids):
+                m.d.comb += assert_(l == r)
+
+        get_ids = [Signal(log2_num_ways, name=f"get_id_{i}") for i in l2nwr]
+        lru = Signal(log2_num_ways)
+        m.d.comb += lru.eq(nodes[0].get_lru(m, get_ids))
+        m.d.comb += assert_(dut.lru_o == lru)
+        self.assertEqual(len(dut._get_lru_nodes), len(get_ids))
+        for l, r in zip(dut._get_lru_nodes, get_ids):
+            m.d.comb += assert_(l == r)
 
         nodes[0].pretty_print()
 
         m.submodules.dut = dut
-        self.assertFormal(m, mode="prove")
+        if test_seq is None:
+            self.assertFormal(m, mode="prove", depth=2)
+        else:
+            traces = [dut.acc_i, dut.acc_en_i, *dut._tree]
+            for node in nodes:
+                traces.append(node.state)
+            traces += [
+                dut.lru_o, lru, *dut._get_lru_nodes, *get_ids,
+                *dut._upd_lru_nodes, *upd_ids,
+            ]
+
+            def subtest(acc_i, acc_en_i):
+                yield dut.acc_i.eq(acc_i)
+                yield dut.acc_en_i.eq(acc_en_i)
+                yield Tick()
+                yield Delay(0.7e-6)
+                for a in asserts:
+                    if (yield a.en):
+                        with self.subTest(
+                                assert_loc=':'.join(map(str, a.en.src_loc))):
+                            self.assertTrue((yield a.test))
+
+            def process():
+                for test_item in test_seq:
+                    if test_item is None:
+                        with self.subTest(test_item="None"):
+                            yield from subtest(acc_i=0, acc_en_i=0)
+                    else:
+                        with self.subTest(test_item=hex(test_item)):
+                            yield from subtest(acc_i=test_item, acc_en_i=1)
+
+            with do_sim(self, m, traces) as sim:
+                sim.add_clock(1e-6)
+                sim.add_process(process)
+                sim.run()
 
     def test_bits_1(self):
         self.tst(1)
@@ -133,35 +232,13 @@ class TestPLRU(FHDLTestCase):
     def test_bits_6(self):
         self.tst(6)
 
-    def test_bits_7(self):
-        self.tst(7)
-
-    def test_bits_8(self):
-        self.tst(8)
-
-    def test_bits_9(self):
-        self.tst(9)
-
-    def test_bits_10(self):
-        self.tst(10)
-
-    def test_bits_11(self):
-        self.tst(11)
-
-    def test_bits_12(self):
-        self.tst(12)
-
-    def test_bits_13(self):
-        self.tst(13)
-
-    def test_bits_14(self):
-        self.tst(14)
-
-    def test_bits_15(self):
-        self.tst(15)
-
-    def test_bits_16(self):
-        self.tst(16)
+    def test_bits_3_sim(self):
+        self.tst(3, [
+            0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7,
+            None,
+            0x0, 0x4, 0x2, 0x6, 0x1, 0x5, 0x3, 0x7,
+            None,
+        ])
 
 
 if __name__ == "__main__":
index aab9ac8d45852c82f0cec764e023421269dd4f0e..d766f6c16b44ea7e5b4c2952a09c9f6004b1431e 100644 (file)
@@ -1,12 +1,12 @@
-# based on ariane plru, from tlb.sv
+# based on microwatt plru.vhdl
+# https://github.com/antonblanchard/microwatt/blob/f67b1431655c291fc1c99857a5c1ef624d5b264c/plru.vhdl
 
 # new PLRU API, once all users have migrated to new API in plru2.py, then
 # plru2.py will be renamed to plru.py.
 
-from nmigen import Signal, Module, Cat, Const, Repl, Array
-from nmigen.hdl.ir import Elaboratable
+from nmigen.hdl.ir import Elaboratable, Display, Signal, Array, Const, Value
+from nmigen.hdl.dsl import Module
 from nmigen.cli import rtlil
-from nmigen.utils import log2_int
 from nmigen.lib.coding import Decoder
 
 
@@ -17,149 +17,161 @@ class PLRU(Elaboratable):
         lvl0        0
                    / \
                   /   \
-        lvl1     1     2
-                / \   / \
-        lvl2   3   4 5   6
-              / \ /\/\  /\
+                 /     \
+        lvl1    1       2
+               / \     / \
+        lvl2  3   4   5   6
+             / \ / \ / \ / \
              ... ... ... ...
     """
 
-    def __init__(self, BITS):
-        self.BITS = BITS
-        self.acc_i = Signal(BITS)
-        self.acc_en = Signal()
-        self.lru_o = Signal(BITS)
-
-        self._plru_tree = Signal(self.TLBSZ)
+    def __init__(self, log2_num_ways, debug=False):
+        # type: (int, bool) -> None
+        """
+        Arguments:
+        log2_num_ways: int
+            the log-base-2 of the number of cache ways -- BITS in plru.vhdl
+        debug: bool
+            true if this should print debugging messages at simulation time.
+        """
+        assert log2_num_ways > 0
+        self.log2_num_ways = log2_num_ways
+        self.debug = debug
+        self.acc_i = Signal(log2_num_ways)
+        self.acc_en_i = Signal()
+        self.lru_o = Signal(log2_num_ways)
+
+        def mk_tree(i):
+            return Signal(name=f"tree_{i}", reset=0)
+
+        # original vhdl has array 1 too big, last entry is never used,
+        # subtract 1 to compensate
+        self._tree = Array(mk_tree(i) for i in range(self.num_ways - 1))
         """ exposed only for testing """
 
-    @property
-    def TLBSZ(self):
-        return 2 * (self.BITS - 1)
-
-    def elaborate(self, platform=None):
-        m = Module()
+        def mk_node(i, prefix):
+            return Signal(range(self.num_ways), name=f"{prefix}_node_{i}",
+                          reset=0)
 
-        # Tree (bit per entry)
-
-        # Just predefine which nodes will be set/cleared
-        # E.g. for a TLB with 8 entries, the for-loop is semantically
-        # equivalent to the following pseudo-code:
-        # unique case (1'b1)
-        # acc_en[7]: plru_tree[0, 2, 6] = {1, 1, 1};
-        # acc_en[6]: plru_tree[0, 2, 6] = {1, 1, 0};
-        # acc_en[5]: plru_tree[0, 2, 5] = {1, 0, 1};
-        # acc_en[4]: plru_tree[0, 2, 5] = {1, 0, 0};
-        # acc_en[3]: plru_tree[0, 1, 4] = {0, 1, 1};
-        # acc_en[2]: plru_tree[0, 1, 4] = {0, 1, 0};
-        # acc_en[1]: plru_tree[0, 1, 3] = {0, 0, 1};
-        # acc_en[0]: plru_tree[0, 1, 3] = {0, 0, 0};
-        # default: begin /* No hit */ end
-        # endcase
-
-        LOG_TLB = log2_int(self.BITS, False)
-        hit = Signal(self.BITS, reset_less=True)
-        m.d.comb += hit.eq(Repl(self.acc_en, self.BITS) & self.acc_i)
-
-        for i in range(self.BITS):
-            # we got a hit so update the pointer as it was least recently used
-            with m.If(hit[i]):
-                # Set the nodes to the values we would expect
-                for lvl in range(LOG_TLB):
-                    idx_base = (1 << lvl)-1
-                    # lvl0 <=> MSB, lvl1 <=> MSB-1, ...
-                    shift = LOG_TLB - lvl
-                    new_idx = Const(~((i >> (shift-1)) & 1), 1)
-                    plru_idx = idx_base + (i >> shift)
-                    # print("plru", i, lvl, hex(idx_base),
-                    #      plru_idx, shift, new_idx)
-                    m.d.sync += self._plru_tree[plru_idx].eq(new_idx)
-
-        # Decode tree to write enable signals
-        # Next for-loop basically creates the following logic for e.g.
-        # an 8 entry TLB (note: pseudo-code obviously):
-        # replace_en[7] = &plru_tree[ 6, 2, 0]; #plru_tree[0,2,6]=={1,1,1}
-        # replace_en[6] = &plru_tree[~6, 2, 0]; #plru_tree[0,2,6]=={1,1,0}
-        # replace_en[5] = &plru_tree[ 5,~2, 0]; #plru_tree[0,2,5]=={1,0,1}
-        # replace_en[4] = &plru_tree[~5,~2, 0]; #plru_tree[0,2,5]=={1,0,0}
-        # replace_en[3] = &plru_tree[ 4, 1,~0]; #plru_tree[0,1,4]=={0,1,1}
-        # replace_en[2] = &plru_tree[~4, 1,~0]; #plru_tree[0,1,4]=={0,1,0}
-        # replace_en[1] = &plru_tree[ 3,~1,~0]; #plru_tree[0,1,3]=={0,0,1}
-        # replace_en[0] = &plru_tree[~3,~1,~0]; #plru_tree[0,1,3]=={0,0,0}
-        # For each entry traverse the tree. If every tree-node matches
-        # the corresponding bit of the entry's index, this is
-        # the next entry to replace.
-        replace = []
-        for i in range(self.BITS):
-            en = []
-            for lvl in range(LOG_TLB):
-                idx_base = (1 << lvl)-1
-                # lvl0 <=> MSB, lvl1 <=> MSB-1, ...
-                shift = LOG_TLB - lvl
-                new_idx = (i >> (shift-1)) & 1
-                plru_idx = idx_base + (i >> shift)
-                plru = Signal(reset_less=True,
-                              name="plru-%d-%d-%d-%d" %
-                              (i, lvl, plru_idx, new_idx))
-                m.d.comb += plru.eq(self._plru_tree[plru_idx])
-                if new_idx:
-                    en.append(~plru)  # yes inverted (using bool() below)
-                else:
-                    en.append(plru)  # yes inverted (using bool() below)
-            #print("plru", i, en)
-            # boolean logic manipulation:
-            # plru0 & plru1 & plru2 == ~(~plru0 | ~plru1 | ~plru2)
-            replace.append(~Cat(*en).bool())
-        m.d.comb += self.lru_o.eq(Cat(*replace))
+        nodes_range = range(self.log2_num_ways)
 
-        return m
-
-    def ports(self):
-        return [self.acc_en, self.lru_o, self.acc_i]
+        self._get_lru_nodes = [mk_node(i, "get_lru") for i in nodes_range]
+        """ exposed only for testing """
 
+        self._upd_lru_nodes = [mk_node(i, "upd_lru") for i in nodes_range]
+        """ exposed only for testing """
 
-class PLRUs(Elaboratable):
-    def __init__(self, n_plrus, n_bits):
-        self.n_plrus = n_plrus
-        self.n_bits = n_bits
-        self.valid = Signal()
-        self.way = Signal(n_bits)
-        self.index = Signal(n_plrus.bit_length())
-        self.isel = Signal(n_plrus.bit_length())
-        self.o_index = Signal(n_bits)
+    @property
+    def num_ways(self):
+        return 1 << self.log2_num_ways
+
+    def _display(self, msg, *args):
+        if not self.debug:
+            return []
+        # work around not yet having
+        # https://gitlab.com/nmigen/nmigen/-/merge_requests/10
+        # by sending through Value.cast()
+        return [Display(msg, *map(Value.cast, args))]
+
+    def _get_lru(self, m):
+        """ get_lru process in plru.vhdl """
+        # XXX Check if we can turn that into a little ROM instead that
+        # takes the tree bit vector and returns the LRU. See if it's better
+        # in term of FPGA resource usage...
+        m.d.comb += self._get_lru_nodes[0].eq(0)
+        for i in range(self.log2_num_ways):
+            node = self._get_lru_nodes[i]
+            val = self._tree[node]
+            m.d.comb += self._display("GET: i:%i node:%#x val:%i",
+                                      i, node, val)
+            m.d.comb += self.lru_o[self.log2_num_ways - 1 - i].eq(val)
+            if i != self.log2_num_ways - 1:
+                # modified from microwatt version, it uses `node * 2` value
+                # to index into tree, rather than using node like is used
+                # earlier in this loop iteration
+                node <<= 1
+                with m.If(val):
+                    m.d.comb += self._get_lru_nodes[i + 1].eq(node + 2)
+                with m.Else():
+                    m.d.comb += self._get_lru_nodes[i + 1].eq(node + 1)
+
+    def _update_lru(self, m):
+        """ update_lru process in plru.vhdl """
+        with m.If(self.acc_en_i):
+            m.d.comb += self._upd_lru_nodes[0].eq(0)
+            for i in range(self.log2_num_ways):
+                node = self._upd_lru_nodes[i]
+                abit = self.acc_i[self.log2_num_ways - 1 - i]
+                m.d.sync += [
+                    self._tree[node].eq(~abit),
+                    self._display("UPD: i:%i node:%#x val:%i",
+                                  i, node, ~abit),
+                ]
+                if i != self.log2_num_ways - 1:
+                    node <<= 1
+                    with m.If(abit):
+                        m.d.comb += self._upd_lru_nodes[i + 1].eq(node + 2)
+                    with m.Else():
+                        m.d.comb += self._upd_lru_nodes[i + 1].eq(node + 1)
 
-    def elaborate(self, platform):
-        """Generate TLB PLRUs
-        """
+    def elaborate(self, platform=None):
         m = Module()
-        comb = m.d.comb
-
-        if self.n_plrus == 0:
-            return m
-
-        # Binary-to-Unary one-hot, enabled by valid
-        m.submodules.te = te = Decoder(self.n_plrus)
-        comb += te.n.eq(~self.valid)
-        comb += te.i.eq(self.index)
-
-        out = Array(Signal(self.n_bits, name="plru_out%d" % x)
-                    for x in range(self.n_plrus))
-
-        for i in range(self.n_plrus):
-            # PLRU interface
-            m.submodules["plru_%d" % i] = plru = PLRU(self.n_bits)
-
-            comb += plru.acc_en.eq(te.o[i])
-            comb += plru.acc_i.eq(self.way)
-            comb += out[i].eq(plru.lru_o)
-
-        # select output based on index
-        comb += self.o_index.eq(out[self.isel])
-
+        self._get_lru(m)
+        self._update_lru(m)
         return m
 
+    def __iter__(self):
+        yield self.acc_i
+        yield self.acc_en_i
+        yield self.lru_o
+
     def ports(self):
-        return [self.valid, self.way, self.index, self.isel, self.o_index]
+        return list(self)
+
+
+# FIXME: convert PLRUs to new API
+# class PLRUs(Elaboratable):
+#     def __init__(self, n_plrus, n_bits):
+#         self.n_plrus = n_plrus
+#         self.n_bits = n_bits
+#         self.valid = Signal()
+#         self.way = Signal(n_bits)
+#         self.index = Signal(n_plrus.bit_length())
+#         self.isel = Signal(n_plrus.bit_length())
+#         self.o_index = Signal(n_bits)
+#
+#     def elaborate(self, platform):
+#         """Generate TLB PLRUs
+#         """
+#         m = Module()
+#         comb = m.d.comb
+#
+#         if self.n_plrus == 0:
+#             return m
+#
+#         # Binary-to-Unary one-hot, enabled by valid
+#         m.submodules.te = te = Decoder(self.n_plrus)
+#         comb += te.n.eq(~self.valid)
+#         comb += te.i.eq(self.index)
+#
+#         out = Array(Signal(self.n_bits, name="plru_out%d" % x)
+#                     for x in range(self.n_plrus))
+#
+#         for i in range(self.n_plrus):
+#             # PLRU interface
+#             m.submodules["plru_%d" % i] = plru = PLRU(self.n_bits)
+#
+#             comb += plru.acc_en.eq(te.o[i])
+#             comb += plru.acc_i.eq(self.way)
+#             comb += out[i].eq(plru.lru_o)
+#
+#         # select output based on index
+#         comb += self.o_index.eq(out[self.isel])
+#
+#         return m
+#
+#     def ports(self):
+#         return [self.valid, self.way, self.index, self.isel, self.o_index]
 
 
 if __name__ == '__main__':
@@ -168,7 +180,7 @@ if __name__ == '__main__':
     with open("test_plru.il", "w") as f:
         f.write(vl)
 
-    dut = PLRUs(4, 2)
-    vl = rtlil.convert(dut, ports=dut.ports())
-    with open("test_plrus.il", "w") as f:
-        f.write(vl)
+    dut = PLRUs(4, 2)
+    vl = rtlil.convert(dut, ports=dut.ports())
+    with open("test_plrus.il", "w") as f:
+        f.write(vl)