speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / grev.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2021 Jacob Lifshay programmerjake@gmail.com
3 # Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4
5 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
6 # of Horizon 2020 EU Programme 957073.
7
8 r"""Generalized bit-reverse.
9
10 https://bugs.libre-soc.org/show_bug.cgi?id=755
11
12 A generalized bit-reverse is the following operation:
13 grev(input, chunk_sizes):
14 for i in range(input.width):
15 j = i XOR chunk_sizes
16 output bit i = input bit j
17 return output
18
19 This is useful because many bit/byte reverse operations can be created by
20 setting `chunk_sizes` to different values. Some examples for a 64-bit
21 `grev` operation:
22 * `0b111111` -- reverse all bits in the 64-bit word
23 * `0b111000` -- reverse bytes in the 64-bit word
24 * `0b011000` -- reverse bytes in each 32-bit word independently
25 * `0b110000` -- reverse order of 16-bit words
26
27 This is implemented by using a series of `log2_width`
28 `width`-bit wide 2:1 muxes, arranged just like a butterfly network:
29 https://en.wikipedia.org/wiki/Butterfly_network
30
31 To compute `out = grev(inp, 0bxyz)`, where `x`, `y`, and `z` are single bits,
32 the following permutation network is used:
33
34 inp[0] inp[1] inp[2] inp[3] inp[4] inp[5] inp[6] inp[7]
35 | | | | | | | |
36 the value here is | | | | | | | |
37 grev(inp, 0b000): | | | | | | | |
38 | | | | | | | |
39 + + + + + + + +
40 |\ /| |\ /| |\ /| |\ /|
41 | \ / | | \ / | | \ / | | \ / |
42 | \ / | | \ / | | \ / | | \ / |
43 swap 1-bit words: | X | | X | | X | | X |
44 | / \ | | / \ | | / \ | | / \ |
45 | / \ | | / \ | | / \ | | / \ |
46 z--Mux z--Mux z--Mux z--Mux z--Mux z--Mux z--Mux z--Mux
47 | | | | | | | |
48 the value here is | | | | | | | |
49 grev(inp, 0b00z): | | | | | | | |
50 | | | | | | | |
51 | | +-----|-------+ | | +-----|-------+
52 | +-----|-|-----+ | | +-----|-|-----+ |
53 | | | | | | | | | | | |
54 swap 2-bit words: | | +-|-----|-----+ | | | +-|-----|-----+ |
55 +-|-----|-|---+ | | | +-|-----|-|---+ | | |
56 | | | | | | | | | | | | | | | |
57 | / | / \ | \ | | / | / \ | \ |
58 y--Mux y--Mux y--Mux y--Mux y--Mux y--Mux y--Mux y--Mux
59 | | | | | | | |
60 the value here is | | | | | | | |
61 grev(inp, 0b0yz): | | | | | | | |
62 | | | | | | | |
63 | | | | +-----|-------|-------|-------+
64 | | | +-----|-|-----|-------|-------+ |
65 | | +-----|-|-----|-|-----|-------+ | |
66 | +-----|-|-----|-|-----|-|-----+ | | |
67 swap 4-bit words: | | | | | | | | | | | |
68 | | | | | | +-|-----|-------|-------|-----+ |
69 | | | | +-|-----|-|-----|-------|-----+ | | |
70 | | +-|-----|-|-----|-|-----|-----+ | | | | |
71 +-|-----|-|-----|-|-----|-|---+ | | | | | | |
72 | | | | | | | | | | | | | | | |
73 | / | / | / | / \ | \ | \ | \ |
74 x--Mux x--Mux x--Mux x--Mux x--Mux x--Mux x--Mux x--Mux
75 | | | | | | | |
76 the value here is | | | | | | | |
77 grev(inp, 0bxyz): | | | | | | | |
78 | | | | | | | |
79 out[0] out[1] out[2] out[3] out[4] out[5] out[6] out[7]
80 """
81
82 from nmigen.hdl.ast import Signal, Mux, Cat
83 from nmigen.hdl.ast import Assert
84 from nmigen.hdl.dsl import Module
85 from nmigen.hdl.ir import Elaboratable
86 from nmigen.back import rtlil
87 import string
88
89
90 def grev(inval, chunk_sizes, log2_width):
91 """Python reference implementation of generalized bit-reverse.
92 See `GRev` for documentation.
93 """
94 # mask inputs into range
95 inval &= 2 ** 2 ** log2_width - 1
96 chunk_sizes &= 2 ** log2_width - 1
97 # core algorithm:
98 retval = 0
99 for i in range(2 ** log2_width):
100 # don't use `if` so this can be used with nmigen values
101 bit = (inval & (1 << i)) != 0
102 retval |= bit << (i ^ chunk_sizes)
103 return retval
104
105
106 class GRev(Elaboratable):
107 """Generalized bit-reverse.
108
109 See the module's documentation for a description of generalized
110 bit-reverse, as well as the permutation network created by this class.
111
112 Attributes:
113 log2_width: int
114 see __init__'s docs.
115 msb_first: bool
116 see __init__'s docs.
117 width: int
118 the input/output width of the grev operation. The value is
119 `2 ** self.log2_width`.
120 input: Signal with width=self.width
121 the input value of the grev operation.
122 chunk_sizes: Signal with width=self.log2_width
123 the input that describes which bits get swapped. See the module docs
124 for additional details.
125 output: Signal with width=self.width
126 the output value of the grev operation.
127 """
128
129 def __init__(self, log2_width, msb_first=False):
130 """Create a `GRev` instance.
131
132 log2_width: int
133 the base-2 logarithm of the input/output width of the grev
134 operation.
135 msb_first: bool
136 If `msb_first` is True, then the order will be the reverse of the
137 standard order -- swapping adjacent 8-bit words, then 4-bit words,
138 then 2-bit words, then 1-bit words -- using the bits of
139 `chunk_sizes` from MSB to LSB.
140 If `msb_first` is False (the default), then the order will be the
141 standard order -- swapping adjacent 1-bit words, then 2-bit words,
142 then 4-bit words, then 8-bit words -- using the bits of
143 `chunk_sizes` from LSB to MSB.
144 """
145 self.log2_width = log2_width
146 self.msb_first = msb_first
147 self.width = 1 << log2_width
148 self.input = Signal(self.width)
149 self.chunk_sizes = Signal(log2_width)
150 self.output = Signal(self.width)
151
152 # internal signals exposed for unit tests, should be ignored by
153 # external users. The signals are created in the constructor because
154 # that's where all class member variables should *always* be created.
155 # If we were to create the members in elaborate() instead, it would
156 # just make the class very confusing to use.
157 #
158 # `_intermediates[step_count]`` is the value after `step_count` steps
159 # of muxing. e.g. (for `msb_first == False`) `_intermediates[4]` is the
160 # result of 4 steps of muxing, being the value `grev(inp,0b00wxyz)`.
161 self._intermediates = [self.__inter(i) for i in range(log2_width + 1)]
162
163 def _get_cs_bit_index(self, step_index):
164 """get the index of the bit of `chunk_sizes` that this step should mux
165 based off of."""
166 assert 0 <= step_index < self.log2_width
167 if self.msb_first:
168 # reverse so we start from the MSB, producing intermediate values
169 # like, for `step_index == 4`, `0buvwx00` rather than `0b00wxyz`
170 return self.log2_width - step_index - 1
171 return step_index
172
173 def __inter(self, step_count):
174 """make a signal with a name like `grev(inp,0b000xyz)` to match the
175 diagram in the module-level docs."""
176 # make the list of bits in LSB to MSB order
177 chunk_sizes_bits = ['0'] * self.log2_width
178 # for all steps already completed
179 for step_index in range(step_count):
180 bit_num = self._get_cs_bit_index(step_index)
181 ch = string.ascii_lowercase[-1 - bit_num] # count from z to a
182 chunk_sizes_bits[bit_num] = ch
183 # reverse cuz text is MSB first
184 chunk_sizes_val = '0b' + ''.join(reversed(chunk_sizes_bits))
185 # name works according to Verilog's rules for escaped identifiers cuz
186 # it has no spaces
187 name = f"grev(inp,{chunk_sizes_val})"
188 return Signal(self.width, name=name)
189
190 def __get_permutation(self, step_index):
191 """get the bit permutation for the current step. the returned value is
192 a list[int] where `retval[i] == j` means that this step's input bit `i`
193 goes to this step's output bit `j`."""
194 # we can extract just the latest bit for this step, since the previous
195 # step effectively has it's value's grev arg as `0b000xyz`, and this
196 # step has it's value's grev arg as `0b00wxyz`, so we only need to
197 # compute `grev(prev_step_output,0b00w000)` to get
198 # `grev(inp,0b00wxyz)`. `cur_chunk_sizes` is the `0b00w000`.
199 cur_chunk_sizes = 1 << self._get_cs_bit_index(step_index)
200 # compute bit permutation for `grev(...,0b00w000)`.
201 return [i ^ cur_chunk_sizes for i in range(self.width)]
202
203 def _sigs_and_expected(self, inp, chunk_sizes):
204 """the intermediate signals and the expected values, based off of the
205 passed-in `inp` and `chunk_sizes`."""
206 # we accumulate a mask of which chunk_sizes bits we have accounted for
207 # so far
208 chunk_sizes_mask = 0
209 for step_count, intermediate in enumerate(self._intermediates):
210 # mask out chunk_sizes to get the value
211 cur_chunk_sizes = chunk_sizes & chunk_sizes_mask
212 expected = grev(inp, cur_chunk_sizes, self.log2_width)
213 yield (intermediate, expected)
214 # if step_count is in-range for being a valid step_index
215 if step_count < self.log2_width:
216 # add current step's bit to the mask
217 chunk_sizes_mask |= 1 << self._get_cs_bit_index(step_count)
218 assert chunk_sizes_mask == 2 ** self.log2_width - 1, \
219 "should have got all the bits in chunk_sizes"
220
221 def elaborate(self, platform):
222 m = Module()
223
224 # value after zero steps is just the input
225 m.d.comb += self._intermediates[0].eq(self.input)
226
227 for step_index in range(self.log2_width):
228 step_inp = self._intermediates[step_index]
229 step_out = self._intermediates[step_index + 1]
230 # get permutation for current step
231 permutation = self.__get_permutation(step_index)
232 # figure out which `chunk_sizes` bit we want to pay attention to
233 # for this step.
234 sel = self.chunk_sizes[self._get_cs_bit_index(step_index)]
235 for in_index, out_index in enumerate(permutation):
236 # use in_index so we get the permuted bit
237 permuted_bit = step_inp[in_index]
238 # use out_index so we copy the bit straight thru
239 straight_bit = step_inp[out_index]
240 bit = Mux(sel, permuted_bit, straight_bit)
241 m.d.comb += step_out[out_index].eq(bit)
242 # value after all steps is just the output
243 m.d.comb += self.output.eq(self._intermediates[-1])
244
245 if platform != 'formal':
246 return m
247
248 # formal test comparing directly against the (simpler) version
249 m.d.comb += Assert(self.output == grev(self.input,
250 self.chunk_sizes,
251 self.log2_width))
252
253 for value, expected in self._sigs_and_expected(self.input,
254 self.chunk_sizes):
255 m.d.comb += Assert(value == expected)
256 return m
257
258 def ports(self):
259 return [self.input, self.chunk_sizes, self.output]
260
261
262 # useful to see what is going on:
263 # python3 src/nmutil/test/test_grev.py
264 # yosys <<<"read_ilang sim_test_out/__main__.TestGrev.test_small/0.il; proc; clean -purge; show top"
265
266 if __name__ == '__main__':
267 dut = GRev(3)
268 vl = rtlil.convert(dut, ports=dut.ports())
269 with open("grev3.il", "w") as f:
270 f.write(vl)