Skip to content

Padding & Truncation

torchlinops.linops.Pad

Bases: NamedLinop

Zero Pad the last dimensions of the input volume Padding is centered: - TODO: support non-centered padding? ishape: [B... Nx Ny [Nz]] oshape: [B... Nx1 Ny1 [Nz1]]

Source code in src/torchlinops/linops/pad_last.py
class Pad(NamedLinop):
    """Zero Pad the last dimensions of the input volume
    Padding is centered:
    - TODO: support non-centered padding?
    ishape: [B... Nx Ny [Nz]]
    oshape: [B... Nx1 Ny1 [Nz1]]

    """

    def __init__(
        self,
        pad_im_size: tuple[int, ...],
        im_size: tuple[int, ...],
        in_shape: Optional[Shape] = None,
        out_shape: Optional[Shape] = None,
        batch_shape: Optional[Shape] = None,
    ):
        """
        Parameters
        ----------
        pad_im_size : tuple[int, ...]
            Target (padded) size for the last dimensions.
        im_size : tuple[int, ...]
            Original (unpadded) size for the last dimensions.
        in_shape : Shape, optional
            Named shape for the input spatial dimensions.
        out_shape : Shape, optional
            Named shape for the output spatial dimensions.
        batch_shape : Shape, optional
            Named shape for batch dimensions.
        """
        if len(pad_im_size) != len(im_size):
            raise ValueError(
                f"Padded and unpadded dims should be the same length. padded: {pad_im_size} unpadded: {im_size}"
            )

        if in_shape is None:
            self.in_im_shape = ND.infer(get_nd_shape(im_size))
        else:
            self.in_im_shape = ND.infer(in_shape)
        if out_shape is None:
            self.out_im_shape = tuple(
                d.next_unused(self.in_im_shape) for d in self.in_im_shape
            )
        else:
            self.out_im_shape = out_shape
        batch_shape = default_to(("...",), batch_shape)

        shape = NS(batch_shape) + NS(self.in_im_shape, self.out_im_shape)
        super().__init__(shape)
        self.D = len(im_size)
        self.im_size = tuple(im_size)
        self.pad_im_size = tuple(pad_im_size)
        self.in_im_size = tuple(im_size)
        self.out_im_size = tuple(pad_im_size)
        # for psz in pad_im_size:
        #     assert not (psz % 2), "Pad sizes must be even"

        # sizes = [
        #     [(psz - isz) // 2] * 2
        #     for psz, isz in zip(self.out_im_size, self.in_im_size)
        # ]
        # self.pad = sum(sizes, start=[])
        # self.pad.reverse()

        self.pad = pad_to_size(self.im_size, self.pad_im_size)

        # Make crop slice that undoes padding
        # Need to reverse crop_slice because padding is reversed
        self.crop_slice = crop_slice_from_pad(self.pad)

    @staticmethod
    def fn(padlast, x, /):
        if tuple(x.shape[-padlast.D :]) != padlast.im_size:
            raise ValueError(
                f"Mismatched shapes: expected {padlast.im_size} but got {x.shape[-padlast.D :]}"
            )
        pad = padlast.pad + [0, 0] * (x.ndim - padlast.D)
        return F.pad(x, pad)

    @staticmethod
    def adj_fn(padlast, y, /):
        """Crop the last n dimensions of y"""
        if tuple(y.shape[-padlast.D :]) != padlast.pad_im_size:
            raise ValueError(
                f"Mismatched shapes: expected {padlast.pad_im_size} but got {y.shape[-padlast.D :]}"
            )
        slc = [slice(None)] * (y.ndim - padlast.D) + padlast.crop_slice
        return y[tuple(slc)]

    def adjoint(self):
        adj = super().adjoint()
        if self.name == "Pad":
            adj.name = "Crop"
        elif self.name == "Crop":
            adj.name = "Pad"

        adj.in_im_shape, adj.out_im_shape = self.out_im_shape, self.in_im_shape
        adj.in_im_size, adj.out_im_size = self.out_im_size, self.in_im_size
        return adj

    def split_forward(self, ibatch, obatch):
        for islc, oslc in zip(ibatch[-self.D :], obatch[-self.D :]):
            if islc != slice(None) or oslc != slice(None):
                raise ValueError(
                    f"{type(self).__name__} cannot be split along image dim"
                )
        return self

    def size(self, dim: str):
        if dim in self.ishape[-self.D :]:
            return self.in_im_size[self.in_im_shape.index(dim)]
        elif dim in self.oshape[-self.D :]:
            return self.out_im_size[self.out_im_shape.index(dim)]
        return None

__init__

__init__(
    pad_im_size: tuple[int, ...],
    im_size: tuple[int, ...],
    in_shape: Optional[Shape] = None,
    out_shape: Optional[Shape] = None,
    batch_shape: Optional[Shape] = None,
)
PARAMETER DESCRIPTION
pad_im_size

Target (padded) size for the last dimensions.

TYPE: tuple[int, ...]

im_size

Original (unpadded) size for the last dimensions.

TYPE: tuple[int, ...]

in_shape

Named shape for the input spatial dimensions.

TYPE: Shape DEFAULT: None

out_shape

Named shape for the output spatial dimensions.

TYPE: Shape DEFAULT: None

batch_shape

Named shape for batch dimensions.

TYPE: Shape DEFAULT: None

Source code in src/torchlinops/linops/pad_last.py
def __init__(
    self,
    pad_im_size: tuple[int, ...],
    im_size: tuple[int, ...],
    in_shape: Optional[Shape] = None,
    out_shape: Optional[Shape] = None,
    batch_shape: Optional[Shape] = None,
):
    """
    Parameters
    ----------
    pad_im_size : tuple[int, ...]
        Target (padded) size for the last dimensions.
    im_size : tuple[int, ...]
        Original (unpadded) size for the last dimensions.
    in_shape : Shape, optional
        Named shape for the input spatial dimensions.
    out_shape : Shape, optional
        Named shape for the output spatial dimensions.
    batch_shape : Shape, optional
        Named shape for batch dimensions.
    """
    if len(pad_im_size) != len(im_size):
        raise ValueError(
            f"Padded and unpadded dims should be the same length. padded: {pad_im_size} unpadded: {im_size}"
        )

    if in_shape is None:
        self.in_im_shape = ND.infer(get_nd_shape(im_size))
    else:
        self.in_im_shape = ND.infer(in_shape)
    if out_shape is None:
        self.out_im_shape = tuple(
            d.next_unused(self.in_im_shape) for d in self.in_im_shape
        )
    else:
        self.out_im_shape = out_shape
    batch_shape = default_to(("...",), batch_shape)

    shape = NS(batch_shape) + NS(self.in_im_shape, self.out_im_shape)
    super().__init__(shape)
    self.D = len(im_size)
    self.im_size = tuple(im_size)
    self.pad_im_size = tuple(pad_im_size)
    self.in_im_size = tuple(im_size)
    self.out_im_size = tuple(pad_im_size)
    # for psz in pad_im_size:
    #     assert not (psz % 2), "Pad sizes must be even"

    # sizes = [
    #     [(psz - isz) // 2] * 2
    #     for psz, isz in zip(self.out_im_size, self.in_im_size)
    # ]
    # self.pad = sum(sizes, start=[])
    # self.pad.reverse()

    self.pad = pad_to_size(self.im_size, self.pad_im_size)

    # Make crop slice that undoes padding
    # Need to reverse crop_slice because padding is reversed
    self.crop_slice = crop_slice_from_pad(self.pad)

torchlinops.linops.Crop

Crop(
    crop_im_size, im_size, in_shape, out_shape, batch_shape
)
Source code in src/torchlinops/linops/pad_last.py
def Crop(crop_im_size, im_size, in_shape, out_shape, batch_shape):
    return Pad(im_size, crop_im_size, out_shape, in_shape, batch_shape).H

torchlinops.linops.Truncate

Bases: NamedLinop

Truncation (slicing) operator along the last dimension.

Extracts a contiguous slice from the input. The adjoint zero-pads back to the original size.

Source code in src/torchlinops/linops/trunc_pad.py
class Truncate(NamedLinop):
    """Truncation (slicing) operator along the last dimension.

    Extracts a contiguous slice from the input. The adjoint zero-pads
    back to the original size.
    """

    def __init__(
        self,
        dim: int,
        from_length: int,
        to_length: int,
        ishape: Shape,
        oshape: Shape,
    ):
        self.dim = dim
        self.from_length = from_length
        self.to_length = to_length
        if self.from_length < 0:
            raise ValueError(
                f"from_length must be nonnegative but got {self.from_length}"
            )
        if self.to_length < 0 or self.to_length > from_length:
            raise ValueError(
                f"to_length must be in [0, {self.from_length}] but got {self.to_length}"
            )
        self.slc = [slice(None)] * len(ishape)
        self.slc[dim] = slice(0, self.to_length)
        self.slc = tuple(self.slc)

        self.end_slc = [slice(None)] * len(ishape)
        self.end_slc[dim] = slice(self.to_length - self.from_length, None)
        self.end_slc = tuple(self.end_slc)
        super().__init__(NS(ishape, oshape))

    @staticmethod
    def fn(truncate, x, /):
        if x.shape[truncate.dim] != truncate.from_length:
            raise ValueError(
                f"Truncate expecting size {truncate.from_length} at x.shape[{truncate.dim}] but got {x.shape[truncate.dim]} (x.shape: {x.shape})"
            )
        return x[truncate.slc]

    @staticmethod
    def adj_fn(truncate, y, /):
        if y.shape[truncate.dim] != truncate.to_length:
            raise ValueError(
                f"Truncate (adjoint) expecting size {truncate.to_length} at x.shape[{truncate.dim}] but got {y.shape[truncate.dim]} (y.shape: {y.shape})"
            )
        return end_pad_with_zeros(
            y,
            truncate.dim,
            truncate.from_length - truncate.to_length,
        )

    @staticmethod
    def normal_fn(truncate, x, /):
        if x.shape[truncate.dim] != truncate.from_length:
            raise ValueError(
                f"Truncate (normal) expecting size {truncate.from_length} at x.shape[{truncate.dim}] but got {x.shape[truncate.dim]} (x.shape: {x.shape})"
            )
        x = x.clone()
        x[truncate.end_slc] = 0.0
        return x

    def split_forward(self, ibatch, obatch):
        if ibatch[self.dim] != slice(None) or obatch[self.dim] != slice(None):
            raise ValueError("Cannot slice a Truncate linop along truncation dimension")
        return type(self)(
            self.dim, self.from_length, self.to_length, self.ishape, self.oshape
        )

    def adjoint(self):
        return PadDim(
            self.dim,
            self.to_length,
            self.from_length,
            self.oshape,
            self.ishape,
        )

    def normal(self, inner=None):
        """Diagonal in all dims except the last one"""
        pre = copy(self)
        post = self.adjoint()
        if inner is None:
            return post @ pre
        pre.oshape = inner.ishape
        post.ishape = inner.oshape
        new_oshape = list(inner.oshape)
        new_oshape[self.dim] = post.oshape[self.dim]
        post.oshape = tuple(new_oshape)
        return post @ inner @ pre

__init__

__init__(
    dim: int,
    from_length: int,
    to_length: int,
    ishape: Shape,
    oshape: Shape,
)
Source code in src/torchlinops/linops/trunc_pad.py
def __init__(
    self,
    dim: int,
    from_length: int,
    to_length: int,
    ishape: Shape,
    oshape: Shape,
):
    self.dim = dim
    self.from_length = from_length
    self.to_length = to_length
    if self.from_length < 0:
        raise ValueError(
            f"from_length must be nonnegative but got {self.from_length}"
        )
    if self.to_length < 0 or self.to_length > from_length:
        raise ValueError(
            f"to_length must be in [0, {self.from_length}] but got {self.to_length}"
        )
    self.slc = [slice(None)] * len(ishape)
    self.slc[dim] = slice(0, self.to_length)
    self.slc = tuple(self.slc)

    self.end_slc = [slice(None)] * len(ishape)
    self.end_slc[dim] = slice(self.to_length - self.from_length, None)
    self.end_slc = tuple(self.end_slc)
    super().__init__(NS(ishape, oshape))

torchlinops.linops.PadDim

Bases: NamedLinop

Zero-padding operator along a specified dimension.

Pads the input with zeros. The adjoint truncates (slices) back to the original size.

Source code in src/torchlinops/linops/trunc_pad.py
class PadDim(NamedLinop):
    """Zero-padding operator along a specified dimension.

    Pads the input with zeros. The adjoint truncates (slices) back to
    the original size.
    """

    def __init__(self, dim, from_length, to_length, ishape, oshape):
        self.dim = dim
        self.from_length = from_length
        self.to_length = to_length
        if self.to_length < 0:
            raise ValueError(f"to_length must be nonnegative but got {self.to_length}")

        if self.from_length < 0 or self.from_length > self.to_length:
            raise ValueError(
                f"to_length must be in [0, {self.to_length}] but got {self.from_length}"
            )

        self.slc = [slice(None)] * len(ishape)
        self.slc[dim] = slice(0, self.from_length)
        self.slc = tuple(self.slc)
        super().__init__(NS(ishape, oshape))

    def adjoint(self):
        return Truncate(
            self.dim, self.to_length, self.from_length, self.oshape, self.ishape
        )

    def normal(self, inner=None):
        """Diagonal in all dims except the last one"""
        if inner is None:
            return Identity(self.ishape)
        pre = copy(self)
        post = copy(self).H
        pre.oshape = inner.ishape
        post.ishape = inner.oshape
        return post @ inner @ pre

    @staticmethod
    def fn(padend, x, /):
        if x.shape[padend.dim] != padend.from_length:
            raise ValueError(
                f"padend expecting size {padend.from_length} at x.shape[{padend.dim}] but got {x.shape[padend.dim]} (x.shape: {x.shape})"
            )
        return end_pad_with_zeros(x, padend.dim, padend.to_length - padend.from_length)

    @staticmethod
    def adj_fn(padend, y, /):
        if y.shape[padend.dim] != padend.to_length:
            raise ValueError(
                f"PadDim (adjoint) expecting size {padend.to_length} at x.shape[{padend.dim}] but got {y.shape[padend.dim]} (y.shape: {y.shape})"
            )
        return y[padend.slc]

    @staticmethod
    def normal_fn(padend, x, /):
        x = x.clone()
        return x

    def split_forward(self, ibatch, obatch):
        if ibatch[self.dim] != slice(None) or obatch[self.dim] != slice(None):
            raise ValueError("Cannot slice a PadEnd linop along truncation dimension")
        return type(self)(
            self.dim, self.from_length, self.to_length, self.ishape, self.oshape
        )

__init__

__init__(dim, from_length, to_length, ishape, oshape)
Source code in src/torchlinops/linops/trunc_pad.py
def __init__(self, dim, from_length, to_length, ishape, oshape):
    self.dim = dim
    self.from_length = from_length
    self.to_length = to_length
    if self.to_length < 0:
        raise ValueError(f"to_length must be nonnegative but got {self.to_length}")

    if self.from_length < 0 or self.from_length > self.to_length:
        raise ValueError(
            f"to_length must be in [0, {self.to_length}] but got {self.from_length}"
        )

    self.slc = [slice(None)] * len(ishape)
    self.slc[dim] = slice(0, self.from_length)
    self.slc = tuple(self.slc)
    super().__init__(NS(ishape, oshape))