sim._pyrtl: optimize uses of reflexive operators.
authorwhitequark <whitequark@whitequark.org>
Wed, 26 Aug 2020 13:26:38 +0000 (13:26 +0000)
committerwhitequark <whitequark@whitequark.org>
Wed, 26 Aug 2020 13:26:58 +0000 (13:26 +0000)
When a literal is used on the left-hand side of a numeric operator,
Python is able to constant-fold some expressions:

    >>> dis.dis(lambda x: 0 + 0 + x)
      1           0 LOAD_CONST               1 (0)
                  2 LOAD_FAST                0 (x)
                  4 BINARY_ADD
                  6 RETURN_VALUE

If a literal is used on the right-hand side such that the left-hand
side is variable, this doesn't happen:

    >>> dis.dis(lambda x: x + 0 + 0)
      1           0 LOAD_FAST                0 (x)
                  2 LOAD_CONST               1 (0)
                  4 BINARY_ADD
                  6 LOAD_CONST               1 (0)
                  8 BINARY_ADD
                 10 RETURN_VALUE

PyRTL generates fairly redundant code due to the pervasive masking,
and because of that, transforming expressions into the former form,
where possible, improves runtime by about 10% on Minerva SRAM SoC.

nmigen/sim/_pyrtl.py

index c2e9367bcb1d0c36cc3ac8efd7bacb8c4ca00710..baf9a5bbc24724633cfe502d8164e53be191d4dd 100644 (file)
@@ -103,7 +103,7 @@ class _RHSValueCompiler(_ValueCompiler):
     def on_Operator(self, value):
         def mask(value):
             value_mask = (1 << len(value)) - 1
-            return f"({self(value)} & {value_mask})"
+            return f"({value_mask} & {self(value)})"
 
         def sign(value):
             if value.shape().signed:
@@ -120,9 +120,9 @@ class _RHSValueCompiler(_ValueCompiler):
             if value.operator == "b":
                 return f"bool({mask(arg)})"
             if value.operator == "r|":
-                return f"({mask(arg)} != 0)"
+                return f"(0 != {mask(arg)})"
             if value.operator == "r&":
-                return f"({mask(arg)} == {(1 << len(arg)) - 1})"
+                return f"({(1 << len(arg)) - 1} == {mask(arg)})"
             if value.operator == "r^":
                 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
                 return f"(format({mask(arg)}, 'b').count('1') % 2)"
@@ -172,20 +172,20 @@ class _RHSValueCompiler(_ValueCompiler):
         raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
 
     def on_Slice(self, value):
-        return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})"
+        return f"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
 
     def on_Part(self, value):
         offset_mask = (1 << len(value.offset)) - 1
-        offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})"
-        return f"({self(value.value)} >> {offset} & " \
-               f"{(1 << value.width) - 1})"
+        offset = f"({value.stride} * ({offset_mask} & {self(value.offset)}))"
+        return f"({(1 << value.width) - 1} & " \
+               f"{self(value.value)} >> {offset})"
 
     def on_Cat(self, value):
         gen_parts = []
         offset = 0
         for part in value.parts:
             part_mask = (1 << len(part)) - 1
-            gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})")
+            gen_parts.append(f"(({part_mask} & {self(part)}) << {offset})")
             offset += len(part)
         if gen_parts:
             return f"({' | '.join(gen_parts)})"
@@ -193,7 +193,7 @@ class _RHSValueCompiler(_ValueCompiler):
 
     def on_Repl(self, value):
         part_mask = (1 << len(value.value)) - 1
-        gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}")
+        gen_part = self.emitter.def_var("repl", f"{part_mask} & {self(value.value)}")
         gen_parts = []
         offset = 0
         for _ in range(value.count):
@@ -205,15 +205,15 @@ class _RHSValueCompiler(_ValueCompiler):
 
     def on_ArrayProxy(self, value):
         index_mask = (1 << len(value.index)) - 1
-        gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}")
+        gen_index = self.emitter.def_var("rhs_index", f"{index_mask} & {self(value.index)}")
         gen_value = self.emitter.gen_var("rhs_proxy")
         if value.elems:
             gen_elems = []
             for index, elem in enumerate(value.elems):
                 if index == 0:
-                    self.emitter.append(f"if {gen_index} == {index}:")
+                    self.emitter.append(f"if {index} == {gen_index}:")
                 else:
-                    self.emitter.append(f"elif {gen_index} == {index}:")
+                    self.emitter.append(f"elif {index} == {gen_index}:")
                 with self.emitter.indent():
                     self.emitter.append(f"{gen_value} = {self(elem)}")
             self.emitter.append(f"else:")
@@ -253,9 +253,9 @@ class _LHSValueCompiler(_ValueCompiler):
         def gen(arg):
             value_mask = (1 << len(value)) - 1
             if value.shape().signed:
-                value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
+                value_sign = f"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
             else: # unsigned
-                value_sign = f"{arg} & {value_mask}"
+                value_sign = f"{value_mask} & {arg}"
             self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
         return gen
 
@@ -267,17 +267,17 @@ class _LHSValueCompiler(_ValueCompiler):
             width_mask = (1 << (value.stop - value.start)) - 1
             self(value.value)(f"({self.lrhs(value.value)} & " \
                 f"{~(width_mask << value.start)} | " \
-                f"(({arg} & {width_mask}) << {value.start}))")
+                f"(({width_mask} & {arg}) << {value.start}))")
         return gen
 
     def on_Part(self, value):
         def gen(arg):
             width_mask = (1 << value.width) - 1
             offset_mask = (1 << len(value.offset)) - 1
-            offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})"
+            offset = f"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
             self(value.value)(f"({self.lrhs(value.value)} & " \
                 f"~({width_mask} << {offset}) | " \
-                f"(({arg} & {width_mask}) << {offset}))")
+                f"(({width_mask} & {arg}) << {offset}))")
         return gen
 
     def on_Cat(self, value):
@@ -287,7 +287,7 @@ class _LHSValueCompiler(_ValueCompiler):
             offset = 0
             for part in value.parts:
                 part_mask = (1 << len(part)) - 1
-                self(part)(f"(({gen_arg} >> {offset}) & {part_mask})")
+                self(part)(f"({part_mask} & ({gen_arg} >> {offset}))")
                 offset += len(part)
         return gen
 
@@ -302,9 +302,9 @@ class _LHSValueCompiler(_ValueCompiler):
                 gen_elems = []
                 for index, elem in enumerate(value.elems):
                     if index == 0:
-                        self.emitter.append(f"if {gen_index} == {index}:")
+                        self.emitter.append(f"if {index} == {gen_index}:")
                     else:
-                        self.emitter.append(f"elif {gen_index} == {index}:")
+                        self.emitter.append(f"elif {index} == {gen_index}:")
                     with self.emitter.indent():
                         self(elem)(arg)
                 self.emitter.append(f"else:")
@@ -332,7 +332,7 @@ class _StatementCompiler(StatementVisitor, _Compiler):
 
     def on_Switch(self, stmt):
         gen_test = self.emitter.def_var("test",
-            f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}")
+            f"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
         for index, (patterns, stmts) in enumerate(stmt.cases.items()):
             gen_checks = []
             if not patterns:
@@ -342,10 +342,10 @@ class _StatementCompiler(StatementVisitor, _Compiler):
                     if "-" in pattern:
                         mask  = int("".join("0" if b == "-" else "1" for b in pattern), 2)
                         value = int("".join("0" if b == "-" else  b  for b in pattern), 2)
-                        gen_checks.append(f"({gen_test} & {mask}) == {value}")
+                        gen_checks.append(f"{value} == ({mask} & {gen_test})")
                     else:
                         value = int(pattern, 2)
-                        gen_checks.append(f"{gen_test} == {value}")
+                        gen_checks.append(f"{value} == {gen_test}")
             if index == 0:
                 self.emitter.append(f"if {' or '.join(gen_checks)}:")
             else: