Skip to content

FFT

torchlinops.linops.FFT

Bases: NamedLinop

\(n\)-dimensional Fast Fourier Transform as a named linear operator.

With norm="ortho" (the default), the FFT is unitary: \(F^H F = I\). This means the normal operator is the identity and the adjoint is the inverse FFT.

ATTRIBUTE DESCRIPTION
ndim

Number of spatial dimensions to transform.

TYPE: int

norm

FFT normalization mode.

TYPE: str or None

centered

Whether to treat the array center as the origin (sigpy convention).

TYPE: bool

Source code in src/torchlinops/linops/fft.py
class FFT(NamedLinop):
    """$n$-dimensional Fast Fourier Transform as a named linear operator.

    With ``norm="ortho"`` (the default), the FFT is unitary: $F^H F = I$.
    This means the normal operator is the identity and the adjoint is the
    inverse FFT.

    Attributes
    ----------
    ndim : int
        Number of spatial dimensions to transform.
    norm : str or None
        FFT normalization mode.
    centered : bool
        Whether to treat the array center as the origin (sigpy convention).
    """

    def __init__(
        self,
        ndim: int,
        batch_shape: Optional[Shape] = None,
        grid_shapes: Optional[tuple[Shape, Shape]] = None,
        norm: Optional[str] = "ortho",
        centered: bool = False,
    ):
        """
        Parameters
        ----------
        ndim : int
            Number of dimensions to transform (1, 2, or 3).
        batch_shape : Shape, optional
            Named batch dimensions prepended to the grid dimensions.
            Defaults to an empty shape.
        grid_shapes : tuple[Shape, Shape], optional
            Pair of shapes ``(primal, dual)`` naming the input (image-space)
            and output (k-space) grid dimensions. Defaults to
            ``(Nx[, Ny[, Nz]])`` and ``(Kx[, Ky[, Kz]])``.
        norm : str or None, default ``"ortho"``
            Normalization applied to the FFT. Only ``"ortho"`` gives a true
            unitary forward/adjoint pair.
        centered : bool, default False
            If ``True``, treat the center of the array (``N // 2``) as the
            origin via ``fftshift`` / ``ifftshift``. Mimics sigpy convention.
        """
        self.ndim = ndim
        self.dim = tuple(range(-self.ndim, 0))
        self.grid_shapes = grid_shapes
        if grid_shapes is None:
            grid_shapes = get_nd_shape(self.dim), get_nd_shape(self.dim, kspace=True)
        elif len(grid_shapes) != 2:
            raise ValueError(
                f"grid_shapes should consist of two shape tuples but got {grid_shapes}"
            )
        if len(grid_shapes[0]) != len(grid_shapes[1]):
            raise ValueError(
                f"Input and output shapes of FFT must have same length but got len({grid_shapes[0]} != len({grid_shapes[1]})"
            )
        batch_shape = default_to(("...",), batch_shape)
        dim_shape = NS(*grid_shapes)
        shape = NS(batch_shape) + dim_shape
        super().__init__(shape)
        self._shape.batch_shape = batch_shape
        self._shape.input_grid_shape = grid_shapes[0]
        self._shape.output_grid_shape = grid_shapes[1]
        self.norm = norm
        self.centered = centered

    @property
    def batch_shape(self):
        return self._shape.batch_shape

    @staticmethod
    def fn(linop, x):
        if linop.centered:
            x = fft.ifftshift(x, dim=linop.dim)
        x = fft.fftn(x, dim=linop.dim, norm=linop.norm)
        if linop.centered:
            x = fft.fftshift(x, dim=linop.dim)
        return x

    @staticmethod
    def adj_fn(linop, x):
        if linop.centered:
            x = fft.ifftshift(x, dim=linop.dim)
        x = fft.ifftn(x, dim=linop.dim, norm=linop.norm)
        if linop.centered:
            x = fft.fftshift(x, dim=linop.dim)
        return x

    @staticmethod
    def normal_fn(linop, x):
        return x

    def split_forward(self, ibatch, obatch):
        """Splitting does nothing."""
        # TODO: raise an error if the FFT is split along an input or output grid dim
        new = copy(self)
        return new

    def normal(self, inner=None):
        """Return the normal operator $F^H F$.

        With orthonormal normalization, $F^H F = I$, so this returns an
        ``Identity`` when no inner operator is provided.

        Parameters
        ----------
        inner : NamedLinop, optional
            Inner operator for Toeplitz embedding.

        Returns
        -------
        NamedLinop
            ``Identity`` if *inner* is ``None``, otherwise the composed normal.
        """
        if inner is None:
            return Identity(self.ishape)
        return super().normal(inner)

__init__

__init__(
    ndim: int,
    batch_shape: Optional[Shape] = None,
    grid_shapes: Optional[tuple[Shape, Shape]] = None,
    norm: Optional[str] = "ortho",
    centered: bool = False,
)
PARAMETER DESCRIPTION
ndim

Number of dimensions to transform (1, 2, or 3).

TYPE: int

batch_shape

Named batch dimensions prepended to the grid dimensions. Defaults to an empty shape.

TYPE: Shape DEFAULT: None

grid_shapes

Pair of shapes (primal, dual) naming the input (image-space) and output (k-space) grid dimensions. Defaults to (Nx[, Ny[, Nz]]) and (Kx[, Ky[, Kz]]).

TYPE: tuple[Shape, Shape] DEFAULT: None

norm

Normalization applied to the FFT. Only "ortho" gives a true unitary forward/adjoint pair.

TYPE: str or None DEFAULT: ``"ortho"``

centered

If True, treat the center of the array (N // 2) as the origin via fftshift / ifftshift. Mimics sigpy convention.

TYPE: bool DEFAULT: False

Source code in src/torchlinops/linops/fft.py
def __init__(
    self,
    ndim: int,
    batch_shape: Optional[Shape] = None,
    grid_shapes: Optional[tuple[Shape, Shape]] = None,
    norm: Optional[str] = "ortho",
    centered: bool = False,
):
    """
    Parameters
    ----------
    ndim : int
        Number of dimensions to transform (1, 2, or 3).
    batch_shape : Shape, optional
        Named batch dimensions prepended to the grid dimensions.
        Defaults to an empty shape.
    grid_shapes : tuple[Shape, Shape], optional
        Pair of shapes ``(primal, dual)`` naming the input (image-space)
        and output (k-space) grid dimensions. Defaults to
        ``(Nx[, Ny[, Nz]])`` and ``(Kx[, Ky[, Kz]])``.
    norm : str or None, default ``"ortho"``
        Normalization applied to the FFT. Only ``"ortho"`` gives a true
        unitary forward/adjoint pair.
    centered : bool, default False
        If ``True``, treat the center of the array (``N // 2``) as the
        origin via ``fftshift`` / ``ifftshift``. Mimics sigpy convention.
    """
    self.ndim = ndim
    self.dim = tuple(range(-self.ndim, 0))
    self.grid_shapes = grid_shapes
    if grid_shapes is None:
        grid_shapes = get_nd_shape(self.dim), get_nd_shape(self.dim, kspace=True)
    elif len(grid_shapes) != 2:
        raise ValueError(
            f"grid_shapes should consist of two shape tuples but got {grid_shapes}"
        )
    if len(grid_shapes[0]) != len(grid_shapes[1]):
        raise ValueError(
            f"Input and output shapes of FFT must have same length but got len({grid_shapes[0]} != len({grid_shapes[1]})"
        )
    batch_shape = default_to(("...",), batch_shape)
    dim_shape = NS(*grid_shapes)
    shape = NS(batch_shape) + dim_shape
    super().__init__(shape)
    self._shape.batch_shape = batch_shape
    self._shape.input_grid_shape = grid_shapes[0]
    self._shape.output_grid_shape = grid_shapes[1]
    self.norm = norm
    self.centered = centered

normal

normal(inner=None)

Return the normal operator \(F^H F\).

With orthonormal normalization, \(F^H F = I\), so this returns an Identity when no inner operator is provided.

PARAMETER DESCRIPTION
inner

Inner operator for Toeplitz embedding.

TYPE: NamedLinop DEFAULT: None

RETURNS DESCRIPTION
NamedLinop

Identity if inner is None, otherwise the composed normal.

Source code in src/torchlinops/linops/fft.py
def normal(self, inner=None):
    """Return the normal operator $F^H F$.

    With orthonormal normalization, $F^H F = I$, so this returns an
    ``Identity`` when no inner operator is provided.

    Parameters
    ----------
    inner : NamedLinop, optional
        Inner operator for Toeplitz embedding.

    Returns
    -------
    NamedLinop
        ``Identity`` if *inner* is ``None``, otherwise the composed normal.
    """
    if inner is None:
        return Identity(self.ishape)
    return super().normal(inner)