Can't get Pytorch to recognise override

I’m really struggling to figure out how to get CUDA 11 in Pytorch. It seems like no matter what combination of override I do, I always get 10.2. And 10.2 doesn’t work for my 3060ti. I’m on NixOS unstable, and the project is using a flake. The flake.nix is defined as:

{
  description = "Project";

  inputs.nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
  inputs.flake-utils.url = "github:numtide/flake-utils";

  outputs = { self, nixpkgs, flake-utils }:
    flake-utils.lib.eachDefaultSystem (system:
      let
        pkgs = import nixpkgs {
          system = "x86_64-linux";
          config = {
            allowUnfree = true;
            cudaSupport = true;
          };
        };
      in { devShell = import ./shell.nix { inherit pkgs; }; });
}

The shell.nix is defined as:

{ pkgs ? import <nixpkgs> { } }:
let
  cudatoolkit = pkgs.cudaPackages.cudatoolkit_11;
  cudnn = pkgs.cudnn_cudatoolkit_11;
  nccl = pkgs.nccl_cudatoolkit_11;
  magma = pkgs.magma.override { cudatoolkit = cudatoolkit; };

  python = pkgs.python38.withPackages (ps:
    with ps; [
      (pytorch.override
      {
        cudaSupport = true;
        cudatoolkit = cudatoolkit;
        cudnn = cudnn;
        nccl = nccl;
        magma = magma;
      })
      pytorch-lightning
      transformers
      ipython
      mypy
      black
      flake8
    ]);

in pkgs.mkShell {
  buildInputs =
    [ pkgs.linuxPackages.nvidia_x11 cudatoolkit cudnn nccl python magma ];
}

If anyone could help me out, I would really appreciate it. Each time I make changes to try something new I have to wait for Pytorch to recompile, which is incredibly slow.

Did you ever figure this out? I’m having the same issue.

I have never used flakes but could it be happening because while you overridden CUDA in the expression you use to install pytorch, you didn’t override the definition of pytorch itself, so all other packages which depend on pytorch (i.e. pytorch-lightning) pulled the standard package for pytorch from the binary cache? Do flakes have overlays?

2 Likes

I think @alexv could be right.

I was in the same situation. For me, I removed opencv and it turned out that was the package automatically pulling in cuda 10.2 despite me overriding all the stuff for pytorch. So my guess is pytorch-lightning is doing the same here.

Alexv is most likely right, pytorch-lightning must be bringing in a different version of pytorch into your environment. You probably want to use packageOverrides for python package set:

let
  py = pkgs.python3.override {
    packageOverrides = python-final: python-prev: rec {
        pytorch = python-prev.pytorch.override {
          blas = ...;
          cudaSupport = ...;
          cudaArchList = ...;
        };
    };
    self = py;
  };
in py.withPackages (ps: ...);

You can also verify if that’s the case by visualizing the dependency tree e.g. via GitHub - utdemir/nix-tree: Interactively browse dependency graphs of Nix derivations.

1 Like

Actually, I think I ran into exactly same problem recently. The UX of substituting cuda version in pytorch&c could probably see some improvement…

Wow thats a wonderful tool. This is going to save me so much time

1 Like