# Copyright 2020 Google Research. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bounding Box List definition.

BoxList represents a list of bounding boxes as tensorflow
tensors, where each bounding box is represented as a row of 4 numbers,
[y_min, x_min, y_max, x_max].  It is assumed that all bounding boxes
within a given list correspond to a single image.  See also
box_list.py for common box related operations (such as area, iou, etc).

Optionally, users can add additional related fields (such as weights).
We assume the following things to be true about fields:
* they correspond to boxes in the box_list along the 0th dimension
* they have inferable rank at graph construction time
* all dimensions except for possibly the 0th can be inferred
  (i.e., not None) at graph construction time.

Some other notes:
    * Following tensorflow conventions, we use height, width ordering,
        and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering
    * Tensors are always provided as (flat) [N, 4] tensors.
"""
import torch
from typing import Optional, List, Dict


@torch.jit.script
class BoxList(object):
    """Box collection."""
    data: Dict[str, torch.Tensor]

    def __init__(self, boxes):
        """Constructs box collection.

        Args:
            boxes: a tensor of shape [N, 4] representing box corners

        Raises:
            ValueError: if invalid dimensions for bbox data or if bbox data is not in float32 format.
        """
        if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
            raise ValueError('Invalid dimensions for box data.')
        if boxes.dtype != torch.float32:
            raise ValueError('Invalid tensor type: should be tf.float32')
        self.data = {'boxes': boxes}

    def num_boxes(self):
        """Returns number of boxes held in collection.

        Returns:
          a tensor representing the number of boxes held in the collection.
        """
        return self.data['boxes'].shape[0]

    def get_all_fields(self):
        """Returns all fields."""
        return self.data.keys()

    def get_extra_fields(self):
        """Returns all non-box fields (i.e., everything not named 'boxes')."""
        # return [k for k in self.data.keys() if k != 'boxes']  # FIXME torscript doesn't support comprehensions yet
        extra: List[str] = []
        for k in self.data.keys():
            if k != 'boxes':
                extra.append(k)
        return extra

    def add_field(self, field: str, field_data: torch.Tensor):
        """Add field to box list.

        This method can be used to add related box data such as weights/labels, etc.

        Args:
            field: a string key to access the data via `get`
            field_data: a tensor containing the data to store in the BoxList
        """
        self.data[field] = field_data

    def has_field(self, field: str):
        return field in self.data

    #@property  # FIXME for torchscript compat
    def boxes(self):
        """Convenience function for accessing box coordinates.

        Returns:
            a tensor with shape [N, 4] representing box coordinates.
        """
        return self.get_field('boxes')

    #@boxes.setter  # FIXME for torchscript compat
    def set_boxes(self, boxes):
        """Convenience function for setting box coordinates.

        Args:
            boxes: a tensor of shape [N, 4] representing box corners

        Raises:
            ValueError: if invalid dimensions for bbox data
        """
        if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
            raise ValueError('Invalid dimensions for box data.')
        self.data['boxes'] = boxes

    def get_field(self, field: str):
        """Accesses a box collection and associated fields.

        This function returns specified field with object; if no field is specified,
        it returns the box coordinates.

        Args:
            field: this optional string parameter can be used to specify a related field to be accessed.

        Returns:
            a tensor representing the box collection or an associated field.

        Raises:
            ValueError: if invalid field
        """
        if not self.has_field(field):
            raise ValueError(f'field {field} does not exist')
        return self.data[field]

    def set_field(self, field: str, value: torch.Tensor):
        """Sets the value of a field.

        Updates the field of a box_list with a given value.

        Args:
            field: (string) name of the field to set value.
            value: the value to assign to the field.

        Raises:
            ValueError: if the box_list does not have specified field.
        """
        if not self.has_field(field):
            raise ValueError(f'field {field} does not exist')
        self.data[field] = value

    def get_center_coordinates_and_sizes(self):
        """Computes the center coordinates, height and width of the boxes.

        Returns:
            a list of 4 1-D tensors [ycenter, xcenter, height, width].
        """
        box_corners = self.boxes()
        ymin, xmin, ymax, xmax = box_corners.t().unbind()
        width = xmax - xmin
        height = ymax - ymin
        ycenter = ymin + height / 2.
        xcenter = xmin + width / 2.
        return [ycenter, xcenter, height, width]

    def transpose_coordinates(self):
        """Transpose the coordinate representation in a boxlist.

        """
        y_min, x_min, y_max, x_max = self.boxes().chunk(4, dim=1)
        self.set_boxes(torch.cat([x_min, y_min, x_max, y_max], 1))

    def as_tensor_dict(self, fields: Optional[List[str]] = None):
        """Retrieves specified fields as a dictionary of tensors.

        Args:
            fields: (optional) list of fields to return in the dictionary.
                If None (default), all fields are returned.

        Returns:
            tensor_dict: A dictionary of tensors specified by fields.

        Raises:
            ValueError: if specified field is not contained in boxlist.
        """
        tensor_dict = {}
        if fields is None:
            fields = self.get_all_fields()
        for field in fields:
            if not self.has_field(field):
                raise ValueError('boxlist must contain all specified fields')
            tensor_dict[field] = self.get_field(field)
        return tensor_dict

    #@property
    def device(self):
        return self.data['boxes'].device
