Multi-GPU Splitting
Linops should be able to take advantage of multi-GPU systems to leverage the larger total GPU memory available and to gain increased speed from parallelization across separate devices.
We assume CUDA devices with peer-to-peer memory access.
The splitting mechanism
At its core, multi-GPU distribution is built on the ability to split a linop into smaller sub-linops that each operate on a slice of the data.
split_forward
Every NamedLinop can implement split_forward(ibatch, obatch), where ibatch and obatch are lists of slices corresponding to the input and output dimensions. The method returns a new linop that operates only on the specified slice.
For example, a Diagonal linop with shape (Nx, Ny) -> (Nx, Ny) and a weight tensor of shape (256, 256) can be split along Nx into two sub-linops, each with a weight of shape (128, 256).
The static split(linop, tile) method provides a higher-level interface that accepts a dictionary mapping dimension names to slices:
from torchlinops.nameddim import NamedDimension as ND
tile = {ND("Nx"): slice(0, 128)}
sub_linop = NamedLinop.split(A, tile)
For adjoint splitting, use adj_split(linop, tile) which constructs the adjoint, splits it according to tile, and returns the adjoint of the split.
split_linop
The split_linop() function automates the tiling process. Given a linop and a dictionary of batch sizes, it:
- Queries
linop.size(dim)for every dimension to determine the total size. - Creates a grid of tiles, where each tile maps dimensions to
(index, slice)pairs. - Calls
split()for each tile, producing an nd-array of sub-linops.
from torchlinops.linops.split import split_linop
# Split a linop into chunks of 128 along Nx and 64 along Ny
linops, ibatches, obatches = split_linop(A, {"Nx": 128, "Ny": 64})
# linops is a 2D numpy array of sub-linops
Chain splitting
When splitting a Chain (a composition of linops), each constituent linop is split independently according to the slices over its own dimensions. This means that a chain \(C \circ B \circ A\) is split into tiles where each tile is a chain of the corresponding sub-linops.
The Chain.split_forward() method receives lists of slices per constituent linop:
- ibatches: list of lists, one per linop in the chain
- obatches: list of lists, one per linop in the chain
BatchSpec
BatchSpec is a dataclass that bundles all the information needed to split and distribute a linop:
| Field | Description |
|---|---|
batch_sizes |
dict[dim, int] -- how large each chunk should be |
device_matrix |
Optional array of torch.device objects, one per tile |
base_device |
The device where input/output tensors live |
The device_matrix is broadcast to match the tile grid shape. For example, if splitting creates a 4-tile grid and device_matrix = ["cuda:0", "cuda:1"], it is repeated to ["cuda:0", "cuda:1", "cuda:0", "cuda:1"]. The broadcasting uses a fuzzy strategy that tiles and truncates as needed, so the device list does not need to exactly match the number of tiles.
BatchSpec has a broadcast_device_matrix(linop) method that computes the number of tiles along each batched dimension and broadcasts the device matrix accordingly.
create_batched_linop
This is the main entry point for multi-GPU distribution. It takes a linop and one or more BatchSpec objects and returns a new composite linop that transparently handles splitting, device placement, and reassembly.
How it works
- Split the linop into tiles according to
batch_sizes. - Place each tile on its target device using
ModuleMemoryMap.memory_aware_to(), which preserves tensor storage topology (see Copying Linops). - Wrap each tile with
ToDevicelinops for input transfer (base -> target) and output collection (target -> base). - Reassemble tiles by reducing along each split dimension:
- If the dimension appears in both
ishapeandoshape: useConcat(the tiles partition the data along that dim). - If the dimension appears only in
ishapeor only inoshape: useConcatalong the relevant side. - If the dimension appears in neither: use
Add(the tiles produce partial results that must be summed).
- If the dimension appears in both
The result is a single composite linop that behaves identically to the original but executes across multiple devices.
Recursive batching
create_batched_linop accepts a list of BatchSpec objects and processes them recursively. This enables multi-level splitting -- for example, first splitting across GPUs along one dimension, then splitting within each GPU along another dimension for memory management.
Data transfer and synchronization
DeviceSpec
DeviceSpec is a lightweight dataclass that holds useful CUDA-related objects for multi-GPU computation:
| Field | Description |
|---|---|
device |
The torch.device for this specification |
compute_stream |
Stream used for computation on this device |
transfer_stream |
Stream used for data transfers to/from this device |
DeviceSpec has a p2p_setup(other_device) method that configures compute and transfer streams for peer-to-peer transfers between devices. This is called automatically when creating ToDevice linops between CUDA devices.
The transfer stream is obtained from a registry (_TRANSFER_STREAMS_REGISTRY) to enable stream reuse. Each source/target device pair gets a dedicated transfer stream.
ToDevice
ToDevice is a specialized linop that moves tensors between devices. It is the glue between the base device (where input/output data lives) and the target devices (where computation happens).
Key attributes:
| Attribute | Type | Description |
|---|---|---|
ispec |
DeviceSpec |
Source (input) device specification |
ospec |
DeviceSpec |
Target (output) device specification |
input_listener |
tuple(linop, str) or None |
Event to wait on before transferring |
is_gpu2gpu |
bool |
True if both source and target are CUDA devices |
The adjoint of ToDevice(A -> B) is ToDevice(B -> A) -- it simply reverses the direction and swaps the device specs.
For CUDA-to-CUDA transfers, ToDevice uses non-blocking operations on specific streams:
Input on base_device
-> transfer_stream: non-blocking .to(target_device)
-> target_stream: wait for transfer, run computation
-> base_stream: collect output back to base_device
Key implementation details:
x.record_stream(stream)prevents PyTorch's caching allocator from freeing the source tensor's memory before the transfer completes.ostream.wait_stream(istream)ensures the target stream does not start computation until the data has arrived.- The transfer stream is obtained via
DeviceSpec.get_transfer_stream(source, target).
Input Listeners
The input_listener attribute enables coordination between parallel GPU transfers. It specifies an event (via a tuple of (linop, attribute_name)) that the transfer should wait on before initiating.
This is particularly useful when multiple ToDevice operations need to be triggered in parallel:
By setting both ToDevice1.input_listener and ToDevice2.input_listener to reference the start of C, both device movements can be triggered in parallel when C begins execution.
The NamedLinop.input_listener property uses ForwardedAttribute to enable this cross-linop attribute forwarding.
RepeatedEvent
RepeatedEvent is a lightweight wrapper around CUDA events that creates a fresh event on each record() call. This is used as the start_event on the top-level batched linop: when forward() is called, it records an event that all ToDevice input transfers wait on.
Rather than creating new events and re-registering them every time the linop needs to be run, the RepeatedEvent automatically refreshes itself one each call.
This start_event is necessary to prevent computation or data transfer on other streams from occuring before the start of the linop itself, since repeated linop applications automatically queue kernels on those other streams.
Stream workflow
The full execution flow for a multi-GPU forward pass:
start_event.record()on the current stream -- signals that input data is ready.- For each tile:
transfer_stream.wait_event(start_event)-- wait for input.- Transfer input slice to target device via
transfer_stream. target_stream.wait_stream(transfer_stream)-- wait for data arrival.- Compute on
target_stream. - Transfer output back to base device.
- Reassemble outputs on base device (via
ConcatorAdd).
Additionally, there is a notion of a "base" device. The base device orchestrates all the transfers and is the device on which the input is required and on which the final output is ultimately produced.
For a multi-GPU setup with GPU0 (base) to GPU1, the default behavior is:
- GPU0
- default_stream: computation
- transfer_stream: Moving tensors between GPU0 and GPU1
- GPU1
- default_stream: computation
Configuration
Set torchlinops.config.log_device_transfers = True to enable debug logging of CUDA events, stream synchronization, and device transfers. This is useful for debugging multi-GPU workflows.
import torchlinops.config as config
config.log_device_transfers = True # Enable logging
config.log_device_transfers = False # Disable logging
Limitations and future work
- Peer-to-peer access: The current implementation assumes efficient P2P memory access between GPUs. Systems without P2P will fall back to staging through host memory, which is slower.
- Manual tuning: Choosing optimal
batch_sizesanddevice_matrixrequires understanding the model's memory footprint and the hardware topology. No auto-tuning is provided. - Single-node only: Running computations on distributed GPU nodes across multiple servers is possible in principle via standard PyTorch distributed APIs, but no simplified API is provided within this library.