Rectangular Separability Filters (RSF) for circular object detection.
%load_ext autoreload
%autoreload 2

conv2d[source]

conv2d(X, W, normalize_weights=True, **kwargs)

X = t.rand((1, 500, 500))

N1 = 15 # Box blur with kernel size 15
W1 = t.tensor([[1/N1**2]*N1]*N1)
N2 = 50 # Box blur with kernel size 50
W2 = t.tensor([[1/N2**2]*N2]*N2)

Y1 = conv2d(X, W1, padding=1).squeeze(0)
Y2 = conv2d(X, W2, padding=1).squeeze(0)
plot_images([X, Y1, Y2])

cvtIntegralImage[source]

cvtIntegralImage(X)

cvtCombSimpRectFilter[source]

cvtCombSimpRectFilter(I, P, sh)

tmpFnc[source]

tmpFnc(I, P, bh, bw, sh, sw, dh, dw)

def RecSum(SAT, h, w):
    W = [[1,-1],[-1,1]]
    return conv2d(SAT, W, dilation=(h, w))

def integral_image(X): # For batch
    for i in range(1, X.ndim):
        X = X.cumsum(i)
    return X

def conv2d(X, W, normalize_weights=True, **kwargs):
    X = torch.FloatTensor(X)
    W = torch.FloatTensor(W)
    
    # Do stuff on the input tensor
    if X.ndim==2:
        X = X.view(1, 1, *X.shape)
    elif X.ndim==3:
        X = X.view(1, *X.shape)
    else:
        assert X.ndim==4

    # Do stuff on the weights
    c = X.shape[1]
    h, w = W.shape[-2:]
    W = W.view(1, 1, h, w).repeat(c,1,1,1)
    c = X.shape[1]
    
    # Do the convolution
    Y = F.conv2d(X, W.flip(2).flip(3), groups=c, **kwargs)
    return Y

def seperability_filter(X, angle, t1, t2, t3, gpu=True):
    assert t1//t2 = t1/t2
    R = t1 // t2
    
    H, W = X.shape[-2:]
    diag = int(np.ceil(np.sqrt(W**2 + H**2))) // 2
    padw = (H-diag)//2
    padh = (W-diag)//2
    
    X = F.pad(X, (padw,padw,padh, padh), mode='constant')
    X = TF.rotate(X, angle=-angle, expand=False)

    if not gpu:
        II = integral_image(X) 
        II_sq = integral_image(torch.pow(X, 2)) 
        for x, y in itertools.product(range(W-t3), range(H-(t1*2_t2*2))):
            P1 = RecSum(II, x, y, x+t1, t+t3)
            P2
            P
            Psq
    elif gpu:
        t_area_sum = RecSum(II, 2*t1+2*t2, t3)
        t_area_sum_sq = RecSum(II_sq,  2*t1+2*t2, t3)
        P1_w = np.array([[0],[1],[1],[0]]) / 2
        P1 = conv2d(t_area_sum, P1_w, dilation=(t1, t3))
        P2_w = np.array([[1],[0],[0],[1]]) / 2
        P2 = conv2d(t_area_sum, P2_w, dilation=(t2, t3))
        P_w = np.array([[1],[1],[1],[1]]) / 4
        P = conv2d(t_area_sum, P_w, dilation=(t))
        Psq = conv2d(t_area_sum_sq, P_w, dilation=(t))

    N1 = t*2 * t
    N2 = t*t
    N = N1 + N2

    Sb = (N1/N)*torch.pow((P1 - P),2) + (N2/N)*torch.pow((P2 - P),2)
    St = torch.pow(P, 2) - Psq

    sepmap = Sb/St
    sepmap = TF.rotate(sepmap, angle=angle, expand=False)
    sepmap = TF.center_crop(sepmap, (H-4*t, W-t))

    sepmap[sepmap.isinf()] = 0
    sepmap[sepmap.isnan()] = 0
    sepmap[sepmap<0] = 0

    return sepmap
import glob
import torch
import torch
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from PIL import Image
from RSF.utils import *

images = [Image.open(path) for path in glob.glob('../data/samples/**.jpg', recursive=True)]
c, h, w = 1, 300, 400
image_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Resize(w),
    transforms.CenterCrop((h, w))
])

X = torch.stack([image_transform(im) for im in images]) 
plot_images(X)

H, W = X.shape[-2:]
sepmap = torch.zeros((H, W))
# for i in range(5,6):
#     sepmap = torch.max(sepmap, F.interpolate(seperability_filter(X, 0, i), size=(H,W)))
# plot_images(sepmap)

sepmap = seperability_filter(X, 0, 60)
plot_images(sepmap)
torch.Size([9, 1, 300, 400]) 25 75
torch.Size([9, 1, 450, 450])
torch.Size([9, 1, 450, 450])
torch.Size([9, 1, 210, 390])
torch.Size([9, 1, 210, 390])
torch.Size([9, 1, 60, 340]) (60, 340)
%%time
# from scipy.stats.mstats import gmean

T = 15
P = torch.stack((
    seperability_filter(X, 0, T),
    seperability_filter(X, 45, T),
    seperability_filter(X, 90, T),
    seperability_filter(X, 135, T)
))

plot_images(gmean(P))
torch.Size([9, 1, 300, 400]) 40 105
torch.Size([9, 1, 510, 480])
torch.Size([9, 1, 510, 480])
torch.Size([9, 1, 450, 465])
torch.Size([9, 1, 450, 465])
torch.Size([9, 1, 300, 400]) (270, 385)
torch.Size([9, 1, 300, 400]) 40 105
torch.Size([9, 1, 510, 480])
torch.Size([9, 1, 510, 480])
torch.Size([9, 1, 450, 465])
torch.Size([9, 1, 450, 465])
torch.Size([9, 1, 300, 400]) (270, 385)
torch.Size([9, 1, 300, 400]) 40 105
torch.Size([9, 1, 510, 480])
torch.Size([9, 1, 510, 480])
torch.Size([9, 1, 450, 465])
torch.Size([9, 1, 450, 465])
torch.Size([9, 1, 300, 400]) (270, 385)
torch.Size([9, 1, 300, 400]) 40 105
torch.Size([9, 1, 510, 480])
torch.Size([9, 1, 510, 480])
torch.Size([9, 1, 450, 465])
torch.Size([9, 1, 450, 465])
torch.Size([9, 1, 300, 400]) (270, 385)
CPU times: user 2.71 s, sys: 631 ms, total: 3.34 s
Wall time: 1.05 s