move clmul files from nmutil.git
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 4 Apr 2022 05:37:46 +0000 (22:37 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 4 Apr 2022 05:37:46 +0000 (22:37 -0700)
https://git.libre-soc.org/?p=nmutil.git;a=commit;h=2ef87c06d25b692ede35aa6340a108ba410b440a
nmutil.git/src/nmutil/clmul.py => src/nmigen_gf/hdl/clmul.py
nmutil.git/src/nmutil/test/test_clmul.py => src/nmigen_gf/hdl/test/test_clmul.py

src/nmigen_gf/hdl/clmul.py [new file with mode: 0644]
src/nmigen_gf/hdl/test/test_clmul.py [new file with mode: 0644]

diff --git a/src/nmigen_gf/hdl/clmul.py b/src/nmigen_gf/hdl/clmul.py
new file mode 100644 (file)
index 0000000..d25174f
--- /dev/null
@@ -0,0 +1,100 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+""" Carry-less Multiplication.
+
+https://bugs.libre-soc.org/show_bug.cgi?id=784
+"""
+
+from functools import reduce
+from operator import xor
+from nmigen.hdl.ir import Elaboratable
+from nmigen.hdl.ast import Signal, Cat, Repl, Value
+from nmigen.hdl.dsl import Module
+
+
+class BitwiseXorReduce(Elaboratable):
+    """Bitwise Xor lots of stuff together by using tree-reduction on each bit.
+
+    Properties:
+    input_values: tuple[Value, ...]
+        input nmigen Values
+    output: Signal
+        output, set to `input_values[0] ^ input_values[1] ^ input_values[2]...`
+    """
+
+    def __init__(self, input_values):
+        self.input_values = tuple(map(Value.cast, input_values))
+        assert len(self.input_values) > 0, "can't xor-reduce nothing"
+        self.output = Signal(reduce(xor, self.input_values).shape())
+
+    def elaborate(self, platform):
+        m = Module()
+        # collect inputs into full-width Signals
+        inputs = []
+        for i, inp_v in enumerate(self.input_values):
+            inp = self.output.like(self.output, name=f"input_{i}")
+            # sign/zero-extend inp_v to full-width
+            m.d.comb += inp.eq(inp_v)
+            inputs.append(inp)
+        for bit in range(self.output.width):
+            # construct a tree-reduction for bit index `bit` of all inputs
+            m.d.comb += self.output[bit].eq(Cat(i[bit] for i in inputs).xor())
+        return m
+
+
+class CLMulAdd(Elaboratable):
+    """Carry-less multiply-add.
+
+        Computes:
+        ```
+        self.output = (clmul(self.factor1, self.factor2) ^ self.terms[0]
+            ^ self.terms[1] ^ self.terms[2] ...)
+        ```
+
+        Properties:
+        factor_width: int
+            the bit-width of `factor1` and `factor2`
+        term_widths: tuple[int, ...]
+            the bit-width of each Signal in `terms`
+        factor1: Signal of width self.factor_width
+            the first input to the carry-less multiplication section
+        factor2: Signal of width self.factor_width
+            the second input to the carry-less multiplication section
+        terms: tuple[Signal, ...]
+            inputs to be carry-less added (really XOR)
+        output: Signal
+            the final output
+    """
+
+    def __init__(self, factor_width, term_widths=()):
+        assert isinstance(factor_width, int) and factor_width >= 1
+        self.factor_width = factor_width
+        self.term_widths = tuple(map(int, term_widths))
+
+        # build Signals
+        self.factor1 = Signal(self.factor_width)
+        self.factor2 = Signal(self.factor_width)
+
+        def terms():
+            for i, inp in enumerate(self.term_widths):
+                yield Signal(inp, name=f"term_{i}")
+        self.terms = tuple(terms())
+        self.output = Signal(max((self.factor_width * 2 - 1,
+                                  *self.term_widths)))
+
+    def __reduce_inputs(self):
+        for shift in range(self.factor_width):
+            mask = Repl(self.factor2[shift], self.factor_width)
+            yield (self.factor1 & mask) << shift
+        yield from self.terms
+
+    def elaborate(self, platform):
+        m = Module()
+        xor_reduce = BitwiseXorReduce(self.__reduce_inputs())
+        m.submodules.xor_reduce = xor_reduce
+        m.d.comb += self.output.eq(xor_reduce.output)
+        return m
diff --git a/src/nmigen_gf/hdl/test/test_clmul.py b/src/nmigen_gf/hdl/test/test_clmul.py
new file mode 100644 (file)
index 0000000..ffff3f0
--- /dev/null
@@ -0,0 +1,170 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+from functools import reduce
+from operator import xor
+import unittest
+from nmigen.hdl.ast import (AnyConst, Assert, Signal, Const, unsigned, signed,
+                            Mux)
+from nmigen.hdl.dsl import Module
+from nmutil.formaltest import FHDLTestCase
+from nmutil.openpower_sv_bitmanip_in_wiki.clmul import clmul
+from nmutil.clmul import BitwiseXorReduce, CLMulAdd
+from nmigen.sim import Delay
+from nmutil.sim_util import do_sim, hash_256
+
+
+class TestBitwiseXorReduce(FHDLTestCase):
+    def tst(self, input_shapes):
+        dut = BitwiseXorReduce(Signal(w, name=f"input_{i}")
+                               for i, w in enumerate(input_shapes))
+        self.assertEqual(reduce(xor, dut.input_values).shape(),
+                         dut.output.shape())
+
+        def case(inputs):
+            expected = reduce(xor, inputs)
+            with self.subTest(inputs=list(map(hex, inputs)),
+                              expected=hex(expected)):
+                for i, inp in enumerate(inputs):
+                    yield dut.input_values[i].eq(inp)
+                yield Delay(1e-6)
+                output = yield dut.output
+                with self.subTest(output=hex(output)):
+                    self.assertEqual(expected, output)
+
+        def process():
+            for i in range(100):
+                inputs = []
+                for inp in dut.input_values:
+                    v = hash_256(f"bxorr input {i} {inp.name}")
+                    inputs.append(Const.normalize(v, inp.shape()))
+                yield from case(inputs)
+
+        with do_sim(self, dut, [*dut.input_values, dut.output]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+    def tst_formal(self, input_shapes):
+        dut = BitwiseXorReduce(Signal(w, name=f"input_{i}")
+                               for i, w in enumerate(input_shapes))
+        m = Module()
+        m.submodules.dut = dut
+        for i in dut.input_values:
+            m.d.comb += i.eq(AnyConst(i.shape()))
+        m.d.comb += Assert(dut.output == reduce(xor, dut.input_values))
+        self.assertFormal(m)
+
+    def test_65_of_u64(self):
+        self.tst([64] * 65)
+
+    def test_formal_65_of_u64(self):
+        self.tst_formal([64] * 65)
+
+    def test_5_of_u6(self):
+        self.tst([6] * 5)
+
+    def test_formal_5_of_u6(self):
+        self.tst_formal([6] * 5)
+
+    def test_u5_i6_u3_i10(self):
+        self.tst([unsigned(5), signed(6), unsigned(3), signed(10)])
+
+    def test_formal_u5_i6_u3_i10(self):
+        self.tst_formal([unsigned(5), signed(6), unsigned(3), signed(10)])
+
+
+class TestCLMulAdd(FHDLTestCase):
+    def tst(self, factor_width, terms_width):
+        dut = CLMulAdd(factor_width, terms_width)
+        self.assertEqual(dut.output.width,
+                         max((factor_width * 2 - 1, *terms_width)))
+
+        def case(factor1, factor2, terms):
+            expected = reduce(xor, terms, clmul(factor1, factor2))
+            with self.subTest(factor1=hex(factor1),
+                              factor2=bin(factor2),
+                              terms=list(map(hex, terms)),
+                              expected=hex(expected)):
+                yield dut.factor1.eq(factor1)
+                yield dut.factor2.eq(factor2)
+                for i, term in enumerate(terms):
+                    yield dut.terms[i].eq(term)
+                yield Delay(1e-6)
+                output = yield dut.output
+                with self.subTest(output=hex(output)):
+                    self.assertEqual(expected, output)
+
+        def process():
+            for i in range(100):
+                v = hash_256(f"clmuladd term {i} factor1")
+                factor1 = Const.normalize(v, unsigned(factor_width))
+                v = hash_256(f"clmuladd term {i} factor2")
+                factor2 = Const.normalize(v, unsigned(factor_width))
+                terms = []
+                for j, term_width in enumerate(terms_width):
+                    v = hash_256(f"clmuladd term {i} {j}")
+                    terms.append(Const.normalize(v, unsigned(term_width)))
+                yield from case(factor1, factor2, terms)
+        with do_sim(self, dut, [dut.factor1, dut.factor2, *dut.terms,
+                                dut.output]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+    def test_4x4(self):
+        self.tst(4, ())
+
+    def test_4x4_8(self):
+        self.tst(4, (8,))
+
+    def test_64x64(self):
+        self.tst(64, ())
+
+    def test_64x64_64(self):
+        self.tst(64, (64,))
+
+    def test_8x8_16_16_16(self):
+        self.tst(8, (16, 16, 16))
+
+    def tst_formal(self, factor_width, terms_width):
+        dut = CLMulAdd(factor_width, terms_width)
+        m = Module()
+        m.submodules.dut = dut
+        m.d.comb += dut.factor1.eq(AnyConst(factor_width))
+        m.d.comb += dut.factor2.eq(AnyConst(factor_width))
+        reduce_inputs = []
+        for shift in range(factor_width):
+            reduce_inputs.append(
+                Mux(dut.factor1[shift], dut.factor2 << shift, 0))
+        for i in dut.terms:
+            m.d.comb += i.eq(AnyConst(i.shape()))
+            reduce_inputs.append(i)
+        for i in range(len(reduce_inputs)):
+            sig = Signal(reduce_inputs[i].shape(), name=f"reduce_input_{i}")
+            m.d.comb += sig.eq(reduce_inputs[i])
+            reduce_inputs[i] = sig
+        expected = Signal(reduce(xor, reduce_inputs).shape())
+        m.d.comb += expected.eq(reduce(xor, reduce_inputs))
+        m.d.comb += Assert(dut.output == expected)
+        self.assertFormal(m)
+
+    def test_formal_4x4(self):
+        self.tst_formal(4, ())
+
+    def test_formal_4x4_8(self):
+        self.tst_formal(4, (8,))
+
+    def test_formal_64x64(self):
+        self.tst_formal(64, ())
+
+    def test_formal_64x64_64(self):
+        self.tst_formal(64, (64,))
+
+    def test_formal_8x8_16_16_16(self):
+        self.tst_formal(8, (16, 16, 16))
+
+
+if __name__ == "__main__":
+    unittest.main()