Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/doc/crate_feature_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
//! ## `serde`
//! - Enables serialization support for serde 1.x
//!
//! ## `rayon`
//! - Enables parallel iterators, parallelized methods, the [`parallel`] module and [`par_azip!`].
//! - Implies std
#![cfg_attr(
not(feature = "rayon"),
doc = "//! ## `rayon`\n//! - Enables parallel iterators, parallelized methods, and the `par_azip!` macro.\n//! - Implies std\n"
)]
#![cfg_attr(
feature = "rayon",
doc = "//! ## `rayon`\n//! - Enables parallel iterators, parallelized methods, the [`crate::parallel`] module and [`crate::parallel::par_azip`].\n//! - Implies std\n"
)]
//!
//! ## `approx`
//! - Enables implementations of traits of the [`approx`] crate.
Expand All @@ -28,8 +33,3 @@
//!
//! ## `matrixmultiply-threading`
//! - Enable the ``threading`` feature in the matrixmultiply package
//!
//! [`parallel`]: crate::parallel

#[cfg(doc)]
use crate::parallel::par_azip;
13 changes: 10 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,18 @@
//! ## Crate Feature Flags
//!
//! The following crate feature flags are available. They are configured in your
//! `Cargo.toml`. See [`doc::crate_feature_flags`] for more information.
//! `Cargo.toml`. See [`crate::doc::crate_feature_flags`] for more information.
//!
//! - `std`: Rust standard library-using functionality (enabled by default)
//! - `serde`: serialization support for serde 1.x
//! - `rayon`: Parallel iterators, parallelized methods, the [`parallel`] module and [`par_azip!`].
#![cfg_attr(
not(feature = "rayon"),
doc = "//! - `rayon`: Parallel iterators, parallelized methods, and the `par_azip!` macro."
)]
#![cfg_attr(
feature = "rayon",
doc = "//! - `rayon`: Parallel iterators, parallelized methods, the [`parallel`] module and [`par_azip!`]."
)]
//! - `approx` Implementations of traits from the [`approx`] crate.
//! - `blas`: transparent BLAS support for matrix multiplication, needs configuration.
//! - `matrixmultiply-threading`: Use threading from `matrixmultiply`.
Expand Down Expand Up @@ -129,7 +136,7 @@ extern crate std;
#[cfg(feature = "blas")]
extern crate cblas_sys;

#[cfg(docsrs)]
#[cfg(any(doc, docsrs))]
pub mod doc;

use alloc::fmt::Debug;
Expand Down
114 changes: 57 additions & 57 deletions src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,63 @@ where
is_blas_2d(a._dim(), a._strides(), BlasOrder::F)
}

/// Dot product for dynamic-dimensional arrays (`ArrayD`).
///
/// For one-dimensional arrays, computes the vector dot product, which is the sum
/// of the elementwise products (no conjugation of complex operands).
/// Both arrays must have the same length.
///
/// For two-dimensional arrays, performs matrix multiplication. The array shapes
/// must be compatible in the following ways:
/// - If `self` is *M* × *N*, then `rhs` must be *N* × *K* for matrix-matrix multiplication
/// - If `self` is *M* × *N* and `rhs` is *N*, returns a vector of length *M*
/// - If `self` is *M* and `rhs` is *M* × *N*, returns a vector of length *N*
/// - If both arrays are one-dimensional of length *N*, returns a scalar
///
/// **Panics** if:
/// - The arrays have dimensions other than 1 or 2
/// - The array shapes are incompatible for the operation
/// - For vector dot product: the vectors have different lengths
impl<A> Dot<ArrayRef<A, IxDyn>> for ArrayRef<A, IxDyn>
where A: LinalgScalar
{
type Output = Array<A, IxDyn>;

fn dot(&self, rhs: &ArrayRef<A, IxDyn>) -> Self::Output
{
match (self.ndim(), rhs.ndim()) {
(1, 1) => {
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
ArrayD::from_elem(vec![], result)
}
(2, 2) => {
// Matrix-matrix multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(2, 1) => {
// Matrix-vector multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(1, 2) => {
// Vector-matrix multiplication
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
_ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
}
}
}

#[cfg(test)]
#[cfg(feature = "blas")]
mod blas_tests
Expand Down Expand Up @@ -1083,60 +1140,3 @@ mod blas_tests
}
}
}

/// Dot product for dynamic-dimensional arrays (`ArrayD`).
///
/// For one-dimensional arrays, computes the vector dot product, which is the sum
/// of the elementwise products (no conjugation of complex operands).
/// Both arrays must have the same length.
///
/// For two-dimensional arrays, performs matrix multiplication. The array shapes
/// must be compatible in the following ways:
/// - If `self` is *M* × *N*, then `rhs` must be *N* × *K* for matrix-matrix multiplication
/// - If `self` is *M* × *N* and `rhs` is *N*, returns a vector of length *M*
/// - If `self` is *M* and `rhs` is *M* × *N*, returns a vector of length *N*
/// - If both arrays are one-dimensional of length *N*, returns a scalar
///
/// **Panics** if:
/// - The arrays have dimensions other than 1 or 2
/// - The array shapes are incompatible for the operation
/// - For vector dot product: the vectors have different lengths
impl<A> Dot<ArrayRef<A, IxDyn>> for ArrayRef<A, IxDyn>
where A: LinalgScalar
{
type Output = Array<A, IxDyn>;

fn dot(&self, rhs: &ArrayRef<A, IxDyn>) -> Self::Output
{
match (self.ndim(), rhs.ndim()) {
(1, 1) => {
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
ArrayD::from_elem(vec![], result)
}
(2, 2) => {
// Matrix-matrix multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(2, 1) => {
// Matrix-vector multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(1, 2) => {
// Vector-matrix multiplication
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
_ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
}
}
}
1 change: 1 addition & 0 deletions src/zip/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ where D: Dimension
}

#[cfg(feature = "rayon")]
#[allow(dead_code)]
pub(crate) fn uninitialized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
{
let is_f = self.prefer_f();
Expand Down
5 changes: 3 additions & 2 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
)]

use approx::assert_relative_eq;
use core::panic;
use defmac::defmac;
#[allow(deprecated)]
use itertools::{zip, Itertools};
Expand Down Expand Up @@ -1005,7 +1006,7 @@ fn iter_size_hint()
fn zero_axes()
{
let mut a = arr1::<f32>(&[]);
if let Some(_) = a.iter().next() {
if a.iter().next().is_some() {
panic!();
}
a.map(|_| panic!());
Expand Down Expand Up @@ -2080,7 +2081,7 @@ fn test_contiguous()
assert!(c.as_slice_memory_order().is_some());
let v = c.slice(s![.., 0..1, ..]);
assert!(!v.is_standard_layout());
assert!(!v.as_slice_memory_order().is_some());
assert!(v.as_slice_memory_order().is_none());

let v = c.slice(s![1..2, .., ..]);
assert!(v.is_standard_layout());
Expand Down
6 changes: 4 additions & 2 deletions tests/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ fn inner_iter_corner_cases()
assert_equal(a0.rows(), vec![aview1(&[0])]);

let a2 = ArcArray::<i32, _>::zeros((0, 3));
assert_equal(a2.rows(), vec![aview1(&[]); 0]);
assert_equal(a2.rows(), Vec::<ArrayView1<'_, i32>>::new());

let a2 = ArcArray::<i32, _>::zeros((3, 0));
assert_equal(a2.rows(), vec![aview1(&[]); 3]);
Expand Down Expand Up @@ -359,11 +359,13 @@ fn axis_iter_zip_partially_consumed_discontiguous()
}
}

use ndarray::ArrayView1;

#[test]
fn outer_iter_corner_cases()
{
let a2 = ArcArray::<i32, _>::zeros((0, 3));
assert_equal(a2.outer_iter(), vec![aview1(&[]); 0]);
assert_equal(a2.outer_iter(), Vec::<ArrayView1<'_, i32>>::new());

let a2 = ArcArray::<i32, _>::zeros((3, 0));
assert_equal(a2.outer_iter(), vec![aview1(&[]); 3]);
Expand Down
2 changes: 1 addition & 1 deletion tests/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ fn var_too_large_ddof()
fn var_nan_ddof()
{
let a = Array2::<f64>::zeros((2, 3));
let v = a.var(std::f64::NAN);
let v = a.var(f64::NAN);
assert!(v.is_nan());
}

Expand Down
8 changes: 4 additions & 4 deletions tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32])
let aa = CowArray::from(arr1(a));
let bb = CowArray::from(arr1(b));
let cc = CowArray::from(arr1(c));
test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
test_oper_arr(op, aa.clone(), bb.clone(), cc.clone());
let dim = (2, 2);
let aa = aa.to_shape(dim).unwrap();
let bb = bb.to_shape(dim).unwrap();
let cc = cc.to_shape(dim).unwrap();
test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
test_oper_arr(op, aa.clone(), bb.clone(), cc.clone());
let dim = (1, 2, 1, 2);
let aa = aa.to_shape(dim).unwrap();
let bb = bb.to_shape(dim).unwrap();
let cc = cc.to_shape(dim).unwrap();
test_oper_arr::<f32, _>(op, aa.clone(), bb.clone(), cc.clone());
test_oper_arr(op, aa.clone(), bb.clone(), cc.clone());
}

fn test_oper_arr<A, D>(op: &str, mut aa: CowArray<f32, D>, bb: CowArray<f32, D>, cc: CowArray<f32, D>)
fn test_oper_arr<D>(op: &str, mut aa: CowArray<f32, D>, bb: CowArray<f32, D>, cc: CowArray<f32, D>)
where D: Dimension
{
match op {
Expand Down