redo grev
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 23 Dec 2021 04:43:52 +0000 (20:43 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 23 Dec 2021 04:43:52 +0000 (20:43 -0800)
src/nmutil/grev.py
src/nmutil/test/test_grev.py

index 1ca501fb4a64e9a96884f1242a93dcc3c125ae9f..35b45657eb5895028a88d7927454fbe2398cc27a 100644 (file)
@@ -5,18 +5,85 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
+r"""Generalized bit-reverse.
 
 
-"""Generalized bit-reverse. See `GRev` for docs. - no: move the
-module docstring here, to describe the Grev concept.
-* module docs tell you "about the concept and anything generally useful to know"
-* class docs are for "how to actually use the class".
+https://bugs.libre-soc.org/show_bug.cgi?id=755
+
+A generalized bit-reverse is the following operation:
+grev(input, chunk_sizes):
+    for i in range(input.width):
+        j = i XOR chunk_sizes
+        output bit i = input bit j
+    return output
+
+This is useful because many bit/byte reverse operations can be created by
+setting `chunk_sizes` to different values. Some examples for a 64-bit
+`grev` operation:
+* `0b111111` -- reverse all bits in the 64-bit word
+* `0b111000` -- reverse bytes in the 64-bit word
+* `0b011000` -- reverse bytes in each 32-bit word independently
+* `0b110000` -- reverse order of 16-bit words
+
+This is implemented by using a series of `log2_width`
+`width`-bit wide 2:1 muxes, arranged just like a butterfly network:
+https://en.wikipedia.org/wiki/Butterfly_network
+
+To compute `out = grev(inp, 0bxyz)`, where `x`, `y`, and `z` are single bits,
+the following permutation network is used:
+
+                inp[0]  inp[1]  inp[2]  inp[3]  inp[4]  inp[5]  inp[6]  inp[7]
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0b000): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                  +       +       +       +       +       +       +       +
+                  |\     /|       |\     /|       |\     /|       |\     /|
+                  | \   / |       | \   / |       | \   / |       | \   / |
+                  |  \ /  |       |  \ /  |       |  \ /  |       |  \ /  |
+swap 1-bit words: |   X   |       |   X   |       |   X   |       |   X   |
+                  |  / \  |       |  / \  |       |  / \  |       |  / \  |
+                  | /   \ |       | /   \ |       | /   \ |       | /   \ |
+              z--Mux  z--Mux  z--Mux  z--Mux  z--Mux  z--Mux  z--Mux  z--Mux
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0b00z): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                  |       | +-----|-------+       |       | +-----|-------+
+                  | +-----|-|-----+       |       | +-----|-|-----+       |
+                  | |     | |     |       |       | |     | |     |       |
+swap 2-bit words: | |     +-|-----|-----+ |       | |     +-|-----|-----+ |
+                  +-|-----|-|---+ |     | |       +-|-----|-|---+ |     | |
+                  | |     | |   | |     | |       | |     | |   | |     | |
+                  | /     | /   \ |     \ |       | /     | /   \ |     \ |
+              y--Mux  y--Mux  y--Mux  y--Mux  y--Mux  y--Mux  y--Mux  y--Mux
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0b0yz): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                  |       |       |       | +-----|-------|-------|-------+
+                  |       |       | +-----|-|-----|-------|-------+       |
+                  |       | +-----|-|-----|-|-----|-------+       |       |
+                  | +-----|-|-----|-|-----|-|-----+       |       |       |
+swap 4-bit words: | |     | |     | |     | |     |       |       |       |
+                  | |     | |     | |     +-|-----|-------|-------|-----+ |
+                  | |     | |     +-|-----|-|-----|-------|-----+ |     | |
+                  | |     +-|-----|-|-----|-|-----|-----+ |     | |     | |
+                  +-|-----|-|-----|-|-----|-|---+ |     | |     | |     | |
+                  | |     | |     | |     | |   | |     | |     | |     | |
+                  | /     | /     | /     | /   \ |     \ |     \ |     \ |
+              x--Mux  x--Mux  x--Mux  x--Mux  x--Mux  x--Mux  x--Mux  x--Mux
+                  |       |       |       |       |       |       |       |
+the value here is |       |       |       |       |       |       |       |
+grev(inp, 0bxyz): |       |       |       |       |       |       |       |
+                  |       |       |       |       |       |       |       |
+                out[0]  out[1]  out[2]  out[3]  out[4]  out[5]  out[6]  out[7]
 """
 
 from nmigen.hdl.ast import Signal, Mux, Cat
 from nmigen.hdl.ast import Assert
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Elaboratable
 """
 
 from nmigen.hdl.ast import Signal, Mux, Cat
 from nmigen.hdl.ast import Assert
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Elaboratable
-from nmigen.cli import rtlil
+import string
 
 
 def grev(inval, chunk_sizes, log2_width):
 
 
 def grev(inval, chunk_sizes, log2_width):
@@ -38,89 +105,141 @@ def grev(inval, chunk_sizes, log2_width):
 class GRev(Elaboratable):
     """Generalized bit-reverse.
 
 class GRev(Elaboratable):
     """Generalized bit-reverse.
 
-    https://bugs.libre-soc.org/show_bug.cgi?id=755
-
-    XXX this is documentation about Grev (the concept) which should be in
-    the docstring.  the class string is reserved for describing how to
-    *use* the class (describe its inputs and outputs)
-
-    A generalized bit-reverse - also known as a butterfly network - is where
-    every output bit is the input bit at index `output_bit_index XOR
-    chunk_sizes` where `chunk_sizes` is the control input.
-
-    This is useful because many bit/byte reverse operations can be created by
-    setting `chunk_sizes` to different values. Some examples for a 64-bit
-    `grev` operation:
-    * `0b111111` -- reverse all bits in the 64-bit word
-    * `0b111000` -- reverse bytes in the 64-bit word
-    * `0b011000` -- reverse bytes in each 32-bit word independently
-    * `0b110000` -- reverse order of 16-bit words
-
-    This is implemented by using a series of `log2_width` 2:1 muxes, exactly
-    as in a butterfly network: https://en.wikipedia.org/wiki/Butterfly_network
-
-    The 2:1 muxes are arranged to calculate successive `grev`-ed values where
-    each intermediate value's corresponding `chunk_sizes` is progressively
-    changed from all zeros to the input `chunk_sizes` by adding one bit at a
-    time from the LSB to MSB.  (XXX i don't understand this at all!)
+    See the module's documentation for a description of generalized
+    bit-reverse, as well as the permutation network created by this class.
 
 
-    :reverse_order: if True the butterfly steps are performed
-                    at offsets of 2^N ... 8 4 2.
-                    if False, the order is 2 4 8 ... 2^N
+    Attributes:
+    log2_width: int
+        see __init__'s docs.
+    msb_first: bool
+        see __init__'s docs.
+    width: int
+        the input/output width of the grev operation. The value is
+        `2 ** self.log2_width`.
+    input: Signal with width=self.width
+        the input value of the grev operation.
+    chunk_sizes: Signal with width=self.log2_width
+        the input that describes which bits get swapped. See the module docs
+        for additional details.
+    output: Signal with width=self.width
+        the output value of the grev operation.
     """
 
     """
 
-    def __init__(self, log2_width, reverse_order=False):
-        self.reverse_order = reverse_order    # reverses the order of steps
+    def __init__(self, log2_width, msb_first=False):
+        """Create a `GRev` instance.
+
+        log2_width: int
+            the base-2 logarithm of the input/output width of the grev
+            operation.
+        msb_first: bool
+            If `msb_first` is True, then the order will be the reverse of the
+            standard order -- swapping adjacent 8-bit words, then 4-bit words,
+            then 2-bit words, then 1-bit words -- using the bits of
+            `chunk_sizes` from MSB to LSB.
+            If `msb_first` is False (the default), then the order will be the
+            standard order -- swapping adjacent 1-bit words, then 2-bit words,
+            then 4-bit words, then 8-bit words -- using the bits of
+            `chunk_sizes` from LSB to MSB.
+        """
         self.log2_width = log2_width
         self.log2_width = log2_width
+        self.msb_first = msb_first
         self.width = 1 << log2_width
         self.input = Signal(self.width)
         self.width = 1 << log2_width
         self.input = Signal(self.width)
-        # XXX is this an input or output?
         self.chunk_sizes = Signal(log2_width)
         self.output = Signal(self.width)
 
         self.chunk_sizes = Signal(log2_width)
         self.output = Signal(self.width)
 
+        # internal signals exposed for unit tests, should be ignored by
+        # external users. The signals are created in the constructor because
+        # that's where all class member variables should *always* be created.
+        # If we were to create the members in elaborate() instead, it would
+        # just make the class very confusing to use.
+        #
+        # `_intermediates[step_count]`` is the value after `step_count` steps
+        # of muxing. e.g. (for `msb_first == False`) `_intermediates[4]` is the
+        # result of 4 steps of muxing, being the value `grev(inp,0b00wxyz)`.
+        self._intermediates = [self.__inter(i) for i in range(log2_width + 1)]
+
+    def _get_cs_bit_index(self, step_index):
+        """get the index of the bit of `chunk_sizes` that this step should mux
+        based off of."""
+        assert 0 <= step_index < self.log2_width
+        if self.msb_first:
+            # reverse so we start from the MSB, producing intermediate values
+            # like, for `step_index == 4`, `0buvwx00` rather than `0b00wxyz`
+            return self.log2_width - step_index - 1
+        return step_index
+
+    def __inter(self, step_count):
+        """make a signal with a name like `grev(inp,0b000xyz)` to match the
+        diagram in the module-level docs."""
+        # make the list of bits in LSB to MSB order
+        chunk_sizes_bits = ['0'] * self.log2_width
+        # for all steps already completed
+        for step_index in range(step_count):
+            bit_num = self._get_cs_bit_index(step_index)
+            ch = string.ascii_lowercase[-1 - bit_num]  # count from z to a
+            chunk_sizes_bits[bit_num] = ch
+        # reverse cuz text is MSB first
+        chunk_sizes_val = '0b' + ''.join(reversed(chunk_sizes_bits))
+        # name works according to Verilog's rules for escaped identifiers cuz
+        # it has no spaces
+        name = f"grev(inp,{chunk_sizes_val})"
+        return Signal(self.width, name=name)
+
+    def __get_permutation(self, step_index):
+        """get the bit permutation for the current step. the returned value is
+        a list[int] where `retval[i] == j` means that this step's input bit `i`
+        goes to this step's output bit `j`."""
+        # we can extract just the latest bit for this step, since the previous
+        # step effectively has it's value's grev arg as `0b000xyz`, and this
+        # step has it's value's grev arg as `0b00wxyz`, so we only need to
+        # compute `grev(prev_step_output,0b00w000)` to get
+        # `grev(inp,0b00wxyz)`. `cur_chunk_sizes` is the `0b00w000`.
+        cur_chunk_sizes = 1 << self._get_cs_bit_index(step_index)
+        # compute bit permutation for `grev(...,0b00w000)`.
+        return [i ^ cur_chunk_sizes for i in range(self.width)]
+
+    def _sigs_and_expected(self, inp, chunk_sizes):
+        """the intermediate signals and the expected values, based off of the
+        passed-in `inp` and `chunk_sizes`."""
+        # we accumulate a mask of which chunk_sizes bits we have accounted for
+        # so far
+        chunk_sizes_mask = 0
+        for step_count, intermediate in enumerate(self._intermediates):
+            # mask out chunk_sizes to get the value
+            cur_chunk_sizes = chunk_sizes & chunk_sizes_mask
+            expected = grev(inp, cur_chunk_sizes, self.log2_width)
+            yield (intermediate, expected)
+            # if step_count is in-range for being a valid step_index
+            if step_count < self.log2_width:
+                # add current step's bit to the mask
+                chunk_sizes_mask |= 1 << self._get_cs_bit_index(step_count)
+        assert chunk_sizes_mask == 2 ** self.log2_width - 1, \
+            "should have got all the bits in chunk_sizes"
+
     def elaborate(self, platform):
         m = Module()
     def elaborate(self, platform):
         m = Module()
-        comb = m.d.comb
-
-        # accumulate list of internal signals, exposed only for unit testing.
-        # contains the input, intermediary steps, and the output.
-        self._steps = [self.input]
-
-        # TODO: no. "see class doc comment for algorithm docs." <-- document
-        #           *in* the code, not "see another location elsewhere"
-        #           (unless it is a repeated text/concept of course, like
-        #            with BitwiseLut, and that's because the API is identical)
-        #           "see elsewhere" entirely defeats the object of the exercise.
-        #           jumping back and forth (page-up, page-down)
-        #           between the text and the code splits attention.
-        #           the purpose of comments is to be able to understand
-        #           (in plain english) the code *at* the point of seeing it
-        #           it should contain "the thoughts going through your head"
-        #
-        #           demonstrated below (with a rewrite)
-
-        step_i = self.input  # start with input as the first step
-
-        # create (reversed?) list of steps
-        steps = list(range(self.log2_width))
-        if self.reverse_order:
-            steps.reverse()
-
-        for i in steps:
-            # each chunk is a power-2 jump.
-            chunk_size = 1 << i
-            # prepare a list of XOR-swapped bits of this layer/step
-            butterfly = [step_i[j ^ chunk_size] for j in range(self.width)]
-            # create muxes here: 1 bit of chunk_sizes decides swap/no-swap
-            step_o = Signal(self.width, name="step%d" % chunk_size)
-            comb += step_o.eq(Mux(self.chunk_sizes[i],
-                                  Cat(*butterfly), step_i))
-            # output becomes input to next layer
-            step_i = step_o
-            self._steps.append(step_o)  # record steps for test purposes (only)
-
-        # last layer is also the output
-        comb += self.output.eq(step_o)
+
+        # value after zero steps is just the input
+        m.d.comb += self._intermediates[0].eq(self.input)
+
+        for step_index in range(self.log2_width):
+            step_inp = self._intermediates[step_index]
+            step_out = self._intermediates[step_index + 1]
+            # get permutation for current step
+            permutation = self.__get_permutation(step_index)
+            # figure out which `chunk_sizes` bit we want to pay attention to
+            # for this step.
+            sel = self.chunk_sizes[self._get_cs_bit_index(step_index)]
+            for in_index, out_index in enumerate(permutation):
+                # use in_index so we get the permuted bit
+                permuted_bit = step_inp[in_index]
+                # use out_index so we copy the bit straight thru
+                straight_bit = step_inp[out_index]
+                bit = Mux(sel, permuted_bit, straight_bit)
+                m.d.comb += step_out[out_index].eq(bit)
+        # value after all steps is just the output
+        m.d.comb += self.output.eq(self._intermediates[-1])
 
         if platform != 'formal':
             return m
 
         if platform != 'formal':
             return m
@@ -129,20 +248,16 @@ class GRev(Elaboratable):
         m.d.comb += Assert(self.output == grev(self.input,
                                                self.chunk_sizes,
                                                self.log2_width))
         m.d.comb += Assert(self.output == grev(self.input,
                                                self.chunk_sizes,
                                                self.log2_width))
-        for i, step in enumerate(self._steps):
-            cur_chunk_sizes = self.chunk_sizes & (2 ** i - 1)
-            step_expected = grev(self.input, cur_chunk_sizes, self.log2_width)
-            m.d.comb += Assert(step == step_expected)
 
 
+        for value, expected in self._sigs_and_expected(self.input,
+                                                       self.chunk_sizes):
+            m.d.comb += Assert(value == expected)
         return m
 
     def ports(self):
         return [self.input, self.chunk_sizes, self.output]
 
 
         return m
 
     def ports(self):
         return [self.input, self.chunk_sizes, self.output]
 
 
-# useful to see what is going on: use yosys "read_ilang test_grev.il; show top"
-if __name__ == '__main__':
-    dut = GRev(3)
-    vl = rtlil.convert(dut, ports=dut.ports())
-    with open("test_grev.il", "w") as f:
-        f.write(vl)
+# useful to see what is going on:
+# python3 src/nmutil/test/test_grev.py
+# yosys <<<"read_ilang sim_test_out/__main__.TestGrev.test_small/0.il; proc; clean -purge; show top"
index 15e71489a738b601646cdb6bc2683550e7e58f13..780239d8a13b2954a7953d5d2e312dd517a80347 100644 (file)
@@ -14,11 +14,11 @@ from nmutil.sim_util import do_sim, hash_256
 
 
 class TestGrev(FHDLTestCase):
 
 
 class TestGrev(FHDLTestCase):
-    def test(self):
-        log2_width = 6
+    def tst(self, msb_first, log2_width=6):
         width = 2 ** log2_width
         width = 2 ** log2_width
-        dut = GRev(log2_width)
+        dut = GRev(log2_width, msb_first)
         self.assertEqual(width, dut.width)
         self.assertEqual(width, dut.width)
+        self.assertEqual(len(dut._intermediates), log2_width + 1)
 
         def case(inval, chunk_sizes):
             expected = grev(inval, chunk_sizes, log2_width)
 
         def case(inval, chunk_sizes):
             expected = grev(inval, chunk_sizes, log2_width)
@@ -30,17 +30,14 @@ class TestGrev(FHDLTestCase):
                 output = yield dut.output
                 with self.subTest(output=hex(output)):
                     self.assertEqual(expected, output)
                 output = yield dut.output
                 with self.subTest(output=hex(output)):
                     self.assertEqual(expected, output)
-                for i, step in enumerate(dut._steps):
-                    cur_chunk_sizes = chunk_sizes & (2 ** i - 1)
-                    step_expected = grev(inval, cur_chunk_sizes, log2_width)
-                    step = yield step
-                    with self.subTest(i=i, step=hex(step),
-                                      cur_chunk_sizes=bin(cur_chunk_sizes),
-                                      step_expected=hex(step_expected)):
-                        self.assertEqual(step, step_expected)
+                for sig, expected in dut._sigs_and_expected(inval,
+                                                            chunk_sizes):
+                    value = yield sig
+                    with self.subTest(sig=sig.name, value=hex(value),
+                                      expected=hex(expected)):
+                        self.assertEqual(value, expected)
 
         def process():
 
         def process():
-            self.assertEqual(len(dut._steps), log2_width + 1)
             for count in range(width + 1):
                 inval = (1 << count) - 1
                 for chunk_sizes in range(2 ** log2_width):
             for count in range(width + 1):
                 inval = (1 << count) - 1
                 for chunk_sizes in range(2 ** log2_width):
@@ -56,9 +53,21 @@ class TestGrev(FHDLTestCase):
             sim.add_process(process)
             sim.run()
 
             sim.add_process(process)
             sim.run()
 
-    def test_formal(self):
+    def test(self):
+        self.tst(msb_first=False)
+
+    def test_msb_first(self):
+        self.tst(msb_first=True)
+
+    def test_small(self):
+        self.tst(msb_first=False, log2_width=3)
+
+    def test_small_msb_first(self):
+        self.tst(msb_first=True, log2_width=3)
+
+    def tst_formal(self, msb_first):
         log2_width = 4
         log2_width = 4
-        dut = GRev(log2_width)
+        dut = GRev(log2_width, msb_first)
         m = Module()
         m.submodules.dut = dut
         m.d.comb += dut.input.eq(AnyConst(2 ** log2_width))
         m = Module()
         m.submodules.dut = dut
         m.d.comb += dut.input.eq(AnyConst(2 ** log2_width))
@@ -66,6 +75,12 @@ class TestGrev(FHDLTestCase):
         # actual formal correctness proof is inside the module itself, now
         self.assertFormal(m)
 
         # actual formal correctness proof is inside the module itself, now
         self.assertFormal(m)
 
+    def test_formal(self):
+        self.tst_formal(msb_first=False)
+
+    def test_formal_msb_first(self):
+        self.tst_formal(msb_first=True)
+
 
 if __name__ == "__main__":
     unittest.main()
 
 if __name__ == "__main__":
     unittest.main()