jaxmod package

Submodules

jaxmod.constants module

Physical, chemical, and mathematical constants for scientific modelling.

jaxmod.constants.AVOGADRO: float = 6.02214076e+23

Avogadro constant in \(\mathrm{mol}^{-1}\)

jaxmod.constants.GAS_CONSTANT: float = 8.31446261815324

Gas constant in \(\mathrm{J}\ \mathrm{K}^{-1}\ \mathrm{mol}^{-1}\)

jaxmod.constants.GAS_CONSTANT_BAR: float = 8.31446261815324e-05

Gas constant in \(\mathrm{m}^3\ \mathrm{bar}^{-1}\ \mathrm{K}^{-1}\ \mathrm{mol}^{-1}\)

jaxmod.constants.GRAVITATIONAL_CONSTANT: float = 6.6743e-11

Gravitational constant in \(\mathrm{m}^3\ \mathrm{kg}^{-1}\ \mathrm{s}^{-2}\)

jaxmod.constants.ATMOSPHERE: float = 1.01325

Atmospheres in 1 bar

jaxmod.constants.BOLTZMANN_CONSTANT: float = 1.380649e-23

Boltzmann constant in \(\mathrm{J}\ \mathrm{K}^{-1}\)

jaxmod.constants.BOLTZMANN_CONSTANT_BAR: float = 1.3806490000000002e-28

Boltzmann constant in \(\mathrm{bar}\ \mathrm{m}^3\ \mathrm{K}^{-1}\)

jaxmod.constants.EARTH_MASS: float = 5.9722e+24

Mass of Earth in kg

jaxmod.constants.OCEAN_MOLES: float = 7.68894973907177e+22

Moles of \(\mathrm{H}_2\) or \(\mathrm{H}_2\mathrm{O}\) in present-day Earth’s ocean

jaxmod.constants.OCEAN_MASS_H2: float = 1.5500015377899477e+20

Mass of \(\mathrm{H}_2\) in one present-day Earth ocean in kg

jaxmod.constants.OCEAN_MASS_H2O: float = 1.3851863627795307e+21

Mass of \(\mathrm{H}_2\mathrm{O}\) in one present-day Earth ocean in kg

jaxmod.solvers module

Solvers

jaxmod.solvers.POSTCHECK_TOLERANCE: float = 1e-06

Default tolerance for the objective-based convergence validation performed after each solve attempt

class jaxmod.solvers.RootFindParameters(solver: type[AbstractRootFinder | AbstractLeastSquaresSolver | AbstractMinimiser] = <class 'optimistix._solver.newton_chord.Newton'>, atol: float = 1e-06, rtol: float = 1e-06, linear_solver: AbstractLinearSolver = AutoLinearSolver(well_posed=None), norm: Callable = <function max_norm>, throw: bool = False, max_steps: int = 256, jac: Literal['fwd', 'bwd']='fwd')

Bases: Module

Parameters for Optimistix root finding

Parameters:
  • solver – Solver. Defaults to optimistix.Newton.

  • atol – Absolute tolerance. Defaults to 1.0e-6.

  • rtol – Relative tolerance. Defaults to 1.0e-6.

  • linear_solver – Linear solver. Defaults to AutoLinearSolver(well_posed=False).

  • norm – Norm. Defaults to optimistix.max_norm().

  • throw – How to report any failures. Defaults to False.

  • max_steps – The maximum number of steps the solver can take. Defaults to 256.

  • jac – Whether to use forward- or reverse-mode autodifferentiation to compute the Jacobian. Can be either fwd or bwd. Defaults to fwd.

solver

Solver

alias of Newton

atol: float = 1e-06

Absolute tolerance

rtol: float = 1e-06

Relative tolerance

linear_solver: AbstractLinearSolver = AutoLinearSolver(well_posed=None)

//docs.kidger.site/lineax/api/solvers/)

Type:

Linear solver (see https

norm() Shaped[Array, '']

Norm

throw: bool = False

How to report any failures

max_steps: int = 256

Maximum number of steps the solver can take

jac: Literal['fwd', 'bwd'] = 'fwd'

Whether to use forward- or reverse-mode autodifferentiation to compute the Jacobian

get_solver_instance() AbstractRootFinder | AbstractLeastSquaresSolver | AbstractMinimiser

Instantiates the solver

class jaxmod.solvers.MultiAttemptSolution(solution: Solution, _attempts: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0)

Bases: Module

A solution wrapper for handling multiple solver attempts per problem

This class standardises solver outputs from multi-attempt strategies. Some attributes (e.g. converged, solver_success, num_steps) are broadcast to the batch dimension to ensure consistent shapes across all outputs, whether the underlying solver returns scalar or per-attempt values.

Parameters:
  • solution – Optimistix solution

  • _attempts – Number of attempts required for each batch element to converge (0 indicates no successful attempt). Defaults to 0.

solution: Solution
property attempts: Integer[Array, 'batch']
property aux
property batch_shape: tuple[int, ...]

Batch shape (all dimensions except the trailing solution dimension)

property converged: Bool[Array, 'batch']

Boolean mask indicating objective-based convergence

property num_steps: Integer[Array, 'batch']

Number of steps

property result: RESULTS
property value: Float[Array, 'batch solution']
property solver_success: Bool[Array, 'batch']

Whether the underlying solver claims success

property state: Any
property stats: dict[str, PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex | jax._src.literals.TypedNdArray]]
asdict() dict[str, Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]

Converts pertinent solution statistics to a dictionary

jaxmod.solvers.max_norm(objective_function: Callable, solution: Float[Array, '... solution'], parameters: PyTree) Float[Array, '...']

Computes the L-infinity norm of batched objective residuals.

Evaluates the objective function for each model in the batch and returns the maximum absolute residual across all components of each system. This is a vectorised variant of optimistix.max_norm(), producing one scalar L-infinity norm per system in the batch.

See: https://docs.kidger.site/optimistix/api/norms/

Parameters:
  • objective_function – A callable taking solution and parameters that returns the objective residuals for each model in the batch

  • solution – Batched array of candidate solutions

  • parameters – Parameters passed to the objective function

Returns:

An array of the L-infinity norm

jaxmod.solvers.expand_mask(mask: Bool[Array, '...'], target: Float[Array, '... solution']) Bool[Array, '... 1']

Expands a batch mask to broadcast over trailing solution dimensions.

Parameters:
  • mask – Boolean array indicating entries to update

  • target – Array with shape (... solution) that the mask will be expanded to match

Returns:

Boolean array with shape (... 1) that can be broadcast to the shape of target

jaxmod.solvers.make_batch_retry_solver(solver_function: Callable, objective_function: Callable) Callable

Makes a batch retry solver.

solver_function and objective_function must be pure JAX-callable functions compatible with equinox.filter_jit`(). They must not close over non-JAX state or produce Python side effects.

Parameters:
  • solver_function – Callable that performs a single solve and returns an Solution object. Must accept arguments of an initial guess and a pytree of parameters.

  • objective_function – Callable for the objective function

Returns:

Callable

jaxmod.type_aliases module

Common type aliases

jaxmod.type_aliases.NpArray

NumPy array

alias of ndarray[tuple[Any, …], dtype[_ScalarT]]

jaxmod.type_aliases.NpBool

NumPy numpy.bool_ array

alias of ndarray[tuple[Any, …], dtype[bool]]

jaxmod.type_aliases.NpFloat

NumPy numpy.float64 array

alias of ndarray[tuple[Any, …], dtype[float64]]

jaxmod.type_aliases.NpInt

NumPy numpy.int_ array

alias of ndarray[tuple[Any, …], dtype[int64]]

jaxmod.type_aliases.Scalar: TypeAlias = int | float

Scalar

jaxmod.type_aliases.OptxSolver: TypeAlias = optimistix._root_find.AbstractRootFinder | optimistix._least_squares.AbstractLeastSquaresSolver | optimistix._minimise.AbstractMinimiser

Optimistix solver

jaxmod.units module

Unit conversion factors for scientific calculations

class jaxmod.units.UnitConversion(atmosphere_to_bar: float = 1.01325, bar_to_Pa: float = 100000.0, bar_to_MPa: float = 0.1, bar_to_GPa: float = 0.0001, Pa_to_bar: float = 1e-05, MPa_to_bar: float = 10.0, GPa_to_bar: float = 10000.0, fraction_to_ppm: float = 1000000.0, ppm_to_fraction: float = 1e-06, ppm_to_percent: float = 0.0001, percent_to_ppm: float = 10000.0, g_to_kg: float = 0.001, cm3_to_m3: float = 1e-06, m3_to_cm3: float = 1000000.0, litre_to_m3: float = 0.001, m3_bar_to_J: float = 100000.0, J_to_m3_bar: float = 1e-05)

Bases: Module

Unit conversions

atmosphere_to_bar: float = 1.01325
bar_to_Pa: float = 100000.0
bar_to_MPa: float = 0.1
bar_to_GPa: float = 0.0001
Pa_to_bar: float = 1e-05
MPa_to_bar: float = 10.0
GPa_to_bar: float = 10000.0
fraction_to_ppm: float = 1000000.0
ppm_to_fraction: float = 1e-06
ppm_to_percent: float = 0.0001
percent_to_ppm: float = 10000.0
g_to_kg: float = 0.001
cm3_to_m3: float = 1e-06
m3_to_cm3: float = 1000000.0
litre_to_m3: float = 0.001
m3_bar_to_J: float = 100000.0
J_to_m3_bar: float = 1e-05

jaxmod.utils module

Utils

jaxmod.utils.as_j64(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | tuple) Float[Array, '...']

Converts input to a jax.Array of dtype jax.numpy.float64.

This ensures that the array has a fixed dtype, preventing JAX from recompiling functions due to type changes between calls.

Parameters:

x – Input to convert

Returns:

jax.Array of dtype jax.numpy.float64

jaxmod.utils.get_batch_axis(x: Any) Literal[0, None]

Determines the batch axis for a JAX array.

Determines whether an object should be treated as batched along axis 0 for jax.vmap().

This function only considers JAX arrays for batching. While equinox.is_array() regards both JAX and NumPy arrays as arrays for tracing, NumPy arrays are treated here as static constants and are never batched. This allows fixed matrices to remain inside pytrees without being inadvertently vectorised.

Rules:
  • 1-D JAX arrays: Batched along axis 0

  • 2-D JAX arrays: Batched along axis 0 if shape[0]>1

  • 0-D (scalar) JAX arrays: Not batched

  • NumPy arrays or other objects: Not batched

Parameters:

x – Object to check for batching

Returns:

0 if batched along axis 0, otherwise None

jaxmod.utils.get_batch_size(x: PyTree) int

Determines the maximum batch size (i.e., length along axis 0) amongst all array-like leaves.

This inspects every leaf in the pytree and checks whether it is an array. Scalars contribute a size of 1, while arrays contribute the length of their leading dimension (shape[0]). The result is the largest such size found.

Note

Unlike get_batch_axis(), which only considers JAX arrays for batching, this function counts both JAX and NumPy arrays as array-like leaves when computing the maximum batch size.

Parameters:

x – Pytree of nested containers that may include arrays or scalars

Returns:

The maximum leading dimension size across all array-like leaves

jaxmod.utils.is_hashable(x: Any) None

Checks whether an object is hashable and prints the result.

Parameters:

x – Object to check

jaxmod.utils.is_jax_array(element: Any) bool

Checks if element is a JAX array.

Note

NumPy arrays are not considered JAX arrays

Parameters:

element – Object to check

Returns:

True if element is a JAX array, otherwise False

jaxmod.utils.partial_rref(matrix: ndarray[tuple[Any, ...], dtype[_ScalarT]]) ndarray[tuple[Any, ...], dtype[_ScalarT]]

Computes a partial reduced row echelon form (RREF) to determine linear components.

This function performs the computation using NumPy in-place operations and is therefore not compatible with JAX transformations. The returned matrix represents the linear components of the input, extracted from the augmented RREF procedure.

Parameters:

matrix – A 2-D NumPy array of shape (nrows, ncols).

Returns:

A numpy.ndarray containing the linear components.

jaxmod.utils.power_law(values: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, constant: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, exponent: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array

Power law

Parameters:
  • values – Values

  • constant – Constant for the power law

  • exponent – Exponent for the power law

Returns:

Evaluated power law

jaxmod.utils.safe_exp(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array

Computes the elementwise exponential of x with input clipping to prevent overflow.

This function clips the input x to a maximum value defined by MAX_EXP_INPUT before applying jax.numpy.exp(), ensuring numerical stability for large values.

Parameters:

x – Array-like input

Returns:

Array of the same shape as x, where each element is the exponential of the clipped input

jaxmod.utils.to_hashable(x: Callable) Callable

Wraps a callable to make it hashable for JAX transformations.

This wrapper is useful when passing bound methods of Equinox PyTrees (with JAX arrays as attributes) to transformations like jax.jit(), jax.vmap(), or lax.scan(). It wraps the callable in a lambda to forward all arguments while avoiding JAX trying to trace the method itself. See discussion: https://github.com/patrick-kidger/equinox/issues/1011

Parameters:

x – A callable to wrap

Returns:

A hashable lambda forwarding all arguments to the original callable.

jaxmod.utils.to_native_floats(value: Any) Any

Recursively converts any structure to nested tuples of native floats.

Parameters:

value – A scalar, list/tuple/array of floats, or nested thereof

Returns:

A float or nested tuple of floats

jaxmod.utils.vmap_axes_spec(x: PyTree) PyTree[Literal[0, None]]

Recursively generate in_axes for jax.vmap() over a pytree.

Only JAX arrays are considered for batching. NumPy arrays and other objects are treated as static constants (not batched).

Parameters:

x – A pytree potentially containing JAX arrays, NumPy arrays, or scalars

Returns:

A pytree with the same structure as x. Each leaf is 0 if batched, or None if not.

Module contents

Package level variables

jaxmod.MAX_EXP_INPUT: float = np.float64(709.782712893384)

Maximum x for which exp(x) is finite in 64-bit precision to prevent overflow

jaxmod.MIN_EXP_INPUT: float = np.float64(-708.3964185322641)

Minimum x for which exp(x) is non-zero in 64-bit precision to prevent underflow

jaxmod.simple_formatter() Formatter

Simple formatter for logging

Returns:

Formatter for logging

jaxmod.debug_logger() Logger

Sets up debug logging to the console.

Returns:

A logger