remove uses of BitwiseXorReduce
[nmigen-gf.git] / src / nmigen_gf / hdl / clmul.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 """ Carry-less Multiplication.
8
9 https://bugs.libre-soc.org/show_bug.cgi?id=784
10 """
11
12 from functools import reduce
13 import operator
14 from nmigen.hdl.ir import Elaboratable
15 from nmigen.hdl.ast import Signal, Cat, Repl, Value
16 from nmigen.hdl.dsl import Module
17 from nmutil.util import treereduce
18
19
20 # XXX class to be removed https://bugs.libre-soc.org/show_bug.cgi?id=784
21 # functionality covered already entirely by nmutil.util.tree_reduce
22 class BitwiseXorReduce(Elaboratable):
23 """Bitwise Xor lots of stuff together by using tree-reduction on each bit.
24
25 Properties:
26 input_values: tuple[Value, ...]
27 input nmigen Values
28 output: Signal
29 output, set to `input_values[0] ^ input_values[1] ^ input_values[2]...`
30 """
31
32 def __init__(self, input_values):
33 self.input_values = tuple(map(Value.cast, input_values))
34 assert len(self.input_values) > 0, "can't xor-reduce nothing"
35 self.output = Signal(reduce(operator.xor, self.input_values).shape())
36
37 def elaborate(self, platform):
38 m = Module()
39 # collect inputs into full-width Signals
40 inputs = []
41 for i, inp_v in enumerate(self.input_values):
42 inp = self.output.like(self.output, name=f"input_{i}")
43 # sign/zero-extend inp_v to full-width
44 m.d.comb += inp.eq(inp_v)
45 inputs.append(inp)
46 for bit in range(self.output.width):
47 # construct a tree-reduction for bit index `bit` of all inputs
48 m.d.comb += self.output[bit].eq(Cat(i[bit] for i in inputs).xor())
49 return m
50
51
52 class CLMulAdd(Elaboratable):
53 """Carry-less multiply-add. (optional add)
54
55 Computes:
56 ```
57 self.output = (clmul(self.factor1, self.factor2) ^
58 self.terms[0] ^
59 self.terms[1] ^
60 self.terms[2] ...)
61 ```
62
63 Properties:
64 factor_width: int
65 the bit-width of `factor1` and `factor2`
66 term_widths: tuple[int, ...]
67 the bit-width of each Signal in `terms`
68 factor1: Signal of width self.factor_width
69 the first input to the carry-less multiplication section
70 factor2: Signal of width self.factor_width
71 the second input to the carry-less multiplication section
72 terms: tuple[Signal, ...]
73 inputs to be carry-less added (really XOR)
74 output: Signal
75 the final output
76 """
77
78 def __init__(self, factor_width, term_widths=()):
79 assert isinstance(factor_width, int) and factor_width >= 1
80 self.factor_width = factor_width
81 self.term_widths = tuple(map(int, term_widths))
82
83 # build Signals
84 self.factor1 = Signal(self.factor_width)
85 self.factor2 = Signal(self.factor_width)
86
87 # build terms at requested widths (if any)
88 self.terms = []
89 for i, inp in enumerate(self.term_widths):
90 self.terms.append(Signal(inp, name=f"term_%d" % i))
91
92 # build output at the maximum bit-width covering all inputs
93 self.output = Signal(max((self.factor_width * 2 - 1,
94 *self.term_widths)))
95
96 def elaborate(self, platform):
97 m = Module()
98
99 part_prods = []
100 for shift in range(self.factor_width):
101 part_prod = Signal(self.output.width, name=f"part_prod_{shift}")
102 mask = Repl(self.factor2[shift], self.factor_width)
103 m.d.comb += part_prod.eq((self.factor1 & mask) << shift)
104 part_prods.append(part_prod)
105
106 output = treereduce(part_prods + self.terms, operator.xor)
107 m.d.comb += self.output.eq(output)
108 return m