simplify lut.py
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 17 Dec 2021 01:20:50 +0000 (17:20 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 17 Dec 2021 01:20:50 +0000 (17:20 -0800)
src/nmutil/lut.py

index 8edc2b487f84680c2573a8b854ff01fe1fdb1508..8776b0e80e7b5694e1be23e5a52ef9bc91dae576 100644 (file)
@@ -1,7 +1,7 @@
 # SPDX-License-Identifier: LGPL-3-or-later
 # See Notices.txt for copyright information
 
-from nmigen.hdl.ast import Array, Repl, Signal
+from nmigen.hdl.ast import Array, Cat, Repl, Signal
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Elaboratable
 
@@ -26,8 +26,6 @@ class BitwiseMux(Elaboratable):
 
 class BitwiseLut(Elaboratable):
     def __init__(self, input_count, width):
-        assert isinstance(input_count, int)
-        assert isinstance(width, int)
         self.input_count = input_count
         self.width = width
 
@@ -37,17 +35,12 @@ class BitwiseLut(Elaboratable):
         self.output = Signal(width)
         self.lut = Signal(2 ** input_count)
 
-        def lut_index(i):
-            return Signal(input_count, name=f"lut_index_{i}")
-        self._lut_indexes = [lut_index(i) for i in range(width)]
-
     def elaborate(self, platform):
         m = Module()
-        lut = Array(self.lut[i] for i in range(self.lut.width))
-        for i in range(self.width):
-            for j in range(self.input_count):
-                m.d.comb += self._lut_indexes[i][j].eq(self.inputs[j][i])
-            m.d.comb += self.output[i].eq(lut[self._lut_indexes[i]])
+        lut_array = Array(self.lut)
+        for bit in range(self.width):
+            index = Cat(inp[bit] for inp in self.inputs)
+            m.d.comb += self.output[bit].eq(lut_array[index])
         return m
 
 
@@ -55,8 +48,6 @@ class TreeBitwiseLut(Elaboratable):
     """tree-based version of BitwiseLut"""
 
     def __init__(self, input_count, width):
-        assert isinstance(input_count, int)
-        assert isinstance(width, int)
         self.input_count = input_count
         self.width = width