### Online Colab Demo: https://colab.research.google.com/drive/14Dqg7oeBkFEtchaHLNpig2BcdkZEogba
### Hugging Face Spaces Demo: https://huggingface.co/spaces/ZhengPeng7/BiRefNet_demo

In [None]:
# Imports
from PIL import Image
import torch
from torchvision import transforms
from IPython.display import display

import sys
sys.path.insert(0, "../")
from models.birefnet import BiRefNet


# Load Model
# Option 2 and Option 3 is better for local running -- we can modify codes locally.

# # # Option 1: loading BiRefNet with weights:
# from transformers import AutoModelForImageSegmentation
# birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)

# Option-2: loading weights with BiReNet codes:
model = [
        'zhengpeng7/BiRefNet',
        'zhengpeng7/BiRefNet-portrait',
        'zhengpeng7/BiRefNet-legacy', 'zhengpeng7/BiRefNet-DIS5K-TR_TEs', 'zhengpeng7/BiRefNet-DIS5K', 'zhengpeng7/BiRefNet-HRSOD', 'zhengpeng7/BiRefNet-COD',
        'zhengpeng7/BiRefNet_lite',     # Modify the `bb` in `config.py` to `swin_v1_tiny`.
    ][0]
birefnet = BiRefNet.from_pretrained(
    model
)
model_name = model.split('/')[-1]

# # Option-3: Loading model and weights from local disk:
# from utils import check_state_dict

# birefnet = BiRefNet(bb_pretrained=False)
# state_dict = torch.load('../BiRefNet-general-epoch_244.pth', map_location='cpu', weights_only=True)
# state_dict = check_state_dict(state_dict)
# birefnet.load_state_dict(state_dict)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.set_float32_matmul_precision(['high', 'highest'][0])

birefnet.to(device)
birefnet.eval()
print('BiRefNet is ready to use.')
birefnet.half()

# Input Data
transform_image = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

### Do the inferences on all videos in directory **videos_todo**

In [None]:
import os
from glob import glob

video_src_paths = sorted([f for f in glob('../videos_todo/*') if os.path.splitext(f)[1] in ('.mp4', '.avi')], key=lambda x: int(x.split('/')[-1].split('.')[0]))
print('video_src_paths:', video_src_paths)

In [None]:
import cv2
import numpy as np
from image_proc import refine_foreground
from time import time

for video_src_path in video_src_paths[:]:
    print('\nvideo_src_path:', video_src_path)
    src_dir = os.path.join('../frames-{}-video_{}'.format(model_name, os.path.splitext(os.path.basename(video_src_path))[0]))
    video_ext = os.path.splitext(video_src_path)[-1]
    video_dst_path_mask = video_src_path.replace(video_ext, '-preds_mask-{}'.format(model_name)+video_ext)
    video_dst_path_subject = video_src_path.replace(video_ext, '-preds_subject-{}'.format(model_name)+video_ext)
    vidcap = cv2.VideoCapture(video_src_path)
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    success, image = vidcap.read()

    video_writer_shape = image.shape[:2][::-1]
    video_writer_mask = cv2.VideoWriter(video_dst_path_mask, cv2.VideoWriter_fourcc(*'mp4v'), fps, video_writer_shape, isColor=False)
    video_writer_subject = cv2.VideoWriter(video_dst_path_subject, cv2.VideoWriter_fourcc(*'mp4v'), fps, video_writer_shape, isColor=True)

    count = 0
    while success:
        os.makedirs(src_dir, exist_ok=True)
        cv2.imwrite(os.path.join(src_dir, 'frame_{}.png'.format(count)), image)
        success, image = vidcap.read()
        count += 1

    image_paths = sorted(glob(os.path.join(src_dir, '*')), key=lambda x: int(os.path.splitext(os.path.basename(x).split('_')[-1])[0]))   # Sorted by the frame index.
    # dst_dir = '../predictions'
    # os.makedirs(dst_dir, exist_ok=True)
    time_st = time()
    batch_size = 1
    for idx in range(0, len(image_paths[:]), batch_size):
        image_path = image_paths[idx]
        if (idx // batch_size + 1) % int(len(image_paths) // batch_size * 0.1) == 0:
            print('Processing {} / {} ...'.format(image_path, len(image_paths)))
        input_images_pil = [image.convert("RGB") if image.mode != "RGB" else image
                            for image in [Image.open(image_path) for image_path in image_paths[idx:idx + batch_size]]]
        input_images = [transform_image(input_image).unsqueeze(0).to(device) for input_image in input_images_pil]
        input_images = torch.cat(input_images, dim=0)
        input_images = input_images.half()

        # Prediction
        with torch.no_grad():
            preds = birefnet(input_images)[-1].sigmoid().cpu()

        for idx_pred in range(preds.shape[0]):
            pred = preds[idx_pred].squeeze()
            image = input_images_pil[idx_pred]

            # Show Results
            pred_pil = transforms.ToPILImage()(pred)
            # pred_pil.resize(image.size).save(image_path.replace(src_dir, dst_dir))

            image_masked = refine_foreground(image, pred_pil)
            image_masked.putalpha(pred_pil.resize(image.size))

            video_writer_mask.write(np.array(pred_pil.convert('L').resize(image.size)))
            array_foreground = np.array(image_masked)[:, :, :3].astype(np.float32)
            array_mask = (np.array(image_masked)[:, :, 3:] / 255).astype(np.float32)
            array_background = np.zeros_like(array_foreground)
            array_background[:, :, :] = (0, 177, 64)
            array_foreground_background = array_foreground * array_mask + array_background * (1 - array_mask)
            video_writer_subject.write(cv2.cvtColor(array_foreground_background, cv2.COLOR_RGB2BGR).astype(np.uint8))

    video_writer_mask.release()
    video_writer_subject.release()

    print('Mask video has been saved to:', video_dst_path_mask)
    print('Subject video has been saved to:', video_dst_path_subject)
    print('Time cost:', round(time() - time_st, 2))