1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2021 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 """ Carry-less Multiplication.
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
12 from functools
import reduce
13 from operator
import xor
14 from nmigen
.hdl
.ir
import Elaboratable
15 from nmigen
.hdl
.ast
import Signal
, Cat
, Repl
, Value
16 from nmigen
.hdl
.dsl
import Module
19 class BitwiseXorReduce(Elaboratable
):
20 """Bitwise Xor lots of stuff together by using tree-reduction on each bit.
23 input_values: tuple[Value, ...]
26 output, set to `input_values[0] ^ input_values[1] ^ input_values[2]...`
29 def __init__(self
, input_values
):
30 self
.input_values
= tuple(map(Value
.cast
, input_values
))
31 assert len(self
.input_values
) > 0, "can't xor-reduce nothing"
32 self
.output
= Signal(reduce(xor
, self
.input_values
).shape())
34 def elaborate(self
, platform
):
36 # collect inputs into full-width Signals
38 for i
, inp_v
in enumerate(self
.input_values
):
39 inp
= self
.output
.like(self
.output
, name
=f
"input_{i}")
40 # sign/zero-extend inp_v to full-width
41 m
.d
.comb
+= inp
.eq(inp_v
)
43 for bit
in range(self
.output
.width
):
44 # construct a tree-reduction for bit index `bit` of all inputs
45 m
.d
.comb
+= self
.output
[bit
].eq(Cat(i
[bit
] for i
in inputs
).xor())
49 class CLMulAdd(Elaboratable
):
50 """Carry-less multiply-add.
54 self.output = (clmul(self.factor1, self.factor2) ^ self.terms[0]
55 ^ self.terms[1] ^ self.terms[2] ...)
60 the bit-width of `factor1` and `factor2`
61 term_widths: tuple[int, ...]
62 the bit-width of each Signal in `terms`
63 factor1: Signal of width self.factor_width
64 the first input to the carry-less multiplication section
65 factor2: Signal of width self.factor_width
66 the second input to the carry-less multiplication section
67 terms: tuple[Signal, ...]
68 inputs to be carry-less added (really XOR)
73 def __init__(self
, factor_width
, term_widths
=()):
74 assert isinstance(factor_width
, int) and factor_width
>= 1
75 self
.factor_width
= factor_width
76 self
.term_widths
= tuple(map(int, term_widths
))
79 self
.factor1
= Signal(self
.factor_width
)
80 self
.factor2
= Signal(self
.factor_width
)
83 for i
, inp
in enumerate(self
.term_widths
):
84 yield Signal(inp
, name
=f
"term_{i}")
85 self
.terms
= tuple(terms())
86 self
.output
= Signal(max((self
.factor_width
* 2 - 1,
89 def __reduce_inputs(self
):
90 for shift
in range(self
.factor_width
):
91 mask
= Repl(self
.factor2
[shift
], self
.factor_width
)
92 yield (self
.factor1
& mask
) << shift
95 def elaborate(self
, platform
):
97 xor_reduce
= BitwiseXorReduce(self
.__reduce
_inputs
())
98 m
.submodules
.xor_reduce
= xor_reduce
99 m
.d
.comb
+= self
.output
.eq(xor_reduce
.output
)