#!/usr/bin/env python3
from pathlib import Path
import re
from typing import Iterable

FILENAME = Path("openpower/power_trans_ops.mdwn")
NEW_FILENAME = FILENAME.with_suffix(".new.mdwn")
OLD_FILENAME = FILENAME.with_suffix(".old.mdwn")

PO_59_63_HEADER = "# Opcode Tables for PO=59/63 XO=1---011--"
MNEMONIC_COLUMN_NAME = "opcode"
XO_COLUMN_NAME = "Major 59 and 63"


class LineReader:
    def __init__(self, lines):
        # type: (list[str]) -> None
        self.__next_line_index = 0
        self.__lines = lines

    def read(self):
        if self.__next_line_index == len(self.__lines):
            self.__next_line_index += 1
            return None, self.__next_line_index
        assert self.__next_line_index < len(self.__lines), \
            "read past end-of-file"
        line = self.__lines[self.__next_line_index].rstrip()
        self.__next_line_index += 1
        return line, self.__next_line_index


def process(lr):
    # type: (LineReader) -> Iterable[str]

    line, lineno = lr.read()

    mnemonic_to_xo_map = {}  # type: dict[str, str]

    def parse_table_separator_line(len_line_parts):
        # type: (int) -> Iterable[str]
        nonlocal line, lineno

        assert line is not None \
            and line.startswith('|') \
            and line.endswith('|') \
            and len(line.split('|')) == len_line_parts \
            and line.strip(" |-") == "", "invalid table separator line"

        yield line
        line, lineno = lr.read()

        assert line is not None and line != "", "empty table"

    def parse_single_mnemonic_to_opcode_map():
        # type: () -> Iterable[str]
        nonlocal line, lineno, mnemonic_to_xo_map

        assert line is not None and line.startswith(
            "| XO LSB half &#x2192;<br> XO MSB half &#x2193; |"), \
            "can't find PO=59/63 table"
        line_parts = line.split('|')
        len_line_parts = len(line_parts)
        assert line_parts[-1] == "", "invalid PO=59/63 table top row"
        columns = []  # type: list[str]
        columns_range = range(2, len_line_parts - 1)
        for i in columns_range:
            column = line_parts[i].strip()
            if column.startswith('`') and column.endswith('`'):
                column = column[1:-1].strip()
            assert column.lstrip(" 01") == "", (f"invalid table top row "
                                                f"contents -- must be a "
                                                f"binary string: {column}")
            columns.append(column)

        yield line
        line, lineno = lr.read()

        yield from parse_table_separator_line(len_line_parts)

        while line is not None and line != "":
            line_parts = line.split('|')
            assert line.startswith('|') and line.endswith('|'), \
                "invalid table line, must start and end with |"
            assert len(line_parts) == len_line_parts, (
                f"invalid table line, wrong part count: found "
                f"{len(line_parts)} expected {len_line_parts}")
            row = line_parts[1].strip()
            if row.startswith('`') and row.endswith('`'):
                row = row[1:-1].strip()
            assert row.lstrip(" 01/.") == "", (
                f"invalid table line header-cell contents -- must be a "
                f"binary string: {row}")
            for i, column in zip(columns_range, columns):
                cell = line_parts[i]
                if cell.strip() == "":
                    continue
                match = re.fullmatch(
                    r" *<small> *` *(?P<xo>[01./][01 ./]*[01./]) *` *</small>"
                    r" *<br/?> *(?P<mnemonic>[a-zA-Z0-9_.][a-zA-Z0-9_.()]*)?"
                    r"(?(mnemonic)|(?:\([a-zA-Z0-9_.()]+\)|"
                    r"\*\*TBD\*\*|&nbsp;|))"
                    r"(?: *\(draft\))? *", cell)
                assert match is not None, f"invalid table cell: {cell!r}"
                xo, mnemonic = match.group("xo", "mnemonic")
                shrunk_xo = xo.replace(" ", "").replace('.', '/')
                expected_xo = (row + column).replace(" ", "").replace('.', '/')
                assert shrunk_xo == expected_xo, \
                    f"incorrect XO: found {shrunk_xo} expected {expected_xo}"
                if mnemonic is None:
                    continue
                assert mnemonic.endswith('(s)'), \
                    f"PO=59/63 fptrans mnemonic must end in `(s)`: {mnemonic}"
                assert mnemonic not in mnemonic_to_xo_map, (
                    f"duplicate mnemonic: {mnemonic} -- has opcode "
                    f"{xo} and {mnemonic_to_xo_map[mnemonic]}")

                mnemonic_to_xo_map[mnemonic] = xo

            yield line
            line, lineno = lr.read()

        while line == "":
            yield line
            line, lineno = lr.read()

    def parse_mnemonic_to_opcode_map():
        # type: () -> Iterable[str]
        nonlocal line, lineno, mnemonic_to_xo_map

        while line != PO_59_63_HEADER:
            assert line is not None, "missing PO=59/63 header"
            yield line
            line, lineno = lr.read()

        yield line
        line, lineno = lr.read()

        while line is not None and not line.startswith(("#", "|")):
            yield line
            line, lineno = lr.read()

        for _ in range(3):
            yield from parse_single_mnemonic_to_opcode_map()

    def skip_table():
        # type: () -> Iterable[str]
        nonlocal line, lineno

        assert line is not None \
            and line.startswith("|") and line.endswith('|'), \
            "invalid table header"
        line_parts = line.split("|")
        len_line_parts = len(line_parts)
        assert len_line_parts >= 3, "invalid table header"

        yield line
        line, lineno = lr.read()

        yield from parse_table_separator_line(len_line_parts)
        while line is not None and line != "":
            line_parts = line.split('|')
            assert line.startswith('|') and line.endswith('|'), \
                "invalid table line, must start and end with |"
            assert len(line_parts) == len_line_parts, (
                f"invalid table line, wrong part count: found "
                f"{len(line_parts)} expected {len_line_parts}")

            yield line
            line, lineno = lr.read()

    def handle_table():
        # type: () -> Iterable[str]
        nonlocal line, lineno

        assert line is not None \
            and line.startswith("|") and line.endswith('|'), \
            "invalid table header"
        line_parts = line.split("|")
        len_line_parts = len(line_parts)
        assert len_line_parts >= 3, "invalid table header"
        mnemonic_index = None
        xo_index = None
        xo_column_width = 0
        for i, column in enumerate(line_parts):
            column_width = len(column)  # should use wcswidth here
            column = column.strip()
            if column == MNEMONIC_COLUMN_NAME:
                assert mnemonic_index is None, \
                    f"two {MNEMONIC_COLUMN_NAME!r} columns in table " \
                    f"-- can't handle that"
                mnemonic_index = i
            if column == XO_COLUMN_NAME:
                assert xo_index is None, \
                    f"two {XO_COLUMN_NAME!r} columns in table " \
                    f"-- can't handle that"
                xo_index = i
                xo_column_width = column_width
        if mnemonic_index is None and xo_index is None:
            # not an opcode table -- skip it
            yield from skip_table()
            return

        assert mnemonic_index is not None, \
            f"missing {MNEMONIC_COLUMN_NAME} column"
        assert xo_index is not None, f"missing {XO_COLUMN_NAME} column"

        yield line
        line, lineno = lr.read()

        yield from parse_table_separator_line(len_line_parts)
        while line is not None and line != "":
            line_parts = line.split('|')
            assert line.startswith('|') and line.endswith('|'), \
                "invalid table line, must start and end with |"
            assert len(line_parts) == len_line_parts, (
                f"invalid table line, wrong part count: found "
                f"{len(line_parts)} expected {len_line_parts}")

            mnemonic = line_parts[mnemonic_index].strip()
            xo = line_parts[xo_index].strip()
            if mnemonic not in mnemonic_to_xo_map:
                print(f"mnemonic not assigned an XO value: {mnemonic!r}")
            elif xo == "":
                xo = mnemonic_to_xo_map[mnemonic]
                xo_width = len(xo)  # should use wcswidth here
                if xo_width < xo_column_width:
                    # should use wc_ljust here
                    xo = (" " + xo).ljust(xo_column_width)
                line_parts[xo_index] = xo
            else:
                expected_xo = mnemonic_to_xo_map[mnemonic].replace(" ", "")
                assert xo.replace(" ", "") == expected_xo, (
                    f"mismatch in {XO_COLUMN_NAME} column: expected "
                    f"{mnemonic_to_xo_map[mnemonic]} found {xo!r}")

            yield '|'.join(line_parts)
            line, lineno = lr.read()

    try:
        yield from parse_mnemonic_to_opcode_map()

        print(mnemonic_to_xo_map)

        while line is not None:
            if line.startswith('|'):
                yield from handle_table()
            else:
                yield line
                line, lineno = lr.read()

    except AssertionError as e:
        raise AssertionError(f"\n{FILENAME}:{lineno}: error: {e}")


inp = FILENAME.read_text(encoding="utf-8")
output_lines = list(process(LineReader(inp.splitlines())))
if output_lines[-1] != "":
    output_lines.append("")  # ensure file ends with newline
NEW_FILENAME.write_text("\n".join(output_lines), encoding="utf-8")
FILENAME.replace(OLD_FILENAME)
NEW_FILENAME.rename(FILENAME)