#!/usr/bin/env python3

from collections import defaultdict
import sys
import re


def merge_continuation_lines(lines: "list[str]"):
    nest_level = 0
    cur = []
    for line in lines:
        cur.append(line)
        nest_level += len(re.findall(r"(?<!\\)\{", line))
        nest_level -= len(re.findall(r"(?<!\\)\}", line))
        assert nest_level >= 0, "too many closing }"
        if nest_level == 0:
            yield ''.join(cur)
            cur.clear()
    assert nest_level == 0, "missing closing }"


def merge_footnotes(lines: "list[str]"):
    inp_ctr = 0
    footnote_inp_ctr_to_text_map: "dict[int, str]" = {}

    def replace_footnotemark(match):
        nonlocal inp_ctr
        print(f"input footnote ref #{inp_ctr}")
        retval = "\\footnotemark{" + str(inp_ctr) + "}"
        inp_ctr += 1
        return retval

    tmpl_lines = []  # template lines
    for line in merge_continuation_lines(lines):
        parts = line.split(r'\footnotetext')
        if len(parts) > 1:
            assert len(parts) == 2 and parts[0] == '', \
                "\\footnotetext must only be at the beginning of a line"
            nest_level = 0
            footnote_parts = []
            trailing_parts = []
            after_footnote = False
            for part in re.split(r'(?<!\\)(\{|\})', parts[1]):
                if part == '{':
                    nest_level += 1
                    if nest_level == 1 and not after_footnote:
                        continue  # remove leading {
                if part == '}':
                    nest_level -= 1
                    if nest_level == 0 and not after_footnote:
                        after_footnote = True
                        continue  # remove trailing }
                if after_footnote:
                    trailing_parts.append(part)
                elif nest_level:
                    footnote_parts.append(part)
            footnote_text = ''.join(footnote_parts)
            trailing_text = ''.join(trailing_parts)
            print(f"input footnote #{inp_ctr - 1}: {footnote_text[:30]}")
            footnote_inp_ctr_to_text_map[inp_ctr - 1] = footnote_text
            line = "\\footnotetext{}" + trailing_text

        match = re.fullmatch(
            r"\\addtocounter\{footnote\}\{(-?[1-9][0-9]*)\}\n", line)
        if match:
            adj = int(match.group(1))
            inp_ctr += adj
            print(f"adjust input footnote counter by {adj} to {inp_ctr}")
            continue
        line = re.sub(r"\\footnotemark\{\}", replace_footnotemark, line)
        tmpl_lines.append(line)
    footnote_text_to_id_map: "dict[str, int]" = {}
    next_footnote_id = 1
    footnote_queue: "list[str]" = []

    def replace_footnotemark_tmpl(match: "re.Match[str]"):
        nonlocal next_footnote_id
        inp_ctr = int(match.group(1))
        text = footnote_inp_ctr_to_text_map[inp_ctr]
        footnote_id = footnote_text_to_id_map.get(text)
        if footnote_id is None:
            footnote_id = next_footnote_id
            next_footnote_id += 1
            footnote_text_to_id_map[text] = footnote_id
            footnote_queue.append(
                "\\footnotetext["
                + str(footnote_id) + "]{" + text + "}")
        return "\\footnotemark[" + str(footnote_id) + "]"

    retval = []
    for line in tmpl_lines:
        parts = line.split(r'\footnotetext{}')
        if len(parts) > 1:
            if len(footnote_queue) == 0:
                line = parts[1]
            else:
                line = footnote_queue.pop() + parts[1]
                for footnote in footnote_queue:
                    retval.append(footnote + "\n")
                footnote_queue.clear()
        line = re.sub(r"\\footnotemark\{([0-9]+)\}",
                      replace_footnotemark_tmpl, line)
        retval.append(line)
    return retval


with open(sys.argv[1], "r") as f:
    lines = list(f.readlines())

with open(sys.argv[2], "w") as o:
    if sys.argv[1].endswith("comparison_table_pre.tex"):
        o.write("\\renewcommand{\\footnotesize}"
                "{\\fontsize{6pt}{4pt}\\selectfont}\n")
        lines = merge_footnotes(lines)

    for line in lines:
        if sys.argv[1].endswith("comparison_table_pre.tex") and \
                line.startswith(r"\begin{itemize}"):
            o.write(line)
            o.write("\\itemsep -0.6em\n")
            continue
        o.write(line)