NumPy gave Python its foundation for numerical computing. But datasets grew larger than RAM. GPUs became standard hardware. Machine learning demanded automatic differentiation. Scientists needed parallelism across clusters. No single library could handle all of it.
The result is an ecosystem of array computation libraries that extend, accelerate, and reimagine what NumPy started. This article covers the major players—CuPy, JAX, Dask, Numba, and the Array API Standard that is trying to hold them all together. Real code. Real tradeoffs. Real understanding of when and why you would reach for each one.
The Problem of Fragmentation
In August 2020, Ralf Gommers, a long-time NumPy maintainer and director of the Consortium for Python Data API Standards, identified the core issue in the consortium's announcement post: new frameworks pushing forward the state of the art appear every year, but one unintended consequence has been fragmentation in the fundamental building blocks—multidimensional arrays and dataframe libraries—that underpin the whole Python data ecosystem. He noted that arrays are fragmented across TensorFlow, PyTorch, NumPy, CuPy, Dask, and others, and that this fragmentation comes with significant costs, from entire libraries being reimplemented for a different array backend to users having to relearn APIs when switching between frameworks.
The 2020 Nature paper on NumPy (Harris et al.) made the same observation from the library's perspective: several projects targeting audiences with specialized needs have developed their own NumPy-like interfaces and array objects. NumPy increasingly acts as an interoperability layer between these array computation libraries. It described how exploring new ways of working with arrays is experimental by nature, and that each time users decide to try a new technology, they must change import statements and verify that the new library implements all the NumPy API they currently use.
That is the landscape. Now let's walk through it.
CuPy: NumPy on the GPU
CuPy is the most direct answer to the question "what if my NumPy arrays lived on a GPU?" Developed originally by Preferred Networks (the team behind the Chainer deep learning framework), CuPy implements a near-identical interface to NumPy's ndarray, but backed by CUDA memory on NVIDIA GPUs. CuPy has supported AMD ROCm experimentally since its earlier releases, and CuPy v14 (released February 17, 2026) moved to ROCm 7 as its minimum supported version (dropping ROCm 6.x), aligned with NumPy 2 semantics, and added initial bfloat16 support.
The migration path is deliberately minimal. In many cases, you change one import:
import cupy as cp
# Create a 10,000 x 10,000 matrix on the GPU
a = cp.random.randn(10000, 10000, dtype=cp.float32)
b = cp.random.randn(10000, 10000, dtype=cp.float32)
# Matrix multiplication on GPU -- same syntax as NumPy
c = a @ b
# Move result back to CPU as a NumPy array
c_cpu = cp.asnumpy(c)
CuPy supports broadcasting, fancy indexing, reductions, and linear algebra—all running on GPU. Because it implements __array_ufunc__ and __array_function__, NumPy functions applied to CuPy arrays automatically dispatch to CuPy's CUDA implementations. That means existing code that calls np.sum() or np.linalg.qr() can operate on GPU data without modification, as long as the arrays are CuPy arrays.
GPU computing shines when you are operating on large, regular arrays with high arithmetic intensity. A matrix multiplication on a 10,000-by-10,000 matrix will see dramatic speedups. But transferring small arrays between CPU and GPU memory incurs latency that can erase any computational gains. The rule of thumb: if your array fits comfortably in L3 cache, NumPy on the CPU may actually be faster.
CuPy also provides cp.RawKernel for writing custom CUDA kernels directly in Python strings, giving advanced users access to GPU hardware without leaving the Python environment. This matters when your algorithm cannot be expressed purely through array operations—for example, when different elements of the array need different numbers of iterations to converge.
CuPy requires a supported GPU: NVIDIA with CUDA, or AMD with ROCm (experimental support has been available since earlier CuPy versions; AMD provided ROCm wheels for its CuPy v13 fork, and CuPy v14 now offers official upstream ROCm 7 wheels via pip install cupy-rocm-7-0). If you are on Apple Silicon, neither CuPy nor CUDA applies; consider JAX with its Metal plugin (experimental) or stick with NumPy plus Numba for CPU-bound acceleration.
JAX: Composable Transformations of NumPy Programs
JAX is different from every other library on this list. Where CuPy and Dask take the NumPy API and run it on different hardware or at different scales, JAX takes the NumPy API and wraps it in a system of composable function transformations. The official project description calls JAX "composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more."
A nascent version of JAX, supporting automatic differentiation and compilation to XLA, was first described in a paper presented at SysML 2018 by Roy Frostig, Matthew James Johnson, and Chris Leary (among others). The full project was open-sourced in December 2018. Google DeepMind adopted it extensively. In a blog post on their research practices, DeepMind noted that JIT-compilation via XLA, together with JAX's NumPy-consistent API, allows researchers with no previous experience in high-performance computing to easily scale to one or many accelerators.
Here is where JAX becomes powerful. It provides four core transformations that can be composed arbitrarily:
import jax
import jax.numpy as jnp
# jit: compile a function with XLA for faster execution
@jax.jit
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
# grad: automatic differentiation
grad_selu = jax.grad(selu)
# vmap: automatic vectorization (batch a function over an axis)
batched_selu = jax.vmap(selu)
# pmap: parallel execution across multiple devices (GPUs/TPUs)
parallel_selu = jax.pmap(selu)
The jit transformation compiles your Python function into optimized XLA (Accelerated Linear Algebra) code that runs on CPU, GPU, or TPU. The grad transformation computes gradients through arbitrary Python code—not just through a fixed set of predefined layers, but through any differentiable computation you can express. The vmap transformation takes a function that operates on a single example and automatically vectorizes it across a batch. And pmap distributes computation across multiple accelerators.
What makes JAX genuinely novel is that these transformations compose. You can take the gradient of a JIT-compiled function. You can vectorize a gradient computation. You can parallelize a vectorized gradient. This composability is why JAX has become the tool of choice for machine learning researchers who need to implement custom algorithms from scratch.
There are important tradeoffs. JAX arrays are immutable—you cannot modify them in place. JAX functions must be "pure" in the functional programming sense: no side effects, no global state mutation. This is not a limitation of the implementation; it is a requirement for the transformations to work correctly. If you come from a NumPy background where a[0] = 5 is natural, JAX requires a different style:
# NumPy style (WILL NOT WORK in JAX)
# a[0] = 5
# JAX style -- returns a new array
a = a.at[0].set(5)
JAX has a compilation cost. The first call to a jit-decorated function triggers XLA compilation, which can take seconds for complex functions. Subsequent calls with the same input shapes and types reuse the compiled code and are fast. If you call the same function with different shapes, it recompiles. JAX is best suited for workloads where the same computation is repeated many times on data of consistent shape—exactly the pattern seen in training neural networks.
Dask: Scaling Out to Clusters
Dask solves a different problem than CuPy or JAX. It is not about running on different hardware; it is about running on more hardware, or on data that does not fit in memory.
Dask Array implements a subset of the NumPy ndarray interface using blocked algorithms. It divides a large array into many smaller "chunks," each of which is a regular NumPy array (or CuPy array, or any NumPy-compatible array), and schedules computation across chunks using a task graph:
import dask.array as da
import numpy as np
# Create a 100,000 x 100,000 array -- ~80 GB of float64 data
# Dask never creates this in memory all at once
x = da.random.normal(0, 1, size=(100_000, 100_000), chunks=(10_000, 10_000))
# Operations are lazy -- they build a task graph
result = (x + x.T).mean(axis=0)
# Nothing has been computed yet. Trigger execution:
output = result.compute()
print(output.shape) # (100000,)
The key insight is that Dask is lazy. When you write x + x.T, Dask does not perform the addition. It records the operation in a directed acyclic graph. Only when you call .compute() does Dask execute the graph, processing chunks in parallel and managing memory so that the full 80 GB dataset never needs to reside in RAM simultaneously.
Dask works with multiple schedulers. The default threaded scheduler uses multiple threads on a single machine. The distributed scheduler can coordinate computation across a cluster of machines. And because Dask operates on chunks of NumPy-compatible arrays, you can swap the chunk type. Replace NumPy chunks with CuPy chunks, and you have GPU-accelerated distributed array computing:
import cupy
import dask.array as da
# Generate chunks backed by CuPy arrays (on GPU)
rs = da.random.RandomState(RandomState=cupy.random.RandomState)
x = rs.normal(0, 1, size=(50_000, 50_000), chunks=(10_000, 10_000))
# This runs on GPU, distributed across chunks
result = x.mean().compute()
Dask is the right tool when your bottleneck is data size or when you need to scale across machines. It is not the right tool when you need microsecond-level latency on small arrays—the task graph overhead adds milliseconds that matter in those cases.
Numba: Compiling Python to Machine Code
Numba occupies a unique niche. It does not provide a new array type. Instead, it compiles your existing Python and NumPy code into optimized machine code at runtime using the LLVM compiler infrastructure. Travis Oliphant, who created NumPy, was also a co-creator of Numba.
The primary use case is code that cannot be expressed as vectorized NumPy operations—code with loops, conditional logic, and element-wise algorithms where NumPy's array-at-a-time model creates too many temporary arrays:
from numba import njit
import numpy as np
@njit
def custom_filter(data, threshold):
"""Apply a filter that NumPy can't vectorize efficiently:
each element depends on the running state."""
result = np.empty_like(data)
state = 0.0
for i in range(len(data)):
if data[i] > threshold:
state = 0.9 * state + 0.1 * data[i]
else:
state *= 0.95
result[i] = state
return result
data = np.random.randn(10_000_000)
filtered = custom_filter(data, 0.5) # First call compiles; subsequent calls are fast
Without Numba, that loop would take seconds in pure Python. With the @njit decorator, Numba compiles it to machine code that runs at C/Fortran speed. The Numba documentation states directly that compiled numerical algorithms in Python can approach the speeds of C or Fortran.
Numba works by analyzing the Python bytecode of a decorated function, inferring data types from the arguments, and generating LLVM IR (intermediate representation) that is then compiled to native machine code. Importantly, in nopython mode (@njit, or @jit which has defaulted to nopython since version 0.59, released January 2024), Numba-compiled functions run entirely without the Python interpreter's involvement—and they release the GIL, enabling genuine multi-threaded parallelism.
Numba also provides GPU compilation through its CUDA backend. Note that as of Numba 0.61 (released January 16, 2025), the built-in numba.cuda module is deprecated in favor of the standalone numba-cuda package, which receives more frequent updates. The built-in target will continue to be provided through at least Numba 0.62, but new features and bug fixes are being added only to numba-cuda:
from numba import cuda
import numpy as np
@cuda.jit
def vector_add(a, b, c):
i = cuda.grid(1)
if i < a.size:
c[i] = a[i] + b[i]
n = 1_000_000
a = np.ones(n)
b = np.ones(n)
c = np.zeros(n)
# Configure grid dimensions
threads_per_block = 256
blocks = (n + threads_per_block - 1) // threads_per_block
vector_add[blocks, threads_per_block](a, b, c)
The critical insight about Numba is knowing when to use it versus NumPy. If your operation can be expressed as array operations (additions, multiplications, reductions, broadcasting), NumPy is already fast because those operations run in compiled C code. Numba's strength is loops that carry state, branching logic that differs per element, and algorithms that would create excessive temporary arrays if vectorized. Numba likes loops—which is the opposite of the conventional NumPy wisdom that says to avoid loops at all costs.
The Array API Standard: Stitching It Together
The proliferation of array libraries created a real problem. Code written for NumPy would not run on CuPy without changes. JAX arrays behaved slightly differently from PyTorch tensors. A function written for one backend could not easily be ported to another. The Consortium for Python Data API Standards, announced in August 2020 and initiated by Quansight Labs, exists to fix this.
The consortium includes stakeholders from NumPy, CuPy, PyTorch, JAX, Dask, TensorFlow, and others. Their goal: define a common API specification that all array libraries can implement, so that downstream consumers like SciPy and scikit-learn can write code once and have it work across backends.
The 2024 revision of the Array API Standard (published February 27, 2025) standardizes a comprehensive set of functions covering creation, manipulation, statistical operations, linear algebra, FFT operations, and—new in this revision—integer array indexing for vectorized fancy indexing. The revision also relaxed requirements for binary element-wise functions to accept scalars alongside arrays, a practical improvement requested by many downstream users. The array-api-compat compatibility layer, published on PyPI by the consortium, smooths over behavioral differences between libraries as they work toward full conformance:
from array_api_compat import array_namespace
import numpy as np
def normalize(x):
"""Works with NumPy, CuPy, PyTorch, JAX arrays."""
xp = array_namespace(x) # Detect which library 'x' comes from
return (x - xp.mean(x)) / xp.std(x)
# Works with NumPy
result_np = normalize(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
# Would also work with CuPy, JAX, or PyTorch arrays
# without changing a single line of the function
The array_namespace() function inspects the input array and returns the appropriate library namespace. This means xp.mean() dispatches to np.mean() for NumPy arrays, cp.mean() for CuPy arrays, or jnp.mean() for JAX arrays.
Notably, NumPy 2.0 (released June 16, 2024) adopted the Array API Standard directly in its main namespace—a significant milestone that the consortium highlighted in their 2024 blog post. Both SciPy and scikit-learn have added experimental support, meaning that algorithms from those libraries can now operate on GPU tensors through the standard API. The Chan Zuckerberg Initiative awarded the consortium an EOSS Cycle 6 grant in November 2024 to accelerate adoption further.
The data interchange mechanism underlying all of this is DLPack, an open in-memory structure for sharing tensors between frameworks without copying data. Array objects that conform to the standard implement __dlpack__ and __dlpack_device__ methods, and a standardized from_dlpack() function constructs arrays across libraries. DLPack supports data on devices other than CPU—including GPU memory—making zero-copy transfer between, say, PyTorch and CuPy possible.
PEPs That Enable the Ecosystem
Several Python Enhancement Proposals underpin how these libraries interoperate.
PEP 3118 — Revising the Buffer Protocol. As discussed in our previous NumPy article, this PEP standardized how Python objects share memory. Every array library that interoperates with NumPy on CPU relies on this protocol. Authored by Travis Oliphant and Carl Banks, it enabled the zero-copy data sharing that makes the ecosystem practical.
PEP 465 — The @ Operator for Matrix Multiplication. Every library covered here supports @ for matrix multiplication, because PEP 465 added __matmul__, __rmatmul__, and __imatmul__ to the Python data model. CuPy arrays, JAX arrays, Dask arrays, and NumPy arrays all use the same operator for the same operation.
PEP 703 — Making the GIL Optional. Free-threaded Python arrived as officially supported in Python 3.14 (released October 7, 2025), promoted from experimental status by PEP 779. This is especially significant for Numba and Dask. Numba already releases the GIL in nopython mode, but free-threaded Python makes multi-threaded execution safer and more practical across the entire stack. Numba 0.63 (released December 8, 2025) added Python 3.14 support (including compatibility with free-threaded builds), and NumPy 2.1 through 2.3 have been progressively improving free-threaded compatibility. According to the official CPython documentation, the single-threaded overhead of the free-threaded build on the pyperformance benchmark suite dropped from about 40% in Python 3.13 to roughly 1–8% in 3.14 (varying by platform), making it viable for production workloads.
NEP 13 and NEP 18 — Array Function and Ufunc Protocols. These NumPy Enhancement Proposals (not PEPs, but analogous within the NumPy project) introduced __array_ufunc__ and __array_function__, which allow third-party libraries to override NumPy functions. When you call np.sum() on a CuPy array, it dispatches to CuPy's implementation instead of trying to process the data as a NumPy array. This mechanism is what makes the entire interoperability layer possible.
Mental Models: How to Think About Execution
The deepest gap in understanding these libraries is not syntax. It is execution model. Each library thinks about computation differently, and unless you internalize those mental models, you will write code that technically runs but performs poorly or fails in production at scale.
NumPy thinks in bulk operations. Every call to a NumPy function hands control to a compiled C loop that processes the entire array. Python never touches individual elements. The performance contract: one Python-level call, one fast C-level sweep. If you write a Python for loop over a NumPy array's elements, you are violating this contract and paying the interpreter overhead on every iteration.
CuPy thinks in kernel launches. Each CuPy operation dispatches a GPU kernel—a compiled function that runs across thousands of CUDA cores simultaneously. But launching a kernel has latency (tens of microseconds). If your operation does very little work per launch, the launch overhead dominates. CuPy rewards large, regular, compute-dense operations and punishes small, irregular ones.
JAX thinks in compiled computation graphs. When you call a jit-decorated function, JAX does not execute it immediately. It traces through the function symbolically, captures the entire computation as a graph, compiles it via XLA into a fused kernel, and only then executes it. The first call is slow (compilation). Subsequent calls with the same input shapes reuse the compiled artifact and are fast. If your shapes change frequently, you are paying compilation cost repeatedly. JAX rewards static, repetitive computation patterns—exactly the pattern of training loops in machine learning.
Dask thinks in task graphs. Every operation on a Dask array builds a node in a directed acyclic graph. No computation happens until .compute() is called. Dask then analyzes the graph, fuses operations where possible, schedules work across threads or machines, and manages memory by processing chunks sequentially. This lazy evaluation means that Dask can handle datasets far larger than RAM, but the overhead of graph construction and scheduling means it is never the right choice for small, fast computations.
Numba thinks in compiled Python functions. Numba takes your Python bytecode and, at the first call, compiles it through LLVM into machine code specific to the input types. The compiled function replaces the Python function for subsequent calls. Numba's mental model is closest to a traditional compiler: write the code naturally, let the compiler optimize. But the compiler only understands a subset of Python—pure numerical code with typed arrays and scalars. Stray outside that subset and Numba either falls back to slow object mode or raises an error.
These mental models are not just conceptual aids. They predict performance. If you understand that JAX traces symbolically, you understand why Python if statements inside a jit function behave differently (they are traced once, not evaluated dynamically). If you understand that CuPy launches kernels, you understand why transferring a 100-element array to the GPU and back is slower than just using NumPy. If you understand that Dask builds graphs lazily, you understand why chaining a hundred operations before calling .compute() is efficient—Dask fuses them into fewer actual operations.
Profiling and Diagnosing the Real Bottleneck
Choosing a library is only the first decision. The harder question is: is your performance problem actually where you think it is? Developers routinely migrate to GPU acceleration before confirming that their bottleneck is compute-bound rather than I/O-bound or memory-bound. This is like buying a faster car when your problem is traffic.
Before reaching for CuPy or JAX, profile. Python's built-in cProfile reveals where wall-clock time is spent. But for numerical code, it often shows that all the time is in a single NumPy call—which is not very illuminating. The real question is whether you are compute-bound, memory-bound, or transfer-bound.
import numpy as np
import time
# Profile data transfer vs. computation
data = np.random.randn(10_000, 10_000).astype(np.float32)
# 1. Measure just the computation
start = time.perf_counter()
result = data @ data.T
cpu_time = time.perf_counter() - start
# 2. If using CuPy, measure transfer + computation
import cupy as cp
start = time.perf_counter()
gpu_data = cp.asarray(data) # CPU -> GPU transfer
gpu_result = gpu_data @ gpu_data.T # GPU computation
cp.cuda.Stream.null.synchronize() # Wait for GPU to finish
gpu_total = time.perf_counter() - start
# 3. Measure just the transfer
start = time.perf_counter()
gpu_data = cp.asarray(data)
cp.cuda.Stream.null.synchronize()
transfer_time = time.perf_counter() - start
print(f"CPU compute: {cpu_time:.3f}s")
print(f"GPU total (transfer + compute): {gpu_total:.3f}s")
print(f"GPU transfer alone: {transfer_time:.3f}s")
If the transfer time dominates, GPU acceleration will not help unless you can keep data on the GPU for the entire pipeline. This is the single most important profiling insight for GPU computing: measure transfer and compute separately.
For JAX, the equivalent diagnostic is measuring JIT compilation time versus execution time. Call the function once to trigger compilation, then time subsequent calls:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
return jnp.linalg.svd(x, full_matrices=False)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (5000, 5000))
# First call: includes compilation
start = time.perf_counter()
f(x).block_until_ready()
first_call = time.perf_counter() - start
# Second call: uses cached compiled code
start = time.perf_counter()
f(x).block_until_ready()
second_call = time.perf_counter() - start
print(f"First call (compile + run): {first_call:.3f}s")
print(f"Second call (run only): {second_call:.3f}s")
For Dask, the diagnostic is whether your graph has become excessively complex. Dask's dashboard (accessible when using the distributed scheduler) visualizes task graphs, memory usage, and worker utilization in real time. If you see workers sitting idle while the scheduler plans, your graph is too complex—increase chunk sizes or reduce the number of chained operations before .compute().
For Numba, the first call to a decorated function includes compilation time. If you are benchmarking, always call the function once before timing. Also, use numba.core.types to explicitly specify signatures and trigger ahead-of-time compilation with @njit(cache=True) to avoid recompilation on every script run.
Memory Pitfalls Across Libraries
Memory management is where these libraries diverge sharply, and where production code breaks in ways that benchmarks never reveal.
CPU/GPU memory boundaries. CuPy arrays live in GPU memory (VRAM), which is typically 8–80 GB depending on the card. Calling cp.asnumpy() copies data from GPU to CPU memory. Calling cp.asarray() copies the other way. If you are not conscious of these transfers, you can saturate the PCIe bus and lose all GPU benefit. The discipline: once data moves to the GPU, keep it there for the entire pipeline. Only transfer back for final results or I/O.
JAX's functional memory model. Because JAX arrays are immutable, every "modification" creates a new array. This can cause memory to spike if you are not careful:
import jax.numpy as jnp
# This creates three arrays in memory: x, intermediate, result
x = jnp.ones((10_000, 10_000))
intermediate = x + 1 # New allocation
result = intermediate * 2 # New allocation
# Inside @jax.jit, XLA fuses these into one operation
# and avoids the intermediate allocation.
# Outside @jit, you pay for every temporary.
The solution is to wrap multi-step computations in @jax.jit wherever possible. XLA's compiler fuses operations and eliminates temporary allocations automatically. If you are writing JAX code and not using jit, you are paying for Python-level allocations that the compiler would otherwise optimize away.
Dask's chunking discipline. Dask chunks that are too small create excessive scheduler overhead (millions of tiny tasks). Chunks that are too large defeat the purpose of out-of-core computing by exceeding available memory. The Dask documentation recommends chunks of roughly 100 MB as a starting point, but the right size depends on your available memory, the number of workers, and the complexity of your operations. Dask's .rechunk() method lets you adjust chunk sizes after creation.
Numba's memory allocation inside compiled functions. Numba can allocate NumPy arrays inside compiled functions (using np.empty, np.zeros, etc.), but these allocations go through NumPy's allocator, which acquires the GIL momentarily even in nopython mode. For performance-critical inner loops, pre-allocate output arrays outside the Numba function and pass them in as arguments. This is an easy optimization that many developers miss.
A common failure mode: data silently moves between CPU and GPU when you mix libraries. For example, passing a CuPy array to a function that calls np.array() internally will trigger an implicit GPU-to-CPU transfer, copy the entire array, and destroy performance with no visible error. Use cupy.get_array_module(x) or the Array API's array_namespace() to detect and prevent this.
How to Choose
There is no universal best library. The right choice depends on your bottleneck, and frequently the best solution is a combination:
- Your data fits in memory and you are on CPU. Use NumPy. It is the most mature, best documented, and has the widest downstream support. Add Numba if you have hot loops that cannot be vectorized. For free-threaded Python 3.14, NumPy 2.3+ has improved thread-safety, making multi-threaded NumPy workflows more viable than before.
- You need GPU acceleration for array operations. Use CuPy if your workflow is NumPy-centric and you want the simplest migration path. Use JAX if you also need automatic differentiation, JIT compilation, or plan to scale across multiple GPUs/TPUs. If you are on AMD hardware, CuPy v14 provides official upstream ROCm 7 wheels, and JAX also has a community-maintained ROCm plugin.
- Your data does not fit in memory, or you need distributed computing. Use Dask. It can wrap NumPy arrays for CPU workloads or CuPy arrays for GPU workloads, and scale across clusters. For data that fits in memory but benefits from parallelism, consider whether Numba's
parallel=Trueoption or NumPy with free-threaded Python might be a simpler solution before introducing Dask's graph-building overhead. - You are building custom numerical algorithms with loops and branching. Use Numba for CPU compilation. Use Numba's CUDA backend or JAX for GPU. Consider whether your algorithm can be reformulated as vectorized operations first—Numba is for when it genuinely cannot.
- You want your code to work across multiple backends. Write against the Array API Standard using
array-api-compat. This future-proofs your code as the ecosystem continues to evolve. - You are doing machine learning research. JAX is the strongest choice for custom training loops, novel architectures, and research that demands composable differentiation. Its functional purity constraint, while initially unfamiliar, forces clean code that is inherently more reproducible.
A decision that many articles miss: you can and should combine these libraries. A realistic pipeline might load data with Dask (because it exceeds RAM), process chunks with CuPy (for GPU acceleration), train a model with JAX (for automatic differentiation), and use the Array API Standard to keep the pipeline portable. The interoperability infrastructure—DLPack, __array_function__, the Array API—exists precisely to make this kind of composition practical.
A Practical Example: The Same Computation, Four Ways
To make the differences concrete, here is the same computation—normalizing columns of a large matrix and computing column-wise variances—in four libraries:
# --- NumPy (CPU, single-threaded) ---
import numpy as np
data = np.random.randn(50_000, 1_000)
normalized = (data - data.mean(axis=0)) / data.std(axis=0)
variances = normalized.var(axis=0)
# --- CuPy (GPU) ---
import cupy as cp
data = cp.random.randn(50_000, 1_000)
normalized = (data - data.mean(axis=0)) / data.std(axis=0)
variances = normalized.var(axis=0)
# Note: identical syntax to NumPy
# --- JAX (JIT-compiled, GPU/TPU) ---
import jax.numpy as jnp
import jax
@jax.jit
def compute_variances(data):
normalized = (data - jnp.mean(data, axis=0)) / jnp.std(data, axis=0)
return jnp.var(normalized, axis=0)
key = jax.random.PRNGKey(0)
data = jax.random.normal(key, (50_000, 1_000))
variances = compute_variances(data)
# --- Dask (larger-than-memory, parallel) ---
import dask.array as da
data = da.random.normal(0, 1, size=(50_000, 1_000), chunks=(10_000, 1_000))
normalized = (data - data.mean(axis=0)) / data.std(axis=0)
variances = normalized.var(axis=0).compute()
# .compute() triggers actual execution
The syntax is almost identical across all four. The semantics differ: NumPy executes immediately, CuPy executes immediately on GPU, JAX compiles then executes, and Dask builds a task graph then executes on .compute(). Understanding those execution models—not just the API—is what separates competent use from effective use.
The Bigger Picture
The Python array computation ecosystem in 2026 is not fragmented by accident. It is fragmented because the problems are genuinely different. GPU computation, automatic differentiation, distributed computing, and JIT compilation are not the same problem, and no single library handles all of them optimally. What has changed is that the community now has infrastructure—the Array API Standard, DLPack, __array_function__, __array_ufunc__—to make these libraries work together rather than in isolation.
The Consortium's 2023 SciPy Proceedings paper (Reines et al.) describes how direct participation by library maintainers has accelerated ecosystem-wide coordination, and how growing adoption of the Array API Standard by array libraries has encouraged downstream projects like SciPy and scikit-learn to decouple their implementations from any single array backend.
The arrival of free-threaded Python 3.14 adds another dimension. Libraries that previously could not benefit from multi-threaded Python execution—because the GIL serialized everything—now have a path to genuine CPU parallelism within a single process. NumPy 2.3, Numba 0.63, and Dask 2026.1.2 have all added Python 3.14 support, and are actively improving their free-threading compatibility. This does not replace GPU acceleration for compute-dense workloads, but it does mean that the CPU-bound coordination code between GPU kernels, the data loading pipelines, and the post-processing stages can all run in parallel without resorting to multiprocessing.
There is a deeper lesson here that goes beyond any single library. The history of scientific computing is a history of abstraction layers: from FORTRAN subroutines to BLAS/LAPACK, from BLAS to NumPy, from NumPy to the Array API Standard. Each layer lets you think at a higher level while trusting the layer below to handle the details. The libraries covered in this article are the current front of that abstraction. The Array API Standard is the latest attempt to create a stable interface between what you mean (normalize these columns, compute this gradient) and how it gets done (on which hardware, with which compiler, across how many machines).
Write your code against the standard APIs. Understand the execution model of your chosen backend. Know when to switch. Profile before optimizing. And keep data where the computation happens. That is what real array programming in Python looks like.
Sources and further reading: Harris et al., "Array programming with NumPy," Nature 585, 357–362 (2020). Gommers, "Announcing the Consortium for Python Data API Standards," data-apis.org (August 17, 2020). Reines et al., "Python Array API Standard: Toward Array Interoperability in the Scientific Python Ecosystem," SciPy Proceedings (2023). "2024 release of the Array API Standard," data-apis.org (February 27, 2025). "CZI EOSS 6 Award to Advance Array Interoperability," data-apis.org (November 11, 2024). Google DeepMind, "Using JAX to accelerate our research," deepmind.google. Frostig et al., "Compiling machine learning programs via high-level tracing," SysML 2018. Okuta et al., "CuPy: A NumPy-Compatible Library for NVIDIA GPU Calculations," NeurIPS 2017 LearningSys Workshop. Maehashi, "Announcing CuPy v14," medium.com/cupy-team (February 2026). NumPy Interoperability documentation at numpy.org. Numba 0.61.0 release notes and Numba 0.63.0 release notes at numba.readthedocs.io. "Free-Threaded Python," docs.python.org/3/howto/free-threading-python.html. "Free-Threaded Python," docs.python.org/3.13/howto/free-threading-python.html. PEP 3118, PEP 465, PEP 703, PEP 779 at peps.python.org. "Python 3.14 What's New," docs.python.org/3/whatsnew/3.14.html.