How to add AMD GPU to python and pytorch

I am using python and pytorch, and I want to use my amd gpu (RX 5600). How can I modify this flake.nix to be able to install torch with gpu using a virtualenv?

{
  description = "A Python project";

  inputs = {
    nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
  };

  outputs = { self, nixpkgs }:
    let
      system = "x86_64-linux";
      pkgs = nixpkgs.legacyPackages.${system};
      pyEnv = pkgs.python3.withPackages (ps: with ps; [
        virtualenv
      ]);
      buildInputs = with pkgs; [
        pyEnv
        zlib
      ];
    in
    {
      devShell.${system} = pkgs.mkShell {
        buildInputs = buildInputs;
        shellHook = ''
          export LD_LIBRARY_PATH="${pkgs.lib.makeLibraryPath buildInputs}:$LD_LIBRARY_PATH"
          export LD_LIBRARY_PATH="${pkgs.stdenv.cc.cc.lib.outPath}/lib:$LD_LIBRARY_PATH"
          source venv/bin/activate
        '';
      };
    };
}
1 Like