#!/usr/bin/env python

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os

from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES
from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType
from ...utils.imports import is_boto3_available
from .config_args import SageMakerConfig
from .config_utils import (
    DYNAMO_BACKENDS,
    _ask_field,
    _ask_options,
    _convert_dynamo_backend,
    _convert_mixed_precision,
    _convert_sagemaker_distributed_mode,
    _convert_yes_no_to_bool,
)


if is_boto3_available():
    import boto3  # noqa: F401


def _create_iam_role_for_sagemaker(role_name):
    iam_client = boto3.client("iam")

    sagemaker_trust_policy = {
        "Version": "2012-10-17",
        "Statement": [
            {"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"}
        ],
    }
    try:
        # create the role, associated with the chosen trust policy
        iam_client.create_role(
            RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2)
        )
        policy_document = {
            "Version": "2012-10-17",
            "Statement": [
                {
                    "Effect": "Allow",
                    "Action": [
                        "sagemaker:*",
                        "ecr:GetDownloadUrlForLayer",
                        "ecr:BatchGetImage",
                        "ecr:BatchCheckLayerAvailability",
                        "ecr:GetAuthorizationToken",
                        "cloudwatch:PutMetricData",
                        "cloudwatch:GetMetricData",
                        "cloudwatch:GetMetricStatistics",
                        "cloudwatch:ListMetrics",
                        "logs:CreateLogGroup",
                        "logs:CreateLogStream",
                        "logs:DescribeLogStreams",
                        "logs:PutLogEvents",
                        "logs:GetLogEvents",
                        "s3:CreateBucket",
                        "s3:ListBucket",
                        "s3:GetBucketLocation",
                        "s3:GetObject",
                        "s3:PutObject",
                    ],
                    "Resource": "*",
                }
            ],
        }
        # attach policy to role
        iam_client.put_role_policy(
            RoleName=role_name,
            PolicyName=f"{role_name}_policy_permission",
            PolicyDocument=json.dumps(policy_document, indent=2),
        )
    except iam_client.exceptions.EntityAlreadyExistsException:
        print(f"role {role_name} already exists. Using existing one")


def _get_iam_role_arn(role_name):
    iam_client = boto3.client("iam")
    return iam_client.get_role(RoleName=role_name)["Role"]["Arn"]


def get_sagemaker_input():
    credentials_configuration = _ask_options(
        "How do you want to authorize?",
        ["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "],
        int,
    )
    aws_profile = None
    if credentials_configuration == 0:
        aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default")
        os.environ["AWS_PROFILE"] = aws_profile
    else:
        print(
            "Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with,"
            "`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`"
        )
        aws_access_key_id = _ask_field("AWS Access Key ID: ")
        os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id

        aws_secret_access_key = _ask_field("AWS Secret Access Key: ")
        os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key

    aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1")
    os.environ["AWS_DEFAULT_REGION"] = aws_region

    role_management = _ask_options(
        "Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?",
        ["Provide IAM Role name", "Create new IAM role using credentials"],
        int,
    )
    if role_management == 0:
        iam_role_name = _ask_field("Enter your IAM role name: ")
    else:
        iam_role_name = "accelerate_sagemaker_execution_role"
        print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials')
        _create_iam_role_for_sagemaker(iam_role_name)

    is_custom_docker_image = _ask_field(
        "Do you want to use custom Docker image? [yes/NO]: ",
        _convert_yes_no_to_bool,
        default=False,
        error_message="Please enter yes or no.",
    )
    docker_image = None
    if is_custom_docker_image:
        docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower())

    is_sagemaker_inputs_enabled = _ask_field(
        "Do you want to provide SageMaker input channels with data locations? [yes/NO]: ",
        _convert_yes_no_to_bool,
        default=False,
        error_message="Please enter yes or no.",
    )
    sagemaker_inputs_file = None
    if is_sagemaker_inputs_enabled:
        sagemaker_inputs_file = _ask_field(
            "Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ",
            lambda x: str(x).lower(),
        )

    is_sagemaker_metrics_enabled = _ask_field(
        "Do you want to enable SageMaker metrics? [yes/NO]: ",
        _convert_yes_no_to_bool,
        default=False,
        error_message="Please enter yes or no.",
    )
    sagemaker_metrics_file = None
    if is_sagemaker_metrics_enabled:
        sagemaker_metrics_file = _ask_field(
            "Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ",
            lambda x: str(x).lower(),
        )

    distributed_type = _ask_options(
        "What is the distributed mode?",
        ["No distributed training", "Data parallelism"],
        _convert_sagemaker_distributed_mode,
    )
    dynamo_config = {}
    use_dynamo = _ask_field(
        "Do you wish to optimize your script with torch dynamo?[yes/NO]:",
        _convert_yes_no_to_bool,
        default=False,
        error_message="Please enter yes or no.",
    )
    if use_dynamo:
        prefix = "dynamo_"
        dynamo_config[prefix + "backend"] = _ask_options(
            "Which dynamo backend would you like to use?",
            [x.lower() for x in DYNAMO_BACKENDS],
            _convert_dynamo_backend,
            default=2,
        )
        use_custom_options = _ask_field(
            "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
            _convert_yes_no_to_bool,
            default=False,
            error_message="Please enter yes or no.",
        )

        if use_custom_options:
            dynamo_config[prefix + "mode"] = _ask_options(
                "Which mode do you want to use?",
                TORCH_DYNAMO_MODES,
                lambda x: TORCH_DYNAMO_MODES[int(x)],
                default="default",
            )
            dynamo_config[prefix + "use_fullgraph"] = _ask_field(
                "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
                _convert_yes_no_to_bool,
                default=False,
                error_message="Please enter yes or no.",
            )
            dynamo_config[prefix + "use_dynamic"] = _ask_field(
                "Do you want to enable dynamic shape tracing? [yes/NO]: ",
                _convert_yes_no_to_bool,
                default=False,
                error_message="Please enter yes or no.",
            )
    ec2_instance_query = "Which EC2 instance type you want to use for your training?"
    if distributed_type != SageMakerDistributedType.NO:
        ec2_instance_type = _ask_options(
            ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)]
        )
    else:
        ec2_instance_query += "? [ml.p3.2xlarge]:"
        ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge")

    debug = False
    if distributed_type != SageMakerDistributedType.NO:
        debug = _ask_field(
            "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
            _convert_yes_no_to_bool,
            default=False,
            error_message="Please enter yes or no.",
        )

    num_machines = 1
    if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL):
        num_machines = _ask_field(
            "How many machines do you want use? [1]: ",
            int,
            default=1,
        )

    mixed_precision = _ask_options(
        "Do you wish to use FP16 or BF16 (mixed precision)?",
        ["no", "fp16", "bf16", "fp8"],
        _convert_mixed_precision,
    )

    if use_dynamo and mixed_precision == "no":
        print(
            "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
        )

    return SageMakerConfig(
        image_uri=docker_image,
        compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER,
        distributed_type=distributed_type,
        use_cpu=False,
        dynamo_config=dynamo_config,
        ec2_instance_type=ec2_instance_type,
        profile=aws_profile,
        region=aws_region,
        iam_role_name=iam_role_name,
        mixed_precision=mixed_precision,
        num_machines=num_machines,
        sagemaker_inputs_file=sagemaker_inputs_file,
        sagemaker_metrics_file=sagemaker_metrics_file,
        debug=debug,
    )
