diff --git a/cmake/gauxc-onedft.cmake b/cmake/gauxc-onedft.cmake index 6e799849..4fbcae8d 100644 --- a/cmake/gauxc-onedft.cmake +++ b/cmake/gauxc-onedft.cmake @@ -24,7 +24,9 @@ endif() # store and restore CMAKE_CUDA_ARCHITECTURES if Torch clobbers it set(_PREV_CUDA_ARCHS "${CMAKE_CUDA_ARCHITECTURES}") -find_package(Torch REQUIRED) + +include(skala-torch) + if(CMAKE_CUDA_ARCHITECTURES STREQUAL "OFF") set(CMAKE_CUDA_ARCHITECTURES "${_PREV_CUDA_ARCHS}" CACHE STRING "Restore CUDA archs after Torch override" FORCE) message(WARNING "Torch set CMAKE_CUDA_ARCHITECTURES to OFF. Restored previous value: ${CMAKE_CUDA_ARCHITECTURES}") diff --git a/cmake/skala-torch.cmake b/cmake/skala-torch.cmake new file mode 100644 index 00000000..6e6214d0 --- /dev/null +++ b/cmake/skala-torch.cmake @@ -0,0 +1,76 @@ +# Find or download LibTorch +find_package(Torch QUIET) + +if(NOT Torch_FOUND) + message(STATUS "Torch not found. Downloading libtorch...") + + # Set libtorch version and download URL + set(LIBTORCH_VERSION "2.9.1") + set(USE_CUDA_LIBTORCH FALSE) #default is not to use the cuda version but cpu version + if(GAUXC_HAS_CUDA) + find_package(CUDAToolkit) + set(SUPPORTED_CUDA_VERSION_NO_DOTS "126" "128" "130") + set(CUDA_VERSION_NO_DOT ${CUDAToolkit_VERSION_MAJOR}${CUDAToolkit_VERSION_MINOR}) + if(CUDA_VERSION_NO_DOT IN_LIST SUPPORTED_CUDA_VERSION_NO_DOTS) + find_package(CUDA REQUIRED) # Needed for Caffe cmake, since it use the old CUDA inclusion + set(USE_CUDA_LIBTORCH TRUE) + message(STATUS "CUDAToolkit_INCLUDE_DIRS: ${CUDAToolkit_INCLUDE_DIRS}") + set(CUDA_INCLUDE_DIRS "${CUDAToolkit_INCLUDE_DIRS}") # and pass over the correct headers to Caffe cmake + else() + message(WARNING "CUDA toolkit version is ${CUDAToolkit_VERSION}, for which there is no libtorch to download.") + message(WARNING "Falling back to cpu version of libtorch.") + endif() + endif() + + + # Determine the appropriate libtorch variant based on platform and CUDA availability + if(UNIX AND NOT APPLE) + # Linux + if(USE_CUDA_LIBTORCH) + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu${CUDA_VERSION_NO_DOT}/libtorch-shared-with-deps-${LIBTORCH_VERSION}%2Bcu${CUDA_VERSION_NO_DOT}.zip") + else() + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-${LIBTORCH_VERSION}%2Bcpu.zip") + endif() + elseif(APPLE) + # macOS (CPU only) + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-${LIBTORCH_VERSION}.zip") + elseif(WIN32) + # Windows + if(USE_CUDA_LIBTORCH) + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu${CUDA_VERSION_NO_DOT}/libtorch-win-shared-with-deps-${LIBTORCH_VERSION}%2Bcu${CUDA_VERSION_NO_DOT}.zip") + else() + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-${LIBTORCH_VERSION}%2Bcpu.zip") + endif() + endif() + + set(LIBTORCH_DOWNLOAD_DIR "${CMAKE_BINARY_DIR}/libtorch") + set(LIBTORCH_ZIP "${CMAKE_BINARY_DIR}/libtorch.zip") + + # Download libtorch + if(NOT EXISTS ${LIBTORCH_DOWNLOAD_DIR}) + message(STATUS "Downloading libtorch from ${LIBTORCH_URL}") + file(DOWNLOAD ${LIBTORCH_URL} ${LIBTORCH_ZIP} + SHOW_PROGRESS + STATUS DOWNLOAD_STATUS) + + list(GET DOWNLOAD_STATUS 0 STATUS_CODE) + if(NOT STATUS_CODE EQUAL 0) + message(FATAL_ERROR "Failed to download libtorch: ${DOWNLOAD_STATUS}") + endif() + + # Extract libtorch + message(STATUS "Extracting libtorch...") + file(ARCHIVE_EXTRACT INPUT ${LIBTORCH_ZIP} DESTINATION ${CMAKE_BINARY_DIR}) + + # Clean up zip file + file(REMOVE ${LIBTORCH_ZIP}) + endif() + + # Set CMAKE_PREFIX_PATH to find the downloaded libtorch + set(CMAKE_PREFIX_PATH "${LIBTORCH_DOWNLOAD_DIR};${CMAKE_PREFIX_PATH}") + + # Find Torch package again + find_package(Torch REQUIRED) +endif() + +message(STATUS "Using Torch version: ${Torch_VERSION}")