Skip to content

API Reference

This page documents the public API of JAX-AMG.

Solver

jaxamg.solve(A, b, config=None, comm=None, nglobal=None, partition_info=None, save_stats_file=None, **kwargs)

Solve Ax=b using the AmgX backend. See Examples for usage.

Parameters:

Name Type Description Default
A MatrixOrOperator

Matrix or callable operator A(x). All matrices/operators are converted to jax.experimental.sparse.bcsr sparse matrices internally. In MPI mode this is the local partition.

required
b ArrayLike

Right-hand-side vector. In MPI mode this is the local RHS partition.

required
config dict | None

AmgX configuration dictionary (see Solver Configuration for details). If None, JAX-AMG defaults are used.

None
comm Comm | None

MPI communicator (typically mpi4py.MPI.COMM_WORLD). If provided, the solve runs in MPI mode. If not provided, MPI mode can still be used if MPI metadata has already been attached via with_cache(..., mpi=...).

None
nglobal int | None

Global matrix row count for MPI mode. Required when comm is provided and MPI metadata is not pre-attached to A.

None
partition_info tuple[int, int] | None

(row_start, row_end) owned by this rank in MPI mode. Required when comm is provided and MPI metadata is not pre-attached to A.

None
save_stats_file str | PathLike | None

Optional file path to save detailed AmgX solver statistics. If None, no file is created.

None
**kwargs Any

Additional AmgX config parameters. These override values in config when both are provided.

{}

Returns:

Name Type Description
x Array

Solution vector (float32 or float64). In MPI mode, returns local portion.

info dict

Dictionary containing iterations, residual, and status.

Status Codes

jaxamg.AMGXStatus

Bases: IntEnum

High-level AmgX solve status codes returned in info["status"] after calling jaxamg.solve.

These values are mapped from the native backend status for quick checks in Python code and in docs.

Members
  • SUCCESS: Solve converged successfully.
  • FAILED: Solver failed due to an internal/runtime error.
  • DIVERGED: Iterations diverged.
  • NOT_CONVERGED: Reached stopping criteria without convergence.

Caching

jaxamg.with_cache(A, *, coloring=None, mpi=None, is_symmetric=False)

Attach cached metadata (coloring, MPI info, or symmetry) to a matrix or operator.

This cache allows using matrices/operators inside JIT-compiled functions without recomputing metadata or passing it as separate arguments. See Caching Guide for more details.

Parameters:

Name Type Description Default
A MatrixOrOperator

A matrix or operator.

required
coloring tuple[ndarray, ndarray, ndarray, int, tuple[int, int]] | None

Cached coloring information from cache_coloring().

None
mpi dict[str, Any] | None

Cached MPI metadata from cache_mpi_metadata().

None
is_symmetric bool

If True, indicates the matrix is symmetric, allowing optimizations like skipping transpose in backward pass.

False

Returns:

Type Description
MatrixOrOperator

The same matrix/operator with requested cache attached.

jaxamg.cache_coloring(operator, shape)

Compute and cache coloring information for a callable operator.

Detection uses two methods, so the result is correct for ANY operator:

  1. Tracing: interpret the operator's jaxpr to recover the EXACT sparsity in a single trace (no probing), then colour and materialise it. Works for any JAX-expressed operator; skipped for operators that can't be traced structurally (opaque calls, data-dependent indexing).
  2. Probing (probe_sparsity_pattern + get_column_coloring): exhaustive one-hot basis-vector probing, correct for any operator -- the fallback when tracing is unavailable.

Parameters:

Name Type Description Default
operator Any

A callable operator A(x) that returns A @ x.

required
shape tuple[int, int] | int

Shape of the operator (n, m) or int size (for an n×n matrix). For a distributed operator this is the local block (n_local, n_global).

required

Returns:

Type Description
tuple[ndarray, ndarray, ndarray, int, tuple[int, int]]

Cached coloring information for reattachment with with_cache(..., coloring=...).

jaxamg.cache_mpi_metadata(config, comm, nglobal, partition_info, A)

Pre-compute and cache MPI metadata for JIT-compatible solver usage.

The cached metadata can be reused across multiple JIT-compiled function calls with different matrices or operators (same structure).

Note

This function performs all non-traceable MPI operations outside the JIT boundary:

  • Computes static MPI communication metadata (recvcounts, displs)
  • Prepares MPI communicator pointer and local rank
  • Prepares config string
  • Computes max nnz across all ranks

Parameters:

Name Type Description Default
config dict

AmgX configuration dict or string

required
comm Comm

MPI communicator (from mpi4py.MPI.COMM_WORLD)

required
nglobal int

Global matrix size (total rows across all ranks)

required
partition_info tuple[int, int]

tuple (row_start, row_end) indicating which rows this rank owns

required
A MatrixOrOperator

Matrix or operator to compute max nnz for buffer sizing

required

Returns:

Type Description
dict[str, Any]

A dictionary containing MPI metadata.

Note

The returned dictionary includes the following keys:

  • recvcounts_tuple: Tuple of row counts per rank
  • displs_tuple: Tuple of displacement offsets
  • comm_ptr: MPI communicator pointer
  • lrank: Local GPU rank
  • nglobal: Global matrix size
  • config_str: Prepared configuration string
  • max_nnz: Maximum nnz across all ranks

Preconditioner

jaxamg.make_preconditioner(A, config=None, *, comm=None, nglobal=None, partition_info=None, save_stats_file=None, return_info=False, **kwargs)

Create a callable approximate inverse for external Krylov solvers.

The returned callable can be passed directly as the M argument to jax.scipy.sparse.linalg.cg(...) or jax.scipy.sparse.linalg.bicgstab(...).

Parameters:

Name Type Description Default
A MatrixOrOperator

Matrix or callable operator to precondition.

required
config dict[str, Any] | None

Optional AmgX configuration. If omitted, an AMG-only approximate-inverse config is used.

None
comm Comm | None

Optional MPI communicator for distributed solves. If A already has MPI metadata attached via jaxamg.with_cache(..., mpi=...), this may be omitted.

None
nglobal int | None

Global matrix row count for MPI mode.

None
partition_info tuple[int, int] | None

Local row partition (row_start, row_end) for MPI mode.

None
save_stats_file str | None

Optional stats output path passed to jaxamg.solve(...).

None
return_info bool

If True, the returned callable yields (x, info) instead of only x.

False
**kwargs Any

Additional solver config overrides.

{}

Returns:

Type Description
Callable

A callable representing an approximate inverse M^{-1}.

Runtime Utilities

jaxamg.clear_solver_cache()

Clear the internal C++ AmgX solver cache. This releases all cached AmgX resources (matrices, solvers, vectors).

jaxamg.finalize()

Manually finalize AmgX resources. This clears the cache and calls AMGX_finalize. Normally only needed to be called manually in MPI mode to avoid shutdown-time resource warnings.