Skip to content

Dense

torchlinops.linops.Dense

Bases: NamedLinop

Dense matrix-vector multiply.

"Dense" is used to distinguish from "sparse" linear operators. This operator performs a matrix-vector multiplication, potentially with batch and broadcast dimensions, implemented via einops.einsum.

The core operation is:

\(y_{o\dots} = \sum_{i\dots} W_{i\dots, o\dots} x_{i\dots}\)

where \(x\) is the input, \(W\) is the weight matrix, and \(y\) is the output. \(i\dots\) and \(o\dots\) represent the input and output dimensions involved in the multiplication. Other dimensions are treated as batch or broadcast dimensions.

Examples:

A simple batched multiplication:

  • Input \(x\) shape: \((A, N_x, N_y)\)
  • Weight \(W\) shape: \((A, T)\)
  • Output \(y\) shape: \((T, N_x, N_y)\)

Here, \(A\) is the input feature dimension, \(T\) is the output feature dimension, and \((N_x, N_y)\) are broadcast dimensions. The operation is:

\(y_{t, n_x, n_y} = \sum_{a} W_{a, t} x_{a, n_x, n_y}\)

Another example with a batch dimension \(C\) shared between input and weights:

  • Input \(x\) shape: \((C, A, N_x, N_y)\)
  • Weight \(W\) shape: \((C, A, A_1)\)
  • Output \(y\) shape: \((C, A_1, N_x, N_y)\)

The operation is:

\(y_{c, a_1, n_x, n_y} = \sum_{a} W_{c, a, a_1} x_{c, a, n_x, n_y}\)

Source code in src/torchlinops/linops/dense.py
class Dense(NamedLinop):
    r"""Dense matrix-vector multiply.

    "Dense" is used to distinguish from "sparse" linear operators. This
    operator performs a matrix-vector multiplication, potentially with batch
    and broadcast dimensions, implemented via ``einops.einsum``.

    The core operation is:

    $y_{o\dots} = \sum_{i\dots} W_{i\dots, o\dots} x_{i\dots}$

    where $x$ is the input, $W$ is the weight matrix, and
    $y$ is the output. $i\dots$ and $o\dots$ represent
    the input and output dimensions involved in the multiplication. Other
    dimensions are treated as batch or broadcast dimensions.

    Examples
    --------
    A simple batched multiplication:

    - Input $x$ shape: $(A, N_x, N_y)$
    - Weight $W$ shape: $(A, T)$
    - Output $y$ shape: $(T, N_x, N_y)$

    Here, $A$ is the input feature dimension, $T$ is the output
    feature dimension, and $(N_x, N_y)$ are broadcast dimensions.
    The operation is:

    $y_{t, n_x, n_y} = \sum_{a} W_{a, t} x_{a, n_x, n_y}$

    Another example with a batch dimension $C$ shared between input
    and weights:

    - Input $x$ shape: $(C, A, N_x, N_y)$
    - Weight $W$ shape: $(C, A, A_1)$
    - Output $y$ shape: $(C, A_1, N_x, N_y)$

    The operation is:

    $y_{c, a_1, n_x, n_y} = \sum_{a} W_{c, a, a_1} x_{c, a, n_x, n_y}$

    """

    def __init__(
        self,
        weight: Tensor,
        weightshape: Shape,
        ishape: Shape,
        oshape: Shape,
        broadcast_dims: Optional[list] = None,
    ):
        """
        Parameters
        ----------
        weight : Tensor
            The dense matrix used for this linop.
        weightshape : Shape
            The shape of the matrix, in symbolic form.
        ishape : Shape
            The input shape of the matrix.
        oshape : Shape
            The output shape of the matrix.
        broadcast_dims : list
            A list of the dimensions of weight that are intended to be broadcasted over the input.
            As such, they are excluded from splitting.
        """
        super().__init__(NS(ishape, oshape))
        self.weight = nn.Parameter(weight, requires_grad=False)
        self._shape.weightshape = weightshape

        broadcast_dims = broadcast_dims if broadcast_dims is not None else []
        self._shape.broadcast_dims = broadcast_dims

    @property
    def weightshape(self) -> Shape:
        weightshape = self._shape.weightshape
        if not isinstance(weightshape, Sequence):
            raise ValueError(
                f"Expected weightshape to be a sequence but got {type(weightshape)}: {weightshape}"
            )
        return weightshape

    @property
    def broadcast_dims(self):
        return self._shape.broadcast_dims

    @property
    def forward_einstr(self):
        return f"{self.einstr(self.ishape)},{self.einstr(self.weightshape)}->{self.einstr(self.oshape)}"

    @property
    def adj_einstr(self):
        return f"{self.einstr(self.oshape)},{self.einstr(self.weightshape)}->{self.einstr(self.ishape)}"

    @staticmethod
    def einstr(arr):
        return " ".join(str(s) for s in arr)

    @staticmethod
    def fn(dense, x, /):
        return einsum(x, dense.weight, dense.forward_einstr)

    @staticmethod
    def adj_fn(dense, x, /):
        return einsum(x, dense.weight.conj(), dense.adj_einstr)

    def adjoint(self):
        adj = copy(self)
        adj.weight = nn.Parameter(
            self.weight.conj(), requires_grad=adj.weight.requires_grad
        )
        adj._shape = adj._shape.H
        adj._update_suffix(adjoint=self._name is not None)
        return adj

    def normal(self, inner=None):
        """Compute the normal operator (adjoint times forward).

        Parameters
        ----------
        inner : NamedLinop, optional
            An optional inner operator to sandwich between the adjoint and
            forward. If None, consolidates two Dense operators into a single
            Dense.

        Returns
        -------
        NamedLinop
            The normal operator.

        Notes
        -----
        If inner is None, consolidate two Dense's into a single Dense
        ishape: [A B X Y]
        oshape: [C D X Y]
        wshape: [A B C D]

        Needs to become
        ishape: [A B X Y]
        oshape: [A1 B1 X Y]
        wshape: [A B A1 B1]

        New weight is attained as
        einsum(weight.conj(), weight, 'A1 B1 C D, A B C D -> A B A1 B1')

        -----
        ishape: [C A]
        oshape: [C1 A]
        wshape = [C C1]

        Needs to become
        ishape: [C A]
        oshape: [C2 A]
        wshape = [C C2]

        einsum(weight.conj(), weight, 'C1 C2, C C1 -> C C2')


        """
        new_oshape = []
        weight_conj_shape = list(deepcopy(self.weightshape))
        wdiag_shape = []
        wout_shape = []
        win_shape = []
        used_shapes = self.ishape + self.oshape
        shape_updates = {}
        # Make new oshape and weight shape
        # Rules:
        # New weightshape
        #   If dim appears in ishape and weightshape but not oshape -> increment
        #   If dim appears in ishape and weightshape AND oshape -> don't increment
        #   If dim doesn't appear in ishape or weightshape -> don't add it to new weightshape
        # Other rules:
        # new ishape is same as old ishape
        # new oshape is ishape but updated with new dimensions
        for dim in self.ishape:
            if dim in self.weightshape:
                if dim not in self.oshape:
                    win_shape.append(dim)
                    new_dim = dim.next_unused(used_shapes)
                    shape_updates[dim] = new_dim
                    wout_shape.append(new_dim)
                else:
                    wdiag_shape.append(dim)
                    new_dim = dim
                i = weight_conj_shape.index(dim)
                weight_conj_shape[i] = new_dim
            else:
                new_dim = dim
            new_oshape.append(new_dim)

        if config.inner_not_relevant(inner):
            # Consolidate dense and dense adjoint into single dense
            new_weight_shape = wdiag_shape + wout_shape + win_shape
            einstr = shapes2einstr(
                self.weightshape,
                weight_conj_shape,
                new_weight_shape,
            )
            new_weight = einsum(self.weight, self.weight.conj(), einstr)
            normal = type(self)(
                new_weight,
                tuple(new_weight_shape),
                self.ishape,
                new_oshape,
            )
            normal._name = self._name
            normal._update_suffix(normal=self._name is not None)
            normal._shape_updates = shape_updates
            return normal
        _shape_updates = getattr(inner, "_shape_updates", {})
        _shape_updates.update(shape_updates)
        pre = copy(self)
        pre.oshape = inner.ishape
        post = self.adjoint()  # Copy happens inside adjoint
        post.ishape = inner.oshape
        post.oshape = new_oshape
        normal = post @ inner @ pre
        normal._shape_updates = _shape_updates
        return normal

    def split_forward(self, ibatch, obatch):
        weight = self.split_weight(ibatch, obatch, self.weight)
        out = copy(self)
        out.weight = nn.Parameter(weight, requires_grad=self.weight.requires_grad)
        return out

    def split_weight(self, ibatch, obatch, /, weight):
        weightbatch = [slice(None)] * len(self.weightshape)
        for dim, batch in zip(self.ishape, ibatch):
            if dim in self.weightshape and dim not in self.broadcast_dims:
                weightbatch[self.weightshape.index(dim)] = batch
        for dim, batch in zip(self.oshape, obatch):
            if dim in self.weightshape and dim not in self.broadcast_dims:
                weightbatch[self.weightshape.index(dim)] = batch
        return weight[tuple(weightbatch)]

    def size(self, dim: str):
        if dim in self.broadcast_dims:
            return None
        if dim in self.weightshape:
            return self.weight.shape[self.weightshape.index(dim)]
        return None

__init__

__init__(
    weight: Tensor,
    weightshape: Shape,
    ishape: Shape,
    oshape: Shape,
    broadcast_dims: Optional[list] = None,
)
PARAMETER DESCRIPTION
weight

The dense matrix used for this linop.

TYPE: Tensor

weightshape

The shape of the matrix, in symbolic form.

TYPE: Shape

ishape

The input shape of the matrix.

TYPE: Shape

oshape

The output shape of the matrix.

TYPE: Shape

broadcast_dims

A list of the dimensions of weight that are intended to be broadcasted over the input. As such, they are excluded from splitting.

TYPE: list DEFAULT: None

Source code in src/torchlinops/linops/dense.py
def __init__(
    self,
    weight: Tensor,
    weightshape: Shape,
    ishape: Shape,
    oshape: Shape,
    broadcast_dims: Optional[list] = None,
):
    """
    Parameters
    ----------
    weight : Tensor
        The dense matrix used for this linop.
    weightshape : Shape
        The shape of the matrix, in symbolic form.
    ishape : Shape
        The input shape of the matrix.
    oshape : Shape
        The output shape of the matrix.
    broadcast_dims : list
        A list of the dimensions of weight that are intended to be broadcasted over the input.
        As such, they are excluded from splitting.
    """
    super().__init__(NS(ishape, oshape))
    self.weight = nn.Parameter(weight, requires_grad=False)
    self._shape.weightshape = weightshape

    broadcast_dims = broadcast_dims if broadcast_dims is not None else []
    self._shape.broadcast_dims = broadcast_dims

normal

normal(inner=None)

Compute the normal operator (adjoint times forward).

PARAMETER DESCRIPTION
inner

An optional inner operator to sandwich between the adjoint and forward. If None, consolidates two Dense operators into a single Dense.

TYPE: NamedLinop DEFAULT: None

RETURNS DESCRIPTION
NamedLinop

The normal operator.

Notes

If inner is None, consolidate two Dense's into a single Dense ishape: [A B X Y] oshape: [C D X Y] wshape: [A B C D]

Needs to become ishape: [A B X Y] oshape: [A1 B1 X Y] wshape: [A B A1 B1]

New weight is attained as einsum(weight.conj(), weight, 'A1 B1 C D, A B C D -> A B A1 B1')


ishape: [C A] oshape: [C1 A] wshape = [C C1]

Needs to become ishape: [C A] oshape: [C2 A] wshape = [C C2]

einsum(weight.conj(), weight, 'C1 C2, C C1 -> C C2')

Source code in src/torchlinops/linops/dense.py
def normal(self, inner=None):
    """Compute the normal operator (adjoint times forward).

    Parameters
    ----------
    inner : NamedLinop, optional
        An optional inner operator to sandwich between the adjoint and
        forward. If None, consolidates two Dense operators into a single
        Dense.

    Returns
    -------
    NamedLinop
        The normal operator.

    Notes
    -----
    If inner is None, consolidate two Dense's into a single Dense
    ishape: [A B X Y]
    oshape: [C D X Y]
    wshape: [A B C D]

    Needs to become
    ishape: [A B X Y]
    oshape: [A1 B1 X Y]
    wshape: [A B A1 B1]

    New weight is attained as
    einsum(weight.conj(), weight, 'A1 B1 C D, A B C D -> A B A1 B1')

    -----
    ishape: [C A]
    oshape: [C1 A]
    wshape = [C C1]

    Needs to become
    ishape: [C A]
    oshape: [C2 A]
    wshape = [C C2]

    einsum(weight.conj(), weight, 'C1 C2, C C1 -> C C2')


    """
    new_oshape = []
    weight_conj_shape = list(deepcopy(self.weightshape))
    wdiag_shape = []
    wout_shape = []
    win_shape = []
    used_shapes = self.ishape + self.oshape
    shape_updates = {}
    # Make new oshape and weight shape
    # Rules:
    # New weightshape
    #   If dim appears in ishape and weightshape but not oshape -> increment
    #   If dim appears in ishape and weightshape AND oshape -> don't increment
    #   If dim doesn't appear in ishape or weightshape -> don't add it to new weightshape
    # Other rules:
    # new ishape is same as old ishape
    # new oshape is ishape but updated with new dimensions
    for dim in self.ishape:
        if dim in self.weightshape:
            if dim not in self.oshape:
                win_shape.append(dim)
                new_dim = dim.next_unused(used_shapes)
                shape_updates[dim] = new_dim
                wout_shape.append(new_dim)
            else:
                wdiag_shape.append(dim)
                new_dim = dim
            i = weight_conj_shape.index(dim)
            weight_conj_shape[i] = new_dim
        else:
            new_dim = dim
        new_oshape.append(new_dim)

    if config.inner_not_relevant(inner):
        # Consolidate dense and dense adjoint into single dense
        new_weight_shape = wdiag_shape + wout_shape + win_shape
        einstr = shapes2einstr(
            self.weightshape,
            weight_conj_shape,
            new_weight_shape,
        )
        new_weight = einsum(self.weight, self.weight.conj(), einstr)
        normal = type(self)(
            new_weight,
            tuple(new_weight_shape),
            self.ishape,
            new_oshape,
        )
        normal._name = self._name
        normal._update_suffix(normal=self._name is not None)
        normal._shape_updates = shape_updates
        return normal
    _shape_updates = getattr(inner, "_shape_updates", {})
    _shape_updates.update(shape_updates)
    pre = copy(self)
    pre.oshape = inner.ishape
    post = self.adjoint()  # Copy happens inside adjoint
    post.ishape = inner.oshape
    post.oshape = new_oshape
    normal = post @ inner @ pre
    normal._shape_updates = _shape_updates
    return normal