"""Pydantic models for parsing an OpenAPI spec."""

from __future__ import annotations

import logging
from enum import Enum
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
)

from pydantic import BaseModel, Field

from langchain_community.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec

logger = logging.getLogger(__name__)
PRIMITIVE_TYPES = {
    "integer": int,
    "number": float,
    "string": str,
    "boolean": bool,
    "array": List,
    "object": Dict,
    "null": None,
}


# See https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#parameterIn
# for more info.
class APIPropertyLocation(Enum):
    """The location of the property."""

    QUERY = "query"
    PATH = "path"
    HEADER = "header"
    COOKIE = "cookie"  # Not yet supported

    @classmethod
    def from_str(cls, location: str) -> "APIPropertyLocation":
        """Parse an APIPropertyLocation."""
        try:
            return cls(location)
        except ValueError:
            raise ValueError(
                f"Invalid APIPropertyLocation. Valid values are {cls.__members__}"
            )


_SUPPORTED_MEDIA_TYPES = ("application/json",)

SUPPORTED_LOCATIONS = {
    APIPropertyLocation.HEADER,
    APIPropertyLocation.QUERY,
    APIPropertyLocation.PATH,
}
INVALID_LOCATION_TEMPL = (
    'Unsupported APIPropertyLocation "{location}"'
    " for parameter {name}. "
    + f"Valid values are {[loc.value for loc in SUPPORTED_LOCATIONS]}"
)

SCHEMA_TYPE = Union[str, Type, tuple, None, Enum]


class APIPropertyBase(BaseModel):
    """Base model for an API property."""

    # The name of the parameter is required and is case-sensitive.
    # If "in" is "path", the "name" field must correspond to a template expression
    # within the path field in the Paths Object.
    # If "in" is "header" and the "name" field is "Accept", "Content-Type",
    # or "Authorization", the parameter definition is ignored.
    # For all other cases, the "name" corresponds to the parameter
    # name used by the "in" property.
    name: str = Field(alias="name")
    """The name of the property."""

    required: bool = Field(alias="required")
    """Whether the property is required."""

    type: SCHEMA_TYPE = Field(alias="type")
    """The type of the property.
    
    Either a primitive type, a component/parameter type,
    or an array or 'object' (dict) of the above."""

    default: Optional[Any] = Field(alias="default", default=None)
    """The default value of the property."""

    description: Optional[str] = Field(alias="description", default=None)
    """The description of the property."""


if TYPE_CHECKING:
    from openapi_pydantic import (
        MediaType,
        Parameter,
        RequestBody,
        Schema,
    )


class APIProperty(APIPropertyBase):
    """A model for a property in the query, path, header, or cookie params."""

    location: APIPropertyLocation = Field(alias="location")
    """The path/how it's being passed to the endpoint."""

    @staticmethod
    def _cast_schema_list_type(
        schema: Schema,
    ) -> Optional[Union[str, Tuple[str, ...]]]:
        type_ = schema.type
        if not isinstance(type_, list):
            return type_
        else:
            return tuple(type_)

    @staticmethod
    def _get_schema_type_for_enum(parameter: Parameter, schema: Schema) -> Enum:
        """Get the schema type when the parameter is an enum."""
        param_name = f"{parameter.name}Enum"
        return Enum(param_name, {str(v): v for v in schema.enum})

    @staticmethod
    def _get_schema_type_for_array(
        schema: Schema,
    ) -> Optional[Union[str, Tuple[str, ...]]]:
        from openapi_pydantic import (
            Reference,
            Schema,
        )

        items = schema.items
        if isinstance(items, Schema):
            schema_type = APIProperty._cast_schema_list_type(items)
        elif isinstance(items, Reference):
            ref_name = items.ref.split("/")[-1]
            schema_type = ref_name  # TODO: Add ref definitions to make his valid
        else:
            raise ValueError(f"Unsupported array items: {items}")

        if isinstance(schema_type, str):
            # TODO: recurse
            schema_type = (schema_type,)

        return schema_type

    @staticmethod
    def _get_schema_type(parameter: Parameter, schema: Optional[Schema]) -> SCHEMA_TYPE:
        if schema is None:
            return None
        schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema)
        if schema_type == "array":
            schema_type = APIProperty._get_schema_type_for_array(schema)
        elif schema_type == "object":
            # TODO: Resolve array and object types to components.
            raise NotImplementedError("Objects not yet supported")
        elif schema_type in PRIMITIVE_TYPES:
            if schema.enum:
                schema_type = APIProperty._get_schema_type_for_enum(parameter, schema)
            else:
                # Directly use the primitive type
                pass
        else:
            raise NotImplementedError(f"Unsupported type: {schema_type}")

        return schema_type

    @staticmethod
    def _validate_location(location: APIPropertyLocation, name: str) -> None:
        if location not in SUPPORTED_LOCATIONS:
            raise NotImplementedError(
                INVALID_LOCATION_TEMPL.format(location=location, name=name)
            )

    @staticmethod
    def _validate_content(content: Optional[Dict[str, MediaType]]) -> None:
        if content:
            raise ValueError(
                "API Properties with media content not supported. "
                "Media content only supported within APIRequestBodyProperty's"
            )

    @staticmethod
    def _get_schema(parameter: Parameter, spec: OpenAPISpec) -> Optional[Schema]:
        from openapi_pydantic import (
            Reference,
            Schema,
        )

        schema = parameter.param_schema
        if isinstance(schema, Reference):
            schema = spec.get_referenced_schema(schema)
        elif schema is None:
            return None
        elif not isinstance(schema, Schema):
            raise ValueError(f"Error dereferencing schema: {schema}")

        return schema

    @staticmethod
    def is_supported_location(location: str) -> bool:
        """Return whether the provided location is supported."""
        try:
            return APIPropertyLocation.from_str(location) in SUPPORTED_LOCATIONS
        except ValueError:
            return False

    @classmethod
    def from_parameter(cls, parameter: Parameter, spec: OpenAPISpec) -> "APIProperty":
        """Instantiate from an OpenAPI Parameter."""
        location = APIPropertyLocation.from_str(parameter.param_in)
        cls._validate_location(
            location,
            parameter.name,
        )
        cls._validate_content(parameter.content)
        schema = cls._get_schema(parameter, spec)
        schema_type = cls._get_schema_type(parameter, schema)
        default_val = schema.default if schema is not None else None
        return cls(
            name=parameter.name,
            location=location,
            default=default_val,
            description=parameter.description,
            required=parameter.required,
            type=schema_type,
        )


class APIRequestBodyProperty(APIPropertyBase):
    """A model for a request body property."""

    properties: List["APIRequestBodyProperty"] = Field(alias="properties")
    """The sub-properties of the property."""

    # This is useful for handling nested property cycles.
    # We can define separate types in that case.
    references_used: List[str] = Field(alias="references_used")
    """The references used by the property."""

    @classmethod
    def _process_object_schema(
        cls, schema: Schema, spec: OpenAPISpec, references_used: List[str]
    ) -> Tuple[Union[str, List[str], None], List["APIRequestBodyProperty"]]:
        from openapi_pydantic import (
            Reference,
        )

        properties = []
        required_props = schema.required or []
        if schema.properties is None:
            raise ValueError(
                f"No properties found when processing object schema: {schema}"
            )
        for prop_name, prop_schema in schema.properties.items():
            if isinstance(prop_schema, Reference):
                ref_name = prop_schema.ref.split("/")[-1]
                if ref_name not in references_used:
                    references_used.append(ref_name)
                    prop_schema = spec.get_referenced_schema(prop_schema)
                else:
                    continue

            properties.append(
                cls.from_schema(
                    schema=prop_schema,
                    name=prop_name,
                    required=prop_name in required_props,
                    spec=spec,
                    references_used=references_used,
                )
            )
        return schema.type, properties

    @classmethod
    def _process_array_schema(
        cls,
        schema: Schema,
        name: str,
        spec: OpenAPISpec,
        references_used: List[str],
    ) -> str:
        from openapi_pydantic import Reference, Schema

        items = schema.items
        if items is not None:
            if isinstance(items, Reference):
                ref_name = items.ref.split("/")[-1]
                if ref_name not in references_used:
                    references_used.append(ref_name)
                    items = spec.get_referenced_schema(items)
                else:
                    pass
                return f"Array<{ref_name}>"
            else:
                pass

            if isinstance(items, Schema):
                array_type = cls.from_schema(
                    schema=items,
                    name=f"{name}Item",
                    required=True,  # TODO: Add required
                    spec=spec,
                    references_used=references_used,
                )
                return f"Array<{array_type.type}>"

        return "array"

    @classmethod
    def from_schema(
        cls,
        schema: Schema,
        name: str,
        required: bool,
        spec: OpenAPISpec,
        references_used: Optional[List[str]] = None,
    ) -> "APIRequestBodyProperty":
        """Recursively populate from an OpenAPI Schema."""
        if references_used is None:
            references_used = []

        schema_type = schema.type
        properties: List[APIRequestBodyProperty] = []
        if schema_type == "object" and schema.properties:
            schema_type, properties = cls._process_object_schema(
                schema, spec, references_used
            )
        elif schema_type == "array":
            schema_type = cls._process_array_schema(schema, name, spec, references_used)
        elif schema_type in PRIMITIVE_TYPES:
            # Use the primitive type directly
            pass
        elif schema_type is None:
            # No typing specified/parsed. WIll map to 'any'
            pass
        else:
            raise ValueError(f"Unsupported type: {schema_type}")

        return cls(
            name=name,
            required=required,
            type=schema_type,
            default=schema.default,
            description=schema.description,
            properties=properties,
            references_used=references_used,
        )


# class APIRequestBodyProperty(APIPropertyBase):
class APIRequestBody(BaseModel):
    """A model for a request body."""

    description: Optional[str] = Field(alias="description")
    """The description of the request body."""

    properties: List[APIRequestBodyProperty] = Field(alias="properties")

    # E.g., application/json - we only support JSON at the moment.
    media_type: str = Field(alias="media_type")
    """The media type of the request body."""

    @classmethod
    def _process_supported_media_type(
        cls,
        media_type_obj: MediaType,
        spec: OpenAPISpec,
    ) -> List[APIRequestBodyProperty]:
        """Process the media type of the request body."""
        from openapi_pydantic import Reference

        references_used = []
        schema = media_type_obj.media_type_schema
        if isinstance(schema, Reference):
            references_used.append(schema.ref.split("/")[-1])
            schema = spec.get_referenced_schema(schema)
        if schema is None:
            raise ValueError(
                f"Could not resolve schema for media type: {media_type_obj}"
            )
        api_request_body_properties = []
        required_properties = schema.required or []
        if schema.type == "object" and schema.properties:
            for prop_name, prop_schema in schema.properties.items():
                if isinstance(prop_schema, Reference):
                    prop_schema = spec.get_referenced_schema(prop_schema)

                api_request_body_properties.append(
                    APIRequestBodyProperty.from_schema(
                        schema=prop_schema,
                        name=prop_name,
                        required=prop_name in required_properties,
                        spec=spec,
                    )
                )
        else:
            api_request_body_properties.append(
                APIRequestBodyProperty(
                    name="body",
                    required=True,
                    type=schema.type,
                    default=schema.default,
                    description=schema.description,
                    properties=[],
                    references_used=references_used,
                )
            )

        return api_request_body_properties

    @classmethod
    def from_request_body(
        cls, request_body: RequestBody, spec: OpenAPISpec
    ) -> "APIRequestBody":
        """Instantiate from an OpenAPI RequestBody."""
        properties = []
        for media_type, media_type_obj in request_body.content.items():
            if media_type not in _SUPPORTED_MEDIA_TYPES:
                continue
            api_request_body_properties = cls._process_supported_media_type(
                media_type_obj,
                spec,
            )
            properties.extend(api_request_body_properties)

        return cls(
            description=request_body.description,
            properties=properties,
            media_type=media_type,
        )


# class APIRequestBodyProperty(APIPropertyBase):
#     class APIRequestBody(BaseModel):
class APIOperation(BaseModel):
    """A model for a single API operation."""

    operation_id: str = Field(alias="operation_id")
    """The unique identifier of the operation."""

    description: Optional[str] = Field(alias="description")
    """The description of the operation."""

    base_url: str = Field(alias="base_url")
    """The base URL of the operation."""

    path: str = Field(alias="path")
    """The path of the operation."""

    method: HTTPVerb = Field(alias="method")
    """The HTTP method of the operation."""

    properties: Sequence[APIProperty] = Field(alias="properties")

    # TODO: Add parse in used components to be able to specify what type of
    # referenced object it is.
    # """The properties of the operation."""
    # components: Dict[str, BaseModel] = Field(alias="components")

    request_body: Optional[APIRequestBody] = Field(alias="request_body")
    """The request body of the operation."""

    @staticmethod
    def _get_properties_from_parameters(
        parameters: List[Parameter], spec: OpenAPISpec
    ) -> List[APIProperty]:
        """Get the properties of the operation."""
        properties = []
        for param in parameters:
            if APIProperty.is_supported_location(param.param_in):
                properties.append(APIProperty.from_parameter(param, spec))
            elif param.required:
                raise ValueError(
                    INVALID_LOCATION_TEMPL.format(
                        location=param.param_in, name=param.name
                    )
                )
            else:
                logger.warning(
                    INVALID_LOCATION_TEMPL.format(
                        location=param.param_in, name=param.name
                    )
                    + " Ignoring optional parameter"
                )
                pass
        return properties

    @classmethod
    def from_openapi_url(
        cls,
        spec_url: str,
        path: str,
        method: str,
    ) -> "APIOperation":
        """Create an APIOperation from an OpenAPI URL."""
        spec = OpenAPISpec.from_url(spec_url)
        return cls.from_openapi_spec(spec, path, method)

    @classmethod
    def from_openapi_spec(
        cls,
        spec: OpenAPISpec,
        path: str,
        method: str,
    ) -> "APIOperation":
        """Create an APIOperation from an OpenAPI spec."""
        operation = spec.get_operation(path, method)
        parameters = spec.get_parameters_for_operation(operation)
        properties = cls._get_properties_from_parameters(parameters, spec)
        operation_id = OpenAPISpec.get_cleaned_operation_id(operation, path, method)
        request_body = spec.get_request_body_for_operation(operation)
        api_request_body = (
            APIRequestBody.from_request_body(request_body, spec)
            if request_body is not None
            else None
        )
        description = operation.description or operation.summary
        if not description and spec.paths is not None:
            description = spec.paths[path].description or spec.paths[path].summary
        return cls(
            operation_id=operation_id,
            description=description or "",
            base_url=spec.base_url,
            path=path,
            method=method,  # type: ignore[arg-type]
            properties=properties,
            request_body=api_request_body,
        )

    @staticmethod
    def ts_type_from_python(type_: SCHEMA_TYPE) -> str:
        if type_ is None:
            # TODO: Handle Nones better. These often result when
            # parsing specs that are < v3
            return "any"
        elif isinstance(type_, str):
            return {
                "str": "string",
                "integer": "number",
                "float": "number",
                "date-time": "string",
            }.get(type_, type_)
        elif isinstance(type_, tuple):
            return f"Array<{APIOperation.ts_type_from_python(type_[0])}>"
        elif isinstance(type_, type) and issubclass(type_, Enum):
            return " | ".join([f"'{e.value}'" for e in type_])
        else:
            return str(type_)

    def _format_nested_properties(
        self, properties: List[APIRequestBodyProperty], indent: int = 2
    ) -> str:
        """Format nested properties."""
        formatted_props = []

        for prop in properties:
            prop_name = prop.name
            prop_type = self.ts_type_from_python(prop.type)
            prop_required = "" if prop.required else "?"
            prop_desc = f"/* {prop.description} */" if prop.description else ""

            if prop.properties:
                nested_props = self._format_nested_properties(
                    prop.properties, indent + 2
                )
                prop_type = f"{{\n{nested_props}\n{' ' * indent}}}"

            formatted_props.append(
                f"{prop_desc}\n{' ' * indent}{prop_name}"
                f"{prop_required}: {prop_type},"
            )

        return "\n".join(formatted_props)

    def to_typescript(self) -> str:
        """Get typescript string representation of the operation."""
        operation_name = self.operation_id
        params = []

        if self.request_body:
            formatted_request_body_props = self._format_nested_properties(
                self.request_body.properties
            )
            params.append(formatted_request_body_props)

        for prop in self.properties:
            prop_name = prop.name
            prop_type = self.ts_type_from_python(prop.type)
            prop_required = "" if prop.required else "?"
            prop_desc = f"/* {prop.description} */" if prop.description else ""
            params.append(f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},")

        formatted_params = "\n".join(params).strip()
        description_str = f"/* {self.description} */" if self.description else ""
        typescript_definition = f"""
{description_str}
type {operation_name} = (_: {{
{formatted_params}
}}) => any;
"""
        return typescript_definition.strip()

    @property
    def query_params(self) -> List[str]:
        return [
            property.name
            for property in self.properties
            if property.location == APIPropertyLocation.QUERY
        ]

    @property
    def path_params(self) -> List[str]:
        return [
            property.name
            for property in self.properties
            if property.location == APIPropertyLocation.PATH
        ]

    @property
    def body_params(self) -> List[str]:
        if self.request_body is None:
            return []
        return [prop.name for prop in self.request_body.properties]
