Rectangular Separability Filters (RSF) for circular object detection.
%load_ext autoreload
%autoreload 2
X = t.rand((1, 500, 500))
N1 = 15 # Box blur with kernel size 15
W1 = t.tensor([[1/N1**2]*N1]*N1)
N2 = 50 # Box blur with kernel size 50
W2 = t.tensor([[1/N2**2]*N2]*N2)
Y1 = conv2d(X, W1, padding=1).squeeze(0)
Y2 = conv2d(X, W2, padding=1).squeeze(0)
plot_images([X, Y1, Y2])
def RecSum(SAT, h, w):
W = [[1,-1],[-1,1]]
return conv2d(SAT, W, dilation=(h, w))
def integral_image(X): # For batch
for i in range(1, X.ndim):
X = X.cumsum(i)
return X
def conv2d(X, W, normalize_weights=True, **kwargs):
X = torch.FloatTensor(X)
W = torch.FloatTensor(W)
# Do stuff on the input tensor
if X.ndim==2:
X = X.view(1, 1, *X.shape)
elif X.ndim==3:
X = X.view(1, *X.shape)
else:
assert X.ndim==4
# Do stuff on the weights
c = X.shape[1]
h, w = W.shape[-2:]
W = W.view(1, 1, h, w).repeat(c,1,1,1)
c = X.shape[1]
# Do the convolution
Y = F.conv2d(X, W.flip(2).flip(3), groups=c, **kwargs)
return Y
def seperability_filter(X, angle, t1, t2, t3, gpu=True):
assert t1//t2 = t1/t2
R = t1 // t2
H, W = X.shape[-2:]
diag = int(np.ceil(np.sqrt(W**2 + H**2))) // 2
padw = (H-diag)//2
padh = (W-diag)//2
X = F.pad(X, (padw,padw,padh, padh), mode='constant')
X = TF.rotate(X, angle=-angle, expand=False)
if not gpu:
II = integral_image(X)
II_sq = integral_image(torch.pow(X, 2))
for x, y in itertools.product(range(W-t3), range(H-(t1*2_t2*2))):
P1 = RecSum(II, x, y, x+t1, t+t3)
P2
P
Psq
elif gpu:
t_area_sum = RecSum(II, 2*t1+2*t2, t3)
t_area_sum_sq = RecSum(II_sq, 2*t1+2*t2, t3)
P1_w = np.array([[0],[1],[1],[0]]) / 2
P1 = conv2d(t_area_sum, P1_w, dilation=(t1, t3))
P2_w = np.array([[1],[0],[0],[1]]) / 2
P2 = conv2d(t_area_sum, P2_w, dilation=(t2, t3))
P_w = np.array([[1],[1],[1],[1]]) / 4
P = conv2d(t_area_sum, P_w, dilation=(t))
Psq = conv2d(t_area_sum_sq, P_w, dilation=(t))
N1 = t*2 * t
N2 = t*t
N = N1 + N2
Sb = (N1/N)*torch.pow((P1 - P),2) + (N2/N)*torch.pow((P2 - P),2)
St = torch.pow(P, 2) - Psq
sepmap = Sb/St
sepmap = TF.rotate(sepmap, angle=angle, expand=False)
sepmap = TF.center_crop(sepmap, (H-4*t, W-t))
sepmap[sepmap.isinf()] = 0
sepmap[sepmap.isnan()] = 0
sepmap[sepmap<0] = 0
return sepmap
import glob
import torch
import torch
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from PIL import Image
from RSF.utils import *
images = [Image.open(path) for path in glob.glob('../data/samples/**.jpg', recursive=True)]
c, h, w = 1, 300, 400
image_transform = transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Resize(w),
transforms.CenterCrop((h, w))
])
X = torch.stack([image_transform(im) for im in images])
plot_images(X)
H, W = X.shape[-2:]
sepmap = torch.zeros((H, W))
# for i in range(5,6):
# sepmap = torch.max(sepmap, F.interpolate(seperability_filter(X, 0, i), size=(H,W)))
# plot_images(sepmap)
sepmap = seperability_filter(X, 0, 60)
plot_images(sepmap)
%%time
# from scipy.stats.mstats import gmean
T = 15
P = torch.stack((
seperability_filter(X, 0, T),
seperability_filter(X, 45, T),
seperability_filter(X, 90, T),
seperability_filter(X, 135, T)
))
plot_images(gmean(P))