Skip to content

NamedLinop

torchlinops.linops.NamedLinop

Bases: Module

Base class for all named linear operators.

A NamedLinop represents a linear map \(A : X \to Y\) where the input and output tensor dimensions are identified by name (e.g. ("Nx", "Ny") -> ("Kx", "Ky")).

Subclass this to implement concrete operators. At minimum, override fn and adj_fn as static methods.

ATTRIBUTE DESCRIPTION
shape

The named shape of the linop, containing ishape and oshape.

TYPE: NamedShape

stream

Optional cuda Stream to run this linop on.

TYPE: Stream

start_event

An event that signals when the linop has started. Useful for synchronizing multiple linops across multiple devices.

TYPE: (Event, optional)

end_event

An event that signals when the linop has completed. Useful for synchronizing multiple linops across multiple devices.

TYPE: (Event, optional)

input_listener

Pointer to another linop's event attribute. Used to coordinate GPU-to-GPU transfers in parallel execution contexts. When set to a tuple like (some_linop, "start_event"), the device transfer will wait for that event to be recorded before initiating the transfer.

TYPE: tuple(linop, str) or None

Source code in src/torchlinops/linops/namedlinop.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
class NamedLinop(nn.Module):
    """Base class for all named linear operators.

    A ``NamedLinop`` represents a linear map $A : X \\to Y$ where the input and
    output tensor dimensions are identified by name (e.g. ``("Nx", "Ny") -> ("Kx", "Ky")``).

    Subclass this to implement concrete operators. At minimum, override ``fn``
    and ``adj_fn`` as static methods.

    Attributes
    ----------
    shape : NamedShape
        The named shape of the linop, containing ``ishape`` and ``oshape``.
    stream : torch.cuda.Stream
        Optional cuda Stream to run this linop on.
    start_event : Event, optional
        An event that signals when the linop has started. Useful for synchronizing
        multiple linops across multiple devices.
    end_event : Event, optional
        An event that signals when the linop has completed. Useful for synchronizing multiple
        linops across multiple devices.
    input_listener : tuple(linop, str) or None
        Pointer to another linop's event attribute. Used to coordinate GPU-to-GPU
        transfers in parallel execution contexts. When set to a tuple like
        ``(some_linop, "start_event")``, the device transfer will wait for that
        event to be recorded before initiating the transfer.
    """

    def __init__(self, shape: NamedShape, name: Optional[str] = None, **kwargs):
        """
        Parameters
        ----------
        shape : NamedShape
            The shape of this linop, e.g. ``NamedShape(("N",), ("M",))``
        name : str, optional
            Optional name to display for this linop.
        """
        super().__init__(**kwargs)
        # Note: this attribute is private because the `.shape` attribute may be derived
        # dynamically
        self._shape = shape
        self._suffix = ""
        self._name = name
        self._setup()

    def _setup(self):
        """Helper method that should be called to reset the linop's state.
        Should be performed after any substantial changes to the linop."""
        self.reset_adjoint_and_normal()
        self.stream = None
        self.start_event = None
        self.end_event = None
        self._input_listener = ForwardedAttribute()
        # By default, listen for the start of this linop
        self.input_listener = (self, "start_event")

    @final
    def forward(self, x: Tensor) -> Tensor:
        """Apply the forward operation $y = A(x)$.

        If a CUDA stream is assigned, execution is dispatched to that stream.
        If a ``start_event`` is set, it is recorded before execution begins,
        allowing other operators to synchronize on it.

        Do not override this method. Instead, override .fn() and .adj_fn().

        Parameters
        ----------
        x : Tensor
            Input tensor.

        Returns
        -------
        Tensor
            The result of applying this linop to *x*.
        """
        if x.is_cuda:  # pragma: no cover
            stream = default_to(default_stream(x.device), self.stream)
            self.start_event = stream.record_event()
            with torch.cuda.stream(stream):
                y = self.fn(self, x)
            x.record_stream(stream)
            self.end_event = stream.record_event()
        else:
            y = self.fn(self, x)
        return y

    def apply(self, x: Tensor) -> Tensor:
        """Apply the linear operator to a tensor."""
        return LinopFunction.apply(x, self)

    # Override
    @staticmethod
    def fn(linop, x: Tensor, /) -> Tensor:
        """Compute the forward operation $y = A(x)$.

        Override this in subclasses to define the linop's forward behavior.

        Parameters
        ----------
        linop : NamedLinop
            The linop instance (passed explicitly because this is a staticmethod).
        x : Tensor
            Input tensor.

        Returns
        -------
        Tensor
            Result of applying the linop to *x*.

        Notes
        -----
        Declared as a staticmethod so that ``adjoint()`` can swap ``fn`` and
        ``adj_fn`` on a shallow copy without bound-method complications.
        """
        return x

    # Override
    @staticmethod
    def adj_fn(linop, x: Tensor, /) -> Tensor:
        """Compute the adjoint operation $y = A^H(x)$.

        Override this in subclasses to define the linop's adjoint behavior.

        Parameters
        ----------
        linop : NamedLinop
            The linop instance.
        x : Tensor
            Input tensor.

        Returns
        -------
        Tensor
            Result of applying the adjoint $A^H$ to *x*.
        """
        return x

    # Override
    @staticmethod
    def normal_fn(linop, x: Tensor, /) -> Tensor:
        """Compute the normal operation $y = A^H A(x)$.

        The default implementation composes ``adj_fn(fn(x))``. Override this
        in subclasses that have an efficient closed-form normal (e.g.
        ``Diagonal``, ``FFT``).

        Parameters
        ----------
        linop : NamedLinop
            The linop instance.
        x : Tensor
            Input tensor.

        Returns
        -------
        Tensor
            Result of applying $A^H A$ to *x*.
        """
        return linop.adj_fn(linop, linop.fn(linop, x))

    # Override
    def split_forward(self, ibatch, obatch) -> "NamedLinop":
        """Split this linop into a sub-linop according to slices over its dimensions.

        Override this in subclasses to define how the linop decomposes when tiled
        along its named dimensions. For the companion method that handles adjoints,
        see ``adj_split``.

        Parameters
        ----------
        ibatch : tuple[slice, ...]
            Slices over the input dimensions, one per element of ``ishape``.
        obatch : tuple[slice, ...]
            Slices over the output dimensions, one per element of ``oshape``.

        Returns
        -------
        NamedLinop
            A new linop that operates on the specified slice of the data.
        """

        return type(self)(self._shape)

    # Override
    def size(self, dim: str) -> int | None:
        """Return the concrete size of *dim*, or ``None`` if this linop does not determine it.

        Parameters
        ----------
        dim : str
            The named dimension to query.

        Returns
        -------
        int or None
            The size of the dimension, or ``None``.
        """
        return None

    @final
    @property
    def dims(self) -> set:
        """Get the set of dims that appear in this linop."""
        return set(self.ishape).union(set(self.oshape))

    @final
    @property
    def H(self) -> "NamedLinop":
        """Adjoint operator $A^H$.

        By default, creates a new adjoint on each access. Set
        ``torchlinops.config.cache_adjoint_normal = True`` to enable caching
        (deprecated).
        """
        try:
            if config.cache_adjoint_normal:
                config._warn_if_caching_enabled()
                if self._adjoint is None:
                    try:
                        _adjoint = self.adjoint()
                        _adjoint._adjoint = [self]
                        self._adjoint = [_adjoint]
                    except AttributeError as e:
                        traceback.print_exc()
                        raise e
                    logger.debug(
                        f"{type(self).__name__}: Making new adjoint {_adjoint._shape}"
                    )
                return self._adjoint[0]
            return self.adjoint()
        except AttributeError as e:
            raise RuntimeError(f"AttributeError in {type(self).__name__}.H: {e}") from e

    def adjoint(self) -> "NamedLinop":
        """Create the adjoint operator $A^H$.

        The default implementation shallow-copies this linop, swaps ``fn`` and
        ``adj_fn``, and flips the shape. Override this in subclasses that need
        special adjoint construction (e.g. conjugating weights).

        Returns
        -------
        NamedLinop
            The adjoint operator, sharing the same underlying data.
        """
        adj = copy(self)  # Retains data
        adj._shape = adj._shape.H
        # Swap functions (requires staticmethod)
        adj.fn, adj.adj_fn = adj.adj_fn, adj.fn
        adj.split, adj.adj_split = adj.adj_split, adj.split
        adj._update_suffix(adjoint=True)
        return adj

    @final
    def _update_suffix(self, adjoint: bool = False, normal: bool = False):
        if adjoint:
            if self._suffix.endswith(".H"):
                self._suffix = self._suffix[:-2]
            else:
                self._suffix += ".H"
        elif normal:
            self._suffix += ".N"

    @final
    @property
    def N(self) -> "NamedLinop":
        """Normal operator $A^H A$.

        Note that the naive normal operator can always be created via ``A.H @ A``.
        This function is reserved for custom behavior, as many linops have
        optimized normal forms.

        By default, creates a new normal on each access. Set
        ``torchlinops.config.cache_adjoint_normal = True`` to enable caching
        (deprecated).
        """
        try:
            if config.cache_adjoint_normal:
                config._warn_if_caching_enabled()
                if self._normal is None:
                    try:
                        _normal = self.normal()
                        self._normal = [_normal]
                    except AttributeError as e:
                        traceback.print_exc()
                        raise e
                return self._normal[0]
            return self.normal()
        except AttributeError as e:
            raise RuntimeError(f"AttributeError in {type(self).__name__}.N: {e}") from e

    def normal(self, inner=None) -> "NamedLinop":
        """Create the normal operator $A^H A$, optionally with an inner operator.

        When *inner* is ``None`` (or ``Identity`` with the reduce-identity config
        enabled), creates a linop whose forward pass calls ``normal_fn``.

        When *inner* is provided, constructs the composition $A^H \\cdot \\text{inner} \\cdot A$,
        which is used for Toeplitz embedding and similar optimizations.

        Parameters
        ----------
        inner : NamedLinop, optional
            An optional inner operator for Toeplitz embedding. If ``None``,
            the standard normal $A^H A$ is computed.

        Returns
        -------
        NamedLinop
            The normal operator.
        """
        if config.inner_not_relevant(inner):
            normal = copy(self)
            normal._shape = self._shape.N

            # Auxiliary object
            # Avoids creating lambda functions, which enables multiprocessing
            function_table = NormalFunctionLookup(self)
            # Static
            normal.fn = function_table.new_forward_adjoint_fn
            normal.adj_fn = function_table.new_forward_adjoint_fn
            normal.normal_fn = function_table.new_normal_fn
            # Bind `self` with partial to avoid weird multiprocessing-only error?
            normal.adjoint = partial(new_normal_adjoint, self=normal)
            # normal.adjoint = new_normal_adjoint.__get__(normal) # This one doesn't work

            # Assume that none of the dims are the same anymore
            # Override this behavior for e.g. diagonal linops
            normal.oshape = tuple(d.next_unused(normal.ishape) for d in normal.oshape)
            # Remember which shapes were updated
            normal._shape_updates = {
                d: d.next_unused(normal.ishape) for d in normal.oshape
            }
            normal._update_suffix(normal=True)
            return normal
        pre = copy(self)
        pre.oshape = inner.ishape
        post = self.adjoint()  # Copy happens inside adjoint
        post.ishape = inner.oshape
        normal = post @ inner @ pre
        normal._shape_updates = getattr(inner, "_shape_updates", {})
        return normal

    @final
    @staticmethod
    def split(linop, tile: Mapping[ND | str, slice]) -> "NamedLinop":
        """Split a linop into a sub-linop for a given tile.

        Translates a tile dictionary into per-dimension slices and delegates
        to ``split_forward``.

        Parameters
        ----------
        linop : NamedLinop
            The linop to split.
        tile : Mapping[ND | str, slice]
            Dictionary mapping dimension names to slices.

        Returns
        -------
        NamedLinop
            The sub-linop operating on the specified tile.
        """
        ibatch = [tile.get(dim, slice(None)) for dim in linop.ishape]
        obatch = [tile.get(dim, slice(None)) for dim in linop.oshape]
        return linop.split_forward(ibatch, obatch)

    @final
    @staticmethod
    def adj_split(linop, tile: Mapping[ND | str, slice]) -> "NamedLinop":
        """Split the adjoint of this linop for a given tile.

        Constructs the adjoint, splits it according to *tile*, and returns the
        adjoint of the split.

        Parameters
        ----------
        linop : NamedLinop
            The linop whose adjoint should be split.
        tile : Mapping[ND | str, slice]
            Dictionary mapping dimension names to slices.

        Returns
        -------
        NamedLinop
            The split adjoint sub-linop.
        """
        ibatch = [tile.get(dim, slice(None)) for dim in linop.ishape]
        obatch = [tile.get(dim, slice(None)) for dim in linop.oshape]
        splitH = linop.adjoint().split_forward(obatch, ibatch).adjoint()
        return splitH

    def flatten(self) -> list["NamedLinop"]:
        """Get a flattened list of constituent linops for composition."""
        return [self]

    def compose(self, inner) -> "NamedLinop":
        """Compose this linop with another linop.

        Parameters
        ----------
        inner : NamedLinop
            The linop to call before this one.

        Returns
        -------
        NamedLinop
            The composition of self and inner. If A = self and B = inner then this returns
            C = AB.
        """
        before = inner.flatten()
        after = self.flatten()
        return torchlinops.Chain(*(before + after))

    def __add__(self, right) -> "NamedLinop":
        return torchlinops.Add(self, right)

    def __radd__(self, left) -> "NamedLinop":
        return torchlinops.Add(left, self)

    def __mul__(self, right) -> "NamedLinop":
        if isinstance(right, (int, float)) or isinstance(right, torch.Tensor):
            right = torchlinops.Scalar(weight=right, ioshape=self.ishape)
            return self.compose(right)
        return NotImplemented

    def __rmul__(self, left) -> "NamedLinop":
        if isinstance(left, (int, float)) or isinstance(left, torch.Tensor):
            left = torchlinops.Scalar(weight=left, ioshape=self.oshape)
            return left.compose(self)
        return NotImplemented

    def __neg__(self) -> "NamedLinop":
        return (-1) * self

    def __sub__(self, right) -> "NamedLinop":
        return torchlinops.Add(self, -right)

    def __rsub__(self, left) -> "NamedLinop":
        if isinstance(left, NamedLinop):
            return torchlinops.Add(left, -self)
        return NotImplemented

    def __matmul__(self, right) -> "NamedLinop":
        if isinstance(right, NamedLinop):
            return self.compose(right)
        if isinstance(right, torch.Tensor):
            return self(right)
        return NotImplemented

    def __rmatmul__(self, left) -> "NamedLinop":
        if not isinstance(left, NamedLinop):
            raise ValueError(
                f"__rmatmul__ of linop {type(self)} with non-linop of type {type(left)} is undefined."
            )
        return left.compose(self)

    @property
    def name(self):
        if self._name is not None:
            return self._name
        return type(self).__name__

    @name.setter
    def name(self, new_name):
        self._name = new_name

    @property
    def repr_name(self):
        return self.name + self._suffix

    def __repr__(self):
        out = f"{self.repr_name}({self.ishape} -> {self.oshape})"
        if self.start_event is not None:  # pragma: no cover
            out += f", start: {self.start_event.event_id:x}"
        if self.end_event is not None:  # pragma: no cover
            out += f", end: {self.end_event.event_id:x}"
        out = INDENT.indent(out)
        return out

    def reset_adjoint_and_normal(self):
        self._adjoint = None
        self._normal = None

    @property
    def shape(self) -> Shape:
        return self._shape

    @shape.setter
    def shape(self, val):
        self._shape = val

    @property
    def ishape(self):
        return self._shape.ishape

    @ishape.setter
    def ishape(self, val):
        self._shape.ishape = val

    @property
    def oshape(self):
        return self._shape.oshape

    @oshape.setter
    def oshape(self, val):
        self._shape.oshape = val

    def to(self, device, memory_aware: bool = False, called_by_adjoint: bool = False):
        """Move this linop (and its cached adjoint/normal) to *device*.

        Parameters
        ----------
        device : torch.device or str
            Target device.
        memory_aware : bool, default False
            If ``True``, use ``memory_aware_to`` which preserves shared-storage
            topology when moving tensors.
        called_by_adjoint : bool, default False
            Internal flag to prevent infinite recursion when the adjoint
            also calls ``.to()``. Will be deprecated along with cache_adjoint_normal.

        Returns
        -------
        NamedLinop
            The linop on the target device.
        """

        if config.cache_adjoint_normal:  # pragma: no cover
            config._warn_if_caching_enabled()
            if self._adjoint and not called_by_adjoint:
                # bool flag avoids infinite recursion
                self._adjoint[0] = self._adjoint[0].to(
                    device, memory_aware, called_by_adjoint=True
                )
            if self._normal:
                self._normal[0] = self._normal[0].to(device, memory_aware)
        if memory_aware:
            return memory_aware_to(self, device)
        return super().to(device)

    @property
    def input_listener(self):
        """Pointer to another linop event attribute.

        Useful for facilitating gpu-gpu transfers in parallel.

        For example, if ToDevice occurs inside a composing linop that allows for
        parallel execution, e.g.

        C = Concat(
            Chain(ToDevice1, A, ...),
            Chain(ToDevice2, B, ...),
            ...
        )

        Then we may want to set ToDevice1 and ToDevice2 to both listen for the beginning of C.
        That way, both device movements can be triggered in parallel.

        This attribute is a universal attribute so that it can be chained in cases of nesting, e.g.
        Add(
            Concat(
                Chain(ToDevice, ...), ...
                ...
            )
        )
        The innermost ToDevice can listens to Chain, which listens to Concat, which listens to Add.
        This is good because Concat and Add both can parallelize efficiently across multiple GPUs.
        """
        return self._input_listener.value

    @input_listener.setter
    def input_listener(self, value):
        if isinstance(value, tuple):
            _log_transfer(
                f"Setting {type(self).__name__}.input_listener to reference {type(value[0]).__name__}.{value[1]}"
            )
            self._input_listener.forward_to(*value)
        else:
            _log_transfer(f"Setting {type(self).__name__}.input_listener to {value}")
            self._input_listener = value

    def __copy__(self):
        """Specialized copying for linops.

        Notes
        -----
        - Shares previous data
        - Removes references to adjoint and normal
        - Creates a new shape object, rather than using the old one
        """
        cls = type(self)
        new = cls.__new__(cls)
        new.__dict__ = self.__dict__.copy()
        # Pytorch-specific module state dictionaries
        # Mirror those used in `__getattr__``
        # See https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/module.py#L1915
        new._parameters = new._parameters.copy()
        new._modules = new._modules.copy()
        new._buffers = new._buffers.copy()

        # Create new shape
        new._shape = deepcopy(self._shape)
        new._setup()
        return new

    @final
    def __deepcopy__(self, _):
        return memory_aware_deepcopy(self)

dims property

dims: set

Get the set of dims that appear in this linop.

H property

H: NamedLinop

Adjoint operator \(A^H\).

By default, creates a new adjoint on each access. Set torchlinops.config.cache_adjoint_normal = True to enable caching (deprecated).

N property

N: NamedLinop

Normal operator \(A^H A\).

Note that the naive normal operator can always be created via A.H @ A. This function is reserved for custom behavior, as many linops have optimized normal forms.

By default, creates a new normal on each access. Set torchlinops.config.cache_adjoint_normal = True to enable caching (deprecated).

__init__

__init__(
    shape: NamedShape, name: Optional[str] = None, **kwargs
)
PARAMETER DESCRIPTION
shape

The shape of this linop, e.g. NamedShape(("N",), ("M",))

TYPE: NamedShape

name

Optional name to display for this linop.

TYPE: str DEFAULT: None

Source code in src/torchlinops/linops/namedlinop.py
def __init__(self, shape: NamedShape, name: Optional[str] = None, **kwargs):
    """
    Parameters
    ----------
    shape : NamedShape
        The shape of this linop, e.g. ``NamedShape(("N",), ("M",))``
    name : str, optional
        Optional name to display for this linop.
    """
    super().__init__(**kwargs)
    # Note: this attribute is private because the `.shape` attribute may be derived
    # dynamically
    self._shape = shape
    self._suffix = ""
    self._name = name
    self._setup()

forward

forward(x: Tensor) -> Tensor

Apply the forward operation \(y = A(x)\).

If a CUDA stream is assigned, execution is dispatched to that stream. If a start_event is set, it is recorded before execution begins, allowing other operators to synchronize on it.

Do not override this method. Instead, override .fn() and .adj_fn().

PARAMETER DESCRIPTION
x

Input tensor.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

The result of applying this linop to x.

Source code in src/torchlinops/linops/namedlinop.py
@final
def forward(self, x: Tensor) -> Tensor:
    """Apply the forward operation $y = A(x)$.

    If a CUDA stream is assigned, execution is dispatched to that stream.
    If a ``start_event`` is set, it is recorded before execution begins,
    allowing other operators to synchronize on it.

    Do not override this method. Instead, override .fn() and .adj_fn().

    Parameters
    ----------
    x : Tensor
        Input tensor.

    Returns
    -------
    Tensor
        The result of applying this linop to *x*.
    """
    if x.is_cuda:  # pragma: no cover
        stream = default_to(default_stream(x.device), self.stream)
        self.start_event = stream.record_event()
        with torch.cuda.stream(stream):
            y = self.fn(self, x)
        x.record_stream(stream)
        self.end_event = stream.record_event()
    else:
        y = self.fn(self, x)
    return y

fn staticmethod

fn(linop, x: Tensor) -> Tensor

Compute the forward operation \(y = A(x)\).

Override this in subclasses to define the linop's forward behavior.

PARAMETER DESCRIPTION
linop

The linop instance (passed explicitly because this is a staticmethod).

TYPE: NamedLinop

x

Input tensor.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Result of applying the linop to x.

Notes

Declared as a staticmethod so that adjoint() can swap fn and adj_fn on a shallow copy without bound-method complications.

Source code in src/torchlinops/linops/namedlinop.py
@staticmethod
def fn(linop, x: Tensor, /) -> Tensor:
    """Compute the forward operation $y = A(x)$.

    Override this in subclasses to define the linop's forward behavior.

    Parameters
    ----------
    linop : NamedLinop
        The linop instance (passed explicitly because this is a staticmethod).
    x : Tensor
        Input tensor.

    Returns
    -------
    Tensor
        Result of applying the linop to *x*.

    Notes
    -----
    Declared as a staticmethod so that ``adjoint()`` can swap ``fn`` and
    ``adj_fn`` on a shallow copy without bound-method complications.
    """
    return x

adj_fn staticmethod

adj_fn(linop, x: Tensor) -> Tensor

Compute the adjoint operation \(y = A^H(x)\).

Override this in subclasses to define the linop's adjoint behavior.

PARAMETER DESCRIPTION
linop

The linop instance.

TYPE: NamedLinop

x

Input tensor.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Result of applying the adjoint \(A^H\) to x.

Source code in src/torchlinops/linops/namedlinop.py
@staticmethod
def adj_fn(linop, x: Tensor, /) -> Tensor:
    """Compute the adjoint operation $y = A^H(x)$.

    Override this in subclasses to define the linop's adjoint behavior.

    Parameters
    ----------
    linop : NamedLinop
        The linop instance.
    x : Tensor
        Input tensor.

    Returns
    -------
    Tensor
        Result of applying the adjoint $A^H$ to *x*.
    """
    return x

normal_fn staticmethod

normal_fn(linop, x: Tensor) -> Tensor

Compute the normal operation \(y = A^H A(x)\).

The default implementation composes adj_fn(fn(x)). Override this in subclasses that have an efficient closed-form normal (e.g. Diagonal, FFT).

PARAMETER DESCRIPTION
linop

The linop instance.

TYPE: NamedLinop

x

Input tensor.

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Result of applying \(A^H A\) to x.

Source code in src/torchlinops/linops/namedlinop.py
@staticmethod
def normal_fn(linop, x: Tensor, /) -> Tensor:
    """Compute the normal operation $y = A^H A(x)$.

    The default implementation composes ``adj_fn(fn(x))``. Override this
    in subclasses that have an efficient closed-form normal (e.g.
    ``Diagonal``, ``FFT``).

    Parameters
    ----------
    linop : NamedLinop
        The linop instance.
    x : Tensor
        Input tensor.

    Returns
    -------
    Tensor
        Result of applying $A^H A$ to *x*.
    """
    return linop.adj_fn(linop, linop.fn(linop, x))

split_forward

split_forward(ibatch, obatch) -> NamedLinop

Split this linop into a sub-linop according to slices over its dimensions.

Override this in subclasses to define how the linop decomposes when tiled along its named dimensions. For the companion method that handles adjoints, see adj_split.

PARAMETER DESCRIPTION
ibatch

Slices over the input dimensions, one per element of ishape.

TYPE: tuple[slice, ...]

obatch

Slices over the output dimensions, one per element of oshape.

TYPE: tuple[slice, ...]

RETURNS DESCRIPTION
NamedLinop

A new linop that operates on the specified slice of the data.

Source code in src/torchlinops/linops/namedlinop.py
def split_forward(self, ibatch, obatch) -> "NamedLinop":
    """Split this linop into a sub-linop according to slices over its dimensions.

    Override this in subclasses to define how the linop decomposes when tiled
    along its named dimensions. For the companion method that handles adjoints,
    see ``adj_split``.

    Parameters
    ----------
    ibatch : tuple[slice, ...]
        Slices over the input dimensions, one per element of ``ishape``.
    obatch : tuple[slice, ...]
        Slices over the output dimensions, one per element of ``oshape``.

    Returns
    -------
    NamedLinop
        A new linop that operates on the specified slice of the data.
    """

    return type(self)(self._shape)

size

size(dim: str) -> int | None

Return the concrete size of dim, or None if this linop does not determine it.

PARAMETER DESCRIPTION
dim

The named dimension to query.

TYPE: str

RETURNS DESCRIPTION
int or None

The size of the dimension, or None.

Source code in src/torchlinops/linops/namedlinop.py
def size(self, dim: str) -> int | None:
    """Return the concrete size of *dim*, or ``None`` if this linop does not determine it.

    Parameters
    ----------
    dim : str
        The named dimension to query.

    Returns
    -------
    int or None
        The size of the dimension, or ``None``.
    """
    return None

adjoint

adjoint() -> NamedLinop

Create the adjoint operator \(A^H\).

The default implementation shallow-copies this linop, swaps fn and adj_fn, and flips the shape. Override this in subclasses that need special adjoint construction (e.g. conjugating weights).

RETURNS DESCRIPTION
NamedLinop

The adjoint operator, sharing the same underlying data.

Source code in src/torchlinops/linops/namedlinop.py
def adjoint(self) -> "NamedLinop":
    """Create the adjoint operator $A^H$.

    The default implementation shallow-copies this linop, swaps ``fn`` and
    ``adj_fn``, and flips the shape. Override this in subclasses that need
    special adjoint construction (e.g. conjugating weights).

    Returns
    -------
    NamedLinop
        The adjoint operator, sharing the same underlying data.
    """
    adj = copy(self)  # Retains data
    adj._shape = adj._shape.H
    # Swap functions (requires staticmethod)
    adj.fn, adj.adj_fn = adj.adj_fn, adj.fn
    adj.split, adj.adj_split = adj.adj_split, adj.split
    adj._update_suffix(adjoint=True)
    return adj

normal

normal(inner=None) -> NamedLinop

Create the normal operator \(A^H A\), optionally with an inner operator.

When inner is None (or Identity with the reduce-identity config enabled), creates a linop whose forward pass calls normal_fn.

When inner is provided, constructs the composition \(A^H \cdot \text{inner} \cdot A\), which is used for Toeplitz embedding and similar optimizations.

PARAMETER DESCRIPTION
inner

An optional inner operator for Toeplitz embedding. If None, the standard normal \(A^H A\) is computed.

TYPE: NamedLinop DEFAULT: None

RETURNS DESCRIPTION
NamedLinop

The normal operator.

Source code in src/torchlinops/linops/namedlinop.py
def normal(self, inner=None) -> "NamedLinop":
    """Create the normal operator $A^H A$, optionally with an inner operator.

    When *inner* is ``None`` (or ``Identity`` with the reduce-identity config
    enabled), creates a linop whose forward pass calls ``normal_fn``.

    When *inner* is provided, constructs the composition $A^H \\cdot \\text{inner} \\cdot A$,
    which is used for Toeplitz embedding and similar optimizations.

    Parameters
    ----------
    inner : NamedLinop, optional
        An optional inner operator for Toeplitz embedding. If ``None``,
        the standard normal $A^H A$ is computed.

    Returns
    -------
    NamedLinop
        The normal operator.
    """
    if config.inner_not_relevant(inner):
        normal = copy(self)
        normal._shape = self._shape.N

        # Auxiliary object
        # Avoids creating lambda functions, which enables multiprocessing
        function_table = NormalFunctionLookup(self)
        # Static
        normal.fn = function_table.new_forward_adjoint_fn
        normal.adj_fn = function_table.new_forward_adjoint_fn
        normal.normal_fn = function_table.new_normal_fn
        # Bind `self` with partial to avoid weird multiprocessing-only error?
        normal.adjoint = partial(new_normal_adjoint, self=normal)
        # normal.adjoint = new_normal_adjoint.__get__(normal) # This one doesn't work

        # Assume that none of the dims are the same anymore
        # Override this behavior for e.g. diagonal linops
        normal.oshape = tuple(d.next_unused(normal.ishape) for d in normal.oshape)
        # Remember which shapes were updated
        normal._shape_updates = {
            d: d.next_unused(normal.ishape) for d in normal.oshape
        }
        normal._update_suffix(normal=True)
        return normal
    pre = copy(self)
    pre.oshape = inner.ishape
    post = self.adjoint()  # Copy happens inside adjoint
    post.ishape = inner.oshape
    normal = post @ inner @ pre
    normal._shape_updates = getattr(inner, "_shape_updates", {})
    return normal

split staticmethod

split(
    linop, tile: Mapping[NamedDimension | str, slice]
) -> NamedLinop

Split a linop into a sub-linop for a given tile.

Translates a tile dictionary into per-dimension slices and delegates to split_forward.

PARAMETER DESCRIPTION
linop

The linop to split.

TYPE: NamedLinop

tile

Dictionary mapping dimension names to slices.

TYPE: Mapping[NamedDimension | str, slice]

RETURNS DESCRIPTION
NamedLinop

The sub-linop operating on the specified tile.

Source code in src/torchlinops/linops/namedlinop.py
@final
@staticmethod
def split(linop, tile: Mapping[ND | str, slice]) -> "NamedLinop":
    """Split a linop into a sub-linop for a given tile.

    Translates a tile dictionary into per-dimension slices and delegates
    to ``split_forward``.

    Parameters
    ----------
    linop : NamedLinop
        The linop to split.
    tile : Mapping[ND | str, slice]
        Dictionary mapping dimension names to slices.

    Returns
    -------
    NamedLinop
        The sub-linop operating on the specified tile.
    """
    ibatch = [tile.get(dim, slice(None)) for dim in linop.ishape]
    obatch = [tile.get(dim, slice(None)) for dim in linop.oshape]
    return linop.split_forward(ibatch, obatch)

adj_split staticmethod

adj_split(
    linop, tile: Mapping[NamedDimension | str, slice]
) -> NamedLinop

Split the adjoint of this linop for a given tile.

Constructs the adjoint, splits it according to tile, and returns the adjoint of the split.

PARAMETER DESCRIPTION
linop

The linop whose adjoint should be split.

TYPE: NamedLinop

tile

Dictionary mapping dimension names to slices.

TYPE: Mapping[NamedDimension | str, slice]

RETURNS DESCRIPTION
NamedLinop

The split adjoint sub-linop.

Source code in src/torchlinops/linops/namedlinop.py
@final
@staticmethod
def adj_split(linop, tile: Mapping[ND | str, slice]) -> "NamedLinop":
    """Split the adjoint of this linop for a given tile.

    Constructs the adjoint, splits it according to *tile*, and returns the
    adjoint of the split.

    Parameters
    ----------
    linop : NamedLinop
        The linop whose adjoint should be split.
    tile : Mapping[ND | str, slice]
        Dictionary mapping dimension names to slices.

    Returns
    -------
    NamedLinop
        The split adjoint sub-linop.
    """
    ibatch = [tile.get(dim, slice(None)) for dim in linop.ishape]
    obatch = [tile.get(dim, slice(None)) for dim in linop.oshape]
    splitH = linop.adjoint().split_forward(obatch, ibatch).adjoint()
    return splitH

flatten

flatten() -> list[NamedLinop]

Get a flattened list of constituent linops for composition.

Source code in src/torchlinops/linops/namedlinop.py
def flatten(self) -> list["NamedLinop"]:
    """Get a flattened list of constituent linops for composition."""
    return [self]

compose

compose(inner) -> NamedLinop

Compose this linop with another linop.

PARAMETER DESCRIPTION
inner

The linop to call before this one.

TYPE: NamedLinop

RETURNS DESCRIPTION
NamedLinop

The composition of self and inner. If A = self and B = inner then this returns C = AB.

Source code in src/torchlinops/linops/namedlinop.py
def compose(self, inner) -> "NamedLinop":
    """Compose this linop with another linop.

    Parameters
    ----------
    inner : NamedLinop
        The linop to call before this one.

    Returns
    -------
    NamedLinop
        The composition of self and inner. If A = self and B = inner then this returns
        C = AB.
    """
    before = inner.flatten()
    after = self.flatten()
    return torchlinops.Chain(*(before + after))

apply

apply(x: Tensor) -> Tensor

Apply the linear operator to a tensor.

Source code in src/torchlinops/linops/namedlinop.py
def apply(self, x: Tensor) -> Tensor:
    """Apply the linear operator to a tensor."""
    return LinopFunction.apply(x, self)

to

to(
    device,
    memory_aware: bool = False,
    called_by_adjoint: bool = False,
)

Move this linop (and its cached adjoint/normal) to device.

PARAMETER DESCRIPTION
device

Target device.

TYPE: device or str

memory_aware

If True, use memory_aware_to which preserves shared-storage topology when moving tensors.

TYPE: bool DEFAULT: False

called_by_adjoint

Internal flag to prevent infinite recursion when the adjoint also calls .to(). Will be deprecated along with cache_adjoint_normal.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
NamedLinop

The linop on the target device.

Source code in src/torchlinops/linops/namedlinop.py
def to(self, device, memory_aware: bool = False, called_by_adjoint: bool = False):
    """Move this linop (and its cached adjoint/normal) to *device*.

    Parameters
    ----------
    device : torch.device or str
        Target device.
    memory_aware : bool, default False
        If ``True``, use ``memory_aware_to`` which preserves shared-storage
        topology when moving tensors.
    called_by_adjoint : bool, default False
        Internal flag to prevent infinite recursion when the adjoint
        also calls ``.to()``. Will be deprecated along with cache_adjoint_normal.

    Returns
    -------
    NamedLinop
        The linop on the target device.
    """

    if config.cache_adjoint_normal:  # pragma: no cover
        config._warn_if_caching_enabled()
        if self._adjoint and not called_by_adjoint:
            # bool flag avoids infinite recursion
            self._adjoint[0] = self._adjoint[0].to(
                device, memory_aware, called_by_adjoint=True
            )
        if self._normal:
            self._normal[0] = self._normal[0].to(device, memory_aware)
    if memory_aware:
        return memory_aware_to(self, device)
    return super().to(device)

reset_adjoint_and_normal

reset_adjoint_and_normal()
Source code in src/torchlinops/linops/namedlinop.py
def reset_adjoint_and_normal(self):
    self._adjoint = None
    self._normal = None