Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmake/gauxc-onedft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ endif()

# store and restore CMAKE_CUDA_ARCHITECTURES if Torch clobbers it
set(_PREV_CUDA_ARCHS "${CMAKE_CUDA_ARCHITECTURES}")
find_package(Torch REQUIRED)

include(skala-torch)

if(CMAKE_CUDA_ARCHITECTURES STREQUAL "OFF")
set(CMAKE_CUDA_ARCHITECTURES "${_PREV_CUDA_ARCHS}" CACHE STRING "Restore CUDA archs after Torch override" FORCE)
message(WARNING "Torch set CMAKE_CUDA_ARCHITECTURES to OFF. Restored previous value: ${CMAKE_CUDA_ARCHITECTURES}")
Expand Down
76 changes: 76 additions & 0 deletions cmake/skala-torch.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Find or download LibTorch
find_package(Torch QUIET)

if(NOT Torch_FOUND)
message(STATUS "Torch not found. Downloading libtorch...")

# Set libtorch version and download URL
set(LIBTORCH_VERSION "2.9.1")
set(USE_CUDA_LIBTORCH FALSE) #default is not to use the cuda version but cpu version
if(GAUXC_HAS_CUDA)
find_package(CUDAToolkit)
set(SUPPORTED_CUDA_VERSION_NO_DOTS "126" "128" "130")
set(CUDA_VERSION_NO_DOT ${CUDAToolkit_VERSION_MAJOR}${CUDAToolkit_VERSION_MINOR})
if(CUDA_VERSION_NO_DOT IN_LIST SUPPORTED_CUDA_VERSION_NO_DOTS)
find_package(CUDA REQUIRED) # Needed for Caffe cmake, since it use the old CUDA inclusion
set(USE_CUDA_LIBTORCH TRUE)
message(STATUS "CUDAToolkit_INCLUDE_DIRS: ${CUDAToolkit_INCLUDE_DIRS}")
set(CUDA_INCLUDE_DIRS "${CUDAToolkit_INCLUDE_DIRS}") # and pass over the correct headers to Caffe cmake
else()
message(WARNING "CUDA toolkit version is ${CUDAToolkit_VERSION}, for which there is no libtorch to download.")
message(WARNING "Falling back to cpu version of libtorch.")
endif()
endif()


# Determine the appropriate libtorch variant based on platform and CUDA availability
if(UNIX AND NOT APPLE)
# Linux
if(USE_CUDA_LIBTORCH)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu${CUDA_VERSION_NO_DOT}/libtorch-shared-with-deps-${LIBTORCH_VERSION}%2Bcu${CUDA_VERSION_NO_DOT}.zip")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-${LIBTORCH_VERSION}%2Bcpu.zip")
endif()
elseif(APPLE)
# macOS (CPU only)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-${LIBTORCH_VERSION}.zip")
elseif(WIN32)
# Windows
if(USE_CUDA_LIBTORCH)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu${CUDA_VERSION_NO_DOT}/libtorch-win-shared-with-deps-${LIBTORCH_VERSION}%2Bcu${CUDA_VERSION_NO_DOT}.zip")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-${LIBTORCH_VERSION}%2Bcpu.zip")
endif()
endif()

set(LIBTORCH_DOWNLOAD_DIR "${CMAKE_BINARY_DIR}/libtorch")
set(LIBTORCH_ZIP "${CMAKE_BINARY_DIR}/libtorch.zip")

# Download libtorch
if(NOT EXISTS ${LIBTORCH_DOWNLOAD_DIR})
message(STATUS "Downloading libtorch from ${LIBTORCH_URL}")
file(DOWNLOAD ${LIBTORCH_URL} ${LIBTORCH_ZIP}
SHOW_PROGRESS
STATUS DOWNLOAD_STATUS)

list(GET DOWNLOAD_STATUS 0 STATUS_CODE)
if(NOT STATUS_CODE EQUAL 0)
message(FATAL_ERROR "Failed to download libtorch: ${DOWNLOAD_STATUS}")
endif()

# Extract libtorch
message(STATUS "Extracting libtorch...")
file(ARCHIVE_EXTRACT INPUT ${LIBTORCH_ZIP} DESTINATION ${CMAKE_BINARY_DIR})

# Clean up zip file
file(REMOVE ${LIBTORCH_ZIP})
endif()

# Set CMAKE_PREFIX_PATH to find the downloaded libtorch
set(CMAKE_PREFIX_PATH "${LIBTORCH_DOWNLOAD_DIR};${CMAKE_PREFIX_PATH}")

# Find Torch package again
find_package(Torch REQUIRED)
endif()

message(STATUS "Using Torch version: ${Torch_VERSION}")