speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / extend.py
index 38b5e7dd27b64701aa8f09f8837a74f42975543b..7b7d1acb78cdc52422040e85479ad8c680091184 100644 (file)
@@ -1,24 +1,41 @@
+# SPDX-License-Identifier: LGPL-2-or-later
+# Copyright (C) Luke Kenneth Casson Leighton 2020,2021 <lkcl@lkcl.net>
 """
-    This work is funded through NLnet under Grant 2019-02-012
-
-    License: LGPLv3+
-
+Provides sign/unsigned extension/truncation utility functions.
 
+This work is funded through NLnet under Grant 2019-02-012
 """
 from nmigen import Repl, Cat, Const
 
 
 def exts(exts_data, width, fullwidth):
+    diff = fullwidth-width
+    if diff == 0:
+        return exts_data
     exts_data = exts_data[0:width]
+    if diff <= 0:
+        return exts_data[:fullwidth]
     topbit = exts_data[-1]
-    signbits = Repl(topbit, fullwidth-width)
+    signbits = Repl(topbit, diff)
     return Cat(exts_data, signbits)
 
 
-def extz(exts_data, width, fullwidth):
-    exts_data = exts_data[0:width]
+def extz(extz_data, width, fullwidth):
+    diff = fullwidth-width
+    if diff == 0:
+        return extz_data
+    extz_data = extz_data[0:width]
+    if diff <= 0:
+        return extz_data[:fullwidth]
     topbit = Const(0)
-    signbits = Repl(topbit, fullwidth-width)
-    return Cat(exts_data, signbits)
+    signbits = Repl(topbit, diff)
+    return Cat(extz_data, signbits)
 
 
+def ext(data, shape, newwidth):
+    """extend/truncate data to new width, preserving sign
+    """
+    width, signed = shape
+    if signed:
+        return exts(data, width, newwidth)
+    return extz(data, width, newwidth)