Hello,
I’ve been trying to set up an environment with Jax and CUDA. My shell.nix is as follows, though I’ve tried variants:
{ pkgs ? import <nixpkgs> { } }:
pkgs.mkShell {
buildInputs = with pkgs; [
cudatoolkit
python3
python3Packages.jax
python3Packages.jaxlibWithCuda
python3Packages.huggingface-hub
python3Packages.datasets
python3Packages.wandb
python3Packages.equinox
python3Packages.optax
];
}
In the nix shell, nvidia-smi
shows my GPU as expected. However, jaxlib seems to be compiled as CPU-only, despite my use of jaxlibWithCuda:
[nix-shell:~/Projects/ML/ramen]$ python -c 'import jax; print(jax.devices())'
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]
In my configuration.nix, I have:
nixpkgs.config.allowUnfree = true;
services.xserver.videoDrivers = ["nvidia"];
hardware.opengl = {
enable = true;
driSupport = true;
driSupport32Bit = true;
};
hardware.nvidia = {
modesetting.enable = true;
open = false;
nvidiaSettings = true;
};
nixpkgs.config.cudaSupport = true;
systemd.services.nvidia-control-devices = {
wantedBy = [ "multi-user.target" ];
serviceConfig.ExecStart = "${pkgs.linuxPackages.nvidia_x11.bin}/bin/nvidia-smi";
};
Does anyone know what could be causing this?