import sys
import copy
import getopt

import numpy as np

from collections import defaultdict, deque

def extract_words(line):
    sents = line.split()
    words = []
    for word in sents:
        word = word.lower().replace('ё', 'е')
        i = len(word) - 1
        while i >= 0 and not (word[i].isalpha() or word[i].isdigit()):
            i -= 1
        if i >= 0:
            words.append(word[:(i+1)])
    return words

def make_levenstein_table(source, correct, allow_transpositions=False,
        removal_cost=1.0, insertion_cost=1.0, replace_cost=1.0, transposition_cost=1.0):
    """
    Строит динамическую таблицу, применяемую при вычислении расстояния Левенштейна,
    а также массив обратных ссылок, применяемый при восстановлении выравнивания
    Builds a dynamic table, applied while computing Levenshtein distance,
    also an array of backward links, applied while restoring alignment

    :param source: list of strs, исходное предложение original sentence
    :param correct: list of strs, исправленное предложение corrected sentence
    :param allow_transpositions: bool, optional(default=False),
        разрешены ли перестановки соседних символов в расстоянии Левенштейна
        whether permutations of neighboring symbols are allowed in Levenshtein distance
    :param removal_cost: float, optional(default=1.0),
        штраф за удаление
        cost of deletion
    :param insertion_cost: float, optional(default=1.0),
        штраф за вставку
        cost of insertion
    :param replace_cost: float, optional(default=1.9),
        штраф за замену символов
        cost of replacing symbols
    :param transposition_cost: float, optional(default=1.9),
        штраф за перестановку символов
        cost of switching symbols
    :return:
        table, numpy 2D-array of float, двумерная таблица расстояний между префиксами, 2-dimensional table of distance between prefixes
            table[i][j] = d(source[:i], correct[:j])
        backtraces, 2D-array of lists,
            двумерный массив обратных ссылок при вычислении оптимального выравнивания
            2-dimensnional array of backward links during computation of optimal alignment
    """
    first_length, second_length = len(source), len(correct)
    table = np.zeros(shape=(first_length + 1, second_length + 1), dtype=float)
    backtraces = [([None]  * (second_length + 1)) for _ in range(first_length + 1)]
    for i in range(1, second_length + 1):
        table[0][i] = i
        backtraces[0][i] = [(0, i-1)]
    for i in range(1, first_length + 1):
        table[i][0] = i
        backtraces[i][0] = [(i-1, 0)]
    for i, first_word in enumerate(source, 1):
        for j, second_word in enumerate(correct, 1):
            if first_word == second_word:
                table[i][j] = table[i-1][j-1]
                backtraces[i][j] = [(i-1, j-1)]
            else:
                table[i][j] = min((table[i-1][j-1] + replace_cost,
                                   table[i][j-1] + removal_cost,
                                   table[i-1][j] + insertion_cost))
                if (allow_transpositions and min(i, j) >= 2
                        and first_word == correct[j-2] and second_word == source[j-2]):
                    table[i][j] = min(table[i][j], table[i-2][j-2] + transposition_cost)
                curr_backtraces = []
                if table[i-1][j-1] + replace_cost == table[i][j]:
                    curr_backtraces.append((i-1, j-1))
                if table[i][j-1] + removal_cost == table[i][j]:
                    curr_backtraces.append((i, j-1))
                if table[i-1][j] + insertion_cost == table[i][j]:
                    curr_backtraces.append((i-1, j))
                if (allow_transpositions and min(i, j) >= 2
                    and first_word == correct[j-2] and second_word == source[j-2]
                        and table[i][j] == table[i-2][j-2] + transposition_cost):
                    curr_backtraces.append((i-2, j-2))
                backtraces[i][j] = copy.copy(curr_backtraces)
    return table, backtraces

def extract_best_alignment(backtraces):
    """
    Извлекает оптимальное выравнивание из таблицы обратных ссылок
    Derives optimal alignment from tables of backward links

    :param backtraces, 2D-array of lists,
        двумерный массив обратных ссылок при вычислении оптимального выравнивания
        2-dimensional array of backward links during computation of optimal alignment
    :return: best_paths, list of lists,
        список путей, ведущих из точки (0, 0) в точку (m, n) в массиве backtraces
        list of paths leading from point (0,0) to point (m,n) in array of backtraces
    """
    m, n = len(backtraces) - 1, len(backtraces[0]) - 1
    used_vertexes = {(m, n)}
    reverse_path_graph = defaultdict(list)
    vertexes_queue = [(m, n)]
    # строим граф наилучших путей в таблице
    # build graph of best paths in table
    while len(vertexes_queue) > 0:
        i, j = vertex = vertexes_queue.pop(0)
        if i > 0 or j > 0:
            for new_vertex in backtraces[i][j]:
                reverse_path_graph[new_vertex].append(vertex)
                if new_vertex not in used_vertexes:
                    vertexes_queue.append(new_vertex)
                    used_vertexes.add(new_vertex)
    # проходим пути в обратном направлении
    # walk paths in backward direction
    best_paths = []
    current_path = [(0, 0)]
    last_indexes, neighbor_vertexes_list = [], []
    while len(current_path) > 0:
        if current_path[-1] != (m, n):
            children = reverse_path_graph[current_path[-1]]
            if len(children) > 0:
                current_path.append(children[0])
                last_indexes.append(0)
                neighbor_vertexes_list.append(children)
                continue
        else:
            best_paths.append(copy.copy(current_path))
        while len(last_indexes) > 0 and last_indexes[-1] == len(neighbor_vertexes_list[-1]) - 1:
            current_path.pop()
            last_indexes.pop()
            neighbor_vertexes_list.pop()
        if len(last_indexes) == 0:
            break
        last_indexes[-1] += 1
        current_path[-1] = neighbor_vertexes_list[-1][last_indexes[-1]]
    return best_paths

def extract_basic_alignment_paths(paths_in_alignments, source, correct):
    """
    Извлекает из путей в таблице Левенштейна тождественные замены в выравнивании
    Derives from paths in the Levenshtein table identical substitutions in alignment

    :param paths_in_alignments: list of lists, список оптимальных путей
        в таблице из точки (0, 0) в точку (len(source), len(correct))
        list of optimal paths in the table from point(0,0) to point (len(source), len(correct))
    :param source: str, исходная строка, original line
    :param correct: str, строка с исправлениями line with corrections
    :return:
        answer: list, список вариантов тождественных замен в оптимальных путях
        list of options of identical substitutions in optimal paths
    """
    m, n = len(source), len(correct)
    are_symbols_equal = np.zeros(dtype=bool, shape=(m, n))
    for i, a in enumerate(source):
        for j, b in enumerate(correct):
            are_symbols_equal[i][j] = (a == b)
    answer = set()
    for path in paths_in_alignments:
        answer.add(tuple(elem for elem in path[1:]
                         if are_symbols_equal[elem[0]-1][elem[1]-1]))
    return list(answer)

def extract_levenstein_alignments(source, correct):
    """
    Находит позиции тождественных замен
    в оптимальном выравнивании между source и correct
    Finds positions of identical substitutions in optimal alignment between source and correct

    :param source: str. исходная строка Original line
    :param correct: str, исправленная строка Corrected line
    :return: basic_alignment_paths, list of lists of pairs of ints
        список позиций тождественных замен в оптимальном выравнивании
        list of positions of identical substitutions in the optimal alignment
    """
    table, backtraces = make_levenstein_table(source, correct, replace_cost=1.9)
    paths_in_alignments = extract_best_alignment(backtraces)
    basic_alignment_paths = extract_basic_alignment_paths(paths_in_alignments, source, correct)
    return basic_alignment_paths

def get_partition_indexes(first, second):
    """
    Строит оптимальное разбиение на группы (ошибка, исправление)
    Группа заканчивается после first[i] и second[j], если пара из
    концов этих слов встречается в оптимальном пути в таблице Левенштейна
    для " ".join(first) и " ".join(second)
    Build optimal partition into groups (mistake, correction)
    The group ends after first[i] and second[j], if pair from ends of these words is found in optimal path in
    the Levenshtein table for " ".join(first) and " ".join(second)

    :param first: list of strs, список исходных слов List of original words
    :param second: list of strs, их исправление Their correction
    :return: answer, list of pairs of ints,
        список пар (f[0], s[0]), (f[1], s[1]), ...
        list of pairs
        отрезок second[s[i]: s[i+1]] является исправлением для first[f[i]: f[i+1]]
        segment second[s[i]: s[i+1]] is the correction for first[f[i]: f[i+1]]
    """
    m, n = len(first), len(second)
    answer = [(0, 0)]
    if m <= 1 or n <= 1:
        answer += [(m, n)]
    elif m == 2 and n == 2:
        answer += [(1, 1), (2, 2)]
    else:
        levenstein_table, backtraces = make_levenstein_table(" ".join(first), " ".join(second))
        best_paths_in_table = extract_best_alignment(backtraces)
        good_partitions, other_partitions = set(), set()
        word_ends = [0], [0]
        last = -1
        for i, word in enumerate(first):
            last = last + len(word) + 1
            word_ends[0].append(last)
        last = -1
        for i, word in enumerate(second):
            last = last + len(word) + 1
            word_ends[1].append(last)
        for path in best_paths_in_table:
            current_indexes = [(0, 0)]
            first_pos, second_pos = 0, 0
            is_partition_good = True
            for i, j in path[1:]:
                if i > word_ends[0][first_pos]:
                    first_pos += 1
                if j > word_ends[1][second_pos]:
                    second_pos += 1
                if i == word_ends[0][first_pos] and j == word_ends[1][second_pos]:
                    if first_pos > current_indexes[-1][0] or second_pos > current_indexes[-1][1]:
                        current_indexes.append((first_pos, second_pos))
                        first_pos += 1
                        second_pos += 1
                    else:
                        is_partition_good = False
            if current_indexes[-1] == (m, n):
                if is_partition_good:
                    good_partitions.add(tuple(current_indexes))
                else:
                    other_partitions.add(tuple(current_indexes))
        if len(good_partitions) == 1:
            answer = list(good_partitions)[0]
        else:
            answer = list(other_partitions)[0]
    return answer

def align_sents(source, correct, return_only_different=False):
    """
    Возвращает индексы границ групп в оптимальном выравнивании
    Returns indices of group boundaries in optimal alignment

    :param source, correct: str, исходное и исправленное предложение original and corrected sentence
    :param return_only_different: следует ли возвращать только индексы нетождественных исправлений Whether it is only necessary to return indices of identical corrections
    :return: answer, list of pairs of tuples,
        оптимальное разбиение на группы. Если answer[i] == ((i, j), (k, l)), то
        в одну группу входят source[i:j] и correct[k:l]
        optimal partition into groups. If answer[i] == ((i, j), (k, l)), then
        source[i:j] and correct[k:l] belong to one group.
    """
    alignments = extract_levenstein_alignments(source, correct)
    m, n = len(source), len(correct)
    prev = 0, 0
    answer = []
    for i, j in alignments[0]:
        if i > prev[0] + 1 or j > prev[1] + 1:
            partition_indexes =\
                get_partition_indexes(source[prev[0]: i-1], correct[prev[1]: j-1])
            if partition_indexes is not None:
                for pos, (f, s) in enumerate(partition_indexes[:-1]):
                    answer.append(((prev[0] + f, prev[0] + partition_indexes[pos+1][0]),
                                   (prev[1] + s, prev[1] + partition_indexes[pos+1][1])))
            else:
                answer.append((prev[0], i-1), (prev[1], j-1))
        if not return_only_different:
            answer.append(((i-1, i), (j-1, j)))
        prev = i, j
    if m > prev[0] or n > prev[1]:
        partition_indexes =\
                get_partition_indexes(source[prev[0]: m], correct[prev[1]: n])
        for pos, (f, s) in enumerate(partition_indexes[:-1]):
                answer.append(((prev[0] + f, prev[0] + partition_indexes[pos+1][0]),
                               (prev[1] + s, prev[1] + partition_indexes[pos+1][1])))
    return answer


if __name__ == "__main__":
    args = sys.argv[1:]
    output_differences = False
    opts, args = getopt.getopt(args, "d:", ["--differences="])
    for opt, val in opts:
        if opt in ["-d", "--differences"]:
            output_differences = True
            diff_file = val
        else:
            print(ValueError("Wrong option {0}".format(opt)))
    if len(args) != 3:
        sys.exit("Использование: evaluate.py source_file correct_file answer_file\n"
                 "source_file: исходный файл\n"
                 "correct_file: файл с эталонными исправлениями\n"
                 "answer_file: файл с ответами системы\n")
    source_file, correct_file, answer_file = args
    with open(source_file, "r", encoding="utf8") as fsource,\
            open(correct_file, "r", encoding="utf8") as fcorr,\
            open(answer_file, "r", encoding="utf8") as fans:
        source_sents = [extract_words(line.strip())
                        for line in fsource.readlines() if line.strip() != ""]
        correct_sents = [extract_words(line.strip())
                         for line in fcorr.readlines() if line.strip() != ""]
        answer_sents = [extract_words(line.strip())
                        for line in fans.readlines() if line.strip() != ""]
    etalon_corrections = dict()
    answer_corrections = dict()
    for num, (source, correct, answer) in\
            enumerate(zip(source_sents, correct_sents, answer_sents)):
        indexes = align_sents(source, correct, return_only_different=True)
        for ((i, j), (k, l)) in indexes:
            etalon_corrections[(num, i, j)] = tuple(correct[k:l])
        indexes = align_sents(source, answer, return_only_different=True)
        for ((i, j), (k, l)) in indexes:
            answer_corrections[(num, i, j)] = tuple(answer[k:l])
    TP = 0
    for triple, answer_correction in answer_corrections.items():
        etalon_correction = etalon_corrections.get(triple)
        if etalon_correction == answer_correction:
            TP += 1
    precision = TP / len(answer_corrections)
    recall = TP / len(etalon_corrections)
    f_measure = 2 * precision * recall / (precision + recall)
    print("Precision={0:.2f} Recall={1:.2f} FMeasure={2:.2f}".format(
        100 * precision, 100 * recall, 100 * f_measure))
    print(TP, len(answer_corrections), len(etalon_corrections))
    if output_differences:
        false_positives = defaultdict(list)
        false_negatives = defaultdict(list)
        miscorrections = defaultdict(list)
        for (num, i, j), answer_correction in answer_corrections.items():
            etalon_correction = etalon_corrections.get((num, i, j))
            if etalon_correction is None:
                false_positives[num].append(((i, j), answer_correction))
            elif etalon_correction != answer_correction:
                miscorrections[num].append(((i, j), answer_correction, etalon_correction))
        for (num, i, j), etalon_correction in etalon_corrections.items():
            answer_correction = answer_corrections.get((num, i, j))
            if answer_correction is None:
                false_negatives[num].append(((i, j), etalon_correction))
        with open(diff_file, "w", encoding="utf8") as fout:
            width = 24
            for num, sent in enumerate(source_sents):
                current_false_positives = false_positives[num]
                current_false_negatives = false_negatives[num]
                current_miscorrections = miscorrections[num]
                if (len(current_false_positives) == 0 and len(current_false_negatives) == 0 and
                        len(current_miscorrections) == 0):
                    continue
                fout.write("{0}\n{1}\n{2}\n".format(
                    " ".join(sent), " ".join(answer_sents[num]), " ".join(correct_sents[num])))
                for (i, j), answer_correction in current_false_positives:
                    fout.write("{0:<{width}}{1:<{width}}{2:<{width}}\n".format(" ".join(sent[i:j]),
                        " ".join(answer_correction), " ".join(sent[i:j]), width=width))
                for (i, j), etalon_correction in current_false_negatives:
                    fout.write("{0:<{width}}{1:<{width}}{2:<{width}}\n".format(" ".join(sent[i:j]),
                        " ".join(sent[i:j]), " ".join(etalon_correction), width=width))
                for (i, j), answer_correction, etalon_correction in current_miscorrections:
                    fout.write("{0:<{width}}{1:<{width}}{2:<{width}}\n".format(" ".join(sent[i:j]),
                        " ".join(answer_correction), " ".join(etalon_correction), width=width))
                fout.write("\n")



