API Reference
This page documents the public API of JAX-AMG.
Solver
jaxamg.solve(A, b, config=None, comm=None, nglobal=None, partition_info=None, save_stats_file=None, **kwargs)
Solve Ax=b using the AmgX backend. See Examples for usage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
MatrixOrOperator
|
Matrix or callable operator A(x). All matrices/operators are converted to |
required |
b
|
ArrayLike
|
Right-hand-side vector. In MPI mode this is the local RHS partition. |
required |
config
|
dict | None
|
AmgX configuration dictionary (see Solver Configuration for details). If |
None
|
comm
|
Comm | None
|
MPI communicator (typically |
None
|
nglobal
|
int | None
|
Global matrix row count for MPI mode. Required when |
None
|
partition_info
|
tuple[int, int] | None
|
|
None
|
save_stats_file
|
str | PathLike | None
|
Optional file path to save detailed AmgX solver statistics. If None, no file is created. |
None
|
**kwargs
|
Any
|
Additional AmgX config parameters. These override values in |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
x |
Array
|
Solution vector (float32 or float64). In MPI mode, returns local portion. |
info |
dict
|
Dictionary containing |
Status Codes
jaxamg.AMGXStatus
Bases: IntEnum
High-level AmgX solve status codes returned in info["status"] after calling jaxamg.solve.
These values are mapped from the native backend status for quick checks in Python code and in docs.
Members
SUCCESS: Solve converged successfully.FAILED: Solver failed due to an internal/runtime error.DIVERGED: Iterations diverged.NOT_CONVERGED: Reached stopping criteria without convergence.
Caching
jaxamg.with_cache(A, *, coloring=None, mpi=None, is_symmetric=False)
Attach cached metadata (coloring, MPI info, or symmetry) to a matrix or operator.
This cache allows using matrices/operators inside JIT-compiled functions without recomputing metadata or passing it as separate arguments. See Caching Guide for more details.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
MatrixOrOperator
|
A matrix or operator. |
required |
coloring
|
tuple[ndarray, ndarray, ndarray, int, tuple[int, int]] | None
|
Cached coloring information from |
None
|
mpi
|
dict[str, Any] | None
|
Cached MPI metadata from |
None
|
is_symmetric
|
bool
|
If True, indicates the matrix is symmetric, allowing optimizations like skipping transpose in backward pass. |
False
|
Returns:
| Type | Description |
|---|---|
MatrixOrOperator
|
The same matrix/operator with requested cache attached. |
jaxamg.cache_coloring(operator, shape)
Compute and cache coloring information for a callable operator.
Detection uses two methods, so the result is correct for ANY operator:
- Tracing: interpret the operator's jaxpr to recover the EXACT sparsity in a single trace (no probing), then colour and materialise it. Works for any JAX-expressed operator; skipped for operators that can't be traced structurally (opaque calls, data-dependent indexing).
- Probing (
probe_sparsity_pattern+get_column_coloring): exhaustive one-hot basis-vector probing, correct for any operator -- the fallback when tracing is unavailable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
Any
|
A callable operator A(x) that returns |
required |
shape
|
tuple[int, int] | int
|
Shape of the operator (n, m) or int size (for an n×n matrix). For a
distributed operator this is the local block |
required |
Returns:
| Type | Description |
|---|---|
tuple[ndarray, ndarray, ndarray, int, tuple[int, int]]
|
Cached coloring information for reattachment with |
jaxamg.cache_mpi_metadata(config, comm, nglobal, partition_info, A)
Pre-compute and cache MPI metadata for JIT-compatible solver usage.
The cached metadata can be reused across multiple JIT-compiled function calls with different matrices or operators (same structure).
Note
This function performs all non-traceable MPI operations outside the JIT boundary:
- Computes static MPI communication metadata (recvcounts, displs)
- Prepares MPI communicator pointer and local rank
- Prepares config string
- Computes max nnz across all ranks
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict
|
AmgX configuration dict or string |
required |
comm
|
Comm
|
MPI communicator (from mpi4py.MPI.COMM_WORLD) |
required |
nglobal
|
int
|
Global matrix size (total rows across all ranks) |
required |
partition_info
|
tuple[int, int]
|
tuple (row_start, row_end) indicating which rows this rank owns |
required |
A
|
MatrixOrOperator
|
Matrix or operator to compute max nnz for buffer sizing |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing MPI metadata. |
Note
The returned dictionary includes the following keys:
recvcounts_tuple: Tuple of row counts per rankdispls_tuple: Tuple of displacement offsetscomm_ptr: MPI communicator pointerlrank: Local GPU ranknglobal: Global matrix sizeconfig_str: Prepared configuration stringmax_nnz: Maximum nnz across all ranks
Preconditioner
jaxamg.make_preconditioner(A, config=None, *, comm=None, nglobal=None, partition_info=None, save_stats_file=None, return_info=False, **kwargs)
Create a callable approximate inverse for external Krylov solvers.
The returned callable can be passed directly as the M argument to
jax.scipy.sparse.linalg.cg(...) or jax.scipy.sparse.linalg.bicgstab(...).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
MatrixOrOperator
|
Matrix or callable operator to precondition. |
required |
config
|
dict[str, Any] | None
|
Optional AmgX configuration. If omitted, an AMG-only approximate-inverse config is used. |
None
|
comm
|
Comm | None
|
Optional MPI communicator for distributed solves. If |
None
|
nglobal
|
int | None
|
Global matrix row count for MPI mode. |
None
|
partition_info
|
tuple[int, int] | None
|
Local row partition |
None
|
save_stats_file
|
str | None
|
Optional stats output path passed to |
None
|
return_info
|
bool
|
If |
False
|
**kwargs
|
Any
|
Additional solver config overrides. |
{}
|
Returns:
| Type | Description |
|---|---|
Callable
|
A callable representing an approximate inverse |
Runtime Utilities
jaxamg.clear_solver_cache()
Clear the internal C++ AmgX solver cache. This releases all cached AmgX resources (matrices, solvers, vectors).
jaxamg.finalize()
Manually finalize AmgX resources. This clears the cache and calls AMGX_finalize. Normally only needed to be called manually in MPI mode to avoid shutdown-time resource warnings.