"""A common module for NVIDIA Riva Runnables."""

import asyncio
import logging
import pathlib
import queue
import tempfile
import threading
import wave
from enum import Enum
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncGenerator,
    AsyncIterator,
    Dict,
    Generator,
    Iterator,
    List,
    Optional,
    Tuple,
    Union,
    cast,
)

from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from pydantic import (
    AnyHttpUrl,
    BaseModel,
    Field,
    parse_obj_as,
    root_validator,
    validator,
)

if TYPE_CHECKING:
    import riva.client
    import riva.client.proto.riva_asr_pb2 as rasr

_LOGGER = logging.getLogger(__name__)
_QUEUE_GET_TIMEOUT = 0.5
_MAX_TEXT_LENGTH = 400
_SENTENCE_TERMINATORS = ("\n", ".", "!", "?", "¡", "¿")


# COMMON utilities used by all Riva Runnables
def _import_riva_client() -> "riva.client":
    """Import the riva client and raise an error on failure."""
    try:
        # pylint: disable-next=import-outside-toplevel # this client library is optional
        import riva.client
    except ImportError as err:
        raise ImportError(
            "Could not import the NVIDIA Riva client library. "
            "Please install it with `pip install nvidia-riva-client`."
        ) from err
    return riva.client


class SentinelT:  # pylint: disable=too-few-public-methods
    """An empty Sentinel type."""


HANGUP = SentinelT()
_TRANSFORM_END = SentinelT()


class RivaAudioEncoding(str, Enum):
    """An enum of the possible choices for Riva audio encoding.

    The list of types exposed by the Riva GRPC Protobuf files can be found
    with the following commands:
    ```python
    import riva.client
    print(riva.client.AudioEncoding.keys())  # noqa: T201
    ```
    """

    ALAW = "ALAW"
    ENCODING_UNSPECIFIED = "ENCODING_UNSPECIFIED"
    FLAC = "FLAC"
    LINEAR_PCM = "LINEAR_PCM"
    MULAW = "MULAW"
    OGGOPUS = "OGGOPUS"

    @classmethod
    def from_wave_format_code(cls, format_code: int) -> "RivaAudioEncoding":
        """Return the audio encoding specified by the format code in the wave file.

        ref: https://mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
        """
        try:
            return {1: cls.LINEAR_PCM, 6: cls.ALAW, 7: cls.MULAW}[format_code]
        except KeyError as err:
            raise NotImplementedError(
                "The following wave file format code is "
                f"not supported by Riva: {format_code}"
            ) from err

    @property
    def riva_pb2(self) -> "riva.client.AudioEncoding":
        """Returns the Riva API object for the encoding."""
        riva_client = _import_riva_client()
        return getattr(riva_client.AudioEncoding, self)


class RivaAuthMixin(BaseModel):
    """Configuration for the authentication to a Riva service connection."""

    url: Union[AnyHttpUrl, str] = Field(
        AnyHttpUrl("http://localhost:50051"),
        description="The full URL where the Riva service can be found.",
        examples=["http://localhost:50051", "https://user@pass:riva.example.com"],
    )
    ssl_cert: Optional[str] = Field(
        None,
        description="A full path to the file where Riva's public ssl key can be read.",
    )

    @property
    def auth(self) -> "riva.client.Auth":
        """Return a riva client auth object."""
        riva_client = _import_riva_client()
        url = cast(AnyHttpUrl, self.url)
        use_ssl = url.scheme == "https"  # pylint: disable=no-member # false positive
        url_no_scheme = str(self.url).split("/")[2]
        return riva_client.Auth(
            ssl_cert=self.ssl_cert, use_ssl=use_ssl, uri=url_no_scheme
        )

    @validator("url", pre=True, allow_reuse=True)
    @classmethod
    def _validate_url(cls, val: Any) -> AnyHttpUrl:
        """Do some initial conversations for the URL before checking."""
        if isinstance(val, str):
            return cast(AnyHttpUrl, parse_obj_as(AnyHttpUrl, val))
        return cast(AnyHttpUrl, val)


class RivaCommonConfigMixin(BaseModel):
    """A collection of common Riva settings."""

    encoding: RivaAudioEncoding = Field(
        default=RivaAudioEncoding.LINEAR_PCM,
        description="The encoding on the audio stream.",
    )
    sample_rate_hertz: int = Field(
        default=8000, description="The sample rate frequency of audio stream."
    )
    language_code: str = Field(
        default="en-US",
        description=(
            "The [BCP-47 language code]"
            "(https://www.rfc-editor.org/rfc/bcp/bcp47.txt) for "
            "the target language."
        ),
    )


class _Event:
    """A combined event that is threadsafe and async safe."""

    _event: threading.Event
    _aevent: asyncio.Event

    def __init__(self) -> None:
        """Initialize the event."""
        self._event = threading.Event()
        self._aevent = asyncio.Event()

    def set(self) -> None:
        """Set the event."""
        self._event.set()
        self._aevent.set()

    def clear(self) -> None:
        """Set the event."""
        self._event.clear()
        self._aevent.clear()

    def is_set(self) -> bool:
        """Indicate if the event is set."""
        return self._event.is_set()

    def wait(self) -> None:
        """Wait for the event to be set."""
        self._event.wait()

    async def async_wait(self) -> None:
        """Async wait for the event to be set."""
        await self._aevent.wait()


def _mk_wave_file(
    output_directory: Optional[str], sample_rate: float
) -> Tuple[Optional[str], Optional[wave.Wave_write]]:
    """Create a new wave file and return the wave write object and filename."""
    if output_directory:
        with tempfile.NamedTemporaryFile(
            mode="bx", suffix=".wav", delete=False, dir=output_directory
        ) as f:
            wav_file_name = f.name
        wav_file = wave.open(wav_file_name, "wb")
        wav_file.setnchannels(1)
        wav_file.setsampwidth(2)
        wav_file.setframerate(sample_rate)
        return (wav_file_name, wav_file)
    return (None, None)


def _coerce_string(val: "TTSInputType") -> str:
    """Attempt to coerce the input value to a string.

    This is particularly useful for converting LangChain message to strings.
    """
    if isinstance(val, PromptValue):
        return val.to_string()
    if isinstance(val, BaseMessage):
        return str(val.content)
    return str(val)


def _process_chunks(inputs: Iterator["TTSInputType"]) -> Generator[str, None, None]:
    """Filter the input chunks are return strings ready for TTS."""
    buffer = ""
    for chunk in inputs:
        chunk = _coerce_string(chunk)

        # return the buffer if an end of sentence character is detected
        for terminator in _SENTENCE_TERMINATORS:
            while terminator in chunk:
                last_sentence, chunk = chunk.split(terminator, 1)
                yield buffer + last_sentence + terminator
                buffer = ""

        buffer += chunk

        # return the buffer if is too long
        if len(buffer) > _MAX_TEXT_LENGTH:
            for idx in range(0, len(buffer), _MAX_TEXT_LENGTH):
                yield buffer[idx : idx + 5]
            buffer = ""

    # return remaining buffer
    if buffer:
        yield buffer


# Riva AudioStream Type
StreamInputType = Union[bytes, SentinelT]
StreamOutputType = str


class AudioStream:
    """A message containing streaming audio."""

    _put_lock: threading.Lock
    _queue: queue.Queue
    output: queue.Queue
    hangup: _Event
    user_talking: _Event
    user_quiet: _Event
    _worker: Optional[threading.Thread]

    def __init__(self, maxsize: int = 0) -> None:
        """Initialize the queue."""
        self._put_lock = threading.Lock()
        self._queue = queue.Queue(maxsize=maxsize)
        self.output = queue.Queue()
        self.hangup = _Event()
        self.user_quiet = _Event()
        self.user_talking = _Event()
        self._worker = None

    def __iter__(self) -> Generator[bytes, None, None]:
        """Return an error."""
        while True:
            # get next item
            try:
                next_val = self._queue.get(True, _QUEUE_GET_TIMEOUT)
            except queue.Empty:
                continue

            # hangup when requested
            if next_val == HANGUP:
                break

            # yield next item
            yield next_val
            self._queue.task_done()

    async def __aiter__(self) -> AsyncIterator[StreamInputType]:
        """Iterate through all items in the queue until HANGUP."""
        while True:
            # get next item
            try:
                next_val = await asyncio.get_event_loop().run_in_executor(
                    None, self._queue.get, True, _QUEUE_GET_TIMEOUT
                )
            except queue.Empty:
                continue

            # hangup when requested
            if next_val == HANGUP:
                break

            # yield next item
            yield next_val
            self._queue.task_done()

    @property
    def hungup(self) -> bool:
        """Indicate if the audio stream has hungup."""
        return self.hangup.is_set()

    @property
    def empty(self) -> bool:
        """Indicate in the input stream buffer is empty."""
        return self._queue.empty()

    @property
    def complete(self) -> bool:
        """Indicate if the audio stream has hungup and been processed."""
        input_done = self.hungup and self.empty
        output_done = (
            self._worker is not None
            and not self._worker.is_alive()
            and self.output.empty()
        )
        return input_done and output_done

    @property
    def running(self) -> bool:
        """Indicate if the ASR stream is running."""
        if self._worker:
            return self._worker.is_alive()
        return False

    def put(self, item: StreamInputType, timeout: Optional[int] = None) -> None:
        """Put a new item into the queue."""
        with self._put_lock:
            if self.hungup:
                raise RuntimeError(
                    "The audio stream has already been hungup. Cannot put more data."
                )
            if item is HANGUP:
                self.hangup.set()
            self._queue.put(item, timeout=timeout)

    async def aput(self, item: StreamInputType, timeout: Optional[int] = None) -> None:
        """Async put a new item into the queue."""
        loop = asyncio.get_event_loop()
        await asyncio.wait_for(loop.run_in_executor(None, self.put, item), timeout)

    def close(self, timeout: Optional[int] = None) -> None:
        """Send the hangup signal."""
        self.put(HANGUP, timeout)

    async def aclose(self, timeout: Optional[int] = None) -> None:
        """Async send the hangup signal."""
        await self.aput(HANGUP, timeout)

    def register(self, responses: Iterator["rasr.StreamingRecognizeResponse"]) -> None:
        """Drain the responses from the provided iterator and put them into a queue."""
        if self.running:
            raise RuntimeError("An ASR instance has already been registered.")

        has_started = threading.Barrier(2, timeout=5)

        def worker() -> None:
            """Consume the ASR Generator."""
            has_started.wait()
            for response in responses:
                if not response.results:
                    continue

                for result in response.results:
                    if not result.alternatives:
                        continue

                    if result.is_final:
                        self.user_talking.clear()
                        self.user_quiet.set()
                        transcript = cast(str, result.alternatives[0].transcript)
                        self.output.put(transcript)

                    elif not self.user_talking.is_set():
                        self.user_talking.set()
                        self.user_quiet.clear()

        self._worker = threading.Thread(target=worker)
        self._worker.daemon = True
        self._worker.start()
        has_started.wait()


# RivaASR Runnable
ASRInputType = AudioStream
ASROutputType = str


class RivaASR(  # type: ignore[override]
    RivaAuthMixin,
    RivaCommonConfigMixin,
    RunnableSerializable[ASRInputType, ASROutputType],
):
    """A runnable that performs Automatic Speech Recognition (ASR) using NVIDIA Riva."""

    name: str = "nvidia_riva_asr"
    description: str = (
        "A Runnable for converting audio bytes to a string."
        "This is useful for feeding an audio stream into a chain and"
        "preprocessing that audio to create an LLM prompt."
    )

    # riva options
    audio_channel_count: int = Field(
        1, description="The number of audio channels in the input audio stream."
    )
    profanity_filter: bool = Field(
        True,
        description=(
            "Controls whether or not Riva should attempt to filter "
            "profanity out of the transcribed text."
        ),
    )
    enable_automatic_punctuation: bool = Field(
        True,
        description=(
            "Controls whether Riva should attempt to correct "
            "senetence puncuation in the transcribed text."
        ),
    )

    @root_validator(pre=True)
    @classmethod
    def _validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Validate the Python environment and input arguments."""
        _ = _import_riva_client()
        return values

    @property
    def config(self) -> "riva.client.StreamingRecognitionConfig":
        """Create and return the riva config object."""
        riva_client = _import_riva_client()
        return riva_client.StreamingRecognitionConfig(
            interim_results=True,
            config=riva_client.RecognitionConfig(
                encoding=self.encoding,
                sample_rate_hertz=self.sample_rate_hertz,
                audio_channel_count=self.audio_channel_count,
                max_alternatives=1,
                profanity_filter=self.profanity_filter,
                enable_automatic_punctuation=self.enable_automatic_punctuation,
                language_code=self.language_code,
            ),
        )

    def _get_service(self) -> "riva.client.ASRService":
        """Connect to the riva service and return the a client object."""
        riva_client = _import_riva_client()
        try:
            return riva_client.ASRService(self.auth)
        except Exception as err:
            raise ValueError(
                "Error raised while connecting to the Riva ASR server."
            ) from err

    def invoke(
        self,
        input: ASRInputType,
        config: Optional[RunnableConfig] = None,
        **kwargs: Any,
    ) -> ASROutputType:
        """Transcribe the audio bytes into a string with Riva."""
        # create an output text generator with Riva
        if not input.running:
            service = self._get_service()
            responses = service.streaming_response_generator(
                audio_chunks=input,
                streaming_config=self.config,
            )
            input.register(responses)

        # return the first valid result
        full_response: List[str] = []
        while not input.complete:
            with input.output.not_empty:
                ready = input.output.not_empty.wait(0.1)

            if ready:
                while not input.output.empty():
                    try:
                        full_response += [input.output.get_nowait()]
                    except queue.Empty:
                        continue
                    input.output.task_done()
                _LOGGER.debug("Riva ASR returning: %s", repr(full_response))
                return " ".join(full_response).strip()

        return ""


# RivaTTS Runnable
# pylint: disable-next=invalid-name
TTSInputType = Union[str, AnyMessage, PromptValue]
TTSOutputType = bytes


class RivaTTS(  # type: ignore[override]
    RivaAuthMixin,
    RivaCommonConfigMixin,
    RunnableSerializable[TTSInputType, TTSOutputType],
):
    """A runnable that performs Text-to-Speech (TTS) with NVIDIA Riva."""

    name: str = "nvidia_riva_tts"
    description: str = (
        "A tool for converting text to speech."
        "This is useful for converting LLM output into audio bytes."
    )

    # riva options
    voice_name: str = Field(
        "English-US.Female-1",
        description=(
            "The voice model in Riva to use for speech. "
            "Pre-trained models are documented in "
            "[the Riva documentation]"
            "(https://docs.nvidia.com/deeplearning/riva/user-guide/docs/tts/tts-overview.html)."
        ),
    )
    output_directory: Optional[str] = Field(
        None,
        description=(
            "The directory where all audio files should be saved. "
            "A null value indicates that wave files should not be saved. "
            "This is useful for debugging purposes."
        ),
    )

    @root_validator(pre=True)
    @classmethod
    def _validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Validate the Python environment and input arguments."""
        _ = _import_riva_client()
        return values

    @validator("output_directory")
    @classmethod
    def _output_directory_validator(cls, v: str) -> str:
        if v:
            dirpath = pathlib.Path(v)
            dirpath.mkdir(parents=True, exist_ok=True)
            return str(dirpath.absolute())
        return v

    def _get_service(self) -> "riva.client.SpeechSynthesisService":
        """Connect to the riva service and return the a client object."""
        riva_client = _import_riva_client()
        try:
            return riva_client.SpeechSynthesisService(self.auth)
        except Exception as err:
            raise ValueError(
                "Error raised while connecting to the Riva TTS server."
            ) from err

    def invoke(
        self,
        input: TTSInputType,
        config: Optional[RunnableConfig] = None,
        **kwargs: Any,
    ) -> TTSOutputType:
        """Perform TTS by taking a string and outputting the entire audio file."""
        return b"".join(self.transform(iter([input])))

    def transform(
        self,
        input: Iterator[TTSInputType],
        config: Optional[RunnableConfig] = None,
        **kwargs: Optional[Any],
    ) -> Iterator[TTSOutputType]:
        """Perform TTS by taking a stream of characters and streaming output bytes."""
        service = self._get_service()

        # create an output wave file
        wav_file_name, wav_file = _mk_wave_file(
            self.output_directory, self.sample_rate_hertz
        )

        # split the input text and perform tts
        for chunk in _process_chunks(input):
            _LOGGER.debug("Riva TTS chunk: %s", chunk)

            # start riva tts streaming
            responses = service.synthesize_online(
                text=chunk,
                voice_name=self.voice_name,
                language_code=self.language_code,
                encoding=self.encoding.riva_pb2,
                sample_rate_hz=self.sample_rate_hertz,
            )

            # stream audio bytes out
            for resp in responses:
                audio = cast(bytes, resp.audio)
                if wav_file:
                    wav_file.writeframesraw(audio)
                yield audio

        # close the wave file when we are done
        if wav_file:
            wav_file.close()
            _LOGGER.debug("Riva TTS wrote file: %s", wav_file_name)

    async def atransform(
        self,
        input: AsyncIterator[TTSInputType],
        config: Optional[RunnableConfig] = None,
        **kwargs: Optional[Any],
    ) -> AsyncGenerator[TTSOutputType, None]:
        """Intercept async transforms and route them to the synchronous transform."""
        loop = asyncio.get_running_loop()
        input_queue: queue.Queue = queue.Queue()
        out_queue: asyncio.Queue = asyncio.Queue()

        async def _producer() -> None:
            """Produce input into the input queue."""
            async for val in input:
                input_queue.put_nowait(val)
            input_queue.put_nowait(_TRANSFORM_END)

        def _input_iterator() -> Iterator[TTSInputType]:
            """Iterate over the input_queue."""
            while True:
                try:
                    val = input_queue.get(timeout=0.5)
                except queue.Empty:
                    continue
                if val == _TRANSFORM_END:
                    break
                yield val

        def _consumer() -> None:
            """Consume the input with transform."""
            for val in self.transform(_input_iterator()):
                out_queue.put_nowait(val)
            out_queue.put_nowait(_TRANSFORM_END)

        async def _consumer_coro() -> None:
            """Coroutine that wraps the consumer."""
            await loop.run_in_executor(None, _consumer)

        producer = loop.create_task(_producer())
        consumer = loop.create_task(_consumer_coro())

        while True:
            try:
                val = await asyncio.wait_for(out_queue.get(), 0.5)
            except asyncio.exceptions.TimeoutError:
                continue
            out_queue.task_done()

            if val is _TRANSFORM_END:
                break
            yield val

        await producer
        await consumer


# Backwards compatibility:
NVIDIARivaASR = RivaASR
NVIDIARivaTTS = RivaTTS
NVIDIARivaStream = AudioStream
