We can locate a football using an exaplainable AI technique known as Grad-CAM.
import os
os.chdir('../../')
os.getcwd()
from PIL import Image
import glob
im_paths = glob.glob('./data/samples/**.jpg', recursive=True)
images = [Image.open(path) for path in im_paths]
from ball_detection.utils import plot_images
plot_images(images)
import torch
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
X = torch.stack([transform(im) for im in images])
X.requires_grad = True
from torchvision.models import resnet18
model = resnet18(pretrained=True)
model = model.eval()
Load imagenet weights as well.
!wget -P $HOME/.torch/models https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
import json
labels_path = os.getenv("HOME") + '/.torch/models/imagenet_class_index.json'
with open(labels_path) as json_data:
idx_to_labels = json.load(json_data)
These are the labels for any time of ball in ImageNet.
labels_to_use = ['429', '430', '522', '574', '722','805','852','890']
for label in labels_to_use:
print(idx_to_labels[label])
import pandas as pd
import torch.nn.functional as F
y_hat = F.softmax(model(X), dim=1).detach().numpy().squeeze()
d = {}
for i, prediction in enumerate(y_hat):
d[os.path.basename(im_paths[i])] = [f'{prediction[int(label)]:.2%}' for label in labels_to_use]
df = pd.DataFrame(d).T
df.columns = [idx_to_labels[label][1] for label in labels_to_use]
df
from captum.attr import GuidedGradCam
import matplotlib.pyplot as plt
ggc = GuidedGradCam(model, model.layer4)
attributions = []
for label in labels_to_use:
attr = ggc.attribute(X, target=int(label)).detach().numpy().squeeze()
attributions.append(attr)
import numpy as np
# scale to 0-1 range
max_attr = np.max(attributions, axis=0).mean(axis=1)
mx, mn = np.amin(max_attr, axis=(1,2), keepdims=True), np.amax(max_attr, axis=(1,2), keepdims=True)
scaled = (max_attr-mn) / (mn-mx)
plot_images(scaled)
from scipy.ndimage import gaussian_filter
blured = [gaussian_filter(x, 10) for x in scaled]
plot_images(blured)
One thing we could do is take the maximum activation point from grad-cam. However this doesn't extract the center of the ball
x,y = zip(*[np.unravel_index(x.argmax(), x.shape) for x in blured])
def plot_ball(tensor, x, y):
img = tensor.permute((1,2,0))
plt.imshow(img)
plt.scatter(x,y, color='red')
plt.axis('off')
plt.show()
for i in range(len(X)):
plot_ball(X.detach()[i], y[i], x[i])
from torchvision import models
resnet152 = models.resnet152(pretrained=True)
ggc = GuidedGradCam(resnet152, resnet152.layer4)
attributions = []
for label in labels_to_use:
attr = ggc.attribute(X, target=int(label)).detach().numpy().squeeze()
attributions.append(attr)
max_attr = np.max(attributions, axis=0).mean(axis=1)
mx, mn = np.amin(max_attr, axis=(1,2), keepdims=True), np.amax(max_attr, axis=(1,2), keepdims=True)
scaled = (max_attr-mn) / (mn-mx)
plot_images(scaled)
blured = [gaussian_filter(x, 10) for x in scaled]
plot_images(blured)
densenet = models.densenet201(pretrained=True)
ggc = GuidedGradCam(densenet, densenet.features[-1])
attributions = []
for label in labels_to_use:
attr = ggc.attribute(X, target=int(label)).detach().numpy().squeeze()
attributions.append(attr)
max_attr = np.max(attributions, axis=0).mean(axis=1)
mx, mn = np.amin(max_attr, axis=(1,2), keepdims=True), np.amax(max_attr, axis=(1,2), keepdims=True)
scaled = (max_attr-mn) / (mn-mx)
plot_images(scaled)
blured = [gaussian_filter(x, 10) for x in scaled]
plot_images(blured)
from torchvision import models
mobilenet = models.mobilenet_v2(pretrained=True)
ggc = GuidedGradCam(mobilenet, mobilenet.features[-1])
attributions = []
for label in labels_to_use:
attr = ggc.attribute(X, target=int(label)).detach().numpy().squeeze()
attributions.append(attr)
max_attr = np.max(attributions, axis=0).mean(axis=1)
mx, mn = np.amin(max_attr, axis=(1,2), keepdims=True), np.amax(max_attr, axis=(1,2), keepdims=True)
scaled = (max_attr-mn) / (mn-mx)
plot_images(scaled)
blured = [gaussian_filter(x, 10) for x in scaled]
plot_images(blured)