Question

This is an incomplete Python snippet of convolution with FFT.

I want to modify it to make it support, 1) valid convolution 2) and full convolution

import numpy as np
from numpy.fft import fft2, ifft2

image = np.array([[3,2,5,6,7,8],
                  [5,4,2,10,8,1]])

kernel = np.array([[4,5],
                   [1,2]])

fft_size =  # what size should I put here for,
            # 1) valid convolution
            # 2) full convolution

convolution = ifft2(fft2(image, fft_size) * fft2(kernel, fft_size))

Thank you in advance.

Was it helpful?

Solution

In the case of 1-dimensional arrays x and y with lengths L and M, resp., you need to pad the FFT to size L + M - 1 for mode="full". For the 2-d case, apply that rule to each axis.

Using numpy, you can compute the size in the 2-d case with

np.array(x.shape) + np.array(y.shape) - 1

To implement the "valid" mode, you'll have to compute the "full" result and then slice out the valid part. For 1-d, assuming L > M, the valid data is the L - M + 1 elements in the center of the full data. Again, apply the same rule to each axis in the 2-d case.

For example,

import numpy as np
from numpy.fft import fft2, ifft2


def fftconvolve2d(x, y, mode="full"):
    """
    x and y must be real 2-d numpy arrays.

    mode must be "full" or "valid".
    """
    x_shape = np.array(x.shape)
    y_shape = np.array(y.shape)
    z_shape = x_shape + y_shape - 1
    z = ifft2(fft2(x, z_shape) * fft2(y, z_shape)).real

    if mode == "valid":
        # To compute a valid shape, either np.all(x_shape >= y_shape) or
        # np.all(y_shape >= x_shape).
        valid_shape = x_shape - y_shape + 1
        if np.any(valid_shape < 1):
            valid_shape = y_shape - x_shape + 1
            if np.any(valid_shape < 1):
                raise ValueError("empty result for valid shape")
        start = (z_shape - valid_shape) // 2
        end = start + valid_shape
        z = z[start[0]:end[0], start[1]:end[1]]

    return z

Here's the function applied to your example data:

In [146]: image
Out[146]: 
array([[ 3,  2,  5,  6,  7,  8],
       [ 5,  4,  2, 10,  8,  1]])

In [147]: kernel
Out[147]: 
array([[4, 5],
       [1, 2]])

In [148]: fftconvolve2d(image, kernel, mode="full")
Out[148]: 
array([[  12.,   23.,   30.,   49.,   58.,   67.,   40.],
       [  23.,   49.,   37.,   66.,  101.,   66.,   21.],
       [   5.,   14.,   10.,   14.,   28.,   17.,    2.]])

In [149]: fftconvolve2d(image, kernel, mode="valid")
Out[149]: array([[  49.,   37.,   66.,  101.,   66.]])

More error checking could be added, and it could be modified to handle complex arrays and n-dimensional arrays. And it would be nice if additional padding was chosen to make the FFT calculation more efficient. If you made all those enhancements, you might end up with something like scipy.signal.fftconvolve (https://github.com/scipy/scipy/blob/master/scipy/signal/signaltools.py#L210):

In [152]: from scipy.signal import fftconvolve

In [153]: fftconvolve(image, kernel, mode="full")
Out[153]: 
array([[  12.,   23.,   30.,   49.,   58.,   67.,   40.],
       [  23.,   49.,   37.,   66.,  101.,   66.,   21.],
       [   5.,   14.,   10.,   14.,   28.,   17.,    2.]])

In [154]: fftconvolve(image, kernel, mode="valid")
Out[154]: array([[  49.,   37.,   66.,  101.,   66.]])
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top