RUNLOCALAIv38
->Will it run?Best GPUCompareTroubleshootStartLearnPulseModelsHardwareToolsBench
Run check
RUNLOCALAI

Independently operated catalog for local-AI hardware and software. Hand-written verdicts. Source-cited claims. Reproducible commands when we have them.

OP·Fredoline Eruo
DIR
  • Models
  • Hardware
  • Tools
  • Benchmarks
TOOLS
  • Will it run?
  • Compare hardware
  • Cost vs cloud
  • Choose my GPU
  • Prompting kits
  • Quick answers
REF
  • All buyer guides
  • Learn local AI
  • Methodology
  • Glossary
  • Errors KB
  • Trust
EDITOR
  • About
  • Author
  • How we make money
  • Editorial policy
  • Contact
LEGAL
  • Privacy
  • Terms
  • Sitemap
MAIL · MONTHLY DIGEST
Get monthly local AI changes
Monthly recap. No spam.
DISCLOSURE

Some links on this site are affiliate links (Amazon Associates and other first-class retailers). When you buy through them, we earn a small commission at no extra cost to you. Affiliate links do not influence our verdicts — there are cards we rate highly that we don't have affiliate relationships with, and cards that sell well that we refuse to recommend. Read more →

© 2026 runlocalai.coIndependently operated
RUNLOCALAI · v38
Glossary / Frameworks & tools / JAX
Frameworks & tools

JAX

JAX is a numerical computing library from Google that combines NumPy-like array operations with automatic differentiation and just-in-time (JIT) compilation via XLA. Operators encounter JAX primarily as the backend for model training and inference in frameworks like Flax, Haiku, and Transformers JAX. Unlike PyTorch or TensorFlow, JAX uses a functional programming model: tensors are immutable, and random state is explicit. JIT compilation can accelerate large matrix operations, but the functional style and compilation overhead make JAX less common for local inference than for research or TPU-based workloads.

Deeper dive

JAX's core design is functional: every operation returns a new array, and randomness requires explicit PRNG keys. This purity enables powerful transformations like grad (automatic differentiation), vmap (vectorization), pmap (parallelization across devices), and jit (JIT compilation). XLA (Accelerated Linear Algebra) compiles JAX functions into efficient GPU/TPU kernels. For local AI operators, JAX is rarely used for inference because the compilation step adds startup latency and the functional style complicates stateful model serving. However, some Hugging Face models (e.g., BERT, GPT-2) have JAX weights available, and Flax is a popular JAX neural network library. JAX's performance on consumer GPUs can match or exceed PyTorch for training, but the ecosystem for local deployment (quantization, KV-cache management) is less mature.

Practical example

A Hugging Face model like bert-base-uncased offers JAX weights. Loading it with FlaxBertForSequenceClassification.from_pretrained('bert-base-uncased') downloads the JAX checkpoint. Inference requires a JIT-compiled forward pass: the first call may take several seconds to compile, but subsequent calls run fast. On an RTX 4090, a compiled BERT inference runs ~2-3x faster than the equivalent PyTorch eager mode, but the compilation overhead makes it impractical for interactive chat.

Workflow example

An operator training a small model locally might use JAX with Flax: pip install jax jaxlib flax. A training script defines a loss function, applies grad to compute gradients, and uses jit to compile the update step. For inference, the operator would export the model to a SavedModel or ONNX for use in vLLM or llama.cpp, because JAX's runtime doesn't integrate with those tools. JAX is more common in research workflows (e.g., fine-tuning on a single GPU) than in production local-AI pipelines.

Reviewed by Fredoline Eruo. See our editorial policy.

Buyer guides
  • Best GPU for local AI →
  • Best laptop for local AI →
  • Best Mac for local AI →
When it doesn't work
  • CUDA out of memory →
  • Ollama running slowly →
  • ROCm not detected →