format code
[nmigen-gf.git] / src / nmigen_gf / hdl / clmul.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 """ Carry-less Multiplication.
8
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
10 """
11
12 from functools import reduce
13 from operator import xor # XXX import operator then use operator.xor
14 from nmigen.hdl.ir import Elaboratable
15 from nmigen.hdl.ast import Signal, Cat, Repl, Value
16 from nmigen.hdl.dsl import Module
17
18
19 # XXX class to be removed https://bugs.libre-soc.org/show_bug.cgi?id=784
20 # functionality covered already entirely by nmutil.util.tree_reduce
21 class BitwiseXorReduce(Elaboratable):
22 """Bitwise Xor lots of stuff together by using tree-reduction on each bit.
23
24 Properties:
25 input_values: tuple[Value, ...]
26 input nmigen Values
27 output: Signal
28 output, set to `input_values[0] ^ input_values[1] ^ input_values[2]...`
29 """
30
31 def __init__(self, input_values):
32 self.input_values = tuple(map(Value.cast, input_values))
33 assert len(self.input_values) > 0, "can't xor-reduce nothing"
34 self.output = Signal(reduce(xor, self.input_values).shape())
35
36 def elaborate(self, platform):
37 m = Module()
38 # collect inputs into full-width Signals
39 inputs = []
40 for i, inp_v in enumerate(self.input_values):
41 inp = self.output.like(self.output, name=f"input_{i}")
42 # sign/zero-extend inp_v to full-width
43 m.d.comb += inp.eq(inp_v)
44 inputs.append(inp)
45 for bit in range(self.output.width):
46 # construct a tree-reduction for bit index `bit` of all inputs
47 m.d.comb += self.output[bit].eq(Cat(i[bit] for i in inputs).xor())
48 return m
49
50
51 class CLMulAdd(Elaboratable):
52 """Carry-less multiply-add. (optional add)
53
54 Computes:
55 ```
56 self.output = (clmul(self.factor1, self.factor2) ^
57 self.terms[0] ^
58 self.terms[1] ^
59 self.terms[2] ...)
60 ```
61
62 Properties:
63 factor_width: int
64 the bit-width of `factor1` and `factor2`
65 term_widths: tuple[int, ...]
66 the bit-width of each Signal in `terms`
67 factor1: Signal of width self.factor_width
68 the first input to the carry-less multiplication section
69 factor2: Signal of width self.factor_width
70 the second input to the carry-less multiplication section
71 terms: tuple[Signal, ...]
72 inputs to be carry-less added (really XOR)
73 output: Signal
74 the final output
75 """
76
77 def __init__(self, factor_width, term_widths=()):
78 assert isinstance(factor_width, int) and factor_width >= 1
79 self.factor_width = factor_width
80 self.term_widths = tuple(map(int, term_widths))
81
82 # build Signals
83 self.factor1 = Signal(self.factor_width)
84 self.factor2 = Signal(self.factor_width)
85
86 # build terms at requested widths (if any)
87 self.terms = []
88 for i, inp in enumerate(self.term_widths):
89 self.terms.append(Signal(inp, name=f"term_%d" % i))
90
91 # build output at the maximum bit-width covering all inputs
92 self.output = Signal(max((self.factor_width * 2 - 1,
93 *self.term_widths)))
94
95 # XXX to create temporary Signals for mask-shifted expression.
96 # terms ok.
97 def __reduce_inputs(self, m):
98 for shift in range(self.factor_width):
99 mask = Repl(self.factor2[shift], self.factor_width)
100 yield (self.factor1 & mask) << shift
101 yield from self.terms
102
103 def elaborate(self, platform):
104 m = Module()
105 # XXX to be replaced with nmutil.util.tree_reduce()
106 xor_reduce = BitwiseXorReduce(self.__reduce_inputs())
107 m.submodules.xor_reduce = xor_reduce
108 m.d.comb += self.output.eq(xor_reduce.output)
109 return m