Skip to content

Conjugate Gradient

Preconditioned conjugate gradient solver for linear systems.

torchlinops.alg.conjugate_gradients

conjugate_gradients(
    A: Callable,
    y: Tensor,
    x0: Optional[Tensor] = None,
    max_num_iters: int = 20,
    gtol: float = 0.001,
    ltol: float = 1e-05,
    disable_tracking: bool = False,
    tqdm_kwargs: Optional[dict] = None,
) -> Tensor | None

Solve \(Ax = y\) with the conjugate gradient method.

\(A\) must be positive semidefinite (Hermitian). The algorithm iterates at most max_num_iters times or until both the loss-difference and gradient-norm convergence criteria are met.

PARAMETER DESCRIPTION
A

Function implementing the matrix-vector product \(A(x)\).

TYPE: Callable[[Tensor], Tensor]

y

Right-hand side of the linear system.

TYPE: Tensor

x0

Initial guess. Defaults to the zero vector.

TYPE: Tensor DEFAULT: None

max_num_iters

Maximum number of CG iterations.

TYPE: int DEFAULT: 20

gtol

Convergence tolerance on the gradient norm \(\|Ax - y\|\).

TYPE: float DEFAULT: 1e-3

ltol

Convergence tolerance on the absolute change in loss between successive iterations.

TYPE: float DEFAULT: 1e-5

disable_tracking

If True, skip loss/gradient tracking for speed (convergence checking is also disabled).

TYPE: bool DEFAULT: False

tqdm_kwargs

Extra keyword arguments forwarded to tqdm.

TYPE: dict DEFAULT: None

RETURNS DESCRIPTION
Tensor or None

The approximate solution \(x\), or None if the solver was not able to produce a result.

Source code in src/torchlinops/alg/pcg.py
def conjugate_gradients(
    A: Callable,
    y: Tensor,
    x0: Optional[Tensor] = None,
    max_num_iters: int = 20,
    gtol: float = 1e-3,
    ltol: float = 1e-5,
    disable_tracking: bool = False,
    tqdm_kwargs: Optional[dict] = None,
) -> Tensor | None:
    """Solve $Ax = y$ with the conjugate gradient method.

    $A$ must be positive semidefinite (Hermitian). The algorithm iterates at
    most *max_num_iters* times or until both the loss-difference and
    gradient-norm convergence criteria are met.

    Parameters
    ----------
    A : Callable[[Tensor], Tensor]
        Function implementing the matrix-vector product $A(x)$.
    y : Tensor
        Right-hand side of the linear system.
    x0 : Tensor, optional
        Initial guess. Defaults to the zero vector.
    max_num_iters : int, default 20
        Maximum number of CG iterations.
    gtol : float, default 1e-3
        Convergence tolerance on the gradient norm $\\|Ax - y\\|$.
    ltol : float, default 1e-5
        Convergence tolerance on the absolute change in loss between
        successive iterations.
    disable_tracking : bool, default False
        If ``True``, skip loss/gradient tracking for speed (convergence
        checking is also disabled).
    tqdm_kwargs : dict, optional
        Extra keyword arguments forwarded to ``tqdm``.

    Returns
    -------
    Tensor or None
        The approximate solution $x$, or ``None`` if the solver was not
        able to produce a result.
    """
    # Default values
    if x0 is None:
        x = torch.zeros_like(y)
    else:
        x = x0.clone()
    tqdm_kwargs = default_to_dict(dict(desc="CG", leave=False), tqdm_kwargs)

    # Initialize run
    run = CGRun(ltol, gtol, A, y, disable=disable_tracking)
    run.update(x)

    r = y - A(x)
    p = r.clone()
    rs = zdot(r, r).real
    with tqdm(range(max_num_iters), **tqdm_kwargs) as pbar:
        for k in pbar:
            Ap = A(p)
            pAp = zdot(p, Ap)
            alpha = rs / pAp
            # Take step
            x = x + alpha * p
            r = r - alpha * Ap
            rs_old = rs.clone()
            rs = zdot(r, r).real
            run.update(x)
            # Stopping criterion
            if run.is_converged():
                break

            run.set_postfix(pbar)
            beta = rs / rs_old
            p = beta * p + r
    return run.x_out