detect duplicate comment fields
[simplev-cpp.git] / generate_headers.py
index 633683c72e0552d7d581fdc4cbba53f1e3a6065d..7ead2d42b4ff5b20671cdfcffa142b4eedb79c89 100755 (executable)
 
 import sys
 from io import StringIO
-from typing import List
+from typing import Any, Dict, List
 from soc.decoder.pseudo.pagereader import ISA
 from soc.decoder.power_svp64 import SVP64RM
 from soc.decoder.power_fields import DecodeFields
 from soc.decoder.power_decoder import create_pdecode
 
 
+# generated text should wrap at 80 columns
+OUTPUT_WIDTH = 80
+
+
+def wcwidth(text: str) -> int:
+    """ return the number of columns that `text` takes up if printed to a terminal.
+        returns -1 if any characters are not printable.
+    """
+    # TODO: replace with wcwidth from wcwidth
+    # package when we add pip packaging files
+    if text.isprintable():
+        return len(text)
+    return -1
+
+
+class TableEntry:
+    def __init__(self, key: str, value: Any):
+        self.key = repr(key)
+        self.key_lines = self.key.splitlines()
+        for i in range(len(self.key_lines)):
+            self.key_lines[i] = self.key_lines[i].rstrip()
+        self.value = repr(value)
+        self.value_lines = self.value.splitlines()
+        for i in range(len(self.value_lines)):
+            self.value_lines[i] = self.value_lines[i].rstrip()
+        self.width = max(max(map(wcwidth, self.key_lines)),
+                         max(map(wcwidth, self.value_lines)))
+
+
+def format_dict_as_tables(d: Dict[str, Any], column_width: int) -> List[str]:
+    entries = [TableEntry(k, v) for k, v in d.items()]
+    entries.sort(key=lambda entry: entry.key)
+    entries.reverse()
+    col_sep_start = '| '
+    col_sep_start_top_line = '+-'
+    col_sep_start_mid_line = '+='
+    col_sep_start_bottom_line = '+-'
+    col_sep_start_width = wcwidth(col_sep_start)
+    col_sep_mid = ' | '
+    col_sep_mid_top_line = '-+-'
+    col_sep_mid_mid_line = '=+='
+    col_sep_mid_bottom_line = '-+-'
+    col_sep_mid_width = wcwidth(col_sep_mid)
+    col_sep_end = ' |'
+    col_sep_end_top_line = '-+'
+    col_sep_end_mid_line = '=+'
+    col_sep_end_bottom_line = '-+'
+    col_sep_end_width = wcwidth(col_sep_end)
+    col_top_line_char = '-'
+    col_mid_line_char = '='
+    col_bottom_line_char = '-'
+    retval: List[str] = []
+    while len(entries) != 0:
+        total_width = col_sep_start_width - col_sep_mid_width
+        column_entries: List[TableEntry] = []
+        key_line_count = 0
+        value_line_count = 0
+        while len(entries) != 0:
+            entry = entries.pop()
+            next_total_width = total_width + col_sep_mid_width
+            next_total_width += entry.width
+            if len(column_entries) != 0 and \
+                    next_total_width + col_sep_end_width >= column_width:
+                entries.append(entry)
+                break
+            total_width = next_total_width
+            column_entries.append(entry)
+            key_line_count = max(key_line_count, len(entry.key_lines))
+            value_line_count = max(value_line_count, len(entry.value_lines))
+        top_line = col_sep_start_top_line
+        mid_line = col_sep_start_mid_line
+        bottom_line = col_sep_start_bottom_line
+        key_lines = [col_sep_start] * key_line_count
+        value_lines = [col_sep_start] * value_line_count
+        for i in range(len(column_entries)):
+            last = (i == len(column_entries) - 1)
+            entry = column_entries[i]
+
+            def extend_line(line, entry_text, fill_char,
+                            col_sep_mid, col_sep_end):
+                line += entry_text
+                line += fill_char * (entry.width - wcwidth(entry_text))
+                line += col_sep_end if last else col_sep_mid
+                return line
+
+            top_line = extend_line(line=top_line, entry_text='',
+                                   fill_char=col_top_line_char,
+                                   col_sep_mid=col_sep_mid_top_line,
+                                   col_sep_end=col_sep_end_top_line)
+            mid_line = extend_line(line=mid_line, entry_text='',
+                                   fill_char=col_mid_line_char,
+                                   col_sep_mid=col_sep_mid_mid_line,
+                                   col_sep_end=col_sep_end_mid_line)
+            bottom_line = extend_line(line=bottom_line, entry_text='',
+                                      fill_char=col_bottom_line_char,
+                                      col_sep_mid=col_sep_mid_bottom_line,
+                                      col_sep_end=col_sep_end_bottom_line)
+
+            def extend_lines(lines, entry_lines):
+                for j in range(len(lines)):
+                    entry_text = ''
+                    if j < len(entry_lines):
+                        entry_text = entry_lines[j]
+                    lines[j] = extend_line(line=lines[j],
+                                           entry_text=entry_text,
+                                           fill_char=' ',
+                                           col_sep_mid=col_sep_mid,
+                                           col_sep_end=col_sep_end)
+
+            extend_lines(key_lines, entry.key_lines)
+            extend_lines(value_lines, entry.value_lines)
+        retval.append(top_line)
+        retval.extend(key_lines)
+        retval.append(mid_line)
+        retval.extend(value_lines)
+        retval.append(bottom_line)
+    return retval
+
+
 class InclusiveRange:
     __slots__ = "start", "stop"
 
@@ -140,23 +259,38 @@ def flatten(v):
         yield v
 
 
-def find_opcode(internal_op):
-    retval = None
-    for primary_subdecoder in decoder.dec:
-        for extended_subdecoder in flatten(primary_subdecoder.subdecoders):
-            for opcode in extended_subdecoder.opcodes:
-                if opcode['internal op'] == internal_op:
-                    if retval is not None:
-                        raise ValueError(f"internal_op={internal_op!r} "
-                                         "found more than once")
-                    retval = extended_subdecoder.pattern, \
-                        int(opcode['opcode'], base=0)
-    if retval is None:
-        raise ValueError(f"internal_op={internal_op!r} not found")
+def subdecoders():
+    def visit_subdecoders(subdecoders):
+        for subdecoder in flatten(subdecoders):
+            yield subdecoder
+            yield from visit_subdecoders(subdecoder.subdecoders)
+    yield from visit_subdecoders(decoder.dec)
+
+
+def make_opcodes_dict() -> Dict[str, Dict[str, Any]]:
+    retval = {}
+    for subdecoder in subdecoders():
+        for opcode in subdecoder.opcodes:
+            opcode = dict(opcode)
+            opcode['subdecoder'] = subdecoder
+            comment = opcode['comment']
+            if comment not in retval:
+                retval[comment] = opcode
+            else:
+                print(f"duplicate comment field: {comment!r}"
+                      f" subdecoder.pattern={subdecoder.pattern}")
     return retval
 
 
-SETVL_PO, SETVL_XO = find_opcode("OP_SETVL")
+OPCODES_DICT = make_opcodes_dict()
+
+
+def find_opcode(comment):
+    opcode = OPCODES_DICT[comment]
+    return opcode['subdecoder'].pattern, int(opcode['opcode'], base=0)
+
+
+SETVL_PO, SETVL_XO = find_opcode("setvl")
 PO_FIELD: RangesField = decode_fields.PO
 RT_FIELD: RangesField = decode_fields.RT
 RA_FIELD: RangesField = decode_fields.RA
@@ -225,7 +359,7 @@ struct VecTypeStruct;
 template <typename ElementType, std::size_t SUB_VL, std::size_t MAX_VL>
 using VecType = typename VecTypeStruct<ElementType, SUB_VL, MAX_VL>::Type;
 
-#define SIMPLEV_MAKE_VEC_TYPE(size)                                                         \\
+#define SIMPLEV_MAKE_VEC_TYPE(size, underlying_size)                                        \\
     template <typename ElementType, std::size_t SUB_VL, std::size_t MAX_VL>                 \\
     struct VecTypeStruct<ElementType,                                                       \\
                          SUB_VL,                                                            \\
@@ -233,21 +367,37 @@ using VecType = typename VecTypeStruct<ElementType, SUB_VL, MAX_VL>::Type;
                          std::enable_if_t<sizeof(ElementType) * SUB_VL * MAX_VL == (size)>> \\
         final                                                                               \\
     {                                                                                       \\
-        typedef ElementType Type __attribute__((vector_size(size)));                        \\
+        typedef ElementType Type __attribute__((vector_size(underlying_size)));             \\
     };
 
+template <typename ElementType, std::size_t SUB_VL, std::size_t MAX_VL>
+struct Vec final
+{
+    static_assert(MAX_VL > 0 && MAX_VL <= 64);
+    static_assert(SUB_VL >= 1 && SUB_VL <= 4);
+    using Type = VecType<ElementType, SUB_VL, MAX_VL>;
+    Type value;
+};
+
+// power-of-2
+#define SIMPLEV_MAKE_VEC_TYPE_POT(size) SIMPLEV_MAKE_VEC_TYPE(size, size)
+
+// non-power-of-2
 #ifdef SIMPLEV_USE_NONPOT_VECTORS
-#define SIMPLEV_MAKE_VEC_TYPE_NONPOT(size) SIMPLEV_MAKE_VEC_TYPE(size)
+#define SIMPLEV_MAKE_VEC_TYPE_NONPOT(size, rounded_up_size) SIMPLEV_MAKE_VEC_TYPE(size)
 #else
-#define SIMPLEV_MAKE_VEC_TYPE_NONPOT(size)
+#define SIMPLEV_MAKE_VEC_TYPE_NONPOT(size, rounded_up_size) \\
+    SIMPLEV_MAKE_VEC_TYPE(size, rounded_up_size)
 #endif
 
 """)
-    for i in range(1, 128 + 1):
+    for i in range(64 * 4):
+        i += 1
         if is_power_of_2(i):
-            o.write(f"SIMPLEV_MAKE_VEC_TYPE({i})\n")
+            o.write(f"SIMPLEV_MAKE_VEC_TYPE_POT({i})\n")
         else:
-            o.write(f"SIMPLEV_MAKE_VEC_TYPE_NONPOT({i})\n")
+            rounded_up_size = 1 << i.bit_length()
+            o.write(f"SIMPLEV_MAKE_VEC_TYPE_NONPOT({i}, {rounded_up_size})\n")
         if i == 8:
             o.write("#ifdef SIMPLEV_USE_BIGGER_THAN_8_BYTE_VECTORS\n")
     o.write(f"""#endif // SIMPLEV_USE_BIGGER_THAN_8_BYTE_VECTORS
@@ -279,10 +429,38 @@ inline __attribute__((always_inline)) VL<MAX_VL> setvl(std::size_t vl)
         : "memory");
     return retval;
 }}
-}} // namespace sv
+""")
+    for opcode in {i: None for i in svp64rm.instrs}:
+        try:
+            instr = OPCODES_DICT[opcode]
+        except KeyError as e:
+            print(repr(e), file=sys.stderr)
+            o.write(f"\n// skipped invalid opcode: {opcode!r}\n")
+            continue
+        o.write(f"\n/// {opcode}\n")
+        function_name = "sv_" + opcode.split('/')[0].replace('.', '_')
+        SV_Ptype = instr['SV_Ptype']
+        if SV_Ptype == 'P2':
+            pass
+            # TODO
+        else:
+            assert SV_Ptype == 'P1'
+            # TODO
+        o.write("/// (not yet implemented)\n")
+        instr_without_subdecoder = instr.copy()
+        del instr_without_subdecoder['subdecoder']
+        comment = "/// "
+        for line in format_dict_as_tables(instr_without_subdecoder,
+                                          OUTPUT_WIDTH - wcwidth(comment)):
+            o.write((comment + line).rstrip() + '\n')
+        o.write(f"""template <typename... Args>
+void {function_name}(Args &&...) = delete;
+""")
+    o.write(f"""}} // namespace sv
 
 #undef SIMPLEV_MAKE_VEC_TYPE
 #undef SIMPLEV_MAKE_VEC_TYPE_NONPOT
+#undef SIMPLEV_MAKE_VEC_TYPE_POT
 #undef SIMPLEV_USE_NONPOT_VECTORS
 #undef SIMPLEV_USE_BIGGER_THAN_8_BYTE_VECTORS
 """)