hdl.ast: simplify Mux implementation.
authorwhitequark <whitequark@whitequark.org>
Sat, 2 Oct 2021 14:18:02 +0000 (14:18 +0000)
committerwhitequark <whitequark@whitequark.org>
Sat, 2 Oct 2021 14:18:02 +0000 (14:18 +0000)
nmigen/back/rtlil.py
nmigen/hdl/ast.py
tests/test_hdl_ast.py
tests/test_sim.py

index 21986392366e547d62d99b426ae7e7b8d360715c..6da4d9cd3c7bc2065d95bc6ccb58a35e85b864ce 100644 (file)
@@ -562,6 +562,8 @@ class _RHSValueCompiler(_ValueCompiler):
 
     def on_Operator_mux(self, value):
         sel, val1, val0 = value.operands
+        if len(sel) != 1:
+            sel = sel.bool()
         val1_bits, val1_sign = val1.shape()
         val0_bits, val0_sign = val0.shape()
         res_bits, res_sign = value.shape()
index 685924ac2f9481e12cb2d9424db4e619ec71d14a..5ed3a77dd3cdbb9934dfb57054908c2b11b6a8a0 100644 (file)
@@ -735,9 +735,6 @@ def Mux(sel, val1, val0):
     Value, out
         Output ``Value``. If ``sel`` is asserted, the Mux returns ``val1``, else ``val0``.
     """
-    sel = Value.cast(sel)
-    if len(sel) != 1:
-        sel = sel.bool()
     return Operator("m", [sel, val1, val0])
 
 
index 9f0fec6ab4900c898f973e057fc546cb9d92b311..3604433c4aeddcaf4b788e1cf980f065f9f1e65c 100644 (file)
@@ -542,7 +542,7 @@ class OperatorTestCase(FHDLTestCase):
     def test_mux_wide(self):
         s = Const(0b100)
         v = Mux(s, Const(0, unsigned(4)), Const(0, unsigned(6)))
-        self.assertEqual(repr(v), "(m (b (const 3'd4)) (const 4'd0) (const 6'd0))")
+        self.assertEqual(repr(v), "(m (const 3'd4) (const 4'd0) (const 6'd0))")
 
     def test_mux_bool(self):
         v = Mux(True, Const(0), Const(0))
index ab31bd6a5b65b06e6f45bd89e3a898a49ece73c1..e4bd5c840b9fdbb8e33e1f98c64e7c706aca43c0 100644 (file)
@@ -191,6 +191,12 @@ class SimulatorUnitTestCase(FHDLTestCase):
         self.assertStatement(stmt, [C(2, 4), C(3, 4), C(0)], C(2, 4))
         self.assertStatement(stmt, [C(2, 4), C(3, 4), C(1)], C(3, 4))
 
+    def test_mux_wide(self):
+        stmt = lambda y, a, b, c: y.eq(Mux(c, a, b))
+        self.assertStatement(stmt, [C(2, 4), C(3, 4), C(0, 2)], C(3, 4))
+        self.assertStatement(stmt, [C(2, 4), C(3, 4), C(1, 2)], C(2, 4))
+        self.assertStatement(stmt, [C(2, 4), C(3, 4), C(2, 2)], C(2, 4))
+
     def test_abs(self):
         stmt = lambda y, a: y.eq(abs(a))
         self.assertStatement(stmt, [C(3,  unsigned(8))], C(3,  unsigned(8)))