Skip to content

Diagonal & Scalar

torchlinops.linops.Diagonal

Bases: NamedLinop

Elementwise diagonal linear operator \(D(x) = w \odot x\).

The forward operation is pointwise multiplication by a weight tensor w. The adjoint is \(D^H(x) = \bar{w} \odot x\) and the normal is \(D^N(x) = |w|^2 \odot x\).

Because the input and output shapes are identical, Diagonal sets oshape = ishape and keeps them synchronized.

ATTRIBUTE DESCRIPTION
weight

The diagonal weight tensor \(w\).

TYPE: Parameter

broadcast_dims

Dimensions along which the weight is broadcast (not stored explicitly).

TYPE: list

Source code in src/torchlinops/linops/diagonal.py
class Diagonal(NamedLinop):
    """Elementwise diagonal linear operator $D(x) = w \\odot x$.

    The forward operation is pointwise multiplication by a weight tensor *w*.
    The adjoint is $D^H(x) = \\bar{w} \\odot x$ and the normal is
    $D^N(x) = |w|^2 \\odot x$.

    Because the input and output shapes are identical, ``Diagonal`` sets
    ``oshape = ishape`` and keeps them synchronized.

    Attributes
    ----------
    weight : nn.Parameter
        The diagonal weight tensor $w$.
    broadcast_dims : list
        Dimensions along which the weight is broadcast (not stored explicitly).
    """

    def __init__(
        self,
        weight: torch.Tensor,
        ioshape: Optional[Shape] = None,
        broadcast_dims: Optional[Shape] = None,
    ):
        """
        Parameters
        ----------
        weight : Tensor
            The diagonal weight tensor.
        ioshape : Shape, optional
            Named dimensions for input and output (they are the same).
        broadcast_dims : Shape, optional
            Dimensions along which *weight* should be broadcast rather than
            indexed. Useful when the weight has fewer dimensions than the input.
        """
        if ioshape is not None and len(weight.shape) > len(ioshape):
            raise ValueError(
                f"All dimensions must be named or broadcastable, but got weight shape {weight.shape} and ioshape {ioshape}"
            )
        # if broadcast_dims is not None:
        #     warn(
        #         f"broadcast_dims argument is deprecated for torchlinops Diagonal but got {broadcast_dims}",
        #         DeprecationWarning,
        #         stacklevel=2,
        #     )
        super().__init__(NS(ioshape))
        self.weight = nn.Parameter(weight, requires_grad=False)
        # assert (
        #     len(self.ishape) >= len(self.weight.shape)
        # ), f"Weight cannot have fewer dimensions than the input shape: ishape: {self.ishape}, weight: {weight.shape}"
        broadcast_dims = broadcast_dims if broadcast_dims is not None else []
        if ANY in self.ishape:
            broadcast_dims.append(ANY)
        self._shape.broadcast_dims = broadcast_dims

    @classmethod
    def from_weight(
        cls,
        weight: Tensor,
        weight_shape: Shape,
        ioshape: Shape,
        shape_kwargs: Optional[dict] = None,
    ):
        """Construct a ``Diagonal`` by expanding *weight* to match *ioshape* via einops.

        Parameters
        ----------
        weight : Tensor
            The weight tensor in its original (possibly lower-dimensional) shape.
        weight_shape : Shape
            Named dimensions labeling the axes of *weight*.
        ioshape : Shape
            Target named dimensions for the expanded weight.
        shape_kwargs : dict, optional
            Extra keyword arguments forwarded to ``einops.repeat``.

        Returns
        -------
        Diagonal
            A new diagonal linop with the expanded weight.
        """
        shape_kwargs = shape_kwargs if shape_kwargs is not None else {}
        if len(weight.shape) > len(ioshape):
            raise ValueError(
                f"All dimensions must be named or broadcastable, but got weight shape {weight.shape} and ioshape {ioshape}"
            )
        weight = repeat(
            weight,
            f"{' '.join(weight_shape)} -> {' '.join(ioshape)}",
            **shape_kwargs,
        )
        return cls(weight, ioshape)

    @property
    def broadcast_dims(self):
        return self._shape.broadcast_dims

    @broadcast_dims.setter
    def broadcast_dims(self, val):
        self._shape.broadcast_dims = val

    # Override shape setters too
    @NamedLinop.ishape.setter
    def ishape(self, val):
        self._shape.ishape = val
        self._shape.oshape = val

    @NamedLinop.oshape.setter
    def oshape(self, val):
        self._shape.oshape = val
        self._shape.ishape = val

    @staticmethod
    def fn(diagonal, x, /):
        return x * diagonal.weight

    @staticmethod
    def adj_fn(diagonal, x, /):
        return x * torch.conj(diagonal.weight)

    @staticmethod
    def normal_fn(diagonal, x, /):
        return x * torch.abs(diagonal.weight) ** 2

    def adjoint(self):
        adj = copy(self)
        adj.weight = nn.Parameter(
            self.weight.conj(),
            requires_grad=self.weight.requires_grad,
        )
        return adj

    def normal(self, inner=None):
        if inner is None:
            normal = copy(self)
            normal.weight = nn.Parameter(
                torch.abs(self.weight) ** 2,
                requires_grad=self.weight.requires_grad,
            )
            return normal
        return super().normal(inner)

    def split_forward(self, ibatch, obatch):
        weight = self.split_weight(ibatch, obatch, self.weight)
        split = copy(self)
        split.weight = nn.Parameter(weight, requires_grad=self.weight.requires_grad)
        return split

    def split_weight(self, ibatch, obatch, /, weight):
        assert ibatch == obatch, "Diagonal linop must be split identically"
        # Filter out broadcastable dims
        ibatch = [
            slice(None) if dim in self.broadcast_dims else slc
            for slc, dim in zip(ibatch, self.ishape)
        ]
        return weight[tuple(ibatch[-len(weight.shape) :])]

    def size(self, dim: str):
        if dim in self.ishape:
            n_broadcast = len(self.ishape) - len(self.weight.shape)
            if self.ishape.index(dim) < n_broadcast or dim in self.broadcast_dims:
                return None
            else:
                return self.weight.shape[self.ishape.index(dim) - n_broadcast]
        return None

    def __pow__(self, exponent):
        new = copy(self)
        new.weight = nn.Parameter(
            self.weight**exponent,
            requires_grad=self.weight.requires_grad,
        )
        return new

__init__

__init__(
    weight: Tensor,
    ioshape: Optional[Shape] = None,
    broadcast_dims: Optional[Shape] = None,
)
PARAMETER DESCRIPTION
weight

The diagonal weight tensor.

TYPE: Tensor

ioshape

Named dimensions for input and output (they are the same).

TYPE: Shape DEFAULT: None

broadcast_dims

Dimensions along which weight should be broadcast rather than indexed. Useful when the weight has fewer dimensions than the input.

TYPE: Shape DEFAULT: None

Source code in src/torchlinops/linops/diagonal.py
def __init__(
    self,
    weight: torch.Tensor,
    ioshape: Optional[Shape] = None,
    broadcast_dims: Optional[Shape] = None,
):
    """
    Parameters
    ----------
    weight : Tensor
        The diagonal weight tensor.
    ioshape : Shape, optional
        Named dimensions for input and output (they are the same).
    broadcast_dims : Shape, optional
        Dimensions along which *weight* should be broadcast rather than
        indexed. Useful when the weight has fewer dimensions than the input.
    """
    if ioshape is not None and len(weight.shape) > len(ioshape):
        raise ValueError(
            f"All dimensions must be named or broadcastable, but got weight shape {weight.shape} and ioshape {ioshape}"
        )
    # if broadcast_dims is not None:
    #     warn(
    #         f"broadcast_dims argument is deprecated for torchlinops Diagonal but got {broadcast_dims}",
    #         DeprecationWarning,
    #         stacklevel=2,
    #     )
    super().__init__(NS(ioshape))
    self.weight = nn.Parameter(weight, requires_grad=False)
    # assert (
    #     len(self.ishape) >= len(self.weight.shape)
    # ), f"Weight cannot have fewer dimensions than the input shape: ishape: {self.ishape}, weight: {weight.shape}"
    broadcast_dims = broadcast_dims if broadcast_dims is not None else []
    if ANY in self.ishape:
        broadcast_dims.append(ANY)
    self._shape.broadcast_dims = broadcast_dims

from_weight classmethod

from_weight(
    weight: Tensor,
    weight_shape: Shape,
    ioshape: Shape,
    shape_kwargs: Optional[dict] = None,
)

Construct a Diagonal by expanding weight to match ioshape via einops.

PARAMETER DESCRIPTION
weight

The weight tensor in its original (possibly lower-dimensional) shape.

TYPE: Tensor

weight_shape

Named dimensions labeling the axes of weight.

TYPE: Shape

ioshape

Target named dimensions for the expanded weight.

TYPE: Shape

shape_kwargs

Extra keyword arguments forwarded to einops.repeat.

TYPE: dict DEFAULT: None

RETURNS DESCRIPTION
Diagonal

A new diagonal linop with the expanded weight.

Source code in src/torchlinops/linops/diagonal.py
@classmethod
def from_weight(
    cls,
    weight: Tensor,
    weight_shape: Shape,
    ioshape: Shape,
    shape_kwargs: Optional[dict] = None,
):
    """Construct a ``Diagonal`` by expanding *weight* to match *ioshape* via einops.

    Parameters
    ----------
    weight : Tensor
        The weight tensor in its original (possibly lower-dimensional) shape.
    weight_shape : Shape
        Named dimensions labeling the axes of *weight*.
    ioshape : Shape
        Target named dimensions for the expanded weight.
    shape_kwargs : dict, optional
        Extra keyword arguments forwarded to ``einops.repeat``.

    Returns
    -------
    Diagonal
        A new diagonal linop with the expanded weight.
    """
    shape_kwargs = shape_kwargs if shape_kwargs is not None else {}
    if len(weight.shape) > len(ioshape):
        raise ValueError(
            f"All dimensions must be named or broadcastable, but got weight shape {weight.shape} and ioshape {ioshape}"
        )
    weight = repeat(
        weight,
        f"{' '.join(weight_shape)} -> {' '.join(ioshape)}",
        **shape_kwargs,
    )
    return cls(weight, ioshape)

normal

normal(inner=None)
Source code in src/torchlinops/linops/diagonal.py
def normal(self, inner=None):
    if inner is None:
        normal = copy(self)
        normal.weight = nn.Parameter(
            torch.abs(self.weight) ** 2,
            requires_grad=self.weight.requires_grad,
        )
        return normal
    return super().normal(inner)

__pow__

__pow__(exponent)
Source code in src/torchlinops/linops/diagonal.py
def __pow__(self, exponent):
    new = copy(self)
    new.weight = nn.Parameter(
        self.weight**exponent,
        requires_grad=self.weight.requires_grad,
    )
    return new

torchlinops.linops.Scalar

Bases: Diagonal

Scalar multiplication operator \(S(x) = \alpha x\).

A special case of Diagonal where the weight is a scalar, making it trivially splittable (the same scalar applies to every tile).

Source code in src/torchlinops/linops/scalar.py
class Scalar(Diagonal):
    """Scalar multiplication operator $S(x) = \\alpha x$.

    A special case of ``Diagonal`` where the weight is a scalar, making it
    trivially splittable (the same scalar applies to every tile).
    """

    def __init__(self, weight, ioshape: Optional[Shape] = None):
        """
        Parameters
        ----------
        weight : float or Tensor
            The scalar multiplier $\\alpha$.
        ioshape : Shape, optional
            Named dimensions (same for input and output).
        """
        if not isinstance(weight, torch.Tensor):
            weight = torch.tensor(weight)
        ioshape = default_to(("...",), ioshape)
        super().__init__(weight, ioshape=ioshape)

    def split_weight(self, ibatch, obatch, /, weight):
        assert ibatch == obatch, "Scalar linop must be split identically"
        return weight

    def size(self, dim: str):
        return None

__init__

__init__(weight, ioshape: Optional[Shape] = None)
PARAMETER DESCRIPTION
weight

The scalar multiplier \(\alpha\).

TYPE: float or Tensor

ioshape

Named dimensions (same for input and output).

TYPE: Shape DEFAULT: None

Source code in src/torchlinops/linops/scalar.py
def __init__(self, weight, ioshape: Optional[Shape] = None):
    """
    Parameters
    ----------
    weight : float or Tensor
        The scalar multiplier $\\alpha$.
    ioshape : Shape, optional
        Named dimensions (same for input and output).
    """
    if not isinstance(weight, torch.Tensor):
        weight = torch.tensor(weight)
    ioshape = default_to(("...",), ioshape)
    super().__init__(weight, ioshape=ioshape)

torchlinops.linops.Zero

Bases: NamedLinop

Zero operator \(0(x) = 0\).

Always returns a zero tensor with the same shape as the input.

Source code in src/torchlinops/linops/identity.py
class Zero(NamedLinop):
    """Zero operator $0(x) = 0$.

    Always returns a zero tensor with the same shape as the input.
    """

    def __init__(self, ishape=("...",), oshape=None):
        super().__init__(NS(ishape, oshape))

    @staticmethod
    def fn(self, x, /):
        return x.zero_()

    @staticmethod
    def adj_fn(self, x, /):
        return x.zero_()

    @staticmethod
    def normal_fn(self, x, /):
        return x.zero_()

    def split_forward(self, ibatch, obatch):
        return self

__init__

__init__(ishape=('...',), oshape=None)
Source code in src/torchlinops/linops/identity.py
def __init__(self, ishape=("...",), oshape=None):
    super().__init__(NS(ishape, oshape))