soc/cores: add ECC (Error Correcting Code)
authorFlorent Kermarrec <florent@enjoy-digital.fr>
Sat, 13 Jul 2019 09:44:29 +0000 (11:44 +0200)
committerFlorent Kermarrec <florent@enjoy-digital.fr>
Sat, 13 Jul 2019 09:44:29 +0000 (11:44 +0200)
Hamming codes with additional parity (SECDED):
- Single Error Correction
- Double Error Detection

litex/soc/cores/ecc.py [new file with mode: 0644]
test/test_ecc.py [new file with mode: 0644]

diff --git a/litex/soc/cores/ecc.py b/litex/soc/cores/ecc.py
new file mode 100644 (file)
index 0000000..0a63c08
--- /dev/null
@@ -0,0 +1,155 @@
+# This file is Copyright (c) 2018-2019 Florent Kermarrec <florent@enjoy-digital.fr>
+# License: BSD
+
+"""
+Error Correcting Code
+
+Hamming codes with additional parity (SECDED):
+- Single Error Correction
+- Double Error Detection
+"""
+
+from functools import reduce
+from operator import xor
+
+from migen import *
+
+
+def compute_m_n(k):
+    m = 1
+    while (2**m < (m + k + 1)):
+        m = m + 1;
+    n = m + k
+    return m, n
+
+
+def compute_syndrome_positions(m):
+    r = []
+    i = 1
+    while i <= m:
+        r.append(i)
+        i = i << 1
+    return r
+
+
+def compute_data_positions(m):
+    r = []
+    e = compute_syndrome_positions(m)
+    for i in range(1, m + 1):
+        if not i in e:
+            r.append(i)
+    return r
+
+
+def compute_cover_positions(m, p):
+    r = []
+    i = p
+    while i <= m:
+        for j in range(min(p, m - i + 1)):
+            r.append(i + j)
+        i += 2*p
+    return r
+
+
+class SECDED:
+    def place_data(self, data, codeword):
+        d_pos = compute_data_positions(len(codeword))
+        for i, d in enumerate(d_pos):
+            self.comb += codeword[d-1].eq(data[i])
+
+    def extract_data(self, codeword, data):
+        d_pos = compute_data_positions(len(codeword))
+        for i, d in enumerate(d_pos):
+            self.comb += data[i].eq(codeword[d-1])
+
+    def compute_syndrome(self, codeword, syndrome):
+        p_pos = compute_syndrome_positions(len(codeword))
+        for i, p in enumerate(p_pos):
+            pn = Signal()
+            c_pos = compute_cover_positions(len(codeword), 2**i)
+            for c in c_pos:
+                new_pn = Signal()
+                self.comb += new_pn.eq(pn ^ codeword[c-1])
+                pn = new_pn
+            self.comb += syndrome[i].eq(pn)
+
+    def place_syndrome(self, syndrome, codeword):
+        p_pos = compute_syndrome_positions(len(codeword))
+        for i, p in enumerate(p_pos):
+            self.comb += codeword[p-1].eq(syndrome[i])
+
+    def compute_parity(self, codeword, parity):
+        self.comb += parity.eq(reduce(xor,
+            [codeword[i] for i in range(len(codeword))]))
+
+
+class ECCEncoder(SECDED, Module):
+    def __init__(self, k):
+        m, n = compute_m_n(k)
+
+        self.i = i = Signal(k)
+        self.o = o = Signal(n + 1)
+
+        # # #
+
+        syndrome = Signal(m)
+        parity = Signal()
+        codeword_d = Signal(n)
+        codeword_d_p = Signal(n)
+        codeword = Signal(n + 1)
+
+        # place data bits in codeword
+        self.place_data(i, codeword_d)
+        # compute and place syndrome bits
+        self.compute_syndrome(codeword_d, syndrome)
+        self.comb += codeword_d_p.eq(codeword_d)
+        self.place_syndrome(syndrome, codeword_d_p)
+        # compute parity
+        self.compute_parity(codeword_d_p, parity)
+        # output codeword + parity
+        self.comb += o.eq(Cat(parity, codeword_d_p))
+
+
+class ECCDecoder(SECDED, Module):
+    def __init__(self, k):
+        m, n = compute_m_n(k)
+
+        self.enable = Signal()
+        self.i = i = Signal(n + 1)
+        self.o = o = Signal(k)
+
+        self.sec = sec = Signal()
+        self.ded = ded = Signal()
+
+        # # #
+
+        syndrome = Signal(m)
+        parity = Signal()
+        codeword = Signal(n)
+        codeword_c = Signal(n)
+
+        # input codeword + parity
+        self.compute_parity(i, parity)
+        self.comb += codeword.eq(i[1:])
+        # compute_syndrome
+        self.compute_syndrome(codeword, syndrome)
+        self.comb += If(~self.enable, syndrome.eq(0))
+        # locate/correct codeword error bit if any and flip it
+        cases = {}
+        cases["default"] = codeword_c.eq(codeword)
+        for i in range(1, 2**len(syndrome)):
+            cases[i] = codeword_c.eq(codeword ^ (1<<(i-1)))
+        self.comb += Case(syndrome, cases)
+        # extract data / status
+        self.extract_data(codeword_c, o)
+        self.comb += [
+            If(syndrome != 0,
+                 # double error detected
+                If(~parity,
+                    ded.eq(1)
+                # single error corrected
+                ).Else(
+                    sec.eq(1)
+                )
+            )
+        ]
diff --git a/test/test_ecc.py b/test/test_ecc.py
new file mode 100644 (file)
index 0000000..ca7929f
--- /dev/null
@@ -0,0 +1,100 @@
+# This file is Copyright (c) 2018-2019 Florent Kermarrec <florent@enjoy-digital.fr>
+# License: BSD
+
+import unittest
+import random
+
+from migen import *
+
+from litedram.common import *
+from litedram.frontend.ecc import *
+
+from litex.gen.sim import *
+
+
+class TestECC(unittest.TestCase):
+    def test_m_n(self):
+        m, n = compute_m_n(15)
+        self.assertEqual(m, 5)
+        self.assertEqual(n, 20)
+
+    def test_syndrome_positions(self):
+        p_pos = compute_syndrome_positions(20)
+        p_pos_ref = [1, 2, 4, 8, 16]
+        self.assertEqual(p_pos, p_pos_ref)
+
+    def test_data_positions(self):
+        d_pos = compute_data_positions(20)
+        d_pos_ref = [3, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20]
+        self.assertEqual(d_pos, d_pos_ref)
+
+    def test_cover_positions(self):
+        c_pos_ref = {
+            0 : [1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
+            1 : [2, 3, 6, 7, 10, 11, 14, 15, 18, 19],
+            2 : [4, 5, 6, 7, 12, 13, 14, 15, 20],
+            3 : [8, 9, 10, 11, 12, 13, 14, 15],
+            4 : [16, 17, 18, 19, 20]
+        }
+        for i in range(5):
+            c_pos = compute_cover_positions(20, 2**i)
+            self.assertEqual(c_pos, c_pos_ref[i])
+
+    def test_ecc(self, k=15):
+        class DUT(Module):
+            def __init__(self, k):
+                m, n = compute_m_n(k)
+                self.flip = Signal(n + 1)
+
+                # # #
+
+                self.submodules.encoder = ECCEncoder(k)
+                self.submodules.decoder = ECCDecoder(k)
+
+                self.comb += self.decoder.i.eq(self.encoder.o ^ self.flip)
+
+        def generator(dut, k, nvalues, nerrors):
+            dut.errors = 0
+            prng = random.Random(42)
+            yield dut.decoder.enable.eq(1)
+            for i in range(nvalues):
+                data = prng.randrange(2**k-1)
+                yield dut.encoder.i.eq(data)
+                # FIXME: error when fliping parity bit
+                if nerrors == 1:
+                    flip_bit1 = (prng.randrange(len(dut.flip)-2) + 1)
+                    yield dut.flip.eq(1<<flip_bit1)
+                elif nerrors == 2:
+                    flip_bit1 = (prng.randrange(len(dut.flip)-2) + 1)
+                    flip_bit2 = flip_bit1
+                    while flip_bit2 == flip_bit1:
+                        flip_bit2 = (prng.randrange(len(dut.flip)-2) + 1)
+                    yield dut.flip.eq((1<<flip_bit1) | (1<<flip_bit2))
+                yield
+                # if less than 2 errors, check data
+                if nerrors < 2:
+                    if (yield dut.decoder.o) != data:
+                        dut.errors += 1
+                # if 0 error, verify sec == 0 / ded == 0
+                if nerrors == 0:
+                    if (yield dut.decoder.sec) != 0:
+                        dut.errors += 1
+                    if (yield dut.decoder.ded) != 0:
+                        dut.errors += 1
+                # if 1 error, verify sec == 1 / dec == 0
+                elif nerrors == 1:
+                    if (yield dut.decoder.sec) != 1:
+                        dut.errors += 1
+                    if (yield dut.decoder.ded) != 0:
+                        dut.errors += 1
+                # if 2 errors, verify sec == 0 / ded == 1
+                elif nerrors == 2:
+                    if (yield dut.decoder.sec) != 0:
+                        dut.errors += 1
+                    if (yield dut.decoder.ded) != 1:
+                        dut.errors += 1
+
+        for i in range(3):
+            dut = DUT(k)
+            run_simulation(dut, generator(dut, k, 128, i))
+            self.assertEqual(dut.errors, 0)