diff --git a/Cargo.lock b/Cargo.lock index b619866ea58..ac44d8fd678 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10152,6 +10152,7 @@ dependencies = [ "mimalloc", "noodles-bgzf", "noodles-vcf", + "object_store", "parking_lot", "parquet 57.2.0", "rand 0.9.2", @@ -10173,6 +10174,8 @@ dependencies = [ "url", "uuid", "vortex", + "vortex-cuda", + "vortex-scan", ] [[package]] @@ -10264,8 +10267,11 @@ name = "vortex-cuda" version = "0.1.0" dependencies = [ "async-trait", + "bytes", "criterion", "cudarc", + "futures", + "parking_lot", "tokio", "tracing", "vortex-array", @@ -10273,6 +10279,11 @@ dependencies = [ "vortex-dtype", "vortex-error", "vortex-fastlanes", + "vortex-file", + "vortex-io", + "vortex-layout", + "vortex-metrics", + "vortex-scan", "vortex-session", "vortex-utils", ] @@ -10601,6 +10612,7 @@ dependencies = [ "tempfile", "tokio", "tracing", + "vortex-array", "vortex-buffer", "vortex-error", "vortex-metrics", diff --git a/Cargo.toml b/Cargo.toml index 6a60714b37d..0828ab7949e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -221,6 +221,7 @@ vortex-btrblocks = { version = "0.1.0", path = "./vortex-btrblocks", default-fea vortex-buffer = { version = "0.1.0", path = "./vortex-buffer", default-features = false } vortex-bytebool = { version = "0.1.0", path = "./encodings/bytebool", default-features = false } vortex-compute = { version = "0.1.0", path = "./vortex-compute", default-features = false } +vortex-cuda = { version = "0.1.0", path = "./vortex-cuda", default-features = false } vortex-datafusion = { version = "0.1.0", path = "./vortex-datafusion", default-features = false } vortex-datetime-parts = { version = "0.1.0", path = "./encodings/datetime-parts", default-features = false } vortex-decimal-byte-parts = { version = "0.1.0", path = "encodings/decimal-byte-parts", default-features = false } diff --git a/vortex-bench/Cargo.toml b/vortex-bench/Cargo.toml index aee90d34b77..84bac8cf9d4 100644 --- a/vortex-bench/Cargo.toml +++ b/vortex-bench/Cargo.toml @@ -32,6 +32,7 @@ humansize = { workspace = true } indicatif = { workspace = true, features = ["futures"] } itertools = { workspace = true } mimalloc = { workspace = true } +object_store = { workspace = true, features = ["aws", "http", "fs"] } noodles-bgzf = { workspace = true, features = ["async"] } noodles-vcf = { workspace = true, features = ["async"] } parking_lot = { workspace = true } @@ -64,3 +65,5 @@ vortex = { workspace = true, features = [ "zstd", "unstable_encodings", ] } +vortex-scan = { workspace = true } +vortex-cuda = { workspace = true } diff --git a/vortex-bench/src/bin/scan_io_bench.rs b/vortex-bench/src/bin/scan_io_bench.rs new file mode 100644 index 00000000000..6dfb4026610 --- /dev/null +++ b/vortex-bench/src/bin/scan_io_bench.rs @@ -0,0 +1,639 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::path::Path; +use std::path::PathBuf; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::time::Instant; + +use anyhow::Result; +use clap::Parser; +use clap::ValueEnum; +use futures::StreamExt; +use futures::TryStreamExt; +use object_store::ObjectStore; +use object_store::ObjectStoreScheme; +use object_store::aws::AmazonS3Builder; +use object_store::http::HttpBuilder; +use object_store::local::LocalFileSystem; +use object_store::path::Path as ObjectStorePath; +use parking_lot::Mutex; +use tracing_subscriber::EnvFilter; +use url::Url; +use vortex::array::Array; +use vortex::array::MaskFuture; +use vortex::array::expr::Expression; +use vortex::array::expr::col; +use vortex::array::expr::eq; +use vortex::array::expr::gt; +use vortex::array::expr::gt_eq; +use vortex::array::expr::lit; +use vortex::array::expr::lt; +use vortex::array::expr::lt_eq; +use vortex::array::expr::not_eq; +use vortex::array::expr::root; +use vortex::array::expr::select; +use vortex::dtype::FieldNames; +use vortex::error::VortexResult; +use vortex::error::vortex_err; +use vortex::file::OpenOptionsSessionExt; +use vortex::io::BufferAllocator; +use vortex::layout::LayoutReader; +use vortex::layout::collect_segment_ids; +use vortex::mask::Mask; +use vortex::metrics::VortexMetrics; +use vortex_bench::SESSION; +use vortex_cuda::CudaSessionExt; +use vortex_cuda::PinnedByteBufferPool; +use vortex_cuda::PinnedDeviceAllocator; +use vortex_scan::ScanBuilder; + +#[derive(Parser, Debug)] +#[command( + version, + about = "Benchmark Vortex scans over local files vs object stores" +)] +struct Args { + /// File path, directory, or object store URL (e.g. file:/..., s3://bucket/path, https://host/path) + #[arg(long)] + source: String, + /// Use object_store even for file: URLs + #[arg(long, default_value_t = false)] + force_object_store: bool, + /// Run a predefined scan shape. + #[arg(long, value_enum)] + preset: Option, + /// Projection field names (comma-separated). + #[arg(long, value_delimiter = ',')] + projection: Option>, + /// Filter column name. + #[arg(long)] + filter_col: Option, + /// Filter operator. + #[arg(long, value_enum)] + filter_op: Option, + /// Filter literal value (integer). + #[arg(long)] + filter_value: Option, + /// Filter literal type. + #[arg(long, value_enum, default_value_t = LiteralType::I64)] + filter_type: LiteralType, + /// Number of scan iterations. + #[arg(long, default_value_t = 1)] + iterations: usize, + /// Scan concurrency (tasks per thread). + #[arg(long, default_value_t = 4)] + concurrency: usize, + /// Max files scanned in parallel (file-level readahead). + #[arg(long, default_value_t = 1)] + file_concurrency: usize, + /// Reopen the file for each iteration to avoid caching effects. + #[arg(long, default_value_t = false)] + reopen: bool, + /// Which scan path to use. + #[arg(long, value_enum, default_value_t = ScanMode::Full)] + mode: ScanMode, + /// Only read segments and drop buffers (skip decode/projection). + #[arg(long, default_value_t = false)] + io_only: bool, + /// Only prune whole segments (no intra-segment pruning on CPU). + #[arg(long, default_value_t = false)] + prune_segments: bool, + /// Enable CUDA pinned read + H2D transfer. + #[arg(long, default_value_t = false)] + gpu: bool, +} + +#[derive(ValueEnum, Clone, Debug)] +enum ScanMode { + /// Read segments only (no decode). + Io, + /// Decode arrays without filter evaluation. + Decode, + /// Decode arrays with full filter/projection evaluation. + Full, +} + +#[derive(ValueEnum, Clone, Debug)] +enum Preset { + /// ClickBench query #2: AdvEngineID != 0, projecting AdvEngineID. + Clickbench2, +} + +#[derive(ValueEnum, Clone, Debug)] +enum FilterOp { + Eq, + Neq, + Gt, + Gte, + Lt, + Lte, +} + +#[derive(ValueEnum, Clone, Debug, Copy)] +enum LiteralType { + I16, + I32, + I64, +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); + + let args = Args::parse(); + let mode = if args.io_only { + ScanMode::Io + } else { + args.mode.clone() + }; + + let (projection, filter) = build_scan_exprs(&args)?; + let metrics = VortexMetrics::new_with_tags([("bench", "scan-io")]); + let read_bytes = metrics.counter("vortex.io.read.total_size"); + + #[allow(clippy::if_then_some_else_none)] + let gpu_allocator = if args.gpu { + let cuda_session = SESSION.cuda_session(); + let pool = std::sync::Arc::new(PinnedByteBufferPool::new(cuda_session.context().clone())); + Some(std::sync::Arc::new(PinnedDeviceAllocator::from_session( + pool, &SESSION, + )?)) + } else { + None + }; + let allocator: Option> = gpu_allocator + .as_ref() + .map(|alloc| alloc.clone() as std::sync::Arc); + + let targets = resolve_targets(&args).await?; + let cached_files = if args.reopen { + None + } else { + Some(std::sync::Arc::new( + open_all_targets( + &targets, + metrics.clone(), + args.file_concurrency, + allocator.clone(), + ) + .await?, + )) + }; + read_bytes.clear(); + + let start = Instant::now(); + let bytes_before = read_bytes.count(); + let first_seen = std::sync::Arc::new(AtomicBool::new(false)); + let first_info = std::sync::Arc::new(Mutex::new(None::<(f64, i64)>)); + let targets = targets.clone(); + + let rows = futures::stream::iter(0..args.iterations) + .flat_map(|_| futures::stream::iter(targets.clone().into_iter().enumerate())) + .map(|(idx, target)| { + let cached_files = cached_files.clone(); + let projection = projection.clone(); + let filter = filter.clone(); + let metrics = metrics.clone(); + let read_bytes = read_bytes.clone(); + let first_seen = first_seen.clone(); + let first_info = first_info.clone(); + let mode = mode.clone(); + let allocator = allocator.clone(); + async move { + let file = match &cached_files { + Some(files) => files[idx].clone(), + None => { + open_vortex_file_for_target(&target, metrics.clone(), allocator).await? + } + }; + + if args.prune_segments + && let Some(filter) = filter.as_ref() + && file.can_prune(filter)? + { + return Ok::<_, anyhow::Error>(0); + } + + if matches!(mode, ScanMode::Io) { + read_all_segments(&file, args.concurrency).await?; + if !first_seen.load(Ordering::Relaxed) + && !first_seen.swap(true, Ordering::Relaxed) + { + let latency = start.elapsed().as_secs_f64(); + let bytes = read_bytes.count() - bytes_before; + *first_info.lock() = Some((latency, bytes)); + } + let file_rows = usize::try_from(file.row_count()) + .map_err(|_| anyhow::anyhow!("row_count exceeds usize"))?; + drop(file); + return Ok::<_, anyhow::Error>(file_rows); + } + + let (scan_projection, scan_filter, bypass_filter) = match mode { + ScanMode::Decode => { + let scan_filter = if args.prune_segments { + filter.clone() + } else { + None + }; + (root(), scan_filter, true) + } + ScanMode::Full => (projection.clone(), filter.clone(), false), + ScanMode::Io => unreachable!("io-only handled above"), + }; + + let layout_reader = file.layout_reader()?; + let layout_reader = if args.prune_segments || bypass_filter { + std::sync::Arc::new(BenchLayoutReader::new( + layout_reader, + args.prune_segments, + bypass_filter, + )) as std::sync::Arc + } else { + layout_reader + }; + + let scan = ScanBuilder::new(SESSION.clone(), layout_reader) + .with_metrics(metrics.clone()) + .with_projection(scan_projection) + .with_some_filter(scan_filter) + .with_concurrency(args.concurrency) + .map(|array| Ok(array.len())); + + let mut stream = scan.into_stream()?; + let mut file_rows = 0usize; + while let Some(rows) = stream.try_next().await? { + if !first_seen.load(Ordering::Relaxed) + && !first_seen.swap(true, Ordering::Relaxed) + { + let latency = start.elapsed().as_secs_f64(); + let bytes = read_bytes.count() - bytes_before; + *first_info.lock() = Some((latency, bytes)); + } + file_rows += rows; + } + + drop(file); + Ok::<_, anyhow::Error>(file_rows) + } + }) + .buffer_unordered(args.file_concurrency.max(1)) + .try_fold( + 0usize, + |rows, file_rows| async move { Ok(rows + file_rows) }, + ) + .await?; + + let elapsed = start.elapsed().as_secs_f64(); + let gpu_sync_ms = if let Some(allocator) = gpu_allocator { + let sync_start = Instant::now(); + allocator.synchronize()?; + sync_start.elapsed().as_secs_f64() * 1000.0 + } else { + 0.0 + }; + let bytes = read_bytes.count(); + let (first_latency, first_bytes) = first_info + .lock() + .unwrap_or((elapsed, read_bytes.count() - bytes_before)); + + let avg_elapsed = elapsed / args.iterations as f64; + let avg_bytes = bytes as f64 / args.iterations as f64; + let avg_first_latency = first_latency / args.iterations as f64; + let avg_first_bytes = first_bytes as f64 / args.iterations as f64; + let steady_bytes = (avg_bytes - avg_first_bytes).max(0.0); + let steady_time = (avg_elapsed - avg_first_latency).max(0.0); + let total_mb_s = if avg_elapsed > 0.0 { + avg_bytes / (1024.0 * 1024.0) / avg_elapsed + } else { + 0.0 + }; + let steady_mb_s = if steady_time > 0.0 { + steady_bytes / (1024.0 * 1024.0) / steady_time + } else { + 0.0 + }; + + println!("files={}", targets.len()); + println!("rows={}", rows / args.iterations); + println!("avg_time_s={:.3}", avg_elapsed); + println!("avg_bytes={:.0}", avg_bytes); + println!("avg_mb_s={:.2}", total_mb_s); + println!("avg_first_latency_ms={:.2}", avg_first_latency * 1000.0); + println!("steady_mb_s={:.2}", steady_mb_s); + if args.gpu { + println!("gpu_sync_ms={:.2}", gpu_sync_ms); + } + + Ok(()) +} + +fn build_scan_exprs(args: &Args) -> VortexResult<(Expression, Option)> { + if let Some(preset) = &args.preset { + return build_preset_exprs(preset); + } + + let projection = match &args.projection { + Some(fields) if !fields.is_empty() => { + let names = FieldNames::from_iter(fields.iter().map(|s| s.as_str())); + select(names, root()) + } + _ => root(), + }; + + let filter = match (&args.filter_col, &args.filter_op, args.filter_value) { + (Some(col_name), Some(op), Some(value)) => { + let lhs = col(col_name.as_str()); + let rhs = match args.filter_type { + LiteralType::I16 => lit(i16::try_from(value) + .map_err(|_| vortex_err!("filter_value does not fit in i16"))?), + LiteralType::I32 => lit(i32::try_from(value) + .map_err(|_| vortex_err!("filter_value does not fit in i32"))?), + LiteralType::I64 => lit(value), + }; + Some(apply_filter_op(op.clone(), lhs, rhs)) + } + _ => None, + }; + + Ok((projection, filter)) +} + +fn build_preset_exprs(preset: &Preset) -> VortexResult<(Expression, Option)> { + match preset { + Preset::Clickbench2 => { + let projection = select(["AdvEngineID"], root()); + let filter = not_eq(col("AdvEngineID"), lit(0_i16)); + Ok((projection, Some(filter))) + } + } +} + +fn apply_filter_op(op: FilterOp, lhs: Expression, rhs: Expression) -> Expression { + match op { + FilterOp::Eq => eq(lhs, rhs), + FilterOp::Neq => not_eq(lhs, rhs), + FilterOp::Gt => gt(lhs, rhs), + FilterOp::Gte => gt_eq(lhs, rhs), + FilterOp::Lt => lt(lhs, rhs), + FilterOp::Lte => lt_eq(lhs, rhs), + } +} + +#[derive(Clone)] +enum ScanTarget { + Local(PathBuf), + ObjectStore { + store: std::sync::Arc, + path: ObjectStorePath, + }, +} + +async fn resolve_targets(args: &Args) -> Result> { + let source = &args.source; + + if let Ok(url) = Url::parse(source) { + if url.scheme() == "file" && !args.force_object_store { + let path = url + .to_file_path() + .map_err(|_| anyhow::anyhow!("Invalid file URL: {source}"))?; + return Ok(resolve_local_targets(&path)); + } + + let (scheme, store, path) = object_store_from_url(source)?; + if is_prefix(source) { + if matches!(scheme, ObjectStoreScheme::Http) { + anyhow::bail!("HTTP object stores do not support listing prefixes"); + } + let mut entries = store.list(Some(&path)); + let mut targets = Vec::new(); + while let Some(entry) = entries.try_next().await? { + targets.push(ScanTarget::ObjectStore { + store: store.clone(), + path: entry.location.clone(), + }); + } + return Ok(targets); + } + + return Ok(vec![ScanTarget::ObjectStore { store, path }]); + } + + let path = PathBuf::from(source); + Ok(resolve_local_targets(&path)) +} + +fn resolve_local_targets(path: &Path) -> Vec { + if path.is_dir() { + let mut entries = match std::fs::read_dir(path) { + Ok(entries) => entries + .filter_map(|entry| entry.ok()) + .map(|entry| entry.path()) + .filter(|entry| entry.extension().is_some_and(|e| e == "vortex")) + .collect::>(), + Err(_) => Vec::new(), + }; + entries.sort(); + entries.into_iter().map(ScanTarget::Local).collect() + } else { + vec![ScanTarget::Local(path.to_path_buf())] + } +} + +fn is_prefix(source: &str) -> bool { + source.ends_with('/') +} + +async fn open_vortex_file_for_target( + target: &ScanTarget, + metrics: VortexMetrics, + allocator: Option>, +) -> Result { + let session = SESSION.clone(); + match target { + ScanTarget::Local(path) => { + let mut options = session.open_options().with_metrics(metrics); + if let Some(allocator) = allocator { + options = options.with_allocator(allocator); + } + Ok(options.open_path(path).await?) + } + ScanTarget::ObjectStore { store, path } => { + let path_str = path.to_string(); + let mut options = session.open_options().with_metrics(metrics); + if let Some(allocator) = allocator { + options = options.with_allocator(allocator); + } + Ok(options.open_object_store(store, &path_str).await?) + } + } +} + +async fn open_all_targets( + targets: &[ScanTarget], + metrics: VortexMetrics, + concurrency: usize, + allocator: Option>, +) -> Result> { + let mut files = vec![None; targets.len()]; + let results = futures::stream::iter(targets.iter().enumerate()) + .map(|(idx, target)| { + let metrics = metrics.clone(); + let allocator = allocator.clone(); + async move { + let file = open_vortex_file_for_target(target, metrics, allocator).await?; + Ok::<_, anyhow::Error>((idx, file)) + } + }) + .buffer_unordered(concurrency.max(1)) + .try_collect::>() + .await?; + + for (idx, file) in results { + files[idx] = Some(file); + } + + files + .into_iter() + .map(|file| file.ok_or_else(|| anyhow::anyhow!("file open missing"))) + .collect() +} + +fn object_store_from_url( + url_str: &str, +) -> Result<( + ObjectStoreScheme, + std::sync::Arc, + ObjectStorePath, +)> { + let url = Url::parse(url_str)?; + let (scheme, path) = ObjectStoreScheme::parse(&url).map_err(object_store::Error::from)?; + let store: std::sync::Arc = match scheme { + ObjectStoreScheme::Local => std::sync::Arc::new(LocalFileSystem::default()), + ObjectStoreScheme::AmazonS3 => { + std::sync::Arc::new(AmazonS3Builder::from_env().with_url(url_str).build()?) + } + ObjectStoreScheme::Http => std::sync::Arc::new( + HttpBuilder::new() + .with_url(&url[..url::Position::BeforePath]) + .build()?, + ), + otherwise => anyhow::bail!("unsupported object store scheme: {otherwise:?}"), + }; + + Ok((scheme, store, path)) +} + +async fn read_all_segments(file: &vortex::file::VortexFile, concurrency: usize) -> Result<()> { + let layout = file.footer().layout().clone(); + let segment_ids = collect_segment_ids(&layout)?; + let segment_source = file.segment_source(); + + futures::stream::iter(segment_ids) + .map(|segment_id| { + let segment_source = segment_source.clone(); + async move { + let buffer = segment_source.request(segment_id).await?; + drop(buffer); + Ok::<_, anyhow::Error>(()) + } + }) + .buffer_unordered(concurrency.max(1)) + .try_collect::>() + .await?; + + Ok(()) +} + +#[derive(Clone)] +struct BenchLayoutReader { + inner: std::sync::Arc, + segment_pruning: bool, + bypass_filter: bool, +} + +impl BenchLayoutReader { + fn new( + inner: std::sync::Arc, + segment_pruning: bool, + bypass_filter: bool, + ) -> Self { + Self { + inner, + segment_pruning, + bypass_filter, + } + } +} + +impl LayoutReader for BenchLayoutReader { + fn name(&self) -> &std::sync::Arc { + self.inner.name() + } + + fn dtype(&self) -> &vortex::dtype::DType { + self.inner.dtype() + } + + fn row_count(&self) -> u64 { + self.inner.row_count() + } + + fn register_splits( + &self, + field_mask: &[vortex::dtype::FieldMask], + row_range: &std::ops::Range, + splits: &mut std::collections::BTreeSet, + ) -> VortexResult<()> { + self.inner.register_splits(field_mask, row_range, splits) + } + + fn pruning_evaluation( + &self, + row_range: &std::ops::Range, + expr: &Expression, + mask: Mask, + ) -> VortexResult { + if !self.segment_pruning { + return self.inner.pruning_evaluation(row_range, expr, mask); + } + + let len = mask.len(); + let fut = self.inner.pruning_evaluation(row_range, expr, mask)?; + Ok(MaskFuture::new(len, async move { + let mask = fut.await?; + if mask.all_false() { + Ok(mask) + } else { + Ok(Mask::new_true(len)) + } + })) + } + + fn filter_evaluation( + &self, + row_range: &std::ops::Range, + expr: &Expression, + mask: MaskFuture, + ) -> VortexResult { + if self.bypass_filter { + Ok(mask) + } else { + self.inner.filter_evaluation(row_range, expr, mask) + } + } + + fn projection_evaluation( + &self, + row_range: &std::ops::Range, + expr: &Expression, + mask: MaskFuture, + ) -> VortexResult>> + { + self.inner.projection_evaluation(row_range, expr, mask) + } +} diff --git a/vortex-cuda/Cargo.toml b/vortex-cuda/Cargo.toml index cf822f1ec02..2181440b01e 100644 --- a/vortex-cuda/Cargo.toml +++ b/vortex-cuda/Cargo.toml @@ -18,22 +18,38 @@ workspace = true [dependencies] async-trait = { workspace = true } +bytes = { workspace = true } cudarc = { workspace = true } +parking_lot = { workspace = true } tracing = { workspace = true } vortex-array = { workspace = true } vortex-buffer = { workspace = true } vortex-dtype = { workspace = true } vortex-error = { workspace = true } vortex-fastlanes = { workspace = true } +vortex-io = { workspace = true } vortex-session = { workspace = true } vortex-utils = { workspace = true } [dev-dependencies] criterion = { workspace = true } -tokio = { workspace = true, features = ["rt", "macros"] } +futures = { workspace = true } +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } +vortex-file = { workspace = true, features = ["tokio"] } +vortex-layout = { workspace = true } +vortex-metrics = { workspace = true } +vortex-scan = { workspace = true } [build-dependencies] [[bench]] name = "for_cuda" harness = false + +[[bench]] +name = "h2d_pinned" +harness = false + +[[bench]] +name = "pinned_scan" +harness = false diff --git a/vortex-cuda/benches/h2d_pinned.rs b/vortex-cuda/benches/h2d_pinned.rs new file mode 100644 index 00000000000..8d19ba5840b --- /dev/null +++ b/vortex-cuda/benches/h2d_pinned.rs @@ -0,0 +1,349 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Benchmarks for H2D transfer throughput with pinned vs regular memory. +//! +//! Run with: cargo bench -p vortex-cuda --bench h2d_pinned +//! +//! This benchmark measures: +//! 1. Pure H2D transfer: pinned memory vs regular Vec +//! 2. Read-into-pinned: reading from RAM buffer into pinned memory +//! 3. Full pipeline: RAM -> pinned -> GPU + +#![allow(clippy::unwrap_used)] +#![allow(clippy::expect_used)] +#![allow(clippy::redundant_clone)] + +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; + +use criterion::BenchmarkId; +use criterion::Criterion; +use criterion::Throughput; +use criterion::criterion_group; +use criterion::criterion_main; +use cudarc::driver::CudaContext; +use vortex_cuda::PinnedByteBufferPool; +use vortex_cuda::has_nvcc; + +// Buffer sizes to test: 1KB, 64KB, 1MB, 16MB, 64MB, 256MB +const SIZES: &[(usize, &str)] = &[ + (1 << 10, "1KB"), + (1 << 16, "64KB"), + (1 << 20, "1MB"), + (16 << 20, "16MB"), + (64 << 20, "64MB"), + (256 << 20, "256MB"), +]; + +/// Benchmark H2D transfer from regular (pageable) memory. +/// CUDA internally stages through a pinned buffer, so this measures the slower path. +fn bench_h2d_regular(c: &mut Criterion) { + let ctx = CudaContext::new(0).expect("Failed to create CUDA context"); + let stream = ctx.new_stream().expect("Failed to create stream"); + + let mut group = c.benchmark_group("h2d_regular"); + group.sample_size(10); + + for (size, label) in SIZES { + // Skip very large sizes for regular memory test to save time + if *size > 64 << 20 { + continue; + } + + group.throughput(Throughput::Bytes(*size as u64)); + group.bench_with_input(BenchmarkId::new("regular", label), size, |b, &size| { + // Allocate regular memory and touch it + let data: Vec = vec![0x42u8; size]; + + // Pre-allocate device buffer + let mut device = unsafe { stream.alloc::(size) }.expect("Failed to alloc device"); + + b.iter_custom(|iters| { + let mut total = Duration::ZERO; + for _ in 0..iters { + let start = Instant::now(); + stream.memcpy_htod(&data, &mut device).expect("H2D failed"); + stream.synchronize().expect("Sync failed"); + total += start.elapsed(); + } + total + }); + }); + } + + group.finish(); +} + +/// Benchmark H2D transfer from pinned memory. +/// This uses DMA and should be faster than regular memory. +fn bench_h2d_pinned(c: &mut Criterion) { + let ctx = CudaContext::new(0).expect("Failed to create CUDA context"); + let stream = ctx.new_stream().expect("Failed to create stream"); + let pool = Arc::new(PinnedByteBufferPool::new(ctx.clone())); + + let mut group = c.benchmark_group("h2d_pinned"); + group.sample_size(10); + + for (size, label) in SIZES { + group.throughput(Throughput::Bytes(*size as u64)); + group.bench_with_input(BenchmarkId::new("pinned", label), size, |b, &size| { + // Allocate pinned memory and touch it + let mut pinned = pool.get(size).expect("Failed to get pinned buffer"); + pinned.as_mut_slice().expect("slice").fill(0x42); + + // Pre-allocate device buffer + let mut device = unsafe { stream.alloc::(size) }.expect("Failed to alloc device"); + + b.iter_custom(|iters| { + let mut total = Duration::ZERO; + for _ in 0..iters { + let start = Instant::now(); + stream + .memcpy_htod(&pinned, &mut device) + .expect("H2D failed"); + stream.synchronize().expect("Sync failed"); + total += start.elapsed(); + } + total + }); + + // Return to pool + pool.put(pinned).ok(); + }); + } + + group.finish(); +} + +/// Benchmark the PooledPinnedBuffer path (what the allocator uses). +fn bench_h2d_pooled_pinned(c: &mut Criterion) { + let ctx = CudaContext::new(0).expect("Failed to create CUDA context"); + let stream = ctx.new_stream().expect("Failed to create stream"); + let pool = Arc::new(PinnedByteBufferPool::new(ctx.clone())); + + let mut group = c.benchmark_group("h2d_pooled_pinned"); + group.sample_size(10); + + for (size, label) in SIZES { + group.throughput(Throughput::Bytes(*size as u64)); + group.bench_with_input(BenchmarkId::new("pooled", label), size, |b, &size| { + // Pre-allocate device buffer + let mut device = unsafe { stream.alloc::(size) }.expect("Failed to alloc device"); + + b.iter_custom(|iters| { + let mut total = Duration::ZERO; + for _ in 0..iters { + // Get from pool, fill, transfer, return to pool + let mut pooled = pool.get_pooled(size).expect("Failed to get pooled buffer"); + pooled.as_mut_slice().fill(0x42); + + let start = Instant::now(); + stream + .memcpy_htod(&pooled, &mut device) + .expect("H2D failed"); + stream.synchronize().expect("Sync failed"); + total += start.elapsed(); + + // pooled is returned to pool on drop + } + total + }); + }); + } + + group.finish(); +} + +/// Benchmark copying from RAM into pinned buffer, then H2D. +/// This simulates: read from file/network into RAM, copy to pinned, transfer to GPU. +fn bench_ram_to_pinned_to_gpu(c: &mut Criterion) { + let ctx = CudaContext::new(0).expect("Failed to create CUDA context"); + let stream = ctx.new_stream().expect("Failed to create stream"); + let pool = Arc::new(PinnedByteBufferPool::new(ctx.clone())); + + let mut group = c.benchmark_group("ram_pinned_gpu"); + group.sample_size(10); + + for (size, label) in SIZES { + // Skip very large for this combined test + if *size > 64 << 20 { + continue; + } + + group.throughput(Throughput::Bytes(*size as u64)); + group.bench_with_input( + BenchmarkId::new("ram_to_pinned_to_gpu", label), + size, + |b, &size| { + // Source data in regular RAM + let ram_data: Vec = vec![0x42u8; size]; + + // Pre-allocate device buffer + let mut device = + unsafe { stream.alloc::(size) }.expect("Failed to alloc device"); + + b.iter_custom(|iters| { + let mut total = Duration::ZERO; + for _ in 0..iters { + let start = Instant::now(); + + // Get pinned buffer + let mut pinned = + pool.get_pooled(size).expect("Failed to get pooled buffer"); + + // Copy RAM -> pinned + pinned.as_mut_slice().copy_from_slice(&ram_data); + + // Transfer pinned -> GPU + stream + .memcpy_htod(&pinned, &mut device) + .expect("H2D failed"); + stream.synchronize().expect("Sync failed"); + + total += start.elapsed(); + } + total + }); + }, + ); + } + + group.finish(); +} + +/// Benchmark direct RAM to GPU (baseline without pinned intermediate). +fn bench_ram_to_gpu_direct(c: &mut Criterion) { + let ctx = CudaContext::new(0).expect("Failed to create CUDA context"); + let stream = ctx.new_stream().expect("Failed to create stream"); + + let mut group = c.benchmark_group("ram_gpu_direct"); + group.sample_size(10); + + for (size, label) in SIZES { + if *size > 64 << 20 { + continue; + } + + group.throughput(Throughput::Bytes(*size as u64)); + group.bench_with_input( + BenchmarkId::new("ram_to_gpu_direct", label), + size, + |b, &size| { + let ram_data: Vec = vec![0x42u8; size]; + let mut device = + unsafe { stream.alloc::(size) }.expect("Failed to alloc device"); + + b.iter_custom(|iters| { + let mut total = Duration::ZERO; + for _ in 0..iters { + let start = Instant::now(); + stream + .memcpy_htod(&ram_data, &mut device) + .expect("H2D failed"); + stream.synchronize().expect("Sync failed"); + total += start.elapsed(); + } + total + }); + }, + ); + } + + group.finish(); +} + +/// Quick sanity check that prints bandwidth numbers. +fn print_bandwidth_summary() { + println!("\n=== H2D Bandwidth Quick Test ===\n"); + + let ctx = CudaContext::new(0).expect("Failed to create CUDA context"); + let stream = ctx.new_stream().expect("Failed to create stream"); + let pool = Arc::new(PinnedByteBufferPool::new(ctx.clone())); + + let size = 256 << 20; // 256MB + let iterations = 10; + + // Pinned test + let mut pinned = pool.get(size).expect("Failed to get pinned buffer"); + pinned.as_mut_slice().expect("slice").fill(0x42); + let mut device = unsafe { stream.alloc::(size) }.expect("Failed to alloc device"); + + // Warmup + for _ in 0..3 { + stream + .memcpy_htod(&pinned, &mut device) + .expect("H2D failed"); + stream.synchronize().expect("Sync failed"); + } + + let start = Instant::now(); + for _ in 0..iterations { + stream + .memcpy_htod(&pinned, &mut device) + .expect("H2D failed"); + stream.synchronize().expect("Sync failed"); + } + let pinned_time = start.elapsed(); + let pinned_bw = (size * iterations) as f64 / pinned_time.as_secs_f64() / 1e9; + + pool.put(pinned).ok(); + + // Regular test + let regular: Vec = vec![0x42u8; size]; + + // Warmup + for _ in 0..3 { + stream + .memcpy_htod(®ular, &mut device) + .expect("H2D failed"); + stream.synchronize().expect("Sync failed"); + } + + let start = Instant::now(); + for _ in 0..iterations { + stream + .memcpy_htod(®ular, &mut device) + .expect("H2D failed"); + stream.synchronize().expect("Sync failed"); + } + let regular_time = start.elapsed(); + let regular_bw = (size * iterations) as f64 / regular_time.as_secs_f64() / 1e9; + + println!("Buffer size: {} MB", size >> 20); + println!("Iterations: {}", iterations); + println!(); + println!( + "Pinned memory: {:.2} GB/s ({:.2} ms per transfer)", + pinned_bw, + pinned_time.as_secs_f64() * 1000.0 / iterations as f64 + ); + println!( + "Regular memory: {:.2} GB/s ({:.2} ms per transfer)", + regular_bw, + regular_time.as_secs_f64() * 1000.0 / iterations as f64 + ); + println!("Speedup: {:.2}x", pinned_bw / regular_bw); + println!(); +} + +fn all_benchmarks(c: &mut Criterion) { + if !has_nvcc() { + eprintln!("nvcc not found, skipping CUDA benchmarks"); + return; + } + + // Print quick summary first + print_bandwidth_summary(); + + // Run detailed benchmarks + bench_h2d_pinned(c); + bench_h2d_regular(c); + bench_h2d_pooled_pinned(c); + bench_ram_to_pinned_to_gpu(c); + bench_ram_to_gpu_direct(c); +} + +criterion_group!(benches, all_benchmarks); +criterion_main!(benches); diff --git a/vortex-cuda/benches/pinned_scan.rs b/vortex-cuda/benches/pinned_scan.rs new file mode 100644 index 00000000000..f6dc1568ff7 --- /dev/null +++ b/vortex-cuda/benches/pinned_scan.rs @@ -0,0 +1,395 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Benchmark Vortex file scanning with pinned buffer allocator. +//! +//! Run with: cargo bench -p vortex-cuda --bench pinned_scan +//! +//! This benchmark: +//! 1. Creates a synthetic Vortex file in memory +//! 2. Scans it with default allocator vs pinned allocator +//! 3. Measures total I/O + decode time + +#![allow(clippy::unwrap_used)] +#![allow(clippy::expect_used)] +#![allow(clippy::len_zero)] + +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; + +use criterion::BenchmarkId; +use criterion::Criterion; +use criterion::Throughput; +use criterion::criterion_group; +use criterion::criterion_main; +use cudarc::driver::CudaContext; +use tokio::runtime::Runtime; +use vortex_array::IntoArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::expr::session::ExprSession; +use vortex_array::session::ArraySession; +use vortex_array::stream::ArrayStreamExt; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_buffer::ByteBuffer; +use vortex_buffer::ByteBufferMut; +use vortex_cuda::PinnedBufferAllocator; +use vortex_cuda::PinnedByteBufferPool; +use vortex_cuda::PinnedDeviceAllocator; +use vortex_cuda::has_nvcc; +use vortex_file::OpenOptionsSessionExt; +use vortex_file::WriteOptionsSessionExt; +use vortex_file::register_default_encodings; +use vortex_io::session::RuntimeSession; +use vortex_layout::session::LayoutSession; +use vortex_metrics::VortexMetrics; +use vortex_session::VortexSession; + +// Test sizes: 1M, 10M, 100M rows of i64 (8 bytes each) +const ROW_COUNTS: &[(usize, &str)] = &[ + (1_000_000, "1M_rows"), + (10_000_000, "10M_rows"), + (100_000_000, "100M_rows"), +]; + +fn create_session() -> VortexSession { + let mut session = VortexSession::empty() + .with::() + .with::() + .with::() + .with::() + .with::(); + register_default_encodings(&mut session); + session +} + +/// Create a synthetic Vortex file in memory with the given number of rows. +fn create_vortex_buffer(session: &VortexSession, num_rows: usize) -> ByteBuffer { + let rt = Runtime::new().unwrap(); + + // Create a simple i64 array with predictable data + let data: Vec = (0..num_rows as i64).collect(); + let array = PrimitiveArray::new(Buffer::from(data), Validity::NonNullable).into_array(); + + let mut buf = ByteBufferMut::empty(); + rt.block_on(async { + session + .write_options() + .write(&mut buf, array.to_array_stream()) + .await + .expect("Failed to write Vortex file"); + }); + + ByteBuffer::from(buf) +} + +/// Scan with default allocator (regular memory). +fn scan_default(session: &VortexSession, buffer: &ByteBuffer) -> Duration { + let rt = Runtime::new().unwrap(); + + rt.block_on(async { + let file = session + .open_options() + .open_buffer(buffer.clone()) + .expect("Failed to open file"); + + let start = Instant::now(); + + let result = file + .scan() + .expect("Failed to create scan") + .into_array_stream() + .expect("Failed to create stream") + .read_all() + .await + .expect("Scan failed"); + + let elapsed = start.elapsed(); + assert!(result.len() > 0); + elapsed + }) +} + +/// Scan with pinned allocator (data stays on host in pinned memory). +fn scan_pinned( + session: &VortexSession, + buffer: &ByteBuffer, + pool: &Arc, +) -> Duration { + let rt = Runtime::new().unwrap(); + let allocator = Arc::new(PinnedBufferAllocator::new(pool.clone())); + + rt.block_on(async { + let file = session + .open_options() + .with_allocator(allocator) + .open_buffer(buffer.clone()) + .expect("Failed to open file"); + + let start = Instant::now(); + + let result = file + .scan() + .expect("Failed to create scan") + .into_array_stream() + .expect("Failed to create stream") + .read_all() + .await + .expect("Scan failed"); + + let elapsed = start.elapsed(); + assert!(result.len() > 0); + elapsed + }) +} + +/// Scan with pinned device allocator (data transferred to GPU). +fn scan_device( + session: &VortexSession, + buffer: &ByteBuffer, + pool: &Arc, + stream: &Arc, +) -> Duration { + let rt = Runtime::new().unwrap(); + let allocator = Arc::new(PinnedDeviceAllocator::new(pool.clone(), stream.clone())); + + rt.block_on(async { + let file = session + .open_options() + .with_allocator(allocator.clone()) + .open_buffer(buffer.clone()) + .expect("Failed to open file"); + + let start = Instant::now(); + + let result = file + .scan() + .expect("Failed to create scan") + .into_array_stream() + .expect("Failed to create stream") + .read_all() + .await + .expect("Scan failed"); + + // Synchronize to ensure all H2D transfers complete + allocator.synchronize().expect("Failed to synchronize"); + + let elapsed = start.elapsed(); + assert!(result.len() > 0); + elapsed + }) +} + +fn bench_scan_default(c: &mut Criterion) { + let session = create_session(); + + let mut group = c.benchmark_group("scan_default"); + group.sample_size(10); + + for (num_rows, label) in ROW_COUNTS { + // Skip very large for CI + if *num_rows > 10_000_000 { + continue; + } + + let buffer = create_vortex_buffer(&session, *num_rows); + let bytes = buffer.len(); + + group.throughput(Throughput::Bytes(bytes as u64)); + group.bench_with_input(BenchmarkId::new("default", label), &buffer, |b, buffer| { + b.iter_custom(|iters| { + let mut total = Duration::ZERO; + for _ in 0..iters { + total += scan_default(&session, buffer); + } + total + }); + }); + } + + group.finish(); +} + +fn bench_scan_pinned(c: &mut Criterion) { + if !has_nvcc() { + eprintln!("nvcc not found, skipping pinned scan benchmark"); + return; + } + + let ctx = CudaContext::new(0).expect("Failed to create CUDA context"); + let pool = Arc::new(PinnedByteBufferPool::new(ctx)); + let session = create_session(); + + let mut group = c.benchmark_group("scan_pinned"); + group.sample_size(10); + + for (num_rows, label) in ROW_COUNTS { + if *num_rows > 10_000_000 { + continue; + } + + let buffer = create_vortex_buffer(&session, *num_rows); + let bytes = buffer.len(); + + group.throughput(Throughput::Bytes(bytes as u64)); + group.bench_with_input(BenchmarkId::new("pinned", label), &buffer, |b, buffer| { + b.iter_custom(|iters| { + let mut total = Duration::ZERO; + for _ in 0..iters { + total += scan_pinned(&session, buffer, &pool); + } + total + }); + }); + } + + group.finish(); +} + +fn bench_scan_device(c: &mut Criterion) { + if !has_nvcc() { + eprintln!("nvcc not found, skipping device scan benchmark"); + return; + } + + let ctx = CudaContext::new(0).expect("Failed to create CUDA context"); + let stream = Arc::new(ctx.new_stream().expect("Failed to create stream")); + let pool = Arc::new(PinnedByteBufferPool::new(ctx)); + let session = create_session(); + + let mut group = c.benchmark_group("scan_device"); + group.sample_size(10); + + for (num_rows, label) in ROW_COUNTS { + if *num_rows > 10_000_000 { + continue; + } + + let buffer = create_vortex_buffer(&session, *num_rows); + let bytes = buffer.len(); + + group.throughput(Throughput::Bytes(bytes as u64)); + group.bench_with_input(BenchmarkId::new("device", label), &buffer, |b, buffer| { + b.iter_custom(|iters| { + let mut total = Duration::ZERO; + for _ in 0..iters { + total += scan_device(&session, buffer, &pool, &stream); + } + total + }); + }); + } + + group.finish(); +} + +/// Quick comparison that prints results directly. +fn print_scan_comparison() { + if !has_nvcc() { + eprintln!("nvcc not found, skipping scan comparison"); + return; + } + + println!("\n=== Vortex Scan: Default vs Pinned vs Device Allocator ===\n"); + + let ctx = CudaContext::new(0).expect("Failed to create CUDA context"); + let stream = Arc::new(ctx.new_stream().expect("Failed to create stream")); + let pool = Arc::new(PinnedByteBufferPool::new(ctx)); + let session = create_session(); + + let num_rows = 10_000_000; // 10M rows + println!("Creating Vortex file with {} rows...", num_rows); + let buffer = create_vortex_buffer(&session, num_rows); + println!( + "File size: {:.2} MB ({} bytes)\n", + buffer.len() as f64 / 1e6, + buffer.len() + ); + + let iterations = 5; + + // Warmup + println!("Warming up..."); + scan_default(&session, &buffer); + scan_pinned(&session, &buffer, &pool); + scan_device(&session, &buffer, &pool, &stream); + + // Default allocator + println!( + "Running {} iterations with default allocator...", + iterations + ); + let start = Instant::now(); + for _ in 0..iterations { + scan_default(&session, &buffer); + } + let default_time = start.elapsed(); + let default_throughput = (buffer.len() * iterations) as f64 / default_time.as_secs_f64() / 1e9; + + // Pinned allocator (host) + println!( + "Running {} iterations with pinned allocator (host)...", + iterations + ); + let start = Instant::now(); + for _ in 0..iterations { + scan_pinned(&session, &buffer, &pool); + } + let pinned_time = start.elapsed(); + let pinned_throughput = (buffer.len() * iterations) as f64 / pinned_time.as_secs_f64() / 1e9; + + // Device allocator (pinned + H2D) + println!( + "Running {} iterations with device allocator (pinned + H2D)...", + iterations + ); + let start = Instant::now(); + for _ in 0..iterations { + scan_device(&session, &buffer, &pool, &stream); + } + let device_time = start.elapsed(); + let device_throughput = (buffer.len() * iterations) as f64 / device_time.as_secs_f64() / 1e9; + + println!(); + println!("Results:"); + println!( + " Default allocator: {:.2} GB/s ({:.2} ms avg)", + default_throughput, + default_time.as_secs_f64() * 1000.0 / iterations as f64 + ); + println!( + " Pinned allocator (host): {:.2} GB/s ({:.2} ms avg)", + pinned_throughput, + pinned_time.as_secs_f64() * 1000.0 / iterations as f64 + ); + println!( + " Device allocator (H2D): {:.2} GB/s ({:.2} ms avg)", + device_throughput, + device_time.as_secs_f64() * 1000.0 / iterations as f64 + ); + println!(); + println!("Ratios vs default:"); + println!( + " Pinned (host): {:.2}x", + pinned_throughput / default_throughput.max(0.001) + ); + println!( + " Device (H2D): {:.2}x", + device_throughput / default_throughput.max(0.001) + ); + println!(); +} + +fn all_benchmarks(c: &mut Criterion) { + // Print quick summary first + print_scan_comparison(); + + // Run detailed benchmarks + bench_scan_default(c); + bench_scan_pinned(c); + bench_scan_device(c); +} + +criterion_group!(benches, all_benchmarks); +criterion_main!(benches); diff --git a/vortex-cuda/src/device_buffer.rs b/vortex-cuda/src/device_buffer.rs new file mode 100644 index 00000000000..4d76d712fb5 --- /dev/null +++ b/vortex-cuda/src/device_buffer.rs @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt; +use std::hash::Hash; +use std::hash::Hasher; +use std::ops::Range; +use std::sync::Arc; + +use cudarc::driver::CudaEvent; +use cudarc::driver::CudaSlice; +use cudarc::driver::CudaStream; +use vortex_array::buffer::DeviceBuffer; +use vortex_buffer::Alignment; +use vortex_buffer::ByteBuffer; +use vortex_buffer::ByteBufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_err; +use vortex_error::vortex_panic; + +use crate::PooledPinnedBuffer; + +/// A device buffer backed by CUDA device memory. +pub struct CudaDeviceBuffer { + data: Arc>, + offset: usize, + len: usize, + stream: Arc, + completion: Arc, + host_buffer: Arc>>, +} + +impl CudaDeviceBuffer { + pub fn new( + data: Arc>, + offset: usize, + len: usize, + stream: Arc, + completion: CudaEvent, + host_buffer: PooledPinnedBuffer, + ) -> Self { + Self { + data, + offset, + len, + stream, + completion: Arc::new(completion), + host_buffer: Arc::new(parking_lot::Mutex::new(Some(host_buffer))), + } + } + + fn view(&self) -> cudarc::driver::CudaView<'_, u8> { + self.data.slice(self.offset..self.offset + self.len) + } +} + +impl DeviceBuffer for CudaDeviceBuffer { + fn len(&self) -> usize { + self.len + } + + fn copy_to_host(&self) -> VortexResult { + let mut host = ByteBufferMut::with_capacity_aligned(self.len, Alignment::of::()); + unsafe { host.set_len(self.len) }; + self.stream + .memcpy_dtoh(&self.view(), host.as_mut_slice()) + .map_err(|e| vortex_err!("Failed to copy from device: {e}"))?; + Ok(host.freeze()) + } + + fn slice(&self, range: Range) -> Arc { + if range.start > range.end || range.end > self.len { + vortex_panic!( + "range out of bounds: {}..{} for length {}", + range.start, + range.end, + self.len + ); + } + Arc::new(Self { + data: self.data.clone(), + offset: self.offset + range.start, + len: range.end - range.start, + stream: self.stream.clone(), + completion: self.completion.clone(), + host_buffer: self.host_buffer.clone(), + }) + } +} + +impl Drop for CudaDeviceBuffer { + fn drop(&mut self) { + let _ = self.completion.synchronize(); + if let Some(buffer) = self.host_buffer.lock().take() { + drop(buffer); + } + } +} + +impl PartialEq for CudaDeviceBuffer { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.data, &other.data) && self.offset == other.offset && self.len == other.len + } +} + +impl Eq for CudaDeviceBuffer {} + +impl Hash for CudaDeviceBuffer { + fn hash(&self, state: &mut H) { + let ptr = Arc::as_ptr(&self.data) as usize; + ptr.hash(state); + self.offset.hash(state); + self.len.hash(state); + } +} + +impl fmt::Debug for CudaDeviceBuffer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CudaDeviceBuffer") + .field("offset", &self.offset) + .field("len", &self.len) + .finish() + } +} diff --git a/vortex-cuda/src/lib.rs b/vortex-cuda/src/lib.rs index ca798939493..03075a16371 100644 --- a/vortex-cuda/src/lib.rs +++ b/vortex-cuda/src/lib.rs @@ -3,9 +3,12 @@ //! CUDA support for Vortex arrays. +mod device_buffer; pub mod executor; mod for_; mod kernel; +pub mod pinned; +pub mod pinned_allocator; mod session; use std::process::Command; @@ -13,7 +16,13 @@ use std::process::Command; pub use executor::CudaExecutionCtx; pub use executor::CudaKernelEvents; use for_::ForExecutor; +pub use pinned::PinnedByteBuffer; +pub use pinned::PinnedByteBufferPool; +pub use pinned::PooledPinnedBuffer; +pub use pinned_allocator::PinnedBufferAllocator; +pub use pinned_allocator::PinnedDeviceAllocator; pub use session::CudaSession; +pub use session::CudaSessionExt; /// Check if the NVIDIA CUDA Compiler is available. pub fn has_nvcc() -> bool { diff --git a/vortex-cuda/src/pinned.rs b/vortex-cuda/src/pinned.rs new file mode 100644 index 00000000000..3a1acd43484 --- /dev/null +++ b/vortex-cuda/src/pinned.rs @@ -0,0 +1,321 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use bytes::Bytes; +use cudarc::driver::CudaContext; +use cudarc::driver::CudaStream; +use cudarc::driver::HostSlice; +use cudarc::driver::PinnedHostSlice; +use cudarc::driver::SyncOnDrop; +use parking_lot::Mutex; +use vortex_buffer::ByteBuffer; +use vortex_error::VortexResult; +use vortex_error::vortex_err; +use vortex_error::vortex_panic; +use vortex_utils::aliases::hash_map::HashMap; + +/// A page-locked host buffer allocated by CUDA. +/// +/// This is intended as a staging buffer for H2D transfers. Contents are uninitialized after +/// allocation. +pub struct PinnedByteBuffer { + inner: PinnedHostSlice, +} + +#[allow(clippy::same_name_method)] +impl PinnedByteBuffer { + /// Allocate a pinned host buffer with uninitialized contents. + /// + /// # Safety + /// The returned buffer's contents are uninitialized. The caller must initialize before read. + pub unsafe fn uninit(ctx: &Arc, len: usize) -> VortexResult { + let inner = unsafe { + ctx.alloc_pinned::(len) + .map_err(|e| vortex_err!("failed to allocate pinned host buffer: {e}"))? + }; + Ok(Self { inner }) + } + + /// Returns the length of the buffer in bytes. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Returns true if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Returns the buffer as an immutable slice. + pub fn as_slice(&self) -> VortexResult<&[u8]> { + self.inner + .as_slice() + .map_err(|e| vortex_err!("failed to access pinned host buffer: {e}")) + } + + /// Returns the buffer as a mutable slice. + pub fn as_mut_slice(&mut self) -> VortexResult<&mut [u8]> { + self.inner + .as_mut_slice() + .map_err(|e| vortex_err!("failed to access pinned host buffer: {e}")) + } + + /// Returns a raw pointer to the buffer. + pub fn as_ptr(&self) -> VortexResult<*const u8> { + self.inner + .as_ptr() + .map_err(|e| vortex_err!("failed to access pinned host buffer: {e}")) + } + + /// Returns a mutable raw pointer to the buffer. + pub fn as_mut_ptr(&mut self) -> VortexResult<*mut u8> { + self.inner + .as_mut_ptr() + .map_err(|e| vortex_err!("failed to access pinned host buffer: {e}")) + } + + /// Returns the CUDA context that owns this allocation. + pub fn context(&self) -> &Arc { + self.inner.context() + } +} + +#[allow(clippy::same_name_method)] +impl HostSlice for PinnedByteBuffer { + fn len(&self) -> usize { + self.len() + } + + unsafe fn stream_synced_slice<'a>( + &'a self, + stream: &'a CudaStream, + ) -> (&'a [u8], SyncOnDrop<'a>) { + unsafe { as HostSlice>::stream_synced_slice(&self.inner, stream) } + } + + unsafe fn stream_synced_mut_slice<'a>( + &'a mut self, + stream: &'a CudaStream, + ) -> (&'a mut [u8], SyncOnDrop<'a>) { + unsafe { + as HostSlice>::stream_synced_mut_slice(&mut self.inner, stream) + } + } +} + +/// A simple pinned host buffer pool keyed by allocation size. +pub struct PinnedByteBufferPool { + ctx: Arc, + max_keep_per_size: usize, + buckets: Mutex>>, +} + +impl PinnedByteBufferPool { + /// Create a new pool with default limits. + pub fn new(ctx: Arc) -> Self { + Self::with_limits(ctx, 4) + } + + /// Create a new pool with a maximum number of cached buffers per size. + pub fn with_limits(ctx: Arc, max_keep_per_size: usize) -> Self { + Self { + ctx, + max_keep_per_size: max_keep_per_size.max(1), + buckets: Mutex::new(HashMap::new()), + } + } + + /// Acquire a pinned buffer of the given size in bytes. + pub fn get(&self, len: usize) -> VortexResult { + let mut buckets = self.buckets.lock(); + if let Some(bucket) = buckets.get_mut(&len) + && let Some(buf) = bucket.pop() + { + return Ok(buf); + } + unsafe { PinnedByteBuffer::uninit(&self.ctx, len) } + } + + /// Return a buffer to the pool. + pub fn put(&self, buf: PinnedByteBuffer) -> VortexResult<()> { + let len = buf.len(); + let mut buckets = self.buckets.lock(); + let bucket = buckets.entry(len).or_default(); + if bucket.len() < self.max_keep_per_size { + bucket.push(buf); + } + Ok(()) + } + + /// Get a pooled pinned buffer that will be returned to the pool on drop. + pub fn get_pooled(self: &Arc, len: usize) -> VortexResult { + let inner = self.get(len)?; + Ok(PooledPinnedBuffer { + inner: Some(inner), + pool: self.clone(), + }) + } +} + +/// A pinned buffer that is returned to its pool when dropped. +/// +/// This wrapper owns a [`PinnedByteBuffer`] and ensures it gets returned to the +/// [`PinnedByteBufferPool`] when the buffer is no longer needed. This enables efficient +/// buffer reuse for I/O operations. +pub struct PooledPinnedBuffer { + inner: Option, + pool: Arc, +} + +#[allow(clippy::same_name_method)] +impl PooledPinnedBuffer { + /// Create a new pooled buffer. + pub fn new(inner: PinnedByteBuffer, pool: Arc) -> Self { + Self { + inner: Some(inner), + pool, + } + } + + /// Returns the length of the buffer in bytes. + pub fn len(&self) -> usize { + self.inner + .as_ref() + .map(|b| b.len()) + .unwrap_or_else(|| vortex_panic!("buffer already consumed")) + } + + /// Returns true if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the buffer as a mutable slice. + /// + /// # Panics + /// + /// Panics if the buffer has already been consumed or if the CUDA context is invalid. + pub fn as_mut_slice(&mut self) -> &mut [u8] { + let inner = self + .inner + .as_mut() + .unwrap_or_else(|| vortex_panic!("buffer already consumed")); + inner + .as_mut_slice() + .unwrap_or_else(|e| vortex_panic!("failed to access pinned host buffer: {e}")) + } + + /// Convert this pooled buffer into a [`ByteBuffer`]. + /// + /// The returned buffer will return the underlying pinned memory to the pool when dropped. + /// This enables zero-copy conversion to the standard Vortex buffer type while maintaining + /// pool-based memory reuse. + pub fn into_byte_buffer(mut self) -> ByteBuffer { + let inner = self + .inner + .take() + .unwrap_or_else(|| vortex_panic!("buffer already consumed")); + let len = inner.len(); + let pool = self.pool.clone(); + + // Create a wrapper that will return the buffer to the pool on drop + let wrapper = PooledPinnedBufferOwner::new(inner, pool); + + // Use Bytes::from_owner to create a Bytes that owns the wrapper + let bytes = Bytes::from_owner(wrapper); + + // The ByteBuffer should have the full length + assert_eq!(bytes.len(), len); + + ByteBuffer::from(bytes) + } +} + +#[allow(clippy::same_name_method)] +impl HostSlice for PooledPinnedBuffer { + fn len(&self) -> usize { + self.len() + } + + unsafe fn stream_synced_slice<'a>( + &'a self, + stream: &'a CudaStream, + ) -> (&'a [u8], SyncOnDrop<'a>) { + let inner = self + .inner + .as_ref() + .unwrap_or_else(|| vortex_panic!("buffer already consumed")); + unsafe { HostSlice::stream_synced_slice(inner, stream) } + } + + unsafe fn stream_synced_mut_slice<'a>( + &'a mut self, + stream: &'a CudaStream, + ) -> (&'a mut [u8], SyncOnDrop<'a>) { + let inner = self + .inner + .as_mut() + .unwrap_or_else(|| vortex_panic!("buffer already consumed")); + unsafe { HostSlice::stream_synced_mut_slice(inner, stream) } + } +} + +impl Drop for PooledPinnedBuffer { + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + // Return the buffer to the pool, ignoring errors + drop(self.pool.put(inner)); + } + } +} + +/// Internal wrapper that owns a PinnedByteBuffer and returns it to the pool on drop. +/// +/// This is used by `Bytes::from_owner` to manage the lifecycle of pooled pinned buffers. +struct PooledPinnedBufferOwner { + // We use Option so we can take the buffer out in Drop + inner: Option, + // Cached pointer and length for AsRef implementation + ptr: *const u8, + len: usize, + pool: Arc, +} + +// SAFETY: The pinned buffer is allocated by CUDA and is safe to send across threads. +// The pointer is derived from the buffer and remains valid as long as the buffer exists. +unsafe impl Send for PooledPinnedBufferOwner {} +unsafe impl Sync for PooledPinnedBufferOwner {} + +impl PooledPinnedBufferOwner { + fn new(inner: PinnedByteBuffer, pool: Arc) -> Self { + let ptr = inner + .as_ptr() + .unwrap_or_else(|e| vortex_panic!("failed to get pointer to pinned buffer: {e}")); + let len = inner.len(); + Self { + inner: Some(inner), + ptr, + len, + pool, + } + } +} + +impl AsRef<[u8]> for PooledPinnedBufferOwner { + fn as_ref(&self) -> &[u8] { + // SAFETY: The pointer and length were captured when the buffer was created + // and remain valid as long as this struct exists (buffer is in the Mutex). + unsafe { std::slice::from_raw_parts(self.ptr, self.len) } + } +} + +impl Drop for PooledPinnedBufferOwner { + fn drop(&mut self) { + // Take the buffer out and return it to the pool + if let Some(buffer) = self.inner.take() { + drop(self.pool.put(buffer)); + } + } +} diff --git a/vortex-cuda/src/pinned_allocator.rs b/vortex-cuda/src/pinned_allocator.rs new file mode 100644 index 00000000000..805dae91070 --- /dev/null +++ b/vortex-cuda/src/pinned_allocator.rs @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use cudarc::driver::CudaStream; +use vortex_array::buffer::BufferHandle; +use vortex_buffer::Alignment; +use vortex_error::VortexResult; +use vortex_error::vortex_err; +use vortex_io::BufferAllocator; +use vortex_io::WriteTarget; +use vortex_session::VortexSession; + +use crate::PinnedByteBufferPool; +use crate::PooledPinnedBuffer; +use crate::device_buffer::CudaDeviceBuffer; +use crate::session::CudaSessionExt; + +/// Allocator that sources buffers from a CUDA pinned pool. +pub struct PinnedBufferAllocator { + pool: Arc, +} + +impl PinnedBufferAllocator { + pub fn new(pool: Arc) -> Self { + Self { pool } + } +} + +impl BufferAllocator for PinnedBufferAllocator { + fn allocate(&self, len: usize, _alignment: Alignment) -> VortexResult> { + let buffer = self.pool.get_pooled(len)?; + Ok(Box::new(buffer)) + } +} + +impl WriteTarget for PooledPinnedBuffer { + fn as_mut_slice(&mut self) -> &mut [u8] { + PooledPinnedBuffer::as_mut_slice(self) + } + + fn len(&self) -> usize { + PooledPinnedBuffer::len(self) + } + + fn into_handle(self: Box) -> VortexResult { + Ok(BufferHandle::new_host(self.into_byte_buffer())) + } +} + +/// Allocator that reads into pinned buffers and transfers to device memory. +pub struct PinnedDeviceAllocator { + pool: Arc, + stream: Arc, +} + +impl PinnedDeviceAllocator { + pub fn new(pool: Arc, stream: Arc) -> Self { + Self { pool, stream } + } + + pub fn from_session( + pool: Arc, + session: &VortexSession, + ) -> VortexResult { + let stream = session.cuda_session().new_stream()?; + Ok(Self::new(pool, stream)) + } + + pub fn synchronize(&self) -> VortexResult<()> { + self.stream + .synchronize() + .map_err(|e| vortex_err!("Failed to synchronize CUDA stream: {e}")) + } +} + +impl BufferAllocator for PinnedDeviceAllocator { + fn allocate(&self, len: usize, _alignment: Alignment) -> VortexResult> { + let buffer = self.pool.get_pooled(len)?; + Ok(Box::new(PinnedDeviceWriteTarget { + buffer, + stream: self.stream.clone(), + })) + } +} + +struct PinnedDeviceWriteTarget { + buffer: PooledPinnedBuffer, + stream: Arc, +} + +impl WriteTarget for PinnedDeviceWriteTarget { + fn as_mut_slice(&mut self) -> &mut [u8] { + self.buffer.as_mut_slice() + } + + fn len(&self) -> usize { + self.buffer.len() + } + + fn into_handle(self: Box) -> VortexResult { + let len = self.buffer.len(); + let mut device = unsafe { self.stream.alloc::(len) } + .map_err(|e| vortex_err!("Failed to allocate device memory: {e}"))?; + + self.stream + .memcpy_htod(&self.buffer, &mut device) + .map_err(|e| vortex_err!("Failed to copy to device: {e}"))?; + + let event = self + .stream + .record_event(None) + .map_err(|e| vortex_err!("Failed to record CUDA event: {e}"))?; + + let device_buffer = CudaDeviceBuffer::new( + Arc::new(device), + 0, + len, + self.stream.clone(), + event, + self.buffer, + ); + + Ok(BufferHandle::new_device(Arc::new(device_buffer))) + } +} diff --git a/vortex-cuda/src/session.rs b/vortex-cuda/src/session.rs index 6b33d10e753..bbdd502059b 100644 --- a/vortex-cuda/src/session.rs +++ b/vortex-cuda/src/session.rs @@ -5,6 +5,7 @@ use std::fmt::Debug; use std::sync::Arc; use cudarc::driver::CudaContext; +use cudarc::driver::CudaStream; use vortex_array::vtable::ArrayId; use vortex_dtype::PType; use vortex_error::VortexResult; @@ -50,6 +51,18 @@ impl CudaSession { Ok(CudaExecutionCtx::new(stream, vortex_session)) } + /// Creates a new CUDA stream. + pub fn new_stream(&self) -> VortexResult> { + self.context + .new_stream() + .map_err(|e| vortex_err!("Failed to create CUDA stream: {}", e)) + } + + /// Returns the CUDA context. + pub fn context(&self) -> &Arc { + &self.context + } + /// Registers CUDA support for an array encoding. /// /// # Arguments diff --git a/vortex-file/src/open.rs b/vortex-file/src/open.rs index 03232856531..6a8b23aecef 100644 --- a/vortex-file/src/open.rs +++ b/vortex-file/src/open.rs @@ -12,6 +12,7 @@ use vortex_dtype::DType; use vortex_error::VortexError; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_io::BufferAllocator; use vortex_io::InstrumentedReadAt; use vortex_io::VortexReadAt; use vortex_io::session::RuntimeSessionExt; @@ -53,6 +54,8 @@ pub struct VortexOpenOptions { footer: Option