redo grev
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 23 Dec 2021 04:43:52 +0000 (20:43 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 23 Dec 2021 04:43:52 +0000 (20:43 -0800)
src/nmutil/grev.py
src/nmutil/test/test_grev.py

index 1ca501fb4a64e9a96884f1242a93dcc3c125ae9f..35b45657eb5895028a88d7927454fbe2398cc27a 100644 (file)
@@ -5,18 +5,85 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
+r"""Generalized bit-reverse.
 
-"""Generalized bit-reverse. See `GRev` for docs. - no: move the
-module docstring here, to describe the Grev concept.
-* module docs tell you "about the concept and anything generally useful to know"
-* class docs are for "how to actually use the class".
+https://bugs.libre-soc.org/show_bug.cgi?id=755
+
+A generalized bit-reverse is the following operation:
+grev(input, chunk_sizes):
+    for i in range(input.width):
+        j = i XOR chunk_sizes
+        output bit i = input bit j
+    return output
+
+This is useful because many bit/byte reverse operations can be created by
+setting `chunk_sizes` to different values. Some examples for a 64-bit
+`grev` operation:
+* `0b111111` -- reverse all bits in the 64-bit word
+* `0b111000` -- reverse bytes in the 64-bit word
+* `0b011000` -- reverse bytes in each 32-bit word independently
+* `0b110000` -- reverse order of 16-bit words
+
+This is implemented by using a series of `log2_width`
+`width`-bit wide 2:1 muxes, arranged just like a butterfly network:
+https://en.wikipedia.org/wiki/Butterfly_network
+
+To compute `out = grev(inp, 0bxyz)`, where `x`, `y`, and `z` are single bits,
+the following permutation network is used:
+
+                inp[0]  inp[1]  inp[2]  inp[3]  inp[4]  inp[5]  inp[6]  inp[7]
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0b000): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                  +       +       +       +       +       +       +       +
+                  |\     /|       |\     /|       |\     /|       |\     /|
+                  | \   / |       | \   / |       | \   / |       | \   / |
+                  |  \ /  |       |  \ /  |       |  \ /  |       |  \ /  |
+swap 1-bit words: |   X   |       |   X   |       |   X   |       |   X   |
+                  |  / \  |       |  / \  |       |  / \  |       |  / \  |
+                  | /   \ |       | /   \ |       | /   \ |       | /   \ |
+              z--Mux  z--Mux  z--Mux  z--Mux  z--Mux  z--Mux  z--Mux  z--Mux
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0b00z): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                  |       | +-----|-------+       |       | +-----|-------+
+                  | +-----|-|-----+       |       | +-----|-|-----+       |
+                  | |     | |     |       |       | |     | |     |       |
+swap 2-bit words: | |     +-|-----|-----+ |       | |     +-|-----|-----+ |
+                  +-|-----|-|---+ |     | |       +-|-----|-|---+ |     | |
+                  | |     | |   | |     | |       | |     | |   | |     | |
+                  | /     | /   \ |     \ |       | /     | /   \ |     \ |
+              y--Mux  y--Mux  y--Mux  y--Mux  y--Mux  y--Mux  y--Mux  y--Mux
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0b0yz): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                  |       |       |       | +-----|-------|-------|-------+
+                  |       |       | +-----|-|-----|-------|-------+       |
+                  |       | +-----|-|-----|-|-----|-------+       |       |
+                  | +-----|-|-----|-|-----|-|-----+       |       |       |
+swap 4-bit words: | |     | |     | |     | |     |       |       |       |
+                  | |     | |     | |     +-|-----|-------|-------|-----+ |
+                  | |     | |     +-|-----|-|-----|-------|-----+ |     | |
+                  | |     +-|-----|-|-----|-|-----|-----+ |     | |     | |
+                  +-|-----|-|-----|-|-----|-|---+ |     | |     | |     | |
+                  | |     | |     | |     | |   | |     | |     | |     | |
+                  | /     | /     | /     | /   \ |     \ |     \ |     \ |
+              x--Mux  x--Mux  x--Mux  x--Mux  x--Mux  x--Mux  x--Mux  x--Mux
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0bxyz): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                out[0]  out[1]  out[2]  out[3]  out[4]  out[5]  out[6]  out[7]
 """
 
 from nmigen.hdl.ast import Signal, Mux, Cat
 from nmigen.hdl.ast import Assert
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Elaboratable
-from nmigen.cli import rtlil
+import string
 
 
 def grev(inval, chunk_sizes, log2_width):
@@ -38,89 +105,141 @@ def grev(inval, chunk_sizes, log2_width):
 class GRev(Elaboratable):
     """Generalized bit-reverse.
 
-    https://bugs.libre-soc.org/show_bug.cgi?id=755
-
-    XXX this is documentation about Grev (the concept) which should be in
-    the docstring.  the class string is reserved for describing how to
-    *use* the class (describe its inputs and outputs)
-
-    A generalized bit-reverse - also known as a butterfly network - is where
-    every output bit is the input bit at index `output_bit_index XOR
-    chunk_sizes` where `chunk_sizes` is the control input.
-
-    This is useful because many bit/byte reverse operations can be created by
-    setting `chunk_sizes` to different values. Some examples for a 64-bit
-    `grev` operation:
-    * `0b111111` -- reverse all bits in the 64-bit word
-    * `0b111000` -- reverse bytes in the 64-bit word
-    * `0b011000` -- reverse bytes in each 32-bit word independently
-    * `0b110000` -- reverse order of 16-bit words
-
-    This is implemented by using a series of `log2_width` 2:1 muxes, exactly
-    as in a butterfly network: https://en.wikipedia.org/wiki/Butterfly_network
-
-    The 2:1 muxes are arranged to calculate successive `grev`-ed values where
-    each intermediate value's corresponding `chunk_sizes` is progressively
-    changed from all zeros to the input `chunk_sizes` by adding one bit at a
-    time from the LSB to MSB.  (XXX i don't understand this at all!)
+    See the module's documentation for a description of generalized
+    bit-reverse, as well as the permutation network created by this class.
 
-    :reverse_order: if True the butterfly steps are performed
-                    at offsets of 2^N ... 8 4 2.
-                    if False, the order is 2 4 8 ... 2^N
+    Attributes:
+    log2_width: int
+        see __init__'s docs.
+    msb_first: bool
+        see __init__'s docs.
+    width: int
+        the input/output width of the grev operation. The value is
+        `2 ** self.log2_width`.
+    input: Signal with width=self.width
+        the input value of the grev operation.
+    chunk_sizes: Signal with width=self.log2_width
+        the input that describes which bits get swapped. See the module docs
+        for additional details.
+    output: Signal with width=self.width
+        the output value of the grev operation.
     """
 
-    def __init__(self, log2_width, reverse_order=False):
-        self.reverse_order = reverse_order    # reverses the order of steps
+    def __init__(self, log2_width, msb_first=False):
+        """Create a `GRev` instance.
+
+        log2_width: int
+            the base-2 logarithm of the input/output width of the grev
+            operation.
+        msb_first: bool
+            If `msb_first` is True, then the order will be the reverse of the
+            standard order -- swapping adjacent 8-bit words, then 4-bit words,
+            then 2-bit words, then 1-bit words -- using the bits of
+            `chunk_sizes` from MSB to LSB.
+            If `msb_first` is False (the default), then the order will be the
+            standard order -- swapping adjacent 1-bit words, then 2-bit words,
+            then 4-bit words, then 8-bit words -- using the bits of
+            `chunk_sizes` from LSB to MSB.
+        """
         self.log2_width = log2_width
+        self.msb_first = msb_first
         self.width = 1 << log2_width
         self.input = Signal(self.width)
-        # XXX is this an input or output?
         self.chunk_sizes = Signal(log2_width)
         self.output = Signal(self.width)
 
+        # internal signals exposed for unit tests, should be ignored by
+        # external users. The signals are created in the constructor because
+        # that's where all class member variables should *always* be created.
+        # If we were to create the members in elaborate() instead, it would
+        # just make the class very confusing to use.
+        #
+        # `_intermediates[step_count]`` is the value after `step_count` steps
+        # of muxing. e.g. (for `msb_first == False`) `_intermediates[4]` is the
+        # result of 4 steps of muxing, being the value `grev(inp,0b00wxyz)`.
+        self._intermediates = [self.__inter(i) for i in range(log2_width + 1)]
+
+    def _get_cs_bit_index(self, step_index):
+        """get the index of the bit of `chunk_sizes` that this step should mux
+        based off of."""
+        assert 0 <= step_index < self.log2_width
+        if self.msb_first:
+            # reverse so we start from the MSB, producing intermediate values
+            # like, for `step_index == 4`, `0buvwx00` rather than `0b00wxyz`
+            return self.log2_width - step_index - 1
+        return step_index
+
+    def __inter(self, step_count):
+        """make a signal with a name like `grev(inp,0b000xyz)` to match the
+        diagram in the module-level docs."""
+        # make the list of bits in LSB to MSB order
+        chunk_sizes_bits = ['0'] * self.log2_width
+        # for all steps already completed
+        for step_index in range(step_count):
+            bit_num = self._get_cs_bit_index(step_index)
+            ch = string.ascii_lowercase[-1 - bit_num]  # count from z to a
+            chunk_sizes_bits[bit_num] = ch
+        # reverse cuz text is MSB first
+        chunk_sizes_val = '0b' + ''.join(reversed(chunk_sizes_bits))
+        # name works according to Verilog's rules for escaped identifiers cuz
+        # it has no spaces
+        name = f"grev(inp,{chunk_sizes_val})"
+        return Signal(self.width, name=name)
+
+    def __get_permutation(self, step_index):
+        """get the bit permutation for the current step. the returned value is
+        a list[int] where `retval[i] == j` means that this step's input bit `i`
+        goes to this step's output bit `j`."""
+        # we can extract just the latest bit for this step, since the previous
+        # step effectively has it's value's grev arg as `0b000xyz`, and this
+        # step has it's value's grev arg as `0b00wxyz`, so we only need to
+        # compute `grev(prev_step_output,0b00w000)` to get
+        # `grev(inp,0b00wxyz)`. `cur_chunk_sizes` is the `0b00w000`.
+        cur_chunk_sizes = 1 << self._get_cs_bit_index(step_index)
+        # compute bit permutation for `grev(...,0b00w000)`.
+        return [i ^ cur_chunk_sizes for i in range(self.width)]
+
+    def _sigs_and_expected(self, inp, chunk_sizes):
+        """the intermediate signals and the expected values, based off of the
+        passed-in `inp` and `chunk_sizes`."""
+        # we accumulate a mask of which chunk_sizes bits we have accounted for
+        # so far
+        chunk_sizes_mask = 0
+        for step_count, intermediate in enumerate(self._intermediates):
+            # mask out chunk_sizes to get the value
+            cur_chunk_sizes = chunk_sizes & chunk_sizes_mask
+            expected = grev(inp, cur_chunk_sizes, self.log2_width)
+            yield (intermediate, expected)
+            # if step_count is in-range for being a valid step_index
+            if step_count < self.log2_width:
+                # add current step's bit to the mask
+                chunk_sizes_mask |= 1 << self._get_cs_bit_index(step_count)
+        assert chunk_sizes_mask == 2 ** self.log2_width - 1, \
+            "should have got all the bits in chunk_sizes"
+
     def elaborate(self, platform):
         m = Module()
-        comb = m.d.comb
-
-        # accumulate list of internal signals, exposed only for unit testing.
-        # contains the input, intermediary steps, and the output.
-        self._steps = [self.input]
-
-        # TODO: no. "see class doc comment for algorithm docs." <-- document
-        #           *in* the code, not "see another location elsewhere"
-        #           (unless it is a repeated text/concept of course, like
-        #            with BitwiseLut, and that's because the API is identical)
-        #           "see elsewhere" entirely defeats the object of the exercise.
-        #           jumping back and forth (page-up, page-down)
-        #           between the text and the code splits attention.
-        #           the purpose of comments is to be able to understand
-        #           (in plain english) the code *at* the point of seeing it
-        #           it should contain "the thoughts going through your head"
-        #
-        #           demonstrated below (with a rewrite)
-
-        step_i = self.input  # start with input as the first step
-
-        # create (reversed?) list of steps
-        steps = list(range(self.log2_width))
-        if self.reverse_order:
-            steps.reverse()
-
-        for i in steps:
-            # each chunk is a power-2 jump.
-            chunk_size = 1 << i
-            # prepare a list of XOR-swapped bits of this layer/step
-            butterfly = [step_i[j ^ chunk_size] for j in range(self.width)]
-            # create muxes here: 1 bit of chunk_sizes decides swap/no-swap
-            step_o = Signal(self.width, name="step%d" % chunk_size)
-            comb += step_o.eq(Mux(self.chunk_sizes[i],
-                                  Cat(*butterfly), step_i))
-            # output becomes input to next layer
-            step_i = step_o
-            self._steps.append(step_o)  # record steps for test purposes (only)
-
-        # last layer is also the output
-        comb += self.output.eq(step_o)
+
+        # value after zero steps is just the input
+        m.d.comb += self._intermediates[0].eq(self.input)
+
+        for step_index in range(self.log2_width):
+            step_inp = self._intermediates[step_index]
+            step_out = self._intermediates[step_index + 1]
+            # get permutation for current step
+            permutation = self.__get_permutation(step_index)
+            # figure out which `chunk_sizes` bit we want to pay attention to
+            # for this step.
+            sel = self.chunk_sizes[self._get_cs_bit_index(step_index)]
+            for in_index, out_index in enumerate(permutation):
+                # use in_index so we get the permuted bit
+                permuted_bit = step_inp[in_index]
+                # use out_index so we copy the bit straight thru
+                straight_bit = step_inp[out_index]
+                bit = Mux(sel, permuted_bit, straight_bit)
+                m.d.comb += step_out[out_index].eq(bit)
+        # value after all steps is just the output
+        m.d.comb += self.output.eq(self._intermediates[-1])
 
         if platform != 'formal':
             return m
@@ -129,20 +248,16 @@ class GRev(Elaboratable):
         m.d.comb += Assert(self.output == grev(self.input,
                                                self.chunk_sizes,
                                                self.log2_width))
-        for i, step in enumerate(self._steps):
-            cur_chunk_sizes = self.chunk_sizes & (2 ** i - 1)
-            step_expected = grev(self.input, cur_chunk_sizes, self.log2_width)
-            m.d.comb += Assert(step == step_expected)
 
+        for value, expected in self._sigs_and_expected(self.input,
+                                                       self.chunk_sizes):
+            m.d.comb += Assert(value == expected)
         return m
 
     def ports(self):
         return [self.input, self.chunk_sizes, self.output]
 
 
-# useful to see what is going on: use yosys "read_ilang test_grev.il; show top"
-if __name__ == '__main__':
-    dut = GRev(3)
-    vl = rtlil.convert(dut, ports=dut.ports())
-    with open("test_grev.il", "w") as f:
-        f.write(vl)
+# useful to see what is going on:
+# python3 src/nmutil/test/test_grev.py
+# yosys <<<"read_ilang sim_test_out/__main__.TestGrev.test_small/0.il; proc; clean -purge; show top"
index 15e71489a738b601646cdb6bc2683550e7e58f13..780239d8a13b2954a7953d5d2e312dd517a80347 100644 (file)
@@ -14,11 +14,11 @@ from nmutil.sim_util import do_sim, hash_256
 
 
 class TestGrev(FHDLTestCase):
-    def test(self):
-        log2_width = 6
+    def tst(self, msb_first, log2_width=6):
         width = 2 ** log2_width
-        dut = GRev(log2_width)
+        dut = GRev(log2_width, msb_first)
         self.assertEqual(width, dut.width)
+        self.assertEqual(len(dut._intermediates), log2_width + 1)
 
         def case(inval, chunk_sizes):
             expected = grev(inval, chunk_sizes, log2_width)
@@ -30,17 +30,14 @@ class TestGrev(FHDLTestCase):
                 output = yield dut.output
                 with self.subTest(output=hex(output)):
                     self.assertEqual(expected, output)
-                for i, step in enumerate(dut._steps):
-                    cur_chunk_sizes = chunk_sizes & (2 ** i - 1)
-                    step_expected = grev(inval, cur_chunk_sizes, log2_width)
-                    step = yield step
-                    with self.subTest(i=i, step=hex(step),
-                                      cur_chunk_sizes=bin(cur_chunk_sizes),
-                                      step_expected=hex(step_expected)):
-                        self.assertEqual(step, step_expected)
+                for sig, expected in dut._sigs_and_expected(inval,
+                                                            chunk_sizes):
+                    value = yield sig
+                    with self.subTest(sig=sig.name, value=hex(value),
+                                      expected=hex(expected)):
+                        self.assertEqual(value, expected)
 
         def process():
-            self.assertEqual(len(dut._steps), log2_width + 1)
             for count in range(width + 1):
                 inval = (1 << count) - 1
                 for chunk_sizes in range(2 ** log2_width):
@@ -56,9 +53,21 @@ class TestGrev(FHDLTestCase):
             sim.add_process(process)
             sim.run()
 
-    def test_formal(self):
+    def test(self):
+        self.tst(msb_first=False)
+
+    def test_msb_first(self):
+        self.tst(msb_first=True)
+
+    def test_small(self):
+        self.tst(msb_first=False, log2_width=3)
+
+    def test_small_msb_first(self):
+        self.tst(msb_first=True, log2_width=3)
+
+    def tst_formal(self, msb_first):
         log2_width = 4
-        dut = GRev(log2_width)
+        dut = GRev(log2_width, msb_first)
         m = Module()
         m.submodules.dut = dut
         m.d.comb += dut.input.eq(AnyConst(2 ** log2_width))
@@ -66,6 +75,12 @@ class TestGrev(FHDLTestCase):
         # actual formal correctness proof is inside the module itself, now
         self.assertFormal(m)
 
+    def test_formal(self):
+        self.tst_formal(msb_first=False)
+
+    def test_formal_msb_first(self):
+        self.tst_formal(msb_first=True)
+
 
 if __name__ == "__main__":
     unittest.main()