jaxlibWithCuda not using CUDA

Hello,

I’ve been trying to set up an environment with Jax and CUDA. My shell.nix is as follows, though I’ve tried variants:

{ pkgs ? import <nixpkgs> { } }:
pkgs.mkShell {
  buildInputs = with pkgs; [
   cudatoolkit
   python3
   python3Packages.jax
   python3Packages.jaxlibWithCuda
   python3Packages.huggingface-hub
   python3Packages.datasets
   python3Packages.wandb
   python3Packages.equinox
   python3Packages.optax
  ];
}

In the nix shell, nvidia-smi shows my GPU as expected. However, jaxlib seems to be compiled as CPU-only, despite my use of jaxlibWithCuda:

[nix-shell:~/Projects/ML/ramen]$ python -c 'import jax; print(jax.devices())'
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]

In my configuration.nix, I have:

  nixpkgs.config.allowUnfree = true;


  services.xserver.videoDrivers = ["nvidia"];
  hardware.opengl = {
    enable = true;
    driSupport = true;
    driSupport32Bit = true;
  };

  hardware.nvidia = {
    modesetting.enable = true;
    open = false;
    nvidiaSettings = true;
  };

  nixpkgs.config.cudaSupport = true;

  systemd.services.nvidia-control-devices = {
    wantedBy = [ "multi-user.target" ];
    serviceConfig.ExecStart = "${pkgs.linuxPackages.nvidia_x11.bin}/bin/nvidia-smi";
  };

Does anyone know what could be causing this?

1 Like

Update:

Replacing the first line with { pkgs ? import <nixpkgs> { config.allowUnfree = true; config.cudaSupport = true } }: causes it to perform a lot of tests.

However, the same error persists, although now wandb must be removed since it fails some tests.

I am having the same issue. With nixos-23.11, jaxlibWithCuda seems to have some issue. I got

>>> import jax
>>> jax.devices()
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]