import torch
import requests
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor


import urllib.request

# Load ImageNet class names
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
classes = urllib.request.urlopen(url).read().decode('utf-8').splitlines()


image_processor = AutoImageProcessor.from_pretrained(
    "google/vit-base-patch16-224",
    use_fast=True,
)
model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="sdpa"
)
url = "image1.png"
image = Image.open(url)
inputs = image_processor(image, return_tensors="pt").to("cuda")

with torch.no_grad():
  logits = model(**inputs).logits
predicted_class_id = logits.argmax(dim=-1).item()
 
predicted_class_label = classes[predicted_class_id]
print(f"The predicted class label is: {predicted_class_label}")
