From 0759ba7a4a56812e195ee324bf87855084797ec0 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 19 Aug 2022 00:28:30 -0700 Subject: [PATCH] add fixed PLRU --- src/nmutil/formal/test_plru.py | 231 +++++++++++++++++--------- src/nmutil/plru2.py | 286 +++++++++++++++++---------------- 2 files changed, 303 insertions(+), 214 deletions(-) diff --git a/src/nmutil/formal/test_plru.py b/src/nmutil/formal/test_plru.py index 49b8599..f7c3870 100644 --- a/src/nmutil/formal/test_plru.py +++ b/src/nmutil/formal/test_plru.py @@ -2,12 +2,12 @@ # Copyright 2022 Jacob Lifshay import unittest -from nmigen.hdl.ast import (AnySeq, Assert, Signal, Assume, Const, - unsigned, AnyConst, Value) +from nmigen.hdl.ast import (AnySeq, Assert, Signal, Value, Array, Value) from nmigen.hdl.dsl import Module +from nmigen.sim import Delay, Tick from nmutil.formaltest import FHDLTestCase -from nmutil.plru import PLRU, PLRUs -from nmutil.sim_util import write_il +from nmutil.plru2 import PLRU # , PLRUs +from nmutil.sim_util import write_il, do_sim from nmutil.plain_data import plain_data @@ -33,18 +33,29 @@ class PrettyPrintState: @plain_data() class PLRUNode: - __slots__ = "state", "left_child", "right_child" + __slots__ = "id", "state", "left_child", "right_child" - def __init__(self, state, left_child=None, right_child=None): - # type: (Signal, PLRUNode | None, PLRUNode | None) -> None - self.state = state + def __init__(self, id, left_child=None, right_child=None): + # type: (int, PLRUNode | None, PLRUNode | None) -> None + self.id = id + self.state = Signal(name=f"state_{id}") self.left_child = left_child self.right_child = right_child + @property + def depth(self): + depth = 0 + if self.left_child is not None: + depth = max(depth, 1 + self.left_child.depth) + if self.right_child is not None: + depth = max(depth, 1 + self.right_child.depth) + return depth + def __pretty_print(self, state): # type: (PrettyPrintState) -> None state.write("PLRUNode(") state.indent += 1 + state.write(f"id={self.id!r},\n") state.write(f"state={self.state!r},\n") state.write("left_child=") if self.left_child is None: @@ -61,59 +72,147 @@ class PLRUNode: def pretty_print(self, file=None): self.__pretty_print(PrettyPrintState(file=file)) + print(file=file) - def set_states_from_index(self, m, index): - # type: (Module, Value) -> None - m.d.sync += self.state.eq(index[-1]) + def set_states_from_index(self, m, index, ids): + # type: (Module, Value, list[Signal]) -> None + m.d.sync += self.state.eq(~index[-1]) + m.d.comb += ids[0].eq(self.id) with m.If(index[-1]): - if self.left_child is not None: - self.left_child.set_states_from_index(m, index[:-1]) + if self.right_child is not None: + self.right_child.set_states_from_index(m, index[:-1], ids[1:]) with m.Else(): + if self.left_child is not None: + self.left_child.set_states_from_index(m, index[:-1], ids[1:]) + + def get_lru(self, m, ids): + # type: (Module, list[Signal]) -> Signal + retval = Signal(1 + self.depth, name=f"lru_{self.id}", reset=0) + m.d.comb += retval[-1].eq(self.state) + m.d.comb += ids[0].eq(self.id) + with m.If(self.state): if self.right_child is not None: - self.right_child.set_states_from_index(m, index[:-1]) + right_lru = self.right_child.get_lru(m, ids[1:]) + m.d.comb += retval[:-1].eq(right_lru) + with m.Else(): + if self.left_child is not None: + left_lru = self.left_child.get_lru(m, ids[1:]) + m.d.comb += retval[:-1].eq(left_lru) + return retval class TestPLRU(FHDLTestCase): - @unittest.skip("not finished yet") - def tst(self, BITS): - # type: (int) -> None - - # FIXME: figure out what BITS is supposed to mean -- I would have - # expected it to be the number of cache ways, or the number of state - # bits in PLRU, but it's neither of those, making me think whoever - # converted the code botched their math. - # - # Until that's figured out, this test is broken. - - dut = PLRU(BITS) + def tst(self, log2_num_ways, test_seq=None): + # type: (int, list[int | None] | None) -> None + + @plain_data() + class MyAssert: + __slots__ = "test", "en" + + def __init__(self, test, en): + # type: (Value, Signal) -> None + self.test = test + self.en = en + + asserts = [] # type: list[MyAssert] + + def assert_(test): + if test_seq is None: + return [Assert(test, src_loc_at=1)] + assert_en = Signal(name="assert_en", src_loc_at=1, reset=False) + asserts.append(MyAssert(test=test, en=assert_en)) + return [assert_en.eq(True)] + + dut = PLRU(log2_num_ways, debug=True) # check debug works + write_il(self, dut, ports=dut.ports()) + # debug clutters up vcd, so disable it for formal proofs + dut = PLRU(log2_num_ways, debug=test_seq is not None) + num_ways = 1 << log2_num_ways + self.assertEqual(dut.log2_num_ways, log2_num_ways) + self.assertEqual(dut.num_ways, num_ways) + self.assertIsInstance(dut.acc_i, Signal) + self.assertIsInstance(dut.acc_en_i, Signal) + self.assertIsInstance(dut.lru_o, Signal) + self.assertEqual(len(dut.acc_i), log2_num_ways) + self.assertEqual(len(dut.acc_en_i), 1) + self.assertEqual(len(dut.lru_o), log2_num_ways) write_il(self, dut, ports=dut.ports()) m = Module() - nodes = [PLRUNode(Signal(name=f"state_{i}")) for i in range(dut.TLBSZ)] - self.assertEqual(len(dut._plru_tree), len(nodes)) - for i in range(1, dut.TLBSZ): - parent = (i + 1) // 2 - 1 - if i % 2: - nodes[parent].left_child = nodes[i] - else: - nodes[parent].right_child = nodes[i] - m.d.comb += Assert(nodes[i].state == dut._plru_tree[i]) - - in_index = Signal(range(BITS)) - - m.d.comb += [ - in_index.eq(AnySeq(range(BITS))), - Assume(in_index < BITS), - dut.acc_i.eq(1 << in_index), - dut.acc_en.eq(AnySeq(1)), - ] - - with m.If(dut.acc_en): - nodes[0].set_states_from_index(m, in_index) + nodes = [PLRUNode(i) for i in range(num_ways - 1)] + self.assertIsInstance(dut._tree, Array) + self.assertEqual(len(dut._tree), len(nodes)) + for i in range(len(nodes)): + if i != 0: + parent = (i + 1) // 2 - 1 + if i % 2: + nodes[parent].left_child = nodes[i] + else: + nodes[parent].right_child = nodes[i] + self.assertIsInstance(dut._tree[i], Signal) + self.assertEqual(len(dut._tree[i]), 1) + m.d.comb += assert_(nodes[i].state == dut._tree[i]) + + if test_seq is None: + m.d.comb += [ + dut.acc_i.eq(AnySeq(log2_num_ways)), + dut.acc_en_i.eq(AnySeq(1)), + ] + + l2nwr = range(log2_num_ways) + upd_ids = [Signal(log2_num_ways, name=f"upd_id_{i}") for i in l2nwr] + with m.If(dut.acc_en_i): + nodes[0].set_states_from_index(m, dut.acc_i, upd_ids) + + self.assertEqual(len(dut._upd_lru_nodes), len(upd_ids)) + for l, r in zip(dut._upd_lru_nodes, upd_ids): + m.d.comb += assert_(l == r) + + get_ids = [Signal(log2_num_ways, name=f"get_id_{i}") for i in l2nwr] + lru = Signal(log2_num_ways) + m.d.comb += lru.eq(nodes[0].get_lru(m, get_ids)) + m.d.comb += assert_(dut.lru_o == lru) + self.assertEqual(len(dut._get_lru_nodes), len(get_ids)) + for l, r in zip(dut._get_lru_nodes, get_ids): + m.d.comb += assert_(l == r) nodes[0].pretty_print() m.submodules.dut = dut - self.assertFormal(m, mode="prove") + if test_seq is None: + self.assertFormal(m, mode="prove", depth=2) + else: + traces = [dut.acc_i, dut.acc_en_i, *dut._tree] + for node in nodes: + traces.append(node.state) + traces += [ + dut.lru_o, lru, *dut._get_lru_nodes, *get_ids, + *dut._upd_lru_nodes, *upd_ids, + ] + + def subtest(acc_i, acc_en_i): + yield dut.acc_i.eq(acc_i) + yield dut.acc_en_i.eq(acc_en_i) + yield Tick() + yield Delay(0.7e-6) + for a in asserts: + if (yield a.en): + with self.subTest( + assert_loc=':'.join(map(str, a.en.src_loc))): + self.assertTrue((yield a.test)) + + def process(): + for test_item in test_seq: + if test_item is None: + with self.subTest(test_item="None"): + yield from subtest(acc_i=0, acc_en_i=0) + else: + with self.subTest(test_item=hex(test_item)): + yield from subtest(acc_i=test_item, acc_en_i=1) + + with do_sim(self, m, traces) as sim: + sim.add_clock(1e-6) + sim.add_process(process) + sim.run() def test_bits_1(self): self.tst(1) @@ -133,35 +232,13 @@ class TestPLRU(FHDLTestCase): def test_bits_6(self): self.tst(6) - def test_bits_7(self): - self.tst(7) - - def test_bits_8(self): - self.tst(8) - - def test_bits_9(self): - self.tst(9) - - def test_bits_10(self): - self.tst(10) - - def test_bits_11(self): - self.tst(11) - - def test_bits_12(self): - self.tst(12) - - def test_bits_13(self): - self.tst(13) - - def test_bits_14(self): - self.tst(14) - - def test_bits_15(self): - self.tst(15) - - def test_bits_16(self): - self.tst(16) + def test_bits_3_sim(self): + self.tst(3, [ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + None, + 0x0, 0x4, 0x2, 0x6, 0x1, 0x5, 0x3, 0x7, + None, + ]) if __name__ == "__main__": diff --git a/src/nmutil/plru2.py b/src/nmutil/plru2.py index aab9ac8..d766f6c 100644 --- a/src/nmutil/plru2.py +++ b/src/nmutil/plru2.py @@ -1,12 +1,12 @@ -# based on ariane plru, from tlb.sv +# based on microwatt plru.vhdl +# https://github.com/antonblanchard/microwatt/blob/f67b1431655c291fc1c99857a5c1ef624d5b264c/plru.vhdl # new PLRU API, once all users have migrated to new API in plru2.py, then # plru2.py will be renamed to plru.py. -from nmigen import Signal, Module, Cat, Const, Repl, Array -from nmigen.hdl.ir import Elaboratable +from nmigen.hdl.ir import Elaboratable, Display, Signal, Array, Const, Value +from nmigen.hdl.dsl import Module from nmigen.cli import rtlil -from nmigen.utils import log2_int from nmigen.lib.coding import Decoder @@ -17,149 +17,161 @@ class PLRU(Elaboratable): lvl0 0 / \ / \ - lvl1 1 2 - / \ / \ - lvl2 3 4 5 6 - / \ /\/\ /\ + / \ + lvl1 1 2 + / \ / \ + lvl2 3 4 5 6 + / \ / \ / \ / \ ... ... ... ... """ - def __init__(self, BITS): - self.BITS = BITS - self.acc_i = Signal(BITS) - self.acc_en = Signal() - self.lru_o = Signal(BITS) - - self._plru_tree = Signal(self.TLBSZ) + def __init__(self, log2_num_ways, debug=False): + # type: (int, bool) -> None + """ + Arguments: + log2_num_ways: int + the log-base-2 of the number of cache ways -- BITS in plru.vhdl + debug: bool + true if this should print debugging messages at simulation time. + """ + assert log2_num_ways > 0 + self.log2_num_ways = log2_num_ways + self.debug = debug + self.acc_i = Signal(log2_num_ways) + self.acc_en_i = Signal() + self.lru_o = Signal(log2_num_ways) + + def mk_tree(i): + return Signal(name=f"tree_{i}", reset=0) + + # original vhdl has array 1 too big, last entry is never used, + # subtract 1 to compensate + self._tree = Array(mk_tree(i) for i in range(self.num_ways - 1)) """ exposed only for testing """ - @property - def TLBSZ(self): - return 2 * (self.BITS - 1) - - def elaborate(self, platform=None): - m = Module() + def mk_node(i, prefix): + return Signal(range(self.num_ways), name=f"{prefix}_node_{i}", + reset=0) - # Tree (bit per entry) - - # Just predefine which nodes will be set/cleared - # E.g. for a TLB with 8 entries, the for-loop is semantically - # equivalent to the following pseudo-code: - # unique case (1'b1) - # acc_en[7]: plru_tree[0, 2, 6] = {1, 1, 1}; - # acc_en[6]: plru_tree[0, 2, 6] = {1, 1, 0}; - # acc_en[5]: plru_tree[0, 2, 5] = {1, 0, 1}; - # acc_en[4]: plru_tree[0, 2, 5] = {1, 0, 0}; - # acc_en[3]: plru_tree[0, 1, 4] = {0, 1, 1}; - # acc_en[2]: plru_tree[0, 1, 4] = {0, 1, 0}; - # acc_en[1]: plru_tree[0, 1, 3] = {0, 0, 1}; - # acc_en[0]: plru_tree[0, 1, 3] = {0, 0, 0}; - # default: begin /* No hit */ end - # endcase - - LOG_TLB = log2_int(self.BITS, False) - hit = Signal(self.BITS, reset_less=True) - m.d.comb += hit.eq(Repl(self.acc_en, self.BITS) & self.acc_i) - - for i in range(self.BITS): - # we got a hit so update the pointer as it was least recently used - with m.If(hit[i]): - # Set the nodes to the values we would expect - for lvl in range(LOG_TLB): - idx_base = (1 << lvl)-1 - # lvl0 <=> MSB, lvl1 <=> MSB-1, ... - shift = LOG_TLB - lvl - new_idx = Const(~((i >> (shift-1)) & 1), 1) - plru_idx = idx_base + (i >> shift) - # print("plru", i, lvl, hex(idx_base), - # plru_idx, shift, new_idx) - m.d.sync += self._plru_tree[plru_idx].eq(new_idx) - - # Decode tree to write enable signals - # Next for-loop basically creates the following logic for e.g. - # an 8 entry TLB (note: pseudo-code obviously): - # replace_en[7] = &plru_tree[ 6, 2, 0]; #plru_tree[0,2,6]=={1,1,1} - # replace_en[6] = &plru_tree[~6, 2, 0]; #plru_tree[0,2,6]=={1,1,0} - # replace_en[5] = &plru_tree[ 5,~2, 0]; #plru_tree[0,2,5]=={1,0,1} - # replace_en[4] = &plru_tree[~5,~2, 0]; #plru_tree[0,2,5]=={1,0,0} - # replace_en[3] = &plru_tree[ 4, 1,~0]; #plru_tree[0,1,4]=={0,1,1} - # replace_en[2] = &plru_tree[~4, 1,~0]; #plru_tree[0,1,4]=={0,1,0} - # replace_en[1] = &plru_tree[ 3,~1,~0]; #plru_tree[0,1,3]=={0,0,1} - # replace_en[0] = &plru_tree[~3,~1,~0]; #plru_tree[0,1,3]=={0,0,0} - # For each entry traverse the tree. If every tree-node matches - # the corresponding bit of the entry's index, this is - # the next entry to replace. - replace = [] - for i in range(self.BITS): - en = [] - for lvl in range(LOG_TLB): - idx_base = (1 << lvl)-1 - # lvl0 <=> MSB, lvl1 <=> MSB-1, ... - shift = LOG_TLB - lvl - new_idx = (i >> (shift-1)) & 1 - plru_idx = idx_base + (i >> shift) - plru = Signal(reset_less=True, - name="plru-%d-%d-%d-%d" % - (i, lvl, plru_idx, new_idx)) - m.d.comb += plru.eq(self._plru_tree[plru_idx]) - if new_idx: - en.append(~plru) # yes inverted (using bool() below) - else: - en.append(plru) # yes inverted (using bool() below) - #print("plru", i, en) - # boolean logic manipulation: - # plru0 & plru1 & plru2 == ~(~plru0 | ~plru1 | ~plru2) - replace.append(~Cat(*en).bool()) - m.d.comb += self.lru_o.eq(Cat(*replace)) + nodes_range = range(self.log2_num_ways) - return m - - def ports(self): - return [self.acc_en, self.lru_o, self.acc_i] + self._get_lru_nodes = [mk_node(i, "get_lru") for i in nodes_range] + """ exposed only for testing """ + self._upd_lru_nodes = [mk_node(i, "upd_lru") for i in nodes_range] + """ exposed only for testing """ -class PLRUs(Elaboratable): - def __init__(self, n_plrus, n_bits): - self.n_plrus = n_plrus - self.n_bits = n_bits - self.valid = Signal() - self.way = Signal(n_bits) - self.index = Signal(n_plrus.bit_length()) - self.isel = Signal(n_plrus.bit_length()) - self.o_index = Signal(n_bits) + @property + def num_ways(self): + return 1 << self.log2_num_ways + + def _display(self, msg, *args): + if not self.debug: + return [] + # work around not yet having + # https://gitlab.com/nmigen/nmigen/-/merge_requests/10 + # by sending through Value.cast() + return [Display(msg, *map(Value.cast, args))] + + def _get_lru(self, m): + """ get_lru process in plru.vhdl """ + # XXX Check if we can turn that into a little ROM instead that + # takes the tree bit vector and returns the LRU. See if it's better + # in term of FPGA resource usage... + m.d.comb += self._get_lru_nodes[0].eq(0) + for i in range(self.log2_num_ways): + node = self._get_lru_nodes[i] + val = self._tree[node] + m.d.comb += self._display("GET: i:%i node:%#x val:%i", + i, node, val) + m.d.comb += self.lru_o[self.log2_num_ways - 1 - i].eq(val) + if i != self.log2_num_ways - 1: + # modified from microwatt version, it uses `node * 2` value + # to index into tree, rather than using node like is used + # earlier in this loop iteration + node <<= 1 + with m.If(val): + m.d.comb += self._get_lru_nodes[i + 1].eq(node + 2) + with m.Else(): + m.d.comb += self._get_lru_nodes[i + 1].eq(node + 1) + + def _update_lru(self, m): + """ update_lru process in plru.vhdl """ + with m.If(self.acc_en_i): + m.d.comb += self._upd_lru_nodes[0].eq(0) + for i in range(self.log2_num_ways): + node = self._upd_lru_nodes[i] + abit = self.acc_i[self.log2_num_ways - 1 - i] + m.d.sync += [ + self._tree[node].eq(~abit), + self._display("UPD: i:%i node:%#x val:%i", + i, node, ~abit), + ] + if i != self.log2_num_ways - 1: + node <<= 1 + with m.If(abit): + m.d.comb += self._upd_lru_nodes[i + 1].eq(node + 2) + with m.Else(): + m.d.comb += self._upd_lru_nodes[i + 1].eq(node + 1) - def elaborate(self, platform): - """Generate TLB PLRUs - """ + def elaborate(self, platform=None): m = Module() - comb = m.d.comb - - if self.n_plrus == 0: - return m - - # Binary-to-Unary one-hot, enabled by valid - m.submodules.te = te = Decoder(self.n_plrus) - comb += te.n.eq(~self.valid) - comb += te.i.eq(self.index) - - out = Array(Signal(self.n_bits, name="plru_out%d" % x) - for x in range(self.n_plrus)) - - for i in range(self.n_plrus): - # PLRU interface - m.submodules["plru_%d" % i] = plru = PLRU(self.n_bits) - - comb += plru.acc_en.eq(te.o[i]) - comb += plru.acc_i.eq(self.way) - comb += out[i].eq(plru.lru_o) - - # select output based on index - comb += self.o_index.eq(out[self.isel]) - + self._get_lru(m) + self._update_lru(m) return m + def __iter__(self): + yield self.acc_i + yield self.acc_en_i + yield self.lru_o + def ports(self): - return [self.valid, self.way, self.index, self.isel, self.o_index] + return list(self) + + +# FIXME: convert PLRUs to new API +# class PLRUs(Elaboratable): +# def __init__(self, n_plrus, n_bits): +# self.n_plrus = n_plrus +# self.n_bits = n_bits +# self.valid = Signal() +# self.way = Signal(n_bits) +# self.index = Signal(n_plrus.bit_length()) +# self.isel = Signal(n_plrus.bit_length()) +# self.o_index = Signal(n_bits) +# +# def elaborate(self, platform): +# """Generate TLB PLRUs +# """ +# m = Module() +# comb = m.d.comb +# +# if self.n_plrus == 0: +# return m +# +# # Binary-to-Unary one-hot, enabled by valid +# m.submodules.te = te = Decoder(self.n_plrus) +# comb += te.n.eq(~self.valid) +# comb += te.i.eq(self.index) +# +# out = Array(Signal(self.n_bits, name="plru_out%d" % x) +# for x in range(self.n_plrus)) +# +# for i in range(self.n_plrus): +# # PLRU interface +# m.submodules["plru_%d" % i] = plru = PLRU(self.n_bits) +# +# comb += plru.acc_en.eq(te.o[i]) +# comb += plru.acc_i.eq(self.way) +# comb += out[i].eq(plru.lru_o) +# +# # select output based on index +# comb += self.o_index.eq(out[self.isel]) +# +# return m +# +# def ports(self): +# return [self.valid, self.way, self.index, self.isel, self.o_index] if __name__ == '__main__': @@ -168,7 +180,7 @@ if __name__ == '__main__': with open("test_plru.il", "w") as f: f.write(vl) - dut = PLRUs(4, 2) - vl = rtlil.convert(dut, ports=dut.ports()) - with open("test_plrus.il", "w") as f: - f.write(vl) + # dut = PLRUs(4, 2) + # vl = rtlil.convert(dut, ports=dut.ports()) + # with open("test_plrus.il", "w") as f: + # f.write(vl) -- 2.30.2