Note
Click here to download the full example code
Creating Custom Linear Operators
This tutorial shows how to create your own NamedLinop subclasses.
Every operator in torchlinops follows the same pattern, so once you
understand the interface you can wrap any linear operation you need.
The NamedLinop Interface
To define a custom operator you need:
__init__— callsuper().__init__(NamedShape(ishape, oshape))to declare the named input and output dimensions.fn(linop, x, /)— a@staticmethodthat computes the forward operation \(y = A x\).adj_fn(linop, x, /)— a@staticmethodthat computes the adjoint operation \(y = A^H x\).
Optionally you can also override normal_fn (for an efficient \(A^H A\)),
split_forward (for multi-GPU tiling), and size (to report dimension
sizes).
Setup
import torch
from torch import Tensor
from torchlinops import Dense, Dim, NamedLinop
from torchlinops.nameddim import NamedShape as NS
from torchlinops.utils import is_adjoint
torch.manual_seed(0)
Out:
Example 1: Diagonal Scaling
The simplest useful operator multiplies each element by a weight vector. This is mathematically \(y = w \odot x\) (elementwise product).
We store the weight as an nn.Parameter so that it moves with the
module when you call .to(device). The fn and adj_fn static
methods receive the linop instance as their first argument — this is how
they access self.weight without being regular methods.
class DiagScale(NamedLinop):
"""Elementwise scaling: y = w * x."""
def __init__(self, weight: Tensor, ioshape):
# For a diagonal operator the input and output shapes are the same,
# so we pass ioshape for both.
super().__init__(NS(ioshape, ioshape))
import torch.nn as nn
self.weight = nn.Parameter(weight, requires_grad=False)
@staticmethod
def fn(linop, x, /):
return x * linop.weight
@staticmethod
def adj_fn(linop, x, /):
# The adjoint of elementwise multiplication by w is multiplication
# by conj(w).
return x * torch.conj(linop.weight)
@staticmethod
def normal_fn(linop, x, /):
# A^H A x = |w|^2 * x — avoids two separate passes.
return x * torch.abs(linop.weight) ** 2
Let's create an instance and try it out.
N = 8
w = torch.randn(N, dtype=torch.complex64)
D = DiagScale(w, ioshape=Dim("N"))
x = torch.randn(N, dtype=torch.complex64)
y = D(x)
print("D(x) =", y)
print("D.H(y) =", D.H(y))
Out:
D(x) = tensor([-0.2157+1.7340j, 0.1011-0.2223j, 0.7922-0.4527j, 2.0137+0.0682j,
-0.2490-0.5682j, -0.0586+0.0054j, -0.8984+0.9548j, 0.4241-0.5881j])
D.H(y) = tensor([-1.2412-1.5562j, 0.0503+0.0704j, 0.2539-0.6593j, -0.5520+2.9966j,
0.4509-0.3519j, -0.0133+0.0141j, 0.7595+0.8672j, 0.4377-0.3903j])
Testing the Adjoint
A correct adjoint must satisfy the identity
\(\langle y,\, Ax \rangle = \langle A^H y,\, x \rangle\) for all \(x, y\).
The helper is_adjoint checks this numerically.
x_test = torch.randn(N, dtype=torch.complex64)
y_test = torch.randn(N, dtype=torch.complex64)
passed = is_adjoint(D, x_test, y_test)
print(f"Adjoint test passed: {passed}")
Out:
Example 2: Permutation Operator
A permutation operator reorders the elements of a vector according to a fixed index mapping. Its adjoint is the inverse permutation (which is also its transpose, since permutation matrices are orthogonal).
class Permute(NamedLinop):
"""Reorder elements: y[i] = x[perm[i]]."""
def __init__(self, perm: Tensor, ishape, oshape):
super().__init__(NS(ishape, oshape))
import torch.nn as nn
# Store perm and its inverse as buffers so they travel with the module.
self.perm = nn.Parameter(perm, requires_grad=False)
inv = torch.empty_like(perm)
inv[perm] = torch.arange(len(perm))
self.inv_perm = nn.Parameter(inv, requires_grad=False)
@staticmethod
def fn(linop, x, /):
return x[linop.perm]
@staticmethod
def adj_fn(linop, x, /):
# The adjoint of a permutation is the inverse permutation.
return x[linop.inv_perm]
def size(self, dim):
if dim in self.ishape:
return len(self.perm)
if dim in self.oshape:
return len(self.perm)
return None
Create a random permutation and verify it.
M = 6
perm = torch.randperm(M)
P = Permute(perm, ishape=Dim("X"), oshape=Dim("Y"))
x_perm = torch.arange(M, dtype=torch.float32)
print("x =", x_perm)
print("P(x) =", P(x_perm))
print("P.H(P(x)) =", P.H(P(x_perm))) # should recover x
Out:
x = tensor([0., 1., 2., 3., 4., 5.])
P(x) = tensor([1., 0., 4., 2., 3., 5.])
P.H(P(x)) = tensor([0., 1., 2., 3., 4., 5.])
The adjoint test should pass for the permutation operator too.
x_test2 = torch.randn(M)
y_test2 = torch.randn(M)
passed2 = is_adjoint(P, x_test2, y_test2)
print(f"Permutation adjoint test passed: {passed2}")
Out:
Composing Custom and Built-in Operators
One of the main benefits of the NamedLinop system is easy composition
via the @ operator. Here we compose our DiagScale with a built-in
Dense matrix operator to form \(A = D \, M\) where \(D\) is diagonal
scaling and \(M\) is a dense matrix.
K = 5
M_weight = torch.randn(K, N, dtype=torch.complex64)
M_op = Dense(
weight=M_weight,
weightshape=Dim("KN"),
ishape=Dim("N"),
oshape=Dim("K"),
)
# Compose: first apply M, then apply D in the K-space
D_k = DiagScale(
torch.randn(K, dtype=torch.complex64),
ioshape=Dim("K"),
)
A = D_k @ M_op
print("Composed operator:", A)
Out:
The composed operator automatically supports adjoint and normal operations.
x_comp = torch.randn(N, dtype=torch.complex64)
y_comp = A(x_comp)
x_adj = A.H(y_comp)
print(f"Forward: {x_comp.shape} -> {y_comp.shape}")
print(f"Adjoint: {y_comp.shape} -> {x_adj.shape}")
x_t = torch.randn(N, dtype=torch.complex64)
y_t = torch.randn(K, dtype=torch.complex64)
print(f"Composed adjoint test passed: {is_adjoint(A, x_t, y_t)}")
Out:
Forward: torch.Size([8]) -> torch.Size([5])
Adjoint: torch.Size([5]) -> torch.Size([8])
Composed adjoint test passed: True
Summary
To create a custom NamedLinop:
- Subclass
NamedLinopand callsuper().__init__(NS(ishape, oshape))in__init__. - Define
fnandadj_fnas@staticmethodmethods with signature(linop, x, /). - Optionally define
normal_fnfor an efficient \(A^H A\). - Use
is_adjointto verify correctness. - Compose freely with other operators using
@.
Total running time of the script: ( 0 minutes 0.013 seconds)
Download Python source code: custom_linop.py