From a880ce700de461f646fe356f149924bbbf0cc25e Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 15 Jan 2026 20:39:59 -0800 Subject: [PATCH 1/2] [FIX] Add dependencies --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 5d7d989..700a6bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,8 @@ dependencies = [ "tqdm", "dipy", "trx-python", + "nvidia-cuda-runtime", + "nvidia-curand", "cuda-python", "cuda-core", "cuda-cccl" From 7a67c0a8ec8782eaa2664c3abca888439a1b1326 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 15 Jan 2026 22:43:46 -0800 Subject: [PATCH 2/2] fix batched propogator --- cuslines/cuda_python/cu_propagate_seeds.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cuslines/cuda_python/cu_propagate_seeds.py b/cuslines/cuda_python/cu_propagate_seeds.py index b8991e5..f9a401a 100644 --- a/cuslines/cuda_python/cu_propagate_seeds.py +++ b/cuslines/cuda_python/cu_propagate_seeds.py @@ -30,8 +30,8 @@ def __init__(self, gpu_tracker): self.nSlines_old = np.zeros(self.ngpus, dtype=np.int32) self.nSlines = np.zeros(self.ngpus, dtype=np.int32) - self.slines = np.zeros(self.ngpus, dtype=np.ndarray) - self.sline_lens = np.zeros(self.ngpus, dtype=np.ndarray) + self.slines = [None] * self.ngpus + self.sline_lens = [None] * self.ngpus self.seeds_d = np.empty(self.ngpus, dtype=DEV_PTR) self.slineSeed_d = np.empty(self.ngpus, dtype=DEV_PTR) @@ -140,19 +140,19 @@ def _allocate_tracking_memory(self): ) if self.nSlines[ii] > EXCESS_ALLOC_FACT * self.nSlines_old[ii]: - self.slines[ii] = 0 - self.sline_lens[ii] = 0 + self.slines[ii] = None + self.sline_lens[ii] = None gc.collect() buffer_size = self._get_sl_buffer_size(ii) logger.debug(f"Streamline buffer size: {buffer_size}") - if not self.slines[ii]: + if self.slines[ii] is None: self.slines[ii] = np.empty( (EXCESS_ALLOC_FACT * self.nSlines[ii], MAX_SLINE_LEN * 2, 3), dtype=REAL_DTYPE, ) - if not self.sline_lens[ii]: + if self.sline_lens[ii] is None: self.sline_lens[ii] = np.empty( EXCESS_ALLOC_FACT * self.nSlines[ii], dtype=np.int32 )