From: Jacob Lifshay Date: Thu, 23 Dec 2021 04:43:52 +0000 (-0800) Subject: redo grev X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=24fb96b47682cf40d3f5abd9103067122d0641ea;p=nmutil.git redo grev --- diff --git a/src/nmutil/grev.py b/src/nmutil/grev.py index 1ca501f..35b4565 100644 --- a/src/nmutil/grev.py +++ b/src/nmutil/grev.py @@ -5,18 +5,85 @@ # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part # of Horizon 2020 EU Programme 957073. +r"""Generalized bit-reverse. -"""Generalized bit-reverse. See `GRev` for docs. - no: move the -module docstring here, to describe the Grev concept. -* module docs tell you "about the concept and anything generally useful to know" -* class docs are for "how to actually use the class". +https://bugs.libre-soc.org/show_bug.cgi?id=755 + +A generalized bit-reverse is the following operation: +grev(input, chunk_sizes): + for i in range(input.width): + j = i XOR chunk_sizes + output bit i = input bit j + return output + +This is useful because many bit/byte reverse operations can be created by +setting `chunk_sizes` to different values. Some examples for a 64-bit +`grev` operation: +* `0b111111` -- reverse all bits in the 64-bit word +* `0b111000` -- reverse bytes in the 64-bit word +* `0b011000` -- reverse bytes in each 32-bit word independently +* `0b110000` -- reverse order of 16-bit words + +This is implemented by using a series of `log2_width` +`width`-bit wide 2:1 muxes, arranged just like a butterfly network: +https://en.wikipedia.org/wiki/Butterfly_network + +To compute `out = grev(inp, 0bxyz)`, where `x`, `y`, and `z` are single bits, +the following permutation network is used: + + inp[0] inp[1] inp[2] inp[3] inp[4] inp[5] inp[6] inp[7] + | | | | | | | | +the value here is | | | | | | | | +grev(inp, 0b000): | | | | | | | | + | | | | | | | | + + + + + + + + + + |\ /| |\ /| |\ /| |\ /| + | \ / | | \ / | | \ / | | \ / | + | \ / | | \ / | | \ / | | \ / | +swap 1-bit words: | X | | X | | X | | X | + | / \ | | / \ | | / \ | | / \ | + | / \ | | / \ | | / \ | | / \ | + z--Mux z--Mux z--Mux z--Mux z--Mux z--Mux z--Mux z--Mux + | | | | | | | | +the value here is | | | | | | | | +grev(inp, 0b00z): | | | | | | | | + | | | | | | | | + | | +-----|-------+ | | +-----|-------+ + | +-----|-|-----+ | | +-----|-|-----+ | + | | | | | | | | | | | | +swap 2-bit words: | | +-|-----|-----+ | | | +-|-----|-----+ | + +-|-----|-|---+ | | | +-|-----|-|---+ | | | + | | | | | | | | | | | | | | | | + | / | / \ | \ | | / | / \ | \ | + y--Mux y--Mux y--Mux y--Mux y--Mux y--Mux y--Mux y--Mux + | | | | | | | | +the value here is | | | | | | | | +grev(inp, 0b0yz): | | | | | | | | + | | | | | | | | + | | | | +-----|-------|-------|-------+ + | | | +-----|-|-----|-------|-------+ | + | | +-----|-|-----|-|-----|-------+ | | + | +-----|-|-----|-|-----|-|-----+ | | | +swap 4-bit words: | | | | | | | | | | | | + | | | | | | +-|-----|-------|-------|-----+ | + | | | | +-|-----|-|-----|-------|-----+ | | | + | | +-|-----|-|-----|-|-----|-----+ | | | | | + +-|-----|-|-----|-|-----|-|---+ | | | | | | | + | | | | | | | | | | | | | | | | + | / | / | / | / \ | \ | \ | \ | + x--Mux x--Mux x--Mux x--Mux x--Mux x--Mux x--Mux x--Mux + | | | | | | | | +the value here is | | | | | | | | +grev(inp, 0bxyz): | | | | | | | | + | | | | | | | | + out[0] out[1] out[2] out[3] out[4] out[5] out[6] out[7] """ from nmigen.hdl.ast import Signal, Mux, Cat from nmigen.hdl.ast import Assert from nmigen.hdl.dsl import Module from nmigen.hdl.ir import Elaboratable -from nmigen.cli import rtlil +import string def grev(inval, chunk_sizes, log2_width): @@ -38,89 +105,141 @@ def grev(inval, chunk_sizes, log2_width): class GRev(Elaboratable): """Generalized bit-reverse. - https://bugs.libre-soc.org/show_bug.cgi?id=755 - - XXX this is documentation about Grev (the concept) which should be in - the docstring. the class string is reserved for describing how to - *use* the class (describe its inputs and outputs) - - A generalized bit-reverse - also known as a butterfly network - is where - every output bit is the input bit at index `output_bit_index XOR - chunk_sizes` where `chunk_sizes` is the control input. - - This is useful because many bit/byte reverse operations can be created by - setting `chunk_sizes` to different values. Some examples for a 64-bit - `grev` operation: - * `0b111111` -- reverse all bits in the 64-bit word - * `0b111000` -- reverse bytes in the 64-bit word - * `0b011000` -- reverse bytes in each 32-bit word independently - * `0b110000` -- reverse order of 16-bit words - - This is implemented by using a series of `log2_width` 2:1 muxes, exactly - as in a butterfly network: https://en.wikipedia.org/wiki/Butterfly_network - - The 2:1 muxes are arranged to calculate successive `grev`-ed values where - each intermediate value's corresponding `chunk_sizes` is progressively - changed from all zeros to the input `chunk_sizes` by adding one bit at a - time from the LSB to MSB. (XXX i don't understand this at all!) + See the module's documentation for a description of generalized + bit-reverse, as well as the permutation network created by this class. - :reverse_order: if True the butterfly steps are performed - at offsets of 2^N ... 8 4 2. - if False, the order is 2 4 8 ... 2^N + Attributes: + log2_width: int + see __init__'s docs. + msb_first: bool + see __init__'s docs. + width: int + the input/output width of the grev operation. The value is + `2 ** self.log2_width`. + input: Signal with width=self.width + the input value of the grev operation. + chunk_sizes: Signal with width=self.log2_width + the input that describes which bits get swapped. See the module docs + for additional details. + output: Signal with width=self.width + the output value of the grev operation. """ - def __init__(self, log2_width, reverse_order=False): - self.reverse_order = reverse_order # reverses the order of steps + def __init__(self, log2_width, msb_first=False): + """Create a `GRev` instance. + + log2_width: int + the base-2 logarithm of the input/output width of the grev + operation. + msb_first: bool + If `msb_first` is True, then the order will be the reverse of the + standard order -- swapping adjacent 8-bit words, then 4-bit words, + then 2-bit words, then 1-bit words -- using the bits of + `chunk_sizes` from MSB to LSB. + If `msb_first` is False (the default), then the order will be the + standard order -- swapping adjacent 1-bit words, then 2-bit words, + then 4-bit words, then 8-bit words -- using the bits of + `chunk_sizes` from LSB to MSB. + """ self.log2_width = log2_width + self.msb_first = msb_first self.width = 1 << log2_width self.input = Signal(self.width) - # XXX is this an input or output? self.chunk_sizes = Signal(log2_width) self.output = Signal(self.width) + # internal signals exposed for unit tests, should be ignored by + # external users. The signals are created in the constructor because + # that's where all class member variables should *always* be created. + # If we were to create the members in elaborate() instead, it would + # just make the class very confusing to use. + # + # `_intermediates[step_count]`` is the value after `step_count` steps + # of muxing. e.g. (for `msb_first == False`) `_intermediates[4]` is the + # result of 4 steps of muxing, being the value `grev(inp,0b00wxyz)`. + self._intermediates = [self.__inter(i) for i in range(log2_width + 1)] + + def _get_cs_bit_index(self, step_index): + """get the index of the bit of `chunk_sizes` that this step should mux + based off of.""" + assert 0 <= step_index < self.log2_width + if self.msb_first: + # reverse so we start from the MSB, producing intermediate values + # like, for `step_index == 4`, `0buvwx00` rather than `0b00wxyz` + return self.log2_width - step_index - 1 + return step_index + + def __inter(self, step_count): + """make a signal with a name like `grev(inp,0b000xyz)` to match the + diagram in the module-level docs.""" + # make the list of bits in LSB to MSB order + chunk_sizes_bits = ['0'] * self.log2_width + # for all steps already completed + for step_index in range(step_count): + bit_num = self._get_cs_bit_index(step_index) + ch = string.ascii_lowercase[-1 - bit_num] # count from z to a + chunk_sizes_bits[bit_num] = ch + # reverse cuz text is MSB first + chunk_sizes_val = '0b' + ''.join(reversed(chunk_sizes_bits)) + # name works according to Verilog's rules for escaped identifiers cuz + # it has no spaces + name = f"grev(inp,{chunk_sizes_val})" + return Signal(self.width, name=name) + + def __get_permutation(self, step_index): + """get the bit permutation for the current step. the returned value is + a list[int] where `retval[i] == j` means that this step's input bit `i` + goes to this step's output bit `j`.""" + # we can extract just the latest bit for this step, since the previous + # step effectively has it's value's grev arg as `0b000xyz`, and this + # step has it's value's grev arg as `0b00wxyz`, so we only need to + # compute `grev(prev_step_output,0b00w000)` to get + # `grev(inp,0b00wxyz)`. `cur_chunk_sizes` is the `0b00w000`. + cur_chunk_sizes = 1 << self._get_cs_bit_index(step_index) + # compute bit permutation for `grev(...,0b00w000)`. + return [i ^ cur_chunk_sizes for i in range(self.width)] + + def _sigs_and_expected(self, inp, chunk_sizes): + """the intermediate signals and the expected values, based off of the + passed-in `inp` and `chunk_sizes`.""" + # we accumulate a mask of which chunk_sizes bits we have accounted for + # so far + chunk_sizes_mask = 0 + for step_count, intermediate in enumerate(self._intermediates): + # mask out chunk_sizes to get the value + cur_chunk_sizes = chunk_sizes & chunk_sizes_mask + expected = grev(inp, cur_chunk_sizes, self.log2_width) + yield (intermediate, expected) + # if step_count is in-range for being a valid step_index + if step_count < self.log2_width: + # add current step's bit to the mask + chunk_sizes_mask |= 1 << self._get_cs_bit_index(step_count) + assert chunk_sizes_mask == 2 ** self.log2_width - 1, \ + "should have got all the bits in chunk_sizes" + def elaborate(self, platform): m = Module() - comb = m.d.comb - - # accumulate list of internal signals, exposed only for unit testing. - # contains the input, intermediary steps, and the output. - self._steps = [self.input] - - # TODO: no. "see class doc comment for algorithm docs." <-- document - # *in* the code, not "see another location elsewhere" - # (unless it is a repeated text/concept of course, like - # with BitwiseLut, and that's because the API is identical) - # "see elsewhere" entirely defeats the object of the exercise. - # jumping back and forth (page-up, page-down) - # between the text and the code splits attention. - # the purpose of comments is to be able to understand - # (in plain english) the code *at* the point of seeing it - # it should contain "the thoughts going through your head" - # - # demonstrated below (with a rewrite) - - step_i = self.input # start with input as the first step - - # create (reversed?) list of steps - steps = list(range(self.log2_width)) - if self.reverse_order: - steps.reverse() - - for i in steps: - # each chunk is a power-2 jump. - chunk_size = 1 << i - # prepare a list of XOR-swapped bits of this layer/step - butterfly = [step_i[j ^ chunk_size] for j in range(self.width)] - # create muxes here: 1 bit of chunk_sizes decides swap/no-swap - step_o = Signal(self.width, name="step%d" % chunk_size) - comb += step_o.eq(Mux(self.chunk_sizes[i], - Cat(*butterfly), step_i)) - # output becomes input to next layer - step_i = step_o - self._steps.append(step_o) # record steps for test purposes (only) - - # last layer is also the output - comb += self.output.eq(step_o) + + # value after zero steps is just the input + m.d.comb += self._intermediates[0].eq(self.input) + + for step_index in range(self.log2_width): + step_inp = self._intermediates[step_index] + step_out = self._intermediates[step_index + 1] + # get permutation for current step + permutation = self.__get_permutation(step_index) + # figure out which `chunk_sizes` bit we want to pay attention to + # for this step. + sel = self.chunk_sizes[self._get_cs_bit_index(step_index)] + for in_index, out_index in enumerate(permutation): + # use in_index so we get the permuted bit + permuted_bit = step_inp[in_index] + # use out_index so we copy the bit straight thru + straight_bit = step_inp[out_index] + bit = Mux(sel, permuted_bit, straight_bit) + m.d.comb += step_out[out_index].eq(bit) + # value after all steps is just the output + m.d.comb += self.output.eq(self._intermediates[-1]) if platform != 'formal': return m @@ -129,20 +248,16 @@ class GRev(Elaboratable): m.d.comb += Assert(self.output == grev(self.input, self.chunk_sizes, self.log2_width)) - for i, step in enumerate(self._steps): - cur_chunk_sizes = self.chunk_sizes & (2 ** i - 1) - step_expected = grev(self.input, cur_chunk_sizes, self.log2_width) - m.d.comb += Assert(step == step_expected) + for value, expected in self._sigs_and_expected(self.input, + self.chunk_sizes): + m.d.comb += Assert(value == expected) return m def ports(self): return [self.input, self.chunk_sizes, self.output] -# useful to see what is going on: use yosys "read_ilang test_grev.il; show top" -if __name__ == '__main__': - dut = GRev(3) - vl = rtlil.convert(dut, ports=dut.ports()) - with open("test_grev.il", "w") as f: - f.write(vl) +# useful to see what is going on: +# python3 src/nmutil/test/test_grev.py +# yosys <<<"read_ilang sim_test_out/__main__.TestGrev.test_small/0.il; proc; clean -purge; show top" diff --git a/src/nmutil/test/test_grev.py b/src/nmutil/test/test_grev.py index 15e7148..780239d 100644 --- a/src/nmutil/test/test_grev.py +++ b/src/nmutil/test/test_grev.py @@ -14,11 +14,11 @@ from nmutil.sim_util import do_sim, hash_256 class TestGrev(FHDLTestCase): - def test(self): - log2_width = 6 + def tst(self, msb_first, log2_width=6): width = 2 ** log2_width - dut = GRev(log2_width) + dut = GRev(log2_width, msb_first) self.assertEqual(width, dut.width) + self.assertEqual(len(dut._intermediates), log2_width + 1) def case(inval, chunk_sizes): expected = grev(inval, chunk_sizes, log2_width) @@ -30,17 +30,14 @@ class TestGrev(FHDLTestCase): output = yield dut.output with self.subTest(output=hex(output)): self.assertEqual(expected, output) - for i, step in enumerate(dut._steps): - cur_chunk_sizes = chunk_sizes & (2 ** i - 1) - step_expected = grev(inval, cur_chunk_sizes, log2_width) - step = yield step - with self.subTest(i=i, step=hex(step), - cur_chunk_sizes=bin(cur_chunk_sizes), - step_expected=hex(step_expected)): - self.assertEqual(step, step_expected) + for sig, expected in dut._sigs_and_expected(inval, + chunk_sizes): + value = yield sig + with self.subTest(sig=sig.name, value=hex(value), + expected=hex(expected)): + self.assertEqual(value, expected) def process(): - self.assertEqual(len(dut._steps), log2_width + 1) for count in range(width + 1): inval = (1 << count) - 1 for chunk_sizes in range(2 ** log2_width): @@ -56,9 +53,21 @@ class TestGrev(FHDLTestCase): sim.add_process(process) sim.run() - def test_formal(self): + def test(self): + self.tst(msb_first=False) + + def test_msb_first(self): + self.tst(msb_first=True) + + def test_small(self): + self.tst(msb_first=False, log2_width=3) + + def test_small_msb_first(self): + self.tst(msb_first=True, log2_width=3) + + def tst_formal(self, msb_first): log2_width = 4 - dut = GRev(log2_width) + dut = GRev(log2_width, msb_first) m = Module() m.submodules.dut = dut m.d.comb += dut.input.eq(AnyConst(2 ** log2_width)) @@ -66,6 +75,12 @@ class TestGrev(FHDLTestCase): # actual formal correctness proof is inside the module itself, now self.assertFormal(m) + def test_formal(self): + self.tst_formal(msb_first=False) + + def test_formal_msb_first(self): + self.tst_formal(msb_first=True) + if __name__ == "__main__": unittest.main()