In the case of 1-dimensional arrays x
and y
with lengths L
and M
, resp., you need to pad the FFT to size L + M - 1
for mode="full"
. For the 2-d case, apply that rule to each axis.
Using numpy, you can compute the size in the 2-d case with
np.array(x.shape) + np.array(y.shape) - 1
To implement the "valid" mode, you'll have to compute the "full" result and then slice out the valid part. For 1-d, assuming L
> M
, the valid data is the L - M + 1
elements in the center of the full data. Again, apply the same rule to each axis in the 2-d case.
For example,
import numpy as np
from numpy.fft import fft2, ifft2
def fftconvolve2d(x, y, mode="full"):
"""
x and y must be real 2-d numpy arrays.
mode must be "full" or "valid".
"""
x_shape = np.array(x.shape)
y_shape = np.array(y.shape)
z_shape = x_shape + y_shape - 1
z = ifft2(fft2(x, z_shape) * fft2(y, z_shape)).real
if mode == "valid":
# To compute a valid shape, either np.all(x_shape >= y_shape) or
# np.all(y_shape >= x_shape).
valid_shape = x_shape - y_shape + 1
if np.any(valid_shape < 1):
valid_shape = y_shape - x_shape + 1
if np.any(valid_shape < 1):
raise ValueError("empty result for valid shape")
start = (z_shape - valid_shape) // 2
end = start + valid_shape
z = z[start[0]:end[0], start[1]:end[1]]
return z
Here's the function applied to your example data:
In [146]: image
Out[146]:
array([[ 3, 2, 5, 6, 7, 8],
[ 5, 4, 2, 10, 8, 1]])
In [147]: kernel
Out[147]:
array([[4, 5],
[1, 2]])
In [148]: fftconvolve2d(image, kernel, mode="full")
Out[148]:
array([[ 12., 23., 30., 49., 58., 67., 40.],
[ 23., 49., 37., 66., 101., 66., 21.],
[ 5., 14., 10., 14., 28., 17., 2.]])
In [149]: fftconvolve2d(image, kernel, mode="valid")
Out[149]: array([[ 49., 37., 66., 101., 66.]])
More error checking could be added, and it could be modified to handle complex arrays and n-dimensional arrays. And it would be nice if additional padding was chosen to make the FFT calculation more efficient. If you made all those enhancements, you might end up with something like scipy.signal.fftconvolve
(https://github.com/scipy/scipy/blob/master/scipy/signal/signaltools.py#L210):
In [152]: from scipy.signal import fftconvolve
In [153]: fftconvolve(image, kernel, mode="full")
Out[153]:
array([[ 12., 23., 30., 49., 58., 67., 40.],
[ 23., 49., 37., 66., 101., 66., 21.],
[ 5., 14., 10., 14., 28., 17., 2.]])
In [154]: fftconvolve(image, kernel, mode="valid")
Out[154]: array([[ 49., 37., 66., 101., 66.]])