add CLMulAdd and tests
[nmutil.git] / src / nmutil / clmul.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2021 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 """ Carry-less Multiplication.
8
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
10 """
11
12 from functools import reduce
13 from operator import xor
14 from nmigen.hdl.ir import Elaboratable
15 from nmigen.hdl.ast import Signal, Cat, Repl, Value
16 from nmigen.hdl.dsl import Module
17
18
19 class BitwiseXorReduce(Elaboratable):
20 """Bitwise Xor lots of stuff together by using tree-reduction on each bit.
21
22 Properties:
23 input_values: tuple[Value, ...]
24 input nmigen Values
25 output: Signal
26 output, set to `input_values[0] ^ input_values[1] ^ input_values[2]...`
27 """
28
29 def __init__(self, input_values):
30 self.input_values = tuple(map(Value.cast, input_values))
31 assert len(self.input_values) > 0, "can't xor-reduce nothing"
32 self.output = Signal(reduce(xor, self.input_values).shape())
33
34 def elaborate(self, platform):
35 m = Module()
36 # collect inputs into full-width Signals
37 inputs = []
38 for i, inp_v in enumerate(self.input_values):
39 inp = self.output.like(self.output, name=f"input_{i}")
40 # sign/zero-extend inp_v to full-width
41 m.d.comb += inp.eq(inp_v)
42 inputs.append(inp)
43 for bit in range(self.output.width):
44 # construct a tree-reduction for bit index `bit` of all inputs
45 m.d.comb += self.output[bit].eq(Cat(i[bit] for i in inputs).xor())
46 return m
47
48
49 class CLMulAdd(Elaboratable):
50 """Carry-less multiply-add.
51
52 Computes:
53 ```
54 self.output = (clmul(self.factor1, self.factor2) ^ self.terms[0]
55 ^ self.terms[1] ^ self.terms[2] ...)
56 ```
57
58 Properties:
59 factor_width: int
60 the bit-width of `factor1` and `factor2`
61 term_widths: tuple[int, ...]
62 the bit-width of each Signal in `terms`
63 factor1: Signal of width self.factor_width
64 the first input to the carry-less multiplication section
65 factor2: Signal of width self.factor_width
66 the second input to the carry-less multiplication section
67 terms: tuple[Signal, ...]
68 inputs to be carry-less added (really XOR)
69 output: Signal
70 the final output
71 """
72
73 def __init__(self, factor_width, term_widths=()):
74 assert isinstance(factor_width, int) and factor_width >= 1
75 self.factor_width = factor_width
76 self.term_widths = tuple(map(int, term_widths))
77
78 # build Signals
79 self.factor1 = Signal(self.factor_width)
80 self.factor2 = Signal(self.factor_width)
81
82 def terms():
83 for i, inp in enumerate(self.term_widths):
84 yield Signal(inp, name=f"term_{i}")
85 self.terms = tuple(terms())
86 self.output = Signal(max((self.factor_width * 2 - 1,
87 *self.term_widths)))
88
89 def __reduce_inputs(self):
90 for shift in range(self.factor_width):
91 mask = Repl(self.factor2[shift], self.factor_width)
92 yield (self.factor1 & mask) << shift
93 yield from self.terms
94
95 def elaborate(self, platform):
96 m = Module()
97 xor_reduce = BitwiseXorReduce(self.__reduce_inputs())
98 m.submodules.xor_reduce = xor_reduce
99 m.d.comb += self.output.eq(xor_reduce.output)
100 return m