Skip to content

Add

torchlinops.linops.Add

Bases: Threadable, NamedLinop

The sum of one or more linear operators.

Inherits from Threadable to support parallel execution of sub-linops. When threaded=True (default), each sub-linop is executed in parallel using a ThreadPoolExecutor, which is useful for I/O-bound operations or operations that release the GIL (e.g., PyTorch tensor operations).

Note that shared linops (e.g., Add(A, A)) are automatically shallow- copied to ensure independent identity for threading, while still sharing tensor data. See Threadable for details.

ATTRIBUTE DESCRIPTION
linops

The list of linops being added together.

TYPE: ModuleList

threaded

Whether to run sub-linops in parallel. Default is True.

TYPE: bool

num_workers

Number of worker threads. If None, defaults to the number of sub-linops.

TYPE: int | None

Source code in src/torchlinops/linops/add.py
class Add(Threadable, NamedLinop):
    """The sum of one or more linear operators.

    Inherits from ``Threadable`` to support parallel execution of sub-linops.
    When ``threaded=True`` (default), each sub-linop is executed in parallel
    using a ThreadPoolExecutor, which is useful for I/O-bound operations or
    operations that release the GIL (e.g., PyTorch tensor operations).

    Note that shared linops (e.g., ``Add(A, A)``) are automatically shallow-
    copied to ensure independent identity for threading, while still sharing
    tensor data. See ``Threadable`` for details.

    Attributes
    ----------
    linops : nn.ModuleList
        The list of linops being added together.
    threaded : bool
        Whether to run sub-linops in parallel. Default is True.
    num_workers : int | None
        Number of worker threads. If None, defaults to the number of sub-linops.
    """

    def __init__(self, *linops, **kwargs):
        """
        Parameters
        ----------
        *linops : tuple[NamedLinop]
            The linear operators to be added together.
        """
        assert all(isequal(linop.ishape, linops[0].ishape) for linop in linops), (
            f"Add: All linops must share same ishape. Found {linops}"
        )
        assert all(isequal(linop.oshape, linops[0].oshape) for linop in linops), (
            f"Add: All linops must share same oshape. Linops: {linops}"
        )
        super().__init__(NS(linops[0].ishape, linops[0].oshape), **kwargs)
        self.linops = nn.ModuleList(linops)

    @staticmethod
    def fn(add, x: torch.Tensor, /):
        if add.threaded:
            return add.threaded_apply_sum_reduce([x] * len(add.linops), add.num_workers)
        return sum(linop(x) for linop in add.linops)

    @staticmethod
    def adj_fn(add, x: torch.Tensor, /):
        if add.threaded:
            adj_linops = [linop.H for linop in add.linops]
            return add.threaded_apply_sum_reduce([x] * len(adj_linops), add.num_workers)
        return sum(linop.H(x) for linop in add.linops)

    def split_forward(self, ibatch, obatch):
        split = copy(self)
        linops = [linop.split_forward(ibatch, obatch) for linop in self.linops]
        split.linops = nn.ModuleList(linops)
        return split

    def adjoint(self):
        adj = copy(self)
        adj.linops = nn.ModuleList([linop.adjoint() for linop in self.linops])
        adj.shape = self.shape.adjoint()
        return adj

    def normal(self, inner=None):
        if inner is None:
            max_ishape = max_shape([linop.N.ishape for linop in self.linops])
            max_oshape = max_shape([linop.N.oshape for linop in self.linops])
            new_shape = NS(max_ishape, max_oshape)
            all_combinations = []
            for left_linop in self.linops:
                for right_linop in self.linops:
                    if left_linop == right_linop:
                        all_combinations.append(left_linop.N)
                    else:
                        all_combinations.append(left_linop.H @ right_linop)
            all_combinations = standardize_shapes(all_combinations, new_shape)
            normal = copy(self)
            normal.linops = nn.ModuleList(list(all_combinations))
            normal.shape = normal.linops[0].shape
            return normal
        return super().normal(inner)

    def size(self, dim):
        for linop in self.linops:
            out = linop.size(dim)
            if out is not None:
                return out
        return None

    @property
    def dims(self):
        return set().union(*[linop.dims for linop in self.linops])

    @property
    def H(self):
        try:
            if config.cache_adjoint_normal:
                config._warn_if_caching_enabled()
                if self._adjoint is None:
                    self._adjoint = [self.adjoint()]
                return self._adjoint[0]
            return self.adjoint()
        except AttributeError as e:
            raise RuntimeError(f"AttributeError in {type(self).__name__}.H: {e}") from e

    @property
    def N(self):
        try:
            if config.cache_adjoint_normal:
                config._warn_if_caching_enabled()
                if self._normal is None:
                    self._normal = [self.normal()]
                return self._normal[0]
            return self.normal()
        except AttributeError as e:
            raise RuntimeError(f"AttributeError in {type(self).__name__}.N: {e}") from e

    def flatten(self):
        return [self]

    def __getitem__(self, idx):
        linops = self.linops[idx]
        if isinstance(linops, NamedLinop):
            return linops
        new = copy(self)
        new.linops = nn.ModuleList(linops)
        return new

    def __len__(self):
        return len(self.linops)

    def __repr__(self):
        linop_chain = " + ".join(repr(linop) for linop in self.linops)
        return linop_chain

__init__

__init__(*linops, **kwargs)
PARAMETER DESCRIPTION
*linops

The linear operators to be added together.

TYPE: tuple[NamedLinop] DEFAULT: ()

Source code in src/torchlinops/linops/add.py
def __init__(self, *linops, **kwargs):
    """
    Parameters
    ----------
    *linops : tuple[NamedLinop]
        The linear operators to be added together.
    """
    assert all(isequal(linop.ishape, linops[0].ishape) for linop in linops), (
        f"Add: All linops must share same ishape. Found {linops}"
    )
    assert all(isequal(linop.oshape, linops[0].oshape) for linop in linops), (
        f"Add: All linops must share same oshape. Linops: {linops}"
    )
    super().__init__(NS(linops[0].ishape, linops[0].oshape), **kwargs)
    self.linops = nn.ModuleList(linops)