Skip to content

Note

Click here to download the full example code

Creating Custom Linear Operators

This tutorial shows how to create your own NamedLinop subclasses. Every operator in torchlinops follows the same pattern, so once you understand the interface you can wrap any linear operation you need.

The NamedLinop Interface

To define a custom operator you need:

  1. __init__ — call super().__init__(NamedShape(ishape, oshape)) to declare the named input and output dimensions.
  2. fn(linop, x, /) — a @staticmethod that computes the forward operation \(y = A x\).
  3. adj_fn(linop, x, /) — a @staticmethod that computes the adjoint operation \(y = A^H x\).

Optionally you can also override normal_fn (for an efficient \(A^H A\)), split_forward (for multi-GPU tiling), and size (to report dimension sizes).

Setup

import torch
from torch import Tensor

from torchlinops import Dense, Dim, NamedLinop
from torchlinops.nameddim import NamedShape as NS
from torchlinops.utils import is_adjoint

torch.manual_seed(0)

Out:

<torch._C.Generator object at 0x7f2c2490fc10>

Example 1: Diagonal Scaling

The simplest useful operator multiplies each element by a weight vector. This is mathematically \(y = w \odot x\) (elementwise product).

We store the weight as an nn.Parameter so that it moves with the module when you call .to(device). The fn and adj_fn static methods receive the linop instance as their first argument — this is how they access self.weight without being regular methods.

class DiagScale(NamedLinop):
    """Elementwise scaling: y = w * x."""

    def __init__(self, weight: Tensor, ioshape):
        # For a diagonal operator the input and output shapes are the same,
        # so we pass ioshape for both.
        super().__init__(NS(ioshape, ioshape))
        import torch.nn as nn

        self.weight = nn.Parameter(weight, requires_grad=False)

    @staticmethod
    def fn(linop, x, /):
        return x * linop.weight

    @staticmethod
    def adj_fn(linop, x, /):
        # The adjoint of elementwise multiplication by w is multiplication
        # by conj(w).
        return x * torch.conj(linop.weight)

    @staticmethod
    def normal_fn(linop, x, /):
        # A^H A x = |w|^2 * x — avoids two separate passes.
        return x * torch.abs(linop.weight) ** 2

Let's create an instance and try it out.

N = 8
w = torch.randn(N, dtype=torch.complex64)
D = DiagScale(w, ioshape=Dim("N"))

x = torch.randn(N, dtype=torch.complex64)
y = D(x)
print("D(x) =", y)
print("D.H(y) =", D.H(y))

Out:

D(x) = tensor([-0.2157+1.7340j,  0.1011-0.2223j,  0.7922-0.4527j,  2.0137+0.0682j,
        -0.2490-0.5682j, -0.0586+0.0054j, -0.8984+0.9548j,  0.4241-0.5881j])
D.H(y) = tensor([-1.2412-1.5562j,  0.0503+0.0704j,  0.2539-0.6593j, -0.5520+2.9966j,
         0.4509-0.3519j, -0.0133+0.0141j,  0.7595+0.8672j,  0.4377-0.3903j])

Testing the Adjoint

A correct adjoint must satisfy the identity \(\langle y,\, Ax \rangle = \langle A^H y,\, x \rangle\) for all \(x, y\). The helper is_adjoint checks this numerically.

x_test = torch.randn(N, dtype=torch.complex64)
y_test = torch.randn(N, dtype=torch.complex64)
passed = is_adjoint(D, x_test, y_test)
print(f"Adjoint test passed: {passed}")

Out:

Adjoint test passed: True

Example 2: Permutation Operator

A permutation operator reorders the elements of a vector according to a fixed index mapping. Its adjoint is the inverse permutation (which is also its transpose, since permutation matrices are orthogonal).

class Permute(NamedLinop):
    """Reorder elements: y[i] = x[perm[i]]."""

    def __init__(self, perm: Tensor, ishape, oshape):
        super().__init__(NS(ishape, oshape))
        import torch.nn as nn

        # Store perm and its inverse as buffers so they travel with the module.
        self.perm = nn.Parameter(perm, requires_grad=False)
        inv = torch.empty_like(perm)
        inv[perm] = torch.arange(len(perm))
        self.inv_perm = nn.Parameter(inv, requires_grad=False)

    @staticmethod
    def fn(linop, x, /):
        return x[linop.perm]

    @staticmethod
    def adj_fn(linop, x, /):
        # The adjoint of a permutation is the inverse permutation.
        return x[linop.inv_perm]

    def size(self, dim):
        if dim in self.ishape:
            return len(self.perm)
        if dim in self.oshape:
            return len(self.perm)
        return None

Create a random permutation and verify it.

M = 6
perm = torch.randperm(M)
P = Permute(perm, ishape=Dim("X"), oshape=Dim("Y"))

x_perm = torch.arange(M, dtype=torch.float32)
print("x       =", x_perm)
print("P(x)    =", P(x_perm))
print("P.H(P(x)) =", P.H(P(x_perm)))  # should recover x

Out:

x       = tensor([0., 1., 2., 3., 4., 5.])
P(x)    = tensor([1., 0., 4., 2., 3., 5.])
P.H(P(x)) = tensor([0., 1., 2., 3., 4., 5.])

The adjoint test should pass for the permutation operator too.

x_test2 = torch.randn(M)
y_test2 = torch.randn(M)
passed2 = is_adjoint(P, x_test2, y_test2)
print(f"Permutation adjoint test passed: {passed2}")

Out:

Permutation adjoint test passed: True

Composing Custom and Built-in Operators

One of the main benefits of the NamedLinop system is easy composition via the @ operator. Here we compose our DiagScale with a built-in Dense matrix operator to form \(A = D \, M\) where \(D\) is diagonal scaling and \(M\) is a dense matrix.

K = 5
M_weight = torch.randn(K, N, dtype=torch.complex64)
M_op = Dense(
    weight=M_weight,
    weightshape=Dim("KN"),
    ishape=Dim("N"),
    oshape=Dim("K"),
)

# Compose: first apply M, then apply D in the K-space
D_k = DiagScale(
    torch.randn(K, dtype=torch.complex64),
    ioshape=Dim("K"),
)
A = D_k @ M_op
print("Composed operator:", A)

Out:

Composed operator: Chain(
    Dense((N,) -> (K,))
    DiagScale((K,) -> (K,))
)

The composed operator automatically supports adjoint and normal operations.

x_comp = torch.randn(N, dtype=torch.complex64)
y_comp = A(x_comp)
x_adj = A.H(y_comp)
print(f"Forward:  {x_comp.shape} -> {y_comp.shape}")
print(f"Adjoint:  {y_comp.shape} -> {x_adj.shape}")

x_t = torch.randn(N, dtype=torch.complex64)
y_t = torch.randn(K, dtype=torch.complex64)
print(f"Composed adjoint test passed: {is_adjoint(A, x_t, y_t)}")

Out:

Forward:  torch.Size([8]) -> torch.Size([5])
Adjoint:  torch.Size([5]) -> torch.Size([8])
Composed adjoint test passed: True

Summary

To create a custom NamedLinop:

  1. Subclass NamedLinop and call super().__init__(NS(ishape, oshape)) in __init__.
  2. Define fn and adj_fn as @staticmethod methods with signature (linop, x, /).
  3. Optionally define normal_fn for an efficient \(A^H A\).
  4. Use is_adjoint to verify correctness.
  5. Compose freely with other operators using @.

Total running time of the script: ( 0 minutes 0.013 seconds)

Download Python source code: custom_linop.py

Download Jupyter notebook: custom_linop.ipynb

Gallery generated by mkdocs-gallery