1. Getting Started
1.1. User installation
Standard Installation (NumPy backend)
The package is available on PyPI:
pip install specular-differentiation
Check the version:
import specular
print("version: ", specular.__version__)
version: 1.1.0
ODE solvers
pip install "specular-differentiation[ode]"
Optimization routines
pip install "specular-differentiation[optimization]"
Numba backend
pip install "specular-differentiation[numba]"
If Numba is installed and available, the package may use the Numba-accelerated CPU backend.
JAX backend
By default, the package uses the NumPy backend (CPU). To enable hardware acceleration, you can install the package with the JAX backend (GPU/TPU). This adds the following dependencies:
- JAX (
jax,jaxlib>= 0.4):
pip install "specular-differentiation[jax]"
Note
This feature is experimental for now. See 2.4 Backend.
Developer installation
To install all dependencies including tests, docs, and examples. This adds the following dependencies:
- optional extras:
ode,optimization,numba, andjax - SciPy (
scipy>= 1.10.0) - PyTorch (
torch>= 2.0.0) - Pytest (
pytest>= 7.0)
pip install -e ".[dev]"
1.2. Quick start
The following simple example calculates the specular derivative of the ReLU function \(f(x) = max(0, x)\) at the origin.
import specular
ReLU = lambda x: max(x, 0)
specular.derivative(ReLU, x=0)
0.41421356237309515
1.3. Backend Usage
To use the JAX backend, install the JAX extra and select the backend explicitly:
import jax.numpy as jnp
import specular
specular.change_backend("cpu_jax")
ReLU = lambda x: jnp.maximum(x, 0)
specular.derivative(ReLU, 0.0)
Array(0.41421354, dtype=float32)
To enable 64-bit precision (double precision), update the JAX configuration as follows:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import specular
specular.change_backend("cpu_jax")
ReLU = lambda x: jnp.maximum(x, 0)
specular.derivative(ReLU, 0.0)
Array(0.41421356, dtype=float64)