Skip to content

Index

torchlinops.linops

INDENT module-attribute

INDENT = Indenter(start_level=0)

PadLast module-attribute

PadLast = Pad

Shape module-attribute

Shape = Sequence[NamedDimension | str]

logger module-attribute

logger = getLogger('torchlinops')

Add

Bases: Threadable, NamedLinop

The sum of one or more linear operators.

Inherits from Threadable to support parallel execution of sub-linops. When threaded=True (default), each sub-linop is executed in parallel using a ThreadPoolExecutor, which is useful for I/O-bound operations or operations that release the GIL (e.g., PyTorch tensor operations).

Note that shared linops (e.g., Add(A, A)) are automatically shallow- copied to ensure independent identity for threading, while still sharing tensor data. See Threadable for details.

ATTRIBUTE DESCRIPTION
linops

The list of linops being added together.

TYPE: ModuleList

threaded

Whether to run sub-linops in parallel. Default is True.

TYPE: bool

num_workers

Number of worker threads. If None, defaults to the number of sub-linops.

TYPE: int | None

Source code in src/torchlinops/linops/add.py
class Add(Threadable, NamedLinop):
    """The sum of one or more linear operators.

    Inherits from ``Threadable`` to support parallel execution of sub-linops.
    When ``threaded=True`` (default), each sub-linop is executed in parallel
    using a ThreadPoolExecutor, which is useful for I/O-bound operations or
    operations that release the GIL (e.g., PyTorch tensor operations).

    Note that shared linops (e.g., ``Add(A, A)``) are automatically shallow-
    copied to ensure independent identity for threading, while still sharing
    tensor data. See ``Threadable`` for details.

    Attributes
    ----------
    linops : nn.ModuleList
        The list of linops being added together.
    threaded : bool
        Whether to run sub-linops in parallel. Default is True.
    num_workers : int | None
        Number of worker threads. If None, defaults to the number of sub-linops.
    """

    def __init__(self, *linops, **kwargs):
        """
        Parameters
        ----------
        *linops : tuple[NamedLinop]
            The linear operators to be added together.
        """
        assert all(isequal(linop.ishape, linops[0].ishape) for linop in linops), (
            f"Add: All linops must share same ishape. Found {linops}"
        )
        assert all(isequal(linop.oshape, linops[0].oshape) for linop in linops), (
            f"Add: All linops must share same oshape. Linops: {linops}"
        )
        super().__init__(NS(linops[0].ishape, linops[0].oshape), **kwargs)
        self.linops = nn.ModuleList(linops)

    @staticmethod
    def fn(add, x: torch.Tensor, /):
        if add.threaded:
            return add.threaded_apply_sum_reduce([x] * len(add.linops), add.num_workers)
        return sum(linop(x) for linop in add.linops)

    @staticmethod
    def adj_fn(add, x: torch.Tensor, /):
        if add.threaded:
            adj_linops = [linop.H for linop in add.linops]
            return add.threaded_apply_sum_reduce([x] * len(adj_linops), add.num_workers)
        return sum(linop.H(x) for linop in add.linops)

    def split_forward(self, ibatch, obatch):
        split = copy(self)
        linops = [linop.split_forward(ibatch, obatch) for linop in self.linops]
        split.linops = nn.ModuleList(linops)
        return split

    def adjoint(self):
        adj = copy(self)
        adj.linops = nn.ModuleList([linop.adjoint() for linop in self.linops])
        adj.shape = self.shape.adjoint()
        return adj

    def normal(self, inner=None):
        if inner is None:
            max_ishape = max_shape([linop.N.ishape for linop in self.linops])
            max_oshape = max_shape([linop.N.oshape for linop in self.linops])
            new_shape = NS(max_ishape, max_oshape)
            all_combinations = []
            for left_linop in self.linops:
                for right_linop in self.linops:
                    if left_linop == right_linop:
                        all_combinations.append(left_linop.N)
                    else:
                        all_combinations.append(left_linop.H @ right_linop)
            all_combinations = standardize_shapes(all_combinations, new_shape)
            normal = copy(self)
            normal.linops = nn.ModuleList(list(all_combinations))
            normal.shape = normal.linops[0].shape
            return normal
        return super().normal(inner)

    def size(self, dim):
        for linop in self.linops:
            out = linop.size(dim)
            if out is not None:
                return out
        return None

    @property
    def dims(self):
        return set().union(*[linop.dims for linop in self.linops])

    @property
    def H(self):
        try:
            if config.cache_adjoint_normal:
                config._warn_if_caching_enabled()
                if self._adjoint is None:
                    self._adjoint = [self.adjoint()]
                return self._adjoint[0]
            return self.adjoint()
        except AttributeError as e:
            raise RuntimeError(f"AttributeError in {type(self).__name__}.H: {e}") from e

    @property
    def N(self):
        try:
            if config.cache_adjoint_normal:
                config._warn_if_caching_enabled()
                if self._normal is None:
                    self._normal = [self.normal()]
                return self._normal[0]
            return self.normal()
        except AttributeError as e:
            raise RuntimeError(f"AttributeError in {type(self).__name__}.N: {e}") from e

    def flatten(self):
        return [self]

    def __getitem__(self, idx):
        linops = self.linops[idx]
        if isinstance(linops, NamedLinop):
            return linops
        new = copy(self)
        new.linops = nn.ModuleList(linops)
        return new

    def __len__(self):
        return len(self.linops)

    def __repr__(self):
        linop_chain = " + ".join(repr(linop) for linop in self.linops)
        return linop_chain

__init__

__init__(*linops, **kwargs)
PARAMETER DESCRIPTION
*linops

The linear operators to be added together.

TYPE: tuple[NamedLinop] DEFAULT: ()

Source code in src/torchlinops/linops/add.py
def __init__(self, *linops, **kwargs):
    """
    Parameters
    ----------
    *linops : tuple[NamedLinop]
        The linear operators to be added together.
    """
    assert all(isequal(linop.ishape, linops[0].ishape) for linop in linops), (
        f"Add: All linops must share same ishape. Found {linops}"
    )
    assert all(isequal(linop.oshape, linops[0].oshape) for linop in linops), (
        f"Add: All linops must share same oshape. Linops: {linops}"
    )
    super().__init__(NS(linops[0].ishape, linops[0].oshape), **kwargs)
    self.linops = nn.ModuleList(linops)

ArrayToBlocks

Bases: NamedLinop

Extract sliding windows from an array.

Adjoint of BlocksToArray.

Source code in src/torchlinops/linops/array_to_blocks.py
class ArrayToBlocks(NamedLinop):
    """Extract sliding windows from an array.

    Adjoint of [BlocksToArray](#BlocksToArray).

    """

    def __init__(
        self,
        grid_size: tuple[int, ...],
        block_size: tuple[int, ...],
        stride: tuple[int, ...],
        mask: Optional[Tensor] = None,
        batch_shape: Optional[Shape] = None,
        array_shape: Optional[Shape] = None,
        blocks_shape: Optional[Shape] = None,
    ):
        """
        Parameters
        ----------
        grid_size : tuple[int, ...]
            Size of the input array spatial dimensions.
        block_size : tuple[int, ...]
            Size of each extracted block.
        stride : tuple[int, ...]
            Stride between consecutive blocks.
        mask : Tensor, optional
            Boolean mask selecting a subset of blocks.
        batch_shape : Shape, optional
            Named shape for batch dimensions.
        array_shape : Shape, optional
            Named shape for the input array dimensions.
        blocks_shape : Shape, optional
            Named shape for the output block dimensions.
        """
        self.grid_size = grid_size
        self.ndim = len(self.grid_size)
        self.block_size = block_size
        self.stride = stride

        self.batch_shape = default_to(("...",), batch_shape)
        self.array_shape = default_to(("...",), array_shape)
        self.blocks_shape = default_to(("...",), blocks_shape)
        shape = NS(self.batch_shape) + NS(self.array_shape, self.blocks_shape)
        super().__init__(shape)

        if mask is not None:
            self.mask = nn.Parameter(mask, requires_grad=False)
        else:
            self.mask = mask

    @staticmethod
    def fn(arraytoblocks, x, /):
        return F.array_to_blocks(
            x,
            arraytoblocks.block_size,
            arraytoblocks.stride,
            arraytoblocks.mask,
        )

    @staticmethod
    def adj_fn(arraytoblocks, x, /):
        return F.blocks_to_array(
            x,
            arraytoblocks.grid_size,
            arraytoblocks.block_size,
            arraytoblocks.stride,
            arraytoblocks.mask,
        )

    @staticmethod
    def normal_fn(arraytoblocks, x, /):
        return arraytoblocks.adj_fn(arraytoblocks, arraytoblocks.fn(arraytoblocks, x))

    def split_forward(self, ibatch, obatch):
        return copy(self)

    def adjoint(self):
        return BlocksToArray(
            self.grid_size,
            self.block_size,
            self.stride,
            self.mask,
            self.batch_shape,
            self.blocks_shape,
            self.array_shape,
        )

    def size(self, dim):
        ndim = len(self.grid_size)
        if dim in self.ishape[-ndim:]:
            i = self.ishape.index(dim) - len(self.ishape)
            return self.grid_size[i]
        return None

__init__

__init__(
    grid_size: tuple[int, ...],
    block_size: tuple[int, ...],
    stride: tuple[int, ...],
    mask: Optional[Tensor] = None,
    batch_shape: Optional[Shape] = None,
    array_shape: Optional[Shape] = None,
    blocks_shape: Optional[Shape] = None,
)
PARAMETER DESCRIPTION
grid_size

Size of the input array spatial dimensions.

TYPE: tuple[int, ...]

block_size

Size of each extracted block.

TYPE: tuple[int, ...]

stride

Stride between consecutive blocks.

TYPE: tuple[int, ...]

mask

Boolean mask selecting a subset of blocks.

TYPE: Tensor DEFAULT: None

batch_shape

Named shape for batch dimensions.

TYPE: Shape DEFAULT: None

array_shape

Named shape for the input array dimensions.

TYPE: Shape DEFAULT: None

blocks_shape

Named shape for the output block dimensions.

TYPE: Shape DEFAULT: None

Source code in src/torchlinops/linops/array_to_blocks.py
def __init__(
    self,
    grid_size: tuple[int, ...],
    block_size: tuple[int, ...],
    stride: tuple[int, ...],
    mask: Optional[Tensor] = None,
    batch_shape: Optional[Shape] = None,
    array_shape: Optional[Shape] = None,
    blocks_shape: Optional[Shape] = None,
):
    """
    Parameters
    ----------
    grid_size : tuple[int, ...]
        Size of the input array spatial dimensions.
    block_size : tuple[int, ...]
        Size of each extracted block.
    stride : tuple[int, ...]
        Stride between consecutive blocks.
    mask : Tensor, optional
        Boolean mask selecting a subset of blocks.
    batch_shape : Shape, optional
        Named shape for batch dimensions.
    array_shape : Shape, optional
        Named shape for the input array dimensions.
    blocks_shape : Shape, optional
        Named shape for the output block dimensions.
    """
    self.grid_size = grid_size
    self.ndim = len(self.grid_size)
    self.block_size = block_size
    self.stride = stride

    self.batch_shape = default_to(("...",), batch_shape)
    self.array_shape = default_to(("...",), array_shape)
    self.blocks_shape = default_to(("...",), blocks_shape)
    shape = NS(self.batch_shape) + NS(self.array_shape, self.blocks_shape)
    super().__init__(shape)

    if mask is not None:
        self.mask = nn.Parameter(mask, requires_grad=False)
    else:
        self.mask = mask

BatchSpec dataclass

Specification for splitting and distributing a linop across devices.

PARAMETER DESCRIPTION
batch_sizes

Mapping from dimension names to chunk sizes for tiling.

TYPE: dict[NamedDimension | str, int]

device_matrix

Array of torch.device objects specifying target devices for each tile. Broadcast to match the tile grid shape.

TYPE: ndarray or list DEFAULT: None

base_device

The device where input/output data resides. Default is CPU.

TYPE: device DEFAULT: None

Source code in src/torchlinops/linops/split.py
@dataclass
class BatchSpec:
    """Specification for splitting and distributing a linop across devices.

    Parameters
    ----------
    batch_sizes : dict[ND | str, int]
        Mapping from dimension names to chunk sizes for tiling.
    device_matrix : np.ndarray or list, optional
        Array of ``torch.device`` objects specifying target devices for each
        tile. Broadcast to match the tile grid shape.
    base_device : torch.device, optional
        The device where input/output data resides. Default is CPU.
    """

    batch_sizes: dict[ND | str, int]
    device_matrix: np.ndarray | None = None
    base_device: torch.device | None = None

    def __post_init__(self):
        if not isinstance(self.batch_sizes, dict):
            warn(
                f"Got {self.batch_sizes} of type {type(self.batch_sizes).__name__} for batch_sizes instead of dict."
            )
        # Ensure ndarray
        if isinstance(self.device_matrix, list | tuple):
            self.device_matrix = np.array(self.device_matrix)

    def broadcast_device_matrix(self, linop):
        # Compute the number of tiles along each batched axis/dimension
        batch_dims = list(self.batch_sizes.keys())
        sizes = {dim: linop.size(dim) for dim in linop.dims}
        tiled_shape = tuple(
            ceil(sizes[dim] / self.batch_sizes[dim]) for dim in batch_dims
        )

        # Broadcast device_matrix over requested tiles.
        # Each tile should receive a single device.
        device_matrix = fuzzy_broadcast_to(self.device_matrix, tiled_shape)
        return device_matrix

BlocksToArray

Bases: NamedLinop

Compose several equally-sized blocks into a larger array.

Adjoint of ArrayToBlocks.

Source code in src/torchlinops/linops/array_to_blocks.py
class BlocksToArray(NamedLinop):
    """Compose several equally-sized blocks into a larger array.

    Adjoint of [ArrayToBlocks](#ArrayToBlocks).
    """

    def __init__(
        self,
        grid_size: tuple[int, ...],
        block_size: tuple[int, ...],
        stride: tuple[int, ...],
        mask: Optional[Tensor] = None,
        batch_shape: Optional = None,
        blocks_shape: Optional = None,
        array_shape: Optional = None,
    ):
        """
        Parameters
        ----------
        grid_size : tuple[int, ...]
            Size of the output array spatial dimensions.
        block_size : tuple[int, ...]
            Size of each block.
        stride : tuple[int, ...]
            Stride between consecutive blocks.
        mask : Tensor, optional
            Boolean mask selecting a subset of blocks.
        batch_shape : optional
            Named shape for batch dimensions.
        blocks_shape : optional
            Named shape for the input block dimensions.
        array_shape : optional
            Named shape for the output array dimensions.
        """
        self.grid_size = grid_size
        self.ndim = len(self.grid_size)
        self.block_size = block_size
        self.stride = stride

        self.batch_shape = default_to(("...",), batch_shape)
        self.blocks_shape = default_to(("...",), blocks_shape)
        self.array_shape = default_to(("...",), array_shape)
        shape = NS(self.batch_shape) + NS(self.blocks_shape, self.array_shape)
        super().__init__(shape)
        if mask is not None:
            self.mask = nn.Parameter(mask, requires_grad=False)
        else:
            self.mask = mask

    @staticmethod
    def fn(blockstoarray, x, /):
        return F.blocks_to_array(
            x,
            blockstoarray.grid_size,
            blockstoarray.block_size,
            blockstoarray.stride,
            blockstoarray.mask,
        )

    @staticmethod
    def adj_fn(blockstoarray, x, /):
        if x.shape[-blockstoarray.ndim :] != blockstoarray.grid_size:
            raise RuntimeError(
                f"BlocksToArray expected input with full size {blockstoarray.grid_size} but got {x.shape}"
            )
        return F.array_to_blocks(
            x,
            blockstoarray.block_size,
            blockstoarray.stride,
            blockstoarray.mask,
        )

    def split_forward(self, ibatch, obatch):
        return copy(self)

    def adjoint(self):
        return ArrayToBlocks(
            self.grid_size,
            self.block_size,
            self.stride,
            self.mask,
            self.batch_shape,
            self.array_shape,
            self.blocks_shape,
        )

    def size(self, dim):
        ndim = len(self.grid_size)
        if dim in self.oshape[-ndim:]:
            i = self.oshape.index(dim) - len(self.oshape)
            return self.grid_size[i]
        return None

__init__

__init__(
    grid_size: tuple[int, ...],
    block_size: tuple[int, ...],
    stride: tuple[int, ...],
    mask: Optional[Tensor] = None,
    batch_shape: Optional = None,
    blocks_shape: Optional = None,
    array_shape: Optional = None,
)
PARAMETER DESCRIPTION
grid_size

Size of the output array spatial dimensions.

TYPE: tuple[int, ...]

block_size

Size of each block.

TYPE: tuple[int, ...]

stride

Stride between consecutive blocks.

TYPE: tuple[int, ...]

mask

Boolean mask selecting a subset of blocks.

TYPE: Tensor DEFAULT: None

batch_shape

Named shape for batch dimensions.

TYPE: optional DEFAULT: None

blocks_shape

Named shape for the input block dimensions.

TYPE: optional DEFAULT: None

array_shape

Named shape for the output array dimensions.

TYPE: optional DEFAULT: None

Source code in src/torchlinops/linops/array_to_blocks.py
def __init__(
    self,
    grid_size: tuple[int, ...],
    block_size: tuple[int, ...],
    stride: tuple[int, ...],
    mask: Optional[Tensor] = None,
    batch_shape: Optional = None,
    blocks_shape: Optional = None,
    array_shape: Optional = None,
):
    """
    Parameters
    ----------
    grid_size : tuple[int, ...]
        Size of the output array spatial dimensions.
    block_size : tuple[int, ...]
        Size of each block.
    stride : tuple[int, ...]
        Stride between consecutive blocks.
    mask : Tensor, optional
        Boolean mask selecting a subset of blocks.
    batch_shape : optional
        Named shape for batch dimensions.
    blocks_shape : optional
        Named shape for the input block dimensions.
    array_shape : optional
        Named shape for the output array dimensions.
    """
    self.grid_size = grid_size
    self.ndim = len(self.grid_size)
    self.block_size = block_size
    self.stride = stride

    self.batch_shape = default_to(("...",), batch_shape)
    self.blocks_shape = default_to(("...",), blocks_shape)
    self.array_shape = default_to(("...",), array_shape)
    shape = NS(self.batch_shape) + NS(self.blocks_shape, self.array_shape)
    super().__init__(shape)
    if mask is not None:
        self.mask = nn.Parameter(mask, requires_grad=False)
    else:
        self.mask = mask

BreakpointLinop

Bases: NamedLinop

Debugging identity operator that drops into pdb on forward/adjoint.

Useful for inspecting intermediate tensor values inside a Chain.

Source code in src/torchlinops/linops/breakpt.py
class BreakpointLinop(NamedLinop):
    """Debugging identity operator that drops into ``pdb`` on forward/adjoint.

    Useful for inspecting intermediate tensor values inside a ``Chain``.
    """

    def __init__(self, ioshape: Optional[Shape] = None):
        super().__init__(NS(ioshape))

    @staticmethod
    def fn(linop, x, /):
        breakpoint()
        return x

    @staticmethod
    def adj_fn(linop, x, /):
        breakpoint()
        return x

    def split_forward(self, ibatch, obatch):
        return self

Chain

Bases: NamedLinop

Composition (sequential application) of named linear operators.

If Chain(A, B, C) is created, then the forward pass applies \(A\) first, then \(B\), then \(C\): mathematically the operator is \(C B A\).

ATTRIBUTE DESCRIPTION
linops

The constituent linops in execution order (inner to outer).

TYPE: ModuleList

Source code in src/torchlinops/linops/chain.py
class Chain(NamedLinop):
    """Composition (sequential application) of named linear operators.

    If ``Chain(A, B, C)`` is created, then the forward pass applies
    $A$ first, then $B$, then $C$: mathematically the operator is $C B A$.

    Attributes
    ----------
    linops : nn.ModuleList
        The constituent linops in **execution order** (inner to outer).
    """

    def __init__(self, *linops, name: Optional[str] = None):
        """
        Parameters
        ----------
        *linops : NamedLinop
            Linops in order of execution. If ``linops = (A, B, C)``, the
            mathematical operator is $C B A$.
        name : str, optional
            Display name for this chain.
        """
        super().__init__(NS(linops[0].ishape, linops[-1].oshape), name=name)
        self.linops = nn.ModuleList(list(linops))
        self._check_inputs_outputs()

    @property
    def linops(self):
        return self._linops

    @linops.setter
    def linops(self, new_linops):
        self._linops = new_linops
        self._setup_events()

    def __setattr__(self, name, value):
        """Bypasses pytorch's setattr, just for linops"""
        if name == "linops":
            # Force descriptor lookup for this name
            type(self).linops.fset(self, value)
        else:
            super().__setattr__(name, value)

    def _check_inputs_outputs(self):
        curr_shape = self.ishape
        for i, linop in enumerate(self.linops):
            if not isequal(linop.ishape, curr_shape):
                raise ValueError(
                    f"Mismatched shape: expected {linop.ishape}, got {curr_shape} at input to {linop}. Full stack: {self}, index {i}"
                )
            curr_shape = linop.oshape

    def _setup_events(self):
        """Copy every linop and point initial linop listener at Chain's input listener."""
        self._linops = nn.ModuleList([copy(linop) for linop in self._linops])
        self._linops[0].input_listener = (self, "input_listener")

    @staticmethod
    def fn(chain, x: torch.Tensor, /):
        for linop in chain.linops:
            x = linop(x)
        return x

    @staticmethod
    def adj_fn(chain, x: torch.Tensor, /):
        for linop in reversed(chain.linops):
            x = linop.H(x)
        return x

    # @staticmethod
    # def normal_fn(chain, x: torch.Tensor):
    #     # fn does the reversing so it's unnecessary to do it here
    #     # If the normal hasn't been explicitly formed with`.N`, do things the naive way
    #     return chain.adj_fn(chain, chain.fn(chain, x))

    def split_forward(self, ibatches, obatches):
        """Split each constituent linop according to per-linop batch slices.

        Parameters
        ----------
        ibatches : list[list[slice]]
            Per-linop input slices. Each element is a list of slices corresponding
            to the input dimensions of one linop in the chain.
        obatches : list[list[slice]]
            Per-linop output slices. Each element is a list of slices corresponding
            to the output dimensions of one linop in the chain.

        Returns
        -------
        Chain
            A new chain of the split sub-linops.
        """
        linops = [
            linop.split_forward(ibatch, obatch)
            for linop, ibatch, obatch in zip(self.linops, ibatches, obatches)
        ]
        split = copy(self)
        split.linops = nn.ModuleList(linops)
        return split

    def size(self, dim):
        out = None
        for linop in self.linops:
            tmp = linop.size(dim)
            if tmp is not None:
                if out is None:
                    out = tmp
                elif out != tmp:
                    raise ValueError(
                        f"Conflicting linop sizes found: {out} and {tmp} for dim {dim} in linop {linop} out of all linops {self.linops}"
                    )
        return out

    @property
    def dims(self):
        """Get the dims that appear anywhere in this linop chain."""
        return set().union(*[linop.dims for linop in self.linops])

    def adjoint(self):
        linops = list(linop.adjoint() for linop in reversed(self.linops))
        adj = copy(self)
        adj.linops = nn.ModuleList(linops)
        return adj

    def normal(self, inner=None):
        """Compute the normal operator by folding through the chain.

        For a chain $C B A$, the normal is computed as
        $A^H (B^H (C^H C (B (A \\cdot))))$ by iterating ``linop.normal(inner)``
        in reverse order. This enables Toeplitz embedding and other per-linop
        normal optimizations to compose correctly.

        Parameters
        ----------
        inner : NamedLinop, optional
            An inner operator seeded from an outer chain or ``None``.

        Returns
        -------
        NamedLinop
            The composed normal operator.
        """
        for linop in reversed(self.linops):
            inner = linop.normal(inner)
        return inner

    @staticmethod
    def split(chain, tile: Mapping[ND | str, slice]):
        """Split a linop into sub-linops.

        Parameters
        ----------
        chain : Chain
            The chain linop to split.
        tile : Mapping[ND | str, slice]
            Dictionary specifying how to slice the linop dimensions
        """
        ibatches = [
            [tile.get(dim, slice(None)) for dim in linop.ishape]
            for linop in chain.linops
        ]
        obatches = [
            [tile.get(dim, slice(None)) for dim in linop.oshape]
            for linop in chain.linops
        ]
        return chain.split_forward(ibatches, obatches)

    @staticmethod
    def adj_split(chain, tile: Mapping[ND | str, slice]):
        """Split an adjoint linop into sub-linops.

        Parameters
        ----------
        chain : Chain
            The chain linop to split.
        tile : Mapping[ND | str, slice]
            Dictionary specifying how to slice the linop dimensions
        """
        ibatches = [
            [tile.get(dim, slice(None)) for dim in linop.ishape]
            for linop in chain.linops
        ]
        obatches = [
            [tile.get(dim, slice(None)) for dim in linop.oshape]
            for linop in chain.linops
        ]
        return chain.H.split_forward(obatches, ibatches).H

    @property
    def shape(self):
        return NS(self.linops[0].ishape, self.linops[-1].oshape)

    @shape.setter
    def shape(self, val):
        self.ishape = val.ishape
        self.oshape = val.oshape

    @property
    def ishape(self):
        return self.linops[0].ishape

    @ishape.setter
    def ishape(self, val):
        self.linops[0].ishape = val

    @property
    def oshape(self):
        return self.linops[-1].oshape

    @oshape.setter
    def oshape(self, val):
        self.linops[-1].oshape = val

    def flatten(self):
        return list(self.linops)

    def __getitem__(self, idx):
        linops = self.linops[idx]
        if isinstance(linops, NamedLinop):
            return linops
        return Chain(*linops, name=self._name)

    def __len__(self):
        return len(self.linops)

    def __repr__(self):
        output = ""
        output += INDENT.indent(self.repr_name + "(\n")
        with INDENT:
            for linop in self.linops:
                output += repr(linop) + "\n"

            if self.start_event is not None:
                output += INDENT.indent(f"start: {self.start_event.event_id:x}\n")
            if self.end_event is not None:
                output += INDENT.indent(f"end: {self.end_event.event_id:x}\n")
        output += INDENT.indent(")")
        return output

    def __copy__(self):
        new = super().__copy__()
        new._setup_events()
        return new

dims property

dims

Get the dims that appear anywhere in this linop chain.

__init__

__init__(*linops, name: Optional[str] = None)
PARAMETER DESCRIPTION
*linops

Linops in order of execution. If linops = (A, B, C), the mathematical operator is \(C B A\).

TYPE: NamedLinop DEFAULT: ()

name

Display name for this chain.

TYPE: str DEFAULT: None

Source code in src/torchlinops/linops/chain.py
def __init__(self, *linops, name: Optional[str] = None):
    """
    Parameters
    ----------
    *linops : NamedLinop
        Linops in order of execution. If ``linops = (A, B, C)``, the
        mathematical operator is $C B A$.
    name : str, optional
        Display name for this chain.
    """
    super().__init__(NS(linops[0].ishape, linops[-1].oshape), name=name)
    self.linops = nn.ModuleList(list(linops))
    self._check_inputs_outputs()

__setattr__

__setattr__(name, value)

Bypasses pytorch's setattr, just for linops

Source code in src/torchlinops/linops/chain.py
def __setattr__(self, name, value):
    """Bypasses pytorch's setattr, just for linops"""
    if name == "linops":
        # Force descriptor lookup for this name
        type(self).linops.fset(self, value)
    else:
        super().__setattr__(name, value)

adj_split staticmethod

adj_split(
    chain, tile: Mapping[NamedDimension | str, slice]
)

Split an adjoint linop into sub-linops.

PARAMETER DESCRIPTION
chain

The chain linop to split.

TYPE: Chain

tile

Dictionary specifying how to slice the linop dimensions

TYPE: Mapping[NamedDimension | str, slice]

Source code in src/torchlinops/linops/chain.py
@staticmethod
def adj_split(chain, tile: Mapping[ND | str, slice]):
    """Split an adjoint linop into sub-linops.

    Parameters
    ----------
    chain : Chain
        The chain linop to split.
    tile : Mapping[ND | str, slice]
        Dictionary specifying how to slice the linop dimensions
    """
    ibatches = [
        [tile.get(dim, slice(None)) for dim in linop.ishape]
        for linop in chain.linops
    ]
    obatches = [
        [tile.get(dim, slice(None)) for dim in linop.oshape]
        for linop in chain.linops
    ]
    return chain.H.split_forward(obatches, ibatches).H

normal

normal(inner=None)

Compute the normal operator by folding through the chain.

For a chain \(C B A\), the normal is computed as \(A^H (B^H (C^H C (B (A \cdot))))\) by iterating linop.normal(inner) in reverse order. This enables Toeplitz embedding and other per-linop normal optimizations to compose correctly.

PARAMETER DESCRIPTION
inner

An inner operator seeded from an outer chain or None.

TYPE: NamedLinop DEFAULT: None

RETURNS DESCRIPTION
NamedLinop

The composed normal operator.

Source code in src/torchlinops/linops/chain.py
def normal(self, inner=None):
    """Compute the normal operator by folding through the chain.

    For a chain $C B A$, the normal is computed as
    $A^H (B^H (C^H C (B (A \\cdot))))$ by iterating ``linop.normal(inner)``
    in reverse order. This enables Toeplitz embedding and other per-linop
    normal optimizations to compose correctly.

    Parameters
    ----------
    inner : NamedLinop, optional
        An inner operator seeded from an outer chain or ``None``.

    Returns
    -------
    NamedLinop
        The composed normal operator.
    """
    for linop in reversed(self.linops):
        inner = linop.normal(inner)
    return inner

split staticmethod

split(chain, tile: Mapping[NamedDimension | str, slice])

Split a linop into sub-linops.

PARAMETER DESCRIPTION
chain

The chain linop to split.

TYPE: Chain

tile

Dictionary specifying how to slice the linop dimensions

TYPE: Mapping[NamedDimension | str, slice]

Source code in src/torchlinops/linops/chain.py
@staticmethod
def split(chain, tile: Mapping[ND | str, slice]):
    """Split a linop into sub-linops.

    Parameters
    ----------
    chain : Chain
        The chain linop to split.
    tile : Mapping[ND | str, slice]
        Dictionary specifying how to slice the linop dimensions
    """
    ibatches = [
        [tile.get(dim, slice(None)) for dim in linop.ishape]
        for linop in chain.linops
    ]
    obatches = [
        [tile.get(dim, slice(None)) for dim in linop.oshape]
        for linop in chain.linops
    ]
    return chain.split_forward(ibatches, obatches)

split_forward

split_forward(ibatches, obatches)

Split each constituent linop according to per-linop batch slices.

PARAMETER DESCRIPTION
ibatches

Per-linop input slices. Each element is a list of slices corresponding to the input dimensions of one linop in the chain.

TYPE: list[list[slice]]

obatches

Per-linop output slices. Each element is a list of slices corresponding to the output dimensions of one linop in the chain.

TYPE: list[list[slice]]

RETURNS DESCRIPTION
Chain

A new chain of the split sub-linops.

Source code in src/torchlinops/linops/chain.py
def split_forward(self, ibatches, obatches):
    """Split each constituent linop according to per-linop batch slices.

    Parameters
    ----------
    ibatches : list[list[slice]]
        Per-linop input slices. Each element is a list of slices corresponding
        to the input dimensions of one linop in the chain.
    obatches : list[list[slice]]
        Per-linop output slices. Each element is a list of slices corresponding
        to the output dimensions of one linop in the chain.

    Returns
    -------
    Chain
        A new chain of the split sub-linops.
    """
    linops = [
        linop.split_forward(ibatch, obatch)
        for linop, ibatch, obatch in zip(self.linops, ibatches, obatches)
    ]
    split = copy(self)
    split.linops = nn.ModuleList(linops)
    return split

Concat

Bases: Threadable, NamedLinop

Concatenate some linops along an existing dimension.

Linops need not output tensors of the same size, but they should output tensors of the same number of dimensions.

Stacking type depends on dimensions provided:

Horizontal stacking (stacking along an input dimension)::

A B C

Vertical stacking (stacking along an output dimension)::

A
B
C

Diagonal stacking (stacking along separate input and output dimensions)::

A . .
. B .
. . C

Inherits from Threadable to support parallel execution of sub-linops. When threaded=True (default), each sub-linop is executed in parallel using a ThreadPoolExecutor.

Note that shared linops (e.g., Concat(A, A, idim="x")) are automatically shallow-copied to ensure independent identity for threading, while still sharing tensor data. See Threadable for details.

ATTRIBUTE DESCRIPTION
linops

The list of linops being concatenated.

TYPE: ModuleList

threaded

Whether to run sub-linops in parallel. Default is True.

TYPE: bool

num_workers

Number of worker threads. If None, defaults to the number of sub-linops.

TYPE: int | None

idim

Input dimension along which to concatenate.

TYPE: NamedDimension | None

odim

Output dimension along which to concatenate.

TYPE: NamedDimension | None

Source code in src/torchlinops/linops/concat.py
class Concat(Threadable, NamedLinop):
    """Concatenate some linops along an existing dimension.

    Linops need not output tensors of the same size, but they should
    output tensors of the same number of dimensions.

    Stacking type depends on dimensions provided:

    Horizontal stacking (stacking along an input dimension)::

        A B C

    Vertical stacking (stacking along an output dimension)::

        A
        B
        C

    Diagonal stacking (stacking along separate input and output dimensions)::

        A . .
        . B .
        . . C

    Inherits from ``Threadable`` to support parallel execution of sub-linops.
    When ``threaded=True`` (default), each sub-linop is executed in parallel
    using a ThreadPoolExecutor.

    Note that shared linops (e.g., ``Concat(A, A, idim="x")``) are automatically
    shallow-copied to ensure independent identity for threading, while still
    sharing tensor data. See ``Threadable`` for details.

    Attributes
    ----------
    linops : nn.ModuleList
        The list of linops being concatenated.
    threaded : bool
        Whether to run sub-linops in parallel. Default is True.
    num_workers : int | None
        Number of worker threads. If None, defaults to the number of sub-linops.
    idim : ND | None
        Input dimension along which to concatenate.
    odim : ND | None
        Output dimension along which to concatenate.
    """

    def __init__(
        self,
        *linops,
        idim: Optional[ND | str] = None,
        odim: Optional[ND | str] = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        *linops : NamedLinop
            The linops to concatenate.
        idim : str or ND, optional
            Input dimension along which to concatenate. If ``None``, the input
            is not concatenated (all linops receive the same input).
        odim : str or ND, optional
            Output dimension along which to concatenate. If ``None``, the output
            is not concatenated (outputs are summed).
        """
        self._check_linop_compatibility(linops)
        super().__init__(NS(linops[0].ishape, linops[0].oshape), **kwargs)
        self.linops = nn.ModuleList(list(linops))
        self._setup_indices(idim, odim)

    @staticmethod
    def fn(concat, x):
        return concat._fn(
            x,
            concat.linops,
            concat.idim_idx,
            concat.odim_idx,
            concat.islices,
            concat.oslices,
            concat.threaded,
            concat.num_workers,
        )

    @staticmethod
    def adj_fn(concat, x):
        adj_linops = [linop.H for linop in concat.linops]
        return concat._fn(
            x,
            adj_linops,
            concat.odim_idx,
            concat.idim_idx,
            concat.oslices,
            concat.islices,
            concat.threaded,
            concat.num_workers,
        )

    @staticmethod
    def _fn(
        x: Tensor,
        linops,
        idim_idx,
        odim_idx,
        islices,
        oslices,
        threaded: bool = False,
        num_workers: int | None = None,
    ):
        """Unifies forward and adjoint functionality for stacked linops"""
        if idim_idx is not None:  # Diagonal, Horizontal
            if islices[-1] != x.shape[idim_idx]:
                raise ValueError(
                    f"Concat Linop expecting input of size {islices[-1]} got input of size {x.shape} with non-matching concat size {x.shape[idim_idx]}"
                )
            xs = x.tensor_split(islices, dim=idim_idx)[:-1]  # Omit final slice
        else:  # Vertical
            xs = [x] * len(oslices)

        if odim_idx is not None:  # Diagonal, Vertical
            if threaded:
                ys = _threaded_apply(list(linops), xs, num_workers)
            else:
                ys = [linop(xi) for xi, linop in zip(xs, linops)]
            return torch.concatenate(ys, dim=odim_idx)

        # Horizontal
        if threaded:
            y = _threaded_apply_sum_reduce(list(linops), xs, num_workers)
        else:
            y = 0.0
            for xi, linop in zip(xs, linops):
                y += linop(xi)
        return y

    def size(self, dim):
        if dim == self.idim:
            return sum(self.isizes)
        elif dim == self.odim:
            return sum(self.osizes)
        else:
            for linop in self.linops:
                if linop.size(dim) is not None:
                    return linop.size(dim)

    def split_forward(self, ibatch, obatch):
        """Split concat linop, making a new concat linop if necessary"""
        ibatches = self.subslice(ibatch, self.idim_idx, self.islices, len(self.linops))
        obatches = self.subslice(obatch, self.odim_idx, self.oslices, len(self.linops))

        output_linop_idxs = ibatches.keys() & obatches.keys()
        output_linop_idxs = sorted(list(output_linop_idxs))
        if len(output_linop_idxs) == 0:
            # No linops satisfy this slice (diagonal stacking)
            return Zero(self.ishape, self.oshape)
        elif len(output_linop_idxs) == 1:
            # Singleton linop
            linop_idx = output_linop_idxs.pop()
            linop = self.linops[linop_idx]
            ibatch, obatch = ibatches[linop_idx], obatches[linop_idx]
            return linop.split_forward(ibatch, obatch)
        else:
            output_linop_idxs = sorted(list(output_linop_idxs))
            output_linops = []
            for i in output_linop_idxs:
                linop = self.linops[i]
                ibatch, obatch = ibatches[i], obatches[i]
                output_linops.append(linop.split_forward(ibatch, obatch))
            return self.spinoff(output_linops, idim=self.idim, odim=self.odim)

    @staticmethod
    def subslice(batch: list[slice], dim_idx: Optional[int], slices, num_linops):
        """Given a slice over some dims of a concat linop,
        return a mapping from the linop index to the relevant sub-slice for that linop.
        """
        linops_batch = {}
        if dim_idx is not None:
            slc = batch[dim_idx]
            slice_partition = slices.detach().cpu().numpy().tolist()
            slice_partition.insert(0, 0)
            sub_linop_slices = partition_slices(slice_partition, slc)
            for i, slc in sub_linop_slices:
                sub_linop_batch = copy(batch)
                sub_linop_batch[dim_idx] = slc
                linops_batch[i] = sub_linop_batch
        else:
            for i in range(num_linops):
                linops_batch[i] = batch
        return linops_batch

    def adjoint(self):
        adj_linops = [linop.H for linop in self.linops]
        adj_shape = adj_linops[0].shape
        return self.spinoff(
            linops=adj_linops,
            shape=adj_shape,
            idim=self.odim,
            odim=self.idim,
        )

    def normal(self, inner=None):
        if inner is None:
            # Standardize on this shape
            max_ishape = max_shape([linop.N.ishape for linop in self.linops])
            max_oshape = max_shape([linop.N.oshape for linop in self.linops])
            new_shape = NS(max_ishape, max_oshape)
            if self.idim is None:  # Vertical (inner product)
                linops = [linop.N for linop in self.linops]
                linops = standardize_shapes(linops, new_shape)
                new = Add(*linops)
                new.settings = self.settings  # Copy Threadable settings
                return new
            elif self.odim is None:  # Horizontal (outer product)
                rows = []
                new_idim, new_odim = self._get_new_normal_io_dims(new_shape, self.idim)
                for linop_left in self.linops:
                    row = []
                    for linop_right in self.linops:
                        if linop_left == linop_right:
                            new_linop = linop_right.N
                        else:
                            new_linop = linop_left.H @ linop_right
                        row.append(new_linop)
                        row = standardize_shapes(row, new_shape)
                    rows.append(
                        self.spinoff(
                            linops=row, shape=new_shape, idim=new_idim, odim=None
                        )
                    )
                # rows = standardize_shapes(rows, new_shape)
                return self.spinoff(rows, shape=new_shape, idim=None, odim=new_odim)
            else:  # Diagonal
                diag = []
                new_idim, new_odim = self._get_new_normal_io_dims(new_shape, self.idim)
                for linop in self.linops:
                    diag.append(linop.N)
                diag = standardize_shapes(diag, new_shape)
                return self.spinoff(diag, shape=new_shape, idim=new_idim, odim=new_odim)
        return super().normal(inner)

    @staticmethod
    def _get_new_normal_io_dims(new_shape, dim) -> tuple:
        i = new_shape.ishape.index(dim)
        new_idim = new_shape.ishape[i]
        new_odim = new_shape.oshape[i]
        return new_idim, new_odim

    @staticmethod
    def _check_linop_compatibility(linops: list[NamedLinop]):
        """Ensure linops can actually be concatenated along the requested dimension"""
        target_shape = linops[0].shape
        for linop in linops:
            if not (
                isequal(target_shape.ishape, linop.ishape)
                and isequal(target_shape.oshape, linop.oshape)
            ):
                raise ValueError(
                    f"Incompatible linops being stacked. Target shape: {target_shape} but got linop shape: {linop.shape}"
                )

    def _setup_indices(self, idim, odim):
        ishape = self.linops[0].ishape
        oshape = self.linops[0].oshape
        self.idim, self.isizes, self.islices = self._setup_dim(idim, ishape)
        self.odim, self.osizes, self.oslices = self._setup_dim(odim, oshape)

        if self.idim is None and self.odim is None:
            raise ValueError(f"At least one of idim and odim cannot be None.")

        self.idim_idx = self._infer_dim_idx(self.idim, ishape)
        self.odim_idx = self._infer_dim_idx(self.odim, oshape)

    def _setup_dim(self, dim, shape):
        if dim is not None:
            _dim = ND.infer(dim)
            if any(linop.size(_dim) is None for linop in self.linops):
                raise ValueError(
                    f"Found linop with undefined size for dim {_dim} when attempting concat."
                )
            _sizes = [linop.size(_dim) for linop in self.linops]
            _slices = torch.tensor(_sizes).cumsum(0)  # Keep on CPU
        else:
            _dim = None
            _sizes = None
            _slices = None
        return _dim, _sizes, _slices

    @staticmethod
    def _infer_dim_idx(dim: ND, shape: tuple[ND, ...]) -> int:
        """Get index of dim within requested shape tuple

        Tries to infer index in the presence of ellipses "..." shapes
        Returns positive int if possible
        Otherwise, tries to return negative int
        Fails if neither is possible.

        """
        if dim is None:
            return None
        if dim not in shape:
            raise ValueError(
                f"Provided concat dimension {dim} not found in shape {shape}"
            )
        shape_list = [str(s) for s in shape]
        dim_idx = shape_list.index(str(dim))
        pre, post = shape_list[:dim_idx], shape_list[dim_idx + 1 :]
        if ELLIPSES in pre:
            if ELLIPSES in post:
                raise ValueError(
                    f"Cannot infer concat dimension for dim {dim} from shape {shape}"
                )
            else:
                return -(len(post) + 1)
        return len(pre)

    def __getitem__(self, idx):
        linops = self.linops[idx]
        if isinstance(linops, NamedLinop):
            return linops
        return self.spinoff(linops, idim=self.idim, odim=self.odim)

    def spinoff(self, linops=None, shape=None, idim=None, odim=None):
        """Helper function for creating a new linop using the provided inputs.

        Preserves settings from the original linop.
        """
        linops = linops if linops is not None else self.linops
        shape = shape if shape is not None else self.shape
        new = copy(self)
        new.shape = shape
        new.linops = nn.ModuleList(linops)
        new._setup_indices(idim=idim, odim=odim)
        return new

    def __len__(self):
        return len(self.linops)

    def __repr__(self):
        output = ""
        output += INDENT.indent(self.repr_name + f"({self._shape}\n")
        with INDENT:
            for linop in self.linops:
                output += repr(linop) + "\n"
            output += INDENT.indent(f"idim = {self.idim}, odim = {self.odim}\n")
        output += INDENT.indent(")")
        return output

__init__

__init__(
    *linops,
    idim: Optional[NamedDimension | str] = None,
    odim: Optional[NamedDimension | str] = None,
    **kwargs,
)
PARAMETER DESCRIPTION
*linops

The linops to concatenate.

TYPE: NamedLinop DEFAULT: ()

idim

Input dimension along which to concatenate. If None, the input is not concatenated (all linops receive the same input).

TYPE: str or NamedDimension DEFAULT: None

odim

Output dimension along which to concatenate. If None, the output is not concatenated (outputs are summed).

TYPE: str or NamedDimension DEFAULT: None

Source code in src/torchlinops/linops/concat.py
def __init__(
    self,
    *linops,
    idim: Optional[ND | str] = None,
    odim: Optional[ND | str] = None,
    **kwargs,
):
    """
    Parameters
    ----------
    *linops : NamedLinop
        The linops to concatenate.
    idim : str or ND, optional
        Input dimension along which to concatenate. If ``None``, the input
        is not concatenated (all linops receive the same input).
    odim : str or ND, optional
        Output dimension along which to concatenate. If ``None``, the output
        is not concatenated (outputs are summed).
    """
    self._check_linop_compatibility(linops)
    super().__init__(NS(linops[0].ishape, linops[0].oshape), **kwargs)
    self.linops = nn.ModuleList(list(linops))
    self._setup_indices(idim, odim)

spinoff

spinoff(linops=None, shape=None, idim=None, odim=None)

Helper function for creating a new linop using the provided inputs.

Preserves settings from the original linop.

Source code in src/torchlinops/linops/concat.py
def spinoff(self, linops=None, shape=None, idim=None, odim=None):
    """Helper function for creating a new linop using the provided inputs.

    Preserves settings from the original linop.
    """
    linops = linops if linops is not None else self.linops
    shape = shape if shape is not None else self.shape
    new = copy(self)
    new.shape = shape
    new.linops = nn.ModuleList(linops)
    new._setup_indices(idim=idim, odim=odim)
    return new

split_forward

split_forward(ibatch, obatch)

Split concat linop, making a new concat linop if necessary

Source code in src/torchlinops/linops/concat.py
def split_forward(self, ibatch, obatch):
    """Split concat linop, making a new concat linop if necessary"""
    ibatches = self.subslice(ibatch, self.idim_idx, self.islices, len(self.linops))
    obatches = self.subslice(obatch, self.odim_idx, self.oslices, len(self.linops))

    output_linop_idxs = ibatches.keys() & obatches.keys()
    output_linop_idxs = sorted(list(output_linop_idxs))
    if len(output_linop_idxs) == 0:
        # No linops satisfy this slice (diagonal stacking)
        return Zero(self.ishape, self.oshape)
    elif len(output_linop_idxs) == 1:
        # Singleton linop
        linop_idx = output_linop_idxs.pop()
        linop = self.linops[linop_idx]
        ibatch, obatch = ibatches[linop_idx], obatches[linop_idx]
        return linop.split_forward(ibatch, obatch)
    else:
        output_linop_idxs = sorted(list(output_linop_idxs))
        output_linops = []
        for i in output_linop_idxs:
            linop = self.linops[i]
            ibatch, obatch = ibatches[i], obatches[i]
            output_linops.append(linop.split_forward(ibatch, obatch))
        return self.spinoff(output_linops, idim=self.idim, odim=self.odim)

subslice staticmethod

subslice(
    batch: list[slice],
    dim_idx: Optional[int],
    slices,
    num_linops,
)

Given a slice over some dims of a concat linop, return a mapping from the linop index to the relevant sub-slice for that linop.

Source code in src/torchlinops/linops/concat.py
@staticmethod
def subslice(batch: list[slice], dim_idx: Optional[int], slices, num_linops):
    """Given a slice over some dims of a concat linop,
    return a mapping from the linop index to the relevant sub-slice for that linop.
    """
    linops_batch = {}
    if dim_idx is not None:
        slc = batch[dim_idx]
        slice_partition = slices.detach().cpu().numpy().tolist()
        slice_partition.insert(0, 0)
        sub_linop_slices = partition_slices(slice_partition, slc)
        for i, slc in sub_linop_slices:
            sub_linop_batch = copy(batch)
            sub_linop_batch[dim_idx] = slc
            linops_batch[i] = sub_linop_batch
    else:
        for i in range(num_linops):
            linops_batch[i] = batch
    return linops_batch

Dense

Bases: NamedLinop

Dense matrix-vector multiply.

"Dense" is used to distinguish from "sparse" linear operators. This operator performs a matrix-vector multiplication, potentially with batch and broadcast dimensions, implemented via einops.einsum.

The core operation is:

\(y_{o\dots} = \sum_{i\dots} W_{i\dots, o\dots} x_{i\dots}\)

where \(x\) is the input, \(W\) is the weight matrix, and \(y\) is the output. \(i\dots\) and \(o\dots\) represent the input and output dimensions involved in the multiplication. Other dimensions are treated as batch or broadcast dimensions.

Examples:

A simple batched multiplication:

  • Input \(x\) shape: \((A, N_x, N_y)\)
  • Weight \(W\) shape: \((A, T)\)
  • Output \(y\) shape: \((T, N_x, N_y)\)

Here, \(A\) is the input feature dimension, \(T\) is the output feature dimension, and \((N_x, N_y)\) are broadcast dimensions. The operation is:

\(y_{t, n_x, n_y} = \sum_{a} W_{a, t} x_{a, n_x, n_y}\)

Another example with a batch dimension \(C\) shared between input and weights:

  • Input \(x\) shape: \((C, A, N_x, N_y)\)
  • Weight \(W\) shape: \((C, A, A_1)\)
  • Output \(y\) shape: \((C, A_1, N_x, N_y)\)

The operation is:

\(y_{c, a_1, n_x, n_y} = \sum_{a} W_{c, a, a_1} x_{c, a, n_x, n_y}\)

Source code in src/torchlinops/linops/dense.py
class Dense(NamedLinop):
    r"""Dense matrix-vector multiply.

    "Dense" is used to distinguish from "sparse" linear operators. This
    operator performs a matrix-vector multiplication, potentially with batch
    and broadcast dimensions, implemented via ``einops.einsum``.

    The core operation is:

    $y_{o\dots} = \sum_{i\dots} W_{i\dots, o\dots} x_{i\dots}$

    where $x$ is the input, $W$ is the weight matrix, and
    $y$ is the output. $i\dots$ and $o\dots$ represent
    the input and output dimensions involved in the multiplication. Other
    dimensions are treated as batch or broadcast dimensions.

    Examples
    --------
    A simple batched multiplication:

    - Input $x$ shape: $(A, N_x, N_y)$
    - Weight $W$ shape: $(A, T)$
    - Output $y$ shape: $(T, N_x, N_y)$

    Here, $A$ is the input feature dimension, $T$ is the output
    feature dimension, and $(N_x, N_y)$ are broadcast dimensions.
    The operation is:

    $y_{t, n_x, n_y} = \sum_{a} W_{a, t} x_{a, n_x, n_y}$

    Another example with a batch dimension $C$ shared between input
    and weights:

    - Input $x$ shape: $(C, A, N_x, N_y)$
    - Weight $W$ shape: $(C, A, A_1)$
    - Output $y$ shape: $(C, A_1, N_x, N_y)$

    The operation is:

    $y_{c, a_1, n_x, n_y} = \sum_{a} W_{c, a, a_1} x_{c, a, n_x, n_y}$

    """

    def __init__(
        self,
        weight: Tensor,
        weightshape: Shape,
        ishape: Shape,
        oshape: Shape,
        broadcast_dims: Optional[list] = None,
    ):
        """
        Parameters
        ----------
        weight : Tensor
            The dense matrix used for this linop.
        weightshape : Shape
            The shape of the matrix, in symbolic form.
        ishape : Shape
            The input shape of the matrix.
        oshape : Shape
            The output shape of the matrix.
        broadcast_dims : list
            A list of the dimensions of weight that are intended to be broadcasted over the input.
            As such, they are excluded from splitting.
        """
        super().__init__(NS(ishape, oshape))
        self.weight = nn.Parameter(weight, requires_grad=False)
        self._shape.weightshape = weightshape

        broadcast_dims = broadcast_dims if broadcast_dims is not None else []
        self._shape.broadcast_dims = broadcast_dims

    @property
    def weightshape(self) -> Shape:
        weightshape = self._shape.weightshape
        if not isinstance(weightshape, Sequence):
            raise ValueError(
                f"Expected weightshape to be a sequence but got {type(weightshape)}: {weightshape}"
            )
        return weightshape

    @property
    def broadcast_dims(self):
        return self._shape.broadcast_dims

    @property
    def forward_einstr(self):
        return f"{self.einstr(self.ishape)},{self.einstr(self.weightshape)}->{self.einstr(self.oshape)}"

    @property
    def adj_einstr(self):
        return f"{self.einstr(self.oshape)},{self.einstr(self.weightshape)}->{self.einstr(self.ishape)}"

    @staticmethod
    def einstr(arr):
        return " ".join(str(s) for s in arr)

    @staticmethod
    def fn(dense, x, /):
        return einsum(x, dense.weight, dense.forward_einstr)

    @staticmethod
    def adj_fn(dense, x, /):
        return einsum(x, dense.weight.conj(), dense.adj_einstr)

    def adjoint(self):
        adj = copy(self)
        adj.weight = nn.Parameter(
            self.weight.conj(), requires_grad=adj.weight.requires_grad
        )
        adj._shape = adj._shape.H
        adj._update_suffix(adjoint=self._name is not None)
        return adj

    def normal(self, inner=None):
        """Compute the normal operator (adjoint times forward).

        Parameters
        ----------
        inner : NamedLinop, optional
            An optional inner operator to sandwich between the adjoint and
            forward. If None, consolidates two Dense operators into a single
            Dense.

        Returns
        -------
        NamedLinop
            The normal operator.

        Notes
        -----
        If inner is None, consolidate two Dense's into a single Dense
        ishape: [A B X Y]
        oshape: [C D X Y]
        wshape: [A B C D]

        Needs to become
        ishape: [A B X Y]
        oshape: [A1 B1 X Y]
        wshape: [A B A1 B1]

        New weight is attained as
        einsum(weight.conj(), weight, 'A1 B1 C D, A B C D -> A B A1 B1')

        -----
        ishape: [C A]
        oshape: [C1 A]
        wshape = [C C1]

        Needs to become
        ishape: [C A]
        oshape: [C2 A]
        wshape = [C C2]

        einsum(weight.conj(), weight, 'C1 C2, C C1 -> C C2')


        """
        new_oshape = []
        weight_conj_shape = list(deepcopy(self.weightshape))
        wdiag_shape = []
        wout_shape = []
        win_shape = []
        used_shapes = self.ishape + self.oshape
        shape_updates = {}
        # Make new oshape and weight shape
        # Rules:
        # New weightshape
        #   If dim appears in ishape and weightshape but not oshape -> increment
        #   If dim appears in ishape and weightshape AND oshape -> don't increment
        #   If dim doesn't appear in ishape or weightshape -> don't add it to new weightshape
        # Other rules:
        # new ishape is same as old ishape
        # new oshape is ishape but updated with new dimensions
        for dim in self.ishape:
            if dim in self.weightshape:
                if dim not in self.oshape:
                    win_shape.append(dim)
                    new_dim = dim.next_unused(used_shapes)
                    shape_updates[dim] = new_dim
                    wout_shape.append(new_dim)
                else:
                    wdiag_shape.append(dim)
                    new_dim = dim
                i = weight_conj_shape.index(dim)
                weight_conj_shape[i] = new_dim
            else:
                new_dim = dim
            new_oshape.append(new_dim)

        if config.inner_not_relevant(inner):
            # Consolidate dense and dense adjoint into single dense
            new_weight_shape = wdiag_shape + wout_shape + win_shape
            einstr = shapes2einstr(
                self.weightshape,
                weight_conj_shape,
                new_weight_shape,
            )
            new_weight = einsum(self.weight, self.weight.conj(), einstr)
            normal = type(self)(
                new_weight,
                tuple(new_weight_shape),
                self.ishape,
                new_oshape,
            )
            normal._name = self._name
            normal._update_suffix(normal=self._name is not None)
            normal._shape_updates = shape_updates
            return normal
        _shape_updates = getattr(inner, "_shape_updates", {})
        _shape_updates.update(shape_updates)
        pre = copy(self)
        pre.oshape = inner.ishape
        post = self.adjoint()  # Copy happens inside adjoint
        post.ishape = inner.oshape
        post.oshape = new_oshape
        normal = post @ inner @ pre
        normal._shape_updates = _shape_updates
        return normal

    def split_forward(self, ibatch, obatch):
        weight = self.split_weight(ibatch, obatch, self.weight)
        out = copy(self)
        out.weight = nn.Parameter(weight, requires_grad=self.weight.requires_grad)
        return out

    def split_weight(self, ibatch, obatch, /, weight):
        weightbatch = [slice(None)] * len(self.weightshape)
        for dim, batch in zip(self.ishape, ibatch):
            if dim in self.weightshape and dim not in self.broadcast_dims:
                weightbatch[self.weightshape.index(dim)] = batch
        for dim, batch in zip(self.oshape, obatch):
            if dim in self.weightshape and dim not in self.broadcast_dims:
                weightbatch[self.weightshape.index(dim)] = batch
        return weight[tuple(weightbatch)]

    def size(self, dim: str):
        if dim in self.broadcast_dims:
            return None
        if dim in self.weightshape:
            return self.weight.shape[self.weightshape.index(dim)]
        return None

__init__

__init__(
    weight: Tensor,
    weightshape: Shape,
    ishape: Shape,
    oshape: Shape,
    broadcast_dims: Optional[list] = None,
)
PARAMETER DESCRIPTION
weight

The dense matrix used for this linop.

TYPE: Tensor

weightshape

The shape of the matrix, in symbolic form.

TYPE: Shape

ishape

The input shape of the matrix.

TYPE: Shape

oshape

The output shape of the matrix.

TYPE: Shape

broadcast_dims

A list of the dimensions of weight that are intended to be broadcasted over the input. As such, they are excluded from splitting.

TYPE: list DEFAULT: None

Source code in src/torchlinops/linops/dense.py
def __init__(
    self,
    weight: Tensor,
    weightshape: Shape,
    ishape: Shape,
    oshape: Shape,
    broadcast_dims: Optional[list] = None,
):
    """
    Parameters
    ----------
    weight : Tensor
        The dense matrix used for this linop.
    weightshape : Shape
        The shape of the matrix, in symbolic form.
    ishape : Shape
        The input shape of the matrix.
    oshape : Shape
        The output shape of the matrix.
    broadcast_dims : list
        A list of the dimensions of weight that are intended to be broadcasted over the input.
        As such, they are excluded from splitting.
    """
    super().__init__(NS(ishape, oshape))
    self.weight = nn.Parameter(weight, requires_grad=False)
    self._shape.weightshape = weightshape

    broadcast_dims = broadcast_dims if broadcast_dims is not None else []
    self._shape.broadcast_dims = broadcast_dims

normal

normal(inner=None)

Compute the normal operator (adjoint times forward).

PARAMETER DESCRIPTION
inner

An optional inner operator to sandwich between the adjoint and forward. If None, consolidates two Dense operators into a single Dense.

TYPE: NamedLinop DEFAULT: None

RETURNS DESCRIPTION
NamedLinop

The normal operator.

Notes

If inner is None, consolidate two Dense's into a single Dense ishape: [A B X Y] oshape: [C D X Y] wshape: [A B C D]

Needs to become ishape: [A B X Y] oshape: [A1 B1 X Y] wshape: [A B A1 B1]

New weight is attained as einsum(weight.conj(), weight, 'A1 B1 C D, A B C D -> A B A1 B1')


ishape: [C A] oshape: [C1 A] wshape = [C C1]

Needs to become ishape: [C A] oshape: [C2 A] wshape = [C C2]

einsum(weight.conj(), weight, 'C1 C2, C C1 -> C C2')

Source code in src/torchlinops/linops/dense.py
def normal(self, inner=None):
    """Compute the normal operator (adjoint times forward).

    Parameters
    ----------
    inner : NamedLinop, optional
        An optional inner operator to sandwich between the adjoint and
        forward. If None, consolidates two Dense operators into a single
        Dense.

    Returns
    -------
    NamedLinop
        The normal operator.

    Notes
    -----
    If inner is None, consolidate two Dense's into a single Dense
    ishape: [A B X Y]
    oshape: [C D X Y]
    wshape: [A B C D]

    Needs to become
    ishape: [A B X Y]
    oshape: [A1 B1 X Y]
    wshape: [A B A1 B1]

    New weight is attained as
    einsum(weight.conj(), weight, 'A1 B1 C D, A B C D -> A B A1 B1')

    -----
    ishape: [C A]
    oshape: [C1 A]
    wshape = [C C1]

    Needs to become
    ishape: [C A]
    oshape: [C2 A]
    wshape = [C C2]

    einsum(weight.conj(), weight, 'C1 C2, C C1 -> C C2')


    """
    new_oshape = []
    weight_conj_shape = list(deepcopy(self.weightshape))
    wdiag_shape = []
    wout_shape = []
    win_shape = []
    used_shapes = self.ishape + self.oshape
    shape_updates = {}
    # Make new oshape and weight shape
    # Rules:
    # New weightshape
    #   If dim appears in ishape and weightshape but not oshape -> increment
    #   If dim appears in ishape and weightshape AND oshape -> don't increment
    #   If dim doesn't appear in ishape or weightshape -> don't add it to new weightshape
    # Other rules:
    # new ishape is same as old ishape
    # new oshape is ishape but updated with new dimensions
    for dim in self.ishape:
        if dim in self.weightshape:
            if dim not in self.oshape:
                win_shape.append(dim)
                new_dim = dim.next_unused(used_shapes)
                shape_updates[dim] = new_dim
                wout_shape.append(new_dim)
            else:
                wdiag_shape.append(dim)
                new_dim = dim
            i = weight_conj_shape.index(dim)
            weight_conj_shape[i] = new_dim
        else:
            new_dim = dim
        new_oshape.append(new_dim)

    if config.inner_not_relevant(inner):
        # Consolidate dense and dense adjoint into single dense
        new_weight_shape = wdiag_shape + wout_shape + win_shape
        einstr = shapes2einstr(
            self.weightshape,
            weight_conj_shape,
            new_weight_shape,
        )
        new_weight = einsum(self.weight, self.weight.conj(), einstr)
        normal = type(self)(
            new_weight,
            tuple(new_weight_shape),
            self.ishape,
            new_oshape,
        )
        normal._name = self._name
        normal._update_suffix(normal=self._name is not None)
        normal._shape_updates = shape_updates
        return normal
    _shape_updates = getattr(inner, "_shape_updates", {})
    _shape_updates.update(shape_updates)
    pre = copy(self)
    pre.oshape = inner.ishape
    post = self.adjoint()  # Copy happens inside adjoint
    post.ishape = inner.oshape
    post.oshape = new_oshape
    normal = post @ inner @ pre
    normal._shape_updates = _shape_updates
    return normal

DeviceSpec dataclass

Lightweight data structure for holding useful CUDA-related objects for multi-GPU computation.

ATTRIBUTE DESCRIPTION
device

The device for computation and transfers.

TYPE: device

compute_stream

Stream used for computation on this device. Set automatically by p2p_setup.

TYPE: (Stream, optional)

transfer_stream

Stream used for data transfers to/from this device. Obtained from a registry to enable stream reuse across transfers.

TYPE: (Stream, optional)

METHOD DESCRIPTION
p2p_setup

Configure compute and transfer streams for peer-to-peer transfers.

get_transfer_stream

Get or create a transfer stream for a source/target device pair.

Source code in src/torchlinops/linops/device.py
@dataclass
class DeviceSpec:
    """Lightweight data structure for holding useful CUDA-related objects for multi-GPU computation.

    Attributes
    ----------
    device : torch.device
        The device for computation and transfers.
    compute_stream : Stream, optional
        Stream used for computation on this device. Set automatically by ``p2p_setup``.
    transfer_stream : Stream, optional
        Stream used for data transfers to/from this device. Obtained from a registry
        to enable stream reuse across transfers.

    Methods
    -------
    p2p_setup(other_device)
        Configure compute and transfer streams for peer-to-peer transfers.
    get_transfer_stream(source_device, target_device)
        Get or create a transfer stream for a source/target device pair.
    """

    device: Any = field(default_factory=lambda: torch.device("cpu"))
    """Device for the streams."""
    compute_stream: Optional[Stream] = None
    """Stream used for computation."""
    transfer_stream: Optional[Stream] = None
    """Stream used for data transfer."""

    def __post_init__(self):
        """Ensure self.device is a proper torch.device."""
        if isinstance(self.device, str):
            self.device = torch.device(self.device)

    def p2p_setup(self, other_device):
        """Sets up compute and transfer streams for peer2peer transfers, if not set yet.

        Parameters
        ----------
        other_device : torch.device
            The other device involved in the peer-to-peer transfer.
        """
        if (
            self.device.type == "cuda" and other_device.type == "cuda"
        ):  # pragma: no cover
            if self.compute_stream is None:
                self.compute_stream = default_stream(self.device)
            if self.transfer_stream is None:
                self.transfer_stream = self.get_transfer_stream(
                    self.device, other_device
                )

    @property
    def type(self):
        """Passthrough for torch.device.type."""
        return self.device.type

    @staticmethod
    def get_transfer_stream(
        source_device: torch.device, target_device: torch.device
    ):  # pragma: no cover
        """Return the stream used for device transfers associated with this device.

        Streams are cached in a registry to enable reuse. Each source/target device
        pair gets a dedicated transfer stream.

        Parameters
        ----------
        source_device : torch.device
            The source device for transfers.
        target_device : torch.device
            The target device for transfers.

        Returns
        -------
        Stream
            A CUDA stream for performing transfers.
        """
        if (source_device, target_device) in _TRANSFER_STREAMS_REGISTRY:
            return _TRANSFER_STREAMS_REGISTRY[(source_device, target_device)]
        # Create a new stream
        new_stream = Stream(source_device)
        _TRANSFER_STREAMS_REGISTRY[(source_device, target_device)] = new_stream
        return new_stream

compute_stream class-attribute instance-attribute

compute_stream: Optional[Stream] = None

Stream used for computation.

device class-attribute instance-attribute

device: Any = field(default_factory=lambda: device('cpu'))

Device for the streams.

transfer_stream class-attribute instance-attribute

transfer_stream: Optional[Stream] = None

Stream used for data transfer.

type property

type

Passthrough for torch.device.type.

__post_init__

__post_init__()

Ensure self.device is a proper torch.device.

Source code in src/torchlinops/linops/device.py
def __post_init__(self):
    """Ensure self.device is a proper torch.device."""
    if isinstance(self.device, str):
        self.device = torch.device(self.device)

get_transfer_stream staticmethod

get_transfer_stream(
    source_device: device, target_device: device
)

Return the stream used for device transfers associated with this device.

Streams are cached in a registry to enable reuse. Each source/target device pair gets a dedicated transfer stream.

PARAMETER DESCRIPTION
source_device

The source device for transfers.

TYPE: device

target_device

The target device for transfers.

TYPE: device

RETURNS DESCRIPTION
Stream

A CUDA stream for performing transfers.

Source code in src/torchlinops/linops/device.py
@staticmethod
def get_transfer_stream(
    source_device: torch.device, target_device: torch.device
):  # pragma: no cover
    """Return the stream used for device transfers associated with this device.

    Streams are cached in a registry to enable reuse. Each source/target device
    pair gets a dedicated transfer stream.

    Parameters
    ----------
    source_device : torch.device
        The source device for transfers.
    target_device : torch.device
        The target device for transfers.

    Returns
    -------
    Stream
        A CUDA stream for performing transfers.
    """
    if (source_device, target_device) in _TRANSFER_STREAMS_REGISTRY:
        return _TRANSFER_STREAMS_REGISTRY[(source_device, target_device)]
    # Create a new stream
    new_stream = Stream(source_device)
    _TRANSFER_STREAMS_REGISTRY[(source_device, target_device)] = new_stream
    return new_stream

p2p_setup

p2p_setup(other_device)

Sets up compute and transfer streams for peer2peer transfers, if not set yet.

PARAMETER DESCRIPTION
other_device

The other device involved in the peer-to-peer transfer.

TYPE: device

Source code in src/torchlinops/linops/device.py
def p2p_setup(self, other_device):
    """Sets up compute and transfer streams for peer2peer transfers, if not set yet.

    Parameters
    ----------
    other_device : torch.device
        The other device involved in the peer-to-peer transfer.
    """
    if (
        self.device.type == "cuda" and other_device.type == "cuda"
    ):  # pragma: no cover
        if self.compute_stream is None:
            self.compute_stream = default_stream(self.device)
        if self.transfer_stream is None:
            self.transfer_stream = self.get_transfer_stream(
                self.device, other_device
            )

Diagonal

Bases: NamedLinop

Elementwise diagonal linear operator \(D(x) = w \odot x\).

The forward operation is pointwise multiplication by a weight tensor w. The adjoint is \(D^H(x) = \bar{w} \odot x\) and the normal is \(D^N(x) = |w|^2 \odot x\).

Because the input and output shapes are identical, Diagonal sets oshape = ishape and keeps them synchronized.

ATTRIBUTE DESCRIPTION
weight

The diagonal weight tensor \(w\).

TYPE: Parameter

broadcast_dims

Dimensions along which the weight is broadcast (not stored explicitly).

TYPE: list

Source code in src/torchlinops/linops/diagonal.py
class Diagonal(NamedLinop):
    """Elementwise diagonal linear operator $D(x) = w \\odot x$.

    The forward operation is pointwise multiplication by a weight tensor *w*.
    The adjoint is $D^H(x) = \\bar{w} \\odot x$ and the normal is
    $D^N(x) = |w|^2 \\odot x$.

    Because the input and output shapes are identical, ``Diagonal`` sets
    ``oshape = ishape`` and keeps them synchronized.

    Attributes
    ----------
    weight : nn.Parameter
        The diagonal weight tensor $w$.
    broadcast_dims : list
        Dimensions along which the weight is broadcast (not stored explicitly).
    """

    def __init__(
        self,
        weight: torch.Tensor,
        ioshape: Optional[Shape] = None,
        broadcast_dims: Optional[Shape] = None,
    ):
        """
        Parameters
        ----------
        weight : Tensor
            The diagonal weight tensor.
        ioshape : Shape, optional
            Named dimensions for input and output (they are the same).
        broadcast_dims : Shape, optional
            Dimensions along which *weight* should be broadcast rather than
            indexed. Useful when the weight has fewer dimensions than the input.
        """
        if ioshape is not None and len(weight.shape) > len(ioshape):
            raise ValueError(
                f"All dimensions must be named or broadcastable, but got weight shape {weight.shape} and ioshape {ioshape}"
            )
        # if broadcast_dims is not None:
        #     warn(
        #         f"broadcast_dims argument is deprecated for torchlinops Diagonal but got {broadcast_dims}",
        #         DeprecationWarning,
        #         stacklevel=2,
        #     )
        super().__init__(NS(ioshape))
        self.weight = nn.Parameter(weight, requires_grad=False)
        # assert (
        #     len(self.ishape) >= len(self.weight.shape)
        # ), f"Weight cannot have fewer dimensions than the input shape: ishape: {self.ishape}, weight: {weight.shape}"
        broadcast_dims = broadcast_dims if broadcast_dims is not None else []
        if ANY in self.ishape:
            broadcast_dims.append(ANY)
        self._shape.broadcast_dims = broadcast_dims

    @classmethod
    def from_weight(
        cls,
        weight: Tensor,
        weight_shape: Shape,
        ioshape: Shape,
        shape_kwargs: Optional[dict] = None,
    ):
        """Construct a ``Diagonal`` by expanding *weight* to match *ioshape* via einops.

        Parameters
        ----------
        weight : Tensor
            The weight tensor in its original (possibly lower-dimensional) shape.
        weight_shape : Shape
            Named dimensions labeling the axes of *weight*.
        ioshape : Shape
            Target named dimensions for the expanded weight.
        shape_kwargs : dict, optional
            Extra keyword arguments forwarded to ``einops.repeat``.

        Returns
        -------
        Diagonal
            A new diagonal linop with the expanded weight.
        """
        shape_kwargs = shape_kwargs if shape_kwargs is not None else {}
        if len(weight.shape) > len(ioshape):
            raise ValueError(
                f"All dimensions must be named or broadcastable, but got weight shape {weight.shape} and ioshape {ioshape}"
            )
        weight = repeat(
            weight,
            f"{' '.join(weight_shape)} -> {' '.join(ioshape)}",
            **shape_kwargs,
        )
        return cls(weight, ioshape)

    @property
    def broadcast_dims(self):
        return self._shape.broadcast_dims

    @broadcast_dims.setter
    def broadcast_dims(self, val):
        self._shape.broadcast_dims = val

    # Override shape setters too
    @NamedLinop.ishape.setter
    def ishape(self, val):
        self._shape.ishape = val
        self._shape.oshape = val

    @NamedLinop.oshape.setter
    def oshape(self, val):
        self._shape.oshape = val
        self._shape.ishape = val

    @staticmethod
    def fn(diagonal, x, /):
        return x * diagonal.weight

    @staticmethod
    def adj_fn(diagonal, x, /):
        return x * torch.conj(diagonal.weight)

    @staticmethod
    def normal_fn(diagonal, x, /):
        return x * torch.abs(diagonal.weight) ** 2

    def adjoint(self):
        adj = copy(self)
        adj.weight = nn.Parameter(
            self.weight.conj(),
            requires_grad=self.weight.requires_grad,
        )
        return adj

    def normal(self, inner=None):
        if inner is None:
            normal = copy(self)
            normal.weight = nn.Parameter(
                torch.abs(self.weight) ** 2,
                requires_grad=self.weight.requires_grad,
            )
            return normal
        return super().normal(inner)

    def split_forward(self, ibatch, obatch):
        weight = self.split_weight(ibatch, obatch, self.weight)
        split = copy(self)
        split.weight = nn.Parameter(weight, requires_grad=self.weight.requires_grad)
        return split

    def split_weight(self, ibatch, obatch, /, weight):
        assert ibatch == obatch, "Diagonal linop must be split identically"
        # Filter out broadcastable dims
        ibatch = [
            slice(None) if dim in self.broadcast_dims else slc
            for slc, dim in zip(ibatch, self.ishape)
        ]
        return weight[tuple(ibatch[-len(weight.shape) :])]

    def size(self, dim: str):
        if dim in self.ishape:
            n_broadcast = len(self.ishape) - len(self.weight.shape)
            if self.ishape.index(dim) < n_broadcast or dim in self.broadcast_dims:
                return None
            else:
                return self.weight.shape[self.ishape.index(dim) - n_broadcast]
        return None

    def __pow__(self, exponent):
        new = copy(self)
        new.weight = nn.Parameter(
            self.weight**exponent,
            requires_grad=self.weight.requires_grad,
        )
        return new

__init__

__init__(
    weight: Tensor,
    ioshape: Optional[Shape] = None,
    broadcast_dims: Optional[Shape] = None,
)
PARAMETER DESCRIPTION
weight

The diagonal weight tensor.

TYPE: Tensor

ioshape

Named dimensions for input and output (they are the same).

TYPE: Shape DEFAULT: None

broadcast_dims

Dimensions along which weight should be broadcast rather than indexed. Useful when the weight has fewer dimensions than the input.

TYPE: Shape DEFAULT: None

Source code in src/torchlinops/linops/diagonal.py
def __init__(
    self,
    weight: torch.Tensor,
    ioshape: Optional[Shape] = None,
    broadcast_dims: Optional[Shape] = None,
):
    """
    Parameters
    ----------
    weight : Tensor
        The diagonal weight tensor.
    ioshape : Shape, optional
        Named dimensions for input and output (they are the same).
    broadcast_dims : Shape, optional
        Dimensions along which *weight* should be broadcast rather than
        indexed. Useful when the weight has fewer dimensions than the input.
    """
    if ioshape is not None and len(weight.shape) > len(ioshape):
        raise ValueError(
            f"All dimensions must be named or broadcastable, but got weight shape {weight.shape} and ioshape {ioshape}"
        )
    # if broadcast_dims is not None:
    #     warn(
    #         f"broadcast_dims argument is deprecated for torchlinops Diagonal but got {broadcast_dims}",
    #         DeprecationWarning,
    #         stacklevel=2,
    #     )
    super().__init__(NS(ioshape))
    self.weight = nn.Parameter(weight, requires_grad=False)
    # assert (
    #     len(self.ishape) >= len(self.weight.shape)
    # ), f"Weight cannot have fewer dimensions than the input shape: ishape: {self.ishape}, weight: {weight.shape}"
    broadcast_dims = broadcast_dims if broadcast_dims is not None else []
    if ANY in self.ishape:
        broadcast_dims.append(ANY)
    self._shape.broadcast_dims = broadcast_dims

from_weight classmethod

from_weight(
    weight: Tensor,
    weight_shape: Shape,
    ioshape: Shape,
    shape_kwargs: Optional[dict] = None,
)

Construct a Diagonal by expanding weight to match ioshape via einops.

PARAMETER DESCRIPTION
weight

The weight tensor in its original (possibly lower-dimensional) shape.

TYPE: Tensor

weight_shape

Named dimensions labeling the axes of weight.

TYPE: Shape

ioshape

Target named dimensions for the expanded weight.

TYPE: Shape

shape_kwargs

Extra keyword arguments forwarded to einops.repeat.

TYPE: dict DEFAULT: None

RETURNS DESCRIPTION
Diagonal

A new diagonal linop with the expanded weight.

Source code in src/torchlinops/linops/diagonal.py
@classmethod
def from_weight(
    cls,
    weight: Tensor,
    weight_shape: Shape,
    ioshape: Shape,
    shape_kwargs: Optional[dict] = None,
):
    """Construct a ``Diagonal`` by expanding *weight* to match *ioshape* via einops.

    Parameters
    ----------
    weight : Tensor
        The weight tensor in its original (possibly lower-dimensional) shape.
    weight_shape : Shape
        Named dimensions labeling the axes of *weight*.
    ioshape : Shape
        Target named dimensions for the expanded weight.
    shape_kwargs : dict, optional
        Extra keyword arguments forwarded to ``einops.repeat``.

    Returns
    -------
    Diagonal
        A new diagonal linop with the expanded weight.
    """
    shape_kwargs = shape_kwargs if shape_kwargs is not None else {}
    if len(weight.shape) > len(ioshape):
        raise ValueError(
            f"All dimensions must be named or broadcastable, but got weight shape {weight.shape} and ioshape {ioshape}"
        )
    weight = repeat(
        weight,
        f"{' '.join(weight_shape)} -> {' '.join(ioshape)}",
        **shape_kwargs,
    )
    return cls(weight, ioshape)

FFT

Bases: NamedLinop

\(n\)-dimensional Fast Fourier Transform as a named linear operator.

With norm="ortho" (the default), the FFT is unitary: \(F^H F = I\). This means the normal operator is the identity and the adjoint is the inverse FFT.

ATTRIBUTE DESCRIPTION
ndim

Number of spatial dimensions to transform.

TYPE: int

norm

FFT normalization mode.

TYPE: str or None

centered

Whether to treat the array center as the origin (sigpy convention).

TYPE: bool

Source code in src/torchlinops/linops/fft.py
class FFT(NamedLinop):
    """$n$-dimensional Fast Fourier Transform as a named linear operator.

    With ``norm="ortho"`` (the default), the FFT is unitary: $F^H F = I$.
    This means the normal operator is the identity and the adjoint is the
    inverse FFT.

    Attributes
    ----------
    ndim : int
        Number of spatial dimensions to transform.
    norm : str or None
        FFT normalization mode.
    centered : bool
        Whether to treat the array center as the origin (sigpy convention).
    """

    def __init__(
        self,
        ndim: int,
        batch_shape: Optional[Shape] = None,
        grid_shapes: Optional[tuple[Shape, Shape]] = None,
        norm: Optional[str] = "ortho",
        centered: bool = False,
    ):
        """
        Parameters
        ----------
        ndim : int
            Number of dimensions to transform (1, 2, or 3).
        batch_shape : Shape, optional
            Named batch dimensions prepended to the grid dimensions.
            Defaults to an empty shape.
        grid_shapes : tuple[Shape, Shape], optional
            Pair of shapes ``(primal, dual)`` naming the input (image-space)
            and output (k-space) grid dimensions. Defaults to
            ``(Nx[, Ny[, Nz]])`` and ``(Kx[, Ky[, Kz]])``.
        norm : str or None, default ``"ortho"``
            Normalization applied to the FFT. Only ``"ortho"`` gives a true
            unitary forward/adjoint pair.
        centered : bool, default False
            If ``True``, treat the center of the array (``N // 2``) as the
            origin via ``fftshift`` / ``ifftshift``. Mimics sigpy convention.
        """
        self.ndim = ndim
        self.dim = tuple(range(-self.ndim, 0))
        self.grid_shapes = grid_shapes
        if grid_shapes is None:
            grid_shapes = get_nd_shape(self.dim), get_nd_shape(self.dim, kspace=True)
        elif len(grid_shapes) != 2:
            raise ValueError(
                f"grid_shapes should consist of two shape tuples but got {grid_shapes}"
            )
        if len(grid_shapes[0]) != len(grid_shapes[1]):
            raise ValueError(
                f"Input and output shapes of FFT must have same length but got len({grid_shapes[0]} != len({grid_shapes[1]})"
            )
        batch_shape = default_to(("...",), batch_shape)
        dim_shape = NS(*grid_shapes)
        shape = NS(batch_shape) + dim_shape
        super().__init__(shape)
        self._shape.batch_shape = batch_shape
        self._shape.input_grid_shape = grid_shapes[0]
        self._shape.output_grid_shape = grid_shapes[1]
        self.norm = norm
        self.centered = centered

    @property
    def batch_shape(self):
        return self._shape.batch_shape

    @staticmethod
    def fn(linop, x):
        if linop.centered:
            x = fft.ifftshift(x, dim=linop.dim)
        x = fft.fftn(x, dim=linop.dim, norm=linop.norm)
        if linop.centered:
            x = fft.fftshift(x, dim=linop.dim)
        return x

    @staticmethod
    def adj_fn(linop, x):
        if linop.centered:
            x = fft.ifftshift(x, dim=linop.dim)
        x = fft.ifftn(x, dim=linop.dim, norm=linop.norm)
        if linop.centered:
            x = fft.fftshift(x, dim=linop.dim)
        return x

    @staticmethod
    def normal_fn(linop, x):
        return x

    def split_forward(self, ibatch, obatch):
        """Splitting does nothing."""
        # TODO: raise an error if the FFT is split along an input or output grid dim
        new = copy(self)
        return new

    def normal(self, inner=None):
        """Return the normal operator $F^H F$.

        With orthonormal normalization, $F^H F = I$, so this returns an
        ``Identity`` when no inner operator is provided.

        Parameters
        ----------
        inner : NamedLinop, optional
            Inner operator for Toeplitz embedding.

        Returns
        -------
        NamedLinop
            ``Identity`` if *inner* is ``None``, otherwise the composed normal.
        """
        if inner is None:
            return Identity(self.ishape)
        return super().normal(inner)

__init__

__init__(
    ndim: int,
    batch_shape: Optional[Shape] = None,
    grid_shapes: Optional[tuple[Shape, Shape]] = None,
    norm: Optional[str] = "ortho",
    centered: bool = False,
)
PARAMETER DESCRIPTION
ndim

Number of dimensions to transform (1, 2, or 3).

TYPE: int

batch_shape

Named batch dimensions prepended to the grid dimensions. Defaults to an empty shape.

TYPE: Shape DEFAULT: None

grid_shapes

Pair of shapes (primal, dual) naming the input (image-space) and output (k-space) grid dimensions. Defaults to (Nx[, Ny[, Nz]]) and (Kx[, Ky[, Kz]]).

TYPE: tuple[Shape, Shape] DEFAULT: None

norm

Normalization applied to the FFT. Only "ortho" gives a true unitary forward/adjoint pair.

TYPE: str or None DEFAULT: ``"ortho"``

centered

If True, treat the center of the array (N // 2) as the origin via fftshift / ifftshift. Mimics sigpy convention.

TYPE: bool DEFAULT: False

Source code in src/torchlinops/linops/fft.py
def __init__(
    self,
    ndim: int,
    batch_shape: Optional[Shape] = None,
    grid_shapes: Optional[tuple[Shape, Shape]] = None,
    norm: Optional[str] = "ortho",
    centered: bool = False,
):
    """
    Parameters
    ----------
    ndim : int
        Number of dimensions to transform (1, 2, or 3).
    batch_shape : Shape, optional
        Named batch dimensions prepended to the grid dimensions.
        Defaults to an empty shape.
    grid_shapes : tuple[Shape, Shape], optional
        Pair of shapes ``(primal, dual)`` naming the input (image-space)
        and output (k-space) grid dimensions. Defaults to
        ``(Nx[, Ny[, Nz]])`` and ``(Kx[, Ky[, Kz]])``.
    norm : str or None, default ``"ortho"``
        Normalization applied to the FFT. Only ``"ortho"`` gives a true
        unitary forward/adjoint pair.
    centered : bool, default False
        If ``True``, treat the center of the array (``N // 2``) as the
        origin via ``fftshift`` / ``ifftshift``. Mimics sigpy convention.
    """
    self.ndim = ndim
    self.dim = tuple(range(-self.ndim, 0))
    self.grid_shapes = grid_shapes
    if grid_shapes is None:
        grid_shapes = get_nd_shape(self.dim), get_nd_shape(self.dim, kspace=True)
    elif len(grid_shapes) != 2:
        raise ValueError(
            f"grid_shapes should consist of two shape tuples but got {grid_shapes}"
        )
    if len(grid_shapes[0]) != len(grid_shapes[1]):
        raise ValueError(
            f"Input and output shapes of FFT must have same length but got len({grid_shapes[0]} != len({grid_shapes[1]})"
        )
    batch_shape = default_to(("...",), batch_shape)
    dim_shape = NS(*grid_shapes)
    shape = NS(batch_shape) + dim_shape
    super().__init__(shape)
    self._shape.batch_shape = batch_shape
    self._shape.input_grid_shape = grid_shapes[0]
    self._shape.output_grid_shape = grid_shapes[1]
    self.norm = norm
    self.centered = centered

normal

normal(inner=None)

Return the normal operator \(F^H F\).

With orthonormal normalization, \(F^H F = I\), so this returns an Identity when no inner operator is provided.

PARAMETER DESCRIPTION
inner

Inner operator for Toeplitz embedding.

TYPE: NamedLinop DEFAULT: None

RETURNS DESCRIPTION
NamedLinop

Identity if inner is None, otherwise the composed normal.

Source code in src/torchlinops/linops/fft.py
def normal(self, inner=None):
    """Return the normal operator $F^H F$.

    With orthonormal normalization, $F^H F = I$, so this returns an
    ``Identity`` when no inner operator is provided.

    Parameters
    ----------
    inner : NamedLinop, optional
        Inner operator for Toeplitz embedding.

    Returns
    -------
    NamedLinop
        ``Identity`` if *inner* is ``None``, otherwise the composed normal.
    """
    if inner is None:
        return Identity(self.ishape)
    return super().normal(inner)

split_forward

split_forward(ibatch, obatch)

Splitting does nothing.

Source code in src/torchlinops/linops/fft.py
def split_forward(self, ibatch, obatch):
    """Splitting does nothing."""
    # TODO: raise an error if the FFT is split along an input or output grid dim
    new = copy(self)
    return new

Identity

Bases: NamedLinop

Identity operator \(I(x) = x\).

Returns the input unchanged. The adjoint, normal, and any power of the identity are also the identity.

Source code in src/torchlinops/linops/identity.py
class Identity(NamedLinop):
    """Identity operator $I(x) = x$.

    Returns the input unchanged. The adjoint, normal, and any power of the
    identity are also the identity.
    """

    def __init__(self, ishape=("...",), oshape=None):
        super().__init__(NS(ishape, oshape))

    def adjoint(self):
        return self

    def normal(self, inner=None):
        if inner is None:
            return self
        return inner

    @staticmethod
    def fn(linop: NamedLinop, x, /):
        return x

    @staticmethod
    def adj_fn(linop: NamedLinop, x, /):
        return x

    @staticmethod
    def normal_fn(linop: NamedLinop, x, /):
        # A bit faster
        return x

    def split_forward(self, ibatch, obatch):
        # TODO: Allow non-diagonal splitting
        assert ibatch == obatch, "Identity linop must be split identically"
        return self

    def __pow__(self, _: float | Tensor):
        return copy(self)

Interpolate

Bases: NamedLinop

Interpolate from a grid to a set of off-grid points.

Input/output pattern::

(batch_shape, grid_shape) -> (batch_shape, locs_batch_shape)
ATTRIBUTE DESCRIPTION
locs

The target interpolation locations.

TYPE: Parameter

grid_size

The expected input grid size.

TYPE: tuple[int, ...]

interp_params

Dictionary of arguments for interpolation kernel.

TYPE: dict

Source code in src/torchlinops/linops/interp.py
class Interpolate(NamedLinop):
    """Interpolate from a grid to a set of off-grid points.

    Input/output pattern::

        (batch_shape, grid_shape) -> (batch_shape, locs_batch_shape)

    Attributes
    ----------
    locs : nn.Parameter
        The target interpolation locations.
    grid_size : tuple[int, ...]
        The expected input grid size.
    interp_params : dict
        Dictionary of arguments for interpolation kernel.
    """

    def __init__(
        self,
        locs: Float[Tensor, "... D"],
        grid_size: tuple[int, ...],
        batch_shape: Optional[Shape] = None,
        locs_batch_shape: Optional[Shape] = None,
        grid_shape: Optional[Shape] = None,
        # Interp params
        width: float = 4.0,
        kernel: str = "kaiser_bessel",
        norm: int = 1,
        pad_mode: str = "circular",
        kernel_params: Optional[dict] = None,
    ):
        """
        Parameters
        ----------
        locs : Float[Tensor, "... D"]
            The target interpolation locations, as a tensor of size (*locs_batch_size, num_dimensions).
            Uses 'ij' indexing.
        grid_size : tuple[int, ...]
            The expected input grid size. Should have same length as number of dimensions.
        batch_shape : Shape, optional
            The input/output batch shape. Defaults to "...".
        locs_batch_shape : Shape, optional
            The shape of the locs. Defaults to "...".
        grid_shape : Shape, optional
            The shape of the grid. Defaults to "...".
        width : float
            The width of the interpolation kernel.
        kernel : str
            The type of kernel to use. Current options are "kaiser_bessel" and "spline".
        norm : int
            The type of norm to use to measure distances. Current options are 1 and 2
        pad_mode : str
            The type of padding to apply.
        """
        if locs_batch_shape is not None:
            if len(locs_batch_shape) > len(locs.shape) - 1:
                raise ValueError(
                    f"locs_batch_shape has length longer than batch dim of locs. locs_batch_shape: {locs_batch_shape}, locs: {locs.shape}"
                )
        batch_shape = default_to(("...",), batch_shape)
        locs_batch_shape = default_to(("...",), locs_batch_shape)
        grid_shape = default_to(("...",), grid_shape)
        shape = NS(batch_shape) + NS(grid_shape, locs_batch_shape)
        super().__init__(shape)
        self._shape.batch_shape = batch_shape
        self._shape.locs_batch_shape = locs_batch_shape
        self._shape.grid_shape = grid_shape
        self.locs = nn.Parameter(locs, requires_grad=False)
        self.grid_size = grid_size

        # Do this here instead of repeating it in both fn() and adjoint_fn()
        kernel_params = default_to_dict(dict(beta=1.0), kernel_params)
        self.interp_params = {
            "width": width,
            "kernel": kernel,
            "norm": norm,
            "pad_mode": pad_mode,
            "kernel_params": kernel_params,
        }

    @staticmethod
    def fn(interp, x, /):
        return F.interpolate(x, interp.locs, **interp.interp_params)

    @staticmethod
    def adj_fn(interp, x, /):
        return F.interpolate_adjoint(
            x, interp.locs, interp.grid_size, **interp.interp_params
        )

    def split_forward(self, ibatch, obatch):
        return type(self)(
            self.split_locs(ibatch, obatch, self.locs),
            self.grid_size,
            self._shape.batch_shape,
            self._shape.locs_batch_shape,
            self._shape.grid_shape,
            **self.interp_params,
        )

    def split_locs(self, ibatch, obatch, /, locs):
        """Can only split on locs dimensions"""
        if self._shape.locs_batch_shape == ELLIPSES:
            return locs

        N = len(self._shape.locs_batch_shape)
        locs_slc = []
        for oslc in obatch[-N:]:
            locs_slc.append(oslc)
        locs_slc.append(slice(None))
        return locs[tuple(locs_slc)]

    def size(self, dim):
        if dim in self._shape.locs_batch_shape:
            dim_idx = self._shape.locs_batch_shape.index(dim)
            return self.locs.shape[dim_idx]
        elif dim in self._shape.grid_shape:
            dim_idx = self._shape.grid_shape.index(dim)
            return self.grid_size[dim_idx]
        return None

__init__

__init__(
    locs: Float[Tensor, "... D"],
    grid_size: tuple[int, ...],
    batch_shape: Optional[Shape] = None,
    locs_batch_shape: Optional[Shape] = None,
    grid_shape: Optional[Shape] = None,
    width: float = 4.0,
    kernel: str = "kaiser_bessel",
    norm: int = 1,
    pad_mode: str = "circular",
    kernel_params: Optional[dict] = None,
)
PARAMETER DESCRIPTION
locs

The target interpolation locations, as a tensor of size (*locs_batch_size, num_dimensions). Uses 'ij' indexing.

TYPE: Float[Tensor, '... D']

grid_size

The expected input grid size. Should have same length as number of dimensions.

TYPE: tuple[int, ...]

batch_shape

The input/output batch shape. Defaults to "...".

TYPE: Shape DEFAULT: None

locs_batch_shape

The shape of the locs. Defaults to "...".

TYPE: Shape DEFAULT: None

grid_shape

The shape of the grid. Defaults to "...".

TYPE: Shape DEFAULT: None

width

The width of the interpolation kernel.

TYPE: float DEFAULT: 4.0

kernel

The type of kernel to use. Current options are "kaiser_bessel" and "spline".

TYPE: str DEFAULT: 'kaiser_bessel'

norm

The type of norm to use to measure distances. Current options are 1 and 2

TYPE: int DEFAULT: 1

pad_mode

The type of padding to apply.

TYPE: str DEFAULT: 'circular'

Source code in src/torchlinops/linops/interp.py
def __init__(
    self,
    locs: Float[Tensor, "... D"],
    grid_size: tuple[int, ...],
    batch_shape: Optional[Shape] = None,
    locs_batch_shape: Optional[Shape] = None,
    grid_shape: Optional[Shape] = None,
    # Interp params
    width: float = 4.0,
    kernel: str = "kaiser_bessel",
    norm: int = 1,
    pad_mode: str = "circular",
    kernel_params: Optional[dict] = None,
):
    """
    Parameters
    ----------
    locs : Float[Tensor, "... D"]
        The target interpolation locations, as a tensor of size (*locs_batch_size, num_dimensions).
        Uses 'ij' indexing.
    grid_size : tuple[int, ...]
        The expected input grid size. Should have same length as number of dimensions.
    batch_shape : Shape, optional
        The input/output batch shape. Defaults to "...".
    locs_batch_shape : Shape, optional
        The shape of the locs. Defaults to "...".
    grid_shape : Shape, optional
        The shape of the grid. Defaults to "...".
    width : float
        The width of the interpolation kernel.
    kernel : str
        The type of kernel to use. Current options are "kaiser_bessel" and "spline".
    norm : int
        The type of norm to use to measure distances. Current options are 1 and 2
    pad_mode : str
        The type of padding to apply.
    """
    if locs_batch_shape is not None:
        if len(locs_batch_shape) > len(locs.shape) - 1:
            raise ValueError(
                f"locs_batch_shape has length longer than batch dim of locs. locs_batch_shape: {locs_batch_shape}, locs: {locs.shape}"
            )
    batch_shape = default_to(("...",), batch_shape)
    locs_batch_shape = default_to(("...",), locs_batch_shape)
    grid_shape = default_to(("...",), grid_shape)
    shape = NS(batch_shape) + NS(grid_shape, locs_batch_shape)
    super().__init__(shape)
    self._shape.batch_shape = batch_shape
    self._shape.locs_batch_shape = locs_batch_shape
    self._shape.grid_shape = grid_shape
    self.locs = nn.Parameter(locs, requires_grad=False)
    self.grid_size = grid_size

    # Do this here instead of repeating it in both fn() and adjoint_fn()
    kernel_params = default_to_dict(dict(beta=1.0), kernel_params)
    self.interp_params = {
        "width": width,
        "kernel": kernel,
        "norm": norm,
        "pad_mode": pad_mode,
        "kernel_params": kernel_params,
    }

split_locs

split_locs(ibatch, obatch, /, locs)

Can only split on locs dimensions

Source code in src/torchlinops/linops/interp.py
def split_locs(self, ibatch, obatch, /, locs):
    """Can only split on locs dimensions"""
    if self._shape.locs_batch_shape == ELLIPSES:
        return locs

    N = len(self._shape.locs_batch_shape)
    locs_slc = []
    for oslc in obatch[-N:]:
        locs_slc.append(oslc)
    locs_slc.append(slice(None))
    return locs[tuple(locs_slc)]

ND dataclass

Fundamental named dimension type used throughout the library.

Each dimension has a name and an optional integer index i for creating indexed variants (e.g. A1, A2). Two NamedDimension instances are considered equal when their string representations match; the index is folded into the representation rather than compared separately.

PARAMETER DESCRIPTION
name

The base name of the dimension (e.g. 'A', 'Nx').

TYPE: str

i

Integer index for indexed variants. Defaults to 0, which is omitted from the string representation.

TYPE: int DEFAULT: 0

Examples:

>>> NamedDimension("A")
A
>>> NamedDimension("A", 1)
A1
>>> NamedDimension("A") == "A"
True
Source code in src/torchlinops/nameddim/_nameddim.py
@dataclass(slots=True, frozen=True)
class NamedDimension:
    """Fundamental named dimension type used throughout the library.

    Each dimension has a ``name`` and an optional integer index ``i`` for
    creating indexed variants (e.g. ``A1``, ``A2``).  Two
    ``NamedDimension`` instances are considered equal when their string
    representations match; the index is folded into the representation
    rather than compared separately.

    Parameters
    ----------
    name : str
        The base name of the dimension (e.g. ``'A'``, ``'Nx'``).
    i : int, optional
        Integer index for indexed variants.  Defaults to ``0``, which is
        omitted from the string representation.

    Examples
    --------
    >>> NamedDimension("A")
    A
    >>> NamedDimension("A", 1)
    A1
    >>> NamedDimension("A") == "A"
    True
    """

    name: str
    i: int = 0

    @classmethod
    def infer(cls, dim: Any):
        """Create a NamedDimension by inferring the name and optional index.

        If *dim* is already a ``NamedDimension`` it is returned as-is.
        A two-character string whose second character is a digit is
        interpreted as ``name=dim[0], i=int(dim[1])``.  Sequences are
        inferred element-wise.

        Parameters
        ----------
        dim : Any
            A ``NamedDimension``, a string, or a list/tuple of those.

        Returns
        -------
        NamedDimension or sequence thereof
            The inferred dimension(s).
        """
        if isinstance(dim, cls):
            return dim
        if isinstance(dim, str) and len(dim) == 2:
            if dim[1].isdigit():
                return cls(dim[0], int(dim[1]))
        elif dim == ELLIPSES:
            return cls(ELLIPSES)
        elif isinstance(dim, Tuple) or isinstance(dim, List):
            return type(dim)(cls.infer(d) for d in dim)
        return cls(dim)

    def next_unused(self, avoid):
        """Get the next dim by index that does not occur in tup"""
        curr = copy(self)
        if self.name == ELLIPSES or self.name == ANY:
            return curr
        while curr in avoid:
            curr = curr + 1
        return curr

    def __repr__(self):
        return self.name + ("" if self.i == 0 else str(self.i))

    def __add__(self, k):
        if self.name == ELLIPSES:
            return self
        try:
            return type(self)(self.name, self.i + k)
        except TypeError as e:
            raise TypeError(f"Unsupported NamedDimension add: {self} + {k}", e)

    def __eq__(self, other):
        """Tests for simple string equality"""
        return repr(self) == other

    def __hash__(self):
        """Allow dictionary lookups to work with strings too."""
        return hash(repr(self))

__eq__

__eq__(other)

Tests for simple string equality

Source code in src/torchlinops/nameddim/_nameddim.py
def __eq__(self, other):
    """Tests for simple string equality"""
    return repr(self) == other

__hash__

__hash__()

Allow dictionary lookups to work with strings too.

Source code in src/torchlinops/nameddim/_nameddim.py
def __hash__(self):
    """Allow dictionary lookups to work with strings too."""
    return hash(repr(self))

infer classmethod

infer(dim: Any)

Create a NamedDimension by inferring the name and optional index.

If dim is already a NamedDimension it is returned as-is. A two-character string whose second character is a digit is interpreted as name=dim[0], i=int(dim[1]). Sequences are inferred element-wise.

PARAMETER DESCRIPTION
dim

A NamedDimension, a string, or a list/tuple of those.

TYPE: Any

RETURNS DESCRIPTION
NamedDimension or sequence thereof

The inferred dimension(s).

Source code in src/torchlinops/nameddim/_nameddim.py
@classmethod
def infer(cls, dim: Any):
    """Create a NamedDimension by inferring the name and optional index.

    If *dim* is already a ``NamedDimension`` it is returned as-is.
    A two-character string whose second character is a digit is
    interpreted as ``name=dim[0], i=int(dim[1])``.  Sequences are
    inferred element-wise.

    Parameters
    ----------
    dim : Any
        A ``NamedDimension``, a string, or a list/tuple of those.

    Returns
    -------
    NamedDimension or sequence thereof
        The inferred dimension(s).
    """
    if isinstance(dim, cls):
        return dim
    if isinstance(dim, str) and len(dim) == 2:
        if dim[1].isdigit():
            return cls(dim[0], int(dim[1]))
    elif dim == ELLIPSES:
        return cls(ELLIPSES)
    elif isinstance(dim, Tuple) or isinstance(dim, List):
        return type(dim)(cls.infer(d) for d in dim)
    return cls(dim)

next_unused

next_unused(avoid)

Get the next dim by index that does not occur in tup

Source code in src/torchlinops/nameddim/_nameddim.py
def next_unused(self, avoid):
    """Get the next dim by index that does not occur in tup"""
    curr = copy(self)
    if self.name == ELLIPSES or self.name == ANY:
        return curr
    while curr in avoid:
        curr = curr + 1
    return curr

NS

Bases: NamedDimCollection

A linop shape with input and output dimensions Inherit from this to define custom behavior - e.g. splitting ishape and oshape into subparts that are linked

Source code in src/torchlinops/nameddim/_namedshape.py
class NamedShape(NamedDimCollection):
    """A linop shape with input and output dimensions
    Inherit from this to define custom behavior
    - e.g. splitting ishape and oshape into subparts that are linked
    """

    def __init__(
        self,
        ishape: Optional["Shape | NamedShape"],
        oshape: Optional[Shape] = None,
        **other_shapes,
    ):
        """Construct a NamedShape from input and output dimension names.

        Parameters
        ----------
        ishape : Shape or NamedShape or None
            Input dimension names.  If a ``NamedShape`` instance is passed,
            it is copied directly and *oshape* / *other_shapes* are ignored.
            If ``None``, defaults to the ellipsis shape ``('...',)``.
        oshape : Shape or None, optional
            Output dimension names.  If ``None`` while *ishape* is provided,
            the operator is treated as diagonal (``oshape = ishape``).  If
            both *ishape* and *oshape* are ``None``, both default to
            ``('...',)``.
        **other_shapes
            Additional named shape sequences stored alongside ishape and
            oshape (e.g. auxiliary dimensions for specialised operators).
        """
        # Pass-through
        if isinstance(ishape, type(self)):
            super().__init__(**ishape.shapes)
            return

        if oshape is None:
            if ishape is None:
                # Empty shape
                oshape = ("...",)
            else:
                # Diagonal
                oshape = ishape
        if ishape is None:
            ishape = ("...",)
        super().__init__(ishape=ishape, oshape=oshape, **other_shapes)

    @property
    def other_shapes(self):
        """Shapes that are not ishape or oshape."""
        other_shapes = self.shapes.copy()
        for name in ["ishape", "oshape"]:  # Special attributes
            other_shapes.pop(name)
        return other_shapes

    def adjoint(self):
        """Return a new NamedShape with ishape and oshape swapped.

        Override this method in subclasses that need custom adjoint
        behaviour (e.g. swapping auxiliary shapes as well).

        Returns
        -------
        NamedShape
            A new instance with ``ishape`` and ``oshape`` exchanged.
        """
        new = type(self)(self.oshape, self.ishape, **self.other_shapes)
        return new

    def normal(self):
        """Return the NamedShape for the normal operator (A^H A).

        The resulting shape has ``ishape`` equal to the original ``ishape``
        and ``oshape`` derived from ``ishape`` with indices incremented to
        avoid collisions, representing the domain-to-domain mapping of the
        normal equation.

        Returns
        -------
        NamedShape
            A new instance representing the normal operator shape.
        """
        # If a shape appears in both ishape and oshape, it is considered
        # "diagonal".
        new_oshape = []
        for d in self.ishape:
            if d in self.oshape:
                new_oshape.append(d)
            else:
                new_oshape.append(d.next_unused(self.ishape))
        new = type(self)(self.ishape, new_oshape, **self.other_shapes)
        return new

    @property
    def H(self) -> "NamedShape":
        """The adjoint NamedShape (ishape and oshape swapped)."""
        return self.adjoint()

    @property
    def N(self) -> "NamedShape":
        """The normal NamedShape for the operator A^H A."""
        return self.normal()

    def __repr__(self):
        return f"{self.ishape} -> {self.oshape}"

    def __add__(self, right) -> "NamedShape":
        try:
            _ishape = self.ishape + right.ishape
        except TypeError as e:
            raise TypeError(
                f"Problem combining shapes {self.ishape} + {right.ishape}"
            ) from e
        try:
            _oshape = self.oshape + right.oshape
        except TypeError as e:
            raise TypeError(
                f"Problem combining shapes {self.oshape} + {right.oshape}"
            ) from e
        new = type(self)(ishape=_ishape, oshape=_oshape)
        new.update(self.other_shapes)
        new.update(right.other_shapes)
        return new

    def __radd__(self, left):
        if left is None:
            return self
        return left.__add__(self)

    def __eq__(self, other):
        return isequal(self.ishape, other.ishape) and isequal(self.oshape, other.oshape)

H property

H: NamedShape

The adjoint NamedShape (ishape and oshape swapped).

N property

N: NamedShape

The normal NamedShape for the operator A^H A.

other_shapes property

other_shapes

Shapes that are not ishape or oshape.

__init__

__init__(
    ishape: Optional[Shape | NamedShape],
    oshape: Optional[Shape] = None,
    **other_shapes,
)

Construct a NamedShape from input and output dimension names.

PARAMETER DESCRIPTION
ishape

Input dimension names. If a NamedShape instance is passed, it is copied directly and oshape / other_shapes are ignored. If None, defaults to the ellipsis shape ('...',).

TYPE: Shape or NamedShape or None

oshape

Output dimension names. If None while ishape is provided, the operator is treated as diagonal (oshape = ishape). If both ishape and oshape are None, both default to ('...',).

TYPE: Shape or None DEFAULT: None

**other_shapes

Additional named shape sequences stored alongside ishape and oshape (e.g. auxiliary dimensions for specialised operators).

DEFAULT: {}

Source code in src/torchlinops/nameddim/_namedshape.py
def __init__(
    self,
    ishape: Optional["Shape | NamedShape"],
    oshape: Optional[Shape] = None,
    **other_shapes,
):
    """Construct a NamedShape from input and output dimension names.

    Parameters
    ----------
    ishape : Shape or NamedShape or None
        Input dimension names.  If a ``NamedShape`` instance is passed,
        it is copied directly and *oshape* / *other_shapes* are ignored.
        If ``None``, defaults to the ellipsis shape ``('...',)``.
    oshape : Shape or None, optional
        Output dimension names.  If ``None`` while *ishape* is provided,
        the operator is treated as diagonal (``oshape = ishape``).  If
        both *ishape* and *oshape* are ``None``, both default to
        ``('...',)``.
    **other_shapes
        Additional named shape sequences stored alongside ishape and
        oshape (e.g. auxiliary dimensions for specialised operators).
    """
    # Pass-through
    if isinstance(ishape, type(self)):
        super().__init__(**ishape.shapes)
        return

    if oshape is None:
        if ishape is None:
            # Empty shape
            oshape = ("...",)
        else:
            # Diagonal
            oshape = ishape
    if ishape is None:
        ishape = ("...",)
    super().__init__(ishape=ishape, oshape=oshape, **other_shapes)

adjoint

adjoint()

Return a new NamedShape with ishape and oshape swapped.

Override this method in subclasses that need custom adjoint behaviour (e.g. swapping auxiliary shapes as well).

RETURNS DESCRIPTION
NamedShape

A new instance with ishape and oshape exchanged.

Source code in src/torchlinops/nameddim/_namedshape.py
def adjoint(self):
    """Return a new NamedShape with ishape and oshape swapped.

    Override this method in subclasses that need custom adjoint
    behaviour (e.g. swapping auxiliary shapes as well).

    Returns
    -------
    NamedShape
        A new instance with ``ishape`` and ``oshape`` exchanged.
    """
    new = type(self)(self.oshape, self.ishape, **self.other_shapes)
    return new

normal

normal()

Return the NamedShape for the normal operator (A^H A).

The resulting shape has ishape equal to the original ishape and oshape derived from ishape with indices incremented to avoid collisions, representing the domain-to-domain mapping of the normal equation.

RETURNS DESCRIPTION
NamedShape

A new instance representing the normal operator shape.

Source code in src/torchlinops/nameddim/_namedshape.py
def normal(self):
    """Return the NamedShape for the normal operator (A^H A).

    The resulting shape has ``ishape`` equal to the original ``ishape``
    and ``oshape`` derived from ``ishape`` with indices incremented to
    avoid collisions, representing the domain-to-domain mapping of the
    normal equation.

    Returns
    -------
    NamedShape
        A new instance representing the normal operator shape.
    """
    # If a shape appears in both ishape and oshape, it is considered
    # "diagonal".
    new_oshape = []
    for d in self.ishape:
        if d in self.oshape:
            new_oshape.append(d)
        else:
            new_oshape.append(d.next_unused(self.ishape))
    new = type(self)(self.ishape, new_oshape, **self.other_shapes)
    return new

NUFFT

Bases: Chain

Non-uniform Fast Fourier Transform (type II) as a named linear operator.

Implemented as a Chain of zero-padding, FFT, and interpolation. Supports forward (image-to-kspace) and adjoint (kspace-to-image) operations.

ATTRIBUTE DESCRIPTION
ndim

Number of spatial dimensions.

TYPE: int

oversamp

Oversampling factor for the padded grid.

TYPE: float

width

Interpolation kernel width.

TYPE: int

Source code in src/torchlinops/linops/nufft.py
class NUFFT(Chain):
    """Non-uniform Fast Fourier Transform (type II) as a named linear operator.

    Implemented as a ``Chain`` of zero-padding, FFT, and interpolation. Supports
    forward (image-to-kspace) and adjoint (kspace-to-image) operations.

    Attributes
    ----------
    ndim : int
        Number of spatial dimensions.
    oversamp : float
        Oversampling factor for the padded grid.
    width : int
        Interpolation kernel width.
    """

    def __init__(
        self,
        locs: Float[Tensor, "... D"],
        grid_size: tuple[int, ...],
        output_shape: Shape,
        input_shape: Optional[Shape] = None,
        input_kshape: Optional[Shape] = None,
        batch_shape: Optional[Shape] = None,
        oversamp: float = 1.25,
        width: float = 4.0,
        mode: Literal["interpolate", "sampling"] = "interpolate",
        do_prep_locs: bool = True,
        apodize_weights: Optional[Float[Tensor, "..."]] = None,
        **options,
    ):
        """
        Parameters
        ----------
        locs : Tensor, float
            Shape [... D] Tensor where last dimension is the spatial dimension.
            locs[..., i] Should be in the range [-N//2, N//2] where N is the grid_size[i], i.e.
            the grid size associated with that dimension
        grid_size : tuple of ints
            The expected spatial dimension of the input tensor.
        output_shape : Shape
        input_shape : Shape, optional
        input_kshape : Shape, optional
        batch_shape : Shape, optional
            NUFFT is implemented as a chain of padding, FFT, and interpolation
            Named Dimensions are set as follows:

            Pad: (*batch_shape, *input_shape) -> (*batch_shape, *next_unused(input_shape))
            FFT: (*batch_shape, *next_unused(input_shape)) -> (*batch_shape, *input_kshape)
            Interp: (*batch_shape, *input_kshape) -> (*batch_shape, *output_shape)

        oversamp : float
            Oversampling factor for fourier domain grid
        width : float
            Width of kernel to use for interpolation
        mode : str, "interpolate" or "sampling"
        do_prep_locs : bool, default True
            Whether to scale, shift, and clamp the locs to be amenable to interpolation
            By default (=True), assumes the locs lie in [-N/2, N/2]
                Scales, shifts and clamps them them to [0, oversamp*N - 1]
            If False, does not do this, which can have some benefits for memory reasons
        apodize_weights : Optional[Tensor]
            Provide apodization weights
            Only relevant for "intepolate" mode
            Can have memory benefits
        **options : dict
            Additional options
            toeplitz : bool
                If True, normal() performs toeplitz embedding calculation
            toeplitz_dtype : torch.dtype
                Data type for the toeplitz embedding. Probably should be torch.complex64

        """
        device = locs.device
        self.mode = mode
        self.options = options
        # Infer shapes
        self.input_shape = ND.infer(default_to(get_nd_shape(grid_size), input_shape))
        self.input_kshape = ND.infer(
            default_to(get_nd_shape(grid_size, kspace=True), input_kshape)
        )
        self.output_shape = ND.infer(output_shape)
        self.batch_shape = ND.infer(default_to(("...",), batch_shape))
        batched_input_shape = NS(batch_shape) + NS(self.input_shape)

        # Initialize variables
        ndim = len(grid_size)
        padded_size = tuple(int(i * oversamp) for i in grid_size)

        # Create Padding
        pad = Pad(
            padded_size,
            grid_size,
            in_shape=self.input_shape,
            batch_shape=self.batch_shape,
        )

        # Create FFT
        fft = FFT(
            ndim=locs.shape[-1],
            centered=True,
            norm="ortho",
            batch_shape=self.batch_shape,
            grid_shapes=(pad.out_im_shape, self.input_kshape),
        )

        # Create Interpolator
        grid_shape = fft._shape.output_grid_shape
        if do_prep_locs:
            locs_prepared = self.prep_locs(
                locs,
                grid_size,
                padded_size,
                nufft_mode=mode,
            )
        else:
            locs_prepared = locs
        if self.mode == "interpolate":
            beta = self.beta(width, oversamp)
            # Create Apodization
            if apodize_weights is None:
                weight = self.apodize_weights(
                    grid_size, padded_size, oversamp, width, beta
                ).to(device)  # Helps with batching later
            else:
                weight = apodize_weights
            if weight.isnan().any() or weight.isinf().any():
                raise ValueError(
                    f"Nan/Inf values detected in apodization weight (width={width}, oversamp={oversamp})."
                )
            apodize = Diagonal(weight, batched_input_shape.ishape)
            apodize.name = "Apodize"

            # Create Interpolator
            interp = Interpolate(
                locs_prepared,
                padded_size,
                batch_shape=self.batch_shape,
                locs_batch_shape=self.output_shape,
                grid_shape=grid_shape,
                width=width,
                kernel="kaiser_bessel",
                kernel_params=dict(beta=beta),
            )
            # Create scaling
            scale_factor = width**ndim * (prod(grid_size) / prod(padded_size)) ** 0.5
            scale = Scalar(weight=1.0 / scale_factor, ioshape=interp.oshape)
            scale.to(device)  # Helps with batching later
            linops = [apodize, pad, fft, interp, scale]
        elif self.mode == "sampling":
            if locs_prepared.is_complex() or locs_prepared.is_floating_point():
                raise ValueError(
                    f"Sampling linop requries integer-type locs but got {locs_prepared.dtype}"
                )
            # Clamp to within range
            interp = Sampling.from_stacked_idx(
                locs_prepared,
                dim=-1,
                # Arguments for Sampling
                input_size=padded_size,
                output_shape=self.output_shape,
                input_shape=grid_shape,
                batch_shape=self.batch_shape,
            )
            # No apodization or scaling needed
            linops = [pad, fft, interp]
        else:
            raise ValueError(f"Unrecognized NUFFT mode: {mode}")

        super().__init__(*linops, name="NUFFT")
        # Useful parameters to save
        self.locs = locs
        self.grid_size = grid_size
        self.oversamp = oversamp
        self.width = width

        # Handles to get modules directly
        self.pad = pad
        self.fft = fft
        self.interp = interp

    def adjoint(self):
        # Hybrid of chain adjoint and namedlinop adjoint
        adj = copy(self)
        adj._shape = adj._shape.H

        linops = list(linop.adjoint() for linop in reversed(self.linops))
        adj.linops = nn.ModuleList(linops)
        return adj

    def normal(self, inner=None):
        if self.options.get("toeplitz", False):
            dtype = self.options.get("toeplitz_dtype")
            oversamp = self.options.get("toeplitz_oversamp", 2.0)
            toep_kernel = toeplitz_psf(self, inner, dtype=dtype, oversamp=oversamp)
            pad = Pad(
                scale_int(self.grid_size, oversamp),
                self.grid_size,
                in_shape=self.input_shape,
                batch_shape=self.batch_shape,
            )
            fft = self.fft
            return pad.normal(fft.normal(toep_kernel))
        return super().normal(inner)

    @staticmethod
    def prep_locs(
        locs: Shaped[Tensor, "... D"],
        grid_size: tuple,
        padded_size: tuple,
        pad_mode: Literal["zero", "circular"] = "circular",
        nufft_mode: Literal["interpolate", "sampling"] = "interpolate",
    ):
        """
        Parameters
        ----------
        locs : Shaped[Tensor, "... D"]
            Input tensor representing locations in the grid. The last dimension corresponds to spatial dimensions.
            Range is [-N//2, N//2]
        grid_size : tuple
            The original size of the grid before padding.
        padded_size : tuple
            The size of the grid after padding.
        pad_mode : Literal["zero", "circular"], optional
            The type of padding applied. Can be "zero" for zero-padding or "circular" for circular padding.
            Default is "circular".
        nufft_mode : Literal["interpolate", "sampling"], optional
            The mode of the NUFFT operation. Can be "interpolate" for interpolation or "sampling" for sampling.
            Default is "interpolate".

        Returns
        -------
        Shaped[Tensor, "... D"]
            Adjusted locations tensor based on the specified padding and NUFFT modes.
            Range is [0, N_pad].
            dtype is floating-point if nufft_mode is "interpolate", and integer
            if nufft_mode is "sampling"

        Raises
        ------
        ValueError
            If an unrecognized `pad_mode` is provided.

        Examples
        --------
        >>> _ = torch.manual_seed(0);
        >>> locs = torch.rand(1000, 3) * 64 - 32 # [-32, 32]
        >>> locs.min()
        tensor(-31.9949)
        >>> locs.max()
        tensor(31.9896)
        >>> grid_size = (64, 64, 64)
        >>> padded_size = (80, 80, 80) # oversamp = 1.25
        >>> locs_scaled_shifted = NUFFT.prep_locs(locs, grid_size, padded_size)
        >>> locs_scaled_shifted.min()
        tensor(0.0064)
        >>> locs_scaled_shifted.max()
        tensor(79.9871)

        >>> _ = torch.manual_seed(0);
        >>> locs = torch.rand(1000, 3) * 64 - 32 # [-32, 32]
        >>> locs = torch.round(locs * 1.25) / 1.25
        >>> grid_size = (64, 64, 64)
        >>> padded_size = (80, 80, 80) # oversamp = 1.25
        >>> locs_scaled_shifted = NUFFT.prep_locs(locs, grid_size, padded_size, nufft_mode='sampling')
        >>> locs_scaled_shifted.min()
        tensor(0)
        >>> locs_scaled_shifted.max()
        tensor(79)



        Notes
        -----
        - Assumes that the input `locs` are centered.
        - Adjusts the locations by scaling and shifting them according to the grid and padded sizes.
        - Applies clamping or remainder operations based on the padding mode and NUFFT mode.
        """
        # Clone to prevent in-place scaling from modifying the original
        out = locs.clone()
        for i in range(-len(grid_size), 0):
            out[..., i] *= padded_size[i] / grid_size[i]
            out[..., i] += padded_size[i] // 2
            if pad_mode == "zero":
                out[..., i] = torch.clamp(out[..., i], 0, padded_size[i] - 1)
            elif pad_mode == "circular":
                if nufft_mode == "interpolate":
                    out[..., i] = torch.remainder(
                        out[..., i], torch.tensor(padded_size[i])
                    )
                elif nufft_mode == "sampling":
                    # Wrap rounded index to other side of kspace
                    out[..., i] = torch.round(out[..., i])
                    out[..., i] = torch.remainder(out[..., i], padded_size[i])
            else:
                raise ValueError(f"Unrecognized padding mode during prep: {pad_mode}")
        if nufft_mode == "sampling":
            out = out.to(torch.int64)
        return out

    @staticmethod
    def beta(width, oversamp):
        """
        https://sigpy.readthedocs.io/en/latest/_modules/sigpy/fourier.html#nufft

        References
        ----------
        Beatty PJ, Nishimura DG, Pauly JM. Rapid gridding reconstruction with a minimal oversampling ratio.
        IEEE Trans Med Imaging. 2005 Jun;24(6):799-808. doi: 10.1109/TMI.2005.848376. PMID: 15959939.
        """
        return torch.pi * (((width / oversamp) * (oversamp - 0.5)) ** 2 - 0.8) ** 0.5

    @staticmethod
    def apodize_weights(grid_size, padded_size, oversamp, width: float, beta: float):
        grid_size = torch.tensor(grid_size)
        padded_size = torch.tensor(padded_size)
        grid = torch.meshgrid(*(torch.arange(s) for s in grid_size), indexing="ij")
        grid = torch.stack(grid, dim=-1)

        # Sigpy compatibility
        apod = (
            beta**2 - (torch.pi * width * (grid - grid_size // 2) / padded_size) ** 2
        ) ** 0.5
        apod /= torch.sinh(apod)

        # Beatty paper
        # apod = (torch.pi * width * (grid - grid_size // 2) / padded_size) ** 2 - beta**2
        # print(apod)
        apod = torch.prod(apod, dim=-1)
        return apod

    def split_forward(self, ibatch, obatch):
        ibatch_lookup = {d: slc for d, slc in zip(self.ishape, ibatch)}
        obatch_lookup = {d: slc for d, slc in zip(self.oshape, obatch)}
        split_linops = []
        for linop in self.linops:
            sub_ibatch = [ibatch_lookup.get(dim, slice(None)) for dim in linop.ishape]
            sub_obatch = [obatch_lookup.get(dim, slice(None)) for dim in linop.oshape]
            split_linops.append(linop.split_forward(sub_ibatch, sub_obatch))
        out = copy(self)
        out.linops = nn.ModuleList(split_linops)
        return out

    def flatten(self):
        """Don't combine constituent linops into a chain with other linops
        Informs how split_forward should behave
        """
        return [self]

    @property
    def device(self):
        """Tracks device of interpolating/sampling linop
        Useful for toeplitz
        """
        if self.mode == "interpolate":
            return self.interp.locs.device
        elif self.mode == "sampling":
            return self.interp.idx[0].device
        raise ValueError(f"Unrecognized NUFFT mode: {self.mode}")

device property

device

Tracks device of interpolating/sampling linop Useful for toeplitz

__init__

__init__(
    locs: Float[Tensor, "... D"],
    grid_size: tuple[int, ...],
    output_shape: Shape,
    input_shape: Optional[Shape] = None,
    input_kshape: Optional[Shape] = None,
    batch_shape: Optional[Shape] = None,
    oversamp: float = 1.25,
    width: float = 4.0,
    mode: Literal[
        "interpolate", "sampling"
    ] = "interpolate",
    do_prep_locs: bool = True,
    apodize_weights: Optional[Float[Tensor, ...]] = None,
    **options,
)
PARAMETER DESCRIPTION
locs

Shape [... D] Tensor where last dimension is the spatial dimension. locs[..., i] Should be in the range [-N//2, N//2] where N is the grid_size[i], i.e. the grid size associated with that dimension

TYPE: (Tensor, float)

grid_size

The expected spatial dimension of the input tensor.

TYPE: tuple of ints

output_shape

TYPE: Shape

input_shape

TYPE: Shape DEFAULT: None

input_kshape

TYPE: Shape DEFAULT: None

batch_shape

NUFFT is implemented as a chain of padding, FFT, and interpolation Named Dimensions are set as follows:

Pad: (batch_shape, input_shape) -> (batch_shape, next_unused(input_shape)) FFT: (batch_shape, next_unused(input_shape)) -> (batch_shape, input_kshape) Interp: (batch_shape, input_kshape) -> (batch_shape, output_shape)

TYPE: Shape DEFAULT: None

oversamp

Oversampling factor for fourier domain grid

TYPE: float DEFAULT: 1.25

width

Width of kernel to use for interpolation

TYPE: float DEFAULT: 4.0

mode

TYPE: (str, interpolate or sampling) DEFAULT: 'interpolate'

do_prep_locs

Whether to scale, shift, and clamp the locs to be amenable to interpolation By default (=True), assumes the locs lie in [-N/2, N/2] Scales, shifts and clamps them them to [0, oversamp*N - 1] If False, does not do this, which can have some benefits for memory reasons

TYPE: bool DEFAULT: True

apodize_weights

Provide apodization weights Only relevant for "intepolate" mode Can have memory benefits

TYPE: Optional[Tensor] DEFAULT: None

**options

Additional options toeplitz : bool If True, normal() performs toeplitz embedding calculation toeplitz_dtype : torch.dtype Data type for the toeplitz embedding. Probably should be torch.complex64

TYPE: dict DEFAULT: {}

Source code in src/torchlinops/linops/nufft.py
def __init__(
    self,
    locs: Float[Tensor, "... D"],
    grid_size: tuple[int, ...],
    output_shape: Shape,
    input_shape: Optional[Shape] = None,
    input_kshape: Optional[Shape] = None,
    batch_shape: Optional[Shape] = None,
    oversamp: float = 1.25,
    width: float = 4.0,
    mode: Literal["interpolate", "sampling"] = "interpolate",
    do_prep_locs: bool = True,
    apodize_weights: Optional[Float[Tensor, "..."]] = None,
    **options,
):
    """
    Parameters
    ----------
    locs : Tensor, float
        Shape [... D] Tensor where last dimension is the spatial dimension.
        locs[..., i] Should be in the range [-N//2, N//2] where N is the grid_size[i], i.e.
        the grid size associated with that dimension
    grid_size : tuple of ints
        The expected spatial dimension of the input tensor.
    output_shape : Shape
    input_shape : Shape, optional
    input_kshape : Shape, optional
    batch_shape : Shape, optional
        NUFFT is implemented as a chain of padding, FFT, and interpolation
        Named Dimensions are set as follows:

        Pad: (*batch_shape, *input_shape) -> (*batch_shape, *next_unused(input_shape))
        FFT: (*batch_shape, *next_unused(input_shape)) -> (*batch_shape, *input_kshape)
        Interp: (*batch_shape, *input_kshape) -> (*batch_shape, *output_shape)

    oversamp : float
        Oversampling factor for fourier domain grid
    width : float
        Width of kernel to use for interpolation
    mode : str, "interpolate" or "sampling"
    do_prep_locs : bool, default True
        Whether to scale, shift, and clamp the locs to be amenable to interpolation
        By default (=True), assumes the locs lie in [-N/2, N/2]
            Scales, shifts and clamps them them to [0, oversamp*N - 1]
        If False, does not do this, which can have some benefits for memory reasons
    apodize_weights : Optional[Tensor]
        Provide apodization weights
        Only relevant for "intepolate" mode
        Can have memory benefits
    **options : dict
        Additional options
        toeplitz : bool
            If True, normal() performs toeplitz embedding calculation
        toeplitz_dtype : torch.dtype
            Data type for the toeplitz embedding. Probably should be torch.complex64

    """
    device = locs.device
    self.mode = mode
    self.options = options
    # Infer shapes
    self.input_shape = ND.infer(default_to(get_nd_shape(grid_size), input_shape))
    self.input_kshape = ND.infer(
        default_to(get_nd_shape(grid_size, kspace=True), input_kshape)
    )
    self.output_shape = ND.infer(output_shape)
    self.batch_shape = ND.infer(default_to(("...",), batch_shape))
    batched_input_shape = NS(batch_shape) + NS(self.input_shape)

    # Initialize variables
    ndim = len(grid_size)
    padded_size = tuple(int(i * oversamp) for i in grid_size)

    # Create Padding
    pad = Pad(
        padded_size,
        grid_size,
        in_shape=self.input_shape,
        batch_shape=self.batch_shape,
    )

    # Create FFT
    fft = FFT(
        ndim=locs.shape[-1],
        centered=True,
        norm="ortho",
        batch_shape=self.batch_shape,
        grid_shapes=(pad.out_im_shape, self.input_kshape),
    )

    # Create Interpolator
    grid_shape = fft._shape.output_grid_shape
    if do_prep_locs:
        locs_prepared = self.prep_locs(
            locs,
            grid_size,
            padded_size,
            nufft_mode=mode,
        )
    else:
        locs_prepared = locs
    if self.mode == "interpolate":
        beta = self.beta(width, oversamp)
        # Create Apodization
        if apodize_weights is None:
            weight = self.apodize_weights(
                grid_size, padded_size, oversamp, width, beta
            ).to(device)  # Helps with batching later
        else:
            weight = apodize_weights
        if weight.isnan().any() or weight.isinf().any():
            raise ValueError(
                f"Nan/Inf values detected in apodization weight (width={width}, oversamp={oversamp})."
            )
        apodize = Diagonal(weight, batched_input_shape.ishape)
        apodize.name = "Apodize"

        # Create Interpolator
        interp = Interpolate(
            locs_prepared,
            padded_size,
            batch_shape=self.batch_shape,
            locs_batch_shape=self.output_shape,
            grid_shape=grid_shape,
            width=width,
            kernel="kaiser_bessel",
            kernel_params=dict(beta=beta),
        )
        # Create scaling
        scale_factor = width**ndim * (prod(grid_size) / prod(padded_size)) ** 0.5
        scale = Scalar(weight=1.0 / scale_factor, ioshape=interp.oshape)
        scale.to(device)  # Helps with batching later
        linops = [apodize, pad, fft, interp, scale]
    elif self.mode == "sampling":
        if locs_prepared.is_complex() or locs_prepared.is_floating_point():
            raise ValueError(
                f"Sampling linop requries integer-type locs but got {locs_prepared.dtype}"
            )
        # Clamp to within range
        interp = Sampling.from_stacked_idx(
            locs_prepared,
            dim=-1,
            # Arguments for Sampling
            input_size=padded_size,
            output_shape=self.output_shape,
            input_shape=grid_shape,
            batch_shape=self.batch_shape,
        )
        # No apodization or scaling needed
        linops = [pad, fft, interp]
    else:
        raise ValueError(f"Unrecognized NUFFT mode: {mode}")

    super().__init__(*linops, name="NUFFT")
    # Useful parameters to save
    self.locs = locs
    self.grid_size = grid_size
    self.oversamp = oversamp
    self.width = width

    # Handles to get modules directly
    self.pad = pad
    self.fft = fft
    self.interp = interp

beta staticmethod

beta(width, oversamp)

https://sigpy.readthedocs.io/en/latest/_modules/sigpy/fourier.html#nufft

References

Beatty PJ, Nishimura DG, Pauly JM. Rapid gridding reconstruction with a minimal oversampling ratio. IEEE Trans Med Imaging. 2005 Jun;24(6):799-808. doi: 10.1109/TMI.2005.848376. PMID: 15959939.

Source code in src/torchlinops/linops/nufft.py
@staticmethod
def beta(width, oversamp):
    """
    https://sigpy.readthedocs.io/en/latest/_modules/sigpy/fourier.html#nufft

    References
    ----------
    Beatty PJ, Nishimura DG, Pauly JM. Rapid gridding reconstruction with a minimal oversampling ratio.
    IEEE Trans Med Imaging. 2005 Jun;24(6):799-808. doi: 10.1109/TMI.2005.848376. PMID: 15959939.
    """
    return torch.pi * (((width / oversamp) * (oversamp - 0.5)) ** 2 - 0.8) ** 0.5

flatten

flatten()

Don't combine constituent linops into a chain with other linops Informs how split_forward should behave

Source code in src/torchlinops/linops/nufft.py
def flatten(self):
    """Don't combine constituent linops into a chain with other linops
    Informs how split_forward should behave
    """
    return [self]

prep_locs staticmethod

prep_locs(
    locs: Shaped[Tensor, "... D"],
    grid_size: tuple,
    padded_size: tuple,
    pad_mode: Literal["zero", "circular"] = "circular",
    nufft_mode: Literal[
        "interpolate", "sampling"
    ] = "interpolate",
)
PARAMETER DESCRIPTION
locs

Input tensor representing locations in the grid. The last dimension corresponds to spatial dimensions. Range is [-N//2, N//2]

TYPE: Shaped[Tensor, '... D']

grid_size

The original size of the grid before padding.

TYPE: tuple

padded_size

The size of the grid after padding.

TYPE: tuple

pad_mode

The type of padding applied. Can be "zero" for zero-padding or "circular" for circular padding. Default is "circular".

TYPE: Literal['zero', 'circular'] DEFAULT: 'circular'

nufft_mode

The mode of the NUFFT operation. Can be "interpolate" for interpolation or "sampling" for sampling. Default is "interpolate".

TYPE: Literal['interpolate', 'sampling'] DEFAULT: 'interpolate'

RETURNS DESCRIPTION
Shaped[Tensor, '... D']

Adjusted locations tensor based on the specified padding and NUFFT modes. Range is [0, N_pad]. dtype is floating-point if nufft_mode is "interpolate", and integer if nufft_mode is "sampling"

RAISES DESCRIPTION
ValueError

If an unrecognized pad_mode is provided.

Examples:

>>> _ = torch.manual_seed(0);
>>> locs = torch.rand(1000, 3) * 64 - 32 # [-32, 32]
>>> locs.min()
tensor(-31.9949)
>>> locs.max()
tensor(31.9896)
>>> grid_size = (64, 64, 64)
>>> padded_size = (80, 80, 80) # oversamp = 1.25
>>> locs_scaled_shifted = NUFFT.prep_locs(locs, grid_size, padded_size)
>>> locs_scaled_shifted.min()
tensor(0.0064)
>>> locs_scaled_shifted.max()
tensor(79.9871)
>>> _ = torch.manual_seed(0);
>>> locs = torch.rand(1000, 3) * 64 - 32 # [-32, 32]
>>> locs = torch.round(locs * 1.25) / 1.25
>>> grid_size = (64, 64, 64)
>>> padded_size = (80, 80, 80) # oversamp = 1.25
>>> locs_scaled_shifted = NUFFT.prep_locs(locs, grid_size, padded_size, nufft_mode='sampling')
>>> locs_scaled_shifted.min()
tensor(0)
>>> locs_scaled_shifted.max()
tensor(79)
Notes
  • Assumes that the input locs are centered.
  • Adjusts the locations by scaling and shifting them according to the grid and padded sizes.
  • Applies clamping or remainder operations based on the padding mode and NUFFT mode.
Source code in src/torchlinops/linops/nufft.py
@staticmethod
def prep_locs(
    locs: Shaped[Tensor, "... D"],
    grid_size: tuple,
    padded_size: tuple,
    pad_mode: Literal["zero", "circular"] = "circular",
    nufft_mode: Literal["interpolate", "sampling"] = "interpolate",
):
    """
    Parameters
    ----------
    locs : Shaped[Tensor, "... D"]
        Input tensor representing locations in the grid. The last dimension corresponds to spatial dimensions.
        Range is [-N//2, N//2]
    grid_size : tuple
        The original size of the grid before padding.
    padded_size : tuple
        The size of the grid after padding.
    pad_mode : Literal["zero", "circular"], optional
        The type of padding applied. Can be "zero" for zero-padding or "circular" for circular padding.
        Default is "circular".
    nufft_mode : Literal["interpolate", "sampling"], optional
        The mode of the NUFFT operation. Can be "interpolate" for interpolation or "sampling" for sampling.
        Default is "interpolate".

    Returns
    -------
    Shaped[Tensor, "... D"]
        Adjusted locations tensor based on the specified padding and NUFFT modes.
        Range is [0, N_pad].
        dtype is floating-point if nufft_mode is "interpolate", and integer
        if nufft_mode is "sampling"

    Raises
    ------
    ValueError
        If an unrecognized `pad_mode` is provided.

    Examples
    --------
    >>> _ = torch.manual_seed(0);
    >>> locs = torch.rand(1000, 3) * 64 - 32 # [-32, 32]
    >>> locs.min()
    tensor(-31.9949)
    >>> locs.max()
    tensor(31.9896)
    >>> grid_size = (64, 64, 64)
    >>> padded_size = (80, 80, 80) # oversamp = 1.25
    >>> locs_scaled_shifted = NUFFT.prep_locs(locs, grid_size, padded_size)
    >>> locs_scaled_shifted.min()
    tensor(0.0064)
    >>> locs_scaled_shifted.max()
    tensor(79.9871)

    >>> _ = torch.manual_seed(0);
    >>> locs = torch.rand(1000, 3) * 64 - 32 # [-32, 32]
    >>> locs = torch.round(locs * 1.25) / 1.25
    >>> grid_size = (64, 64, 64)
    >>> padded_size = (80, 80, 80) # oversamp = 1.25
    >>> locs_scaled_shifted = NUFFT.prep_locs(locs, grid_size, padded_size, nufft_mode='sampling')
    >>> locs_scaled_shifted.min()
    tensor(0)
    >>> locs_scaled_shifted.max()
    tensor(79)



    Notes
    -----
    - Assumes that the input `locs` are centered.
    - Adjusts the locations by scaling and shifting them according to the grid and padded sizes.
    - Applies clamping or remainder operations based on the padding mode and NUFFT mode.
    """
    # Clone to prevent in-place scaling from modifying the original
    out = locs.clone()
    for i in range(-len(grid_size), 0):
        out[..., i] *= padded_size[i] / grid_size[i]
        out[..., i] += padded_size[i] // 2
        if pad_mode == "zero":
            out[..., i] = torch.clamp(out[..., i], 0, padded_size[i] - 1)
        elif pad_mode == "circular":
            if nufft_mode == "interpolate":
                out[..., i] = torch.remainder(
                    out[..., i], torch.tensor(padded_size[i])
                )
            elif nufft_mode == "sampling":
                # Wrap rounded index to other side of kspace
                out[..., i] = torch.round(out[..., i])
                out[..., i] = torch.remainder(out[..., i], padded_size[i])
        else:
            raise ValueError(f"Unrecognized padding mode during prep: {pad_mode}")
    if nufft_mode == "sampling":
        out = out.to(torch.int64)
    return out

NamedLinop

Bases: Module

Base class for all named linear operators.

A NamedLinop represents a linear map \(A : X \to Y\) where the input and output tensor dimensions are identified by name (e.g. ("Nx", "Ny") -> ("Kx", "Ky")).

Subclass this to implement concrete operators. At minimum, override fn and adj_fn as static methods.

ATTRIBUTE DESCRIPTION
shape

The named shape of the linop, containing ishape and oshape.

TYPE: NamedShape

stream

Optional cuda Stream to run this linop on.

TYPE: Stream

start_event

An event that signals when the linop has started. Useful for synchronizing multiple linops across multiple devices.

TYPE: (Event, optional)

end_event

An event that signals when the linop has completed. Useful for synchronizing multiple linops across multiple devices.

TYPE: (Event, optional)

input_listener

Pointer to another linop's event attribute. Used to coordinate GPU-to-GPU transfers in parallel execution contexts. When set to a tuple like (some_linop, "start_event"), the device transfer will wait for that event to be recorded before initiating the transfer.

TYPE: tuple(linop, str) or None

Source code in src/torchlinops/linops/namedlinop.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
class NamedLinop(nn.Module):
    """Base class for all named linear operators.

    A ``NamedLinop`` represents a linear map $A : X \\to Y$ where the input and
    output tensor dimensions are identified by name (e.g. ``("Nx", "Ny") -> ("Kx", "Ky")``).

    Subclass this to implement concrete operators. At minimum, override ``fn``
    and ``adj_fn`` as static methods.

    Attributes
    ----------
    shape : NamedShape
        The named shape of the linop, containing ``ishape`` and ``oshape``.
    stream : torch.cuda.Stream
        Optional cuda Stream to run this linop on.
    start_event : Event, optional
        An event that signals when the linop has started. Useful for synchronizing
        multiple linops across multiple devices.
    end_event : Event, optional
        An event that signals when the linop has completed. Useful for synchronizing multiple
        linops across multiple devices.
    input_listener : tuple(linop, str) or None
        Pointer to another linop's event attribute. Used to coordinate GPU-to-GPU
        transfers in parallel execution contexts. When set to a tuple like
        ``(some_linop, "start_event")``, the device transfer will wait for that
        event to be recorded before initiating the transfer.
    """

    def __init__(self, shape: NamedShape, name: Optional[str] = None, **kwargs):
        """
        Parameters
        ----------
        shape : NamedShape
            The shape of this linop, e.g. ``NamedShape(("N",), ("M",))``
        name : str, optional
            Optional name to display for this linop.
        """
        super().__init__(**kwargs)
        # Note: this attribute is private because the `.shape` attribute may be derived
        # dynamically
        self._shape = shape
        self._suffix = ""
        self._name = name
        self._setup()

    def _setup(self):
        """Helper method that should be called to reset the linop's state.
        Should be performed after any substantial changes to the linop."""
        self.reset_adjoint_and_normal()
        self.stream = None
        self.start_event = None
        self.end_event = None
        self._input_listener = ForwardedAttribute()
        # By default, listen for the start of this linop
        self.input_listener = (self, "start_event")

    @final
    def forward(self, x: Tensor) -> Tensor:
        """Apply the forward operation $y = A(x)$.

        If a CUDA stream is assigned, execution is dispatched to that stream.
        If a ``start_event`` is set, it is recorded before execution begins,
        allowing other operators to synchronize on it.

        Do not override this method. Instead, override .fn() and .adj_fn().

        Parameters
        ----------
        x : Tensor
            Input tensor.

        Returns
        -------
        Tensor
            The result of applying this linop to *x*.
        """
        if x.is_cuda:  # pragma: no cover
            stream = default_to(default_stream(x.device), self.stream)
            self.start_event = stream.record_event()
            with torch.cuda.stream(stream):
                y = self.fn(self, x)
            x.record_stream(stream)
            self.end_event = stream.record_event()
        else:
            y = self.fn(self, x)
        return y

    def apply(self, x: Tensor) -> Tensor:
        """Apply the linear operator to a tensor."""
        return LinopFunction.apply(x, self)

    # Override
    @staticmethod
    def fn(linop, x: Tensor, /) -> Tensor:
        """Compute the forward operation $y = A(x)$.

        Override this in subclasses to define the linop's forward behavior.

        Parameters
        ----------
        linop : NamedLinop
            The linop instance (passed explicitly because this is a staticmethod).
        x : Tensor
            Input tensor.

        Returns
        -------
        Tensor
            Result of applying the linop to *x*.

        Notes
        -----
        Declared as a staticmethod so that ``adjoint()`` can swap ``fn`` and
        ``adj_fn`` on a shallow copy without bound-method complications.
        """
        return x

    # Override
    @staticmethod
    def adj_fn(linop, x: Tensor, /) -> Tensor:
        """Compute the adjoint operation $y = A^H(x)$.

        Override this in subclasses to define the linop's adjoint behavior.

        Parameters
        ----------
        linop : NamedLinop
            The linop instance.
        x : Tensor
            Input tensor.

        Returns
        -------
        Tensor
            Result of applying the adjoint $A^H$ to *x*.
        """
        return x

    # Override
    @staticmethod
    def normal_fn(linop, x: Tensor, /) -> Tensor:
        """Compute the normal operation $y = A^H A(x)$.

        The default implementation composes ``adj_fn(fn(x))``. Override this
        in subclasses that have an efficient closed-form normal (e.g.
        ``Diagonal``, ``FFT``).

        Parameters
        ----------
        linop : NamedLinop
            The linop instance.
        x : Tensor
            Input tensor.

        Returns
        -------
        Tensor
            Result of applying $A^H A$ to *x*.
        """
        return linop.adj_fn(linop, linop.fn(linop, x))

    # Override
    def split_forward(self, ibatch, obatch) -> "NamedLinop":
        """Split this linop into a sub-linop according to slices over its dimensions.

        Override this in subclasses to define how the linop decomposes when tiled
        along its named dimensions. For the companion method that handles adjoints,
        see ``adj_split``.

        Parameters
        ----------
        ibatch : tuple[slice, ...]
            Slices over the input dimensions, one per element of ``ishape``.
        obatch : tuple[slice, ...]
            Slices over the output dimensions, one per element of ``oshape``.

        Returns
        -------
        NamedLinop
            A new linop that operates on the specified slice of the data.
        """

        return type(self)(self._shape)

    # Override
    def size(self, dim: str) -> int | None:
        """Return the concrete size of *dim*, or ``None`` if this linop does not determine it.

        Parameters
        ----------
        dim : str
            The named dimension to query.

        Returns
        -------
        int or None
            The size of the dimension, or ``None``.
        """
        return None

    @final
    @property
    def dims(self) -> set:
        """Get the set of dims that appear in this linop."""
        return set(self.ishape).union(set(self.oshape))

    @final
    @property
    def H(self) -> "NamedLinop":
        """Adjoint operator $A^H$.

        By default, creates a new adjoint on each access. Set
        ``torchlinops.config.cache_adjoint_normal = True`` to enable caching
        (deprecated).
        """
        try:
            if config.cache_adjoint_normal:
                config._warn_if_caching_enabled()
                if self._adjoint is None:
                    try:
                        _adjoint = self.adjoint()
                        _adjoint._adjoint = [self]
                        self._adjoint = [_adjoint]
                    except AttributeError as e:
                        traceback.print_exc()
                        raise e
                    logger.debug(
                        f"{type(self).__name__}: Making new adjoint {_adjoint._shape}"
                    )
                return self._adjoint[0]
            return self.adjoint()
        except AttributeError as e:
            raise RuntimeError(f"AttributeError in {type(self).__name__}.H: {e}") from e

    def adjoint(self) -> "NamedLinop":
        """Create the adjoint operator $A^H$.

        The default implementation shallow-copies this linop, swaps ``fn`` and
        ``adj_fn``, and flips the shape. Override this in subclasses that need
        special adjoint construction (e.g. conjugating weights).

        Returns
        -------
        NamedLinop
            The adjoint operator, sharing the same underlying data.
        """
        adj = copy(self)  # Retains data
        adj._shape = adj._shape.H
        # Swap functions (requires staticmethod)
        adj.fn, adj.adj_fn = adj.adj_fn, adj.fn
        adj.split, adj.adj_split = adj.adj_split, adj.split
        adj._update_suffix(adjoint=True)
        return adj

    @final
    def _update_suffix(self, adjoint: bool = False, normal: bool = False):
        if adjoint:
            if self._suffix.endswith(".H"):
                self._suffix = self._suffix[:-2]
            else:
                self._suffix += ".H"
        elif normal:
            self._suffix += ".N"

    @final
    @property
    def N(self) -> "NamedLinop":
        """Normal operator $A^H A$.

        Note that the naive normal operator can always be created via ``A.H @ A``.
        This function is reserved for custom behavior, as many linops have
        optimized normal forms.

        By default, creates a new normal on each access. Set
        ``torchlinops.config.cache_adjoint_normal = True`` to enable caching
        (deprecated).
        """
        try:
            if config.cache_adjoint_normal:
                config._warn_if_caching_enabled()
                if self._normal is None:
                    try:
                        _normal = self.normal()
                        self._normal = [_normal]
                    except AttributeError as e:
                        traceback.print_exc()
                        raise e
                return self._normal[0]
            return self.normal()
        except AttributeError as e:
            raise RuntimeError(f"AttributeError in {type(self).__name__}.N: {e}") from e

    def normal(self, inner=None) -> "NamedLinop":
        """Create the normal operator $A^H A$, optionally with an inner operator.

        When *inner* is ``None`` (or ``Identity`` with the reduce-identity config
        enabled), creates a linop whose forward pass calls ``normal_fn``.

        When *inner* is provided, constructs the composition $A^H \\cdot \\text{inner} \\cdot A$,
        which is used for Toeplitz embedding and similar optimizations.

        Parameters
        ----------
        inner : NamedLinop, optional
            An optional inner operator for Toeplitz embedding. If ``None``,
            the standard normal $A^H A$ is computed.

        Returns
        -------
        NamedLinop
            The normal operator.
        """
        if config.inner_not_relevant(inner):
            normal = copy(self)
            normal._shape = self._shape.N

            # Auxiliary object
            # Avoids creating lambda functions, which enables multiprocessing
            function_table = NormalFunctionLookup(self)
            # Static
            normal.fn = function_table.new_forward_adjoint_fn
            normal.adj_fn = function_table.new_forward_adjoint_fn
            normal.normal_fn = function_table.new_normal_fn
            # Bind `self` with partial to avoid weird multiprocessing-only error?
            normal.adjoint = partial(new_normal_adjoint, self=normal)
            # normal.adjoint = new_normal_adjoint.__get__(normal) # This one doesn't work

            # Assume that none of the dims are the same anymore
            # Override this behavior for e.g. diagonal linops
            normal.oshape = tuple(d.next_unused(normal.ishape) for d in normal.oshape)
            # Remember which shapes were updated
            normal._shape_updates = {
                d: d.next_unused(normal.ishape) for d in normal.oshape
            }
            normal._update_suffix(normal=True)
            return normal
        pre = copy(self)
        pre.oshape = inner.ishape
        post = self.adjoint()  # Copy happens inside adjoint
        post.ishape = inner.oshape
        normal = post @ inner @ pre
        normal._shape_updates = getattr(inner, "_shape_updates", {})
        return normal

    @final
    @staticmethod
    def split(linop, tile: Mapping[ND | str, slice]) -> "NamedLinop":
        """Split a linop into a sub-linop for a given tile.

        Translates a tile dictionary into per-dimension slices and delegates
        to ``split_forward``.

        Parameters
        ----------
        linop : NamedLinop
            The linop to split.
        tile : Mapping[ND | str, slice]
            Dictionary mapping dimension names to slices.

        Returns
        -------
        NamedLinop
            The sub-linop operating on the specified tile.
        """
        ibatch = [tile.get(dim, slice(None)) for dim in linop.ishape]
        obatch = [tile.get(dim, slice(None)) for dim in linop.oshape]
        return linop.split_forward(ibatch, obatch)

    @final
    @staticmethod
    def adj_split(linop, tile: Mapping[ND | str, slice]) -> "NamedLinop":
        """Split the adjoint of this linop for a given tile.

        Constructs the adjoint, splits it according to *tile*, and returns the
        adjoint of the split.

        Parameters
        ----------
        linop : NamedLinop
            The linop whose adjoint should be split.
        tile : Mapping[ND | str, slice]
            Dictionary mapping dimension names to slices.

        Returns
        -------
        NamedLinop
            The split adjoint sub-linop.
        """
        ibatch = [tile.get(dim, slice(None)) for dim in linop.ishape]
        obatch = [tile.get(dim, slice(None)) for dim in linop.oshape]
        splitH = linop.adjoint().split_forward(obatch, ibatch).adjoint()
        return splitH

    def flatten(self) -> list["NamedLinop"]:
        """Get a flattened list of constituent linops for composition."""
        return [self]

    def compose(self, inner) -> "NamedLinop":
        """Compose this linop with another linop.

        Parameters
        ----------
        inner : NamedLinop
            The linop to call before this one.

        Returns
        -------
        NamedLinop
            The composition of self and inner. If A = self and B = inner then this returns
            C = AB.
        """
        before = inner.flatten()
        after = self.flatten()
        return torchlinops.Chain(*(before + after))

    def __add__(self, right) -> "NamedLinop":
        return torchlinops.Add(self, right)

    def __radd__(self, left) -> "NamedLinop":
        return torchlinops.Add(left, self)

    def __mul__(self, right) -> "NamedLinop":
        if isinstance(right, (int, float)) or isinstance(right, torch.Tensor):
            right = torchlinops.Scalar(weight=right, ioshape=self.ishape)
            return self.compose(right)
        return NotImplemented

    def __rmul__(self, left) -> "NamedLinop":
        if isinstance(left, (int, float)) or isinstance(left, torch.Tensor):
            left = torchlinops.Scalar(weight=left, ioshape=self.oshape)
            return left.compose(self)
        return NotImplemented

    def __neg__(self) -> "NamedLinop":
        return (-1) * self

    def __sub__(self, right) -> "NamedLinop":
        return torchlinops.Add(self, -right)

    def __rsub__(self, left) -> "NamedLinop":
        if isinstance(left, NamedLinop):
            return torchlinops.Add(left, -self)
        return NotImplemented

    def __matmul__(self, right) -> "NamedLinop":
        if isinstance(right, NamedLinop):
            return self.compose(right)
        if isinstance(right, torch.Tensor):
            return self(right)
        return NotImplemented

    def __rmatmul__(self, left) -> "NamedLinop":
        if not isinstance(left, NamedLinop):
            raise ValueError(
                f"__rmatmul__ of linop {type(self)} with non-linop of type {type(left)} is undefined."
            )
        return left.compose(self)

    @property
    def name(self):
        if self._name is not None:
            return self._name
        return type(self).__name__

    @name.setter
    def name(self, new_name):
        self._name = new_name

    @property
    def repr_name(self):
        return self.name + self._suffix

    def __repr__(self):
        out = f"{self.repr_name}({self.ishape} -> {self.oshape})"
        if self.start_event is not None:  # pragma: no cover
            out += f", start: {self.start_event.event_id:x}"
        if self.end_event is not None:  # pragma: no cover
            out += f", end: {self.end_event.event_id:x}"
        out = INDENT.indent(out)
        return out

    def reset_adjoint_and_normal(self):
        self._adjoint = None
        self._normal = None

    @property
    def shape(self) -> Shape:
        return self._shape

    @shape.setter
    def shape(self, val):
        self._shape = val

    @property
    def ishape(self):
        return self._shape.ishape

    @ishape.setter
    def ishape(self, val):
        self._shape.ishape = val

    @property
    def oshape(self):
        return self._shape.oshape

    @oshape.setter
    def oshape(self, val):
        self._shape.oshape = val

    def to(self, device, memory_aware: bool = False, called_by_adjoint: bool = False):
        """Move this linop (and its cached adjoint/normal) to *device*.

        Parameters
        ----------
        device : torch.device or str
            Target device.
        memory_aware : bool, default False
            If ``True``, use ``memory_aware_to`` which preserves shared-storage
            topology when moving tensors.
        called_by_adjoint : bool, default False
            Internal flag to prevent infinite recursion when the adjoint
            also calls ``.to()``. Will be deprecated along with cache_adjoint_normal.

        Returns
        -------
        NamedLinop
            The linop on the target device.
        """

        if config.cache_adjoint_normal:  # pragma: no cover
            config._warn_if_caching_enabled()
            if self._adjoint and not called_by_adjoint:
                # bool flag avoids infinite recursion
                self._adjoint[0] = self._adjoint[0].to(
                    device, memory_aware, called_by_adjoint=True
                )
            if self._normal:
                self._normal[0] = self._normal[0].to(device, memory_aware)
        if memory_aware:
            return memory_aware_to(self, device)
        return super().to(device)

    @property
    def input_listener(self):
        """Pointer to another linop event attribute.

        Useful for facilitating gpu-gpu transfers in parallel.

        For example, if ToDevice occurs inside a composing linop that allows for
        parallel execution, e.g.

        C = Concat(
            Chain(ToDevice1, A, ...),
            Chain(ToDevice2, B, ...),
            ...
        )

        Then we may want to set ToDevice1 and ToDevice2 to both listen for the beginning of C.
        That way, both device movements can be triggered in parallel.

        This attribute is a universal attribute so that it can be chained in cases of nesting, e.g.
        Add(
            Concat(
                Chain(ToDevice, ...), ...
                ...
            )
        )
        The innermost ToDevice can listens to Chain, which listens to Concat, which listens to Add.
        This is good because Concat and Add both can parallelize efficiently across multiple GPUs.
        """
        return self._input_listener.value

    @input_listener.setter
    def input_listener(self, value):
        if isinstance(value, tuple):
            _log_transfer(
                f"Setting {type(self).__name__}.input_listener to reference {type(value[0]).__name__}.{value[1]}"
            )
            self._input_listener.forward_to(*value)
        else:
            _log_transfer(f"Setting {type(self).__name__}.input_listener to {value}")
            self._input_listener = value

    def __copy__(self):
        """Specialized copying for linops.

        Notes
        -----
        - Shares previous data
        - Removes references to adjoint and normal
        - Creates a new shape object, rather than using the old one
        """
        cls = type(self)
        new = cls.__new__(cls)
        new.__dict__ = self.__dict__.copy()
        # Pytorch-specific module state dictionaries
        # Mirror those used in `__getattr__``
        # See https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/module.py#L1915
        new._parameters = new._parameters.copy()
        new._modules = new._modules.copy()
        new._buffers = new._buffers.copy()

        # Create new shape
        new._shape = deepcopy(self._shape)
        new._setup()
        return new

    @final
    def __deepcopy__(self, _):
        return memory_aware_deepcopy(self)

H property

H: NamedLinop

Adjoint operator \(A^H\).

By default, creates a new adjoint on each access. Set torchlinops.config.cache_adjoint_normal = True to enable caching (deprecated).

N property

N: NamedLinop

Normal operator \(A^H A\).

Note that the naive normal operator can always be created via A.H @ A. This function is reserved for custom behavior, as many linops have optimized normal forms.

By default, creates a new normal on each access. Set torchlinops.config.cache_adjoint_normal = True to enable caching (deprecated).

dims property

dims: set

Get the set of dims that appear in this linop.

input_listener property writable

input_listener

Pointer to another linop event attribute.

Useful for facilitating gpu-gpu transfers in parallel.

For example, if ToDevice occurs inside a composing linop that allows for parallel execution, e.g.

C = Concat( Chain(ToDevice1, A, ...), Chain(ToDevice2, B, ...), ... )

Then we may want to set ToDevice1 and ToDevice2 to both listen for the beginning of C. That way, both device movements can be triggered in parallel.

This attribute is a universal attribute so that it can be chained in cases of nesting, e.g. Add( Concat( Chain(ToDevice, ...), ... ... ) ) The innermost ToDevice can listens to Chain, which listens to Concat, which listens to Add. This is good because Concat and Add both can parallelize efficiently across multiple GPUs.

__copy__

__copy__()

Specialized copying for linops.

Notes
  • Shares previous data
  • Removes references to adjoint and normal
  • Creates a new shape object, rather than using the old one
Source code in src/torchlinops/linops/namedlinop.py
def __copy__(self):
    """Specialized copying for linops.

    Notes
    -----
    - Shares previous data
    - Removes references to adjoint and normal
    - Creates a new shape object, rather than using the old one
    """
    cls = type(self)
    new = cls.__new__(cls)
    new.__dict__ = self.__dict__.copy()
    # Pytorch-specific module state dictionaries
    # Mirror those used in `__getattr__``
    # See https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/module.py#L1915
    new._parameters = new._parameters.copy()
    new._modules = new._modules.copy()
    new._buffers = new._buffers.copy()

    # Create new shape
    new._shape = deepcopy(self._shape)
    new._setup()
    return new

__init__

__init__(
    shape: NamedShape, name: Optional[str] = None, **kwargs
)
PARAMETER DESCRIPTION
shape

The shape of this linop, e.g. NamedShape(("N",), ("M",))

TYPE: NamedShape

name

Optional name to display for this linop.

TYPE: str DEFAULT: None

Source code in src/torchlinops/linops/namedlinop.py
def __init__(self, shape: NamedShape, name: Optional[str] = None, **kwargs):
    """
    Parameters
    ----------
    shape : NamedShape
        The shape of this linop, e.g. ``NamedShape(("N",), ("M",))``
    name : str, optional
        Optional name to display for this linop.
    """
    super().__init__(**kwargs)
    # Note: this attribute is private because the `.shape` attribute may be derived
    # dynamically
    self._shape = shape
    self._suffix = ""
    self._name = name
    self._setup()

adj_fn staticmethod

adj_fn(linop, x: Tensor) -> Tensor

Compute the adjoint operation \(y = A^H(x)\).

Override this in subclasses to define the linop's adjoint behavior.

PARAMETER DESCRIPTION
linop

The linop instance.

TYPE: NamedLinop

x

Input tensor.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Result of applying the adjoint \(A^H\) to x.

Source code in src/torchlinops/linops/namedlinop.py
@staticmethod
def adj_fn(linop, x: Tensor, /) -> Tensor:
    """Compute the adjoint operation $y = A^H(x)$.

    Override this in subclasses to define the linop's adjoint behavior.

    Parameters
    ----------
    linop : NamedLinop
        The linop instance.
    x : Tensor
        Input tensor.

    Returns
    -------
    Tensor
        Result of applying the adjoint $A^H$ to *x*.
    """
    return x

adj_split staticmethod

adj_split(
    linop, tile: Mapping[NamedDimension | str, slice]
) -> NamedLinop

Split the adjoint of this linop for a given tile.

Constructs the adjoint, splits it according to tile, and returns the adjoint of the split.

PARAMETER DESCRIPTION
linop

The linop whose adjoint should be split.

TYPE: NamedLinop

tile

Dictionary mapping dimension names to slices.

TYPE: Mapping[NamedDimension | str, slice]

RETURNS DESCRIPTION
NamedLinop

The split adjoint sub-linop.

Source code in src/torchlinops/linops/namedlinop.py
@final
@staticmethod
def adj_split(linop, tile: Mapping[ND | str, slice]) -> "NamedLinop":
    """Split the adjoint of this linop for a given tile.

    Constructs the adjoint, splits it according to *tile*, and returns the
    adjoint of the split.

    Parameters
    ----------
    linop : NamedLinop
        The linop whose adjoint should be split.
    tile : Mapping[ND | str, slice]
        Dictionary mapping dimension names to slices.

    Returns
    -------
    NamedLinop
        The split adjoint sub-linop.
    """
    ibatch = [tile.get(dim, slice(None)) for dim in linop.ishape]
    obatch = [tile.get(dim, slice(None)) for dim in linop.oshape]
    splitH = linop.adjoint().split_forward(obatch, ibatch).adjoint()
    return splitH

adjoint

adjoint() -> NamedLinop

Create the adjoint operator \(A^H\).

The default implementation shallow-copies this linop, swaps fn and adj_fn, and flips the shape. Override this in subclasses that need special adjoint construction (e.g. conjugating weights).

RETURNS DESCRIPTION
NamedLinop

The adjoint operator, sharing the same underlying data.

Source code in src/torchlinops/linops/namedlinop.py
def adjoint(self) -> "NamedLinop":
    """Create the adjoint operator $A^H$.

    The default implementation shallow-copies this linop, swaps ``fn`` and
    ``adj_fn``, and flips the shape. Override this in subclasses that need
    special adjoint construction (e.g. conjugating weights).

    Returns
    -------
    NamedLinop
        The adjoint operator, sharing the same underlying data.
    """
    adj = copy(self)  # Retains data
    adj._shape = adj._shape.H
    # Swap functions (requires staticmethod)
    adj.fn, adj.adj_fn = adj.adj_fn, adj.fn
    adj.split, adj.adj_split = adj.adj_split, adj.split
    adj._update_suffix(adjoint=True)
    return adj

apply

apply(x: Tensor) -> Tensor

Apply the linear operator to a tensor.

Source code in src/torchlinops/linops/namedlinop.py
def apply(self, x: Tensor) -> Tensor:
    """Apply the linear operator to a tensor."""
    return LinopFunction.apply(x, self)

compose

compose(inner) -> NamedLinop

Compose this linop with another linop.

PARAMETER DESCRIPTION
inner

The linop to call before this one.

TYPE: NamedLinop

RETURNS DESCRIPTION
NamedLinop

The composition of self and inner. If A = self and B = inner then this returns C = AB.

Source code in src/torchlinops/linops/namedlinop.py
def compose(self, inner) -> "NamedLinop":
    """Compose this linop with another linop.

    Parameters
    ----------
    inner : NamedLinop
        The linop to call before this one.

    Returns
    -------
    NamedLinop
        The composition of self and inner. If A = self and B = inner then this returns
        C = AB.
    """
    before = inner.flatten()
    after = self.flatten()
    return torchlinops.Chain(*(before + after))

flatten

flatten() -> list[NamedLinop]

Get a flattened list of constituent linops for composition.

Source code in src/torchlinops/linops/namedlinop.py
def flatten(self) -> list["NamedLinop"]:
    """Get a flattened list of constituent linops for composition."""
    return [self]

fn staticmethod

fn(linop, x: Tensor) -> Tensor

Compute the forward operation \(y = A(x)\).

Override this in subclasses to define the linop's forward behavior.

PARAMETER DESCRIPTION
linop

The linop instance (passed explicitly because this is a staticmethod).

TYPE: NamedLinop

x

Input tensor.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Result of applying the linop to x.

Notes

Declared as a staticmethod so that adjoint() can swap fn and adj_fn on a shallow copy without bound-method complications.

Source code in src/torchlinops/linops/namedlinop.py
@staticmethod
def fn(linop, x: Tensor, /) -> Tensor:
    """Compute the forward operation $y = A(x)$.

    Override this in subclasses to define the linop's forward behavior.

    Parameters
    ----------
    linop : NamedLinop
        The linop instance (passed explicitly because this is a staticmethod).
    x : Tensor
        Input tensor.

    Returns
    -------
    Tensor
        Result of applying the linop to *x*.

    Notes
    -----
    Declared as a staticmethod so that ``adjoint()`` can swap ``fn`` and
    ``adj_fn`` on a shallow copy without bound-method complications.
    """
    return x

forward

forward(x: Tensor) -> Tensor

Apply the forward operation \(y = A(x)\).

If a CUDA stream is assigned, execution is dispatched to that stream. If a start_event is set, it is recorded before execution begins, allowing other operators to synchronize on it.

Do not override this method. Instead, override .fn() and .adj_fn().

PARAMETER DESCRIPTION
x

Input tensor.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

The result of applying this linop to x.

Source code in src/torchlinops/linops/namedlinop.py
@final
def forward(self, x: Tensor) -> Tensor:
    """Apply the forward operation $y = A(x)$.

    If a CUDA stream is assigned, execution is dispatched to that stream.
    If a ``start_event`` is set, it is recorded before execution begins,
    allowing other operators to synchronize on it.

    Do not override this method. Instead, override .fn() and .adj_fn().

    Parameters
    ----------
    x : Tensor
        Input tensor.

    Returns
    -------
    Tensor
        The result of applying this linop to *x*.
    """
    if x.is_cuda:  # pragma: no cover
        stream = default_to(default_stream(x.device), self.stream)
        self.start_event = stream.record_event()
        with torch.cuda.stream(stream):
            y = self.fn(self, x)
        x.record_stream(stream)
        self.end_event = stream.record_event()
    else:
        y = self.fn(self, x)
    return y

normal

normal(inner=None) -> NamedLinop

Create the normal operator \(A^H A\), optionally with an inner operator.

When inner is None (or Identity with the reduce-identity config enabled), creates a linop whose forward pass calls normal_fn.

When inner is provided, constructs the composition \(A^H \cdot \text{inner} \cdot A\), which is used for Toeplitz embedding and similar optimizations.

PARAMETER DESCRIPTION
inner

An optional inner operator for Toeplitz embedding. If None, the standard normal \(A^H A\) is computed.

TYPE: NamedLinop DEFAULT: None

RETURNS DESCRIPTION
NamedLinop

The normal operator.

Source code in src/torchlinops/linops/namedlinop.py
def normal(self, inner=None) -> "NamedLinop":
    """Create the normal operator $A^H A$, optionally with an inner operator.

    When *inner* is ``None`` (or ``Identity`` with the reduce-identity config
    enabled), creates a linop whose forward pass calls ``normal_fn``.

    When *inner* is provided, constructs the composition $A^H \\cdot \\text{inner} \\cdot A$,
    which is used for Toeplitz embedding and similar optimizations.

    Parameters
    ----------
    inner : NamedLinop, optional
        An optional inner operator for Toeplitz embedding. If ``None``,
        the standard normal $A^H A$ is computed.

    Returns
    -------
    NamedLinop
        The normal operator.
    """
    if config.inner_not_relevant(inner):
        normal = copy(self)
        normal._shape = self._shape.N

        # Auxiliary object
        # Avoids creating lambda functions, which enables multiprocessing
        function_table = NormalFunctionLookup(self)
        # Static
        normal.fn = function_table.new_forward_adjoint_fn
        normal.adj_fn = function_table.new_forward_adjoint_fn
        normal.normal_fn = function_table.new_normal_fn
        # Bind `self` with partial to avoid weird multiprocessing-only error?
        normal.adjoint = partial(new_normal_adjoint, self=normal)
        # normal.adjoint = new_normal_adjoint.__get__(normal) # This one doesn't work

        # Assume that none of the dims are the same anymore
        # Override this behavior for e.g. diagonal linops
        normal.oshape = tuple(d.next_unused(normal.ishape) for d in normal.oshape)
        # Remember which shapes were updated
        normal._shape_updates = {
            d: d.next_unused(normal.ishape) for d in normal.oshape
        }
        normal._update_suffix(normal=True)
        return normal
    pre = copy(self)
    pre.oshape = inner.ishape
    post = self.adjoint()  # Copy happens inside adjoint
    post.ishape = inner.oshape
    normal = post @ inner @ pre
    normal._shape_updates = getattr(inner, "_shape_updates", {})
    return normal

normal_fn staticmethod

normal_fn(linop, x: Tensor) -> Tensor

Compute the normal operation \(y = A^H A(x)\).

The default implementation composes adj_fn(fn(x)). Override this in subclasses that have an efficient closed-form normal (e.g. Diagonal, FFT).

PARAMETER DESCRIPTION
linop

The linop instance.

TYPE: NamedLinop

x

Input tensor.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Result of applying \(A^H A\) to x.

Source code in src/torchlinops/linops/namedlinop.py
@staticmethod
def normal_fn(linop, x: Tensor, /) -> Tensor:
    """Compute the normal operation $y = A^H A(x)$.

    The default implementation composes ``adj_fn(fn(x))``. Override this
    in subclasses that have an efficient closed-form normal (e.g.
    ``Diagonal``, ``FFT``).

    Parameters
    ----------
    linop : NamedLinop
        The linop instance.
    x : Tensor
        Input tensor.

    Returns
    -------
    Tensor
        Result of applying $A^H A$ to *x*.
    """
    return linop.adj_fn(linop, linop.fn(linop, x))

size

size(dim: str) -> int | None

Return the concrete size of dim, or None if this linop does not determine it.

PARAMETER DESCRIPTION
dim

The named dimension to query.

TYPE: str

RETURNS DESCRIPTION
int or None

The size of the dimension, or None.

Source code in src/torchlinops/linops/namedlinop.py
def size(self, dim: str) -> int | None:
    """Return the concrete size of *dim*, or ``None`` if this linop does not determine it.

    Parameters
    ----------
    dim : str
        The named dimension to query.

    Returns
    -------
    int or None
        The size of the dimension, or ``None``.
    """
    return None

split staticmethod

split(
    linop, tile: Mapping[NamedDimension | str, slice]
) -> NamedLinop

Split a linop into a sub-linop for a given tile.

Translates a tile dictionary into per-dimension slices and delegates to split_forward.

PARAMETER DESCRIPTION
linop

The linop to split.

TYPE: NamedLinop

tile

Dictionary mapping dimension names to slices.

TYPE: Mapping[NamedDimension | str, slice]

RETURNS DESCRIPTION
NamedLinop

The sub-linop operating on the specified tile.

Source code in src/torchlinops/linops/namedlinop.py
@final
@staticmethod
def split(linop, tile: Mapping[ND | str, slice]) -> "NamedLinop":
    """Split a linop into a sub-linop for a given tile.

    Translates a tile dictionary into per-dimension slices and delegates
    to ``split_forward``.

    Parameters
    ----------
    linop : NamedLinop
        The linop to split.
    tile : Mapping[ND | str, slice]
        Dictionary mapping dimension names to slices.

    Returns
    -------
    NamedLinop
        The sub-linop operating on the specified tile.
    """
    ibatch = [tile.get(dim, slice(None)) for dim in linop.ishape]
    obatch = [tile.get(dim, slice(None)) for dim in linop.oshape]
    return linop.split_forward(ibatch, obatch)

split_forward

split_forward(ibatch, obatch) -> NamedLinop

Split this linop into a sub-linop according to slices over its dimensions.

Override this in subclasses to define how the linop decomposes when tiled along its named dimensions. For the companion method that handles adjoints, see adj_split.

PARAMETER DESCRIPTION
ibatch

Slices over the input dimensions, one per element of ishape.

TYPE: tuple[slice, ...]

obatch

Slices over the output dimensions, one per element of oshape.

TYPE: tuple[slice, ...]

RETURNS DESCRIPTION
NamedLinop

A new linop that operates on the specified slice of the data.

Source code in src/torchlinops/linops/namedlinop.py
def split_forward(self, ibatch, obatch) -> "NamedLinop":
    """Split this linop into a sub-linop according to slices over its dimensions.

    Override this in subclasses to define how the linop decomposes when tiled
    along its named dimensions. For the companion method that handles adjoints,
    see ``adj_split``.

    Parameters
    ----------
    ibatch : tuple[slice, ...]
        Slices over the input dimensions, one per element of ``ishape``.
    obatch : tuple[slice, ...]
        Slices over the output dimensions, one per element of ``oshape``.

    Returns
    -------
    NamedLinop
        A new linop that operates on the specified slice of the data.
    """

    return type(self)(self._shape)

to

to(
    device,
    memory_aware: bool = False,
    called_by_adjoint: bool = False,
)

Move this linop (and its cached adjoint/normal) to device.

PARAMETER DESCRIPTION
device

Target device.

TYPE: device or str

memory_aware

If True, use memory_aware_to which preserves shared-storage topology when moving tensors.

TYPE: bool DEFAULT: False

called_by_adjoint

Internal flag to prevent infinite recursion when the adjoint also calls .to(). Will be deprecated along with cache_adjoint_normal.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
NamedLinop

The linop on the target device.

Source code in src/torchlinops/linops/namedlinop.py
def to(self, device, memory_aware: bool = False, called_by_adjoint: bool = False):
    """Move this linop (and its cached adjoint/normal) to *device*.

    Parameters
    ----------
    device : torch.device or str
        Target device.
    memory_aware : bool, default False
        If ``True``, use ``memory_aware_to`` which preserves shared-storage
        topology when moving tensors.
    called_by_adjoint : bool, default False
        Internal flag to prevent infinite recursion when the adjoint
        also calls ``.to()``. Will be deprecated along with cache_adjoint_normal.

    Returns
    -------
    NamedLinop
        The linop on the target device.
    """

    if config.cache_adjoint_normal:  # pragma: no cover
        config._warn_if_caching_enabled()
        if self._adjoint and not called_by_adjoint:
            # bool flag avoids infinite recursion
            self._adjoint[0] = self._adjoint[0].to(
                device, memory_aware, called_by_adjoint=True
            )
        if self._normal:
            self._normal[0] = self._normal[0].to(device, memory_aware)
    if memory_aware:
        return memory_aware_to(self, device)
    return super().to(device)

Pad

Bases: NamedLinop

Zero Pad the last dimensions of the input volume Padding is centered: - TODO: support non-centered padding? ishape: [B... Nx Ny [Nz]] oshape: [B... Nx1 Ny1 [Nz1]]

Source code in src/torchlinops/linops/pad_last.py
class Pad(NamedLinop):
    """Zero Pad the last dimensions of the input volume
    Padding is centered:
    - TODO: support non-centered padding?
    ishape: [B... Nx Ny [Nz]]
    oshape: [B... Nx1 Ny1 [Nz1]]

    """

    def __init__(
        self,
        pad_im_size: tuple[int, ...],
        im_size: tuple[int, ...],
        in_shape: Optional[Shape] = None,
        out_shape: Optional[Shape] = None,
        batch_shape: Optional[Shape] = None,
    ):
        """
        Parameters
        ----------
        pad_im_size : tuple[int, ...]
            Target (padded) size for the last dimensions.
        im_size : tuple[int, ...]
            Original (unpadded) size for the last dimensions.
        in_shape : Shape, optional
            Named shape for the input spatial dimensions.
        out_shape : Shape, optional
            Named shape for the output spatial dimensions.
        batch_shape : Shape, optional
            Named shape for batch dimensions.
        """
        if len(pad_im_size) != len(im_size):
            raise ValueError(
                f"Padded and unpadded dims should be the same length. padded: {pad_im_size} unpadded: {im_size}"
            )

        if in_shape is None:
            self.in_im_shape = ND.infer(get_nd_shape(im_size))
        else:
            self.in_im_shape = ND.infer(in_shape)
        if out_shape is None:
            self.out_im_shape = tuple(
                d.next_unused(self.in_im_shape) for d in self.in_im_shape
            )
        else:
            self.out_im_shape = out_shape
        batch_shape = default_to(("...",), batch_shape)

        shape = NS(batch_shape) + NS(self.in_im_shape, self.out_im_shape)
        super().__init__(shape)
        self.D = len(im_size)
        self.im_size = tuple(im_size)
        self.pad_im_size = tuple(pad_im_size)
        self.in_im_size = tuple(im_size)
        self.out_im_size = tuple(pad_im_size)
        # for psz in pad_im_size:
        #     assert not (psz % 2), "Pad sizes must be even"

        # sizes = [
        #     [(psz - isz) // 2] * 2
        #     for psz, isz in zip(self.out_im_size, self.in_im_size)
        # ]
        # self.pad = sum(sizes, start=[])
        # self.pad.reverse()

        self.pad = pad_to_size(self.im_size, self.pad_im_size)

        # Make crop slice that undoes padding
        # Need to reverse crop_slice because padding is reversed
        self.crop_slice = crop_slice_from_pad(self.pad)

    @staticmethod
    def fn(padlast, x, /):
        if tuple(x.shape[-padlast.D :]) != padlast.im_size:
            raise ValueError(
                f"Mismatched shapes: expected {padlast.im_size} but got {x.shape[-padlast.D :]}"
            )
        pad = padlast.pad + [0, 0] * (x.ndim - padlast.D)
        return F.pad(x, pad)

    @staticmethod
    def adj_fn(padlast, y, /):
        """Crop the last n dimensions of y"""
        if tuple(y.shape[-padlast.D :]) != padlast.pad_im_size:
            raise ValueError(
                f"Mismatched shapes: expected {padlast.pad_im_size} but got {y.shape[-padlast.D :]}"
            )
        slc = [slice(None)] * (y.ndim - padlast.D) + padlast.crop_slice
        return y[tuple(slc)]

    def adjoint(self):
        adj = super().adjoint()
        if self.name == "Pad":
            adj.name = "Crop"
        elif self.name == "Crop":
            adj.name = "Pad"

        adj.in_im_shape, adj.out_im_shape = self.out_im_shape, self.in_im_shape
        adj.in_im_size, adj.out_im_size = self.out_im_size, self.in_im_size
        return adj

    def split_forward(self, ibatch, obatch):
        for islc, oslc in zip(ibatch[-self.D :], obatch[-self.D :]):
            if islc != slice(None) or oslc != slice(None):
                raise ValueError(
                    f"{type(self).__name__} cannot be split along image dim"
                )
        return self

    def size(self, dim: str):
        if dim in self.ishape[-self.D :]:
            return self.in_im_size[self.in_im_shape.index(dim)]
        elif dim in self.oshape[-self.D :]:
            return self.out_im_size[self.out_im_shape.index(dim)]
        return None

__init__

__init__(
    pad_im_size: tuple[int, ...],
    im_size: tuple[int, ...],
    in_shape: Optional[Shape] = None,
    out_shape: Optional[Shape] = None,
    batch_shape: Optional[Shape] = None,
)
PARAMETER DESCRIPTION
pad_im_size

Target (padded) size for the last dimensions.

TYPE: tuple[int, ...]

im_size

Original (unpadded) size for the last dimensions.

TYPE: tuple[int, ...]

in_shape

Named shape for the input spatial dimensions.

TYPE: Shape DEFAULT: None

out_shape

Named shape for the output spatial dimensions.

TYPE: Shape DEFAULT: None

batch_shape

Named shape for batch dimensions.

TYPE: Shape DEFAULT: None

Source code in src/torchlinops/linops/pad_last.py
def __init__(
    self,
    pad_im_size: tuple[int, ...],
    im_size: tuple[int, ...],
    in_shape: Optional[Shape] = None,
    out_shape: Optional[Shape] = None,
    batch_shape: Optional[Shape] = None,
):
    """
    Parameters
    ----------
    pad_im_size : tuple[int, ...]
        Target (padded) size for the last dimensions.
    im_size : tuple[int, ...]
        Original (unpadded) size for the last dimensions.
    in_shape : Shape, optional
        Named shape for the input spatial dimensions.
    out_shape : Shape, optional
        Named shape for the output spatial dimensions.
    batch_shape : Shape, optional
        Named shape for batch dimensions.
    """
    if len(pad_im_size) != len(im_size):
        raise ValueError(
            f"Padded and unpadded dims should be the same length. padded: {pad_im_size} unpadded: {im_size}"
        )

    if in_shape is None:
        self.in_im_shape = ND.infer(get_nd_shape(im_size))
    else:
        self.in_im_shape = ND.infer(in_shape)
    if out_shape is None:
        self.out_im_shape = tuple(
            d.next_unused(self.in_im_shape) for d in self.in_im_shape
        )
    else:
        self.out_im_shape = out_shape
    batch_shape = default_to(("...",), batch_shape)

    shape = NS(batch_shape) + NS(self.in_im_shape, self.out_im_shape)
    super().__init__(shape)
    self.D = len(im_size)
    self.im_size = tuple(im_size)
    self.pad_im_size = tuple(pad_im_size)
    self.in_im_size = tuple(im_size)
    self.out_im_size = tuple(pad_im_size)
    # for psz in pad_im_size:
    #     assert not (psz % 2), "Pad sizes must be even"

    # sizes = [
    #     [(psz - isz) // 2] * 2
    #     for psz, isz in zip(self.out_im_size, self.in_im_size)
    # ]
    # self.pad = sum(sizes, start=[])
    # self.pad.reverse()

    self.pad = pad_to_size(self.im_size, self.pad_im_size)

    # Make crop slice that undoes padding
    # Need to reverse crop_slice because padding is reversed
    self.crop_slice = crop_slice_from_pad(self.pad)

adj_fn staticmethod

adj_fn(padlast, y)

Crop the last n dimensions of y

Source code in src/torchlinops/linops/pad_last.py
@staticmethod
def adj_fn(padlast, y, /):
    """Crop the last n dimensions of y"""
    if tuple(y.shape[-padlast.D :]) != padlast.pad_im_size:
        raise ValueError(
            f"Mismatched shapes: expected {padlast.pad_im_size} but got {y.shape[-padlast.D :]}"
        )
    slc = [slice(None)] * (y.ndim - padlast.D) + padlast.crop_slice
    return y[tuple(slc)]

PadDim

Bases: NamedLinop

Zero-padding operator along a specified dimension.

Pads the input with zeros. The adjoint truncates (slices) back to the original size.

Source code in src/torchlinops/linops/trunc_pad.py
class PadDim(NamedLinop):
    """Zero-padding operator along a specified dimension.

    Pads the input with zeros. The adjoint truncates (slices) back to
    the original size.
    """

    def __init__(self, dim, from_length, to_length, ishape, oshape):
        self.dim = dim
        self.from_length = from_length
        self.to_length = to_length
        if self.to_length < 0:
            raise ValueError(f"to_length must be nonnegative but got {self.to_length}")

        if self.from_length < 0 or self.from_length > self.to_length:
            raise ValueError(
                f"to_length must be in [0, {self.to_length}] but got {self.from_length}"
            )

        self.slc = [slice(None)] * len(ishape)
        self.slc[dim] = slice(0, self.from_length)
        self.slc = tuple(self.slc)
        super().__init__(NS(ishape, oshape))

    def adjoint(self):
        return Truncate(
            self.dim, self.to_length, self.from_length, self.oshape, self.ishape
        )

    def normal(self, inner=None):
        """Diagonal in all dims except the last one"""
        if inner is None:
            return Identity(self.ishape)
        pre = copy(self)
        post = copy(self).H
        pre.oshape = inner.ishape
        post.ishape = inner.oshape
        return post @ inner @ pre

    @staticmethod
    def fn(padend, x, /):
        if x.shape[padend.dim] != padend.from_length:
            raise ValueError(
                f"padend expecting size {padend.from_length} at x.shape[{padend.dim}] but got {x.shape[padend.dim]} (x.shape: {x.shape})"
            )
        return end_pad_with_zeros(x, padend.dim, padend.to_length - padend.from_length)

    @staticmethod
    def adj_fn(padend, y, /):
        if y.shape[padend.dim] != padend.to_length:
            raise ValueError(
                f"PadDim (adjoint) expecting size {padend.to_length} at x.shape[{padend.dim}] but got {y.shape[padend.dim]} (y.shape: {y.shape})"
            )
        return y[padend.slc]

    @staticmethod
    def normal_fn(padend, x, /):
        x = x.clone()
        return x

    def split_forward(self, ibatch, obatch):
        if ibatch[self.dim] != slice(None) or obatch[self.dim] != slice(None):
            raise ValueError("Cannot slice a PadEnd linop along truncation dimension")
        return type(self)(
            self.dim, self.from_length, self.to_length, self.ishape, self.oshape
        )

normal

normal(inner=None)

Diagonal in all dims except the last one

Source code in src/torchlinops/linops/trunc_pad.py
def normal(self, inner=None):
    """Diagonal in all dims except the last one"""
    if inner is None:
        return Identity(self.ishape)
    pre = copy(self)
    post = copy(self).H
    pre.oshape = inner.ishape
    post.ishape = inner.oshape
    return post @ inner @ pre

Rearrange

Bases: NamedLinop

Dimension rearrangement via einops.rearrange.

Wraps einops.rearrange as a named linear operator. The adjoint performs the inverse rearrangement.

Source code in src/torchlinops/linops/einops.py
class Rearrange(NamedLinop):
    """Dimension rearrangement via ``einops.rearrange``.

    Wraps ``einops.rearrange`` as a named linear operator. The adjoint
    performs the inverse rearrangement.
    """

    def __init__(
        self,
        ipattern,
        opattern,
        ishape: Shape,
        oshape: Shape,
        axes_lengths: Optional[Mapping] = None,
    ):
        """
        Parameters
        ----------
        ipattern : str
            Input pattern string for ``einops.rearrange`` (left-hand side of
            the ``->`` arrow).
        opattern : str
            Output pattern string for ``einops.rearrange`` (right-hand side
            of the ``->`` arrow).
        ishape : Shape
            Named input shape specification.
        oshape : Shape
            Named output shape specification.
        axes_lengths : Mapping, optional
            Mapping from axis names to their sizes, passed as keyword
            arguments to ``einops.rearrange`` for resolving ambiguous
            dimensions (e.g., when splitting or merging axes).
        """
        super().__init__(NS(ishape, oshape))
        self.ipattern = ipattern
        self.opattern = opattern
        axes_lengths = axes_lengths if axes_lengths is not None else {}
        self._shape.axes_lengths = axes_lengths

    @property
    def axes_lengths(self):
        return self._shape.axes_lengths

    @staticmethod
    def fn(linop, x, /):
        axes_lengths = {str(k): v for k, v in linop.axes_lengths.items()}
        return rearrange(x, f"{linop.ipattern} -> {linop.opattern}", **axes_lengths)

    @staticmethod
    def adj_fn(linop, x, /):
        axes_lengths = {str(k): v for k, v in linop.axes_lengths.items()}
        return rearrange(x, f"{linop.opattern} -> {linop.ipattern}", **axes_lengths)

    def split_forward(self, ibatch, obatch):
        """TODO: Add compound shapes so splitting through rearrange can work."""
        warn(
            f"Splitting Rearrange linop with shape {self._shape} - splitting a rearrange may behave unusually."
        )
        new_axes_lengths = deepcopy(self.axes_lengths)
        for dim, slc in zip(self.ishape, ibatch):
            if dim in self.axes_lengths:
                n = self.axes_lengths[dim]
                new_axes_lengths[dim] = slicelen(n, slc)
        for dim, slc in zip(self.oshape, obatch):
            if dim in self.axes_lengths:
                n = self.axes_lengths[dim]
                new_axes_lengths[dim] = slicelen(n, slc)

        out = type(self)(
            self.ipattern, self.opattern, self.ishape, self.oshape, new_axes_lengths
        )
        return out

    def size(self, dim: str):
        """Rearranging does not determine any dimensions"""
        return None

    def normal(self, inner=None):
        if inner is None:
            return Identity(self.ishape)
        return super().normal(inner)

__init__

__init__(
    ipattern,
    opattern,
    ishape: Shape,
    oshape: Shape,
    axes_lengths: Optional[Mapping] = None,
)
PARAMETER DESCRIPTION
ipattern

Input pattern string for einops.rearrange (left-hand side of the -> arrow).

TYPE: str

opattern

Output pattern string for einops.rearrange (right-hand side of the -> arrow).

TYPE: str

ishape

Named input shape specification.

TYPE: Shape

oshape

Named output shape specification.

TYPE: Shape

axes_lengths

Mapping from axis names to their sizes, passed as keyword arguments to einops.rearrange for resolving ambiguous dimensions (e.g., when splitting or merging axes).

TYPE: Mapping DEFAULT: None

Source code in src/torchlinops/linops/einops.py
def __init__(
    self,
    ipattern,
    opattern,
    ishape: Shape,
    oshape: Shape,
    axes_lengths: Optional[Mapping] = None,
):
    """
    Parameters
    ----------
    ipattern : str
        Input pattern string for ``einops.rearrange`` (left-hand side of
        the ``->`` arrow).
    opattern : str
        Output pattern string for ``einops.rearrange`` (right-hand side
        of the ``->`` arrow).
    ishape : Shape
        Named input shape specification.
    oshape : Shape
        Named output shape specification.
    axes_lengths : Mapping, optional
        Mapping from axis names to their sizes, passed as keyword
        arguments to ``einops.rearrange`` for resolving ambiguous
        dimensions (e.g., when splitting or merging axes).
    """
    super().__init__(NS(ishape, oshape))
    self.ipattern = ipattern
    self.opattern = opattern
    axes_lengths = axes_lengths if axes_lengths is not None else {}
    self._shape.axes_lengths = axes_lengths

size

size(dim: str)

Rearranging does not determine any dimensions

Source code in src/torchlinops/linops/einops.py
def size(self, dim: str):
    """Rearranging does not determine any dimensions"""
    return None

split_forward

split_forward(ibatch, obatch)

TODO: Add compound shapes so splitting through rearrange can work.

Source code in src/torchlinops/linops/einops.py
def split_forward(self, ibatch, obatch):
    """TODO: Add compound shapes so splitting through rearrange can work."""
    warn(
        f"Splitting Rearrange linop with shape {self._shape} - splitting a rearrange may behave unusually."
    )
    new_axes_lengths = deepcopy(self.axes_lengths)
    for dim, slc in zip(self.ishape, ibatch):
        if dim in self.axes_lengths:
            n = self.axes_lengths[dim]
            new_axes_lengths[dim] = slicelen(n, slc)
    for dim, slc in zip(self.oshape, obatch):
        if dim in self.axes_lengths:
            n = self.axes_lengths[dim]
            new_axes_lengths[dim] = slicelen(n, slc)

    out = type(self)(
        self.ipattern, self.opattern, self.ishape, self.oshape, new_axes_lengths
    )
    return out

Repeat

Bases: NamedLinop

Repeat (expand) operator along specified dimensions (adjoint of SumReduce).

Wraps einops.repeat as a named linear operator.

Source code in src/torchlinops/linops/einops.py
class Repeat(NamedLinop):
    """Repeat (expand) operator along specified dimensions (adjoint of ``SumReduce``).

    Wraps ``einops.repeat`` as a named linear operator.
    """

    def __init__(
        self,
        n_repeats: Mapping,
        ishape: Shape,
        oshape: Shape,
        broadcast_dims: Optional[list] = None,
    ):
        """
        Parameters
        ----------
        n_repeats : Mapping
            Mapping from dimension names to the number of repetitions
            along each new dimension.
        ishape : Shape
            Named input shape specification.
        oshape : Shape
            Named output shape specification. Must have more dimensions
            than ``ishape``.
        broadcast_dims : list, optional
            Dimensions that are broadcast (size unknown until runtime)
            rather than having a fixed repeat count.
        """
        super().__init__(NS(ishape, oshape))
        assert len(self.oshape) > len(self.ishape), (
            f"Repeat must add at least one dimension: got {self.ishape} -> {self.oshape}"
        )
        self._shape.axes_lengths = n_repeats
        if broadcast_dims is not None:
            self._shape.broadcast_dims = broadcast_dims

    @property
    def axes_lengths(self):
        return self._shape.axes_lengths

    @property
    def broadcast_dims(self):
        return self._shape.broadcast_dims

    def forward(self, x):
        return self.fn(self, x)

    @staticmethod
    def fn(linop, x, /):
        x = repeat(
            x,
            f"{linop.ipattern} -> {linop.opattern}",
            **{str(k): v for k, v in linop.axes_lengths.items()},
        )
        return x

    @staticmethod
    def adj_fn(linop, x, /):
        x = reduce(x, f"{linop.opattern} -> {linop.ipattern}", "sum")
        return x

    def split_forward(self, ibatch, obatch):
        """Repeat fewer times, depending on the size of obatch"""
        new_axes_lengths = deepcopy(self.axes_lengths)
        for dim, slc in zip(self.oshape, obatch):
            if dim in self.axes_lengths and dim not in self.broadcast_dims:
                new_axes_lengths[dim] = slicelen(self.size(dim), slc)
        return type(self)(
            new_axes_lengths, self.ishape, self.oshape, self.broadcast_dims
        )

    def size(self, dim: str):
        if dim in self.broadcast_dims:
            return None
        return self.axes_lengths.get(dim, None)

    def adjoint(self):
        return SumReduce(self._shape.H, None)

    def normal(self, inner=None):
        pre = copy(self)
        post = self.adjoint()
        if inner is not None:
            pre.oshape = inner.ishape
            post.ishape = inner.oshape
            return post @ inner @ pre
        # No updated dims because Repeat -> SumReduce
        # Gets rid of the new dimensions immediately
        # TODO: simplify this more?
        return post @ pre

    @property
    def adj_ishape(self):
        return self.fill_singleton_dims(self.oshape, self.ishape)

    @property
    def adj_ipattern(self):
        return " ".join(str(d) if d is not None else "()" for d in self.adj_ishape)

    @property
    def ipattern(self):
        return " ".join(str(d) for d in self.ishape)

    @property
    def opattern(self):
        return " ".join(str(d) for d in self.oshape)

    @staticmethod
    def fill_singleton_dims(ishape, oshape):
        out = []
        for idim in ishape:
            if idim in oshape:
                out.append(idim)
            else:
                out.append(None)
        return tuple(out)

__init__

__init__(
    n_repeats: Mapping,
    ishape: Shape,
    oshape: Shape,
    broadcast_dims: Optional[list] = None,
)
PARAMETER DESCRIPTION
n_repeats

Mapping from dimension names to the number of repetitions along each new dimension.

TYPE: Mapping

ishape

Named input shape specification.

TYPE: Shape

oshape

Named output shape specification. Must have more dimensions than ishape.

TYPE: Shape

broadcast_dims

Dimensions that are broadcast (size unknown until runtime) rather than having a fixed repeat count.

TYPE: list DEFAULT: None

Source code in src/torchlinops/linops/einops.py
def __init__(
    self,
    n_repeats: Mapping,
    ishape: Shape,
    oshape: Shape,
    broadcast_dims: Optional[list] = None,
):
    """
    Parameters
    ----------
    n_repeats : Mapping
        Mapping from dimension names to the number of repetitions
        along each new dimension.
    ishape : Shape
        Named input shape specification.
    oshape : Shape
        Named output shape specification. Must have more dimensions
        than ``ishape``.
    broadcast_dims : list, optional
        Dimensions that are broadcast (size unknown until runtime)
        rather than having a fixed repeat count.
    """
    super().__init__(NS(ishape, oshape))
    assert len(self.oshape) > len(self.ishape), (
        f"Repeat must add at least one dimension: got {self.ishape} -> {self.oshape}"
    )
    self._shape.axes_lengths = n_repeats
    if broadcast_dims is not None:
        self._shape.broadcast_dims = broadcast_dims

split_forward

split_forward(ibatch, obatch)

Repeat fewer times, depending on the size of obatch

Source code in src/torchlinops/linops/einops.py
def split_forward(self, ibatch, obatch):
    """Repeat fewer times, depending on the size of obatch"""
    new_axes_lengths = deepcopy(self.axes_lengths)
    for dim, slc in zip(self.oshape, obatch):
        if dim in self.axes_lengths and dim not in self.broadcast_dims:
            new_axes_lengths[dim] = slicelen(self.size(dim), slc)
    return type(self)(
        new_axes_lengths, self.ishape, self.oshape, self.broadcast_dims
    )

RepeatedEvent

Manage a FIFO queue of CUDA events for stream synchronization.

.. deprecated:: This class is deprecated and will be removed in version 0.7.0. The functionality is no longer used internally.

Keeps only the most recent event, dropping old references to free resources. The wrapper itself can be passed directly to wait_event().

Source code in src/torchlinops/utils/_event.py
class RepeatedEvent:
    """Manage a FIFO queue of CUDA events for stream synchronization.

    .. deprecated::
        This class is deprecated and will be removed in version 0.7.0.
        The functionality is no longer used internally.

    Keeps only the most recent event, dropping old references to free
    resources. The wrapper itself can be passed directly to wait_event().
    """

    def __init__(self, **event_kwargs):
        warnings.warn(
            "RepeatedEvent is deprecated and will be removed in version 0.7.0. "
            "This class is no longer used internally.",
            FutureWarning,
            stacklevel=2,
        )
        """
        A wrapper so each record() creates a fresh CUDA event,
        but the wrapper itself can be passed directly to wait_event().

        Parameters
        ----------
        **event_kwargs
            Keyword arguments passed to ``torch.cuda.Event(...)``.
        """
        self._event_kwargs = event_kwargs
        self._last_event = None

    def record(self, stream=None):  # pragma: no cover
        """
        Create a new CUDA event and record it on the given stream.
        Old events are dropped immediately to free resources.
        """
        # Drop old event reference
        self._last_event = None

        # Create and record new event
        ev = torch.cuda.Event(**self._event_kwargs)
        if stream is None:
            stream = torch.cuda.current_stream()
        ev.record(stream)

        # Store and return self for chaining
        self._last_event = ev
        return self

    @property
    def last_event(self):
        return self._last_event

    def __repr__(self):
        return f"<RepeatedEvent wrapping {self._last_event!r}>"

record

record(stream=None)

Create a new CUDA event and record it on the given stream. Old events are dropped immediately to free resources.

Source code in src/torchlinops/utils/_event.py
def record(self, stream=None):  # pragma: no cover
    """
    Create a new CUDA event and record it on the given stream.
    Old events are dropped immediately to free resources.
    """
    # Drop old event reference
    self._last_event = None

    # Create and record new event
    ev = torch.cuda.Event(**self._event_kwargs)
    if stream is None:
        stream = torch.cuda.current_stream()
    ev.record(stream)

    # Store and return self for chaining
    self._last_event = ev
    return self

Sampling

Bases: NamedLinop

Sample a tensor at some specified integer locations.

Input: (batch_shape, input_shape)
Output: (batch_shape, output_shape)
Source code in src/torchlinops/linops/sampling.py
class Sampling(NamedLinop):
    """Sample a tensor at some specified integer locations.

    ```
    Input: (batch_shape, input_shape)
    Output: (batch_shape, output_shape)
    ```

    """

    def __init__(
        self,
        idx: tuple[Integer[Tensor, "..."], ...],
        input_size: tuple[int, ...],
        output_shape: Optional[Shape] = None,
        input_shape: Optional[Shape] = None,
        batch_shape: Optional[Shape] = None,
    ):
        """
        Parameters
        ----------
        idx : tuple[Integer[Tensor, "..."], ...]
            tuple of of D  tensors, each of shape [M...]
            One index for each "sampled" axis of the input tensor
            Use `canonicalize_idx` to turn a tensor of shape [M... D] to a D-tuple of index tensors.
            idx is in range [0, size-1]
        input_size : tuple[int, ...]
            Actual shape of the input interpolated tensor, without the batch dimensions.
        output_shape : Shape, optional
            Named dimensions for the output.
        input_shape : Shape, optional
            Named dimensions for the input.
        batch_shape : Shape, optional
            Named batch dimensions.

        Notes
        -----
        Sampling: (B..., N...) -> (B..., M...)
        """
        dim = len(input_size)
        if len(idx) != dim:
            raise ValueError(
                f"Input size {input_size} doesn't match index with length {len(idx)}."
            )
        self.input_size = input_size
        batch_shape = default_to(("...",), batch_shape)
        input_shape = default_to(("...",), input_shape)
        output_shape = default_to(("...",), output_shape)
        shape = NS(batch_shape) + NS(input_shape, output_shape)
        super().__init__(shape)
        self._shape.batch_shape = batch_shape
        self._shape.input_shape = input_shape
        self._shape.output_shape = output_shape
        idx = F.ensure_tensor_indexing(idx, self.input_size)
        for d, (t, s) in enumerate(zip(idx, self.input_size)):
            if (t < 0).any() or (t >= s).any():
                raise ValueError(
                    f"Sampling index must lie within range [0, {s - 1}] but got range [{t.min().item()}, {t.max().item()}] for dim {d}"
                )
        self.idx = nn.ParameterList([nn.Parameter(i, requires_grad=False) for i in idx])

    @property
    def locs(self):
        """for compatibility with Interpolate linop"""
        return torch.stack(tuple(self.idx), dim=-1)

    @classmethod
    def from_mask(cls, mask, *args, **kwargs):
        """Alternative constructor for mask-based sampling"""
        idx = F.mask2idx(mask.bool())
        return cls(idx, *args, **kwargs)

    @classmethod
    def from_stacked_idx(cls, idx: Tensor, *args, dim=-1, **kwargs):
        """Alternative constructor for index in [M... D] form"""
        idx = F.canonicalize_idx(idx, dim=-1)
        return cls(idx, *args, **kwargs)

    @staticmethod
    def fn(sampling, x, /):
        return F.index(x, tuple(sampling.idx))

    @staticmethod
    def adj_fn(sampling, x, /):
        return F.index_adjoint(x, tuple(sampling.idx), sampling.input_size)

    def split_forward(self, ibatch, obatch):
        if self._shape.output_shape == ELLIPSES:
            # Cannot split if idx batch shape is not split
            return self
        return type(self)(
            self.split_idx(ibatch, obatch, self.idx),
            self.input_size,
            self._shape.output_shape,
            self._shape.input_shape,
            self._shape.batch_shape,
        )

    def split_idx(self, ibatch, obatch, idx):
        num_output_dims = len(idx[0].shape)
        if num_output_dims > 0:
            idx_slc = tuple(obatch[-num_output_dims:])
            return [i[idx_slc] for i in idx]
        return idx

    def register_shape(self, name, shape: tuple):
        self._shape[name] = shape

    def size(self, dim):
        if dim in self._shape.output_shape:
            dim_idx = self._shape.output_shape.index(dim)
            return self.locs.shape[dim_idx]
        elif dim in self._shape.input_shape:
            dim_idx = self._shape.input_shape.index(dim)
            return self.input_size[dim_idx]
        return None

locs property

locs

for compatibility with Interpolate linop

__init__

__init__(
    idx: tuple[Integer[Tensor, ...], ...],
    input_size: tuple[int, ...],
    output_shape: Optional[Shape] = None,
    input_shape: Optional[Shape] = None,
    batch_shape: Optional[Shape] = None,
)
PARAMETER DESCRIPTION
idx

tuple of of D tensors, each of shape [M...] One index for each "sampled" axis of the input tensor Use canonicalize_idx to turn a tensor of shape [M... D] to a D-tuple of index tensors. idx is in range [0, size-1]

TYPE: tuple[Integer[Tensor, ...], ...]

input_size

Actual shape of the input interpolated tensor, without the batch dimensions.

TYPE: tuple[int, ...]

output_shape

Named dimensions for the output.

TYPE: Shape DEFAULT: None

input_shape

Named dimensions for the input.

TYPE: Shape DEFAULT: None

batch_shape

Named batch dimensions.

TYPE: Shape DEFAULT: None

Notes

Sampling: (B..., N...) -> (B..., M...)

Source code in src/torchlinops/linops/sampling.py
def __init__(
    self,
    idx: tuple[Integer[Tensor, "..."], ...],
    input_size: tuple[int, ...],
    output_shape: Optional[Shape] = None,
    input_shape: Optional[Shape] = None,
    batch_shape: Optional[Shape] = None,
):
    """
    Parameters
    ----------
    idx : tuple[Integer[Tensor, "..."], ...]
        tuple of of D  tensors, each of shape [M...]
        One index for each "sampled" axis of the input tensor
        Use `canonicalize_idx` to turn a tensor of shape [M... D] to a D-tuple of index tensors.
        idx is in range [0, size-1]
    input_size : tuple[int, ...]
        Actual shape of the input interpolated tensor, without the batch dimensions.
    output_shape : Shape, optional
        Named dimensions for the output.
    input_shape : Shape, optional
        Named dimensions for the input.
    batch_shape : Shape, optional
        Named batch dimensions.

    Notes
    -----
    Sampling: (B..., N...) -> (B..., M...)
    """
    dim = len(input_size)
    if len(idx) != dim:
        raise ValueError(
            f"Input size {input_size} doesn't match index with length {len(idx)}."
        )
    self.input_size = input_size
    batch_shape = default_to(("...",), batch_shape)
    input_shape = default_to(("...",), input_shape)
    output_shape = default_to(("...",), output_shape)
    shape = NS(batch_shape) + NS(input_shape, output_shape)
    super().__init__(shape)
    self._shape.batch_shape = batch_shape
    self._shape.input_shape = input_shape
    self._shape.output_shape = output_shape
    idx = F.ensure_tensor_indexing(idx, self.input_size)
    for d, (t, s) in enumerate(zip(idx, self.input_size)):
        if (t < 0).any() or (t >= s).any():
            raise ValueError(
                f"Sampling index must lie within range [0, {s - 1}] but got range [{t.min().item()}, {t.max().item()}] for dim {d}"
            )
    self.idx = nn.ParameterList([nn.Parameter(i, requires_grad=False) for i in idx])

from_mask classmethod

from_mask(mask, *args, **kwargs)

Alternative constructor for mask-based sampling

Source code in src/torchlinops/linops/sampling.py
@classmethod
def from_mask(cls, mask, *args, **kwargs):
    """Alternative constructor for mask-based sampling"""
    idx = F.mask2idx(mask.bool())
    return cls(idx, *args, **kwargs)

from_stacked_idx classmethod

from_stacked_idx(idx: Tensor, *args, dim=-1, **kwargs)

Alternative constructor for index in [M... D] form

Source code in src/torchlinops/linops/sampling.py
@classmethod
def from_stacked_idx(cls, idx: Tensor, *args, dim=-1, **kwargs):
    """Alternative constructor for index in [M... D] form"""
    idx = F.canonicalize_idx(idx, dim=-1)
    return cls(idx, *args, **kwargs)

Scalar

Bases: Diagonal

Scalar multiplication operator \(S(x) = \alpha x\).

A special case of Diagonal where the weight is a scalar, making it trivially splittable (the same scalar applies to every tile).

Source code in src/torchlinops/linops/scalar.py
class Scalar(Diagonal):
    """Scalar multiplication operator $S(x) = \\alpha x$.

    A special case of ``Diagonal`` where the weight is a scalar, making it
    trivially splittable (the same scalar applies to every tile).
    """

    def __init__(self, weight, ioshape: Optional[Shape] = None):
        """
        Parameters
        ----------
        weight : float or Tensor
            The scalar multiplier $\\alpha$.
        ioshape : Shape, optional
            Named dimensions (same for input and output).
        """
        if not isinstance(weight, torch.Tensor):
            weight = torch.tensor(weight)
        ioshape = default_to(("...",), ioshape)
        super().__init__(weight, ioshape=ioshape)

    def split_weight(self, ibatch, obatch, /, weight):
        assert ibatch == obatch, "Scalar linop must be split identically"
        return weight

    def size(self, dim: str):
        return None

__init__

__init__(weight, ioshape: Optional[Shape] = None)
PARAMETER DESCRIPTION
weight

The scalar multiplier \(\alpha\).

TYPE: float or Tensor

ioshape

Named dimensions (same for input and output).

TYPE: Shape DEFAULT: None

Source code in src/torchlinops/linops/scalar.py
def __init__(self, weight, ioshape: Optional[Shape] = None):
    """
    Parameters
    ----------
    weight : float or Tensor
        The scalar multiplier $\\alpha$.
    ioshape : Shape, optional
        Named dimensions (same for input and output).
    """
    if not isinstance(weight, torch.Tensor):
        weight = torch.tensor(weight)
    ioshape = default_to(("...",), ioshape)
    super().__init__(weight, ioshape=ioshape)

ShapeSpec

Bases: Identity

Identity operator that renames dimensions.

Functionally identical to Identity but maps from one set of named dimensions to another, acting as a shape adapter between linops.

Source code in src/torchlinops/linops/identity.py
class ShapeSpec(Identity):
    """Identity operator that renames dimensions.

    Functionally identical to ``Identity`` but maps from one set of named
    dimensions to another, acting as a shape adapter between linops.
    """

    def adjoint(self):
        new = copy(self)
        new.shape = self.shape.adjoint()
        return new

    def normal(self, inner=None):
        if inner is None:
            # Behaves like a diagonal linop
            return ShapeSpec(self.ishape, self.ishape)
        pre = copy(self)
        post = self.adjoint()
        pre.oshape = inner.ishape
        post.ishape = inner.oshape
        normal = post @ inner @ pre
        normal._shape_updates = getattr(inner, "_shape_updates", {})
        return normal

Stack

Bases: Threadable, NamedLinop

Concatenate some linops along a new dimension.

Linops need not output tensors of the same size, but they should output tensors of the same number of dimensions.

Stacking type depends on dimensions provided:

Horizontal stacking (stacking along an input dimension)::

A B C

Vertical stacking (stacking along an output dimension)::

A
B
C

Diagonal stacking (stacking along separate input and output dimensions)::

A . .
. B .
. . C

Inherits from Threadable to support parallel execution of sub-linops. When threaded=True (default), each sub-linop is executed in parallel using a ThreadPoolExecutor.

Note that shared linops (e.g., Stack(A, A, odim_and_idx=("L", 0))) are automatically shallow-copied to ensure independent identity for threading, while still sharing tensor data. See Threadable for details.

ATTRIBUTE DESCRIPTION
linops

The list of linops being stacked.

TYPE: ModuleList

threaded

Whether to run sub-linops in parallel. Default is True.

TYPE: bool

num_workers

Number of worker threads. If None, defaults to the number of sub-linops.

TYPE: int | None

idim

Input stacking dimension name.

TYPE: NamedDimension | None

idim_idx

Index position of the input stacking dimension.

TYPE: int | None

odim

Output stacking dimension name.

TYPE: NamedDimension | None

odim_idx

Index position of the output stacking dimension.

TYPE: int | None

Source code in src/torchlinops/linops/stack.py
class Stack(Threadable, NamedLinop):
    """Concatenate some linops along a new dimension.

    Linops need not output tensors of the same size, but they should
    output tensors of the same number of dimensions.

    Stacking type depends on dimensions provided:

    Horizontal stacking (stacking along an input dimension)::

        A B C

    Vertical stacking (stacking along an output dimension)::

        A
        B
        C

    Diagonal stacking (stacking along separate input and output dimensions)::

        A . .
        . B .
        . . C

    Inherits from ``Threadable`` to support parallel execution of sub-linops.
    When ``threaded=True`` (default), each sub-linop is executed in parallel
    using a ThreadPoolExecutor.

    Note that shared linops (e.g., ``Stack(A, A, odim_and_idx=("L", 0))``) are
    automatically shallow-copied to ensure independent identity for threading,
    while still sharing tensor data. See ``Threadable`` for details.

    Attributes
    ----------
    linops : nn.ModuleList
        The list of linops being stacked.
    threaded : bool
        Whether to run sub-linops in parallel. Default is True.
    num_workers : int | None
        Number of worker threads. If None, defaults to the number of sub-linops.
    idim : ND | None
        Input stacking dimension name.
    idim_idx : int | None
        Index position of the input stacking dimension.
    odim : ND | None
        Output stacking dimension name.
    odim_idx : int | None
        Index position of the output stacking dimension.
    """

    def __init__(
        self,
        *linops: NamedLinop,
        idim_and_idx: tuple[Optional[ND | str], Optional[int]] = (None, None),
        odim_and_idx: tuple[Optional[ND | str], Optional[int]] = (None, None),
        **kwargs,
    ):
        """
        Parameters
        ----------
        *linops : NamedLinop
            The linops to stack.
        idim_and_idx : tuple, optional
            Tuple of ``(dim_name, index_tensor)`` for the input stacking dimension.
        odim_and_idx : tuple, optional
            Tuple of ``(dim_name, index_tensor)`` for the output stacking dimension.
        """

        self.idim, self.idim_idx, ishape = self._get_dim_and_idx(
            *idim_and_idx, linops[0].ishape
        )
        self.odim, self.odim_idx, oshape = self._get_dim_and_idx(
            *odim_and_idx, linops[0].oshape
        )

        # Initialize parent class
        super().__init__(NS(ishape, oshape), **kwargs)
        self.linops = nn.ModuleList(list(linops))
        self._check_linop_compatibility()

    @staticmethod
    def _get_dim_and_idx(dim, idx, shape):
        if dim is not None:
            dim = ND.infer(dim)
            if dim in shape:
                raise ValueError(
                    f"Stack linop attempting to add dim {dim} to shape {shape} but shape already contains {dim}"
                )
            shape = list(shape)
            shape.insert(idx, dim)
        else:
            dim = None
            idx = None
        return dim, idx, shape

    @staticmethod
    def fn(stack, x, /):
        return stack._fn(
            x,
            stack.linops,
            stack.idim_idx,
            stack.odim_idx,
            stack.threaded,
            stack.num_workers,
        )

    @staticmethod
    def adj_fn(stack, x, /):
        adj_linops = [linop.H for linop in stack.linops]
        return stack._fn(
            x,
            adj_linops,
            stack.odim_idx,
            stack.idim_idx,
            stack.threaded,
            stack.num_workers,
        )

    @staticmethod
    def _fn(
        x: Tensor,
        linops,
        idim_idx,
        odim_idx,
        threaded: bool = False,
        num_workers: int | None = None,
    ):
        """Unifies forward and adjoint functionality for stacked linops"""
        if idim_idx is not None:  # Diagonal, Horizontal
            if len(linops) != x.shape[idim_idx]:
                raise ValueError(
                    f"Stack Linop expecting input of size {len(linops)} at dim {idim_idx} got input of size {x.shape} with non-matching stack size {x.shape[idim_idx]}"
                )
            xs = x.tensor_split(len(linops), idim_idx)
            xs = [xi.squeeze(idim_idx) for xi in xs]
        else:  # Vertical
            xs = [x] * len(linops)

        if odim_idx is not None:  # Diagonal, Vertical
            if threaded:
                ys = _threaded_apply(list(linops), xs, num_workers)
            else:
                ys = [linop(xi) for xi, linop in zip(xs, linops)]
            return torch.stack(ys, dim=odim_idx)

        # Horizontal
        if threaded:
            y = _threaded_apply_sum_reduce(list(linops), xs, num_workers)
        else:
            y = 0
            for xi, linop in zip(xs, linops):
                y += linop(xi)
        return y

    def size(self, dim) -> int | None:
        if dim == self.idim or dim == self.odim:
            return len(self.linops)
        else:
            # https://github.com/pytorch/pytorch/issues/80821
            return self.linops[0].size(dim)  # type: ignore

    def split_forward(self, ibatch, obatch):
        """Split stack linop"""
        linop_idxs = set(range(len(self.linops)))
        for i, slc in enumerate(ibatch):
            if i == self.idim_idx:
                linop_idxs &= set(slice2range(slc, len(self.linops)))
        for i, slc in enumerate(obatch):
            if i == self.odim_idx:
                linop_idxs &= set(slice2range(slc, len(self.linops)))

        if len(linop_idxs) == 0:
            # No linops satisfy this slice (diagonal stacking)
            return Zero(self.ishape, self.oshape)
        linop_idxs = sorted(list(linop_idxs))
        output_linops = []
        # Remove stack dims from slice batch
        if self.idim_idx is not None:
            ibatch = ibatch.copy()
            ibatch.pop(self.idim_idx)
        if self.odim_idx is not None:
            obatch = obatch.copy()
            obatch.pop(self.odim_idx)

        # Slice each sub-linop
        for i in linop_idxs:
            linop = self.linops[i]
            islices = {dim: slc for dim, slc in zip(linop.ishape, ibatch)}
            oslices = {dim: slc for dim, slc in zip(linop.oshape, obatch)}
            slices = strict_update(islices, oslices)
            output_linops.append(linop.split(linop, slices))
        return self.spinoff(output_linops)

    def split_data(self, ibatch, obatch, data_list):
        """Split stack linop, making a new stack linop if necessary

        Parameters
        ----------
        data_list : list, same length as linops
            List of data for each linop in this stack linop

        """
        linop_idxs = set(range(len(self.linops)))
        for i, slc in enumerate(ibatch):
            if i == self.idim_idx:
                linop_idxs &= set(slice2range(slc, len(self.linops)))
        for i, slc in enumerate(obatch):
            if i == self.odim_idx:
                linop_idxs &= set(slice2range(slc, len(self.linops)))

        if len(linop_idxs) == 0:
            # No linops satisfy this slice (diagonal stacking)
            return 0.0  # TODO is this ok
        linop_idxs = sorted(list(linop_idxs))
        output_linop_data = []
        # Remove stack dims from slice batch
        if self.idim_idx is not None:
            ibatch = copy(ibatch)
            ibatch.pop(self.idim_idx)
        if self.odim_idx is not None:
            obatch = copy(obatch)
            obatch.pop(self.odim_idx)

        # Slice each sub-linop
        for i in linop_idxs:
            linop = self.linops[i]
            data = data_list[i]
            output_linop_data.append(linop.split_data(ibatch, obatch, data))
        return output_linop_data

    def adjoint(self):
        adj_linops = [linop.H for linop in self.linops]
        return self.spinoff(
            adj_linops,
            idim_and_idx=(self.odim, self.odim_idx),
            odim_and_idx=(self.idim, self.idim_idx),
        )

    def normal(self, inner=None):
        if inner is None:
            if self.idim is None:  # Vertical (inner product)
                # self.odim is not None
                return Add(*(linop.N for linop in self.linops))
            elif self.odim is None:  # Horizontal (outer product)
                # self.idim is not None
                new_idim, new_odim = self._get_new_normal_io_dims(
                    self._shape, self.idim
                )
                rows = []
                new_shape = self.linops[0].shape.N
                for linop_left in self.linops:
                    row = []
                    for linop_right in self.linops:
                        if linop_left == linop_right:
                            new_linop = linop_right.N
                        else:
                            new_linop = linop_left.H @ linop_right
                            new_linop.ishape = new_shape.ishape
                            new_linop.oshape = new_shape.oshape
                        row.append(new_linop)
                    row = self.spinoff(
                        row,
                        idim_and_idx=(new_idim, self.idim_idx),
                        odim_and_idx=(None, None),
                    )
                    rows.append(row)
                return self.spinoff(
                    rows,
                    idim_and_idx=(None, None),
                    odim_and_idx=(new_odim, self.idim_idx),
                )
            else:  # Diagonal
                # self.idim and self.odim are not None
                diag = []
                new_idim, new_odim = self._get_new_normal_io_dims(
                    self._shape, self.idim
                )
                for linop in self.linops:
                    diag.append(linop.N)
                return self.spinoff(
                    diag,
                    idim_and_idx=(new_idim, self.idim_idx),
                    odim_and_idx=(new_odim, self.odim_idx),
                )
        return super().normal(inner)

    @staticmethod
    def _get_new_normal_io_dims(shape, dim) -> tuple:
        new_shape = shape.N
        i = new_shape.ishape.index(dim)
        new_idim = new_shape.ishape[i]
        new_odim = new_shape.oshape[i]
        return new_idim, new_odim

    def _check_linop_compatibility(self):
        """Ensure linops can actually be concatenated along the requested dimension"""
        target_shape = self.linops[0].shape
        for linop in self.linops:
            if not (
                isequal(target_shape.ishape, linop.ishape)
                and isequal(target_shape.oshape, linop.oshape)
            ):
                raise ValueError(
                    f"Incompatible linops being stacked. Target shape: {target_shape} but got linop shape: {linop.shape}"
                )

    def __getitem__(self, idx):
        linops = self.linops[idx]
        if isinstance(linops, NamedLinop):
            return linops
        return self.spinoff(linops)

    def spinoff(
        self,
        linops=None,
        shape=None,
        idim_and_idx=(None, None),
        odim_and_idx=(None, None),
    ):
        """Helper function for creating a new linop using the provided inputs.

        Preserves settings from the original linop.

        Parameters
        ----------
        linops : list[NamedLinop], optional
            The linops for the new instance. Defaults to self.linops.
        shape : NamedShape, optional
            The shape for the new instance. If None, computed from linops
            and idim/odim.
        idim_and_idx : tuple[ND | None, int | None], optional
            Tuple of (dim_name, index) for input stacking dimension.
            Defaults to (self.idim, self.idim_idx).
        odim_and_idx : tuple[ND | None, int | None], optional
            Tuple of (dim_name, index) for output stacking dimension.
            Defaults to (self.odim, self.odim_idx).
        """
        linops = linops if linops is not None else self.linops

        idim, idim_idx = idim_and_idx
        odim, odim_idx = odim_and_idx

        # Compute shape from linops and dimensions if not provided
        if shape is None:
            _, _, ishape = self._get_dim_and_idx(idim, idim_idx, linops[0].ishape)
            _, _, oshape = self._get_dim_and_idx(odim, odim_idx, linops[0].oshape)
            shape = NS(ishape, oshape)

        new = copy(self)
        new.shape = shape
        new.linops = nn.ModuleList(linops)
        new.idim = idim
        new.idim_idx = idim_idx
        new.odim = odim
        new.odim_idx = odim_idx
        return new

    def __len__(self):
        return len(self.linops)

    def __repr__(self):
        output = ""
        output += INDENT.indent(self.repr_name + f"({self._shape}\n")
        with INDENT:
            for linop in self.linops:
                output += repr(linop) + "\n"
            output += INDENT.indent(f"idim = {self.idim}, odim = {self.odim}\n")
        output += INDENT.indent(")")
        return output

__init__

__init__(
    *linops: NamedLinop,
    idim_and_idx: tuple[
        Optional[NamedDimension | str], Optional[int]
    ] = (None, None),
    odim_and_idx: tuple[
        Optional[NamedDimension | str], Optional[int]
    ] = (None, None),
    **kwargs,
)
PARAMETER DESCRIPTION
*linops

The linops to stack.

TYPE: NamedLinop DEFAULT: ()

idim_and_idx

Tuple of (dim_name, index_tensor) for the input stacking dimension.

TYPE: tuple DEFAULT: (None, None)

odim_and_idx

Tuple of (dim_name, index_tensor) for the output stacking dimension.

TYPE: tuple DEFAULT: (None, None)

Source code in src/torchlinops/linops/stack.py
def __init__(
    self,
    *linops: NamedLinop,
    idim_and_idx: tuple[Optional[ND | str], Optional[int]] = (None, None),
    odim_and_idx: tuple[Optional[ND | str], Optional[int]] = (None, None),
    **kwargs,
):
    """
    Parameters
    ----------
    *linops : NamedLinop
        The linops to stack.
    idim_and_idx : tuple, optional
        Tuple of ``(dim_name, index_tensor)`` for the input stacking dimension.
    odim_and_idx : tuple, optional
        Tuple of ``(dim_name, index_tensor)`` for the output stacking dimension.
    """

    self.idim, self.idim_idx, ishape = self._get_dim_and_idx(
        *idim_and_idx, linops[0].ishape
    )
    self.odim, self.odim_idx, oshape = self._get_dim_and_idx(
        *odim_and_idx, linops[0].oshape
    )

    # Initialize parent class
    super().__init__(NS(ishape, oshape), **kwargs)
    self.linops = nn.ModuleList(list(linops))
    self._check_linop_compatibility()

spinoff

spinoff(
    linops=None,
    shape=None,
    idim_and_idx=(None, None),
    odim_and_idx=(None, None),
)

Helper function for creating a new linop using the provided inputs.

Preserves settings from the original linop.

PARAMETER DESCRIPTION
linops

The linops for the new instance. Defaults to self.linops.

TYPE: list[NamedLinop] DEFAULT: None

shape

The shape for the new instance. If None, computed from linops and idim/odim.

TYPE: NamedShape DEFAULT: None

idim_and_idx

Tuple of (dim_name, index) for input stacking dimension. Defaults to (self.idim, self.idim_idx).

TYPE: tuple[NamedDimension | None, int | None] DEFAULT: (None, None)

odim_and_idx

Tuple of (dim_name, index) for output stacking dimension. Defaults to (self.odim, self.odim_idx).

TYPE: tuple[NamedDimension | None, int | None] DEFAULT: (None, None)

Source code in src/torchlinops/linops/stack.py
def spinoff(
    self,
    linops=None,
    shape=None,
    idim_and_idx=(None, None),
    odim_and_idx=(None, None),
):
    """Helper function for creating a new linop using the provided inputs.

    Preserves settings from the original linop.

    Parameters
    ----------
    linops : list[NamedLinop], optional
        The linops for the new instance. Defaults to self.linops.
    shape : NamedShape, optional
        The shape for the new instance. If None, computed from linops
        and idim/odim.
    idim_and_idx : tuple[ND | None, int | None], optional
        Tuple of (dim_name, index) for input stacking dimension.
        Defaults to (self.idim, self.idim_idx).
    odim_and_idx : tuple[ND | None, int | None], optional
        Tuple of (dim_name, index) for output stacking dimension.
        Defaults to (self.odim, self.odim_idx).
    """
    linops = linops if linops is not None else self.linops

    idim, idim_idx = idim_and_idx
    odim, odim_idx = odim_and_idx

    # Compute shape from linops and dimensions if not provided
    if shape is None:
        _, _, ishape = self._get_dim_and_idx(idim, idim_idx, linops[0].ishape)
        _, _, oshape = self._get_dim_and_idx(odim, odim_idx, linops[0].oshape)
        shape = NS(ishape, oshape)

    new = copy(self)
    new.shape = shape
    new.linops = nn.ModuleList(linops)
    new.idim = idim
    new.idim_idx = idim_idx
    new.odim = odim
    new.odim_idx = odim_idx
    return new

split_data

split_data(ibatch, obatch, data_list)

Split stack linop, making a new stack linop if necessary

PARAMETER DESCRIPTION
data_list

List of data for each linop in this stack linop

TYPE: list, same length as linops

Source code in src/torchlinops/linops/stack.py
def split_data(self, ibatch, obatch, data_list):
    """Split stack linop, making a new stack linop if necessary

    Parameters
    ----------
    data_list : list, same length as linops
        List of data for each linop in this stack linop

    """
    linop_idxs = set(range(len(self.linops)))
    for i, slc in enumerate(ibatch):
        if i == self.idim_idx:
            linop_idxs &= set(slice2range(slc, len(self.linops)))
    for i, slc in enumerate(obatch):
        if i == self.odim_idx:
            linop_idxs &= set(slice2range(slc, len(self.linops)))

    if len(linop_idxs) == 0:
        # No linops satisfy this slice (diagonal stacking)
        return 0.0  # TODO is this ok
    linop_idxs = sorted(list(linop_idxs))
    output_linop_data = []
    # Remove stack dims from slice batch
    if self.idim_idx is not None:
        ibatch = copy(ibatch)
        ibatch.pop(self.idim_idx)
    if self.odim_idx is not None:
        obatch = copy(obatch)
        obatch.pop(self.odim_idx)

    # Slice each sub-linop
    for i in linop_idxs:
        linop = self.linops[i]
        data = data_list[i]
        output_linop_data.append(linop.split_data(ibatch, obatch, data))
    return output_linop_data

split_forward

split_forward(ibatch, obatch)

Split stack linop

Source code in src/torchlinops/linops/stack.py
def split_forward(self, ibatch, obatch):
    """Split stack linop"""
    linop_idxs = set(range(len(self.linops)))
    for i, slc in enumerate(ibatch):
        if i == self.idim_idx:
            linop_idxs &= set(slice2range(slc, len(self.linops)))
    for i, slc in enumerate(obatch):
        if i == self.odim_idx:
            linop_idxs &= set(slice2range(slc, len(self.linops)))

    if len(linop_idxs) == 0:
        # No linops satisfy this slice (diagonal stacking)
        return Zero(self.ishape, self.oshape)
    linop_idxs = sorted(list(linop_idxs))
    output_linops = []
    # Remove stack dims from slice batch
    if self.idim_idx is not None:
        ibatch = ibatch.copy()
        ibatch.pop(self.idim_idx)
    if self.odim_idx is not None:
        obatch = obatch.copy()
        obatch.pop(self.odim_idx)

    # Slice each sub-linop
    for i in linop_idxs:
        linop = self.linops[i]
        islices = {dim: slc for dim, slc in zip(linop.ishape, ibatch)}
        oslices = {dim: slc for dim, slc in zip(linop.oshape, obatch)}
        slices = strict_update(islices, oslices)
        output_linops.append(linop.split(linop, slices))
    return self.spinoff(output_linops)

SumReduce

Bases: NamedLinop

Sum-reduction operator (adjoint of Repeat).

Wraps einops.reduce with 'sum' reduction. Reduces (sums over) specified dimensions.

Source code in src/torchlinops/linops/einops.py
class SumReduce(NamedLinop):
    """Sum-reduction operator (adjoint of ``Repeat``).

    Wraps ``einops.reduce`` with ``'sum'`` reduction. Reduces (sums over)
    specified dimensions.
    """

    def __init__(self, ishape, oshape):
        """
        Parameters
        ----------
        ishape : Shape
            Input shape spec, einops style.
        oshape : Shape
            Output shape spec, einops style.
        """
        super().__init__(NS(ishape, oshape))
        assert len(self.oshape) < len(self.ishape), (
            f"Reduce must be over at least one dimension: got {self.ishape} -> {self.oshape}"
        )

    @staticmethod
    def fn(sumreduce, x, /):
        x = reduce(x, f"{sumreduce.ipattern} -> {sumreduce.opattern}", "sum")
        return x

    @staticmethod
    def adj_fn(sumreduce, x, /):
        x = repeat(x, f"{sumreduce.opattern} -> {sumreduce.adj_ipattern}")
        return x

    def split_forward(self, ibatch, obatch):
        return self

    def size(self, dim: str):
        """Reducing does not determine any dimensions"""
        return None

    def adjoint(self):
        broadcast_dims = [d for d in self.ishape if d not in self.oshape]
        n_repeats = {d: 1 for d in broadcast_dims}
        return Repeat(n_repeats, self._shape.H, None, broadcast_dims)

    def normal(self, inner=None):
        pre = copy(self)
        post = self.adjoint()
        # New post output shape (post = Repeat)
        # If dimension is not summed over (i.e. it is in pre_adj_ishape) , it stays the same
        # Otherwise, if dimension is summed over, its name changes
        # This automatically updates the axes_lengths as well.
        post.oshape = tuple(
            d.next_unused(self.ishape) if d not in post.ishape else d
            for d in post.oshape
        )
        shape_updates = {
            d: new_d for d, new_d in zip(pre.ishape, post.oshape) if d != new_d
        }
        if inner is not None:
            pre.oshape = inner.ishape
            post.ishape = inner.oshape
            inner_shape_updates = getattr(inner, "_shape_updates", {})
            shape_updates.update(inner_shape_updates)
            normal = post @ inner @ pre
            normal._shape_updates = shape_updates
        else:
            normal = post @ pre
            normal._shape_updates = shape_updates
        return normal

    @property
    def adj_ishape(self):
        return self.fill_singleton_dims(self.ishape, self.oshape)

    @property
    def adj_ipattern(self):
        return " ".join(str(d) if d is not None else "()" for d in self.adj_ishape)

    @property
    def ipattern(self):
        return " ".join(str(d) for d in self.ishape)

    @property
    def opattern(self):
        return " ".join(str(d) for d in self.oshape)

    @staticmethod
    def fill_singleton_dims(ishape, oshape):
        out = []
        for idim in ishape:
            if idim in oshape:
                out.append(idim)
            else:
                out.append(None)
        return tuple(out)

__init__

__init__(ishape, oshape)
PARAMETER DESCRIPTION
ishape

Input shape spec, einops style.

TYPE: Shape

oshape

Output shape spec, einops style.

TYPE: Shape

Source code in src/torchlinops/linops/einops.py
def __init__(self, ishape, oshape):
    """
    Parameters
    ----------
    ishape : Shape
        Input shape spec, einops style.
    oshape : Shape
        Output shape spec, einops style.
    """
    super().__init__(NS(ishape, oshape))
    assert len(self.oshape) < len(self.ishape), (
        f"Reduce must be over at least one dimension: got {self.ishape} -> {self.oshape}"
    )

size

size(dim: str)

Reducing does not determine any dimensions

Source code in src/torchlinops/linops/einops.py
def size(self, dim: str):
    """Reducing does not determine any dimensions"""
    return None

Threadable

Mixin to enable parallel execution of sub-linops using Python threads.

When threaded=True, the linop's fn and adj_fn methods will run each sub-linop in parallel using a ThreadPoolExecutor. This is useful when sub-linops are I/O bound or release the GIL (e.g., PyTorch tensor operations).

The mixin manages sub-linops through the linops property, which automatically creates shallow copies of each linop when assigned. This ensures that shared linops (e.g., Add(A, A)) have independent identities for threading while still sharing tensor data.

ATTRIBUTE DESCRIPTION
linops

The list of linops to run in parallel. Setting this property triggers automatic shallow copying and input listener setup.

TYPE: ModuleList

threaded

Whether to run sub-linops in parallel. Default is True.

TYPE: bool

num_workers

Number of worker threads. If None, defaults to the number of sub-linops.

TYPE: int | None

settings

Dictionary with threaded and num_workers keys for easy copying of threading configuration.

TYPE: dict

Source code in src/torchlinops/linops/threadable.py
class Threadable:
    """Mixin to enable parallel execution of sub-linops using Python threads.

    When ``threaded=True``, the linop's ``fn`` and ``adj_fn`` methods will run
    each sub-linop in parallel using a ThreadPoolExecutor. This is useful when
    sub-linops are I/O bound or release the GIL (e.g., PyTorch tensor operations).

    The mixin manages sub-linops through the ``linops`` property, which automatically
    creates shallow copies of each linop when assigned. This ensures that shared
    linops (e.g., ``Add(A, A)``) have independent identities for threading while
    still sharing tensor data.

    Attributes
    ----------
    linops : nn.ModuleList
        The list of linops to run in parallel. Setting this property triggers
        automatic shallow copying and input listener setup.
    threaded : bool
        Whether to run sub-linops in parallel. Default is True.
    num_workers : int | None
        Number of worker threads. If None, defaults to the number of sub-linops.
    settings : dict
        Dictionary with ``threaded`` and ``num_workers`` keys for easy copying
        of threading configuration.
    """

    def __init__(
        self,
        *args,
        threaded: bool = True,
        num_workers: Optional[int] = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        threaded : bool, optional
            Whether to run sub-linops in parallel. Default is True.
        num_workers : int | None, optional
            Number of worker threads. If None, defaults to the number of
            sub-linops when ``threaded_apply`` or ``threaded_apply_sum_reduce``
            is called.
        linops : list[NamedLinop], optional
            The list of linops to run in parallel. If assigned via the
            ``linops`` property, input listeners will be set up automatically.
        """
        super().__init__(*args, **kwargs)
        self.threaded = threaded
        self.num_workers = num_workers

    @property
    def linops(self):
        """The list of sub-linops managed by this Threadable.

        This is a property rather than a direct attribute to intercept assignment
        and perform automatic housekeeping whenever linops are set. The setter
        creates shallow copies of each linop (preserving tensor data sharing)
        and sets up input listeners for event coordination.

        Returns
        -------
        nn.ModuleList
            The list of sub-linops.
        """
        return self._linops

    @linops.setter
    def linops(self, new_linops):
        """Set sub-linops with automatic copying and event setup.

        When linops are assigned, this setter:
        1. Creates shallow copies of each linop using ``copy()``, ensuring
           shared linops have independent identities (for threading safety)
           while still sharing tensor data.
        2. Sets up input listeners on each copied linop.

        Parameters
        ----------
        new_linops : list[NamedLinop]
            The linops to manage.
        """
        self._linops = new_linops
        self._setup_events()

    def __setattr__(self, name, value):
        """Set attribute, with special handling for ``linops``.

        PyTorch's ``nn.Module.__setattr__`` intercepts attribute assignment and
        performs special handling for modules, parameters, and buffers. This
        override ensures that ``linops`` assignment goes through the property
        descriptor rather than being intercepted by PyTorch's logic.

        Parameters
        ----------
        name : str
            Attribute name.
        value : Any
            Attribute value.
        """
        if name == "linops":
            type(self).linops.fset(self, value)
        else:
            super().__setattr__(name, value)

    def _setup_events(self):
        """Set up input listeners on all sub-linops.

        This method is called automatically when ``linops`` is assigned via
        the property setter. It performs two operations:

        1. Creates shallow copies of each linop using ``copy()``, ensuring that
           linops shared by identity (e.g., ``Add(A, A)``) become independent
           objects while still sharing tensor data. This prevents race conditions
           when the same linop appears multiple times in a threaded context.

        2. Attaches an input listener to each linop, enabling coordination
           between the parent Threadable and its sub-linops.
        """
        self._linops = nn.ModuleList([copy(linop) for linop in self._linops])
        for linop in self._linops:
            linop.input_listener = (self, "input_listener")

    def _apply_defaults(self, x, num_workers):
        if not hasattr(self, "linops") or len(self.linops) == 0:
            raise AttributeError("Threadable class must have `linops` attribute.")
        xs = list(x) if isinstance(x, (list, tuple)) else [x]
        if num_workers is None:
            num_workers = max(len(self.linops), len(xs))
        return xs, num_workers

    def threaded_apply_sum_reduce(
        self, x: Tensor | list[Tensor], num_workers: Optional[int] = None
    ) -> Tensor:
        """Wrapper around _threaded_apply_sum_reduce."""
        xs, num_workers = self._apply_defaults(x, num_workers)
        return _threaded_apply_sum_reduce(self.linops, xs, num_workers)

    def threaded_apply(
        self, x: Tensor | list[Tensor], num_workers: Optional[int] = None
    ):
        """Wrapper around _threaded_apply"""
        xs, num_workers = self._apply_defaults(x, num_workers)
        return _threaded_apply(self.linops, xs, num_workers)

    @property
    def settings(self):
        """Get threading settings as a dictionary.

        Returns
        -------
        dict
            Dictionary with ``threaded`` and ``num_workers`` keys.
        """
        return {"threaded": self.threaded, "num_workers": self.num_workers}

    @settings.setter
    def settings(self, new_settings):
        """Set threading settings from a dictionary.

        Parameters
        ----------
        new_settings : dict
            Dictionary with ``threaded`` and ``num_workers`` keys.
        """
        self.threaded = new_settings["threaded"]
        self.num_workers = new_settings["num_workers"]

linops property writable

linops

The list of sub-linops managed by this Threadable.

This is a property rather than a direct attribute to intercept assignment and perform automatic housekeeping whenever linops are set. The setter creates shallow copies of each linop (preserving tensor data sharing) and sets up input listeners for event coordination.

RETURNS DESCRIPTION
ModuleList

The list of sub-linops.

settings property writable

settings

Get threading settings as a dictionary.

RETURNS DESCRIPTION
dict

Dictionary with threaded and num_workers keys.

__init__

__init__(
    *args,
    threaded: bool = True,
    num_workers: Optional[int] = None,
    **kwargs,
)
PARAMETER DESCRIPTION
threaded

Whether to run sub-linops in parallel. Default is True.

TYPE: bool DEFAULT: True

num_workers

Number of worker threads. If None, defaults to the number of sub-linops when threaded_apply or threaded_apply_sum_reduce is called.

TYPE: int | None DEFAULT: None

linops

The list of linops to run in parallel. If assigned via the linops property, input listeners will be set up automatically.

TYPE: list[NamedLinop]

Source code in src/torchlinops/linops/threadable.py
def __init__(
    self,
    *args,
    threaded: bool = True,
    num_workers: Optional[int] = None,
    **kwargs,
):
    """
    Parameters
    ----------
    threaded : bool, optional
        Whether to run sub-linops in parallel. Default is True.
    num_workers : int | None, optional
        Number of worker threads. If None, defaults to the number of
        sub-linops when ``threaded_apply`` or ``threaded_apply_sum_reduce``
        is called.
    linops : list[NamedLinop], optional
        The list of linops to run in parallel. If assigned via the
        ``linops`` property, input listeners will be set up automatically.
    """
    super().__init__(*args, **kwargs)
    self.threaded = threaded
    self.num_workers = num_workers

__setattr__

__setattr__(name, value)

Set attribute, with special handling for linops.

PyTorch's nn.Module.__setattr__ intercepts attribute assignment and performs special handling for modules, parameters, and buffers. This override ensures that linops assignment goes through the property descriptor rather than being intercepted by PyTorch's logic.

PARAMETER DESCRIPTION
name

Attribute name.

TYPE: str

value

Attribute value.

TYPE: Any

Source code in src/torchlinops/linops/threadable.py
def __setattr__(self, name, value):
    """Set attribute, with special handling for ``linops``.

    PyTorch's ``nn.Module.__setattr__`` intercepts attribute assignment and
    performs special handling for modules, parameters, and buffers. This
    override ensures that ``linops`` assignment goes through the property
    descriptor rather than being intercepted by PyTorch's logic.

    Parameters
    ----------
    name : str
        Attribute name.
    value : Any
        Attribute value.
    """
    if name == "linops":
        type(self).linops.fset(self, value)
    else:
        super().__setattr__(name, value)

threaded_apply

threaded_apply(
    x: Tensor | list[Tensor],
    num_workers: Optional[int] = None,
)

Wrapper around _threaded_apply

Source code in src/torchlinops/linops/threadable.py
def threaded_apply(
    self, x: Tensor | list[Tensor], num_workers: Optional[int] = None
):
    """Wrapper around _threaded_apply"""
    xs, num_workers = self._apply_defaults(x, num_workers)
    return _threaded_apply(self.linops, xs, num_workers)

threaded_apply_sum_reduce

threaded_apply_sum_reduce(
    x: Tensor | list[Tensor],
    num_workers: Optional[int] = None,
) -> Tensor

Wrapper around _threaded_apply_sum_reduce.

Source code in src/torchlinops/linops/threadable.py
def threaded_apply_sum_reduce(
    self, x: Tensor | list[Tensor], num_workers: Optional[int] = None
) -> Tensor:
    """Wrapper around _threaded_apply_sum_reduce."""
    xs, num_workers = self._apply_defaults(x, num_workers)
    return _threaded_apply_sum_reduce(self.linops, xs, num_workers)

ToDevice

Bases: NamedLinop

Transfer tensors between devices as a named linear operator.

The forward operation moves a tensor from idevice to odevice. The adjoint reverses the direction. The normal \(T^H T\) is the identity (device round-trip is lossless).

For CUDA-to-CUDA transfers, streams and events are used for asynchronous pipelined execution.

ATTRIBUTE DESCRIPTION
ispec

Source (input) device specification containing device and stream info.

TYPE: DeviceSpec

ospec

Target (output) device specification containing device and stream info.

TYPE: DeviceSpec

is_gpu2gpu

True if both source and target devices are CUDA devices.

TYPE: bool

Source code in src/torchlinops/linops/device.py
class ToDevice(NamedLinop):
    """Transfer tensors between devices as a named linear operator.

    The forward operation moves a tensor from ``idevice`` to ``odevice``.
    The adjoint reverses the direction. The normal $T^H T$ is the identity
    (device round-trip is lossless).

    For CUDA-to-CUDA transfers, streams and events are used for asynchronous
    pipelined execution.

    Attributes
    ----------
    ispec : DeviceSpec
        Source (input) device specification containing device and stream info.
    ospec : DeviceSpec
        Target (output) device specification containing device and stream info.
    is_gpu2gpu : bool
        True if both source and target devices are CUDA devices.
    """

    def __init__(
        self,
        idevice: DeviceSpec | torch.device | None,
        odevice: DeviceSpec | torch.device | None,
        ioshape: Optional[Shape] = None,
    ):
        """
        Parameters
        ----------
        idevice : DeviceSpec | torch.device | None
            Source (input) device specification.
        odevice : DeviceSpec | torch.device | None
            Target (output) device specification.
        ioshape : Shape, optional
            Named dimensions (same for input and output since this is diagonal).
        """
        super().__init__(NS(ioshape))

        idevice = default_to(torch.device("cpu"), idevice)
        odevice = default_to(torch.device("cpu"), odevice)
        if not isinstance(idevice, DeviceSpec):
            self.ispec = DeviceSpec(idevice)
        else:
            self.ispec = idevice
        if not isinstance(odevice, DeviceSpec):
            self.ospec = DeviceSpec(odevice)
        else:
            self.ospec = odevice

        # Perform any necessary setup for data transfer between these devices.
        self.ispec.p2p_setup(self.ospec.device)
        self.ospec.p2p_setup(self.ispec.device)

        if (
            self.ispec.device.type == "cuda" and self.ospec.device.type == "cuda"
        ):  # pragma: no cover
            self.is_gpu2gpu = True
        else:
            self.is_gpu2gpu = False

    @staticmethod
    def _fn(
        x,
        ispec: DeviceSpec,
        ospec: DeviceSpec,
        input_listener: Optional[Event] = None,
    ):
        idevice, odevice = ispec.device, ospec.device
        if x.device != idevice:
            raise RuntimeError(
                f"Got input to ToDevice on {x.device} but expected {idevice}"
            )

        # GPU -> GPU
        if idevice.type == "cuda" and odevice.type == "cuda":  # pragma: no cover
            if input_listener is None:
                warn(
                    "Peer-to-peer device transfer with input_listener = None detected. Results may not be accurate."
                )
            return _gpu2gpu_transfer(
                x,
                ospec.compute_stream,
                ispec.transfer_stream,
                input_listener,
            )
        elif idevice.type == "cuda" and odevice.type == "cpu":  # pragma: no cover
            # GPU -> CPU requires additional synchronization, see:
            # https://github.com/pytorch/pytorch/issues/127612
            return x.to(odevice, non_blocking=False)

        # CPU -> GPU or CPU -> CPU
        return x.to(odevice, non_blocking=True)

    @staticmethod
    def fn(todevice, x, /):
        return todevice._fn(
            x,
            todevice.ispec,
            todevice.ospec,
            todevice.input_listener,
        )

    @staticmethod
    def adj_fn(todevice, x, /):
        return todevice._fn(
            x,
            todevice.ospec,
            todevice.ispec,
            todevice.input_listener,
        )

    def adjoint(self):
        adj = copy(self)
        adj._shape = adj._shape.H
        adj.ispec, adj.ospec = self.ospec, self.ispec
        return adj

    def normal(self, inner=None):
        if inner is None:
            return Identity()
        return super().normal(inner)

    def split_forward(self, ibatch, obatch):
        """Return a new instance"""
        return copy(self)

    def __repr__(self):
        """Helps prevent recursion error caused by .H and .N"""
        if (
            self.ispec.compute_stream is not None
            or self.ispec.transfer_stream is not None
        ):  # pragma: no cover
            irepr = f"{self.ispec.device}, compute: 0x{self.ispec.compute_stream.cuda_stream:x}, transfer: 0x{self.ispec.transfer_stream.cuda_stream:x}"
        else:
            irepr = f"{self.ispec.device}"
        if (
            self.ospec.compute_stream is not None
            or self.ospec.transfer_stream is not None
        ):  # pragma: no cover
            orepr = f"{self.ospec.device}, compute: 0x{self.ospec.compute_stream.cuda_stream:x}, transfer: 0x{self.ospec.transfer_stream.cuda_stream:x}"
        else:
            orepr = f"{self.ospec.device}"
        if self.input_listener is not None and self.is_gpu2gpu:  # pragma: no cover
            input_listener_repr = f"on: {self.input_listener.event_id:x}"
        else:
            input_listener_repr = ""
        out = f"({input_listener_repr} | {irepr} -> {orepr})"
        out = INDENT.indent(out)
        return out

__init__

__init__(
    idevice: DeviceSpec | device | None,
    odevice: DeviceSpec | device | None,
    ioshape: Optional[Shape] = None,
)
PARAMETER DESCRIPTION
idevice

Source (input) device specification.

TYPE: DeviceSpec | device | None

odevice

Target (output) device specification.

TYPE: DeviceSpec | device | None

ioshape

Named dimensions (same for input and output since this is diagonal).

TYPE: Shape DEFAULT: None

Source code in src/torchlinops/linops/device.py
def __init__(
    self,
    idevice: DeviceSpec | torch.device | None,
    odevice: DeviceSpec | torch.device | None,
    ioshape: Optional[Shape] = None,
):
    """
    Parameters
    ----------
    idevice : DeviceSpec | torch.device | None
        Source (input) device specification.
    odevice : DeviceSpec | torch.device | None
        Target (output) device specification.
    ioshape : Shape, optional
        Named dimensions (same for input and output since this is diagonal).
    """
    super().__init__(NS(ioshape))

    idevice = default_to(torch.device("cpu"), idevice)
    odevice = default_to(torch.device("cpu"), odevice)
    if not isinstance(idevice, DeviceSpec):
        self.ispec = DeviceSpec(idevice)
    else:
        self.ispec = idevice
    if not isinstance(odevice, DeviceSpec):
        self.ospec = DeviceSpec(odevice)
    else:
        self.ospec = odevice

    # Perform any necessary setup for data transfer between these devices.
    self.ispec.p2p_setup(self.ospec.device)
    self.ospec.p2p_setup(self.ispec.device)

    if (
        self.ispec.device.type == "cuda" and self.ospec.device.type == "cuda"
    ):  # pragma: no cover
        self.is_gpu2gpu = True
    else:
        self.is_gpu2gpu = False

__repr__

__repr__()

Helps prevent recursion error caused by .H and .N

Source code in src/torchlinops/linops/device.py
def __repr__(self):
    """Helps prevent recursion error caused by .H and .N"""
    if (
        self.ispec.compute_stream is not None
        or self.ispec.transfer_stream is not None
    ):  # pragma: no cover
        irepr = f"{self.ispec.device}, compute: 0x{self.ispec.compute_stream.cuda_stream:x}, transfer: 0x{self.ispec.transfer_stream.cuda_stream:x}"
    else:
        irepr = f"{self.ispec.device}"
    if (
        self.ospec.compute_stream is not None
        or self.ospec.transfer_stream is not None
    ):  # pragma: no cover
        orepr = f"{self.ospec.device}, compute: 0x{self.ospec.compute_stream.cuda_stream:x}, transfer: 0x{self.ospec.transfer_stream.cuda_stream:x}"
    else:
        orepr = f"{self.ospec.device}"
    if self.input_listener is not None and self.is_gpu2gpu:  # pragma: no cover
        input_listener_repr = f"on: {self.input_listener.event_id:x}"
    else:
        input_listener_repr = ""
    out = f"({input_listener_repr} | {irepr} -> {orepr})"
    out = INDENT.indent(out)
    return out

split_forward

split_forward(ibatch, obatch)

Return a new instance

Source code in src/torchlinops/linops/device.py
def split_forward(self, ibatch, obatch):
    """Return a new instance"""
    return copy(self)

Truncate

Bases: NamedLinop

Truncation (slicing) operator along the last dimension.

Extracts a contiguous slice from the input. The adjoint zero-pads back to the original size.

Source code in src/torchlinops/linops/trunc_pad.py
class Truncate(NamedLinop):
    """Truncation (slicing) operator along the last dimension.

    Extracts a contiguous slice from the input. The adjoint zero-pads
    back to the original size.
    """

    def __init__(
        self,
        dim: int,
        from_length: int,
        to_length: int,
        ishape: Shape,
        oshape: Shape,
    ):
        self.dim = dim
        self.from_length = from_length
        self.to_length = to_length
        if self.from_length < 0:
            raise ValueError(
                f"from_length must be nonnegative but got {self.from_length}"
            )
        if self.to_length < 0 or self.to_length > from_length:
            raise ValueError(
                f"to_length must be in [0, {self.from_length}] but got {self.to_length}"
            )
        self.slc = [slice(None)] * len(ishape)
        self.slc[dim] = slice(0, self.to_length)
        self.slc = tuple(self.slc)

        self.end_slc = [slice(None)] * len(ishape)
        self.end_slc[dim] = slice(self.to_length - self.from_length, None)
        self.end_slc = tuple(self.end_slc)
        super().__init__(NS(ishape, oshape))

    @staticmethod
    def fn(truncate, x, /):
        if x.shape[truncate.dim] != truncate.from_length:
            raise ValueError(
                f"Truncate expecting size {truncate.from_length} at x.shape[{truncate.dim}] but got {x.shape[truncate.dim]} (x.shape: {x.shape})"
            )
        return x[truncate.slc]

    @staticmethod
    def adj_fn(truncate, y, /):
        if y.shape[truncate.dim] != truncate.to_length:
            raise ValueError(
                f"Truncate (adjoint) expecting size {truncate.to_length} at x.shape[{truncate.dim}] but got {y.shape[truncate.dim]} (y.shape: {y.shape})"
            )
        return end_pad_with_zeros(
            y,
            truncate.dim,
            truncate.from_length - truncate.to_length,
        )

    @staticmethod
    def normal_fn(truncate, x, /):
        if x.shape[truncate.dim] != truncate.from_length:
            raise ValueError(
                f"Truncate (normal) expecting size {truncate.from_length} at x.shape[{truncate.dim}] but got {x.shape[truncate.dim]} (x.shape: {x.shape})"
            )
        x = x.clone()
        x[truncate.end_slc] = 0.0
        return x

    def split_forward(self, ibatch, obatch):
        if ibatch[self.dim] != slice(None) or obatch[self.dim] != slice(None):
            raise ValueError("Cannot slice a Truncate linop along truncation dimension")
        return type(self)(
            self.dim, self.from_length, self.to_length, self.ishape, self.oshape
        )

    def adjoint(self):
        return PadDim(
            self.dim,
            self.to_length,
            self.from_length,
            self.oshape,
            self.ishape,
        )

    def normal(self, inner=None):
        """Diagonal in all dims except the last one"""
        pre = copy(self)
        post = self.adjoint()
        if inner is None:
            return post @ pre
        pre.oshape = inner.ishape
        post.ishape = inner.oshape
        new_oshape = list(inner.oshape)
        new_oshape[self.dim] = post.oshape[self.dim]
        post.oshape = tuple(new_oshape)
        return post @ inner @ pre

normal

normal(inner=None)

Diagonal in all dims except the last one

Source code in src/torchlinops/linops/trunc_pad.py
def normal(self, inner=None):
    """Diagonal in all dims except the last one"""
    pre = copy(self)
    post = self.adjoint()
    if inner is None:
        return post @ pre
    pre.oshape = inner.ishape
    post.ishape = inner.oshape
    new_oshape = list(inner.oshape)
    new_oshape[self.dim] = post.oshape[self.dim]
    post.oshape = tuple(new_oshape)
    return post @ inner @ pre

Zero

Bases: NamedLinop

Zero operator \(0(x) = 0\).

Always returns a zero tensor with the same shape as the input.

Source code in src/torchlinops/linops/identity.py
class Zero(NamedLinop):
    """Zero operator $0(x) = 0$.

    Always returns a zero tensor with the same shape as the input.
    """

    def __init__(self, ishape=("...",), oshape=None):
        super().__init__(NS(ishape, oshape))

    @staticmethod
    def fn(self, x, /):
        return x.zero_()

    @staticmethod
    def adj_fn(self, x, /):
        return x.zero_()

    @staticmethod
    def normal_fn(self, x, /):
        return x.zero_()

    def split_forward(self, ibatch, obatch):
        return self

Crop

Crop(
    crop_im_size, im_size, in_shape, out_shape, batch_shape
)
Source code in src/torchlinops/linops/pad_last.py
def Crop(crop_im_size, im_size, in_shape, out_shape, batch_shape):
    return Pad(im_size, crop_im_size, out_shape, in_shape, batch_shape).H

clear_transfer_streams_registry

clear_transfer_streams_registry() -> None

Clear the transfer streams registry.

This is useful for testing to ensure a clean state between tests. The registry caches CUDA streams to enable reuse across transfers.

Source code in src/torchlinops/linops/device.py
def clear_transfer_streams_registry() -> None:
    """Clear the transfer streams registry.

    This is useful for testing to ensure a clean state between tests.
    The registry caches CUDA streams to enable reuse across transfers.
    """
    _TRANSFER_STREAMS_REGISTRY.clear()

create_batched_linop

create_batched_linop(
    linop,
    batch_specs: BatchSpec | list[BatchSpec],
    default_device: device = None,
    _mmap=None,
)

Split and distribute a linop across devices according to batch specs.

Recursively processes a list of BatchSpec objects: the first spec splits the linop into tiles, optionally places each tile on a target device, then passes remaining specs to each tile recursively. Tiles are reassembled via Concat (for partitioned dimensions) or Add (for reduced dimensions).

PARAMETER DESCRIPTION
linop

The operator to split and distribute.

TYPE: NamedLinop

batch_specs

One or more batch specifications to apply (processed in order).

TYPE: BatchSpec or list[BatchSpec]

_mmap

Internal memory map for efficient device transfers. Created automatically on the first call. Probably don't set this manually.

TYPE: ModuleMemoryMap DEFAULT: None

_default_device

The default device to use if no device info is provided in the batch spec.

TYPE: device

RETURNS DESCRIPTION
NamedLinop

A composite linop (tree of Concat/Add/ToDevice operators) that is functionally equivalent to the original but distributed according to the batch specs.

Source code in src/torchlinops/linops/split.py
def create_batched_linop(
    linop,
    batch_specs: BatchSpec | list[BatchSpec],
    default_device: torch.device = None,
    _mmap=None,
):
    """Split and distribute a linop across devices according to batch specs.

    Recursively processes a list of ``BatchSpec`` objects: the first spec
    splits the linop into tiles, optionally places each tile on a target
    device, then passes remaining specs to each tile recursively. Tiles are
    reassembled via ``Concat`` (for partitioned dimensions) or ``Add`` (for
    reduced dimensions).

    Parameters
    ----------
    linop : NamedLinop
        The operator to split and distribute.
    batch_specs : BatchSpec or list[BatchSpec]
        One or more batch specifications to apply (processed in order).
    _mmap : ModuleMemoryMap, optional
        Internal memory map for efficient device transfers. Created
        automatically on the first call. Probably don't set this manually.
    _default_device : torch.device, optional
        The default device to use if no device info is provided in the batch spec.

    Returns
    -------
    NamedLinop
        A composite linop (tree of ``Concat``/``Add``/``ToDevice`` operators)
        that is functionally equivalent to the original but distributed
        according to the batch specs.
    """
    if default_device is None:
        default_device = torch.device("cpu")
    if isinstance(batch_specs, BatchSpec):
        # Ensure list
        batch_specs = [batch_specs]
    if _mmap is None:
        _mmap = ModuleMemoryMap()
        _mmap.register_module(linop)
    if len(batch_specs) == 0:
        # Recursive ending
        return linop
    batch_spec = deepcopy(batch_specs[0])
    # Set defaults
    batch_spec.base_device = default_to(default_device, batch_spec.base_device)
    batch_spec.device_matrix = default_to(
        np.array([default_device]), batch_spec.device_matrix
    )

    # Split linop into tiles and broadcast device spec to the tile array.
    linops, ibatches, obatches = split_linop(linop, batch_spec.batch_sizes)
    device_matrix = batch_spec.broadcast_device_matrix(linop)
    if device_matrix.shape != linops.shape:
        raise ValueError(
            f"device_matrix and linops should have same shape after broadcasting, but got device_matrix: {device_matrix.shape} and linops: {linops.shape}"
        )

    # Create event to trigger all tiles in the linop.
    source_device = batch_spec.base_device
    # Allocate output
    for idx in np.ndindex(linops.shape):
        linop, target_device = linops[idx], device_matrix[idx]

        # Recursive call to batch the tile
        tiled_linop = create_batched_linop(
            linop, batch_specs[1:], default_device=target_device, _mmap=_mmap
        )

        # Move linop to device
        tiled_linop = _mmap.memory_aware_to(tiled_linop, target_device)

        # Wrap with device movement linops
        if source_device != target_device:
            tiled_linop = Chain(
                ToDevice(
                    source_device,
                    target_device,
                    ioshape=tiled_linop.ishape,
                ),
                tiled_linop,
                ToDevice(
                    target_device,
                    source_device,
                    ioshape=tiled_linop.oshape,
                ),
            )

        # Overwrite entry in linops
        linops[idx] = tiled_linop

    for dim in reversed(batch_spec.batch_sizes):
        # Manual axis reduction because I made Concat and Add too nice
        flat_linops = linops.reshape(-1, linops.shape[-1])
        new_linops = np.empty(flat_linops.shape[0], dtype=object)
        for i, linop_arr in enumerate(flat_linops):
            linop = linop_arr[0]
            if dim in linop.ishape and dim in linop.oshape:
                new_linop = Concat(*linop_arr, idim=dim, odim=dim)
            elif dim not in linop.ishape and dim in linop.oshape:
                new_linop = Concat(*linop_arr, odim=dim)
            elif dim in linop.ishape and dim not in linop.oshape:
                new_linop = Concat(*linop_arr, idim=dim)
            else:
                new_linop = Add(*linop_arr)
            new_linops[i] = new_linop
        linops = new_linops.reshape(linops.shape[:-1])
    linop = linops.item()
    return linop

default_to

default_to(*vals, typecast: bool = False)

Get the first non-None value, right to left order.

Most "default" value goes first.

Source code in src/torchlinops/utils/_defaults.py
def default_to(*vals, typecast: bool = False):
    """Get the first non-None value, right to left order.

    Most "default" value goes first.
    """
    if len(vals) == 0:
        return None
    typecls = type(vals[0])
    if len(vals) == 1:
        return vals[0]
    for val in reversed(vals):
        if val is not None:
            if typecast:
                return typecls(val)
            return val

get_nd_shape

get_nd_shape(im_size, kspace=False)

Return spatial dimension names for a given image size.

Maps a 1-D, 2-D, or 3-D image size tuple to the corresponding named dimension tuple (e.g. ('Nx', 'Ny') or ('Kx', 'Ky')).

PARAMETER DESCRIPTION
im_size

Image size tuple whose length (1, 2, or 3) determines the spatial dimensionality.

TYPE: tuple

kspace

If True, return k-space dimension names (Kx, Ky, …) instead of image-space names (Nx, Ny, …). Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
tuple of str

Named dimension strings for each spatial axis.

RAISES DESCRIPTION
ValueError

If im_size does not have length 1, 2, or 3.

Source code in src/torchlinops/nameddim/_shapes.py
def get_nd_shape(im_size, kspace=False):
    """Return spatial dimension names for a given image size.

    Maps a 1-D, 2-D, or 3-D image size tuple to the corresponding named
    dimension tuple (e.g. ``('Nx', 'Ny')`` or ``('Kx', 'Ky')``).

    Parameters
    ----------
    im_size : tuple
        Image size tuple whose length (1, 2, or 3) determines the spatial
        dimensionality.
    kspace : bool, optional
        If ``True``, return k-space dimension names (``Kx``, ``Ky``, …)
        instead of image-space names (``Nx``, ``Ny``, …).  Defaults to
        ``False``.

    Returns
    -------
    tuple of str
        Named dimension strings for each spatial axis.

    Raises
    ------
    ValueError
        If ``im_size`` does not have length 1, 2, or 3.
    """
    if len(im_size) == 1:
        im_dim = ("Kx",) if kspace else ("Nx",)
    elif len(im_size) == 2:
        im_dim = ("Kx", "Ky") if kspace else ("Nx", "Ny")
    elif len(im_size) == 3:
        im_dim = ("Kx", "Ky", "Kz") if kspace else ("Nx", "Ny", "Nz")
    else:
        raise ValueError(f"Image size {im_size} - should have length 2 or 3")
    return im_dim

isequal

isequal(
    shape1: Sequence,
    shape2: Sequence,
    return_assignments: bool = False,
) -> bool | tuple[bool, Optional[dict[int, list]]]

Test if two sequences with ellipses are length-compatible and value-compatible.

Implemented with bottom-up DP

PARAMETER DESCRIPTION
shape1

The sequences of tokens to compare.

TYPE: Sequence

shape2

The sequences of tokens to compare.

TYPE: Sequence

ELLIPSES

The wildcard that can match any number of tokens.

TYPE: str DEFAULT: = "..."

RETURNS DESCRIPTION
bool

Whether shape1 and shape2 are compatible.

Examples:

>>> isequal(("A", "B"), ("A", "B"))
True
>>> isequal(("A", "C"), ("A",))
False
>>> isequal(("A", "C"), tuple())
False
>>> isequal(("A", "C"), ("...",))
True
>>> isequal(("A", "C", "..."), ("...",))
True
>>> isequal(("A", "B", "C"), ("A", "...", "C"))
True
>>> isequal(("...", "A", "C", "..."), ("...",))
True
>>> isequal(("...", "A", "C"), ("B", "C"))
False

Wildcards

>>> isequal(("A", "B"), ("A", "()"))
True
>>> isequal(("A",), ("()", "()"))
False

Think about this one...

>>> isequal(("...", "A", "C", "..."), ("...", "A"))
True
Source code in src/torchlinops/nameddim/_matching.py
def isequal(
    shape1: Sequence,
    shape2: Sequence,
    return_assignments: bool = False,
) -> bool | tuple[bool, Optional[dict[int, list]]]:
    """Test if two sequences with ellipses are length-compatible and value-compatible.

    Implemented with bottom-up DP

    Parameters
    ----------
    shape1, shape2 : Sequence
        The sequences of tokens to compare.
    ELLIPSES : str, default = "..."
        The wildcard that can match any number of tokens.

    Returns
    -------
    bool
        Whether shape1 and shape2 are compatible.

    Examples
    --------

    >>> isequal(("A", "B"), ("A", "B"))
    True
    >>> isequal(("A", "C"), ("A",))
    False
    >>> isequal(("A", "C"), tuple())
    False
    >>> isequal(("A", "C"), ("...",))
    True
    >>> isequal(("A", "C", "..."), ("...",))
    True
    >>> isequal(("A", "B", "C"), ("A", "...", "C"))
    True
    >>> isequal(("...", "A", "C", "..."), ("...",))
    True
    >>> isequal(("...", "A", "C"), ("B", "C"))
    False

    # Wildcards
    >>> isequal(("A", "B"), ("A", "()"))
    True
    >>> isequal(("A",), ("()", "()"))
    False

    # Think about this one...
    >>> isequal(("...", "A", "C", "..."), ("...", "A"))
    True
    """
    ptrs = [[(0, 0) for _ in range(len(shape2) + 1)] for _ in range(len(shape1) + 1)]
    # Base cases
    ptrs[0][0] = (0, 0)  # True (note that bool(tuple()) == False)
    for i in range(1, len(shape1) + 1):
        ptrs[i][0] = (-1, 0) if shape1[0] == ELLIPSES else None
    for j in range(1, len(shape2) + 1):
        ptrs[0][j] = (0, -1) if shape2[0] == ELLIPSES else None
    for i in range(1, len(shape1) + 1):
        for j in range(1, len(shape2) + 1):
            if ptrs[i - 1][j - 1]:
                if shape1[i - 1] == ELLIPSES or shape2[j - 1] == ELLIPSES:
                    val = (-1, -1)
                elif shape1[i - 1] == shape2[j - 1]:
                    val = (-1, -1)
                elif shape1[i - 1] == ANY or shape2[j - 1] == ANY:
                    val = (-1, -1)
                else:
                    val = None
            elif ptrs[i - 1][j]:
                if shape2[j - 1] == ELLIPSES:
                    val = (-1, 0)
                else:
                    val = None
            elif ptrs[i][j - 1]:
                if shape1[i - 1] == ELLIPSES:
                    val = (0, -1)
                else:
                    val = None
            else:
                val = None
            ptrs[i][j] = val

    if return_assignments:
        if not ptrs[-1][-1]:
            return False, None
        # Traverse in reverse order
        assignments = defaultdict(list)
        row, col = len(shape1), len(shape2)
        while row > 0 and col > 0:
            assignments[row - 1].append(col - 1)
            drow, dcol = ptrs[row][col]
            row = row + drow
            col = col + dcol
        return True, assignments
    return bool(ptrs[-1][-1])

split_linop

split_linop(
    linop: NamedLinop,
    batch_sizes: dict[NamedDimension | str, int],
)

Split a linop into an nd-array of sub-linops according to batch sizes.

PARAMETER DESCRIPTION
linop

The linop to be split.

TYPE: NamedLinop

batch_sizes

Dictionary mapping dimension names to chunk sizes.

TYPE: dict[NamedDimension | str, int]

RETURNS DESCRIPTION
linops

Array of sub-linops with shape determined by the number of tiles per dimension.

TYPE: ndarray

input_batches

Corresponding input slices for each tile.

TYPE: ndarray

output_batches

Corresponding output slices for each tile.

TYPE: ndarray

Source code in src/torchlinops/linops/split.py
def split_linop(linop: NamedLinop, batch_sizes: dict[ND | str, int]):
    """Split a linop into an nd-array of sub-linops according to batch sizes.

    Parameters
    ----------
    linop : NamedLinop
        The linop to be split.
    batch_sizes : dict[ND | str, int]
        Dictionary mapping dimension names to chunk sizes.

    Returns
    -------
    linops : np.ndarray
        Array of sub-linops with shape determined by the number of tiles
        per dimension.
    input_batches : np.ndarray
        Corresponding input slices for each tile.
    output_batches : np.ndarray
        Corresponding output slices for each tile.
    """
    # Precompute sizes and shapes
    batch_sizes = {ND.infer(k): v for k, v in batch_sizes.items()}
    sizes = {dim: linop.size(dim) for dim in linop.dims}

    # Make tiles. Each tile is a dictionary mapping a dimension to an integer
    # index of the tile and a slice over that dimension.
    batch_iterators = make_batch_iterators(sizes, batch_sizes)
    tiles: list[dict[ND, Batch]] = list(dict_product(batch_iterators))

    # Allocate outputs
    batch_dims = list(batch_sizes.keys())
    tiled_shape = tuple(ceil(sizes[dim] / batch_sizes[dim]) for dim in batch_dims)
    linops = np.ndarray(tiled_shape, dtype=object)
    input_batches = np.ndarray(tiled_shape, dtype=object)
    output_batches = np.ndarray(tiled_shape, dtype=object)

    for tile in tiles:
        idx = _tile_get_idx(tile, batch_dims)
        linop_tile = _split_linop_with_tile(linop, tile)
        linop_flat = linop_tile.flatten()
        first_linop, last_linop = linop_flat[0], linop_flat[-1]
        linops[idx] = linop_tile
        input_batches[idx] = [
            tile.get(dim, DEFAULT_BATCH)[1] for dim in first_linop.ishape
        ]
        output_batches[idx] = [
            tile.get(dim, DEFAULT_BATCH)[1] for dim in last_linop.oshape
        ]
    return linops, input_batches, output_batches