From c3fa0e5264956b39313974d9750e58eb986131d8 Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Tue, 6 Jan 2026 11:43:06 +0100 Subject: [PATCH 01/12] add SparseRow tests --- .../sysds/runtime/data/SparseRowVector.java | 6 +- .../test/component/sparse/SparseRowTest.java | 215 ++++++++++++++++++ 2 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java b/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java index 50229e15df6..e59bf2402ee 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java @@ -190,6 +190,10 @@ public void setEstimatedNzs(int estnnz){ estimatedNzs = estnnz; } + public int getEstimatedNzs(){ + return estimatedNzs; + } + private void recap(int newCap) { if( newCap<=values.length ) return; @@ -314,7 +318,7 @@ public int searchIndexesFirstLTE(int col) { //search lt col index (see binary search) index = Math.abs( index+1 ); - return (index-1 < size) ? index-1 : -1; + return (index-1 >= 0) ? index-1 : -1; } @Override diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java new file mode 100644 index 00000000000..d2f588361c6 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.data.SparseRowScalar; +import org.apache.sysds.runtime.data.SparseRowVector; +import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + + +public class SparseRowTest extends AutomatedTestBase +{ + private final static int cols = 121; + private final static int minVal = -10; + private final static int maxVal = 10; + private final static double sparsity = 0.3; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSparseRowEmptyToString() { + SparseRowScalar srs = new SparseRowScalar(); + assertEquals("", srs.toString()); + } + + @Test + public void testSparseRowScalarInitZeroVal() { + SparseRowScalar srs = new SparseRowScalar(5, 0); + srs.compact(); + assertEquals(-1, srs.getIndex()); + } + + @Test + public void testSparseRowScalarSetNewVal() { + SparseRowScalar srs = new SparseRowScalar(); + assertTrue(srs.set(3, 5.0)); + } + + @Test + public void testSparseRowScalarInvalidSet() { + SparseRowScalar srs = new SparseRowScalar(1, 1.0); + RuntimeException ex = assertThrows(RuntimeException.class, () -> srs.set(3, 5.0)); + assertEquals("Invalid set to sparse row scalar.", ex.getMessage()); + } + + @Test + public void testSparseRowScalarAppendZero() { + SparseRowScalar srs = new SparseRowScalar(1, 1.0); + SparseRow srs2 = srs.append(2, 0.0); + assertEquals(srs, srs2); + assertNotEquals(0, srs2.values()[0]); + } + + @Test + public void testSparseRowScalarCompactZero() { + SparseRowScalar srs = new SparseRowScalar(1, 0.0); + srs.compact(); + assertEquals(-1, srs.getIndex()); + } + + @Test + public void testSparseRowScalarCompactNonZero() { + SparseRowScalar srs = new SparseRowScalar(1, 1.0); + srs.compact(); + assertEquals(1, srs.getIndex()); + } + + @Test + public void testSparseRowScalarCopy() { + SparseRowScalar srs = new SparseRowScalar(1, 1.0); + SparseRowScalar srs2 = (SparseRowScalar) srs.copy(true); + assertEquals(srs.getIndex(), srs2.getIndex()); + assertEquals(srs.getValue(), srs2.getValue(), 0.0); + assertNotEquals(srs, srs2); + } + + @Test + public void testSparseRowVectorSetValues() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity, 7)[0]; + SparseRowVector srv = new SparseRowVector(UtilFunctions.computeNnz(v, 0, v.length), v, v.length); + + srv.compact(); + int nnz = srv.size(); + double[] w = getRandomMatrix(1, nnz, minVal, maxVal, 1, 13)[0]; + srv.setValues(w); + + assertArrayEquals(w, srv.values(), 0.0); + assertEquals(srv.indexes().length, srv.values().length); + } + + @Test + public void testSparseRowVectorSetIndexes() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, 1, 7)[0]; + int nnz = UtilFunctions.computeNnz(v, 0, v.length); + SparseRowVector srv = new SparseRowVector(nnz, v, v.length); + + int[] indexes = new int[nnz]; + for(int i = 0; i < nnz; i++) indexes[i] = i; + srv.setIndexes(indexes); + + int idx = (int)(Math.random() * nnz); + assertEquals(idx, srv.getIndex(idx)); + assertEquals(-1, srv.getIndex(nnz)); + assertEquals(srv.values().length, srv.indexes().length); + } + + @Test + public void testSparseRowVectorCopyFromLargerArray() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity, 7)[0]; + double[] w = getRandomMatrix(1, 2*cols, minVal, maxVal, sparsity, 7)[0]; + SparseRowVector srv = new SparseRowVector(UtilFunctions.computeNnz(v, 0, v.length), v, v.length); + SparseRowVector other = new SparseRowVector(UtilFunctions.computeNnz(w, 0, w.length), w, w.length); + srv.copy(other); + + assertArrayEquals(other.indexes(), srv.indexes()); + assertArrayEquals(other.values(), srv.values(), 0.0); + assertNotEquals(other, srv); + } + + @Test + public void testSparseRowVectorSetEstimatedNzs() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity, 7)[0]; + int nnz = UtilFunctions.computeNnz(v, 0, v.length); + SparseRowVector srv = new SparseRowVector(nnz, v, v.length); + srv.setEstimatedNzs(nnz+1); + assertEquals(nnz+1, srv.getEstimatedNzs()); + } + + @Test + public void testSparseRowVectorSetAtPos() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity, 7)[0]; + int nnz = UtilFunctions.computeNnz(v, 0, v.length); + SparseRowVector srv = new SparseRowVector(nnz, v, v.length); + + int pos = nnz-1; + int col = 2; + double val = 2.0; + srv.setAtPos(pos, col, val); + + assertEquals(col, srv.indexes()[pos]); + assertEquals(val, srv.indexes()[pos],0.0); + } + + @Test + public void testSparseRowVectorGetIndex() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity, 7)[0]; + int nnz = UtilFunctions.computeNnz(v, 0, v.length); + SparseRowVector srv = new SparseRowVector(nnz, v, v.length); + + int pos = 0; + srv.setAtPos(pos, 5, 2.0); + int index = srv.getIndex(5); + assertEquals(pos, index); + + int col2 = cols+1; + int index2 = srv.getIndex(col2); + assertEquals(-1, index2); + } + + @Test + public void testSparseRowVectorSearchIndexesFirstLTESizeZero() { + SparseRowVector srv = new SparseRowVector(); + int index = srv.searchIndexesFirstLTE(1); + assertEquals(-1, index); + } + + @Test + public void testSparseRowVectorSearchIndexesFirstLTENotFound() { + SparseRowVector srv = new SparseRowVector(new double[] {1.0, 3.0}, new int[] {1, 3}); + int index = srv.searchIndexesFirstLTE(0); + assertEquals(-1, index); + int index2 = srv.searchIndexesFirstLTE(2); + assertEquals(0, index2); + int index3 = srv.searchIndexesFirstLTE(5); + assertEquals(1, index3); + } + + @Test + public void testSparseRowVectorSetIndexRangeWithRecap() { + SparseRowVector srv = new SparseRowVector(); + srv.add(1, 1.0); + srv.add(4, 4.0); + srv.add(5, 5.0); + srv.setIndexRange(2, 3, new double[]{2.0, 3.0}, 0, 2); + } +} From 7326de47784bb9faa4e51bf059982ac77cbf7d4e Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Tue, 6 Jan 2026 12:20:25 +0100 Subject: [PATCH 02/12] add SparseBlock tests --- .../sysds/runtime/data/SparseBlock.java | 39 ++- .../sparse/SparseBlockAlignment.java | 13 + .../sparse/SparseBlockContainsTest.java | 251 +++++++++++++++ .../sparse/SparseBlockEqualsTest.java | 222 +++++++++++++ .../component/sparse/SparseBlockIterator.java | 300 +++++++++++++++--- 5 files changed, 765 insertions(+), 60 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/sparse/SparseBlockContainsTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java index 864569358f6..29ee6c79f78 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java @@ -26,6 +26,7 @@ import java.util.List; import org.apache.sysds.runtime.matrix.data.IJV; +import org.apache.sysds.runtime.util.UtilFunctions; /** * This SparseBlock is an abstraction for different sparse matrix formats. @@ -501,16 +502,20 @@ public boolean contains(double pattern, int rl, int ru) { } public List contains(double[] pattern, boolean earlyAbort) { + int nnz = UtilFunctions.computeNnz(pattern, 0, pattern.length); List ret = new ArrayList<>(); int rlen = numRows(); + for( int i=0; i0 && posFIndexGTE(_curRow, cl) < 0)) ) + while(_curRow < _rlen){ + if(isEmpty(_curRow)){ + _curRow++; + continue; + } + + int pos = (cl == 0)? 0 : posFIndexGTE(_curRow, cl); + if(pos < 0){ + _curRow++; + continue; + } + + int sizeRow = size(_curRow); + int endPos = (_cu == Integer.MAX_VALUE)? sizeRow : posFIndexGTE(_curRow, _cu); + if(endPos < 0) endPos = sizeRow; + + if(pos < endPos){ + _curColIx = pos(_curRow)+pos; + _curIndexes = indexes(_curRow); + _curValues = values(_curRow); + return; + } _curRow++; - if(_curRow >= _rlen) - _noNext = true; - else { - _curColIx = (cl==0) ? - pos(_curRow) : posFIndexGTE(_curRow, cl); - _curIndexes = indexes(_curRow); - _curValues = values(_curRow); } + _noNext = true; } } diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java index 3c2ed30adc6..f1f044baa5e 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java @@ -276,4 +276,17 @@ else if( i<37 ) {//CSR/COO different after update pos throw new RuntimeException(ex); } } + + @Test + public void testSparseBlockDifferentNumRows() { + double[][] A = getRandomMatrix(rows, cols, -10, 10, sparsity3, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock sblock = mbtmp.getSparseBlock(); + + double[][] B = getRandomMatrix(2*rows, cols, -10, 10, sparsity3, 1234); + MatrixBlock mbtmp2 = DataConverter.convertToMatrixBlock(B); + SparseBlock sblock2 = mbtmp2.getSparseBlock(); + + Assert.assertFalse(sblock.isAligned(sblock2)); + } } diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockContainsTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockContainsTest.java new file mode 100644 index 00000000000..59b74f40804 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockContainsTest.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class SparseBlockContainsTest extends AutomatedTestBase +{ + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSparseBlockContainsNoMatchCOO() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockContainsNoMatchCSC() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockContainsNoMatchCSR() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockContainsNoMatchDCSR() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockContainsNoMatchMCSC() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockContainsNoMatchMCSR() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.MCSR); + } + + @Test + public void testSparseBlockContainsNaNCOO() { + runSparseBlockContainsNaNTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockContainsNaNCSC() { + runSparseBlockContainsNaNTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockContainsNaNCSR() { + runSparseBlockContainsNaNTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockContainsNaNDCSR() { + runSparseBlockContainsNaNTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockContainsNaNMCSC() { + runSparseBlockContainsNaNTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockContainsNaNMCSR() { + runSparseBlockContainsNaNTest(SparseBlock.Type.MCSR); + } + + @Test + public void testSparseBlockContainsEarlyAbortCOO() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockContainsEarlyAbortCSC() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockContainsEarlyAbortCSR() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockContainsEarlyAbortDCSR() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockContainsEarlyAbortMCSC() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockContainsEarlyAbortMCSR() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.MCSR); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsCOO() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsCSC() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsCSR() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsDCSR() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsMCSC() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsMCSR() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.MCSR); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroCOO() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroCSC() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroCSR() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroDCSR() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroMCSC() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroMCSR() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.MCSR); + } + + private void runSparseBlockContainsNoMatchTest(SparseBlock.Type btype) { + double[] pattern = new double[]{1., 2., 3.}; + double[][] A = new double[][]{{4., 5., 6.}, {7., 8., 9.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + List result = sblock.contains(pattern, false); + assertEquals(List.of(), result); + } + + private void runSparseBlockContainsNaNTest(SparseBlock.Type btype) { + double[] pattern = new double[]{Double.NaN, 2., 3.}; + double[][] A = new double[][]{{Double.NaN, 2., 3.}, {1., 2., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + List result = sblock.contains(pattern, false); + assertEquals(List.of(0), result); + } + + private void runSparseBlockContainsEarlyAbortTest(SparseBlock.Type btype) { + double[] pattern = new double[]{1., 2., 3.}; + double[][] A = new double[][]{{0., 0., 0.}, {1., 2., 3.}, {1., 2., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + List result = sblock.contains(pattern, true); + assertEquals(List.of(1), result); + } + + private void runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type btype) { + double[] pattern = new double[]{1., 2., 3., 4.}; + double[][] A = new double[][]{{0., 0., 0.}, {1., 2., 3.}, {1., 2., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + List result = sblock.contains(pattern, false); + assertEquals(List.of(), result); + } + + private void runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type btype) { + double[] pattern = new double[]{0., 1., 2.}; + double[][] A = new double[][]{{0., 1., 2.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}, {0., 1., 2.}, {1., 2., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + List result = sblock.contains(pattern, false); + assertEquals(List.of(0, 4), result); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java new file mode 100644 index 00000000000..29f7c7e2463 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.apache.sysds.runtime.data.SparseRowVector; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; + +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; + +@RunWith(Enclosed.class) +public class SparseBlockEqualsTest { + + @RunWith(Parameterized.class) + public static class SparseBlockEqualsSparseBlockTest extends AutomatedTestBase { + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + private final SparseBlock.Type type1; + private final SparseBlock.Type type2; + + public SparseBlockEqualsSparseBlockTest(SparseBlock.Type type1, SparseBlock.Type type2) { + this.type1 = type1; + this.type2 = type2; + } + + @Parameterized.Parameters(name = "{0} vs {1}") + public static Iterable types() { + SparseBlock.Type[] types = SparseBlock.Type.values(); + ArrayList params = new ArrayList<>(); + + for (int i = 0; i < types.length; i++) { + for (int j = i; j < types.length; j++) { + params.add(new Object[]{types[i], types[j]}); + } + } + + return params; + } + + @Test + public void testSparseBlockEquals() { + runSparseBlockEqualsTest(type1, type2); + } + + @Test + public void testSparseBlockNotEqualsColIdx() { + runSparseBlockNotEqualsColIdxTest(type1, type2); + } + + @Test + public void testSparseBlockNotEqualsEmptyRow() { + runSparseBlockNotEqualsEmptyRowTest(type1, type2); + } + } + + @RunWith(Parameterized.class) + public static class SparseBlockEqualsDenseValuesTest extends AutomatedTestBase { + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + private final SparseBlock.Type type; + + public SparseBlockEqualsDenseValuesTest(SparseBlock.Type type) { + this.type = type; + } + + @Parameterized.Parameters(name = "{0}") + public static Iterable types() { + ArrayList params = new ArrayList<>(); + for (SparseBlock.Type t : SparseBlock.Type.values()) { + params.add(new Object[]{t}); + } + return params; + } + + @Test + public void testSparseBlockNotEqualsNonSparseBlock() { + runSparseBlockNotEqualsNonSparseBlockTest(type); + } + + @Test + public void testSparseBlockNotEqualsDenseValuesEmptyRow() { + runSparseBlockNotEqualsDenseValuesEmptyRowTest(type); + } + + @Test + public void testSparseBlockNotEqualsDenseValuesNonZero() { + runSparseBlockNotEqualsDenseValuesNonZeroTest(type); + } + + @Test + public void testSparseBlockNotEqualsDenseValuesAdditionalNonZero() { + runSparseBlockNotEqualsDenseValuesAdditionalNonZeroTest(type); + } + } + + private static void runSparseBlockEqualsTest(SparseBlock.Type type1, SparseBlock.Type type2) { + double[][] A = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0., 4., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock1 = SparseBlockFactory.copySparseBlock(type1, srtmp, true); + SparseBlock sblock2 = SparseBlockFactory.copySparseBlock(type2, srtmp, true); + + assertEquals(sblock1, sblock2); + } + + private static void runSparseBlockNotEqualsColIdxTest(SparseBlock.Type type1, SparseBlock.Type type2) { + double[][] A = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0., 4., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}}; + double[][] B = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0., 0., 4.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}}; + + MatrixBlock mbtmp1 = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp1 = mbtmp1.getSparseBlock(); + SparseBlock sblock1 = SparseBlockFactory.copySparseBlock(type1, srtmp1, true); + + MatrixBlock mbtmp2 = DataConverter.convertToMatrixBlock(B); + SparseBlock srtmp2 = mbtmp2.getSparseBlock(); + SparseBlock sblock2 = SparseBlockFactory.copySparseBlock(type2, srtmp2, true); + + assertNotEquals("should not be equal: " + type1 + " " + type2, sblock1, sblock2); + } + + private static void runSparseBlockNotEqualsEmptyRowTest(SparseBlock.Type type1, SparseBlock.Type type2) { + double[][] A = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0., 4., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}}; + double[][] B = new double[][]{{1., 2., 3.}, {0., 4., 0.}, {0., 0., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}}; + + MatrixBlock mbtmp1 = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp1 = mbtmp1.getSparseBlock(); + SparseBlock sblock1 = SparseBlockFactory.copySparseBlock(type1, srtmp1, true); + + MatrixBlock mbtmp2 = DataConverter.convertToMatrixBlock(B); + SparseBlock srtmp2 = mbtmp2.getSparseBlock(); + SparseBlock sblock2 = SparseBlockFactory.copySparseBlock(type2, srtmp2, true); + + assertNotEquals("should not be equal: " + type1 + " " + type2, sblock1, sblock2); + } + + private static void runSparseBlockNotEqualsNonSparseBlockTest(SparseBlock.Type type) { + double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(type, srtmp, true); + + SparseRowVector srv = new SparseRowVector(A[0], new int[]{0, 1, 2}); + + assertNotEquals("should not be equal: " + type, sblock, srv); + } + + private static void runSparseBlockNotEqualsDenseValuesEmptyRowTest(SparseBlock.Type type) { + double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0., 0., 0.}, {4., 0., 6.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(type, srtmp, true); + + double[] denseValues = new double[]{1., 0., 3., 0., 0., 0., 1., 1., 1., 4., 0., 6.}; + + assertFalse("should not be equal: " + type, sblock.equals(denseValues, 3, 1e-10)); + } + + private static void runSparseBlockNotEqualsDenseValuesNonZeroTest(SparseBlock.Type type) { + double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0., 0., 0.},{0., 0., 1.}, {4., 0., 6.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(type, srtmp, true); + + double[] denseValues = new double[]{1., 0., 3., 0., 0., 0., 0., 0., 0., 0., 1., 1., 4., 0., 6.}; + + assertFalse("should not be equal: " + type, sblock.equals(denseValues, 3, 1e-10)); + } + + private static void runSparseBlockNotEqualsDenseValuesAdditionalNonZeroTest(SparseBlock.Type type) { + double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0., 0., 0.}, {4., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(type, srtmp, true); + + double[] denseValues = new double[]{1., 0., 3., 0., 0., 0., 0., 0., 0., 4., 0., 6.}; + + assertFalse("should not be equal: " + type, sblock.equals(denseValues, 3, 1e-10)); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java index 57cee617fbb..f1429c4d650 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java @@ -42,7 +42,10 @@ public class SparseBlockIterator extends AutomatedTestBase { private final static int rows = 324; private final static int cols = 100; - private final static int rlPartial = 134; + private final static int rlVal = 134; + private final static int ruVal = 253; + private final static int clVal = 34; + private final static int cuVal = 53; private final static double sparsity1 = 0.1; private final static double sparsity2 = 0.2; private final static double sparsity3 = 0.3; @@ -54,187 +57,367 @@ public void setUp() { @Test public void testSparseBlockMCSR1Full() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1, false); + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1, false, false); } @Test public void testSparseBlockMCSR2Full() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2, false); + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2, false, false); } @Test public void testSparseBlockMCSR3Full() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3, false); + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3, false, false); + } + + @Test + public void testSparseBlockMCSR1RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1, true, false); + } + + @Test + public void testSparseBlockMCSR2RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2, true, false); + } + + @Test + public void testSparseBlockMCSR3RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3, true, false); + } + + @Test + public void testSparseBlockMCSR1RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1, false, true); + } + + @Test + public void testSparseBlockMCSR2RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2, false, true); + } + + @Test + public void testSparseBlockMCSR3RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3, false, true); } @Test public void testSparseBlockMCSR1Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1, true); + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1, true, true); } @Test public void testSparseBlockMCSR2Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2, true); + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2, true, true); } @Test public void testSparseBlockMCSR3Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3, true); + runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3, true, true); } @Test public void testSparseBlockCSR1Full() { - runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1, false); + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1, false, false); } @Test public void testSparseBlockCSR2Full() { - runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2, false); + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2, false, false); } @Test public void testSparseBlockCSR3Full() { - runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3, false); + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3, false, false); + } + + @Test + public void testSparseBlockCSR1RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1, true, false); + } + + @Test + public void testSparseBlockCSR2RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2, true, false); + } + + @Test + public void testSparseBlockCSR3RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3, true, false); + } + + @Test + public void testSparseBlockCSR1RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1, false, true); + } + + @Test + public void testSparseBlockCSR2RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2, false, true); + } + + @Test + public void testSparseBlockCSR3RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3, false, true); } @Test public void testSparseBlockCSR1Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1, true); + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1, true, true); } @Test public void testSparseBlockCSR2Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2, true); + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2, true, true); } @Test public void testSparseBlockCSR3Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3, true); + runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3, true, true); } @Test public void testSparseBlockCOO1Full() { - runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1, false); + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1, false, false); } @Test public void testSparseBlockCOO2Full() { - runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2, false); + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2, false, false); } @Test public void testSparseBlockCOO3Full() { - runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3, false); + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3, false, false); + } + + @Test + public void testSparseBlockCOO1RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1, true, false); + } + + @Test + public void testSparseBlockCOO2RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2, true, false); + } + + @Test + public void testSparseBlockCOO3RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3, true, false); + } + + @Test + public void testSparseBlockCOO1RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1, false, true); + } + + @Test + public void testSparseBlockCOO2RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2, false, true); + } + + @Test + public void testSparseBlockCOO3RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3, false, true); } @Test public void testSparseBlockCOO1Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1, true); + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1, true, true); } @Test public void testSparseBlockCOO2Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2, true); + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2, true, true); } @Test public void testSparseBlockCOO3Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3, true); + runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3, true, true); } @Test public void testSparseBlockDCSR1Full() { - runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1, false); + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1, false, false); } @Test public void testSparseBlockDCSR2Full() { - runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2, false); + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2, false, false); } @Test public void testSparseBlockDCSR3Full() { - runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3, false); + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3, false, false); + } + + @Test + public void testSparseBlockDCSR1RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1, true, false); + } + + @Test + public void testSparseBlockDCSR2RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2, true, false); + } + + @Test + public void testSparseBlockDCSR3RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3, true, false); + } + + @Test + public void testSparseBlockDCSR1RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1, false, true); + } + + @Test + public void testSparseBlockDCSR2RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2, false, true); + } + + @Test + public void testSparseBlockDCSR3RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3, false, true); } @Test public void testSparseBlockDCSR1Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1, true); + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1, true, true); } @Test public void testSparseBlockDCSR2Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2, true); + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2, true, true); } @Test public void testSparseBlockDCSR3Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3, true); + runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3, true, true); } @Test public void testSparseBlockMCSC1Full() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1, false); + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1, false, false); } @Test public void testSparseBlockMCSC2Full() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2, false); + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2, false, false); } @Test public void testSparseBlockMCSC3Full() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3, false); + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3, false, false); + } + + @Test + public void testSparseBlockMCSC1RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1, true, false); + } + + @Test + public void testSparseBlockMCSC2RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2, true, false); + } + + @Test + public void testSparseBlockMCSC3RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3, true, false); + } + + @Test + public void testSparseBlockMCSC1RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1, false, true); + } + + @Test + public void testSparseBlockMCSC2RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2, false, true); + } + + @Test + public void testSparseBlockMCSC3RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3, false, true); } @Test public void testSparseBlockMCSC1Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1, true); + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1, true, true); } @Test public void testSparseBlockMCSC2Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2, true); + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2, true, true); } @Test public void testSparseBlockMCSC3Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3, true); + runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3, true, true); } @Test public void testSparseBlockCSC1Full() { - runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1, false); + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1, false, false); } @Test public void testSparseBlockCSC2Full() { - runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2, false); + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2, false, false); } @Test public void testSparseBlockCSC3Full() { - runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3, false); + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3, false, false); + } + + @Test + public void testSparseBlockCSC1RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1, true, false); + } + + @Test + public void testSparseBlockCSC2RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2, true, false); + } + + @Test + public void testSparseBlockCSC3RlPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3, true, false); + } + + @Test + public void testSparseBlockCSC1RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1, false, true); + } + + @Test + public void testSparseBlockCSC2RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2, false, true); + } + + @Test + public void testSparseBlockCSC3RuPartial() { + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3, false, true); } @Test public void testSparseBlockCSC1Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1, true); + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1, true, true); } @Test public void testSparseBlockCSC2Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2, true); + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2, true, true); } @Test public void testSparseBlockCSC3Partial() { - runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3, true); + runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3, true, true); } - private void runSparseBlockIteratorTest(SparseBlock.Type btype, double sparsity, boolean partial) { + private void runSparseBlockIteratorTest(SparseBlock.Type btype, double sparsity, boolean rlPartial, boolean ruPartial) { try { //data generation double[][] A = getRandomMatrix(rows, cols, -10, 10, sparsity, 8765432); @@ -247,25 +430,34 @@ private void runSparseBlockIteratorTest(SparseBlock.Type btype, double sparsity, //check for correct number of non-zeros int[] rnnz = new int[rows]; int nnz = 0; - int rl = partial ? rlPartial : 0; - for(int i = rl; i < rows; i++) { - for(int j = 0; j < cols; j++) + int rl = rlPartial ? rlVal : 0; + int ru = ruPartial ? ruVal : rows; + int cl = rlPartial && ruPartial ? clVal : 0; + int cu = rlPartial && ruPartial ? cuVal : cols; + for(int i = rl; i < ru; i++) { + for(int j = cl; j < cu; j++) rnnz[i] += (A[i][j] != 0) ? 1 : 0; nnz += rnnz[i]; } - if(!partial && nnz != sblock.size()) + if(!rlPartial && !ruPartial && nnz != sblock.size()) // no restriction Assert.fail("Wrong number of non-zeros: " + sblock.size() + ", expected: " + nnz); //check correct isEmpty return - for(int i = rl; i < rows; i++) - if(sblock.isEmpty(i) != (rnnz[i] == 0)) - Assert.fail("Wrong isEmpty(row) result for row nnz: " + rnnz[i]); + if(!(rlPartial && ruPartial)) { // cols not restricted + for(int i = rl; i < ru; i++) + if(sblock.isEmpty(i) != (rnnz[i] == 0)) + Assert.fail("Wrong isEmpty(row) result for row nnz: " + rnnz[i]); + } //check correct values - Iterator iter = !partial ? sblock.getIterator() : sblock.getIterator(rl, rows); + Iterator iter = rlPartial && ruPartial ? sblock.getIterator(rl, ru, cl, cu): rlPartial? sblock.getIterator(rl, rows) : ruPartial? sblock.getIterator(ru) : sblock.getIterator(); int count = 0; while(iter.hasNext()) { IJV cell = iter.next(); + if(cell.getI() < rl || cell.getI() >= ru) + Assert.fail("iterator row outside of range"); + if(cell.getJ() < cl || cell.getJ() >= cu) + Assert.fail("iterator column outside of range"); if(cell.getV() != A[cell.getI()][cell.getJ()]) Assert.fail("Wrong value returned by iterator: " + cell.getV() + ", expected: " + A[cell.getI()][cell.getJ()]); @@ -277,11 +469,9 @@ private void runSparseBlockIteratorTest(SparseBlock.Type btype, double sparsity, // check iterator over non-zero rows List manualNonZeroRows = new ArrayList<>(); List iteratorNonZeroRows = new ArrayList<>(); - Iterator iterRows = !partial - ? sblock.getNonEmptyRowsIterator(0, rows) - : sblock.getNonEmptyRowsIterator(rl, rows); + Iterator iterRows = sblock.getNonEmptyRowsIterator(rl, ru); - for(int i = rl; i < rows; i++) + for(int i = rl; i < ru; i++) if(!sblock.isEmpty(i)) manualNonZeroRows.add(i); while(iterRows.hasNext()) { @@ -293,6 +483,16 @@ private void runSparseBlockIteratorTest(SparseBlock.Type btype, double sparsity, Assert.fail("Verification of iterator over non-zero rows failed."); } + // check second iterator over non-zero rows + Iterator iterRows2 = !rlPartial && !ruPartial? sblock.getNonEmptyRows().iterator() : sblock.getNonEmptyRows(rl, ru).iterator(); + List iter2NonZeroRows = new ArrayList<>(); + + while(iterRows2.hasNext()) { + iter2NonZeroRows.add(iterRows2.next()); + } + if(!manualNonZeroRows.equals(iter2NonZeroRows)) { + Assert.fail("Verification of second iterator over non-zero rows failed."); + } } catch(Exception ex) { ex.printStackTrace(); From 263d945918ad695a4ef2b0d850dab9164f8d9c1c Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Tue, 6 Jan 2026 12:37:00 +0100 Subject: [PATCH 03/12] add SparseBlockFactory tests --- .../runtime/data/SparseBlockFactory.java | 7 ++- .../sparse/SparseBlockInitializationTest.java | 63 +++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java index 22dfe5417eb..e97c24291d2 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java @@ -37,6 +37,8 @@ public static SparseBlock createSparseBlock(SparseBlock.Type type, int rlen) { case CSR: return new SparseBlockCSR(rlen); case COO: return new SparseBlockCOO(rlen); case DCSR: return new SparseBlockDCSR(rlen); + case MCSC: return new SparseBlockMCSC(rlen); + case CSC: return new SparseBlockCSC(rlen, 0); default: throw new RuntimeException("Unexpected sparse block type: "+type.toString()); } @@ -84,7 +86,10 @@ public static boolean isSparseBlockType(SparseBlock sblock, SparseBlock.Type typ public static SparseBlock.Type getSparseBlockType(SparseBlock sblock) { return (sblock instanceof SparseBlockMCSR) ? SparseBlock.Type.MCSR : (sblock instanceof SparseBlockCSR) ? SparseBlock.Type.CSR : - (sblock instanceof SparseBlockCOO) ? SparseBlock.Type.COO : null; + (sblock instanceof SparseBlockCOO) ? SparseBlock.Type.COO : + (sblock instanceof SparseBlockDCSR) ? SparseBlock.Type.DCSR : + (sblock instanceof SparseBlockMCSC) ? SparseBlock.Type.MCSC : + (sblock instanceof SparseBlockCSC) ? SparseBlock.Type.CSC : null; } public static long estimateSizeSparseInMemory(SparseBlock.Type type, long nrows, long ncols, double sparsity) { diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java new file mode 100644 index 00000000000..79f4bd280af --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class SparseBlockInitializationTest extends AutomatedTestBase +{ + private final static int _cols = 132; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSparseBlockCreationCOO() { + runSparseBlockCreationTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockCreationDCSR() { + runSparseBlockCreationTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockCreationMCSC() { + runSparseBlockCreationTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockCreationCSC() { + runSparseBlockCreationTest(SparseBlock.Type.CSC); + } + + private void runSparseBlockCreationTest(SparseBlock.Type type) { + SparseBlock sblock = SparseBlockFactory.createSparseBlock(type, _cols); + assertEquals(SparseBlockFactory.getSparseBlockType(sblock), type); + } +} From a3527c9687cf76fc4269eeb8293452037772ea02 Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Tue, 6 Jan 2026 16:24:14 +0100 Subject: [PATCH 04/12] add SparseBlock checkValidity tests --- .../sysds/runtime/data/SparseBlockCOO.java | 12 +- .../sysds/runtime/data/SparseBlockCSC.java | 21 +- .../sysds/runtime/data/SparseBlockCSR.java | 8 +- .../sysds/runtime/data/SparseBlockDCSR.java | 15 +- .../sysds/runtime/data/SparseBlockMCSC.java | 39 +- .../sysds/runtime/data/SparseBlockMCSR.java | 21 +- .../sparse/SparseBlockCheckValidityTest.java | 633 ++++++++++++++++++ 7 files changed, 689 insertions(+), 60 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java index c4e60c10cfd..d4d0467a098 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java @@ -226,7 +226,7 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { } //3.1. sort order of row indices - for( int i=1; i<=nnz; i++ ) { + for( int i=1; i= _cindexes[k] ) + for(int k=apos+1; k _cindexes[k] ) throw new RuntimeException("Wrong sparse row ordering: " + k + " "+_cindexes[k-1]+" "+_cindexes[k]); - for( int k=apos; k nnz*RESIZE_FACTOR1 ) { + if( capacity > INIT_CAPACITY && capacity > nnz*RESIZE_FACTOR1 ) { throw new RuntimeException("Capacity is larger than the nnz times a resize factor." + " Current size: "+capacity+ ", while Expected size:"+nnz*RESIZE_FACTOR1); } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java index b38c3525c97..1528c4e189d 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java @@ -71,8 +71,8 @@ public SparseBlockCSC(int clen, int capacity, int size) { _size = size; } - public SparseBlockCSC(int[] rowPtr, int[] rowInd, double[] values, int nnz) { - _ptr = rowPtr; + public SparseBlockCSC(int[] colPtr, int[] rowInd, double[] values, int nnz) { + _ptr = colPtr; _indexes = rowInd; _values = values; _size = nnz; @@ -382,7 +382,7 @@ public int numRows() { if(_rlen > -1) return _rlen; else { - int rlen = Arrays.stream(_indexes).max().getAsInt(); + int rlen = Arrays.stream(_indexes).max().getAsInt()+1; _rlen = rlen; return rlen; } @@ -554,8 +554,8 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { throw new RuntimeException("Incorrect array lengths."); } - //3. non-decreasing row pointers - for(int i = 1; i < clen; i++) { + //3. non-decreasing col pointers + for(int i = 1; i <= clen; i++) { if(_ptr[i - 1] > _ptr[i] && strict) throw new RuntimeException( "Column pointers are decreasing at column: " + i + ", with pointers " + _ptr[i - 1] + " > " + @@ -569,10 +569,9 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { for(int k = apos + 1; k < apos + alen; k++) if(_indexes[k - 1] >= _indexes[k]) throw new RuntimeException( - "Wrong sparse column ordering: " + k + " " + _indexes[k - 1] + " " + _indexes[k]); - for(int k = apos; k < apos + alen; k++) - if(_values[k] == 0) - throw new RuntimeException("Wrong sparse column: zero at " + k + " at row index " + _indexes[k]); + "Wrong sparse column ordering, at column=" + i + ", pos=" + k + " with row indexes " + + _indexes[k - 1] + ">=" + _indexes[k] + ); } //5. non-existing zero values @@ -585,7 +584,7 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { //6. a capacity that is no larger than nnz times resize factor. int capacity = _values.length; - if(capacity > nnz * RESIZE_FACTOR1) { + if(capacity > INIT_CAPACITY && capacity > nnz * RESIZE_FACTOR1) { throw new RuntimeException( "Capacity is larger than the nnz times a resize factor." + " Current size: " + capacity + ", while Expected size:" + nnz * RESIZE_FACTOR1); @@ -1059,7 +1058,7 @@ public int posFIndexGTCol(int r, int c) { @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append("SparseBlockCSR: clen="); + sb.append("SparseBlockCSC: clen="); sb.append(numCols()); sb.append(", nnz="); sb.append(size()); diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java index a40c567dfb2..a6000ff965f 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java @@ -942,7 +942,7 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { } //3. non-decreasing row pointers - for( int i=1; i _ptr[i] && strict) throw new RuntimeException("Row pointers are decreasing at row: "+i + ", with pointers "+_ptr[i-1]+" > "+_ptr[i]); @@ -956,10 +956,6 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { if( _indexes[k-1] >= _indexes[k] ) throw new RuntimeException("Wrong sparse row ordering: " + k + " "+_indexes[k-1]+" "+_indexes[k]); - for( int k=apos; k nnz*RESIZE_FACTOR1 ) { + if(capacity > INIT_CAPACITY && capacity > nnz*RESIZE_FACTOR1 ) { throw new RuntimeException("Capacity is larger than the nnz times a resize factor." + " Current size: "+capacity+ ", while Expected size:"+nnz*RESIZE_FACTOR1); } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java index b369992efa7..3308646c330 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java @@ -670,7 +670,7 @@ public int posFIndexGT(int r, int c) { @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append("SparseBlockCSR: rlen="); + sb.append("SparseBlockDCSR: rlen="); sb.append(numRows()); sb.append(", nnz="); sb.append(size()); @@ -724,19 +724,14 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { } //4. sorted column indexes per row - for ( int rowIdx = 0; rowIdx < _rowidx.length; rowIdx++ ) { - int apos = _rowidx[rowIdx]; - int alen = _rowidx[rowIdx+1] - apos; + for (int i = 0; i < _rowptr.length-1; i++) { + int apos = _rowptr[i]; + int alen = _rowptr[i+1] - apos; for( int k = apos + 1; k < apos + alen; k++) if( _colidx[k-1] >= _colidx[k] ) throw new RuntimeException("Wrong sparse row ordering: " + k + " " + _colidx[k-1] + " " + _colidx[k]); - - for( int k=apos; k nnz*RESIZE_FACTOR1 ) { + if(capacity > INIT_CAPACITY && capacity > nnz*RESIZE_FACTOR1 ) { throw new RuntimeException("Capacity is larger than the nnz times a resize factor." + " Current size: "+capacity+ ", while Expected size:"+nnz*RESIZE_FACTOR1); } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java index fd0b3906bcd..405ff578a00 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java @@ -259,7 +259,7 @@ public void allocate(int r, int nnz) { } public void allocateCol(int c, int nnz) { - if(!isAllocated(c)) { + if(!isAllocatedCol(c)) { _columns[c] = (nnz == 1) ? new SparseRowScalar() : new SparseRowVector(nnz); } } @@ -270,7 +270,7 @@ public void allocate(int r, int ennz, int maxnnz) { } public void allocateCol(int c, int ennz, int maxnnz) { - if(!isAllocated(c)) { + if(!isAllocatedCol(c)) { _columns[c] = (ennz == 1) ? new SparseRowScalar() : new SparseRowVector(ennz, maxnnz); } } @@ -283,7 +283,7 @@ public void compact(int r) { } public void compactCol(int c) { - if(isAllocated(c)) { + if(isAllocatedCol(c)) { if(_columns[c] instanceof SparseRowVector && _columns[c].size() > SparseBlock.INIT_CAPACITY && _columns[c].size() * SparseBlock.RESIZE_FACTOR1 < ((SparseRowVector) _columns[c]).capacity()) { ((SparseRowVector) _columns[c]).compact(); @@ -386,7 +386,7 @@ public int size(int r) { public int sizeCol(int c) { //prior check with isEmpty(r) expected - return isAllocated(c) ? _columns[c].size() : 0; + return isAllocatedCol(c) ? _columns[c].size() : 0; } @Override @@ -404,7 +404,7 @@ public long size(int rl, int ru) { public long sizeCol(int cl, int cu) { long nnz = 0; for(int i = cl; i < cu; i++) { - nnz += isAllocated(i) ? _columns[i].size() : 0; + nnz += isAllocatedCol(i) ? _columns[i].size() : 0; } return nnz; } @@ -449,31 +449,34 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { //3. Sorted column indices per row for(int i = 0; i < clen; i++) { - if(isEmpty(i)) - continue; + if(isEmptyCol(i)) continue; int apos = pos(i); - int alen = size(i); - int[] aix = indexes(i); - double[] avals = values(i); - for(int k = apos + 1; k < apos + alen; k++) { - if(aix[k - 1] >= aix[k] | aix[k - 1] < 0) { + int alen = sizeCol(i); + int[] aix = indexesCol(i); + double[] avals = valuesCol(i); + + int prevRow = -1; + for(int k = apos; k < apos + alen; k++) { + if(aix[k] < 0) + throw new RuntimeException("Invalid index, at column=" + i + ", pos=" + k); + if(aix[k] <= prevRow) throw new RuntimeException( "Wrong sparse column ordering, at column=" + i + ", pos=" + k + " with row indexes " + - aix[k - 1] + ">=" + aix[k]); - } - if(avals[k] == 0) { + prevRow + ">=" + aix[k]); + if(avals[k] == 0) throw new RuntimeException( "The values are expected to be non zeros " + "but zero at column: " + i + ", row pos: " + k); - } + prevRow = aix[k]; } } + //4. A capacity that is no larger than nnz times resize factor for(int i = 0; i < clen; i++) { long max_size = (long) Math.max(nnz * RESIZE_FACTOR1, INIT_CAPACITY); - if(!isEmpty(i) && values(i).length > max_size) { + if(!isEmptyCol(i) && valuesCol(i).length > max_size) { throw new RuntimeException( "The capacity is larger than nnz times a resize factor(=2). " + "Actual length = " + - values(i).length + ", should not exceed " + max_size); + valuesCol(i).length + ", should not exceed " + max_size); } } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java index f94b6bf7f45..9a77646da14 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java @@ -238,13 +238,20 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { int alen = size(i); int[] aix = indexes(i); double[] avals = values(i); - for (int k = apos + 1; k < apos + alen; k++) { - if (aix[k-1] >= aix[k] | aix[k-1] < 0 ) - throw new RuntimeException("Wrong sparse row ordering, at row="+i+", pos="+k - + " with column indexes " + aix[k-1] + ">=" + aix[k]); - if (avals[k] == 0) - throw new RuntimeException("The values are expected to be non zeros " - + "but zero at row: "+ i + ", col pos: " + k); + + int prevCol = -1; + for (int k = apos; k < apos + alen; k++) { + if(aix[k] < 0) + throw new RuntimeException( + "Invalid index, at column=" + i + ", pos=" + k); + if(aix[k] <= prevCol) + throw new RuntimeException( + "Wrong sparse row ordering, at row=" + i + ", pos=" + k + " with column indexes " + + prevCol + ">=" + aix[k]); + if(avals[k] == 0) + throw new RuntimeException( + "The values are expected to be non zeros " + "but zero at column: " + i + ", row pos: " + k); + prevCol = aix[k]; } } diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java new file mode 100644 index 00000000000..1d6ad261201 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java @@ -0,0 +1,633 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockCOO; +import org.apache.sysds.runtime.data.SparseBlockCSC; +import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockDCSR; +import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.apache.sysds.runtime.data.SparseBlockMCSC; +import org.apache.sysds.runtime.data.SparseBlockMCSR; +import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.data.SparseRowVector; + +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.lang.reflect.Field; +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class SparseBlockCheckValidityTest extends AutomatedTestBase +{ + private final static int _rows = 123; + private final static int _cols = 97; + private final static double _sparsity = 0.22; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSparseBlockCOOValid() { + runSparseBlockValidTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockCSCValid() { + runSparseBlockValidTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockCSRValid() { + runSparseBlockValidTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockDCSRValid() { + runSparseBlockValidTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockMCSCValid() { + runSparseBlockValidTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockMCSRValid() { + runSparseBlockValidTest(SparseBlock.Type.MCSR); + } + + @Test + public void testSparseBlockCOOInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockCOO(-1, 0)); + } + + @Test + public void testSparseBlockCSCInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockCSC(-1, 0)); + } + + @Test + public void testSparseBlockCSRInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockCSR(-1, 0)); + } + + @Test + public void testSparseBlockDCSRInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockDCSR(0, 0)); + } + + @Test + public void testSparseBlockMCSCInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockMCSC(-1, 0)); + } + + @Test + public void testSparseBlockMCSRInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockMCSR(0, -1)); + } + + @Test + public void testSparseBlockCOOIncorrectArrayLengths() { + SparseBlockCOO sblock = new SparseBlockCOO(2, 2); + // nnz > capacity + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(2, 2, 4, false)); + + assertEquals("Incorrect array lengths.", ex.getMessage()); + } + + @Test + public void testSparseBlockCSCIncorrectArrayLengths() { + SparseBlockCSC sblock = new SparseBlockCSC(2, 2, 1); + // nnz > capacity + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(2, 3, 6, false)); + + assertEquals("Incorrect array lengths.", ex.getMessage()); + } + + @Test + public void testSparseBlockCSRIncorrectArrayLengths() { + SparseBlockCSR sblock = new SparseBlockCSR(2, 2, 1); + // nnz > capacity + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(3, 2, 6, false)); + + assertEquals("Incorrect array lengths.", ex.getMessage()); + } + + @Test + public void testSparseBlockDCSRIncorrectArrayLengths() { + SparseBlockDCSR sblock = new SparseBlockDCSR(2, 1); + + // cut off last value + int[] rowptr = (int[]) getField(sblock,"_rowptr"); + setField(sblock, "_rowptr", Arrays.copyOfRange(rowptr, 0, rowptr.length-1)); + // nnz > capacity + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(3, 2, 6, false)); + + assertEquals("Incorrect array lengths.", ex.getMessage()); + } + + @Test + public void testSparseBlockMCSCIncorrectArrayLengths() { + SparseBlockMCSC sblock = new SparseBlockMCSC(2, 2); + + // nnz > capacity + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(3, 2, 1, false)); + + assertTrue(ex.getMessage().startsWith("Incorrect size")); + } + + @Test + public void testSparseBlockMCSRIncorrectArrayLengths() { + SparseBlockMCSR sblock = new SparseBlockMCSR(2, 2); + + // nnz > capacity + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(3, 2, 1, false)); + + assertTrue(ex.getMessage().startsWith("Incorrect size")); + } + + @Test + public void testSparseBlockCOOUnsortedRowIndices() { + SparseBlockCOO block = new SparseBlockCOO(10, 3); + + int[] r = new int[]{0, 5, 2}; // unsorted + int[] c = new int[]{0, 1, 2}; + double[] v = new double[]{1, 1, 1}; + + setField(block, "_rindexes", r); + setField(block, "_cindexes", c); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(10, 10, 3, false)); + + assertEquals("Wrong sorted order of row indices", ex.getMessage()); + } + + @Test + public void testSparseBlockCSCDecreasingColPointers() { + SparseBlockCSC block = new SparseBlockCSC(10, 3); + + int[] ptr = new int[]{0, 2, 1, 3}; // unsorted col pointer + int[] idxs = new int[]{0, 1, 2}; + double[] v = new double[]{1, 1, 1}; + + setField(block, "_ptr", ptr); + setField(block, "_indexes", idxs); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(10, 3, 3, true)); + + assertTrue(ex.getMessage().startsWith("Column pointers are decreasing at column")); + } + + @Test + public void testSparseBlockCSRDecreasingRowPointers() { + SparseBlockCSR block = new SparseBlockCSR(3, 3); + + int[] ptr = new int[]{0, 2, 1, 3}; // unsorted row pointer + int[] idxs = new int[]{0, 1, 2}; + double[] v = new double[]{1, 1, 1}; + + setField(block, "_ptr", ptr); + setField(block, "_indexes", idxs); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 10, 3, true)); + + assertTrue(ex.getMessage().startsWith("Row pointers are decreasing at row")); + } + + @Test + public void testSparseBlockDCSRDecreasingRowIndices() { + SparseBlockDCSR block = new SparseBlockDCSR(3, 3); + + int[] rowIdxs = new int[]{0, 2, 1}; // unsorted + int[] rowPtr = new int[]{0, 1, 2, 3}; + int[] colIdxs = new int[]{0, 1, 2}; + double[] v = new double[]{1, 1, 1}; + + setField(block, "_rowidx", rowIdxs); + setField(block, "_rowptr", rowPtr); + setField(block, "_colidx", colIdxs); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 10, 3, true)); + + assertTrue(ex.getMessage().startsWith("Row indices are decreasing at row")); + } + + @Test + public void testSparseBlockDCSRDecreasingRowPointers() { + SparseBlockDCSR block = new SparseBlockDCSR(3, 3); + + int[] rowIdxs = new int[]{0, 1, 2}; + int[] rowPtr = new int[]{0, 1, 3, 2}; // unsorted + int[] colIdxs = new int[]{0, 1, 2}; + double[] v = new double[]{1, 1, 1}; + + setField(block, "_rowidx", rowIdxs); + setField(block, "_rowptr", rowPtr); + setField(block, "_colidx", colIdxs); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 10, 3, true)); + + assertTrue(ex.getMessage().startsWith("Row pointers are decreasing at row")); + } + + @Test + public void testSparseBlockCOOUnsortedColumnIndicesWithinRow() { + SparseBlockCOO block = new SparseBlockCOO(1, 3); + + int[] r = new int[]{0, 0, 0}; + int[] c = new int[]{0, 2, 1}; // unsorted for row 0 + double[] v = new double[]{1, 1, 1}; + + setField(block, "_rindexes", r); + setField(block, "_cindexes", c); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(1, 3, 3, false)); + + assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); + } + + @Test + public void testSparseBlockCSCUnsortedRowIndicesWithinColumn() { + SparseBlockCSC block = new SparseBlockCSC(10, 3); + + int[] ptr = new int[]{0, 3, 3, 3}; + int[] idxs = new int[]{0, 2, 1}; // unsorted + double[] v = new double[]{1, 1, 1}; + + setField(block, "_ptr", ptr); + setField(block, "_indexes", idxs); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(10, 3, 3, true)); + + assertTrue(ex.getMessage().startsWith("Wrong sparse column ordering")); + } + + @Test + public void testSparseBlockCSRUnsortedColumnIndicesWithinRow() { + SparseBlockCSR block = new SparseBlockCSR(3, 3); + + int[] ptr = new int[]{0, 3, 3, 3}; + int[] idxs = new int[]{0, 2, 1}; // unsorted + double[] v = new double[]{1, 1, 1}; + + setField(block, "_ptr", ptr); + setField(block, "_indexes", idxs); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(1, 3, 3, false)); + + assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); + } + + @Test + public void testSparseBlockDCSRUnsortedColumnIndicesWithinRow() { + SparseBlockDCSR block = new SparseBlockDCSR(3, 3); + + int[] rowIdxs = new int[]{0, 2}; + int[] rowPtr = new int[]{0, 1, 3}; + int[] colIdxs = new int[]{0, 2, 1}; // for row 2 unsorted + double[] v = new double[]{1, 1, 1}; + + setField(block, "_rowidx", rowIdxs); + setField(block, "_rowptr", rowPtr); + setField(block, "_colidx", colIdxs); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(1, 3, 3, false)); + + assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); + } + + @Test + public void testSparseBlockMCSCUnsortedRowIndicesWithinColumn() { + SparseBlockMCSC block = new SparseBlockMCSC(10, 3); + + SparseRow col = new SparseRowVector(new double[]{1., 1., 1.}, new int[]{0, 2, 1}); // unsorted + SparseRow[] cols = new SparseRow[]{null, null, col}; + setField(block, "_columns", cols); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(10, 3, 3, true)); + + assertTrue(ex.getMessage().startsWith("Wrong sparse column ordering")); + } + + @Test + public void testSparseBlockMCSRUnsortedColumnIndicesWithinRow() { + SparseBlockMCSR block = new SparseBlockMCSR(3, 10); + + SparseRow row = new SparseRowVector(new double[]{1., 1., 1.}, new int[]{0, 2, 1}); // unsorted + SparseRow[] rows = new SparseRow[]{null, null, row}; + setField(block, "_rows", rows); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 10, 3, true)); + + assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); + } + + @Test + public void testSparseBlockMCSCInvalidIndices() { + SparseBlockMCSC block = new SparseBlockMCSC(10, 3); + + SparseRow col = new SparseRowVector(new double[]{1., 1., 1.}, new int[]{-1, 0, 2}); + SparseRow[] cols = new SparseRow[]{null, null, col}; + setField(block, "_columns", cols); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(10, 3, 3, true)); + + assertTrue(ex.getMessage().startsWith("Invalid index")); + } + + @Test + public void testSparseBlockMCSRInvalidIndices() { + SparseBlockMCSR block = new SparseBlockMCSR(3, 10); + + SparseRow row = new SparseRowVector(new double[]{1., 1., 1.}, new int[]{-1, 0, 1}); + SparseRow[] rows = new SparseRow[]{null, null, row}; + setField(block, "_rows", rows); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 10, 3, true)); + + assertTrue(ex.getMessage().startsWith("Invalid index")); + } + + @Test + public void testSparseBlockCOOInvalidValue() { + SparseBlockCOO block = new SparseBlockCOO(3, 3); + + int[] r = new int[]{0, 1, 2}; + int[] c = new int[]{0, 1, 2}; + double[] v = new double[]{1, 2, 0}; // contains 0 + + setField(block, "_rindexes", r); + setField(block, "_cindexes", c); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 3, 3, false)); + + assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + } + + @Test + public void testSparseBlockCSCInvalidValue() { + SparseBlockCSC block = new SparseBlockCSC(3, 3); + + int[] ptr = new int[]{0, 3, 3, 3}; + int[] idxs = new int[]{0, 1, 2}; + double[] v = new double[]{1, 1, 0}; + + setField(block, "_ptr", ptr); + setField(block, "_indexes", idxs); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 3, 3, false)); + + assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + } + + @Test + public void testSparseBlockCSRInvalidValue() { + SparseBlockCSR block = new SparseBlockCSR(3, 3); + + int[] ptr = new int[]{0, 3, 3, 3}; + int[] idxs = new int[]{0, 1, 2}; + double[] v = new double[]{1, 1, 0}; + + setField(block, "_ptr", ptr); + setField(block, "_indexes", idxs); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 3, 3, false)); + + assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + } + + @Test + public void testSparseBlockDCSRInvalidValue() { + SparseBlockDCSR block = new SparseBlockDCSR(3, 3); + + int[] rowIdxs = new int[]{0, 1, 2}; + int[] rowPtr = new int[]{0, 1, 2, 3}; + int[] colIdxs = new int[]{0, 1, 2}; + double[] v = new double[]{1, 1, 0}; + + setField(block, "_rowidx", rowIdxs); + setField(block, "_rowptr", rowPtr); + setField(block, "_colidx", colIdxs); + setField(block, "_values", v); + setField(block, "_size", 3); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(1, 3, 3, false)); + + assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + } + + @Test + public void testSparseBlockMCSCInvalidValue() { + SparseBlockMCSC block = new SparseBlockMCSC(10, 3); + + SparseRow col = new SparseRowVector(new double[]{1., 1., 0.}, new int[]{0, 1, 2}); + SparseRow[] cols = new SparseRow[]{null, null, col}; + setField(block, "_columns", cols); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(10, 3, 3, true)); + + assertTrue(ex.getMessage().startsWith("The values are expected to be non zeros")); + } + + @Test + public void testSparseBlockMCSRInvalidValue() { + SparseBlockMCSR block = new SparseBlockMCSR(3, 10); + + SparseRow row = new SparseRowVector(new double[]{1., 1., 0.}, new int[]{0, 1, 2}); + SparseRow[] rows = new SparseRow[]{null, null, row}; + setField(block, "_rows", rows); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 10, 3, true)); + + assertTrue(ex.getMessage().startsWith("The values are expected to be non zeros")); + } + + @Test + public void testSparseBlockCOOCapacityExceedsAllowedLimit() { + SparseBlockCOO block = new SparseBlockCOO(3, 50); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 3, 0, false)); + + // RESIZE_FACTOR1 is 2 + assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); + } + + @Test + public void testSparseBlockCSCCapacityExceedsAllowedLimit() { + SparseBlockCSC block = new SparseBlockCSC(3, 50, 0); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 3, 0, false)); + + // RESIZE_FACTOR1 is 2 + assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); + } + + @Test + public void testSparseBlockCSRCapacityExceedsAllowedLimit() { + SparseBlockCSR block = new SparseBlockCSR(3, 50, 0); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 3, 0, false)); + + // RESIZE_FACTOR1 is 2 + assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); + } + + @Test + public void testSparseBlockDCSRCapacityExceedsAllowedLimit() { + SparseBlockDCSR block = new SparseBlockDCSR(3, 50); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 3, 0, false)); + + // RESIZE_FACTOR1 is 2 + assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); + } + + @Test + public void testSparseBlockMCSCCapacityExceedsAllowedLimit() { + SparseBlockMCSC block = new SparseBlockMCSC(10, 3); + + SparseRow col = new SparseRowVector(new double[]{1., 1., 1., 1., 1.}, new int[]{0, 1, 2, 3, 4}); + SparseRow[] cols = new SparseRow[]{null, null, col}; + setField(block, "_columns", cols); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(10, 3, 2, true)); + + assertTrue(ex.getMessage().startsWith("The capacity is larger than nnz times a resize factor")); + } + + @Test + public void testSparseBlockMCSRCapacityExceedsAllowedLimit() { + SparseBlockMCSR block = new SparseBlockMCSR(3, 10); + + SparseRow row = new SparseRowVector(new double[]{1., 1., 1., 1., 1.}, new int[]{0, 1, 2, 3, 4}); + SparseRow[] rows = new SparseRow[]{null, null, row}; + setField(block, "_rows", rows); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> block.checkValidity(3, 10, 2, true)); + + assertTrue(ex.getMessage().startsWith("The capacity is larger than nnz times a resize factor")); + } + + private void runSparseBlockValidTest(SparseBlock.Type btype) { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + assertTrue("should pass checkValidity", sblock.checkValidity(_rows, _cols, sblock.size(), true)); + } + + private void runSparseBlockInvalidDimensionsTest(SparseBlock block) { + RuntimeException ex1 = assertThrows(RuntimeException.class, + () -> block.checkValidity(-1, 1, 0, false)); + assertTrue(ex1.getMessage().startsWith("Invalid block dimensions")); + + RuntimeException ex2 = assertThrows(RuntimeException.class, + () -> block.checkValidity(1, -1, 0, false)); + assertTrue(ex2.getMessage().startsWith("Invalid block dimensions")); + } + + private static void setField(Object obj, String name, Object value) { + try { + Field f = obj.getClass().getDeclaredField(name); + f.setAccessible(true); + f.set(obj, value); + } catch (Exception ex) { + throw new RuntimeException("Reflection failed: " + ex.getMessage()); + } + } + + private static Object getField(Object obj, String name) { + try { + Field f = obj.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.get(obj); + } catch (Exception ex) { + throw new RuntimeException("Reflection failed: " + ex.getMessage()); + } + } +} From 02dbad2d6fc1fb8f6bed708a394d1ae8796f365e Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Tue, 6 Jan 2026 16:51:38 +0100 Subject: [PATCH 05/12] add SparseBlock initialization tests --- .../sysds/runtime/data/SparseBlockCSC.java | 32 +- .../sysds/runtime/data/SparseBlockDCSR.java | 11 - .../sysds/runtime/data/SparseBlockMCSC.java | 14 +- .../sparse/SparseBlockCheckValidityTest.java | 4 +- .../sparse/SparseBlockInitializationTest.java | 389 ++++++++++++++++++ 5 files changed, 404 insertions(+), 46 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java index 1528c4e189d..84d172afc36 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java @@ -64,11 +64,12 @@ public SparseBlockCSC(int rlen, int clen) { _size = 0; } - public SparseBlockCSC(int clen, int capacity, int size) { + public SparseBlockCSC(int rlen, int clen, int capacity) { + _rlen = rlen; _ptr = new int[clen + 1]; //ix0=0 _indexes = new int[capacity]; _values = new double[capacity]; - _size = size; + _size = 0; } public SparseBlockCSC(int[] colPtr, int[] rowInd, double[] values, int nnz) { @@ -94,8 +95,9 @@ public SparseBlockCSC(SparseBlock sblock) { private void initialize(SparseBlock sblock) { - if(_size > Integer.MAX_VALUE) - throw new RuntimeException("SparseBlockCSC supports nnz<=Integer.MAX_VALUE but got " + _size); + long size = sblock.size(); + if(size > Integer.MAX_VALUE) + throw new RuntimeException("SparseBlockCSC supports nnz<=Integer.MAX_VALUE but got " + size); //special case SparseBlockCSC if(sblock instanceof SparseBlockCSC) { @@ -223,27 +225,6 @@ public SparseBlockCSC(int cols, int[] rowInd, int[] colInd, double[] values) { } - public SparseBlockCSC(int cols, int nnz, int[] rowInd) { - - _clenInferred = cols; - _ptr = new int[cols + 1]; - _indexes = Arrays.copyOf(rowInd, nnz); - _values = new double[nnz]; - Arrays.fill(_values, 1); - _size = nnz; - - //single-pass construction of col pointers - //and copy of row indexes if necessary - for(int i = 0, pos = 0; i < cols; i++) { - if(rowInd[i] >= 0) { - if(cols > nnz) - _indexes[pos] = rowInd[i]; - pos++; - } - _ptr[i + 1] = pos; - } - } - /** * Initializes the CSC sparse block from an ordered input stream of ultra-sparse ijv triples. * @@ -288,7 +269,6 @@ public void initSparse(int clen, int nnz, DataInput in) throws IOException { // Allocate space if necessary if(_values.length < nnz) { resize(newCapacity(nnz)); - System.out.println("hallo"); } // Read sparse columns, append and update pointers _ptr[0] = 0; diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java index 3308646c330..eb75cade617 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java @@ -63,17 +63,6 @@ public SparseBlockDCSR(int rlen, int capacity) { _nnzr = 0; } - public SparseBlockDCSR(int rlen, int capacity, int size, int nnzr){ - LOG.warn("Allocating a DCSR-block using row-length. This will lead to significant overhead!"); - _rowidx = new int[rlen]; - _rowptr = new int[rlen + 1]; - _colidx = new int[capacity]; - _values = new double[capacity]; - _rlen = rlen; - _size = size; - _nnzr = nnzr; - } - public SparseBlockDCSR(int[] rowIdx, int[] rowPtr, int[] colIdx, double[] values, int rlen, int nnz, int nnzr){ LOG.warn("Allocating a DCSR-block using row-length. This will lead to significant overhead!"); _rowidx = rowIdx; diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java index 405ff578a00..d5c03fc03fd 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java @@ -69,15 +69,15 @@ private void initialize(SparseBlock sblock) { else if(sblock instanceof SparseBlockMCSR) { SparseRow[] originalRows = ((SparseBlockMCSR) sblock).getRows(); Map columnSizes = new HashMap<>(); - if(_clenInferred == -1) { - for(SparseRow row : originalRows) { - if(row != null && !row.isEmpty()) { - for(int i = 0; i < row.size(); i++) { - int rowIndex = row.indexes()[i]; - columnSizes.put(rowIndex, columnSizes.getOrDefault(rowIndex, 0) + 1); - } + for(SparseRow row : originalRows) { + if(row != null && !row.isEmpty()) { + for(int i = 0; i < row.size(); i++) { + int rowIndex = row.indexes()[i]; + columnSizes.put(rowIndex, columnSizes.getOrDefault(rowIndex, 0) + 1); } } + } + if(_clenInferred == -1) { clen = columnSizes.keySet().stream().max(Integer::compare).orElseThrow(NoSuchElementException::new); _columns = new SparseRow[clen + 1]; } diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java index 1d6ad261201..b78b8642356 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java @@ -126,7 +126,7 @@ public void testSparseBlockCOOIncorrectArrayLengths() { @Test public void testSparseBlockCSCIncorrectArrayLengths() { - SparseBlockCSC sblock = new SparseBlockCSC(2, 2, 1); + SparseBlockCSC sblock = new SparseBlockCSC(2, 2, 2); // nnz > capacity RuntimeException ex = assertThrows(RuntimeException.class, () -> sblock.checkValidity(2, 3, 6, false)); @@ -532,7 +532,7 @@ public void testSparseBlockCOOCapacityExceedsAllowedLimit() { @Test public void testSparseBlockCSCCapacityExceedsAllowedLimit() { - SparseBlockCSC block = new SparseBlockCSC(3, 50, 0); + SparseBlockCSC block = new SparseBlockCSC(3, 3, 50); RuntimeException ex = assertThrows(RuntimeException.class, () -> block.checkValidity(3, 3, 0, false)); diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java index 79f4bd280af..3d5a3362b07 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java @@ -20,16 +20,35 @@ package org.apache.sysds.test.component.sparse; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockCOO; +import org.apache.sysds.runtime.data.SparseBlockCSC; +import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockDCSR; import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.apache.sysds.runtime.data.SparseBlockMCSC; +import org.apache.sysds.runtime.data.SparseBlockMCSR; +import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.data.SparseRowVector; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestUtils; import org.junit.Test; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; + import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + public class SparseBlockInitializationTest extends AutomatedTestBase { + private final static int _rows = 324; private final static int _cols = 132; + private final static double _sparsity = 0.22; @Override public void setUp() { @@ -60,4 +79,374 @@ private void runSparseBlockCreationTest(SparseBlock.Type type) { SparseBlock sblock = SparseBlockFactory.createSparseBlock(type, _cols); assertEquals(SparseBlockFactory.getSparseBlockType(sblock), type); } + + @Test + public void testSparseBlockCOOInitCapacity() { + int init_capacity = 4; + SparseBlockCOO sblock = new SparseBlockCOO(_cols); + assertEquals("INIT_CAPACITY should be 4", init_capacity, sblock.values(1).length); + } + + @Test + public void testSparseBlockCOORows() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock sblock = new SparseBlockCOO(mbtmp.getSparseBlock()); + + int totalNnz = 0; + int rows = A.length; + SparseRow[] sparseRows = new SparseRow[rows]; + + for (int i = 0; i < rows; i++) { + SparseRow srv = new SparseRowVector(A[i]); + sparseRows[i] = srv; + totalNnz += srv.size(); + } + + SparseBlockCOO sblock2 = new SparseBlockCOO(sparseRows, totalNnz); + assertEquals(sblock, sblock2); + } + + @Test + public void testSparseBlockCOORowsValuesIndexes() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock sblock = new SparseBlockCOO(mbtmp.getSparseBlock()); + + int totalNnz = 0; + int rows = A.length; + SparseRow[] sparseRows = new SparseRow[rows]; + + for (int i = 0; i < rows; i++) { + int[] indexes = new int[A[i].length]; + for (int j = 0; j < A[i].length; j++) indexes[j] = j; + SparseRow srv = new SparseRowVector(A[i], indexes); + srv.compact(); + sparseRows[i] = srv; + totalNnz += srv.size(); + } + + SparseBlockCOO sblock2 = new SparseBlockCOO(sparseRows, totalNnz); + assertEquals(sblock, sblock2); + } + + @Test + public void testSparseBlockCSCInitCapacity() { + int rlen = 4; + int clen = 5; + int capacity = 4; + SparseBlockCSC sblock = new SparseBlockCSC(rlen, clen, capacity); + + assertEquals("num rows should be equal to rlen", rlen, sblock.numRows()); + assertEquals("length ptr should be equal to clen+1", clen+1, sblock.colPointers().length); + assertEquals("length values should be equal to capacity", capacity, sblock.valuesCol(0).length); + assertEquals("length indexes should be equal to capacity", capacity, sblock.indexesCol(0).length); + } + + @Test + public void testSparseBlockCSCInitPointer() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockCSC sblock = new SparseBlockCSC(mbtmp.getSparseBlock()); + + int[] colPtr = sblock.colPointers(); + int[] rowInd = sblock.indexesCol(0); + double[] values = sblock.valuesCol(0); + int nnz = sblock.sizeCol(0); + SparseBlockCSC sblock2 = new SparseBlockCSC(colPtr, rowInd, values, nnz); + + assertEquals(sblock, sblock2); + } + + @Test + public void testSparseBlockCSCInitMSCS() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock sblock = new SparseBlockMCSC(mbtmp.getSparseBlock()); + + SparseBlockCSC sblock2 = new SparseBlockCSC(sblock); + assertEquals(sblock, sblock2); + } + + @Test + public void testSparseBlockCSCInitCols() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockMCSC sblock = new SparseBlockMCSC(mbtmp.getSparseBlock()); + + SparseRow[] cols = sblock.getCols(); + int totalNnz = (int) sblock.size(); + + SparseBlock sblock2 = new SparseBlockCSC(cols, totalNnz); + assertEquals(sblock, sblock2); + + } + + @Test + public void testSparseBlockCSCInitRowColInd() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockCSC sblock = new SparseBlockCSC(mbtmp.getSparseBlock()); + + int[] ptr = sblock.colPointers(); + int[] rowInd = sblock.indexesCol(0); + double[] values = sblock.valuesCol(0); + + int clen = ptr.length-1; + int[] colInd = new int[rowInd.length]; + for(int i=0; i Date: Thu, 15 Jan 2026 10:43:11 +0100 Subject: [PATCH 06/12] add SparseBlock col tests --- .../sysds/runtime/data/SparseBlockCSC.java | 2 +- .../component/sparse/SparseBlockColTest.java | 273 ++++++++++++++++++ 2 files changed, 274 insertions(+), 1 deletion(-) create mode 100644 src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java index 84d172afc36..8e2e5498100 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java @@ -917,7 +917,7 @@ public void deleteIndexRangeCol(int c, int rl, int ru) { int len = sizeCol(c); int end = internPosFIndexGTECol(ru, c); if(end < 0) //delete all remaining - end = start + len; + end = posCol(c) + len; //overlapping array copy (shift rhs values left) System.arraycopy(_indexes, end, _indexes, start, _size - end); diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java new file mode 100644 index 00000000000..15461ca08f0 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockCSC; +import org.apache.sysds.runtime.data.SparseBlockMCSC; +import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.data.SparseRowScalar; +import org.apache.sysds.runtime.data.SparseRowVector; + +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; + +public class SparseBlockColTest extends AutomatedTestBase +{ + private final static int _rows = 324; + private final static int _cols = 132; + private final static double _sparsity = 0.3; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSparseBlockCSCGetReset() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockCSC(mbtmp.getSparseBlock())); + runSparseBlockGetResetTest(b, SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockMCSCGetReset() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockMCSC(mbtmp.getSparseBlock())); + runSparseBlockGetResetTest(b, SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockCSCSetSort() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockCSC(mbtmp.getSparseBlock())); + SparseRow[] cols = (new SparseBlockMCSC(mbtmp.getSparseBlock())).getCols(); + runSparseBlockSetSortTest(b, cols); + } + + @Test + public void testSparseBlockMCSCSetSort() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockMCSC(mbtmp.getSparseBlock())); + SparseRow[] cols = (new SparseBlockMCSC(mbtmp.getSparseBlock())).getCols(); + runSparseBlockSetSortTest(b, cols); + } + + @Test + public void testSparseBlockCSCSetDelIdxRange() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockCSC(mbtmp.getSparseBlock())); + SparseRow[] cols = (new SparseBlockMCSC(mbtmp.getSparseBlock())).getCols(); + runSparseBlockSetDelIdxRangeTest(b, cols); + } + + @Test + public void testSparseBlockMCSCSetDelIdxRange() { + double ultraSparsity = 0.001; + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, ultraSparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockMCSC(mbtmp.getSparseBlock())); + SparseRow[] cols = (new SparseBlockMCSC(mbtmp.getSparseBlock())).getCols(); + runSparseBlockSetDelIdxRangeTest(b, cols); + } + + private void runSparseBlockGetResetTest(SparseBlockColWrapper sblock, SparseBlock.Type btype) { + int c = _cols/3; + SparseRow col = sblock.getCol(c); + int size = sblock.sizeCol(c); + Assert.assertEquals(col.size(), size); + + sblock.resetCol(c); + col = sblock.getCol(c); + size = sblock.sizeCol(c); + Assert.assertEquals(0, size); + Assert.assertTrue(col.isEmpty()); + if(btype == SparseBlock.Type.CSC) Assert.assertTrue(col instanceof SparseRowScalar); + + // nothing changes + SparseBlockColWrapper sblock2 = sblock.copy(); + sblock.resetCol(c); + SparseRow col2 = sblock.getCol(c); + Assert.assertArrayEquals(col.indexes(), col2.indexes()); + Assert.assertArrayEquals(col.values(), col2.values(), 0); + Assert.assertEquals(sblock.getObject(), sblock2.getObject()); + } + + private void runSparseBlockSetSortTest(SparseBlockColWrapper sblock, SparseRow[] cols) { + int c = _cols/3; + SparseRow col = cols[c]; + double[] values = col.values().clone(); + int[] indexes = col.indexes().clone(); + int size = col.size(); + + // reverse + for (int i = 0; i < size/2; i++) { + double t = values[i]; + values[i] = values[size-1-i]; + values[size-1-i] = t; + int t2 = indexes[i]; + indexes[i] = indexes[size-1-i]; + indexes[size-1-i] = t2; + } + Assert.assertFalse(Arrays.equals(col.values(), values)); + Assert.assertFalse(Arrays.equals(col.indexes(), indexes)); + + SparseRow col2 = new SparseRowVector(values, indexes); + sblock.resetCol(c); + sblock.setCol(c, col2, true); + Assert.assertArrayEquals(col2.indexes(), sblock.getCol(c).indexes()); + Assert.assertArrayEquals(col2.values(), sblock.getCol(c).values(), 0); + + int nnz = (int) ((SparseBlock) sblock.getObject()).size(); + int rlen = ((SparseBlock) sblock.getObject()).numRows(); + int clen = cols.length; + RuntimeException ex = Assert.assertThrows(RuntimeException.class, + () -> ((SparseBlock) sblock.getObject()).checkValidity(rlen, clen, nnz, true)); + Assert.assertTrue(ex.getMessage().startsWith("Wrong sparse column ordering")); + + sblock.sortCol(c); + Assert.assertTrue(((SparseBlock)sblock.getObject()).checkValidity(rlen, clen, nnz, true)); + Assert.assertArrayEquals(col.indexes(), sblock.getCol(c).indexes()); + Assert.assertArrayEquals(col.values(), sblock.getCol(c).values(), 0); + } + + private void runSparseBlockSetDelIdxRangeTest(SparseBlockColWrapper sblock, SparseRow[] cols) { + int c = _cols/3; + double[] v = getRandomMatrix(1, _rows, -10, 10, _sparsity, 1234)[0]; + // TODO: SHORTER RANGE THAN COL LENGTH + sblock.setIndexRangeCol(c, 0, _rows, v, 0, _rows); + cols[c] = new SparseRowVector(v); + SparseBlock sblock2 = new SparseBlockMCSC(cols, false, _rows); + Assert.assertEquals(sblock2, sblock.getObject()); + + int rl = _rows/4; + int ru = _rows/2; + sblock.deleteIndexRangeCol(c, rl, ru); + for(int i=rl; i= 0 always true?! + // sblock.deleteIndexRangeCol(c, -2, ru); + // Assert.assertEquals(sblock4, sblock.getObject()); + } + + private interface SparseBlockColWrapper { + SparseRow getCol(int c); + void setCol(int c, SparseRow col, boolean deep); + void setIndexRangeCol(int c, int rl, int ru, double[] v, int vix, int vlen); + void deleteIndexRangeCol(int c, int rl, int ru); + int sizeCol(int c); + void sortCol(int c); + void resetCol(int c); + SparseBlockColWrapper copy(); + Object getObject(); + } + + private SparseBlockColWrapper wrap(SparseBlockCSC b) { + return new SparseBlockColWrapper() { + @Override + public SparseRow getCol(int c) { return b.getCol(c); } + + @Override + public void setCol(int c, SparseRow col, boolean deep) { + b.setCol(c, col, deep); } + + @Override + public void setIndexRangeCol(int c, int rl, int ru, double[] v, int vix, int vlen){ + b.setIndexRangeCol(c, rl, ru, v, vix, vlen); + } + + @Override + public void deleteIndexRangeCol(int c, int rl, int ru){ + b.deleteIndexRangeCol(c, rl, ru); + } + + @Override + public int sizeCol(int c) { return b.sizeCol(c); } + + @Override + public void sortCol(int c) { b.sortCol(c); } + + @Override + public void resetCol(int c) { b.resetCol(c); } + + @Override + public SparseBlockColWrapper copy() { return wrap(new SparseBlockCSC(b)); } + + @Override + public Object getObject() { return b; } + }; + } + + private SparseBlockColWrapper wrap(SparseBlockMCSC b) { + return new SparseBlockColWrapper() { + @Override + public SparseRow getCol(int c) { return b.getCol(c); } + + @Override + public void setCol(int c, SparseRow col, boolean deep) { + b.setCol(c, col, deep); } + + @Override + public void setIndexRangeCol(int c, int rl, int ru, double[] v, int vix, int vlen){ + b.setIndexRangeCol(c, rl, ru, v, vix, vlen); + } + + @Override + public void deleteIndexRangeCol(int c, int rl, int ru){ + b.deleteIndexRangeCol(c, rl, ru); + } + + @Override + public int sizeCol(int c) { return b.sizeCol(c); } + + @Override + public void sortCol(int c) { b.sortCol(c); } + + @Override + public void resetCol(int c) { b.resetCol(c, 0, 0); } + + @Override + public SparseBlockColWrapper copy() { return wrap(new SparseBlockCSC(b)); } + + @Override + public Object getObject() { return b; } + }; + } +} From 483465b58bbffe38647160a32cd1420a819d3dcd Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Thu, 15 Jan 2026 15:22:20 +0100 Subject: [PATCH 07/12] extend SparseBlock initialization tests --- .../sysds/runtime/data/SparseBlockMCSC.java | 9 +++++++- .../sparse/SparseBlockInitializationTest.java | 22 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java index d5c03fc03fd..5c3008add3b 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java @@ -113,7 +113,14 @@ else if(columnSize == 1) { } rowPosition++; } - + } + else if(sblock instanceof SparseBlockCSC) { + clen = ((SparseBlockCSC) sblock).numCols(); + _columns = new SparseRow[clen]; + for(int i = 0; i < clen; i++) { + if(!((SparseBlockCSC) sblock).isEmptyCol(i)) + _columns[i] = ((SparseBlockCSC) sblock).getCol(i); + } } // general case SparseBlock else { diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java index 3d5a3362b07..57185053a7b 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java @@ -404,6 +404,16 @@ public void testSparseBlockMCSCInitMCSRClenInferred() { assertEquals(sblock, sblock2); } + @Test + public void testSparseBlockMCSCInitCSC() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock sblock = new SparseBlockCSC(mbtmp.getSparseBlock()); + + SparseBlock sblock2 = new SparseBlockMCSC(sblock); + assertEquals(sblock, sblock2); + } + @Test public void testSparseBlockMCSCInitColsDeep() { double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); @@ -449,4 +459,16 @@ public void testSparseBlockMCSRInitRows() { assertEquals(sblock, sblock2); assertNotSame(sblock.getRows(), sblock2.getRows()); } + + @Test + public void testSparseBlockCSRInitSize() { + int rlen = 3; + int capacity = 7; + int size = 2; + SparseBlockCSR sblock = new SparseBlockCSR(rlen, capacity, size); + sblock.append(0, 1, 1.0); + sblock.append(0, 3, 3.0); + sblock.compact(); + assertEquals("size should be 2", 2, sblock.size()); + } } From 5c67157ce5f0a29db768ba98ebc4dab073ab29a9 Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Thu, 15 Jan 2026 15:26:57 +0100 Subject: [PATCH 08/12] extend SparseBlock checkValidity tests --- .../sysds/runtime/data/SparseBlockCOO.java | 4 +- .../sysds/runtime/data/SparseBlockCSC.java | 2 + .../sysds/runtime/data/SparseBlockCSR.java | 2 + .../sysds/runtime/data/SparseBlockDCSR.java | 2 + .../sparse/SparseBlockCheckValidityTest.java | 125 +++++++++--------- 5 files changed, 72 insertions(+), 63 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java index d4d0467a098..0512fb100d9 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java @@ -236,7 +236,7 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { int apos = pos(i); int alen = size(i); for(int k=apos+1; k _cindexes[k] ) + if(_cindexes[k-1] > _cindexes[k]) throw new RuntimeException("Wrong sparse row ordering: " + k + " "+_cindexes[k-1]+" "+_cindexes[k]); } @@ -246,6 +246,8 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { if( _values[i] == 0) throw new RuntimeException("The values array should not contain zeros." + " The " + i + "th value is "+_values[i]); + if(_cindexes[i] < 0 || _rindexes[i] < 0) + throw new RuntimeException("Invalid index at pos=" + i); } //5. a capacity that is no larger than nnz times the resize factor diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java index 8e2e5498100..1528c660e4d 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java @@ -560,6 +560,8 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { throw new RuntimeException( "The values array should not contain zeros." + " The " + i + "th value is " + _values[i]); } + if(_indexes[i] < 0) + throw new RuntimeException("Invalid index at pos=" + i); } //6. a capacity that is no larger than nnz times resize factor. diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java index a6000ff965f..c9fe912ec5c 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java @@ -964,6 +964,8 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { throw new RuntimeException("The values array should not contain zeros." + " The " + i + "th value is "+_values[i]); } + if(_indexes[i] < 0) + throw new RuntimeException("Invalid index at pos=" + i); } //6. a capacity that is no larger than nnz times resize factor. diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java index eb75cade617..c3564c2e5b7 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java @@ -729,6 +729,8 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { throw new RuntimeException("The values array should not contain zeros." + " The " + i + "th value is "+_values[i]); } + if(_colidx[i] < 0) + throw new RuntimeException("Invalid index at pos=" + i); } //6. a capacity that is no larger than nnz times resize factor. diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java index b78b8642356..b219209126e 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java @@ -415,80 +415,22 @@ public void testSparseBlockMCSRInvalidIndices() { @Test public void testSparseBlockCOOInvalidValue() { - SparseBlockCOO block = new SparseBlockCOO(3, 3); - - int[] r = new int[]{0, 1, 2}; - int[] c = new int[]{0, 1, 2}; - double[] v = new double[]{1, 2, 0}; // contains 0 - - setField(block, "_rindexes", r); - setField(block, "_cindexes", c); - setField(block, "_values", v); - setField(block, "_size", 3); - - RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 3, 3, false)); - - assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + runSparseBlockInvalidValueTest(SparseBlock.Type.COO); } @Test public void testSparseBlockCSCInvalidValue() { - SparseBlockCSC block = new SparseBlockCSC(3, 3); - - int[] ptr = new int[]{0, 3, 3, 3}; - int[] idxs = new int[]{0, 1, 2}; - double[] v = new double[]{1, 1, 0}; - - setField(block, "_ptr", ptr); - setField(block, "_indexes", idxs); - setField(block, "_values", v); - setField(block, "_size", 3); - - RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 3, 3, false)); - - assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + runSparseBlockInvalidValueTest(SparseBlock.Type.CSC); } @Test public void testSparseBlockCSRInvalidValue() { - SparseBlockCSR block = new SparseBlockCSR(3, 3); - - int[] ptr = new int[]{0, 3, 3, 3}; - int[] idxs = new int[]{0, 1, 2}; - double[] v = new double[]{1, 1, 0}; - - setField(block, "_ptr", ptr); - setField(block, "_indexes", idxs); - setField(block, "_values", v); - setField(block, "_size", 3); - - RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 3, 3, false)); - - assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + runSparseBlockInvalidValueTest(SparseBlock.Type.CSR); } @Test public void testSparseBlockDCSRInvalidValue() { - SparseBlockDCSR block = new SparseBlockDCSR(3, 3); - - int[] rowIdxs = new int[]{0, 1, 2}; - int[] rowPtr = new int[]{0, 1, 2, 3}; - int[] colIdxs = new int[]{0, 1, 2}; - double[] v = new double[]{1, 1, 0}; - - setField(block, "_rowidx", rowIdxs); - setField(block, "_rowptr", rowPtr); - setField(block, "_colidx", colIdxs); - setField(block, "_values", v); - setField(block, "_size", 3); - - RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(1, 3, 3, false)); - - assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + runSparseBlockInvalidValueTest(SparseBlock.Type.DCSR); } @Test @@ -519,6 +461,32 @@ public void testSparseBlockMCSRInvalidValue() { assertTrue(ex.getMessage().startsWith("The values are expected to be non zeros")); } + @Test + public void testSparseBlockCOOInvalidRIndex() { + runSparseBlockInvalidIndexTest(SparseBlock.Type.COO, "_rindexes"); + } + + @Test + public void testSparseBlockCOOInvalidCIndex() { + runSparseBlockInvalidIndexTest(SparseBlock.Type.COO, "_cindexes"); + } + + + @Test + public void testSparseBlockCSCInvalidIndex() { + runSparseBlockInvalidIndexTest(SparseBlock.Type.CSC, "_indexes"); + } + + @Test + public void testSparseBlockCSRInvalidIndex() { + runSparseBlockInvalidIndexTest(SparseBlock.Type.CSR, "_indexes"); + } + + @Test + public void testSparseBlockDCSRInvalidIndex() { + runSparseBlockInvalidIndexTest(SparseBlock.Type.DCSR, "_colidx"); + } + @Test public void testSparseBlockCOOCapacityExceedsAllowedLimit() { SparseBlockCOO block = new SparseBlockCOO(3, 50); @@ -611,6 +579,39 @@ private void runSparseBlockInvalidDimensionsTest(SparseBlock block) { assertTrue(ex2.getMessage().startsWith("Invalid block dimensions")); } + private void runSparseBlockInvalidValueTest(SparseBlock.Type btype) { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + double[] values = (double[]) getField(sblock, "_values"); + values[values.length-1] = 0.; + setField(sblock, "_values", Arrays.copyOfRange(values, 0, values.length)); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(_rows, _cols, sblock.size(), true)); + assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + } + + + private void runSparseBlockInvalidIndexTest(SparseBlock.Type btype, String indexName) { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + int[] indexes = (int[]) getField(sblock, indexName); + indexes[0] = -1; + setField(sblock, indexName, Arrays.copyOfRange(indexes, 0, indexes.length)); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(_rows, _cols, sblock.size(), true)); + assertTrue(ex.getMessage().startsWith("Invalid index at pos")); + } + private static void setField(Object obj, String name, Object value) { try { Field f = obj.getClass().getDeclaredField(name); From 5708be67a00d037ef85fc7e1920451a2272b1de8 Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Fri, 16 Jan 2026 11:47:44 +0100 Subject: [PATCH 09/12] add SparseBlock compact functionality and corresponding tests + fix SparseBlock DCSR checkValidity --- .../sysds/runtime/data/SparseBlock.java | 6 + .../sysds/runtime/data/SparseBlockCOO.java | 14 ++ .../sysds/runtime/data/SparseBlockCSC.java | 19 +++ .../sysds/runtime/data/SparseBlockCSR.java | 3 +- .../sysds/runtime/data/SparseBlockDCSR.java | 31 +++- .../sysds/runtime/data/SparseBlockMCSC.java | 18 +++ .../sysds/runtime/data/SparseBlockMCSR.java | 16 ++ .../sparse/SparseBlockCheckValidityTest.java | 3 + .../sparse/SparseBlockCompactTest.java | 147 ++++++++++++++++++ 9 files changed, 253 insertions(+), 4 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCompactTest.java diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java index 29ee6c79f78..ac2876d4c44 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java @@ -95,6 +95,12 @@ public enum Type { * @param r row index */ public abstract void compact(int r); + + /** + * In-place compaction of non-zero-entries; removes zero entries + * and shifts non-zero entries to the left if necessary. + */ + public abstract void compact(); //////////////////////// //obtain basic meta data diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java index 0512fb100d9..8e57559ca3b 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java @@ -193,6 +193,20 @@ public void compact(int r) { //do nothing everything preallocated } + @Override + public void compact() { + int pos = 0; + for(int i=0; i< _values.length; i++) { + if(_values[i] != 0){ + _values[pos] = _values[i]; + _rindexes[pos] = _rindexes[i]; + _cindexes[pos] = _cindexes[i]; + pos++; + } + } + _size = pos; + } + @Override public int numRows() { return _rlen; diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java index 1528c660e4d..16a3b3163fe 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java @@ -357,6 +357,25 @@ public void compact(int r) { //do nothing everything preallocated } + @Override + public void compact() { + int pos = 0; + for(int i=0; i -1) diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java index c9fe912ec5c..947078b2171 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java @@ -350,7 +350,8 @@ public void allocate(int r, int ennz, int maxnnz) { public void compact(int r) { //do nothing everything preallocated } - + + @Override public void compact() { int pos = 0; for(int i=0; i _rowidx[i]) throw new RuntimeException("Row indices are decreasing at row: " + i + ", with indices " + _rowidx[i-1] + " > " +_rowidx[i]); } - for (int i = 1; i < _rowptr.length; i++ ) { + for (int i = 1; i < _nnzr+1; i++ ) { if (_rowptr[i - 1] > _rowptr[i]) { throw new RuntimeException("Row pointers are decreasing at row: " + i + ", with pointers " + _rowptr[i-1] + " > " +_rowptr[i]); @@ -713,7 +738,7 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { } //4. sorted column indexes per row - for (int i = 0; i < _rowptr.length-1; i++) { + for (int i = 0; i < _nnzr; i++) { int apos = _rowptr[i]; int alen = _rowptr[i+1] - apos; diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java index 5c3008add3b..f8cdfebee4a 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java @@ -303,6 +303,24 @@ else if(_columns[c] instanceof SparseRowScalar) { } } + @Override + public void compact() { + for(int i = 0; i < numCols(); i++) { + if(isAllocatedCol(i)) { + if(_columns[i] instanceof SparseRowVector) { + _columns[i].compact(); + if(_columns[i].isEmpty()) + _columns[i] = null; + } + else if(_columns[i] instanceof SparseRowScalar) { + SparseRowScalar s = (SparseRowScalar) _columns[i]; + if(s.getValue() == 0) + _columns[i] = null; + } + } + } + } + @Override public int numRows() { return _rlen; diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java index 9a77646da14..b797560c428 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java @@ -199,6 +199,22 @@ else if(_rows[r] instanceof SparseRowScalar) { } } } + + @Override + public void compact() { + for(int i = 0; i < numRows(); i++) { + if(isAllocated(i)) { + if(_rows[i] instanceof SparseRowVector) { + _rows[i].compact(); + if(_rows[i].isEmpty()) _rows[i] = null; + } + else if(_rows[i] instanceof SparseRowScalar) { + SparseRowScalar s = (SparseRowScalar) _rows[i]; + if(s.getValue() == 0) _rows[i] = null; + } + } + } + } @Override public int numRows() { diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java index b219209126e..58913ce0957 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java @@ -251,6 +251,7 @@ public void testSparseBlockDCSRDecreasingRowIndices() { setField(block, "_colidx", colIdxs); setField(block, "_values", v); setField(block, "_size", 3); + setField(block, "_nnzr", 3); RuntimeException ex = assertThrows(RuntimeException.class, () -> block.checkValidity(3, 10, 3, true)); @@ -272,6 +273,7 @@ public void testSparseBlockDCSRDecreasingRowPointers() { setField(block, "_colidx", colIdxs); setField(block, "_values", v); setField(block, "_size", 3); + setField(block, "_nnzr", 3); RuntimeException ex = assertThrows(RuntimeException.class, () -> block.checkValidity(3, 10, 3, true)); @@ -350,6 +352,7 @@ public void testSparseBlockDCSRUnsortedColumnIndicesWithinRow() { setField(block, "_colidx", colIdxs); setField(block, "_values", v); setField(block, "_size", 3); + setField(block, "_nnzr", 2); RuntimeException ex = assertThrows(RuntimeException.class, () -> block.checkValidity(1, 3, 3, false)); diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCompactTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCompactTest.java new file mode 100644 index 00000000000..6505f15b711 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCompactTest.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import java.lang.reflect.Field; +import java.util.Arrays; + +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class SparseBlockCompactTest extends AutomatedTestBase +{ + private final static int _rows = 324; + private final static int _cols = 132; + private final static double _sparsity = 0.22; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSparseBlockCompactCOO() { + runSparseBlockCompactZerosTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockCompactCSC() { + runSparseBlockCompactZerosTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockCompactCSR() { + runSparseBlockCompactZerosTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockCompactDCSR() { + runSparseBlockCompactZerosTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockCompactMCSC() { + runSparseBlockModifiedCompactZerosTest(SparseBlock.Type.MCSC, "_columns"); + } + + @Test + public void testSparseBlockCompactMCSR() { + runSparseBlockModifiedCompactZerosTest(SparseBlock.Type.MCSR, "_rows"); + } + + private void runSparseBlockCompactZerosTest(SparseBlock.Type btype) { + + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + double[] values = (double[]) getField(sblock, "_values"); + values[0] = 0.0; + values[values.length-1] = 0.0; + setField(sblock, "_values", Arrays.copyOfRange(values, 0, values.length)); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(_rows, _cols, sblock.size(), true)); + assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + long size = sblock.size(); + + sblock.compact(); + + assertTrue("should pass checkValidity", sblock.checkValidity(_rows, _cols, sblock.size(), true)); + assertEquals(size-2, sblock.size()); + } + + private void runSparseBlockModifiedCompactZerosTest(SparseBlock.Type btype, String field) { + + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + SparseRow[] sr = (SparseRow[]) getField(sblock, field); + double[] values = sr[0].values(); + values[0] = 0.0; + values[values.length-1] = 0.0; + setField(sr[0], "values", Arrays.copyOfRange(values, 0, values.length)); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(_rows, _cols, sblock.size(), true)); + assertTrue(ex.getMessage().startsWith("The values are expected to be non zeros")); + long size = sblock.size(); + + sblock.compact(); + + assertTrue("should pass checkValidity", sblock.checkValidity(_rows, _cols, sblock.size(), true)); + assertEquals(size-2, sblock.size()); + } + + private static void setField(Object obj, String name, Object value) { + try { + Field f = obj.getClass().getDeclaredField(name); + f.setAccessible(true); + f.set(obj, value); + } catch (Exception ex) { + throw new RuntimeException("Reflection failed: " + ex.getMessage()); + } + } + + private static Object getField(Object obj, String name) { + try { + Field f = obj.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.get(obj); + } catch (Exception ex) { + throw new RuntimeException("Reflection failed: " + ex.getMessage()); + } + } +} From ef2e3f3ec4ed62fb5f30b9f8b28329fa2546d119 Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Mon, 19 Jan 2026 14:55:31 +0100 Subject: [PATCH 10/12] minor fixes --- .../sparse/SparseBlockAlignment.java | 19 +++------- .../component/sparse/SparseBlockColTest.java | 38 ++++++++----------- .../sparse/SparseBlockCompactTest.java | 13 ------- .../sparse/SparseBlockEqualsTest.java | 2 +- .../test/component/sparse/SparseRowTest.java | 11 +++--- 5 files changed, 29 insertions(+), 54 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java index f1f044baa5e..006fac8f03a 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java @@ -270,23 +270,16 @@ else if( i<37 ) {//CSR/COO different after update pos Assert.fail("Wrong row alignment indicated: "+rowsAligned37+", expected: "+positive); if( !rowsAlignedRest ) Assert.fail("Wrong row alignment rest indicated: false."); + + //init third sparse block with different number of rows + SparseBlock sblock3 =SparseBlockFactory.createSparseBlock(btype, rows+1); + if (sblock.isAligned(sblock3)) { + Assert.fail("Wrong alignment different rows indicated: true."); + } } catch(Exception ex) { ex.printStackTrace(); throw new RuntimeException(ex); } } - - @Test - public void testSparseBlockDifferentNumRows() { - double[][] A = getRandomMatrix(rows, cols, -10, 10, sparsity3, 1234); - MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); - SparseBlock sblock = mbtmp.getSparseBlock(); - - double[][] B = getRandomMatrix(2*rows, cols, -10, 10, sparsity3, 1234); - MatrixBlock mbtmp2 = DataConverter.convertToMatrixBlock(B); - SparseBlock sblock2 = mbtmp2.getSparseBlock(); - - Assert.assertFalse(sblock.isAligned(sblock2)); - } } diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java index 15461ca08f0..19a313e0259 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java @@ -91,8 +91,7 @@ public void testSparseBlockCSCSetDelIdxRange() { @Test public void testSparseBlockMCSCSetDelIdxRange() { - double ultraSparsity = 0.001; - double[][] A = getRandomMatrix(_rows, _cols, -10, 10, ultraSparsity, 1234); + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); SparseBlockColWrapper b = wrap(new SparseBlockMCSC(mbtmp.getSparseBlock())); SparseRow[] cols = (new SparseBlockMCSC(mbtmp.getSparseBlock())).getCols(); @@ -161,30 +160,25 @@ private void runSparseBlockSetSortTest(SparseBlockColWrapper sblock, SparseRow[] private void runSparseBlockSetDelIdxRangeTest(SparseBlockColWrapper sblock, SparseRow[] cols) { int c = _cols/3; - double[] v = getRandomMatrix(1, _rows, -10, 10, _sparsity, 1234)[0]; - // TODO: SHORTER RANGE THAN COL LENGTH - sblock.setIndexRangeCol(c, 0, _rows, v, 0, _rows); - cols[c] = new SparseRowVector(v); - SparseBlock sblock2 = new SparseBlockMCSC(cols, false, _rows); - Assert.assertEquals(sblock2, sblock.getObject()); - int rl = _rows/4; int ru = _rows/2; + + SparseRow[] cols2 = Arrays.copyOf(cols, cols.length); + double[] v = getRandomMatrix(1, _rows, -10, 10, 1, 1234)[0]; + for(int i=0; i= 0 always true?! - // sblock.deleteIndexRangeCol(c, -2, ru); - // Assert.assertEquals(sblock4, sblock.getObject()); + for(int i=ru; i<_rows; i++) cols2[c].set(i, 0); + Assert.assertEquals(sblock2, sblock.getObject()); } private interface SparseBlockColWrapper { @@ -264,7 +258,7 @@ public void deleteIndexRangeCol(int c, int rl, int ru){ public void resetCol(int c) { b.resetCol(c, 0, 0); } @Override - public SparseBlockColWrapper copy() { return wrap(new SparseBlockCSC(b)); } + public SparseBlockColWrapper copy() { return wrap(new SparseBlockMCSC(b)); } @Override public Object getObject() { return b; } diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCompactTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCompactTest.java index 6505f15b711..013e603e927 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCompactTest.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCompactTest.java @@ -20,7 +20,6 @@ package org.apache.sysds.test.component.sparse; import java.lang.reflect.Field; -import java.util.Arrays; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockFactory; @@ -87,7 +86,6 @@ private void runSparseBlockCompactZerosTest(SparseBlock.Type btype) { double[] values = (double[]) getField(sblock, "_values"); values[0] = 0.0; values[values.length-1] = 0.0; - setField(sblock, "_values", Arrays.copyOfRange(values, 0, values.length)); RuntimeException ex = assertThrows(RuntimeException.class, () -> sblock.checkValidity(_rows, _cols, sblock.size(), true)); @@ -112,7 +110,6 @@ private void runSparseBlockModifiedCompactZerosTest(SparseBlock.Type btype, Stri double[] values = sr[0].values(); values[0] = 0.0; values[values.length-1] = 0.0; - setField(sr[0], "values", Arrays.copyOfRange(values, 0, values.length)); RuntimeException ex = assertThrows(RuntimeException.class, () -> sblock.checkValidity(_rows, _cols, sblock.size(), true)); @@ -125,16 +122,6 @@ private void runSparseBlockModifiedCompactZerosTest(SparseBlock.Type btype, Stri assertEquals(size-2, sblock.size()); } - private static void setField(Object obj, String name, Object value) { - try { - Field f = obj.getClass().getDeclaredField(name); - f.setAccessible(true); - f.set(obj, value); - } catch (Exception ex) { - throw new RuntimeException("Reflection failed: " + ex.getMessage()); - } - } - private static Object getField(Object obj, String name) { try { Field f = obj.getClass().getDeclaredField(name); diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java index 29f7c7e2463..43dbae3004d 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java @@ -197,7 +197,7 @@ private static void runSparseBlockNotEqualsDenseValuesEmptyRowTest(SparseBlock.T } private static void runSparseBlockNotEqualsDenseValuesNonZeroTest(SparseBlock.Type type) { - double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0., 0., 0.},{0., 0., 1.}, {4., 0., 6.}}; + double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 1.}, {4., 0., 6.}}; MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); SparseBlock srtmp = mbtmp.getSparseBlock(); diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java index d2f588361c6..307b4335c49 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java @@ -205,11 +205,12 @@ public void testSparseRowVectorSearchIndexesFirstLTENotFound() { } @Test - public void testSparseRowVectorSetIndexRangeWithRecap() { + public void testSparseRowVectorSetIndexRangeWithoutRecap() { SparseRowVector srv = new SparseRowVector(); - srv.add(1, 1.0); - srv.add(4, 4.0); - srv.add(5, 5.0); - srv.setIndexRange(2, 3, new double[]{2.0, 3.0}, 0, 2); + int capacity = srv.capacity(); + + double[] v = getRandomMatrix(1, capacity, minVal, maxVal, sparsity, 7)[0]; + srv.setIndexRange(0, capacity, v, 0, capacity); + assertEquals(capacity, srv.capacity()); } } From e4cdd0017f644906b8b51948b770b0e2a839c326 Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Tue, 20 Jan 2026 11:13:52 +0100 Subject: [PATCH 11/12] updated checkValidity tests --- .../sparse/SparseBlockCheckValidityTest.java | 328 ++++++------------ 1 file changed, 103 insertions(+), 225 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java index 58913ce0957..216e4e4ef89 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java @@ -27,8 +27,6 @@ import org.apache.sysds.runtime.data.SparseBlockFactory; import org.apache.sysds.runtime.data.SparseBlockMCSC; import org.apache.sysds.runtime.data.SparseBlockMCSR; -import org.apache.sysds.runtime.data.SparseRow; -import org.apache.sysds.runtime.data.SparseRowVector; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.DataConverter; @@ -117,30 +115,27 @@ public void testSparseBlockMCSRInvalidDimensions() { @Test public void testSparseBlockCOOIncorrectArrayLengths() { SparseBlockCOO sblock = new SparseBlockCOO(2, 2); - // nnz > capacity + RuntimeException ex = assertThrows(RuntimeException.class, () -> sblock.checkValidity(2, 2, 4, false)); - assertEquals("Incorrect array lengths.", ex.getMessage()); } @Test public void testSparseBlockCSCIncorrectArrayLengths() { SparseBlockCSC sblock = new SparseBlockCSC(2, 2, 2); - // nnz > capacity + RuntimeException ex = assertThrows(RuntimeException.class, () -> sblock.checkValidity(2, 3, 6, false)); - assertEquals("Incorrect array lengths.", ex.getMessage()); } @Test public void testSparseBlockCSRIncorrectArrayLengths() { SparseBlockCSR sblock = new SparseBlockCSR(2, 2, 1); - // nnz > capacity + RuntimeException ex = assertThrows(RuntimeException.class, () -> sblock.checkValidity(3, 2, 6, false)); - assertEquals("Incorrect array lengths.", ex.getMessage()); } @@ -151,10 +146,9 @@ public void testSparseBlockDCSRIncorrectArrayLengths() { // cut off last value int[] rowptr = (int[]) getField(sblock,"_rowptr"); setField(sblock, "_rowptr", Arrays.copyOfRange(rowptr, 0, rowptr.length-1)); - // nnz > capacity + RuntimeException ex = assertThrows(RuntimeException.class, () -> sblock.checkValidity(3, 2, 6, false)); - assertEquals("Incorrect array lengths.", ex.getMessage()); } @@ -162,10 +156,8 @@ public void testSparseBlockDCSRIncorrectArrayLengths() { public void testSparseBlockMCSCIncorrectArrayLengths() { SparseBlockMCSC sblock = new SparseBlockMCSC(2, 2); - // nnz > capacity RuntimeException ex = assertThrows(RuntimeException.class, () -> sblock.checkValidity(3, 2, 1, false)); - assertTrue(ex.getMessage().startsWith("Incorrect size")); } @@ -173,246 +165,151 @@ public void testSparseBlockMCSCIncorrectArrayLengths() { public void testSparseBlockMCSRIncorrectArrayLengths() { SparseBlockMCSR sblock = new SparseBlockMCSR(2, 2); - // nnz > capacity RuntimeException ex = assertThrows(RuntimeException.class, () -> sblock.checkValidity(3, 2, 1, false)); - assertTrue(ex.getMessage().startsWith("Incorrect size")); } @Test public void testSparseBlockCOOUnsortedRowIndices() { - SparseBlockCOO block = new SparseBlockCOO(10, 3); - - int[] r = new int[]{0, 5, 2}; // unsorted - int[] c = new int[]{0, 1, 2}; - double[] v = new double[]{1, 1, 1}; - - setField(block, "_rindexes", r); - setField(block, "_cindexes", c); - setField(block, "_values", v); - setField(block, "_size", 3); + SparseBlockCOO sblock = new SparseBlockCOO(getFixedSparseBlock()); + int[] r = new int[]{0, 2, 1, 3}; // unsorted + setField(sblock, "_rindexes", r); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(10, 10, 3, false)); - + () -> sblock.checkValidity(4, 4, 6, false)); assertEquals("Wrong sorted order of row indices", ex.getMessage()); } @Test public void testSparseBlockCSCDecreasingColPointers() { - SparseBlockCSC block = new SparseBlockCSC(10, 3); - - int[] ptr = new int[]{0, 2, 1, 3}; // unsorted col pointer - int[] idxs = new int[]{0, 1, 2}; - double[] v = new double[]{1, 1, 1}; - - setField(block, "_ptr", ptr); - setField(block, "_indexes", idxs); - setField(block, "_values", v); - setField(block, "_size", 3); + SparseBlockCSC sblock = new SparseBlockCSC(getFixedSparseBlock()); + int[] ptr = new int[]{0, 2, 1, 4, 6}; // unsorted + setField(sblock, "_ptr", ptr); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(10, 3, 3, true)); - + () -> sblock.checkValidity(4, 4, 6, true)); assertTrue(ex.getMessage().startsWith("Column pointers are decreasing at column")); } @Test public void testSparseBlockCSRDecreasingRowPointers() { - SparseBlockCSR block = new SparseBlockCSR(3, 3); - - int[] ptr = new int[]{0, 2, 1, 3}; // unsorted row pointer - int[] idxs = new int[]{0, 1, 2}; - double[] v = new double[]{1, 1, 1}; - - setField(block, "_ptr", ptr); - setField(block, "_indexes", idxs); - setField(block, "_values", v); - setField(block, "_size", 3); + SparseBlockCSR sblock = new SparseBlockCSR(getFixedSparseBlock()); + int[] ptr = new int[]{0, 2, 1, 4, 6}; // unsorted + setField(sblock, "_ptr", ptr); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 10, 3, true)); - + () -> sblock.checkValidity(4, 4, 6, true)); assertTrue(ex.getMessage().startsWith("Row pointers are decreasing at row")); } @Test public void testSparseBlockDCSRDecreasingRowIndices() { - SparseBlockDCSR block = new SparseBlockDCSR(3, 3); - - int[] rowIdxs = new int[]{0, 2, 1}; // unsorted - int[] rowPtr = new int[]{0, 1, 2, 3}; - int[] colIdxs = new int[]{0, 1, 2}; - double[] v = new double[]{1, 1, 1}; - - setField(block, "_rowidx", rowIdxs); - setField(block, "_rowptr", rowPtr); - setField(block, "_colidx", colIdxs); - setField(block, "_values", v); - setField(block, "_size", 3); - setField(block, "_nnzr", 3); + SparseBlockDCSR sblock = new SparseBlockDCSR(getFixedSparseBlock()); + int[] rowIdxs = new int[]{0, 2, 1, 3}; // unsorted + setField(sblock, "_rowidx", rowIdxs); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 10, 3, true)); - + () -> sblock.checkValidity(4, 4, 6, false)); assertTrue(ex.getMessage().startsWith("Row indices are decreasing at row")); } @Test public void testSparseBlockDCSRDecreasingRowPointers() { - SparseBlockDCSR block = new SparseBlockDCSR(3, 3); - - int[] rowIdxs = new int[]{0, 1, 2}; - int[] rowPtr = new int[]{0, 1, 3, 2}; // unsorted - int[] colIdxs = new int[]{0, 1, 2}; - double[] v = new double[]{1, 1, 1}; - - setField(block, "_rowidx", rowIdxs); - setField(block, "_rowptr", rowPtr); - setField(block, "_colidx", colIdxs); - setField(block, "_values", v); - setField(block, "_size", 3); - setField(block, "_nnzr", 3); + SparseBlockDCSR sblock = new SparseBlockDCSR(getFixedSparseBlock()); + int[] rowPtr = new int[]{0, 1, 2, 6, 4}; // unsorted + setField(sblock, "_rowptr", rowPtr); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 10, 3, true)); - + () -> sblock.checkValidity(4, 4, 6, false)); assertTrue(ex.getMessage().startsWith("Row pointers are decreasing at row")); } @Test public void testSparseBlockCOOUnsortedColumnIndicesWithinRow() { - SparseBlockCOO block = new SparseBlockCOO(1, 3); - - int[] r = new int[]{0, 0, 0}; - int[] c = new int[]{0, 2, 1}; // unsorted for row 0 - double[] v = new double[]{1, 1, 1}; - - setField(block, "_rindexes", r); - setField(block, "_cindexes", c); - setField(block, "_values", v); - setField(block, "_size", 3); + SparseBlockCOO sblock = new SparseBlockCOO(getFixedSparseBlock()); + int[] c = new int[]{0, 1, 3, 4, 4, 3}; // unsorted for last row + setField(sblock, "_cindexes", c); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(1, 3, 3, false)); - + () -> sblock.checkValidity(4, 4, 6, false)); assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); } @Test public void testSparseBlockCSCUnsortedRowIndicesWithinColumn() { - SparseBlockCSC block = new SparseBlockCSC(10, 3); - - int[] ptr = new int[]{0, 3, 3, 3}; - int[] idxs = new int[]{0, 2, 1}; // unsorted - double[] v = new double[]{1, 1, 1}; - - setField(block, "_ptr", ptr); - setField(block, "_indexes", idxs); - setField(block, "_values", v); - setField(block, "_size", 3); + SparseBlockCSC sblock = new SparseBlockCSC(getFixedSparseBlock()); + int[] idxs = new int[]{0, 1, 2, 3, 3, 2}; // unsorted for last col + setField(sblock, "_indexes", idxs); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(10, 3, 3, true)); - + () -> sblock.checkValidity(4, 4, 6, false)); assertTrue(ex.getMessage().startsWith("Wrong sparse column ordering")); } @Test public void testSparseBlockCSRUnsortedColumnIndicesWithinRow() { - SparseBlockCSR block = new SparseBlockCSR(3, 3); - - int[] ptr = new int[]{0, 3, 3, 3}; - int[] idxs = new int[]{0, 2, 1}; // unsorted - double[] v = new double[]{1, 1, 1}; - - setField(block, "_ptr", ptr); - setField(block, "_indexes", idxs); - setField(block, "_values", v); - setField(block, "_size", 3); + SparseBlockCSR sblock = new SparseBlockCSR(getFixedSparseBlock()); + int[] idxs = new int[]{0, 1, 2, 3, 3, 2}; // unsorted for last row + setField(sblock, "_indexes", idxs); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(1, 3, 3, false)); - + () -> sblock.checkValidity(4, 4, 6, false)); assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); } @Test public void testSparseBlockDCSRUnsortedColumnIndicesWithinRow() { - SparseBlockDCSR block = new SparseBlockDCSR(3, 3); - - int[] rowIdxs = new int[]{0, 2}; - int[] rowPtr = new int[]{0, 1, 3}; - int[] colIdxs = new int[]{0, 2, 1}; // for row 2 unsorted - double[] v = new double[]{1, 1, 1}; - - setField(block, "_rowidx", rowIdxs); - setField(block, "_rowptr", rowPtr); - setField(block, "_colidx", colIdxs); - setField(block, "_values", v); - setField(block, "_size", 3); - setField(block, "_nnzr", 2); + SparseBlockDCSR sblock = new SparseBlockDCSR(getFixedSparseBlock()); + int[] colIdxs = new int[]{0, 1, 2, 3, 3, 2}; // unsorted for last row + setField(sblock, "_colidx", colIdxs); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(1, 3, 3, false)); - + () -> sblock.checkValidity(4, 4, 6, false)); assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); } @Test public void testSparseBlockMCSCUnsortedRowIndicesWithinColumn() { - SparseBlockMCSC block = new SparseBlockMCSC(10, 3); - - SparseRow col = new SparseRowVector(new double[]{1., 1., 1.}, new int[]{0, 2, 1}); // unsorted - SparseRow[] cols = new SparseRow[]{null, null, col}; - setField(block, "_columns", cols); + SparseBlockMCSC sblock = new SparseBlockMCSC(getFixedSparseBlock()); + int[] indexes = new int[]{3, 2}; // unsorted + setField(sblock.getCols()[3], "indexes", indexes); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(10, 3, 3, true)); - + () -> sblock.checkValidity(4, 4, 6, false)); assertTrue(ex.getMessage().startsWith("Wrong sparse column ordering")); } @Test public void testSparseBlockMCSRUnsortedColumnIndicesWithinRow() { - SparseBlockMCSR block = new SparseBlockMCSR(3, 10); - - SparseRow row = new SparseRowVector(new double[]{1., 1., 1.}, new int[]{0, 2, 1}); // unsorted - SparseRow[] rows = new SparseRow[]{null, null, row}; - setField(block, "_rows", rows); + SparseBlockMCSR sblock = new SparseBlockMCSR(getFixedSparseBlock()); + int[] indexes = new int[]{3, 2}; // unsorted + setField(sblock.getRows()[3], "indexes", indexes); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 10, 3, true)); - + () -> sblock.checkValidity(4, 4, 6, false)); assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); } @Test public void testSparseBlockMCSCInvalidIndices() { - SparseBlockMCSC block = new SparseBlockMCSC(10, 3); - - SparseRow col = new SparseRowVector(new double[]{1., 1., 1.}, new int[]{-1, 0, 2}); - SparseRow[] cols = new SparseRow[]{null, null, col}; - setField(block, "_columns", cols); + SparseBlockMCSC sblock = new SparseBlockMCSC(getFixedSparseBlock()); + int[] indexes = sblock.getCols()[3].indexes(); + indexes[0] = -1; RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(10, 3, 3, true)); - + () -> sblock.checkValidity(4, 4, 6, false)); assertTrue(ex.getMessage().startsWith("Invalid index")); } @Test public void testSparseBlockMCSRInvalidIndices() { - SparseBlockMCSR block = new SparseBlockMCSR(3, 10); - - SparseRow row = new SparseRowVector(new double[]{1., 1., 1.}, new int[]{-1, 0, 1}); - SparseRow[] rows = new SparseRow[]{null, null, row}; - setField(block, "_rows", rows); + SparseBlockMCSR sblock = new SparseBlockMCSR(getFixedSparseBlock()); + int[] indexes = sblock.getRows()[3].indexes(); + indexes[0] = -1; RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 10, 3, true)); - + () -> sblock.checkValidity(4, 4, 6, false)); assertTrue(ex.getMessage().startsWith("Invalid index")); } @@ -438,29 +335,23 @@ public void testSparseBlockDCSRInvalidValue() { @Test public void testSparseBlockMCSCInvalidValue() { - SparseBlockMCSC block = new SparseBlockMCSC(10, 3); - - SparseRow col = new SparseRowVector(new double[]{1., 1., 0.}, new int[]{0, 1, 2}); - SparseRow[] cols = new SparseRow[]{null, null, col}; - setField(block, "_columns", cols); + SparseBlockMCSC sblock = new SparseBlockMCSC(getFixedSparseBlock()); + double[] values = sblock.valuesCol(3); + values[0] = 0; RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(10, 3, 3, true)); - + () -> sblock.checkValidity(4, 4, 6, true)); assertTrue(ex.getMessage().startsWith("The values are expected to be non zeros")); } @Test public void testSparseBlockMCSRInvalidValue() { - SparseBlockMCSR block = new SparseBlockMCSR(3, 10); - - SparseRow row = new SparseRowVector(new double[]{1., 1., 0.}, new int[]{0, 1, 2}); - SparseRow[] rows = new SparseRow[]{null, null, row}; - setField(block, "_rows", rows); + SparseBlockMCSR sblock = new SparseBlockMCSR(getFixedSparseBlock()); + double[] values = sblock.values(3); + values[0] = 0; RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 10, 3, true)); - + () -> sblock.checkValidity(4, 4, 6, true)); assertTrue(ex.getMessage().startsWith("The values are expected to be non zeros")); } @@ -492,73 +383,63 @@ public void testSparseBlockDCSRInvalidIndex() { @Test public void testSparseBlockCOOCapacityExceedsAllowedLimit() { - SparseBlockCOO block = new SparseBlockCOO(3, 50); + SparseBlockCOO sblock = new SparseBlockCOO(3, 50); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 3, 0, false)); - - // RESIZE_FACTOR1 is 2 + () -> sblock.checkValidity(3, 3, 0, false)); assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); } @Test public void testSparseBlockCSCCapacityExceedsAllowedLimit() { - SparseBlockCSC block = new SparseBlockCSC(3, 3, 50); + SparseBlockCSC sblock = new SparseBlockCSC(3, 3, 50); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 3, 0, false)); - - // RESIZE_FACTOR1 is 2 + () -> sblock.checkValidity(3, 3, 0, false)); assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); } @Test public void testSparseBlockCSRCapacityExceedsAllowedLimit() { - SparseBlockCSR block = new SparseBlockCSR(3, 50, 0); + SparseBlockCSR sblock = new SparseBlockCSR(3, 50, 0); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 3, 0, false)); - - // RESIZE_FACTOR1 is 2 + () -> sblock.checkValidity(3, 3, 0, false)); assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); } @Test public void testSparseBlockDCSRCapacityExceedsAllowedLimit() { - SparseBlockDCSR block = new SparseBlockDCSR(3, 50); + SparseBlockDCSR sblock = new SparseBlockDCSR(3, 50); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 3, 0, false)); - - // RESIZE_FACTOR1 is 2 + () -> sblock.checkValidity(3, 3, 0, false)); assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); } @Test public void testSparseBlockMCSCCapacityExceedsAllowedLimit() { - SparseBlockMCSC block = new SparseBlockMCSC(10, 3); + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); - SparseRow col = new SparseRowVector(new double[]{1., 1., 1., 1., 1.}, new int[]{0, 1, 2, 3, 4}); - SparseRow[] cols = new SparseRow[]{null, null, col}; - setField(block, "_columns", cols); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlockMCSC sblock = new SparseBlockMCSC(srtmp); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(10, 3, 2, true)); - + () -> sblock.checkValidity(_rows, _cols, 2, true)); assertTrue(ex.getMessage().startsWith("The capacity is larger than nnz times a resize factor")); } @Test public void testSparseBlockMCSRCapacityExceedsAllowedLimit() { - SparseBlockMCSR block = new SparseBlockMCSR(3, 10); + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); - SparseRow row = new SparseRowVector(new double[]{1., 1., 1., 1., 1.}, new int[]{0, 1, 2, 3, 4}); - SparseRow[] rows = new SparseRow[]{null, null, row}; - setField(block, "_rows", rows); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlockMCSR sblock = new SparseBlockMCSR(srtmp); RuntimeException ex = assertThrows(RuntimeException.class, - () -> block.checkValidity(3, 10, 2, true)); - + () -> sblock.checkValidity(_rows, _cols, 2, true)); assertTrue(ex.getMessage().startsWith("The capacity is larger than nnz times a resize factor")); } @@ -572,47 +453,44 @@ private void runSparseBlockValidTest(SparseBlock.Type btype) { assertTrue("should pass checkValidity", sblock.checkValidity(_rows, _cols, sblock.size(), true)); } - private void runSparseBlockInvalidDimensionsTest(SparseBlock block) { + private void runSparseBlockInvalidDimensionsTest(SparseBlock sblock) { RuntimeException ex1 = assertThrows(RuntimeException.class, - () -> block.checkValidity(-1, 1, 0, false)); + () -> sblock.checkValidity(-1, 1, 0, false)); assertTrue(ex1.getMessage().startsWith("Invalid block dimensions")); RuntimeException ex2 = assertThrows(RuntimeException.class, - () -> block.checkValidity(1, -1, 0, false)); + () -> sblock.checkValidity(1, -1, 0, false)); assertTrue(ex2.getMessage().startsWith("Invalid block dimensions")); } - private void runSparseBlockInvalidValueTest(SparseBlock.Type btype) { - double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); - - MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); - SparseBlock srtmp = mbtmp.getSparseBlock(); + private void runSparseBlockInvalidIndexTest(SparseBlock.Type btype, String indexName) { + SparseBlock srtmp = getFixedSparseBlock(); SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); - double[] values = (double[]) getField(sblock, "_values"); - values[values.length-1] = 0.; - setField(sblock, "_values", Arrays.copyOfRange(values, 0, values.length)); + int[] indexes = (int[]) getField(sblock, indexName); + indexes[0] = -1; RuntimeException ex = assertThrows(RuntimeException.class, - () -> sblock.checkValidity(_rows, _cols, sblock.size(), true)); - assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + () -> sblock.checkValidity(4, 4, 6, true)); + assertTrue(ex.getMessage().startsWith("Invalid index at pos")); } - - private void runSparseBlockInvalidIndexTest(SparseBlock.Type btype, String indexName) { - double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); - - MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); - SparseBlock srtmp = mbtmp.getSparseBlock(); + private void runSparseBlockInvalidValueTest(SparseBlock.Type btype) { + SparseBlock srtmp = getFixedSparseBlock(); SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); - int[] indexes = (int[]) getField(sblock, indexName); - indexes[0] = -1; - setField(sblock, indexName, Arrays.copyOfRange(indexes, 0, indexes.length)); + double[] values = (double[]) getField(sblock, "_values"); + values[0] = 0; RuntimeException ex = assertThrows(RuntimeException.class, - () -> sblock.checkValidity(_rows, _cols, sblock.size(), true)); - assertTrue(ex.getMessage().startsWith("Invalid index at pos")); + () -> sblock.checkValidity(4, 4, 6, true)); + assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + } + + private SparseBlock getFixedSparseBlock(){ + double[][] A = new double[][] {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 1}, {0, 0, 1, 1}}; + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + return mbtmp.getSparseBlock(); } private static void setField(Object obj, String name, Object value) { From 4aa48026cd5cc586e3bae3935dd6a474833fce18 Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Tue, 20 Jan 2026 13:06:00 +0100 Subject: [PATCH 12/12] fix check validity correct array lengths --- .../sysds/runtime/data/SparseBlockCOO.java | 2 +- .../sysds/runtime/data/SparseBlockCSC.java | 2 +- .../sysds/runtime/data/SparseBlockCSR.java | 2 +- .../sysds/runtime/data/SparseBlockDCSR.java | 2 +- .../sparse/SparseBlockCheckValidityTest.java | 65 ++++++++++++++----- 5 files changed, 51 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java index 8e57559ca3b..c6d104e3c6e 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java @@ -235,7 +235,7 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { } //2. correct array lengths - if(_size != nnz && _cindexes.length < nnz && _rindexes.length < nnz && _values.length < nnz) { + if(_size != nnz || _cindexes.length < nnz || _rindexes.length < nnz || _values.length < nnz) { throw new RuntimeException("Incorrect array lengths."); } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java index 16a3b3163fe..08aaff9eda6 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java @@ -549,7 +549,7 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { } //2. correct array lengths - if(_size != nnz && _ptr.length < clen + 1 && _values.length < nnz && _indexes.length < nnz) { + if(_size != nnz || _ptr.length < clen + 1 || _values.length < nnz || _indexes.length < nnz) { throw new RuntimeException("Incorrect array lengths."); } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java index 947078b2171..33f9273a0d8 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java @@ -938,7 +938,7 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { } //2. correct array lengths - if(_size != nnz && _ptr.length < rlen+1 && _values.length < nnz && _indexes.length < nnz ) { + if( _size != nnz || _ptr.length < rlen+1 || _values.length < nnz || _indexes.length < nnz ) { throw new RuntimeException("Incorrect array lengths."); } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java index 85a0a100297..d0393db18a4 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java @@ -719,7 +719,7 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { } //2. correct array lengths - if (_size != nnz && _rowptr.length != _rowidx.length + 1 && _values.length < nnz && _colidx.length < nnz ) { + if ( _size != nnz || _rowptr.length != _rowidx.length + 1 || _values.length < nnz || _colidx.length < nnz ) { throw new RuntimeException("Incorrect array lengths."); } diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java index 216e4e4ef89..52e53ea5690 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java @@ -35,7 +35,6 @@ import org.junit.Test; import java.lang.reflect.Field; -import java.util.Arrays; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; @@ -114,66 +113,87 @@ public void testSparseBlockMCSRInvalidDimensions() { @Test public void testSparseBlockCOOIncorrectArrayLengths() { - SparseBlockCOO sblock = new SparseBlockCOO(2, 2); + SparseBlockCOO sblock = new SparseBlockCOO(getFixedSparseBlock()); + int size = (int) sblock.size(); RuntimeException ex = assertThrows(RuntimeException.class, - () -> sblock.checkValidity(2, 2, 4, false)); + () -> sblock.checkValidity(4, 4, size+2, false)); assertEquals("Incorrect array lengths.", ex.getMessage()); + + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_cindexes", new int[size-1]); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_rindexes", new int[size-1]); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_values", new double[size-1]); } @Test public void testSparseBlockCSCIncorrectArrayLengths() { - SparseBlockCSC sblock = new SparseBlockCSC(2, 2, 2); + SparseBlockCSC sblock = new SparseBlockCSC(getFixedSparseBlock()); + int size = (int) sblock.size(); RuntimeException ex = assertThrows(RuntimeException.class, - () -> sblock.checkValidity(2, 3, 6, false)); + () -> sblock.checkValidity(4, 4, size+2, false)); assertEquals("Incorrect array lengths.", ex.getMessage()); + + int clen = 4; + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_ptr", new int[clen]); // should be clen+1 + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_values", new double[size-1]); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_indexes", new int[size-1]); } @Test public void testSparseBlockCSRIncorrectArrayLengths() { - SparseBlockCSR sblock = new SparseBlockCSR(2, 2, 1); + SparseBlockCSR sblock = new SparseBlockCSR(getFixedSparseBlock()); + int size = (int) sblock.size(); RuntimeException ex = assertThrows(RuntimeException.class, - () -> sblock.checkValidity(3, 2, 6, false)); + () -> sblock.checkValidity(4, 4, size+2, false)); assertEquals("Incorrect array lengths.", ex.getMessage()); + + int rlen = sblock.numRows(); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_ptr", new int[rlen]); // should be rlen+1 + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_values", new double[size-1]); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_indexes", new int[size-1]); } @Test public void testSparseBlockDCSRIncorrectArrayLengths() { - SparseBlockDCSR sblock = new SparseBlockDCSR(2, 1); - - // cut off last value - int[] rowptr = (int[]) getField(sblock,"_rowptr"); - setField(sblock, "_rowptr", Arrays.copyOfRange(rowptr, 0, rowptr.length-1)); + SparseBlockDCSR sblock = new SparseBlockDCSR(getFixedSparseBlock()); + int size = (int) sblock.size(); RuntimeException ex = assertThrows(RuntimeException.class, - () -> sblock.checkValidity(3, 2, 6, false)); + () -> sblock.checkValidity(4, 4, size+2, false)); assertEquals("Incorrect array lengths.", ex.getMessage()); + + int rows = sblock.numRows(); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_rowptr", new int[rows]); // should be rows+1 + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_colidx", new int[size-1]); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_values", new double[size-1]); } @Test public void testSparseBlockMCSCIncorrectArrayLengths() { - SparseBlockMCSC sblock = new SparseBlockMCSC(2, 2); + SparseBlockMCSC sblock = new SparseBlockMCSC(getFixedSparseBlock()); + int size = (int) sblock.size(); RuntimeException ex = assertThrows(RuntimeException.class, - () -> sblock.checkValidity(3, 2, 1, false)); + () -> sblock.checkValidity(4, 4, size+2, false)); assertTrue(ex.getMessage().startsWith("Incorrect size")); } @Test public void testSparseBlockMCSRIncorrectArrayLengths() { - SparseBlockMCSR sblock = new SparseBlockMCSR(2, 2); + SparseBlockMCSR sblock = new SparseBlockMCSR(getFixedSparseBlock()); + int size = (int) sblock.size(); RuntimeException ex = assertThrows(RuntimeException.class, - () -> sblock.checkValidity(3, 2, 1, false)); + () -> sblock.checkValidity(4, 4, size+2, false)); assertTrue(ex.getMessage().startsWith("Incorrect size")); } @Test public void testSparseBlockCOOUnsortedRowIndices() { SparseBlockCOO sblock = new SparseBlockCOO(getFixedSparseBlock()); - int[] r = new int[]{0, 2, 1, 3}; // unsorted + int[] r = new int[]{0, 2, 1, 2, 3, 3}; // unsorted setField(sblock, "_rindexes", r); RuntimeException ex = assertThrows(RuntimeException.class, @@ -487,6 +507,15 @@ private void runSparseBlockInvalidValueTest(SparseBlock.Type btype) { assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); } + private void checkValidityFailsWhenArrayLengthIsTemporarilyModified(SparseBlock sblock, String name, Object value){ + Object old = getField(sblock, name); + setField(sblock, name, value); + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertEquals("Incorrect array lengths.", ex.getMessage()); + setField(sblock, name, old); + } + private SparseBlock getFixedSparseBlock(){ double[][] A = new double[][] {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 1}, {0, 0, 1, 1}}; MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);