Skip to content

Splitting

torchlinops.linops.split_linop

split_linop(
    linop: NamedLinop,
    batch_sizes: dict[NamedDimension | str, int],
)

Split a linop into an nd-array of sub-linops according to batch sizes.

PARAMETER DESCRIPTION
linop

The linop to be split.

TYPE: NamedLinop

batch_sizes

Dictionary mapping dimension names to chunk sizes.

TYPE: dict[NamedDimension | str, int]

RETURNS DESCRIPTION
linops

Array of sub-linops with shape determined by the number of tiles per dimension.

TYPE: ndarray

input_batches

Corresponding input slices for each tile.

TYPE: ndarray

output_batches

Corresponding output slices for each tile.

TYPE: ndarray

Source code in src/torchlinops/linops/split.py
def split_linop(linop: NamedLinop, batch_sizes: dict[ND | str, int]):
    """Split a linop into an nd-array of sub-linops according to batch sizes.

    Parameters
    ----------
    linop : NamedLinop
        The linop to be split.
    batch_sizes : dict[ND | str, int]
        Dictionary mapping dimension names to chunk sizes.

    Returns
    -------
    linops : np.ndarray
        Array of sub-linops with shape determined by the number of tiles
        per dimension.
    input_batches : np.ndarray
        Corresponding input slices for each tile.
    output_batches : np.ndarray
        Corresponding output slices for each tile.
    """
    # Precompute sizes and shapes
    batch_sizes = {ND.infer(k): v for k, v in batch_sizes.items()}
    sizes = {dim: linop.size(dim) for dim in linop.dims}

    # Make tiles. Each tile is a dictionary mapping a dimension to an integer
    # index of the tile and a slice over that dimension.
    batch_iterators = make_batch_iterators(sizes, batch_sizes)
    tiles: list[dict[ND, Batch]] = list(dict_product(batch_iterators))

    # Allocate outputs
    batch_dims = list(batch_sizes.keys())
    tiled_shape = tuple(ceil(sizes[dim] / batch_sizes[dim]) for dim in batch_dims)
    linops = np.ndarray(tiled_shape, dtype=object)
    input_batches = np.ndarray(tiled_shape, dtype=object)
    output_batches = np.ndarray(tiled_shape, dtype=object)

    for tile in tiles:
        idx = _tile_get_idx(tile, batch_dims)
        linop_tile = _split_linop_with_tile(linop, tile)
        linop_flat = linop_tile.flatten()
        first_linop, last_linop = linop_flat[0], linop_flat[-1]
        linops[idx] = linop_tile
        input_batches[idx] = [
            tile.get(dim, DEFAULT_BATCH)[1] for dim in first_linop.ishape
        ]
        output_batches[idx] = [
            tile.get(dim, DEFAULT_BATCH)[1] for dim in last_linop.oshape
        ]
    return linops, input_batches, output_batches

torchlinops.linops.create_batched_linop

create_batched_linop(
    linop,
    batch_specs: BatchSpec | list[BatchSpec],
    default_device: device = None,
    _mmap=None,
)

Split and distribute a linop across devices according to batch specs.

Recursively processes a list of BatchSpec objects: the first spec splits the linop into tiles, optionally places each tile on a target device, then passes remaining specs to each tile recursively. Tiles are reassembled via Concat (for partitioned dimensions) or Add (for reduced dimensions).

PARAMETER DESCRIPTION
linop

The operator to split and distribute.

TYPE: NamedLinop

batch_specs

One or more batch specifications to apply (processed in order).

TYPE: BatchSpec or list[BatchSpec]

_mmap

Internal memory map for efficient device transfers. Created automatically on the first call. Probably don't set this manually.

TYPE: ModuleMemoryMap DEFAULT: None

_default_device

The default device to use if no device info is provided in the batch spec.

TYPE: device

RETURNS DESCRIPTION
NamedLinop

A composite linop (tree of Concat/Add/ToDevice operators) that is functionally equivalent to the original but distributed according to the batch specs.

Source code in src/torchlinops/linops/split.py
def create_batched_linop(
    linop,
    batch_specs: BatchSpec | list[BatchSpec],
    default_device: torch.device = None,
    _mmap=None,
):
    """Split and distribute a linop across devices according to batch specs.

    Recursively processes a list of ``BatchSpec`` objects: the first spec
    splits the linop into tiles, optionally places each tile on a target
    device, then passes remaining specs to each tile recursively. Tiles are
    reassembled via ``Concat`` (for partitioned dimensions) or ``Add`` (for
    reduced dimensions).

    Parameters
    ----------
    linop : NamedLinop
        The operator to split and distribute.
    batch_specs : BatchSpec or list[BatchSpec]
        One or more batch specifications to apply (processed in order).
    _mmap : ModuleMemoryMap, optional
        Internal memory map for efficient device transfers. Created
        automatically on the first call. Probably don't set this manually.
    _default_device : torch.device, optional
        The default device to use if no device info is provided in the batch spec.

    Returns
    -------
    NamedLinop
        A composite linop (tree of ``Concat``/``Add``/``ToDevice`` operators)
        that is functionally equivalent to the original but distributed
        according to the batch specs.
    """
    if default_device is None:
        default_device = torch.device("cpu")
    if isinstance(batch_specs, BatchSpec):
        # Ensure list
        batch_specs = [batch_specs]
    if _mmap is None:
        _mmap = ModuleMemoryMap()
        _mmap.register_module(linop)
    if len(batch_specs) == 0:
        # Recursive ending
        return linop
    batch_spec = deepcopy(batch_specs[0])
    # Set defaults
    batch_spec.base_device = default_to(default_device, batch_spec.base_device)
    batch_spec.device_matrix = default_to(
        np.array([default_device]), batch_spec.device_matrix
    )

    # Split linop into tiles and broadcast device spec to the tile array.
    linops, ibatches, obatches = split_linop(linop, batch_spec.batch_sizes)
    device_matrix = batch_spec.broadcast_device_matrix(linop)
    if device_matrix.shape != linops.shape:
        raise ValueError(
            f"device_matrix and linops should have same shape after broadcasting, but got device_matrix: {device_matrix.shape} and linops: {linops.shape}"
        )

    # Create event to trigger all tiles in the linop.
    source_device = batch_spec.base_device
    # Allocate output
    for idx in np.ndindex(linops.shape):
        linop, target_device = linops[idx], device_matrix[idx]

        # Recursive call to batch the tile
        tiled_linop = create_batched_linop(
            linop, batch_specs[1:], default_device=target_device, _mmap=_mmap
        )

        # Move linop to device
        tiled_linop = _mmap.memory_aware_to(tiled_linop, target_device)

        # Wrap with device movement linops
        if source_device != target_device:
            tiled_linop = Chain(
                ToDevice(
                    source_device,
                    target_device,
                    ioshape=tiled_linop.ishape,
                ),
                tiled_linop,
                ToDevice(
                    target_device,
                    source_device,
                    ioshape=tiled_linop.oshape,
                ),
            )

        # Overwrite entry in linops
        linops[idx] = tiled_linop

    for dim in reversed(batch_spec.batch_sizes):
        # Manual axis reduction because I made Concat and Add too nice
        flat_linops = linops.reshape(-1, linops.shape[-1])
        new_linops = np.empty(flat_linops.shape[0], dtype=object)
        for i, linop_arr in enumerate(flat_linops):
            linop = linop_arr[0]
            if dim in linop.ishape and dim in linop.oshape:
                new_linop = Concat(*linop_arr, idim=dim, odim=dim)
            elif dim not in linop.ishape and dim in linop.oshape:
                new_linop = Concat(*linop_arr, odim=dim)
            elif dim in linop.ishape and dim not in linop.oshape:
                new_linop = Concat(*linop_arr, idim=dim)
            else:
                new_linop = Add(*linop_arr)
            new_linops[i] = new_linop
        linops = new_linops.reshape(linops.shape[:-1])
    linop = linops.item()
    return linop

torchlinops.linops.BatchSpec dataclass

Specification for splitting and distributing a linop across devices.

PARAMETER DESCRIPTION
batch_sizes

Mapping from dimension names to chunk sizes for tiling.

TYPE: dict[NamedDimension | str, int]

device_matrix

Array of torch.device objects specifying target devices for each tile. Broadcast to match the tile grid shape.

TYPE: ndarray or list DEFAULT: None

base_device

The device where input/output data resides. Default is CPU.

TYPE: device DEFAULT: None

Source code in src/torchlinops/linops/split.py
@dataclass
class BatchSpec:
    """Specification for splitting and distributing a linop across devices.

    Parameters
    ----------
    batch_sizes : dict[ND | str, int]
        Mapping from dimension names to chunk sizes for tiling.
    device_matrix : np.ndarray or list, optional
        Array of ``torch.device`` objects specifying target devices for each
        tile. Broadcast to match the tile grid shape.
    base_device : torch.device, optional
        The device where input/output data resides. Default is CPU.
    """

    batch_sizes: dict[ND | str, int]
    device_matrix: np.ndarray | None = None
    base_device: torch.device | None = None

    def __post_init__(self):
        if not isinstance(self.batch_sizes, dict):
            warn(
                f"Got {self.batch_sizes} of type {type(self.batch_sizes).__name__} for batch_sizes instead of dict."
            )
        # Ensure ndarray
        if isinstance(self.device_matrix, list | tuple):
            self.device_matrix = np.array(self.device_matrix)

    def broadcast_device_matrix(self, linop):
        # Compute the number of tiles along each batched axis/dimension
        batch_dims = list(self.batch_sizes.keys())
        sizes = {dim: linop.size(dim) for dim in linop.dims}
        tiled_shape = tuple(
            ceil(sizes[dim] / self.batch_sizes[dim]) for dim in batch_dims
        )

        # Broadcast device_matrix over requested tiles.
        # Each tile should receive a single device.
        device_matrix = fuzzy_broadcast_to(self.device_matrix, tiled_shape)
        return device_matrix

__init__

__init__(
    batch_sizes: dict[NamedDimension | str, int],
    device_matrix: ndarray | None = None,
    base_device: device | None = None,
) -> None

broadcast_device_matrix

broadcast_device_matrix(linop)
Source code in src/torchlinops/linops/split.py
def broadcast_device_matrix(self, linop):
    # Compute the number of tiles along each batched axis/dimension
    batch_dims = list(self.batch_sizes.keys())
    sizes = {dim: linop.size(dim) for dim in linop.dims}
    tiled_shape = tuple(
        ceil(sizes[dim] / self.batch_sizes[dim]) for dim in batch_dims
    )

    # Broadcast device_matrix over requested tiles.
    # Each tile should receive a single device.
    device_matrix = fuzzy_broadcast_to(self.device_matrix, tiled_shape)
    return device_matrix