Skip to content

Unfold / Fold

Block extraction (unfold) and block reassembly (fold) operations, supporting 1D/2D/3D with optional masking and Triton backends.

torchlinops.functional.unfold

unfold(
    x: Shaped[Tensor, ...],
    block_size: tuple,
    stride: Optional[tuple] = None,
    mask: Optional[Bool[Tensor, ...]] = None,
    output: Optional[Tensor] = None,
) -> Tensor

Wrapper that dispatches complex and real tensors Also precomputes some shapes

PARAMETER DESCRIPTION
x

Shape [B..., *im_size]

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

If mask is not None, block_size will be an int equal to the number of True elements in the mask Otherwise it will be the full block shape.

TYPE: Shape [B..., *blocks, *block_size]

Source code in src/torchlinops/functional/_unfold/unfold.py
def unfold(
    x: Shaped[Tensor, "..."],
    block_size: tuple,
    stride: Optional[tuple] = None,
    mask: Optional[Bool[Tensor, "..."]] = None,
    output: Optional[Tensor] = None,
) -> Tensor:
    """Wrapper that dispatches complex and real tensors
    Also precomputes some shapes

    Parameters
    ----------
    x : Tensor
        Shape [B..., *im_size]

    Returns
    -------
    Tensor: Shape [B..., *blocks, *block_size]
        If mask is not None, block_size will be an int equal to the number of True elements in the mask
        Otherwise it will be the full block shape.


    """
    x_flat, shapes, is_complex = _prep_unfold(x, block_size, stride, mask)
    if is_complex:
        x_flat = torch.view_as_real(x_flat)
        x_flat = torch.flatten(x_flat, -2, -1)  # Flatten real/imag into last dim
    y_flat = _unfold(x_flat, output=output, **shapes)
    y = y_flat.reshape(
        *shapes["batch_shape"],
        *shapes["nblocks"],
        *shapes["block_size"],
    )
    if is_complex:
        y = y.reshape(*y.shape[:-1], y.shape[-1] // 2, 2)
        y = torch.view_as_complex(y)
    if mask is not None:
        y = y[..., mask]
    return y

torchlinops.functional.fold

fold(
    x,
    im_size: tuple,
    block_size: tuple,
    stride: tuple,
    mask: Optional[Bool[Tensor, ...]] = None,
    output: Optional[Tensor] = None,
) -> Tensor

Accumulate an array of blocks into a full array

PARAMETER DESCRIPTION
x

Shape [B..., blocks, block_size]

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

If mask is not None, block_size will be an int equal to the number of True elements in the mask Otherwise it will be the full block shape.

TYPE: Shape [B..., *im_size]

Source code in src/torchlinops/functional/_unfold/fold.py
def fold(
    x,
    im_size: tuple,
    block_size: tuple,
    stride: tuple,
    mask: Optional[Bool[Tensor, "..."]] = None,
    output: Optional[Tensor] = None,
) -> Tensor:
    """Accumulate an array of blocks into a full array

    Parameters
    ----------
    x : Tensor
        Shape [B..., blocks, block_size]

    Returns
    -------
    Tensor: Shape [B..., *im_size]
        If mask is not None, block_size will be an int equal to the number of True elements in the mask
        Otherwise it will be the full block shape.
    """
    x_flat, shapes, is_complex = _prep_fold(x, im_size, block_size, stride, mask)

    if is_complex:
        x_flat = torch.view_as_real(x_flat)
        x_flat = torch.flatten(x_flat, -2, -1)  # Flatten real/imag into last dim
    y_flat = _fold(x_flat, output=output, **shapes)
    y = y_flat.reshape(*shapes["batch_shape"], *shapes["im_size"])
    if is_complex:
        y = y.reshape(*y.shape[:-1], y.shape[-1] // 2, 2)
        y = torch.view_as_complex(y)
    return y

torchlinops.functional.array_to_blocks

array_to_blocks(
    input,
    block_shape: tuple[int, ...],
    stride: Optional[tuple[int, ...]] = None,
    mask: Optional[Bool[Tensor, ...]] = None,
    out: Optional[Tensor] = None,
)

Wrapper for default arguments

Source code in src/torchlinops/functional/_unfold/array_to_blocks.py
def array_to_blocks(
    input,
    block_shape: tuple[int, ...],
    stride: Optional[tuple[int, ...]] = None,
    mask: Optional[Bool[Tensor, "..."]] = None,
    out: Optional[Tensor] = None,
):
    """Wrapper for default arguments"""
    return ArrayToBlocksFn.apply(input, block_shape, stride, mask, out)

torchlinops.functional.blocks_to_array

blocks_to_array(
    input,
    im_size: tuple,
    block_shape: tuple,
    stride: Optional[tuple] = None,
    mask: Optional[Bool[Tensor, ...]] = None,
    out: Optional[Tensor] = None,
)

Wrapper for default arguments

Source code in src/torchlinops/functional/_unfold/array_to_blocks.py
def blocks_to_array(
    input,
    im_size: tuple,
    block_shape: tuple,
    stride: Optional[tuple] = None,
    mask: Optional[Bool[Tensor, "..."]] = None,
    out: Optional[Tensor] = None,
):
    """Wrapper for default arguments"""
    return BlocksToArrayFn.apply(input, im_size, block_shape, stride, mask, out)