CUDA working with poetry2nix

Has anyone managed to get CUDA working with poetry2nix? I’ve seen a number of questions related to CUDA and python but can’t find a solution with poetry.

I have set up nvidia globally and nvidia-smi successfully shows my GPU. Additionally, I followed this nixos + cuda + docker guide and can successfully use my GPU and it can be used inside docker (podman).

Running an official pytorch docker with docker run -ti --rm --gpus all pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime, pytorch finds the GPU with import torch; print(torch.cuda.is_available()). And if I install jax with cuda inside the docker, it also finds the GPU fine.

I cannot get jax or pytorch to find it with a flake using poetry2nix. Trying jax.devices() gives the error

RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

Using a flake.nix:

{
  description = "Test CUDA and poetry2nix";

  inputs = {
    nixpkgs.url = "github:nixos/nixpkgs";
    poetry2nix = {
      url = "github:nix-community/poetry2nix";
      inputs.nixpkgs.follows = "nixpkgs";
    };
  };

  outputs = { self, nixpkgs, poetry2nix }:
    let
      system = "x86_64-linux";
      pkgs = import nixpkgs {
        inherit system;
        overlays = [ poetry2nix.overlays.default ];
        config.allowUnfreePredicate = pkg:
          builtins.elem (pkgs.lib.getName pkg) [
            "cuda-merged"
            "cuda_cuobjdump"
            "cudnn"
            "cuda_gdb"
            "cuda_nvcc"
            "cuda_nvdisasm"
            "cuda_nvprune"
            "cuda_cccl"
            "cuda_cudart"
            "cuda_cupti"
            "cuda_cuxxfilt"
            "cuda_nvml_dev"
            "cuda_nvrtc"
            "cuda_nvtx"
            "cuda_profiler_api"
            "cuda_sanitizer_api"
            "libcublas"
            "libcufft"
            "libcurand"
            "libcusolver"
            "libnvjitlink"
            "libcusparse"
            "libnpp"
            "nvidia-settings"
            "nvidia-x11"
          ];
      };
      pyEnv = pkgs.poetry2nix.mkPoetryEnv {
        projectDir = ./.;
        editablePackageSources = { foo = ./.; };
        preferWheels = true;
      };
    in {
      devShells.${system}.default = pkgs.mkShell {
	    # Not sure what build inputs I need but seen these in various places
		# related to CUDA + python.
        buildInputs = with pkgs; [
          cudatoolkit
          linuxPackages.nvidia_x11
          cudaPackages.cudnn
          libGLU
          libGL
          xorg.libXi
          xorg.libXmu
          freeglut
          xorg.libXext
          xorg.libX11
          xorg.libXv
          xorg.libXrandr
          zlib
          ncurses5
          stdenv.cc
          binutils
        ];
        packages = [ pyEnv pkgs.poetry ];

		# Similarly to the build inputs, seen various env variables that
		# might need to be set.
        shellHook = ''
          # export LD_LIBRARY_PATH=$"LD_LIBRARY_PATH:${pkgs.cudatoolkit}/targets/x86_64-linux/lib"
          # export EXTRA_LDFLAGS="-L/lib -L${pkgs.linuxPackages.nvidia_x11}/lib"
          # export CUDA_PATH=${pkgs.cudatoolkit}
          export LD_LIBRARY_PATH="${pkgs.linuxPackages.nvidia_x11}/lib"
          # export EXTRA_CCFLAGS="-I/usr/include"
        '';
      };
    };
}

And a pyproject.toml:

[tool.poetry]
name = "foo"
version = "0.1.0"
description = "Use CUDA"
authors = ["me"]
packages = [{ include = "foo" }]

[tool.poetry.dependencies]
python = "^3.10,<3.13"
jax = {version = "^0.4.35", extras = ["cuda12"]}
torch = "^2.5.1"

[tool.poetry.group.dev.dependencies]
ipython = "^8.16.1"

And the relavent lines of in my global config:

  services.xserver = {
    videoDrivers = [ "nvidia" ];
  };

  hardware = {
    graphics.enable = true;
    nvidia = {
      modesetting.enable = true;
      powerManagement.enable = false;
      powerManagement.finegrained = true;
      open = true;
      nvidiaSettings = true;
      package = config.boot.kernelPackages.nvidiaPackages.stable;
      prime = {
        offload.enable = true;
        intelBusId = "PCI:0:2:0";
        nvidiaBusId = "PCI:1:0:0";
      };
    };
    nvidia-container-toolkit.enable = true;
  };

In the flake, running poetry lock to generate the poetry.lock file, building, and running ipython -c 'import jax; print(jax.devices())' gives the error.

A few updates:

If I change the pyproject file to use jax without the cuda12 extra, when I check jax.devices() I get the warning “An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.” So it seems jax is aware of the GPU it just failing to initialize it correctly.

Confirming this a bit, I previously missed the error:

E external/xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)

Which is the actual cause of jax not getting the GPU.

If I go the python3.withPackages root and set cudaSupport = true; in pkgs, the resulting jax finds the GPU. (It seems the jax wiki is out of date, you no longer need to include jaxlib (or jaxlibWithCuda, which no longer exists in unstable), see This release note).

I don’t think the cudaSupport option has any impact with poetry2nix but it still fails when this is set.

Using poetry and building from source fails (i.e. setting preferWheels = false;).

Issue persists across python versions (at least 3.11–3.13) as well as if I downgrade jax.

Hi David, I was stumbling on almost exactly the same thing (except I switched over from poetry2nix to uv2nix) and took me a while to solve this (strace is your friend).

Your flake looks alright, but your LD_LIBRARY_PATH seems incomplete: it does not link cudatoolkit, hence JAX (or rather, XLA) won’t find nvlink or nvcc. Also, you need to make sure the cuda libraries exported by python are discoverable by the linker. You were on the right track, as the correct definition (although incomplete due to missing python libs) was commented just two lines above.
Basically something like this:

let 
#...
in pkgs.mkShell {
          packages = with pkgs.cudaPackages; [ pyEnv cudatoolkit ];
          nativeBuildInputs = with pkgs; [ linuxPackages.nvidia_x11 ];
          env = {

               LD_LIBRARY_PATH = "${pkgs.linuxPackages.nvidia_x11}/lib:${pkgs.cudaPackages.cudatoolkit}/lib:${pyEnv}/lib";
          };
}

Note for who is using uv: you need to do some mangling to get the lib*.so files in each nvidia-*-cu11 right under ${nvidia-*-cu11}/site-packages/lib. I need to polish my gist a bit so it stops looking like something written by a script-kiddo.

1 Like

Hey thanks, I think this gets me in the right direction. I’m still getting the same error but I’m also considering moving to uv (seems like it’s replacing poetry) but not ready to go for it. In the mean time, I set up the dependencies with python3Packages. Not ideal but it’s good enough for now. Think I’ll make the switch to uv when I am ready to put a little more time into it, then hopefully your gist will work for that.

I hadn’t seen the env attribute for mkShell. That looks much nicer than using shellHook.

Hi David,

sorry for taking so long.

Have a look at this:

(edited to remove personal information)

Few remarks:

  1. I used to use jax[cuda12], but I thought that having 2 copies of the cuda libraries (the pip ones and the system ones - part of which was unavoidable due to nvidia_x11) was quite wasteful - and it also relied on ugly patchelf-ing overrides straight from poetry2nix which I had to copy over, so I went for jax[cuda12-local]. You can reasonably ignore this step if you are still using poetry2nix, since the overrides will still be available. If you plan on using jax[cuda12] on uv2nix instead you’ll have to copy the overrides manually. But I believe this goes against uv2nix’s philosophy.
  2. As a consequence, the XLA_FLAGS variable is not needed if you are not going to use the nixpkgs CUDA libraries.
  3. (EDIT for clarity) A trick you can always use is hacks, which I used for pyqt and numpy, and that people have tried with torch; but I prefer my solution because it does not require such a hack on JAX, and is officially supported - see point 1.
1 Like

Hey thanks a lot. It really helped getting a working environment. I converted to uv. Spent a few hours on it but could not get cuda working with uv2nix. Same error as previous even with the gist. Not sure what I’m doing wrong.

But went to base pyproject-nix and got it working. I believe this uses python3Packages but the nice thing is it still reads the packages off pyproject.toml which is all I really needed since I was previously having to keep the python dependencies listed in flake.nix in sync with those in pyproject.toml

Spent a few hours on it but could not get cuda working with uv2nix. Same error as previous even with the gist. Not sure what I’m doing wrong.

Are you using jax[cuda12-local] or jax[cuda12]? Would you mind sharing your flake and pyproject.toml file?

But went to base pyproject-nix and got it working. I believe this uses python3Packages but the nice thing is it still reads the packages off pyproject.toml which is all I really needed since I was previously having to keep the python dependencies listed in flake.nix in sync with those in pyproject.toml

I am happy you got it working, but you likely got this backwards: by default it creates new derivations based on what pyproject declares, and only fetches what’s in the nixpkgs if you explicitly tell it to (hence the hacks). However, dependencies may be in a different ballpark, I can’t say.