8f79c5e3aee4b961570f8e9967a96ea8f529e5e2
[nmigen-gf.git] / src / nmigen_gf / hdl / clmul.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 """ Carry-less Multiplication.
8
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
10 """
11
12 import operator
13 from nmigen.hdl.ir import Elaboratable
14 from nmigen.hdl.ast import Signal, Repl
15 from nmigen.hdl.dsl import Module
16 from nmutil.util import treereduce
17
18
19 class CLMulAdd(Elaboratable):
20 """Carry-less multiply-add. (optional add)
21
22 Computes:
23 ```
24 self.output = (clmul(self.factor1, self.factor2) ^
25 self.terms[0] ^
26 self.terms[1] ^
27 self.terms[2] ...)
28 ```
29
30 Properties:
31 factor_width: int
32 the bit-width of `factor1` and `factor2`
33 term_widths: tuple[int, ...]
34 the bit-width of each Signal in `terms`
35 factor1: Signal of width self.factor_width
36 the first input to the carry-less multiplication section
37 factor2: Signal of width self.factor_width
38 the second input to the carry-less multiplication section
39 terms: tuple[Signal, ...]
40 inputs to be carry-less added (really XOR)
41 output: Signal
42 the final output
43 """
44
45 def __init__(self, factor_width, term_widths=()):
46 assert isinstance(factor_width, int) and factor_width >= 1
47 self.factor_width = factor_width
48 self.term_widths = tuple(map(int, term_widths))
49
50 # build Signals
51 self.factor1 = Signal(self.factor_width)
52 self.factor2 = Signal(self.factor_width)
53
54 # build terms at requested widths (if any)
55 self.terms = []
56 for i, inp in enumerate(self.term_widths):
57 self.terms.append(Signal(inp, name=f"term_%d" % i))
58
59 # build output at the maximum bit-width covering all inputs
60 self.output = Signal(max((self.factor_width * 2 - 1,
61 *self.term_widths)))
62
63 def elaborate(self, platform):
64 m = Module()
65
66 part_prods = []
67 for shift in range(self.factor_width):
68 part_prod = Signal(self.output.width, name=f"part_prod_{shift}")
69 mask = Repl(self.factor2[shift], self.factor_width)
70 m.d.comb += part_prod.eq((self.factor1 & mask) << shift)
71 part_prods.append(part_prod)
72
73 output = treereduce(part_prods + self.terms, operator.xor)
74 m.d.comb += self.output.eq(output)
75 return m