detect duplicate comment fields
[simplev-cpp.git] / generate_headers.py
index e498a213289bff2275d24898b1898fb0c3da8a8a..7ead2d42b4ff5b20671cdfcffa142b4eedb79c89 100755 (executable)
 # SPDX-License-Identifier: LGPL-2.1-or-later
 # See Notices.txt for copyright information
 
-from typing import List
-from soc.decoder.isa.caller import (SVP64PrefixFields, SV64P_MAJOR_SIZE,
-                                    SV64P_PID_SIZE, SVP64RMFields,
-                                    SVP64RM_EXTRA2_SPEC_SIZE,
-                                    SVP64RM_EXTRA3_SPEC_SIZE,
-                                    SVP64RM_MODE_SIZE, SVP64RM_SMASK_SIZE,
-                                    SVP64RM_MMODE_SIZE, SVP64RM_MASK_SIZE,
-                                    SVP64RM_SUBVL_SIZE, SVP64RM_EWSRC_SIZE,
-                                    SVP64RM_ELWIDTH_SIZE)
+import sys
+from io import StringIO
+from typing import Any, Dict, List
 from soc.decoder.pseudo.pagereader import ISA
-from soc.decoder.power_svp64 import SVP64RM, get_regtype, decode_extra
+from soc.decoder.power_svp64 import SVP64RM
 from soc.decoder.power_fields import DecodeFields
+from soc.decoder.power_decoder import create_pdecode
+
+
+# generated text should wrap at 80 columns
+OUTPUT_WIDTH = 80
+
+
+def wcwidth(text: str) -> int:
+    """ return the number of columns that `text` takes up if printed to a terminal.
+        returns -1 if any characters are not printable.
+    """
+    # TODO: replace with wcwidth from wcwidth
+    # package when we add pip packaging files
+    if text.isprintable():
+        return len(text)
+    return -1
+
+
+class TableEntry:
+    def __init__(self, key: str, value: Any):
+        self.key = repr(key)
+        self.key_lines = self.key.splitlines()
+        for i in range(len(self.key_lines)):
+            self.key_lines[i] = self.key_lines[i].rstrip()
+        self.value = repr(value)
+        self.value_lines = self.value.splitlines()
+        for i in range(len(self.value_lines)):
+            self.value_lines[i] = self.value_lines[i].rstrip()
+        self.width = max(max(map(wcwidth, self.key_lines)),
+                         max(map(wcwidth, self.value_lines)))
+
+
+def format_dict_as_tables(d: Dict[str, Any], column_width: int) -> List[str]:
+    entries = [TableEntry(k, v) for k, v in d.items()]
+    entries.sort(key=lambda entry: entry.key)
+    entries.reverse()
+    col_sep_start = '| '
+    col_sep_start_top_line = '+-'
+    col_sep_start_mid_line = '+='
+    col_sep_start_bottom_line = '+-'
+    col_sep_start_width = wcwidth(col_sep_start)
+    col_sep_mid = ' | '
+    col_sep_mid_top_line = '-+-'
+    col_sep_mid_mid_line = '=+='
+    col_sep_mid_bottom_line = '-+-'
+    col_sep_mid_width = wcwidth(col_sep_mid)
+    col_sep_end = ' |'
+    col_sep_end_top_line = '-+'
+    col_sep_end_mid_line = '=+'
+    col_sep_end_bottom_line = '-+'
+    col_sep_end_width = wcwidth(col_sep_end)
+    col_top_line_char = '-'
+    col_mid_line_char = '='
+    col_bottom_line_char = '-'
+    retval: List[str] = []
+    while len(entries) != 0:
+        total_width = col_sep_start_width - col_sep_mid_width
+        column_entries: List[TableEntry] = []
+        key_line_count = 0
+        value_line_count = 0
+        while len(entries) != 0:
+            entry = entries.pop()
+            next_total_width = total_width + col_sep_mid_width
+            next_total_width += entry.width
+            if len(column_entries) != 0 and \
+                    next_total_width + col_sep_end_width >= column_width:
+                entries.append(entry)
+                break
+            total_width = next_total_width
+            column_entries.append(entry)
+            key_line_count = max(key_line_count, len(entry.key_lines))
+            value_line_count = max(value_line_count, len(entry.value_lines))
+        top_line = col_sep_start_top_line
+        mid_line = col_sep_start_mid_line
+        bottom_line = col_sep_start_bottom_line
+        key_lines = [col_sep_start] * key_line_count
+        value_lines = [col_sep_start] * value_line_count
+        for i in range(len(column_entries)):
+            last = (i == len(column_entries) - 1)
+            entry = column_entries[i]
+
+            def extend_line(line, entry_text, fill_char,
+                            col_sep_mid, col_sep_end):
+                line += entry_text
+                line += fill_char * (entry.width - wcwidth(entry_text))
+                line += col_sep_end if last else col_sep_mid
+                return line
+
+            top_line = extend_line(line=top_line, entry_text='',
+                                   fill_char=col_top_line_char,
+                                   col_sep_mid=col_sep_mid_top_line,
+                                   col_sep_end=col_sep_end_top_line)
+            mid_line = extend_line(line=mid_line, entry_text='',
+                                   fill_char=col_mid_line_char,
+                                   col_sep_mid=col_sep_mid_mid_line,
+                                   col_sep_end=col_sep_end_mid_line)
+            bottom_line = extend_line(line=bottom_line, entry_text='',
+                                      fill_char=col_bottom_line_char,
+                                      col_sep_mid=col_sep_mid_bottom_line,
+                                      col_sep_end=col_sep_end_bottom_line)
+
+            def extend_lines(lines, entry_lines):
+                for j in range(len(lines)):
+                    entry_text = ''
+                    if j < len(entry_lines):
+                        entry_text = entry_lines[j]
+                    lines[j] = extend_line(line=lines[j],
+                                           entry_text=entry_text,
+                                           fill_char=' ',
+                                           col_sep_mid=col_sep_mid,
+                                           col_sep_end=col_sep_end)
+
+            extend_lines(key_lines, entry.key_lines)
+            extend_lines(value_lines, entry.value_lines)
+        retval.append(top_line)
+        retval.extend(key_lines)
+        retval.append(mid_line)
+        retval.extend(value_lines)
+        retval.append(bottom_line)
+    return retval
 
 
 class InclusiveRange:
@@ -97,7 +211,8 @@ class RangesField:
         self.ranges.append(InclusiveRange(value))
 
     def endian_reversed(self, word_length) -> "RangesField":
-        return RangesField([i.endian_reversed(word_length) for i in reversed(self.ranges)])
+        return RangesField([i.endian_reversed(word_length)
+                            for i in reversed(self.ranges)])
 
     @property
     def contiguous(self) -> bool:
@@ -119,30 +234,100 @@ class RangesField:
         return self.ranges[0].width
 
 
-isa = ISA()
-svp64rm = SVP64RM()
-decode_fields = DecodeFields(bitkls=RangesField)
-decode_fields.create_specs()
+try:
+    # shut-up excessive printing
+    old_stdout = sys.stdout
+    new_stdout = StringIO()
+    sys.stdout = new_stdout
+    isa = ISA()
+    svp64rm = SVP64RM()
+    decode_fields = DecodeFields(bitkls=RangesField)
+    decode_fields.create_specs()
+    decoder = create_pdecode()
+    sys.stdout = old_stdout
+except:
+    sys.stdout = old_stdout
+    print(new_stdout.getvalue())
+    raise
+
+
+def flatten(v):
+    if isinstance(v, list):
+        for i in v:
+            yield from flatten(i)
+    else:
+        yield v
+
+
+def subdecoders():
+    def visit_subdecoders(subdecoders):
+        for subdecoder in flatten(subdecoders):
+            yield subdecoder
+            yield from visit_subdecoders(subdecoder.subdecoders)
+    yield from visit_subdecoders(decoder.dec)
+
+
+def make_opcodes_dict() -> Dict[str, Dict[str, Any]]:
+    retval = {}
+    for subdecoder in subdecoders():
+        for opcode in subdecoder.opcodes:
+            opcode = dict(opcode)
+            opcode['subdecoder'] = subdecoder
+            comment = opcode['comment']
+            if comment not in retval:
+                retval[comment] = opcode
+            else:
+                print(f"duplicate comment field: {comment!r}"
+                      f" subdecoder.pattern={subdecoder.pattern}")
+    return retval
+
+
+OPCODES_DICT = make_opcodes_dict()
+
+
+def find_opcode(comment):
+    opcode = OPCODES_DICT[comment]
+    return opcode['subdecoder'].pattern, int(opcode['opcode'], base=0)
+
+
+SETVL_PO, SETVL_XO = find_opcode("setvl")
 PO_FIELD: RangesField = decode_fields.PO
 RT_FIELD: RangesField = decode_fields.RT
 RA_FIELD: RangesField = decode_fields.RA
+SVL_SVi_FIELD: RangesField = decode_fields.FormSVL.SVi
+SVL_vs_FIELD: RangesField = decode_fields.FormSVL.vs
+SVL_ms_FIELD: RangesField = decode_fields.FormSVL.ms
+SVL_XO_FIELD: RangesField = decode_fields.FormSVL.XO
+SVL_Rc_FIELD: RangesField = decode_fields.FormSVL.Rc
+print(f"SETVL_PO={SETVL_PO}")
+print(f"SETVL_XO={SETVL_XO}")
 print(f"PO_FIELD={PO_FIELD}")
 print(f"RT_FIELD={RT_FIELD}")
 print(f"RA_FIELD={RA_FIELD}")
+print(f"SVL_SVi_FIELD={SVL_SVi_FIELD}")
+print(f"SVL_vs_FIELD={SVL_vs_FIELD}")
+print(f"SVL_ms_FIELD={SVL_ms_FIELD}")
+print(f"SVL_XO_FIELD={SVL_XO_FIELD}")
+print(f"SVL_Rc_FIELD={SVL_Rc_FIELD}")
 
 WORD_LENGTH = 32
-PRIMARY_OPCODE_SHIFT = PO_FIELD.endian_reversed(WORD_LENGTH).start
-# unofficial value. see https://libre-soc.org/openpower/sv/setvl/
-# FIXME: incorrect extended opcode value
-SETVL_OPCODE = (19 << PRIMARY_OPCODE_SHIFT) | (0 << 1)
-SETVL_IMMEDIATE_SHIFT = 32 - 7 - 16
-REG_FIELD_WIDTH = 5
+PO_SHIFT = PO_FIELD.endian_reversed(WORD_LENGTH).start
 RT_SHIFT = RT_FIELD.endian_reversed(WORD_LENGTH).start
 RA_SHIFT = RA_FIELD.endian_reversed(WORD_LENGTH).start
+SVL_SVi_SHIFT = SVL_SVi_FIELD.endian_reversed(WORD_LENGTH).start
+SVL_vs_SHIFT = SVL_vs_FIELD.endian_reversed(WORD_LENGTH).start
+SVL_ms_SHIFT = SVL_ms_FIELD.endian_reversed(WORD_LENGTH).start
+SVL_XO_SHIFT = SVL_XO_FIELD.endian_reversed(WORD_LENGTH).start
+SVL_Rc_SHIFT = SVL_Rc_FIELD.endian_reversed(WORD_LENGTH).start
 
 print(f"RT_SHIFT={RT_SHIFT}")
 print(f"RA_SHIFT={RA_SHIFT}")
-print(f"PRIMARY_OPCODE_SHIFT={PRIMARY_OPCODE_SHIFT}")
+print(f"PO_SHIFT={PO_SHIFT}")
+print(f"SVL_SVi_SHIFT={SVL_SVi_SHIFT}")
+print(f"SVL_vs_SHIFT={SVL_vs_SHIFT}")
+print(f"SVL_ms_SHIFT={SVL_ms_SHIFT}")
+print(f"SVL_XO_SHIFT={SVL_XO_SHIFT}")
+print(f"SVL_Rc_SHIFT={SVL_Rc_SHIFT}")
 
 
 def is_power_of_2(i):
@@ -153,6 +338,8 @@ def is_power_of_2(i):
 with open("include/simplev_cpp_generated.h", mode="w", encoding="utf-8") as o:
     o.write("""// SPDX-License-Identifier: LGPL-2.1-or-later
 // See Notices.txt for copyright information
+// This file is automatically generated by generate_headers.py,
+// do not edit by hand
 #pragma once
 #include <cstddef>
 #include <cstdint>
@@ -172,7 +359,7 @@ struct VecTypeStruct;
 template <typename ElementType, std::size_t SUB_VL, std::size_t MAX_VL>
 using VecType = typename VecTypeStruct<ElementType, SUB_VL, MAX_VL>::Type;
 
-#define SIMPLEV_MAKE_VEC_TYPE(size)                                                         \\
+#define SIMPLEV_MAKE_VEC_TYPE(size, underlying_size)                                        \\
     template <typename ElementType, std::size_t SUB_VL, std::size_t MAX_VL>                 \\
     struct VecTypeStruct<ElementType,                                                       \\
                          SUB_VL,                                                            \\
@@ -180,21 +367,37 @@ using VecType = typename VecTypeStruct<ElementType, SUB_VL, MAX_VL>::Type;
                          std::enable_if_t<sizeof(ElementType) * SUB_VL * MAX_VL == (size)>> \\
         final                                                                               \\
     {                                                                                       \\
-        typedef ElementType Type __attribute__((vector_size(size)));                        \\
+        typedef ElementType Type __attribute__((vector_size(underlying_size)));             \\
     };
 
+template <typename ElementType, std::size_t SUB_VL, std::size_t MAX_VL>
+struct Vec final
+{
+    static_assert(MAX_VL > 0 && MAX_VL <= 64);
+    static_assert(SUB_VL >= 1 && SUB_VL <= 4);
+    using Type = VecType<ElementType, SUB_VL, MAX_VL>;
+    Type value;
+};
+
+// power-of-2
+#define SIMPLEV_MAKE_VEC_TYPE_POT(size) SIMPLEV_MAKE_VEC_TYPE(size, size)
+
+// non-power-of-2
 #ifdef SIMPLEV_USE_NONPOT_VECTORS
-#define SIMPLEV_MAKE_VEC_TYPE_NONPOT(size) SIMPLEV_MAKE_VEC_TYPE(size)
+#define SIMPLEV_MAKE_VEC_TYPE_NONPOT(size, rounded_up_size) SIMPLEV_MAKE_VEC_TYPE(size)
 #else
-#define SIMPLEV_MAKE_VEC_TYPE_NONPOT(size)
+#define SIMPLEV_MAKE_VEC_TYPE_NONPOT(size, rounded_up_size) \\
+    SIMPLEV_MAKE_VEC_TYPE(size, rounded_up_size)
 #endif
 
 """)
-    for i in range(1, 128 + 1):
+    for i in range(64 * 4):
+        i += 1
         if is_power_of_2(i):
-            o.write(f"SIMPLEV_MAKE_VEC_TYPE({i})\n")
+            o.write(f"SIMPLEV_MAKE_VEC_TYPE_POT({i})\n")
         else:
-            o.write(f"SIMPLEV_MAKE_VEC_TYPE_NONPOT({i})\n")
+            rounded_up_size = 1 << i.bit_length()
+            o.write(f"SIMPLEV_MAKE_VEC_TYPE_NONPOT({i}, {rounded_up_size})\n")
         if i == 8:
             o.write("#ifdef SIMPLEV_USE_BIGGER_THAN_8_BYTE_VECTORS\n")
     o.write(f"""#endif // SIMPLEV_USE_BIGGER_THAN_8_BYTE_VECTORS
@@ -214,13 +417,50 @@ inline __attribute__((always_inline)) VL<MAX_VL> setvl(std::size_t vl)
         "# setvl %[retval], %[vl], MVL=%[max_vl]\\n\\t"
         ".long %[instr] | (%[retval] << %[rt_shift]) | (%[vl] << %[ra_shift])"
         : [retval] "=b"(retval.value)
-        : [vl] "b"(vl), [max_vl] "n"(MAX_VL), [instr] "n"({FIXME___FINISH}));
+        : [vl] "b"(vl),
+          [max_vl] "n"(MAX_VL),
+          [instr] "n"(((MAX_VL - 1) << {SVL_SVi_SHIFT}) | {hex(
+              (1 << SVL_vs_SHIFT)
+              | (1 << SVL_ms_SHIFT)
+              | (SETVL_XO << SVL_XO_SHIFT)
+              | (SETVL_PO << PO_SHIFT))}),
+          [rt_shift] "n"({RT_SHIFT}),
+          [ra_shift] "n"({RA_SHIFT})
+        : "memory");
     return retval;
 }}
-}} // namespace sv
+""")
+    for opcode in {i: None for i in svp64rm.instrs}:
+        try:
+            instr = OPCODES_DICT[opcode]
+        except KeyError as e:
+            print(repr(e), file=sys.stderr)
+            o.write(f"\n// skipped invalid opcode: {opcode!r}\n")
+            continue
+        o.write(f"\n/// {opcode}\n")
+        function_name = "sv_" + opcode.split('/')[0].replace('.', '_')
+        SV_Ptype = instr['SV_Ptype']
+        if SV_Ptype == 'P2':
+            pass
+            # TODO
+        else:
+            assert SV_Ptype == 'P1'
+            # TODO
+        o.write("/// (not yet implemented)\n")
+        instr_without_subdecoder = instr.copy()
+        del instr_without_subdecoder['subdecoder']
+        comment = "/// "
+        for line in format_dict_as_tables(instr_without_subdecoder,
+                                          OUTPUT_WIDTH - wcwidth(comment)):
+            o.write((comment + line).rstrip() + '\n')
+        o.write(f"""template <typename... Args>
+void {function_name}(Args &&...) = delete;
+""")
+    o.write(f"""}} // namespace sv
 
 #undef SIMPLEV_MAKE_VEC_TYPE
 #undef SIMPLEV_MAKE_VEC_TYPE_NONPOT
+#undef SIMPLEV_MAKE_VEC_TYPE_POT
 #undef SIMPLEV_USE_NONPOT_VECTORS
 #undef SIMPLEV_USE_BIGGER_THAN_8_BYTE_VECTORS
 """)