Skip to content

Concat & Stack

torchlinops.linops.Concat

Bases: Threadable, NamedLinop

Concatenate some linops along an existing dimension.

Linops need not output tensors of the same size, but they should output tensors of the same number of dimensions.

Stacking type depends on dimensions provided:

Horizontal stacking (stacking along an input dimension)::

A B C

Vertical stacking (stacking along an output dimension)::

A
B
C

Diagonal stacking (stacking along separate input and output dimensions)::

A . .
. B .
. . C

Inherits from Threadable to support parallel execution of sub-linops. When threaded=True (default), each sub-linop is executed in parallel using a ThreadPoolExecutor.

Note that shared linops (e.g., Concat(A, A, idim="x")) are automatically shallow-copied to ensure independent identity for threading, while still sharing tensor data. See Threadable for details.

ATTRIBUTE DESCRIPTION
linops

The list of linops being concatenated.

TYPE: ModuleList

threaded

Whether to run sub-linops in parallel. Default is True.

TYPE: bool

num_workers

Number of worker threads. If None, defaults to the number of sub-linops.

TYPE: int | None

idim

Input dimension along which to concatenate.

TYPE: NamedDimension | None

odim

Output dimension along which to concatenate.

TYPE: NamedDimension | None

Source code in src/torchlinops/linops/concat.py
class Concat(Threadable, NamedLinop):
    """Concatenate some linops along an existing dimension.

    Linops need not output tensors of the same size, but they should
    output tensors of the same number of dimensions.

    Stacking type depends on dimensions provided:

    Horizontal stacking (stacking along an input dimension)::

        A B C

    Vertical stacking (stacking along an output dimension)::

        A
        B
        C

    Diagonal stacking (stacking along separate input and output dimensions)::

        A . .
        . B .
        . . C

    Inherits from ``Threadable`` to support parallel execution of sub-linops.
    When ``threaded=True`` (default), each sub-linop is executed in parallel
    using a ThreadPoolExecutor.

    Note that shared linops (e.g., ``Concat(A, A, idim="x")``) are automatically
    shallow-copied to ensure independent identity for threading, while still
    sharing tensor data. See ``Threadable`` for details.

    Attributes
    ----------
    linops : nn.ModuleList
        The list of linops being concatenated.
    threaded : bool
        Whether to run sub-linops in parallel. Default is True.
    num_workers : int | None
        Number of worker threads. If None, defaults to the number of sub-linops.
    idim : ND | None
        Input dimension along which to concatenate.
    odim : ND | None
        Output dimension along which to concatenate.
    """

    def __init__(
        self,
        *linops,
        idim: Optional[ND | str] = None,
        odim: Optional[ND | str] = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        *linops : NamedLinop
            The linops to concatenate.
        idim : str or ND, optional
            Input dimension along which to concatenate. If ``None``, the input
            is not concatenated (all linops receive the same input).
        odim : str or ND, optional
            Output dimension along which to concatenate. If ``None``, the output
            is not concatenated (outputs are summed).
        """
        self._check_linop_compatibility(linops)
        super().__init__(NS(linops[0].ishape, linops[0].oshape), **kwargs)
        self.linops = nn.ModuleList(list(linops))
        self._setup_indices(idim, odim)

    @staticmethod
    def fn(concat, x):
        return concat._fn(
            x,
            concat.linops,
            concat.idim_idx,
            concat.odim_idx,
            concat.islices,
            concat.oslices,
            concat.threaded,
            concat.num_workers,
        )

    @staticmethod
    def adj_fn(concat, x):
        adj_linops = [linop.H for linop in concat.linops]
        return concat._fn(
            x,
            adj_linops,
            concat.odim_idx,
            concat.idim_idx,
            concat.oslices,
            concat.islices,
            concat.threaded,
            concat.num_workers,
        )

    @staticmethod
    def _fn(
        x: Tensor,
        linops,
        idim_idx,
        odim_idx,
        islices,
        oslices,
        threaded: bool = False,
        num_workers: int | None = None,
    ):
        """Unifies forward and adjoint functionality for stacked linops"""
        if idim_idx is not None:  # Diagonal, Horizontal
            if islices[-1] != x.shape[idim_idx]:
                raise ValueError(
                    f"Concat Linop expecting input of size {islices[-1]} got input of size {x.shape} with non-matching concat size {x.shape[idim_idx]}"
                )
            xs = x.tensor_split(islices, dim=idim_idx)[:-1]  # Omit final slice
        else:  # Vertical
            xs = [x] * len(oslices)

        if odim_idx is not None:  # Diagonal, Vertical
            if threaded:
                ys = _threaded_apply(list(linops), xs, num_workers)
            else:
                ys = [linop(xi) for xi, linop in zip(xs, linops)]
            return torch.concatenate(ys, dim=odim_idx)

        # Horizontal
        if threaded:
            y = _threaded_apply_sum_reduce(list(linops), xs, num_workers)
        else:
            y = 0.0
            for xi, linop in zip(xs, linops):
                y += linop(xi)
        return y

    def size(self, dim):
        if dim == self.idim:
            return sum(self.isizes)
        elif dim == self.odim:
            return sum(self.osizes)
        else:
            for linop in self.linops:
                if linop.size(dim) is not None:
                    return linop.size(dim)

    def split_forward(self, ibatch, obatch):
        """Split concat linop, making a new concat linop if necessary"""
        ibatches = self.subslice(ibatch, self.idim_idx, self.islices, len(self.linops))
        obatches = self.subslice(obatch, self.odim_idx, self.oslices, len(self.linops))

        output_linop_idxs = ibatches.keys() & obatches.keys()
        output_linop_idxs = sorted(list(output_linop_idxs))
        if len(output_linop_idxs) == 0:
            # No linops satisfy this slice (diagonal stacking)
            return Zero(self.ishape, self.oshape)
        elif len(output_linop_idxs) == 1:
            # Singleton linop
            linop_idx = output_linop_idxs.pop()
            linop = self.linops[linop_idx]
            ibatch, obatch = ibatches[linop_idx], obatches[linop_idx]
            return linop.split_forward(ibatch, obatch)
        else:
            output_linop_idxs = sorted(list(output_linop_idxs))
            output_linops = []
            for i in output_linop_idxs:
                linop = self.linops[i]
                ibatch, obatch = ibatches[i], obatches[i]
                output_linops.append(linop.split_forward(ibatch, obatch))
            return self.spinoff(output_linops, idim=self.idim, odim=self.odim)

    @staticmethod
    def subslice(batch: list[slice], dim_idx: Optional[int], slices, num_linops):
        """Given a slice over some dims of a concat linop,
        return a mapping from the linop index to the relevant sub-slice for that linop.
        """
        linops_batch = {}
        if dim_idx is not None:
            slc = batch[dim_idx]
            slice_partition = slices.detach().cpu().numpy().tolist()
            slice_partition.insert(0, 0)
            sub_linop_slices = partition_slices(slice_partition, slc)
            for i, slc in sub_linop_slices:
                sub_linop_batch = copy(batch)
                sub_linop_batch[dim_idx] = slc
                linops_batch[i] = sub_linop_batch
        else:
            for i in range(num_linops):
                linops_batch[i] = batch
        return linops_batch

    def adjoint(self):
        adj_linops = [linop.H for linop in self.linops]
        adj_shape = adj_linops[0].shape
        return self.spinoff(
            linops=adj_linops,
            shape=adj_shape,
            idim=self.odim,
            odim=self.idim,
        )

    def normal(self, inner=None):
        if inner is None:
            # Standardize on this shape
            max_ishape = max_shape([linop.N.ishape for linop in self.linops])
            max_oshape = max_shape([linop.N.oshape for linop in self.linops])
            new_shape = NS(max_ishape, max_oshape)
            if self.idim is None:  # Vertical (inner product)
                linops = [linop.N for linop in self.linops]
                linops = standardize_shapes(linops, new_shape)
                new = Add(*linops)
                new.settings = self.settings  # Copy Threadable settings
                return new
            elif self.odim is None:  # Horizontal (outer product)
                rows = []
                new_idim, new_odim = self._get_new_normal_io_dims(new_shape, self.idim)
                for linop_left in self.linops:
                    row = []
                    for linop_right in self.linops:
                        if linop_left == linop_right:
                            new_linop = linop_right.N
                        else:
                            new_linop = linop_left.H @ linop_right
                        row.append(new_linop)
                        row = standardize_shapes(row, new_shape)
                    rows.append(
                        self.spinoff(
                            linops=row, shape=new_shape, idim=new_idim, odim=None
                        )
                    )
                # rows = standardize_shapes(rows, new_shape)
                return self.spinoff(rows, shape=new_shape, idim=None, odim=new_odim)
            else:  # Diagonal
                diag = []
                new_idim, new_odim = self._get_new_normal_io_dims(new_shape, self.idim)
                for linop in self.linops:
                    diag.append(linop.N)
                diag = standardize_shapes(diag, new_shape)
                return self.spinoff(diag, shape=new_shape, idim=new_idim, odim=new_odim)
        return super().normal(inner)

    @staticmethod
    def _get_new_normal_io_dims(new_shape, dim) -> tuple:
        i = new_shape.ishape.index(dim)
        new_idim = new_shape.ishape[i]
        new_odim = new_shape.oshape[i]
        return new_idim, new_odim

    @staticmethod
    def _check_linop_compatibility(linops: list[NamedLinop]):
        """Ensure linops can actually be concatenated along the requested dimension"""
        target_shape = linops[0].shape
        for linop in linops:
            if not (
                isequal(target_shape.ishape, linop.ishape)
                and isequal(target_shape.oshape, linop.oshape)
            ):
                raise ValueError(
                    f"Incompatible linops being stacked. Target shape: {target_shape} but got linop shape: {linop.shape}"
                )

    def _setup_indices(self, idim, odim):
        ishape = self.linops[0].ishape
        oshape = self.linops[0].oshape
        self.idim, self.isizes, self.islices = self._setup_dim(idim, ishape)
        self.odim, self.osizes, self.oslices = self._setup_dim(odim, oshape)

        if self.idim is None and self.odim is None:
            raise ValueError(f"At least one of idim and odim cannot be None.")

        self.idim_idx = self._infer_dim_idx(self.idim, ishape)
        self.odim_idx = self._infer_dim_idx(self.odim, oshape)

    def _setup_dim(self, dim, shape):
        if dim is not None:
            _dim = ND.infer(dim)
            if any(linop.size(_dim) is None for linop in self.linops):
                raise ValueError(
                    f"Found linop with undefined size for dim {_dim} when attempting concat."
                )
            _sizes = [linop.size(_dim) for linop in self.linops]
            _slices = torch.tensor(_sizes).cumsum(0)  # Keep on CPU
        else:
            _dim = None
            _sizes = None
            _slices = None
        return _dim, _sizes, _slices

    @staticmethod
    def _infer_dim_idx(dim: ND, shape: tuple[ND, ...]) -> int:
        """Get index of dim within requested shape tuple

        Tries to infer index in the presence of ellipses "..." shapes
        Returns positive int if possible
        Otherwise, tries to return negative int
        Fails if neither is possible.

        """
        if dim is None:
            return None
        if dim not in shape:
            raise ValueError(
                f"Provided concat dimension {dim} not found in shape {shape}"
            )
        shape_list = [str(s) for s in shape]
        dim_idx = shape_list.index(str(dim))
        pre, post = shape_list[:dim_idx], shape_list[dim_idx + 1 :]
        if ELLIPSES in pre:
            if ELLIPSES in post:
                raise ValueError(
                    f"Cannot infer concat dimension for dim {dim} from shape {shape}"
                )
            else:
                return -(len(post) + 1)
        return len(pre)

    def __getitem__(self, idx):
        linops = self.linops[idx]
        if isinstance(linops, NamedLinop):
            return linops
        return self.spinoff(linops, idim=self.idim, odim=self.odim)

    def spinoff(self, linops=None, shape=None, idim=None, odim=None):
        """Helper function for creating a new linop using the provided inputs.

        Preserves settings from the original linop.
        """
        linops = linops if linops is not None else self.linops
        shape = shape if shape is not None else self.shape
        new = copy(self)
        new.shape = shape
        new.linops = nn.ModuleList(linops)
        new._setup_indices(idim=idim, odim=odim)
        return new

    def __len__(self):
        return len(self.linops)

    def __repr__(self):
        output = ""
        output += INDENT.indent(self.repr_name + f"({self._shape}\n")
        with INDENT:
            for linop in self.linops:
                output += repr(linop) + "\n"
            output += INDENT.indent(f"idim = {self.idim}, odim = {self.odim}\n")
        output += INDENT.indent(")")
        return output

__init__

__init__(
    *linops,
    idim: Optional[NamedDimension | str] = None,
    odim: Optional[NamedDimension | str] = None,
    **kwargs,
)
PARAMETER DESCRIPTION
*linops

The linops to concatenate.

TYPE: NamedLinop DEFAULT: ()

idim

Input dimension along which to concatenate. If None, the input is not concatenated (all linops receive the same input).

TYPE: str or NamedDimension DEFAULT: None

odim

Output dimension along which to concatenate. If None, the output is not concatenated (outputs are summed).

TYPE: str or NamedDimension DEFAULT: None

Source code in src/torchlinops/linops/concat.py
def __init__(
    self,
    *linops,
    idim: Optional[ND | str] = None,
    odim: Optional[ND | str] = None,
    **kwargs,
):
    """
    Parameters
    ----------
    *linops : NamedLinop
        The linops to concatenate.
    idim : str or ND, optional
        Input dimension along which to concatenate. If ``None``, the input
        is not concatenated (all linops receive the same input).
    odim : str or ND, optional
        Output dimension along which to concatenate. If ``None``, the output
        is not concatenated (outputs are summed).
    """
    self._check_linop_compatibility(linops)
    super().__init__(NS(linops[0].ishape, linops[0].oshape), **kwargs)
    self.linops = nn.ModuleList(list(linops))
    self._setup_indices(idim, odim)

normal

normal(inner=None)
Source code in src/torchlinops/linops/concat.py
def normal(self, inner=None):
    if inner is None:
        # Standardize on this shape
        max_ishape = max_shape([linop.N.ishape for linop in self.linops])
        max_oshape = max_shape([linop.N.oshape for linop in self.linops])
        new_shape = NS(max_ishape, max_oshape)
        if self.idim is None:  # Vertical (inner product)
            linops = [linop.N for linop in self.linops]
            linops = standardize_shapes(linops, new_shape)
            new = Add(*linops)
            new.settings = self.settings  # Copy Threadable settings
            return new
        elif self.odim is None:  # Horizontal (outer product)
            rows = []
            new_idim, new_odim = self._get_new_normal_io_dims(new_shape, self.idim)
            for linop_left in self.linops:
                row = []
                for linop_right in self.linops:
                    if linop_left == linop_right:
                        new_linop = linop_right.N
                    else:
                        new_linop = linop_left.H @ linop_right
                    row.append(new_linop)
                    row = standardize_shapes(row, new_shape)
                rows.append(
                    self.spinoff(
                        linops=row, shape=new_shape, idim=new_idim, odim=None
                    )
                )
            # rows = standardize_shapes(rows, new_shape)
            return self.spinoff(rows, shape=new_shape, idim=None, odim=new_odim)
        else:  # Diagonal
            diag = []
            new_idim, new_odim = self._get_new_normal_io_dims(new_shape, self.idim)
            for linop in self.linops:
                diag.append(linop.N)
            diag = standardize_shapes(diag, new_shape)
            return self.spinoff(diag, shape=new_shape, idim=new_idim, odim=new_odim)
    return super().normal(inner)

torchlinops.linops.Stack

Bases: Threadable, NamedLinop

Concatenate some linops along a new dimension.

Linops need not output tensors of the same size, but they should output tensors of the same number of dimensions.

Stacking type depends on dimensions provided:

Horizontal stacking (stacking along an input dimension)::

A B C

Vertical stacking (stacking along an output dimension)::

A
B
C

Diagonal stacking (stacking along separate input and output dimensions)::

A . .
. B .
. . C

Inherits from Threadable to support parallel execution of sub-linops. When threaded=True (default), each sub-linop is executed in parallel using a ThreadPoolExecutor.

Note that shared linops (e.g., Stack(A, A, odim_and_idx=("L", 0))) are automatically shallow-copied to ensure independent identity for threading, while still sharing tensor data. See Threadable for details.

ATTRIBUTE DESCRIPTION
linops

The list of linops being stacked.

TYPE: ModuleList

threaded

Whether to run sub-linops in parallel. Default is True.

TYPE: bool

num_workers

Number of worker threads. If None, defaults to the number of sub-linops.

TYPE: int | None

idim

Input stacking dimension name.

TYPE: NamedDimension | None

idim_idx

Index position of the input stacking dimension.

TYPE: int | None

odim

Output stacking dimension name.

TYPE: NamedDimension | None

odim_idx

Index position of the output stacking dimension.

TYPE: int | None

Source code in src/torchlinops/linops/stack.py
class Stack(Threadable, NamedLinop):
    """Concatenate some linops along a new dimension.

    Linops need not output tensors of the same size, but they should
    output tensors of the same number of dimensions.

    Stacking type depends on dimensions provided:

    Horizontal stacking (stacking along an input dimension)::

        A B C

    Vertical stacking (stacking along an output dimension)::

        A
        B
        C

    Diagonal stacking (stacking along separate input and output dimensions)::

        A . .
        . B .
        . . C

    Inherits from ``Threadable`` to support parallel execution of sub-linops.
    When ``threaded=True`` (default), each sub-linop is executed in parallel
    using a ThreadPoolExecutor.

    Note that shared linops (e.g., ``Stack(A, A, odim_and_idx=("L", 0))``) are
    automatically shallow-copied to ensure independent identity for threading,
    while still sharing tensor data. See ``Threadable`` for details.

    Attributes
    ----------
    linops : nn.ModuleList
        The list of linops being stacked.
    threaded : bool
        Whether to run sub-linops in parallel. Default is True.
    num_workers : int | None
        Number of worker threads. If None, defaults to the number of sub-linops.
    idim : ND | None
        Input stacking dimension name.
    idim_idx : int | None
        Index position of the input stacking dimension.
    odim : ND | None
        Output stacking dimension name.
    odim_idx : int | None
        Index position of the output stacking dimension.
    """

    def __init__(
        self,
        *linops: NamedLinop,
        idim_and_idx: tuple[Optional[ND | str], Optional[int]] = (None, None),
        odim_and_idx: tuple[Optional[ND | str], Optional[int]] = (None, None),
        **kwargs,
    ):
        """
        Parameters
        ----------
        *linops : NamedLinop
            The linops to stack.
        idim_and_idx : tuple, optional
            Tuple of ``(dim_name, index_tensor)`` for the input stacking dimension.
        odim_and_idx : tuple, optional
            Tuple of ``(dim_name, index_tensor)`` for the output stacking dimension.
        """

        self.idim, self.idim_idx, ishape = self._get_dim_and_idx(
            *idim_and_idx, linops[0].ishape
        )
        self.odim, self.odim_idx, oshape = self._get_dim_and_idx(
            *odim_and_idx, linops[0].oshape
        )

        # Initialize parent class
        super().__init__(NS(ishape, oshape), **kwargs)
        self.linops = nn.ModuleList(list(linops))
        self._check_linop_compatibility()

    @staticmethod
    def _get_dim_and_idx(dim, idx, shape):
        if dim is not None:
            dim = ND.infer(dim)
            if dim in shape:
                raise ValueError(
                    f"Stack linop attempting to add dim {dim} to shape {shape} but shape already contains {dim}"
                )
            shape = list(shape)
            shape.insert(idx, dim)
        else:
            dim = None
            idx = None
        return dim, idx, shape

    @staticmethod
    def fn(stack, x, /):
        return stack._fn(
            x,
            stack.linops,
            stack.idim_idx,
            stack.odim_idx,
            stack.threaded,
            stack.num_workers,
        )

    @staticmethod
    def adj_fn(stack, x, /):
        adj_linops = [linop.H for linop in stack.linops]
        return stack._fn(
            x,
            adj_linops,
            stack.odim_idx,
            stack.idim_idx,
            stack.threaded,
            stack.num_workers,
        )

    @staticmethod
    def _fn(
        x: Tensor,
        linops,
        idim_idx,
        odim_idx,
        threaded: bool = False,
        num_workers: int | None = None,
    ):
        """Unifies forward and adjoint functionality for stacked linops"""
        if idim_idx is not None:  # Diagonal, Horizontal
            if len(linops) != x.shape[idim_idx]:
                raise ValueError(
                    f"Stack Linop expecting input of size {len(linops)} at dim {idim_idx} got input of size {x.shape} with non-matching stack size {x.shape[idim_idx]}"
                )
            xs = x.tensor_split(len(linops), idim_idx)
            xs = [xi.squeeze(idim_idx) for xi in xs]
        else:  # Vertical
            xs = [x] * len(linops)

        if odim_idx is not None:  # Diagonal, Vertical
            if threaded:
                ys = _threaded_apply(list(linops), xs, num_workers)
            else:
                ys = [linop(xi) for xi, linop in zip(xs, linops)]
            return torch.stack(ys, dim=odim_idx)

        # Horizontal
        if threaded:
            y = _threaded_apply_sum_reduce(list(linops), xs, num_workers)
        else:
            y = 0
            for xi, linop in zip(xs, linops):
                y += linop(xi)
        return y

    def size(self, dim) -> int | None:
        if dim == self.idim or dim == self.odim:
            return len(self.linops)
        else:
            # https://github.com/pytorch/pytorch/issues/80821
            return self.linops[0].size(dim)  # type: ignore

    def split_forward(self, ibatch, obatch):
        """Split stack linop"""
        linop_idxs = set(range(len(self.linops)))
        for i, slc in enumerate(ibatch):
            if i == self.idim_idx:
                linop_idxs &= set(slice2range(slc, len(self.linops)))
        for i, slc in enumerate(obatch):
            if i == self.odim_idx:
                linop_idxs &= set(slice2range(slc, len(self.linops)))

        if len(linop_idxs) == 0:
            # No linops satisfy this slice (diagonal stacking)
            return Zero(self.ishape, self.oshape)
        linop_idxs = sorted(list(linop_idxs))
        output_linops = []
        # Remove stack dims from slice batch
        if self.idim_idx is not None:
            ibatch = ibatch.copy()
            ibatch.pop(self.idim_idx)
        if self.odim_idx is not None:
            obatch = obatch.copy()
            obatch.pop(self.odim_idx)

        # Slice each sub-linop
        for i in linop_idxs:
            linop = self.linops[i]
            islices = {dim: slc for dim, slc in zip(linop.ishape, ibatch)}
            oslices = {dim: slc for dim, slc in zip(linop.oshape, obatch)}
            slices = strict_update(islices, oslices)
            output_linops.append(linop.split(linop, slices))
        return self.spinoff(output_linops)

    def split_data(self, ibatch, obatch, data_list):
        """Split stack linop, making a new stack linop if necessary

        Parameters
        ----------
        data_list : list, same length as linops
            List of data for each linop in this stack linop

        """
        linop_idxs = set(range(len(self.linops)))
        for i, slc in enumerate(ibatch):
            if i == self.idim_idx:
                linop_idxs &= set(slice2range(slc, len(self.linops)))
        for i, slc in enumerate(obatch):
            if i == self.odim_idx:
                linop_idxs &= set(slice2range(slc, len(self.linops)))

        if len(linop_idxs) == 0:
            # No linops satisfy this slice (diagonal stacking)
            return 0.0  # TODO is this ok
        linop_idxs = sorted(list(linop_idxs))
        output_linop_data = []
        # Remove stack dims from slice batch
        if self.idim_idx is not None:
            ibatch = copy(ibatch)
            ibatch.pop(self.idim_idx)
        if self.odim_idx is not None:
            obatch = copy(obatch)
            obatch.pop(self.odim_idx)

        # Slice each sub-linop
        for i in linop_idxs:
            linop = self.linops[i]
            data = data_list[i]
            output_linop_data.append(linop.split_data(ibatch, obatch, data))
        return output_linop_data

    def adjoint(self):
        adj_linops = [linop.H for linop in self.linops]
        return self.spinoff(
            adj_linops,
            idim_and_idx=(self.odim, self.odim_idx),
            odim_and_idx=(self.idim, self.idim_idx),
        )

    def normal(self, inner=None):
        if inner is None:
            if self.idim is None:  # Vertical (inner product)
                # self.odim is not None
                return Add(*(linop.N for linop in self.linops))
            elif self.odim is None:  # Horizontal (outer product)
                # self.idim is not None
                new_idim, new_odim = self._get_new_normal_io_dims(
                    self._shape, self.idim
                )
                rows = []
                new_shape = self.linops[0].shape.N
                for linop_left in self.linops:
                    row = []
                    for linop_right in self.linops:
                        if linop_left == linop_right:
                            new_linop = linop_right.N
                        else:
                            new_linop = linop_left.H @ linop_right
                            new_linop.ishape = new_shape.ishape
                            new_linop.oshape = new_shape.oshape
                        row.append(new_linop)
                    row = self.spinoff(
                        row,
                        idim_and_idx=(new_idim, self.idim_idx),
                        odim_and_idx=(None, None),
                    )
                    rows.append(row)
                return self.spinoff(
                    rows,
                    idim_and_idx=(None, None),
                    odim_and_idx=(new_odim, self.idim_idx),
                )
            else:  # Diagonal
                # self.idim and self.odim are not None
                diag = []
                new_idim, new_odim = self._get_new_normal_io_dims(
                    self._shape, self.idim
                )
                for linop in self.linops:
                    diag.append(linop.N)
                return self.spinoff(
                    diag,
                    idim_and_idx=(new_idim, self.idim_idx),
                    odim_and_idx=(new_odim, self.odim_idx),
                )
        return super().normal(inner)

    @staticmethod
    def _get_new_normal_io_dims(shape, dim) -> tuple:
        new_shape = shape.N
        i = new_shape.ishape.index(dim)
        new_idim = new_shape.ishape[i]
        new_odim = new_shape.oshape[i]
        return new_idim, new_odim

    def _check_linop_compatibility(self):
        """Ensure linops can actually be concatenated along the requested dimension"""
        target_shape = self.linops[0].shape
        for linop in self.linops:
            if not (
                isequal(target_shape.ishape, linop.ishape)
                and isequal(target_shape.oshape, linop.oshape)
            ):
                raise ValueError(
                    f"Incompatible linops being stacked. Target shape: {target_shape} but got linop shape: {linop.shape}"
                )

    def __getitem__(self, idx):
        linops = self.linops[idx]
        if isinstance(linops, NamedLinop):
            return linops
        return self.spinoff(linops)

    def spinoff(
        self,
        linops=None,
        shape=None,
        idim_and_idx=(None, None),
        odim_and_idx=(None, None),
    ):
        """Helper function for creating a new linop using the provided inputs.

        Preserves settings from the original linop.

        Parameters
        ----------
        linops : list[NamedLinop], optional
            The linops for the new instance. Defaults to self.linops.
        shape : NamedShape, optional
            The shape for the new instance. If None, computed from linops
            and idim/odim.
        idim_and_idx : tuple[ND | None, int | None], optional
            Tuple of (dim_name, index) for input stacking dimension.
            Defaults to (self.idim, self.idim_idx).
        odim_and_idx : tuple[ND | None, int | None], optional
            Tuple of (dim_name, index) for output stacking dimension.
            Defaults to (self.odim, self.odim_idx).
        """
        linops = linops if linops is not None else self.linops

        idim, idim_idx = idim_and_idx
        odim, odim_idx = odim_and_idx

        # Compute shape from linops and dimensions if not provided
        if shape is None:
            _, _, ishape = self._get_dim_and_idx(idim, idim_idx, linops[0].ishape)
            _, _, oshape = self._get_dim_and_idx(odim, odim_idx, linops[0].oshape)
            shape = NS(ishape, oshape)

        new = copy(self)
        new.shape = shape
        new.linops = nn.ModuleList(linops)
        new.idim = idim
        new.idim_idx = idim_idx
        new.odim = odim
        new.odim_idx = odim_idx
        return new

    def __len__(self):
        return len(self.linops)

    def __repr__(self):
        output = ""
        output += INDENT.indent(self.repr_name + f"({self._shape}\n")
        with INDENT:
            for linop in self.linops:
                output += repr(linop) + "\n"
            output += INDENT.indent(f"idim = {self.idim}, odim = {self.odim}\n")
        output += INDENT.indent(")")
        return output

__init__

__init__(
    *linops: NamedLinop,
    idim_and_idx: tuple[
        Optional[NamedDimension | str], Optional[int]
    ] = (None, None),
    odim_and_idx: tuple[
        Optional[NamedDimension | str], Optional[int]
    ] = (None, None),
    **kwargs,
)
PARAMETER DESCRIPTION
*linops

The linops to stack.

TYPE: NamedLinop DEFAULT: ()

idim_and_idx

Tuple of (dim_name, index_tensor) for the input stacking dimension.

TYPE: tuple DEFAULT: (None, None)

odim_and_idx

Tuple of (dim_name, index_tensor) for the output stacking dimension.

TYPE: tuple DEFAULT: (None, None)

Source code in src/torchlinops/linops/stack.py
def __init__(
    self,
    *linops: NamedLinop,
    idim_and_idx: tuple[Optional[ND | str], Optional[int]] = (None, None),
    odim_and_idx: tuple[Optional[ND | str], Optional[int]] = (None, None),
    **kwargs,
):
    """
    Parameters
    ----------
    *linops : NamedLinop
        The linops to stack.
    idim_and_idx : tuple, optional
        Tuple of ``(dim_name, index_tensor)`` for the input stacking dimension.
    odim_and_idx : tuple, optional
        Tuple of ``(dim_name, index_tensor)`` for the output stacking dimension.
    """

    self.idim, self.idim_idx, ishape = self._get_dim_and_idx(
        *idim_and_idx, linops[0].ishape
    )
    self.odim, self.odim_idx, oshape = self._get_dim_and_idx(
        *odim_and_idx, linops[0].oshape
    )

    # Initialize parent class
    super().__init__(NS(ishape, oshape), **kwargs)
    self.linops = nn.ModuleList(list(linops))
    self._check_linop_compatibility()

normal

normal(inner=None)
Source code in src/torchlinops/linops/stack.py
def normal(self, inner=None):
    if inner is None:
        if self.idim is None:  # Vertical (inner product)
            # self.odim is not None
            return Add(*(linop.N for linop in self.linops))
        elif self.odim is None:  # Horizontal (outer product)
            # self.idim is not None
            new_idim, new_odim = self._get_new_normal_io_dims(
                self._shape, self.idim
            )
            rows = []
            new_shape = self.linops[0].shape.N
            for linop_left in self.linops:
                row = []
                for linop_right in self.linops:
                    if linop_left == linop_right:
                        new_linop = linop_right.N
                    else:
                        new_linop = linop_left.H @ linop_right
                        new_linop.ishape = new_shape.ishape
                        new_linop.oshape = new_shape.oshape
                    row.append(new_linop)
                row = self.spinoff(
                    row,
                    idim_and_idx=(new_idim, self.idim_idx),
                    odim_and_idx=(None, None),
                )
                rows.append(row)
            return self.spinoff(
                rows,
                idim_and_idx=(None, None),
                odim_and_idx=(new_odim, self.idim_idx),
            )
        else:  # Diagonal
            # self.idim and self.odim are not None
            diag = []
            new_idim, new_odim = self._get_new_normal_io_dims(
                self._shape, self.idim
            )
            for linop in self.linops:
                diag.append(linop.N)
            return self.spinoff(
                diag,
                idim_and_idx=(new_idim, self.idim_idx),
                odim_and_idx=(new_odim, self.odim_idx),
            )
    return super().normal(inner)