import re
from collections import namedtuple
from typing import Any, Dict, List, Optional, Tuple

from langchain_core._api.deprecation import deprecated

Schema = namedtuple("Schema", ["left_node", "relation", "right_node"])


@deprecated(
    since="0.3.8",
    removal="1.0",
    alternative_import="langchain_neo4j.chains.graph_qa.cypher_utils.CypherQueryCorrector",
)
class CypherQueryCorrector:
    """
    Used to correct relationship direction in generated Cypher statements.
    This code is copied from the winner's submission to the Cypher competition:
    https://github.com/sakusaku-rich/cypher-direction-competition
    """

    property_pattern = re.compile(r"\{.+?\}")
    node_pattern = re.compile(r"\(.+?\)")
    path_pattern = re.compile(
        r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(<?-)(\[.*?\])?(->?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))"
    )
    node_relation_node_pattern = re.compile(
        r"(\()+(?P<left_node>[^()]*?)\)(?P<relation>.*?)\((?P<right_node>[^()]*?)(\))+"
    )
    relation_type_pattern = re.compile(r":(?P<relation_type>.+?)?(\{.+\})?]")

    def __init__(self, schemas: List[Schema]):
        """
        Args:
            schemas: list of schemas
        """
        self.schemas = schemas

    def clean_node(self, node: str) -> str:
        """
        Args:
            node: node in string format

        """
        node = re.sub(self.property_pattern, "", node)
        node = node.replace("(", "")
        node = node.replace(")", "")
        node = node.strip()
        return node

    def detect_node_variables(self, query: str) -> Dict[str, List[str]]:
        """
        Args:
            query: cypher query
        """
        nodes = re.findall(self.node_pattern, query)
        nodes = [self.clean_node(node) for node in nodes]
        res: Dict[str, Any] = {}
        for node in nodes:
            parts = node.split(":")
            if parts == "":
                continue
            variable = parts[0]
            if variable not in res:
                res[variable] = []
            res[variable] += parts[1:]
        return res

    def extract_paths(self, query: str) -> "List[str]":
        """
        Args:
            query: cypher query
        """
        paths = []
        idx = 0
        while matched := self.path_pattern.findall(query[idx:]):
            matched = matched[0]
            matched = [
                m for i, m in enumerate(matched) if i not in [1, len(matched) - 1]
            ]
            path = "".join(matched)
            idx = query.find(path) + len(path) - len(matched[-1])
            paths.append(path)
        return paths

    def judge_direction(self, relation: str) -> str:
        """
        Args:
            relation: relation in string format
        """
        direction = "BIDIRECTIONAL"
        if relation[0] == "<":
            direction = "INCOMING"
        if relation[-1] == ">":
            direction = "OUTGOING"
        return direction

    def extract_node_variable(self, part: str) -> Optional[str]:
        """
        Args:
            part: node in string format
        """
        part = part.lstrip("(").rstrip(")")
        idx = part.find(":")
        if idx != -1:
            part = part[:idx]
        return None if part == "" else part

    def detect_labels(
        self, str_node: str, node_variable_dict: Dict[str, Any]
    ) -> List[str]:
        """
        Args:
            str_node: node in string format
            node_variable_dict: dictionary of node variables
        """
        splitted_node = str_node.split(":")
        variable = splitted_node[0]
        labels = []
        if variable in node_variable_dict:
            labels = node_variable_dict[variable]
        elif variable == "" and len(splitted_node) > 1:
            labels = splitted_node[1:]
        return labels

    def verify_schema(
        self,
        from_node_labels: List[str],
        relation_types: List[str],
        to_node_labels: List[str],
    ) -> bool:
        """
        Args:
            from_node_labels: labels of the from node
            relation_type: type of the relation
            to_node_labels: labels of the to node
        """
        valid_schemas = self.schemas
        if from_node_labels != []:
            from_node_labels = [label.strip("`") for label in from_node_labels]
            valid_schemas = [
                schema for schema in valid_schemas if schema[0] in from_node_labels
            ]
        if to_node_labels != []:
            to_node_labels = [label.strip("`") for label in to_node_labels]
            valid_schemas = [
                schema for schema in valid_schemas if schema[2] in to_node_labels
            ]
        if relation_types != []:
            relation_types = [type.strip("`") for type in relation_types]
            valid_schemas = [
                schema for schema in valid_schemas if schema[1] in relation_types
            ]
        return valid_schemas != []

    def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]:
        """
        Args:
            str_relation: relation in string format
        """
        relation_direction = self.judge_direction(str_relation)
        relation_type = self.relation_type_pattern.search(str_relation)
        if relation_type is None or relation_type.group("relation_type") is None:
            return relation_direction, []
        relation_types = [
            t.strip().strip("!")
            for t in relation_type.group("relation_type").split("|")
        ]
        return relation_direction, relation_types

    def correct_query(self, query: str) -> str:
        """
        Args:
            query: cypher query
        """
        node_variable_dict = self.detect_node_variables(query)
        paths = self.extract_paths(query)
        for path in paths:
            original_path = path
            start_idx = 0
            while start_idx < len(path):
                match_res = re.match(self.node_relation_node_pattern, path[start_idx:])
                if match_res is None:
                    break
                start_idx += match_res.start()
                match_dict = match_res.groupdict()
                left_node_labels = self.detect_labels(
                    match_dict["left_node"], node_variable_dict
                )
                right_node_labels = self.detect_labels(
                    match_dict["right_node"], node_variable_dict
                )
                end_idx = (
                    start_idx
                    + 4
                    + len(match_dict["left_node"])
                    + len(match_dict["relation"])
                    + len(match_dict["right_node"])
                )
                original_partial_path = original_path[start_idx : end_idx + 1]
                relation_direction, relation_types = self.detect_relation_types(
                    match_dict["relation"]
                )

                if relation_types != [] and "".join(relation_types).find("*") != -1:
                    start_idx += (
                        len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
                    )
                    continue

                if relation_direction == "OUTGOING":
                    is_legal = self.verify_schema(
                        left_node_labels, relation_types, right_node_labels
                    )
                    if not is_legal:
                        is_legal = self.verify_schema(
                            right_node_labels, relation_types, left_node_labels
                        )
                        if is_legal:
                            corrected_relation = "<" + match_dict["relation"][:-1]
                            corrected_partial_path = original_partial_path.replace(
                                match_dict["relation"], corrected_relation
                            )
                            query = query.replace(
                                original_partial_path, corrected_partial_path
                            )
                        else:
                            return ""
                elif relation_direction == "INCOMING":
                    is_legal = self.verify_schema(
                        right_node_labels, relation_types, left_node_labels
                    )
                    if not is_legal:
                        is_legal = self.verify_schema(
                            left_node_labels, relation_types, right_node_labels
                        )
                        if is_legal:
                            corrected_relation = match_dict["relation"][1:] + ">"
                            corrected_partial_path = original_partial_path.replace(
                                match_dict["relation"], corrected_relation
                            )
                            query = query.replace(
                                original_partial_path, corrected_partial_path
                            )
                        else:
                            return ""
                else:
                    is_legal = self.verify_schema(
                        left_node_labels, relation_types, right_node_labels
                    )
                    is_legal |= self.verify_schema(
                        right_node_labels, relation_types, left_node_labels
                    )
                    if not is_legal:
                        return ""

                start_idx += (
                    len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
                )
        return query

    def __call__(self, query: str) -> str:
        """Correct the query to make it valid. If
        Args:
            query: cypher query
        """
        return self.correct_query(query)
