Skip to content

Math & FFT Utilities

Centered FFT functions and mathematical helpers for testing.

Centered FFT

torchlinops.utils.cfftn

cfftn(x, dim=None, norm='ortho')

Compute the centered n-dimenional FFT.

Assumes the origin lies in the middle of the array (i.e., that the array has been fftshifted)

PARAMETER DESCRIPTION
dim

The dimensions over which to take the ifft.

TYPE: tuple[int, ...] DEFAULT: None

norm

Normalization mode. For the forward transform (fft()), these correspond to:

  • "forward" - normalize by 1/n
  • "backward" - no normalization
  • "ortho" - normalize by 1/sqrt(n) (making the FFT orthonormal)

Calling the backward transform (cifftn()) with the same normalization mode will apply an overall normalization of 1/n between the two transforms. This is required to make ifft() the exact inverse. Default is "backward" (no normalization).

TYPE: norm(str, optional) DEFAULT: 'ortho'

Source code in src/torchlinops/utils/_fft.py
def cfftn(x, dim=None, norm="ortho"):
    """Compute the centered n-dimenional FFT.

    Assumes the origin lies in the middle of the array (i.e., that the array has
    been fftshifted)

    Parameters
    ----------
    dim : tuple[int, ...]
        The dimensions over which to take the ifft.
    norm : norm (str, optional)
        Normalization mode. For the forward transform (fft()), these correspond to:

        - "forward" - normalize by 1/n
        - "backward" - no normalization
        - "ortho" - normalize by 1/sqrt(n) (making the FFT orthonormal)

        Calling the backward transform (cifftn()) with the same normalization
        mode will apply an overall normalization of 1/n between the two transforms.
        This is required to make ifft() the exact inverse. Default is "backward"
        (no normalization).
    """
    x = fft.ifftshift(x, dim=dim)
    x = fft.fftn(x, dim=dim, norm=norm)
    x = fft.fftshift(x, dim=dim)
    return x

torchlinops.utils.cifftn

cifftn(x, dim=None, norm='ortho')

Compute the centered n-dimensional inverse FFT.

Assumes the origin lies in the middle of the array (i.e., that the array has been fftshifted)

PARAMETER DESCRIPTION
dim

The dimensions over which to take the ifft.

TYPE: tuple[int, ...] DEFAULT: None

norm

Normalization mode. For the backward transform (ifft()), these correspond to:

  • "forward" - no normalization
  • "backward" - normalize by 1/n
  • "ortho" - normalize by 1/sqrt(n) (making the IFFT orthonormal)

Calling the forward transform (cfftn()) with the same normalization mode will apply an overall normalization of 1/n between the two transforms. This is required to make ifft() the exact inverse. Default is "backward" (normalize by 1/n).

TYPE: norm(str, optional) DEFAULT: 'ortho'

Source code in src/torchlinops/utils/_fft.py
def cifftn(x, dim=None, norm="ortho"):
    """Compute the centered n-dimensional inverse FFT.

    Assumes the origin lies in the middle of the array (i.e., that the array has
    been fftshifted)

    Parameters
    ----------
    dim : tuple[int, ...]
        The dimensions over which to take the ifft.
    norm : norm (str, optional)
        Normalization mode. For the backward transform (ifft()), these correspond to:

        - "forward" - no normalization
        - "backward" - normalize by 1/n
        - "ortho" - normalize by 1/sqrt(n) (making the IFFT orthonormal)

        Calling the forward transform (cfftn()) with the same normalization mode
        will apply an overall normalization of 1/n between the two transforms. This
        is required to make ifft() the exact inverse. Default is "backward"
        (normalize by 1/n).
    """
    x = fft.ifftshift(x, dim=dim)
    x = fft.ifftn(x, dim=dim, norm=norm)
    x = fft.fftshift(x, dim=dim)
    return x

torchlinops.utils.cfft

cfft(x: Tensor, **kwargs)
Source code in src/torchlinops/utils/_fft.py
def cfft(x: Tensor, **kwargs):
    return cfftn(x, dim=(-1,), **kwargs)

torchlinops.utils.cifft

cifft(x: Tensor, **kwargs)
Source code in src/torchlinops/utils/_fft.py
def cifft(x: Tensor, **kwargs):
    return cifftn(x, dim=(-1,), **kwargs)

torchlinops.utils.cfft2

cfft2(x: Tensor, **kwargs)
Source code in src/torchlinops/utils/_fft.py
def cfft2(x: Tensor, **kwargs):
    return cfftn(x, dim=(-2, -1), **kwargs)

torchlinops.utils.cifft2

cifft2(x: Tensor, **kwargs)
Source code in src/torchlinops/utils/_fft.py
def cifft2(x: Tensor, **kwargs):
    return cifftn(x, dim=(-2, -1), **kwargs)

Adjoint Helpers

torchlinops.utils.inner

inner(x, y)

Complex inner product

Source code in src/torchlinops/utils/_adjoint_helpers.py
def inner(x, y):
    """Complex inner product"""
    return torch.sum(x.conj() * y)

torchlinops.utils.is_adjoint

is_adjoint(
    A: Callable,
    x: Tensor,
    y: Tensor,
    atol: float = 1e-05,
    rtol: float = 1e-08,
)

The adjoint test states that if A and AH are adjoints, then inner(y, Ax) = inner(AHy, x)

Source code in src/torchlinops/utils/_adjoint_helpers.py
def is_adjoint(
    A: Callable,
    x: torch.Tensor,
    y: torch.Tensor,
    atol: float = 1e-5,
    rtol: float = 1e-8,
):
    """
    The adjoint test states that if A and AH are adjoints, then
    inner(y, Ax) = inner(AHy, x)
    """
    yAx = inner(y, A(x))
    xAHy = inner(A.H(y), x)
    return torch.isclose(yAx, xAHy, atol=atol, rtol=rtol).all()