jaxmod package
Submodules
jaxmod.constants module
Physical, chemical, and mathematical constants for scientific modelling.
- jaxmod.constants.AVOGADRO: float = 6.02214076e+23
Avogadro constant in \(\mathrm{mol}^{-1}\)
- jaxmod.constants.GAS_CONSTANT: float = 8.31446261815324
Gas constant in \(\mathrm{J}\ \mathrm{K}^{-1}\ \mathrm{mol}^{-1}\)
- jaxmod.constants.GAS_CONSTANT_BAR: float = 8.31446261815324e-05
Gas constant in \(\mathrm{m}^3\ \mathrm{bar}^{-1}\ \mathrm{K}^{-1}\ \mathrm{mol}^{-1}\)
- jaxmod.constants.GRAVITATIONAL_CONSTANT: float = 6.6743e-11
Gravitational constant in \(\mathrm{m}^3\ \mathrm{kg}^{-1}\ \mathrm{s}^{-2}\)
- jaxmod.constants.ATMOSPHERE: float = 1.01325
Atmospheres in 1 bar
- jaxmod.constants.BOLTZMANN_CONSTANT: float = 1.380649e-23
Boltzmann constant in \(\mathrm{J}\ \mathrm{K}^{-1}\)
- jaxmod.constants.BOLTZMANN_CONSTANT_BAR: float = 1.3806490000000002e-28
Boltzmann constant in \(\mathrm{bar}\ \mathrm{m}^3\ \mathrm{K}^{-1}\)
- jaxmod.constants.EARTH_MASS: float = 5.9722e+24
Mass of Earth in kg
- jaxmod.constants.OCEAN_MOLES: float = 7.68894973907177e+22
Moles of \(\mathrm{H}_2\) or \(\mathrm{H}_2\mathrm{O}\) in present-day Earth’s ocean
- jaxmod.constants.OCEAN_MASS_H2: float = 1.5500015377899477e+20
Mass of \(\mathrm{H}_2\) in one present-day Earth ocean in kg
- jaxmod.constants.OCEAN_MASS_H2O: float = 1.3851863627795307e+21
Mass of \(\mathrm{H}_2\mathrm{O}\) in one present-day Earth ocean in kg
jaxmod.solvers module
Solvers
- jaxmod.solvers.POSTCHECK_TOLERANCE: float = 1e-06
Default tolerance for the objective-based convergence validation performed after each solve attempt
- class jaxmod.solvers.RootFindParameters(solver: type[AbstractRootFinder | AbstractLeastSquaresSolver | AbstractMinimiser] = <class 'optimistix._solver.newton_chord.Newton'>, atol: float = 1e-06, rtol: float = 1e-06, linear_solver: AbstractLinearSolver = AutoLinearSolver(well_posed=None), norm: Callable = <function max_norm>, throw: bool = False, max_steps: int = 256, jac: Literal['fwd', 'bwd']='fwd')
Bases:
ModuleParameters for Optimistix root finding
- Parameters:
solver – Solver. Defaults to
optimistix.Newton.atol – Absolute tolerance. Defaults to
1.0e-6.rtol – Relative tolerance. Defaults to
1.0e-6.linear_solver – Linear solver. Defaults to
AutoLinearSolver(well_posed=False).norm – Norm. Defaults to
optimistix.max_norm().throw – How to report any failures. Defaults to
False.max_steps – The maximum number of steps the solver can take. Defaults to
256.jac – Whether to use forward- or reverse-mode autodifferentiation to compute the Jacobian. Can be either
fwdorbwd. Defaults tofwd.
- atol: float = 1e-06
Absolute tolerance
- rtol: float = 1e-06
Relative tolerance
- linear_solver: AbstractLinearSolver = AutoLinearSolver(well_posed=None)
//docs.kidger.site/lineax/api/solvers/)
- Type:
Linear solver (see https
- norm() Shaped[Array, '']
Norm
- throw: bool = False
How to report any failures
- max_steps: int = 256
Maximum number of steps the solver can take
- jac: Literal['fwd', 'bwd'] = 'fwd'
Whether to use forward- or reverse-mode autodifferentiation to compute the Jacobian
- get_solver_instance() AbstractRootFinder | AbstractLeastSquaresSolver | AbstractMinimiser
Instantiates the solver
- class jaxmod.solvers.MultiAttemptSolution(solution: Solution, _attempts: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0)
Bases:
ModuleA solution wrapper for handling multiple solver attempts per problem
This class standardises solver outputs from multi-attempt strategies. Some attributes (e.g.
converged,solver_success,num_steps) are broadcast to the batch dimension to ensure consistent shapes across all outputs, whether the underlying solver returns scalar or per-attempt values.- Parameters:
solution – Optimistix solution
_attempts – Number of attempts required for each batch element to converge (
0indicates no successful attempt). Defaults to0.
- property attempts: Integer[Array, 'batch']
- property aux
- property batch_shape: tuple[int, ...]
Batch shape (all dimensions except the trailing solution dimension)
- property converged: Bool[Array, 'batch']
Boolean mask indicating objective-based convergence
- property num_steps: Integer[Array, 'batch']
Number of steps
- property value: Float[Array, 'batch solution']
- property solver_success: Bool[Array, 'batch']
Whether the underlying solver claims success
- property state: Any
- property stats: dict[str, PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex | jax._src.literals.TypedNdArray]]
- jaxmod.solvers.max_norm(objective_function: Callable, solution: Float[Array, '... solution'], parameters: PyTree) Float[Array, '...']
Computes the L-infinity norm of batched objective residuals.
Evaluates the objective function for each model in the batch and returns the maximum absolute residual across all components of each system. This is a vectorised variant of
optimistix.max_norm(), producing one scalar L-infinity norm per system in the batch.See: https://docs.kidger.site/optimistix/api/norms/
- Parameters:
objective_function – A callable taking
solutionandparametersthat returns the objective residuals for each model in the batchsolution – Batched array of candidate solutions
parameters – Parameters passed to the objective function
- Returns:
An array of the L-infinity norm
- jaxmod.solvers.expand_mask(mask: Bool[Array, '...'], target: Float[Array, '... solution']) Bool[Array, '... 1']
Expands a batch mask to broadcast over trailing solution dimensions.
- Parameters:
mask – Boolean array indicating entries to update
target – Array with shape
(... solution)that the mask will be expanded to match
- Returns:
Boolean array with shape
(... 1)that can be broadcast to the shape oftarget
- jaxmod.solvers.make_batch_retry_solver(solver_function: Callable, objective_function: Callable) Callable
Makes a batch retry solver.
solver_functionandobjective_functionmust be pure JAX-callable functions compatible withequinox.filter_jit`(). They must not close over non-JAX state or produce Python side effects.- Parameters:
solver_function – Callable that performs a single solve and returns an
Solutionobject. Must accept arguments of an initial guess and a pytree of parameters.objective_function – Callable for the objective function
- Returns:
Callable
jaxmod.type_aliases module
Common type aliases
- jaxmod.type_aliases.NpBool
NumPy
numpy.bool_array
- jaxmod.type_aliases.NpFloat
NumPy
numpy.float64array
- jaxmod.type_aliases.NpInt
NumPy
numpy.int_array
- jaxmod.type_aliases.Scalar: TypeAlias = int | float
Scalar
- jaxmod.type_aliases.OptxSolver: TypeAlias = optimistix._root_find.AbstractRootFinder | optimistix._least_squares.AbstractLeastSquaresSolver | optimistix._minimise.AbstractMinimiser
Optimistix solver
jaxmod.units module
Unit conversion factors for scientific calculations
- class jaxmod.units.UnitConversion(atmosphere_to_bar: float = 1.01325, bar_to_Pa: float = 100000.0, bar_to_MPa: float = 0.1, bar_to_GPa: float = 0.0001, Pa_to_bar: float = 1e-05, MPa_to_bar: float = 10.0, GPa_to_bar: float = 10000.0, fraction_to_ppm: float = 1000000.0, ppm_to_fraction: float = 1e-06, ppm_to_percent: float = 0.0001, percent_to_ppm: float = 10000.0, g_to_kg: float = 0.001, cm3_to_m3: float = 1e-06, m3_to_cm3: float = 1000000.0, litre_to_m3: float = 0.001, m3_bar_to_J: float = 100000.0, J_to_m3_bar: float = 1e-05)
Bases:
ModuleUnit conversions
- atmosphere_to_bar: float = 1.01325
- bar_to_Pa: float = 100000.0
- bar_to_MPa: float = 0.1
- bar_to_GPa: float = 0.0001
- Pa_to_bar: float = 1e-05
- MPa_to_bar: float = 10.0
- GPa_to_bar: float = 10000.0
- fraction_to_ppm: float = 1000000.0
- ppm_to_fraction: float = 1e-06
- ppm_to_percent: float = 0.0001
- percent_to_ppm: float = 10000.0
- g_to_kg: float = 0.001
- cm3_to_m3: float = 1e-06
- m3_to_cm3: float = 1000000.0
- litre_to_m3: float = 0.001
- m3_bar_to_J: float = 100000.0
- J_to_m3_bar: float = 1e-05
jaxmod.utils module
Utils
- jaxmod.utils.as_j64(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | tuple) Float[Array, '...']
Converts input to a
jax.Arrayof dtypejax.numpy.float64.This ensures that the array has a fixed dtype, preventing JAX from recompiling functions due to type changes between calls.
- Parameters:
x – Input to convert
- Returns:
jax.Arrayof dtypejax.numpy.float64
- jaxmod.utils.get_batch_axis(x: Any) Literal[0, None]
Determines the batch axis for a JAX array.
Determines whether an object should be treated as batched along axis
0forjax.vmap().This function only considers JAX arrays for batching. While
equinox.is_array()regards both JAX and NumPy arrays as arrays for tracing, NumPy arrays are treated here as static constants and are never batched. This allows fixed matrices to remain inside pytrees without being inadvertently vectorised.- Rules:
1-D JAX arrays: Batched along axis 0
2-D JAX arrays: Batched along axis 0 if
shape[0]>10-D (scalar) JAX arrays: Not batched
NumPy arrays or other objects: Not batched
- Parameters:
x – Object to check for batching
- Returns:
0if batched along axis0, otherwiseNone
- jaxmod.utils.get_batch_size(x: PyTree) int
Determines the maximum batch size (i.e., length along axis
0) amongst all array-like leaves.This inspects every leaf in the pytree and checks whether it is an array. Scalars contribute a size of
1, while arrays contribute the length of their leading dimension (shape[0]). The result is the largest such size found.Note
Unlike
get_batch_axis(), which only considers JAX arrays for batching, this function counts both JAX and NumPy arrays as array-like leaves when computing the maximum batch size.- Parameters:
x – Pytree of nested containers that may include arrays or scalars
- Returns:
The maximum leading dimension size across all array-like leaves
- jaxmod.utils.is_hashable(x: Any) None
Checks whether an object is hashable and prints the result.
- Parameters:
x – Object to check
- jaxmod.utils.is_jax_array(element: Any) bool
Checks if
elementis a JAX array.Note
NumPy arrays are not considered JAX arrays
- Parameters:
element – Object to check
- Returns:
Trueifelementis a JAX array, otherwiseFalse
- jaxmod.utils.partial_rref(matrix: ndarray[tuple[Any, ...], dtype[_ScalarT]]) ndarray[tuple[Any, ...], dtype[_ScalarT]]
Computes a partial reduced row echelon form (RREF) to determine linear components.
This function performs the computation using NumPy in-place operations and is therefore not compatible with JAX transformations. The returned matrix represents the linear components of the input, extracted from the augmented RREF procedure.
- Parameters:
matrix – A 2-D NumPy array of shape (nrows, ncols).
- Returns:
A
numpy.ndarraycontaining the linear components.
- jaxmod.utils.power_law(values: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, constant: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, exponent: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array
Power law
- Parameters:
values – Values
constant – Constant for the power law
exponent – Exponent for the power law
- Returns:
Evaluated power law
- jaxmod.utils.safe_exp(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array
Computes the elementwise exponential of
xwith input clipping to prevent overflow.This function clips the input
xto a maximum value defined byMAX_EXP_INPUTbefore applyingjax.numpy.exp(), ensuring numerical stability for large values.- Parameters:
x – Array-like input
- Returns:
Array of the same shape as
x, where each element is the exponential of the clipped input
- jaxmod.utils.to_hashable(x: Callable) Callable
Wraps a callable to make it hashable for JAX transformations.
This wrapper is useful when passing bound methods of Equinox PyTrees (with JAX arrays as attributes) to transformations like
jax.jit(),jax.vmap(), orlax.scan(). It wraps the callable in a lambda to forward all arguments while avoiding JAX trying to trace the method itself. See discussion: https://github.com/patrick-kidger/equinox/issues/1011- Parameters:
x – A callable to wrap
- Returns:
A hashable lambda forwarding all arguments to the original callable.
- jaxmod.utils.to_native_floats(value: Any) Any
Recursively converts any structure to nested tuples of native floats.
- Parameters:
value – A scalar, list/tuple/array of floats, or nested thereof
- Returns:
A float or nested tuple of floats
- jaxmod.utils.vmap_axes_spec(x: PyTree) PyTree[Literal[0, None]]
Recursively generate
in_axesforjax.vmap()over a pytree.Only JAX arrays are considered for batching. NumPy arrays and other objects are treated as static constants (not batched).
- Parameters:
x – A pytree potentially containing JAX arrays, NumPy arrays, or scalars
- Returns:
A pytree with the same structure as
x. Each leaf is0if batched, orNoneif not.
Module contents
Package level variables
- jaxmod.MAX_EXP_INPUT: float = np.float64(709.782712893384)
Maximum x for which exp(x) is finite in 64-bit precision to prevent overflow
- jaxmod.MIN_EXP_INPUT: float = np.float64(-708.3964185322641)
Minimum x for which exp(x) is non-zero in 64-bit precision to prevent underflow
- jaxmod.simple_formatter() Formatter
Simple formatter for logging
- Returns:
Formatter for logging
- jaxmod.debug_logger() Logger
Sets up debug logging to the console.
- Returns:
A logger