diff --git a/dev_env_builder.nix b/dev_env_builder.nix index 5b6a14cd..08ebb18a 100644 --- a/dev_env_builder.nix +++ b/dev_env_builder.nix @@ -1,40 +1,43 @@ -{ nixpkgs -, system -, cudaSupport -, rocmSupport -}: +{ nixpkgs, system, cudaSupport ? false, rocmSupport ? false +, vulkanSupport ? false }: let pkgs = import nixpkgs { inherit system; config.allowUnfree = true; config.cudaSupport = cudaSupport; config.rocmSupport = rocmSupport; + config.vulkanSupport = vulkanSupport; }; - pythonPackages = pkgs.python3Packages; -in -pkgs.mkShell rec { +in pkgs.mkShell { name = "impureEvoXPythonEnv"; venvDir = "./.venv"; - buildInputs = with pythonPackages; [ - python - # This executes some shell code to initialize a venv in $venvDir before - # dropping into the shell - venvShellHook + nativeBuildInputs = with pkgs; [ + (python313.withPackages (py-pkgs: + with py-pkgs; [ + # This executes some shell code to initialize a venv in $venvDir before + # dropping into the shell + venvShellHook - # Those are dependencies that we would like to use from nixpkgs, which will - # add them to PYTHONPATH and thus make them accessible from within the venv. - numpy - torch - torchvision - ] ++ (with pkgs; [ + # Those are dependencies that we would like to use from nixpkgs, which will + # add them to PYTHONPATH and thus make them accessible from within the venv. + numpy + (if cudaSupport then + torchWithCuda + else if rocmSupport then + torchWithRocm + else if vulkanSupport then + torchWithVulkan + else + torch) + torchvision + ])) pre-commit ruff - ]); + ]; # Run this command, only after creating the virtual environment postVenvCreation = '' unset SOURCE_DATE_EPOCH - pip install -e . ''; # Now we can execute any commands within the virtual environment. @@ -43,5 +46,4 @@ pkgs.mkShell rec { # allow pip to install wheels unset SOURCE_DATE_EPOCH ''; - } diff --git a/flake.nix b/flake.nix index f8608ddc..4fa28452 100644 --- a/flake.nix +++ b/flake.nix @@ -1,7 +1,7 @@ { description = "evox"; inputs = { - nixpkgs.url = "nixpkgs/nixos-24.11"; + nixpkgs.url = "nixpkgs/nixos-unstable"; utils.url = "github:numtide/flake-utils"; }; @@ -11,15 +11,24 @@ eachSystem (with system; [ x86_64-linux ]) (system: let builder = import ./dev_env_builder.nix; - cuda-env = builder { inherit system nixpkgs; cudaSupport = true; rocmSupport = false; }; - rocm-env = builder { inherit system nixpkgs; cudaSupport = false; rocmSupport = true; }; - cpu-env = builder { inherit system nixpkgs; cudaSupport = false; rocmSupport = false; }; - in - { + cuda-env = builder { + inherit system nixpkgs; + cudaSupport = true; + }; + rocm-env = builder { + inherit system nixpkgs; + rocmSupport = true; + }; + vulkan-env = builder { + inherit system nixpkgs; + vulkanSupport = true; + }; + cpu-env = builder { inherit system nixpkgs; }; + in { devShells.default = cpu-env; devShells.cpu = cpu-env; devShells.cuda = cuda-env; devShells.rocm = rocm-env; - } - ); + devShells.vulkan = vulkan-env; + }); }