We can locate a football using an exaplainable AI technique known as Grad-CAM.
import os
os.chdir('../../')
os.getcwd()
'/home/me/github/RSF'

Preparation

Load images

from PIL import Image
import glob

im_paths = glob.glob('./data/samples/**.jpg', recursive=True)
images = [Image.open(path) for path in im_paths]
from ball_detection.utils import plot_images
plot_images(images)
import torch
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

X = torch.stack([transform(im) for im in images]) 
X.requires_grad = True

Load a Resnet model pretrained on ImageNet

from torchvision.models import resnet18

model = resnet18(pretrained=True)
model = model.eval()

Load imagenet weights as well.

!wget -P $HOME/.torch/models https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
--2021-03-02 07:16:32--  https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.96.197
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.96.197|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 35363 (35K) [application/octet-stream]
Saving to: ‘/home/me/.torch/models/imagenet_class_index.json.5’

imagenet_class_inde 100%[===================>]  34.53K   191KB/s    in 0.2s    

2021-03-02 07:16:33 (191 KB/s) - ‘/home/me/.torch/models/imagenet_class_index.json.5’ saved [35363/35363]

import json

labels_path = os.getenv("HOME") + '/.torch/models/imagenet_class_index.json'
with open(labels_path) as json_data:
    idx_to_labels = json.load(json_data)

These are the labels for any time of ball in ImageNet.

labels_to_use = ['429', '430', '522', '574', '722','805','852','890']
for label in labels_to_use:
    print(idx_to_labels[label])
['n02799071', 'baseball']
['n02802426', 'basketball']
['n03134739', 'croquet_ball']
['n03445777', 'golf_ball']
['n03942813', 'ping-pong_ball']
['n04254680', 'soccer_ball']
['n04409515', 'tennis_ball']
['n04540053', 'volleyball']
import pandas as pd
import torch.nn.functional as F

y_hat = F.softmax(model(X), dim=1).detach().numpy().squeeze()

d = {}
for i, prediction in enumerate(y_hat):
    d[os.path.basename(im_paths[i])] = [f'{prediction[int(label)]:.2%}' for label in labels_to_use] 
    
df = pd.DataFrame(d).T
df.columns = [idx_to_labels[label][1] for label in labels_to_use]
df
baseball basketball croquet_ball golf_ball ping-pong_ball soccer_ball tennis_ball volleyball
ronaldo.jpg 1.91% 6.97% 0.15% 0.06% 3.89% 7.33% 6.35% 3.07%
000554.jpg 1.36% 0.59% 33.47% 2.45% 0.07% 30.11% 0.79% 1.56%
000001.jpg 3.04% 0.00% 3.26% 14.55% 0.06% 77.14% 0.11% 0.13%
Monke.jpg 0.00% 0.00% 0.00% 0.00% 0.00% 0.00% 0.00% 0.00%
federer.jpg 0.05% 5.85% 0.00% 0.00% 66.65% 0.53% 1.12% 8.20%
sea.jpg 0.00% 0.00% 0.00% 0.00% 0.00% 0.00% 0.00% 0.00%
DSC_3136.jpg 0.02% 0.00% 0.01% 0.00% 0.00% 0.02% 99.85% 0.00%
baby.jpg 0.00% 0.00% 0.00% 0.00% 0.00% 0.00% 0.00% 0.00%
doge.jpg 0.00% 0.00% 0.00% 0.00% 0.00% 0.02% 0.03% 0.00%
from captum.attr import GuidedGradCam
import matplotlib.pyplot as plt

ggc = GuidedGradCam(model, model.layer4)

attributions = []
for label in labels_to_use:
    attr = ggc.attribute(X, target=int(label)).detach().numpy().squeeze()
    attributions.append(attr)
/home/me/miniconda3/lib/python3.8/site-packages/captum/attr/_core/guided_backprop_deconvnet.py:60: UserWarning: Setting backward hooks on ReLU activations.The hooks will be removed after the attribution is finished
  warnings.warn(
import numpy as np

# scale to 0-1 range
max_attr = np.max(attributions, axis=0).mean(axis=1)
mx, mn = np.amin(max_attr, axis=(1,2), keepdims=True), np.amax(max_attr, axis=(1,2), keepdims=True)

scaled = (max_attr-mn) / (mn-mx)

plot_images(scaled)
from scipy.ndimage import gaussian_filter

blured = [gaussian_filter(x, 10) for x in scaled]
plot_images(blured)

Locating the ball with Grad-CAM

One thing we could do is take the maximum activation point from grad-cam. However this doesn't extract the center of the ball

x,y = zip(*[np.unravel_index(x.argmax(),  x.shape) for x in blured])

def plot_ball(tensor, x, y):
    img = tensor.permute((1,2,0))
    plt.imshow(img)
    plt.scatter(x,y, color='red')
    plt.axis('off')
    plt.show()
    

for i in range(len(X)):
    plot_ball(X.detach()[i], y[i], x[i])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

How do other models behave?

ResNet152

from torchvision import models

resnet152 = models.resnet152(pretrained=True)
ggc = GuidedGradCam(resnet152, resnet152.layer4)

attributions = []
for label in labels_to_use:
    attr = ggc.attribute(X, target=int(label)).detach().numpy().squeeze()
    attributions.append(attr)
    
max_attr = np.max(attributions, axis=0).mean(axis=1)
mx, mn = np.amin(max_attr, axis=(1,2), keepdims=True), np.amax(max_attr, axis=(1,2), keepdims=True)
scaled = (max_attr-mn) / (mn-mx)

plot_images(scaled)
blured = [gaussian_filter(x, 10) for x in scaled]
plot_images(blured)
densenet = models.densenet201(pretrained=True)
ggc = GuidedGradCam(densenet, densenet.features[-1])

attributions = []
for label in labels_to_use:
    attr = ggc.attribute(X, target=int(label)).detach().numpy().squeeze()
    attributions.append(attr)
    
max_attr = np.max(attributions, axis=0).mean(axis=1)
mx, mn = np.amin(max_attr, axis=(1,2), keepdims=True), np.amax(max_attr, axis=(1,2), keepdims=True)
scaled = (max_attr-mn) / (mn-mx)

plot_images(scaled)
blured = [gaussian_filter(x, 10) for x in scaled]
plot_images(blured)
from torchvision import models

mobilenet = models.mobilenet_v2(pretrained=True)
ggc = GuidedGradCam(mobilenet, mobilenet.features[-1])

attributions = []
for label in labels_to_use:
    attr = ggc.attribute(X, target=int(label)).detach().numpy().squeeze()
    attributions.append(attr)
    
max_attr = np.max(attributions, axis=0).mean(axis=1)
mx, mn = np.amin(max_attr, axis=(1,2), keepdims=True), np.amax(max_attr, axis=(1,2), keepdims=True)
scaled = (max_attr-mn) / (mn-mx)

plot_images(scaled)
blured = [gaussian_filter(x, 10) for x in scaled]
plot_images(blured)