[Help] How to get JAX + CUDA 12 working under Nix (Ubuntu 24.04, RTX 2080)

Hi everyone,

I’m trying to create a Nix flake that includes JAX with CUDA 12 support , but the build keeps failing.

I’ve tested both jaxlib-bin and jaxWithCuda , and neither works on my Ubuntu 24.04 machine.

Here are two Gists with details and build logs:

For comparison, I was able to get PyTorch working with GPU (CUDA 12.2) in a similar Nix environment:

System info:

$ nix --version
nix (Nix) 2.20.6

$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 24.04.3 LTS
Release:        24.04
Codename:       noble

$ nvidia-smi
Fri Nov  7 00:33:29 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.274.02             Driver Version: 535.274.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2080        Off | 00000000:17:00.0 Off |                  N/A |
| 20%   34C    P8              16W / 215W |      3MiB /  8192MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 2080 Ti     Off | 00000000:65:00.0 Off |                  N/A |
| 32%   42C    P8              30W / 250W |      3MiB / 11264MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

Has anyone successfully built JAX with CUDA 12 in a similar setup?
Any hints on which Nixpkgs revision or overlay might work would be greatly appreciated!

Best,
Herman