diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java index 864569358f6..ac2876d4c44 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java @@ -26,6 +26,7 @@ import java.util.List; import org.apache.sysds.runtime.matrix.data.IJV; +import org.apache.sysds.runtime.util.UtilFunctions; /** * This SparseBlock is an abstraction for different sparse matrix formats. @@ -94,6 +95,12 @@ public enum Type { * @param r row index */ public abstract void compact(int r); + + /** + * In-place compaction of non-zero-entries; removes zero entries + * and shifts non-zero entries to the left if necessary. + */ + public abstract void compact(); //////////////////////// //obtain basic meta data @@ -501,16 +508,20 @@ public boolean contains(double pattern, int rl, int ru) { } public List contains(double[] pattern, boolean earlyAbort) { + int nnz = UtilFunctions.computeNnz(pattern, 0, pattern.length); List ret = new ArrayList<>(); int rlen = numRows(); + for( int i=0; i0 && posFIndexGTE(_curRow, cl) < 0)) ) + while(_curRow < _rlen){ + if(isEmpty(_curRow)){ + _curRow++; + continue; + } + + int pos = (cl == 0)? 0 : posFIndexGTE(_curRow, cl); + if(pos < 0){ + _curRow++; + continue; + } + + int sizeRow = size(_curRow); + int endPos = (_cu == Integer.MAX_VALUE)? sizeRow : posFIndexGTE(_curRow, _cu); + if(endPos < 0) endPos = sizeRow; + + if(pos < endPos){ + _curColIx = pos(_curRow)+pos; + _curIndexes = indexes(_curRow); + _curValues = values(_curRow); + return; + } _curRow++; - if(_curRow >= _rlen) - _noNext = true; - else { - _curColIx = (cl==0) ? - pos(_curRow) : posFIndexGTE(_curRow, cl); - _curIndexes = indexes(_curRow); - _curValues = values(_curRow); } + _noNext = true; } } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java index c4e60c10cfd..c6d104e3c6e 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java @@ -193,6 +193,20 @@ public void compact(int r) { //do nothing everything preallocated } + @Override + public void compact() { + int pos = 0; + for(int i=0; i< _values.length; i++) { + if(_values[i] != 0){ + _values[pos] = _values[i]; + _rindexes[pos] = _rindexes[i]; + _cindexes[pos] = _cindexes[i]; + pos++; + } + } + _size = pos; + } + @Override public int numRows() { return _rlen; @@ -221,12 +235,12 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { } //2. correct array lengths - if(_size != nnz && _cindexes.length < nnz && _rindexes.length < nnz && _values.length < nnz) { + if(_size != nnz || _cindexes.length < nnz || _rindexes.length < nnz || _values.length < nnz) { throw new RuntimeException("Incorrect array lengths."); } //3.1. sort order of row indices - for( int i=1; i<=nnz; i++ ) { + for( int i=1; i= _cindexes[k] ) + for(int k=apos+1; k _cindexes[k]) throw new RuntimeException("Wrong sparse row ordering: " + k + " "+_cindexes[k-1]+" "+_cindexes[k]); - for( int k=apos; k Integer.MAX_VALUE) - throw new RuntimeException("SparseBlockCSC supports nnz<=Integer.MAX_VALUE but got " + _size); + long size = sblock.size(); + if(size > Integer.MAX_VALUE) + throw new RuntimeException("SparseBlockCSC supports nnz<=Integer.MAX_VALUE but got " + size); //special case SparseBlockCSC if(sblock instanceof SparseBlockCSC) { @@ -223,27 +225,6 @@ public SparseBlockCSC(int cols, int[] rowInd, int[] colInd, double[] values) { } - public SparseBlockCSC(int cols, int nnz, int[] rowInd) { - - _clenInferred = cols; - _ptr = new int[cols + 1]; - _indexes = Arrays.copyOf(rowInd, nnz); - _values = new double[nnz]; - Arrays.fill(_values, 1); - _size = nnz; - - //single-pass construction of col pointers - //and copy of row indexes if necessary - for(int i = 0, pos = 0; i < cols; i++) { - if(rowInd[i] >= 0) { - if(cols > nnz) - _indexes[pos] = rowInd[i]; - pos++; - } - _ptr[i + 1] = pos; - } - } - /** * Initializes the CSC sparse block from an ordered input stream of ultra-sparse ijv triples. * @@ -288,7 +269,6 @@ public void initSparse(int clen, int nnz, DataInput in) throws IOException { // Allocate space if necessary if(_values.length < nnz) { resize(newCapacity(nnz)); - System.out.println("hallo"); } // Read sparse columns, append and update pointers _ptr[0] = 0; @@ -377,12 +357,31 @@ public void compact(int r) { //do nothing everything preallocated } + @Override + public void compact() { + int pos = 0; + for(int i=0; i -1) return _rlen; else { - int rlen = Arrays.stream(_indexes).max().getAsInt(); + int rlen = Arrays.stream(_indexes).max().getAsInt()+1; _rlen = rlen; return rlen; } @@ -550,12 +549,12 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { } //2. correct array lengths - if(_size != nnz && _ptr.length < clen + 1 && _values.length < nnz && _indexes.length < nnz) { + if(_size != nnz || _ptr.length < clen + 1 || _values.length < nnz || _indexes.length < nnz) { throw new RuntimeException("Incorrect array lengths."); } - //3. non-decreasing row pointers - for(int i = 1; i < clen; i++) { + //3. non-decreasing col pointers + for(int i = 1; i <= clen; i++) { if(_ptr[i - 1] > _ptr[i] && strict) throw new RuntimeException( "Column pointers are decreasing at column: " + i + ", with pointers " + _ptr[i - 1] + " > " + @@ -569,10 +568,9 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { for(int k = apos + 1; k < apos + alen; k++) if(_indexes[k - 1] >= _indexes[k]) throw new RuntimeException( - "Wrong sparse column ordering: " + k + " " + _indexes[k - 1] + " " + _indexes[k]); - for(int k = apos; k < apos + alen; k++) - if(_values[k] == 0) - throw new RuntimeException("Wrong sparse column: zero at " + k + " at row index " + _indexes[k]); + "Wrong sparse column ordering, at column=" + i + ", pos=" + k + " with row indexes " + + _indexes[k - 1] + ">=" + _indexes[k] + ); } //5. non-existing zero values @@ -581,11 +579,13 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { throw new RuntimeException( "The values array should not contain zeros." + " The " + i + "th value is " + _values[i]); } + if(_indexes[i] < 0) + throw new RuntimeException("Invalid index at pos=" + i); } //6. a capacity that is no larger than nnz times resize factor. int capacity = _values.length; - if(capacity > nnz * RESIZE_FACTOR1) { + if(capacity > INIT_CAPACITY && capacity > nnz * RESIZE_FACTOR1) { throw new RuntimeException( "Capacity is larger than the nnz times a resize factor." + " Current size: " + capacity + ", while Expected size:" + nnz * RESIZE_FACTOR1); @@ -938,7 +938,7 @@ public void deleteIndexRangeCol(int c, int rl, int ru) { int len = sizeCol(c); int end = internPosFIndexGTECol(ru, c); if(end < 0) //delete all remaining - end = start + len; + end = posCol(c) + len; //overlapping array copy (shift rhs values left) System.arraycopy(_indexes, end, _indexes, start, _size - end); @@ -1059,7 +1059,7 @@ public int posFIndexGTCol(int r, int c) { @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append("SparseBlockCSR: clen="); + sb.append("SparseBlockCSC: clen="); sb.append(numCols()); sb.append(", nnz="); sb.append(size()); diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java index a40c567dfb2..33f9273a0d8 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java @@ -350,7 +350,8 @@ public void allocate(int r, int ennz, int maxnnz) { public void compact(int r) { //do nothing everything preallocated } - + + @Override public void compact() { int pos = 0; for(int i=0; i _ptr[i] && strict) throw new RuntimeException("Row pointers are decreasing at row: "+i + ", with pointers "+_ptr[i-1]+" > "+_ptr[i]); @@ -956,10 +957,6 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { if( _indexes[k-1] >= _indexes[k] ) throw new RuntimeException("Wrong sparse row ordering: " + k + " "+_indexes[k-1]+" "+_indexes[k]); - for( int k=apos; k _rowidx[i]) throw new RuntimeException("Row indices are decreasing at row: " + i + ", with indices " + _rowidx[i-1] + " > " +_rowidx[i]); } - for (int i = 1; i < _rowptr.length; i++ ) { + for (int i = 1; i < _nnzr+1; i++ ) { if (_rowptr[i - 1] > _rowptr[i]) { throw new RuntimeException("Row pointers are decreasing at row: " + i + ", with pointers " + _rowptr[i-1] + " > " +_rowptr[i]); @@ -724,19 +738,14 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { } //4. sorted column indexes per row - for ( int rowIdx = 0; rowIdx < _rowidx.length; rowIdx++ ) { - int apos = _rowidx[rowIdx]; - int alen = _rowidx[rowIdx+1] - apos; + for (int i = 0; i < _nnzr; i++) { + int apos = _rowptr[i]; + int alen = _rowptr[i+1] - apos; for( int k = apos + 1; k < apos + alen; k++) if( _colidx[k-1] >= _colidx[k] ) throw new RuntimeException("Wrong sparse row ordering: " + k + " " + _colidx[k-1] + " " + _colidx[k]); - - for( int k=apos; k columnSizes = new HashMap<>(); - if(_clenInferred == -1) { - for(SparseRow row : originalRows) { - if(row != null && !row.isEmpty()) { - for(int i = 0; i < row.size(); i++) { - int rowIndex = row.indexes()[i]; - columnSizes.put(rowIndex, columnSizes.getOrDefault(rowIndex, 0) + 1); - } + for(SparseRow row : originalRows) { + if(row != null && !row.isEmpty()) { + for(int i = 0; i < row.size(); i++) { + int rowIndex = row.indexes()[i]; + columnSizes.put(rowIndex, columnSizes.getOrDefault(rowIndex, 0) + 1); } } + } + if(_clenInferred == -1) { clen = columnSizes.keySet().stream().max(Integer::compare).orElseThrow(NoSuchElementException::new); _columns = new SparseRow[clen + 1]; } @@ -113,7 +113,14 @@ else if(columnSize == 1) { } rowPosition++; } - + } + else if(sblock instanceof SparseBlockCSC) { + clen = ((SparseBlockCSC) sblock).numCols(); + _columns = new SparseRow[clen]; + for(int i = 0; i < clen; i++) { + if(!((SparseBlockCSC) sblock).isEmptyCol(i)) + _columns[i] = ((SparseBlockCSC) sblock).getCol(i); + } } // general case SparseBlock else { @@ -259,7 +266,7 @@ public void allocate(int r, int nnz) { } public void allocateCol(int c, int nnz) { - if(!isAllocated(c)) { + if(!isAllocatedCol(c)) { _columns[c] = (nnz == 1) ? new SparseRowScalar() : new SparseRowVector(nnz); } } @@ -270,7 +277,7 @@ public void allocate(int r, int ennz, int maxnnz) { } public void allocateCol(int c, int ennz, int maxnnz) { - if(!isAllocated(c)) { + if(!isAllocatedCol(c)) { _columns[c] = (ennz == 1) ? new SparseRowScalar() : new SparseRowVector(ennz, maxnnz); } } @@ -283,7 +290,7 @@ public void compact(int r) { } public void compactCol(int c) { - if(isAllocated(c)) { + if(isAllocatedCol(c)) { if(_columns[c] instanceof SparseRowVector && _columns[c].size() > SparseBlock.INIT_CAPACITY && _columns[c].size() * SparseBlock.RESIZE_FACTOR1 < ((SparseRowVector) _columns[c]).capacity()) { ((SparseRowVector) _columns[c]).compact(); @@ -296,6 +303,24 @@ else if(_columns[c] instanceof SparseRowScalar) { } } + @Override + public void compact() { + for(int i = 0; i < numCols(); i++) { + if(isAllocatedCol(i)) { + if(_columns[i] instanceof SparseRowVector) { + _columns[i].compact(); + if(_columns[i].isEmpty()) + _columns[i] = null; + } + else if(_columns[i] instanceof SparseRowScalar) { + SparseRowScalar s = (SparseRowScalar) _columns[i]; + if(s.getValue() == 0) + _columns[i] = null; + } + } + } + } + @Override public int numRows() { return _rlen; @@ -386,7 +411,7 @@ public int size(int r) { public int sizeCol(int c) { //prior check with isEmpty(r) expected - return isAllocated(c) ? _columns[c].size() : 0; + return isAllocatedCol(c) ? _columns[c].size() : 0; } @Override @@ -404,7 +429,7 @@ public long size(int rl, int ru) { public long sizeCol(int cl, int cu) { long nnz = 0; for(int i = cl; i < cu; i++) { - nnz += isAllocated(i) ? _columns[i].size() : 0; + nnz += isAllocatedCol(i) ? _columns[i].size() : 0; } return nnz; } @@ -449,31 +474,34 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { //3. Sorted column indices per row for(int i = 0; i < clen; i++) { - if(isEmpty(i)) - continue; + if(isEmptyCol(i)) continue; int apos = pos(i); - int alen = size(i); - int[] aix = indexes(i); - double[] avals = values(i); - for(int k = apos + 1; k < apos + alen; k++) { - if(aix[k - 1] >= aix[k] | aix[k - 1] < 0) { + int alen = sizeCol(i); + int[] aix = indexesCol(i); + double[] avals = valuesCol(i); + + int prevRow = -1; + for(int k = apos; k < apos + alen; k++) { + if(aix[k] < 0) + throw new RuntimeException("Invalid index, at column=" + i + ", pos=" + k); + if(aix[k] <= prevRow) throw new RuntimeException( "Wrong sparse column ordering, at column=" + i + ", pos=" + k + " with row indexes " + - aix[k - 1] + ">=" + aix[k]); - } - if(avals[k] == 0) { + prevRow + ">=" + aix[k]); + if(avals[k] == 0) throw new RuntimeException( "The values are expected to be non zeros " + "but zero at column: " + i + ", row pos: " + k); - } + prevRow = aix[k]; } } + //4. A capacity that is no larger than nnz times resize factor for(int i = 0; i < clen; i++) { long max_size = (long) Math.max(nnz * RESIZE_FACTOR1, INIT_CAPACITY); - if(!isEmpty(i) && values(i).length > max_size) { + if(!isEmptyCol(i) && valuesCol(i).length > max_size) { throw new RuntimeException( "The capacity is larger than nnz times a resize factor(=2). " + "Actual length = " + - values(i).length + ", should not exceed " + max_size); + valuesCol(i).length + ", should not exceed " + max_size); } } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java index f94b6bf7f45..b797560c428 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java @@ -199,6 +199,22 @@ else if(_rows[r] instanceof SparseRowScalar) { } } } + + @Override + public void compact() { + for(int i = 0; i < numRows(); i++) { + if(isAllocated(i)) { + if(_rows[i] instanceof SparseRowVector) { + _rows[i].compact(); + if(_rows[i].isEmpty()) _rows[i] = null; + } + else if(_rows[i] instanceof SparseRowScalar) { + SparseRowScalar s = (SparseRowScalar) _rows[i]; + if(s.getValue() == 0) _rows[i] = null; + } + } + } + } @Override public int numRows() { @@ -238,13 +254,20 @@ public boolean checkValidity(int rlen, int clen, long nnz, boolean strict) { int alen = size(i); int[] aix = indexes(i); double[] avals = values(i); - for (int k = apos + 1; k < apos + alen; k++) { - if (aix[k-1] >= aix[k] | aix[k-1] < 0 ) - throw new RuntimeException("Wrong sparse row ordering, at row="+i+", pos="+k - + " with column indexes " + aix[k-1] + ">=" + aix[k]); - if (avals[k] == 0) - throw new RuntimeException("The values are expected to be non zeros " - + "but zero at row: "+ i + ", col pos: " + k); + + int prevCol = -1; + for (int k = apos; k < apos + alen; k++) { + if(aix[k] < 0) + throw new RuntimeException( + "Invalid index, at column=" + i + ", pos=" + k); + if(aix[k] <= prevCol) + throw new RuntimeException( + "Wrong sparse row ordering, at row=" + i + ", pos=" + k + " with column indexes " + + prevCol + ">=" + aix[k]); + if(avals[k] == 0) + throw new RuntimeException( + "The values are expected to be non zeros " + "but zero at column: " + i + ", row pos: " + k); + prevCol = aix[k]; } } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java b/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java index 50229e15df6..e59bf2402ee 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java @@ -190,6 +190,10 @@ public void setEstimatedNzs(int estnnz){ estimatedNzs = estnnz; } + public int getEstimatedNzs(){ + return estimatedNzs; + } + private void recap(int newCap) { if( newCap<=values.length ) return; @@ -314,7 +318,7 @@ public int searchIndexesFirstLTE(int col) { //search lt col index (see binary search) index = Math.abs( index+1 ); - return (index-1 < size) ? index-1 : -1; + return (index-1 >= 0) ? index-1 : -1; } @Override diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java index 3c2ed30adc6..006fac8f03a 100644 --- a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java @@ -270,6 +270,12 @@ else if( i<37 ) {//CSR/COO different after update pos Assert.fail("Wrong row alignment indicated: "+rowsAligned37+", expected: "+positive); if( !rowsAlignedRest ) Assert.fail("Wrong row alignment rest indicated: false."); + + //init third sparse block with different number of rows + SparseBlock sblock3 =SparseBlockFactory.createSparseBlock(btype, rows+1); + if (sblock.isAligned(sblock3)) { + Assert.fail("Wrong alignment different rows indicated: true."); + } } catch(Exception ex) { ex.printStackTrace(); diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java new file mode 100644 index 00000000000..52e53ea5690 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java @@ -0,0 +1,544 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockCOO; +import org.apache.sysds.runtime.data.SparseBlockCSC; +import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockDCSR; +import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.apache.sysds.runtime.data.SparseBlockMCSC; +import org.apache.sysds.runtime.data.SparseBlockMCSR; + +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.lang.reflect.Field; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class SparseBlockCheckValidityTest extends AutomatedTestBase +{ + private final static int _rows = 123; + private final static int _cols = 97; + private final static double _sparsity = 0.22; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSparseBlockCOOValid() { + runSparseBlockValidTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockCSCValid() { + runSparseBlockValidTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockCSRValid() { + runSparseBlockValidTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockDCSRValid() { + runSparseBlockValidTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockMCSCValid() { + runSparseBlockValidTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockMCSRValid() { + runSparseBlockValidTest(SparseBlock.Type.MCSR); + } + + @Test + public void testSparseBlockCOOInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockCOO(-1, 0)); + } + + @Test + public void testSparseBlockCSCInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockCSC(-1, 0)); + } + + @Test + public void testSparseBlockCSRInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockCSR(-1, 0)); + } + + @Test + public void testSparseBlockDCSRInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockDCSR(0, 0)); + } + + @Test + public void testSparseBlockMCSCInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockMCSC(-1, 0)); + } + + @Test + public void testSparseBlockMCSRInvalidDimensions() { + runSparseBlockInvalidDimensionsTest(new SparseBlockMCSR(0, -1)); + } + + @Test + public void testSparseBlockCOOIncorrectArrayLengths() { + SparseBlockCOO sblock = new SparseBlockCOO(getFixedSparseBlock()); + + int size = (int) sblock.size(); + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, size+2, false)); + assertEquals("Incorrect array lengths.", ex.getMessage()); + + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_cindexes", new int[size-1]); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_rindexes", new int[size-1]); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_values", new double[size-1]); + } + + @Test + public void testSparseBlockCSCIncorrectArrayLengths() { + SparseBlockCSC sblock = new SparseBlockCSC(getFixedSparseBlock()); + + int size = (int) sblock.size(); + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, size+2, false)); + assertEquals("Incorrect array lengths.", ex.getMessage()); + + int clen = 4; + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_ptr", new int[clen]); // should be clen+1 + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_values", new double[size-1]); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_indexes", new int[size-1]); + } + + @Test + public void testSparseBlockCSRIncorrectArrayLengths() { + SparseBlockCSR sblock = new SparseBlockCSR(getFixedSparseBlock()); + + int size = (int) sblock.size(); + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, size+2, false)); + assertEquals("Incorrect array lengths.", ex.getMessage()); + + int rlen = sblock.numRows(); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_ptr", new int[rlen]); // should be rlen+1 + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_values", new double[size-1]); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_indexes", new int[size-1]); + } + + @Test + public void testSparseBlockDCSRIncorrectArrayLengths() { + SparseBlockDCSR sblock = new SparseBlockDCSR(getFixedSparseBlock()); + + int size = (int) sblock.size(); + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, size+2, false)); + assertEquals("Incorrect array lengths.", ex.getMessage()); + + int rows = sblock.numRows(); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_rowptr", new int[rows]); // should be rows+1 + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_colidx", new int[size-1]); + checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock, "_values", new double[size-1]); + } + + @Test + public void testSparseBlockMCSCIncorrectArrayLengths() { + SparseBlockMCSC sblock = new SparseBlockMCSC(getFixedSparseBlock()); + + int size = (int) sblock.size(); + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, size+2, false)); + assertTrue(ex.getMessage().startsWith("Incorrect size")); + } + + @Test + public void testSparseBlockMCSRIncorrectArrayLengths() { + SparseBlockMCSR sblock = new SparseBlockMCSR(getFixedSparseBlock()); + + int size = (int) sblock.size(); + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, size+2, false)); + assertTrue(ex.getMessage().startsWith("Incorrect size")); + } + + @Test + public void testSparseBlockCOOUnsortedRowIndices() { + SparseBlockCOO sblock = new SparseBlockCOO(getFixedSparseBlock()); + int[] r = new int[]{0, 2, 1, 2, 3, 3}; // unsorted + setField(sblock, "_rindexes", r); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertEquals("Wrong sorted order of row indices", ex.getMessage()); + } + + @Test + public void testSparseBlockCSCDecreasingColPointers() { + SparseBlockCSC sblock = new SparseBlockCSC(getFixedSparseBlock()); + int[] ptr = new int[]{0, 2, 1, 4, 6}; // unsorted + setField(sblock, "_ptr", ptr); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, true)); + assertTrue(ex.getMessage().startsWith("Column pointers are decreasing at column")); + } + + @Test + public void testSparseBlockCSRDecreasingRowPointers() { + SparseBlockCSR sblock = new SparseBlockCSR(getFixedSparseBlock()); + int[] ptr = new int[]{0, 2, 1, 4, 6}; // unsorted + setField(sblock, "_ptr", ptr); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, true)); + assertTrue(ex.getMessage().startsWith("Row pointers are decreasing at row")); + } + + @Test + public void testSparseBlockDCSRDecreasingRowIndices() { + SparseBlockDCSR sblock = new SparseBlockDCSR(getFixedSparseBlock()); + int[] rowIdxs = new int[]{0, 2, 1, 3}; // unsorted + setField(sblock, "_rowidx", rowIdxs); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertTrue(ex.getMessage().startsWith("Row indices are decreasing at row")); + } + + @Test + public void testSparseBlockDCSRDecreasingRowPointers() { + SparseBlockDCSR sblock = new SparseBlockDCSR(getFixedSparseBlock()); + int[] rowPtr = new int[]{0, 1, 2, 6, 4}; // unsorted + setField(sblock, "_rowptr", rowPtr); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertTrue(ex.getMessage().startsWith("Row pointers are decreasing at row")); + } + + @Test + public void testSparseBlockCOOUnsortedColumnIndicesWithinRow() { + SparseBlockCOO sblock = new SparseBlockCOO(getFixedSparseBlock()); + int[] c = new int[]{0, 1, 3, 4, 4, 3}; // unsorted for last row + setField(sblock, "_cindexes", c); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); + } + + @Test + public void testSparseBlockCSCUnsortedRowIndicesWithinColumn() { + SparseBlockCSC sblock = new SparseBlockCSC(getFixedSparseBlock()); + int[] idxs = new int[]{0, 1, 2, 3, 3, 2}; // unsorted for last col + setField(sblock, "_indexes", idxs); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertTrue(ex.getMessage().startsWith("Wrong sparse column ordering")); + } + + @Test + public void testSparseBlockCSRUnsortedColumnIndicesWithinRow() { + SparseBlockCSR sblock = new SparseBlockCSR(getFixedSparseBlock()); + int[] idxs = new int[]{0, 1, 2, 3, 3, 2}; // unsorted for last row + setField(sblock, "_indexes", idxs); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); + } + + @Test + public void testSparseBlockDCSRUnsortedColumnIndicesWithinRow() { + SparseBlockDCSR sblock = new SparseBlockDCSR(getFixedSparseBlock()); + int[] colIdxs = new int[]{0, 1, 2, 3, 3, 2}; // unsorted for last row + setField(sblock, "_colidx", colIdxs); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); + } + + @Test + public void testSparseBlockMCSCUnsortedRowIndicesWithinColumn() { + SparseBlockMCSC sblock = new SparseBlockMCSC(getFixedSparseBlock()); + int[] indexes = new int[]{3, 2}; // unsorted + setField(sblock.getCols()[3], "indexes", indexes); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertTrue(ex.getMessage().startsWith("Wrong sparse column ordering")); + } + + @Test + public void testSparseBlockMCSRUnsortedColumnIndicesWithinRow() { + SparseBlockMCSR sblock = new SparseBlockMCSR(getFixedSparseBlock()); + int[] indexes = new int[]{3, 2}; // unsorted + setField(sblock.getRows()[3], "indexes", indexes); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertTrue(ex.getMessage().startsWith("Wrong sparse row ordering")); + } + + @Test + public void testSparseBlockMCSCInvalidIndices() { + SparseBlockMCSC sblock = new SparseBlockMCSC(getFixedSparseBlock()); + int[] indexes = sblock.getCols()[3].indexes(); + indexes[0] = -1; + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertTrue(ex.getMessage().startsWith("Invalid index")); + } + + @Test + public void testSparseBlockMCSRInvalidIndices() { + SparseBlockMCSR sblock = new SparseBlockMCSR(getFixedSparseBlock()); + int[] indexes = sblock.getRows()[3].indexes(); + indexes[0] = -1; + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertTrue(ex.getMessage().startsWith("Invalid index")); + } + + @Test + public void testSparseBlockCOOInvalidValue() { + runSparseBlockInvalidValueTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockCSCInvalidValue() { + runSparseBlockInvalidValueTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockCSRInvalidValue() { + runSparseBlockInvalidValueTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockDCSRInvalidValue() { + runSparseBlockInvalidValueTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockMCSCInvalidValue() { + SparseBlockMCSC sblock = new SparseBlockMCSC(getFixedSparseBlock()); + double[] values = sblock.valuesCol(3); + values[0] = 0; + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, true)); + assertTrue(ex.getMessage().startsWith("The values are expected to be non zeros")); + } + + @Test + public void testSparseBlockMCSRInvalidValue() { + SparseBlockMCSR sblock = new SparseBlockMCSR(getFixedSparseBlock()); + double[] values = sblock.values(3); + values[0] = 0; + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, true)); + assertTrue(ex.getMessage().startsWith("The values are expected to be non zeros")); + } + + @Test + public void testSparseBlockCOOInvalidRIndex() { + runSparseBlockInvalidIndexTest(SparseBlock.Type.COO, "_rindexes"); + } + + @Test + public void testSparseBlockCOOInvalidCIndex() { + runSparseBlockInvalidIndexTest(SparseBlock.Type.COO, "_cindexes"); + } + + + @Test + public void testSparseBlockCSCInvalidIndex() { + runSparseBlockInvalidIndexTest(SparseBlock.Type.CSC, "_indexes"); + } + + @Test + public void testSparseBlockCSRInvalidIndex() { + runSparseBlockInvalidIndexTest(SparseBlock.Type.CSR, "_indexes"); + } + + @Test + public void testSparseBlockDCSRInvalidIndex() { + runSparseBlockInvalidIndexTest(SparseBlock.Type.DCSR, "_colidx"); + } + + @Test + public void testSparseBlockCOOCapacityExceedsAllowedLimit() { + SparseBlockCOO sblock = new SparseBlockCOO(3, 50); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(3, 3, 0, false)); + assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); + } + + @Test + public void testSparseBlockCSCCapacityExceedsAllowedLimit() { + SparseBlockCSC sblock = new SparseBlockCSC(3, 3, 50); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(3, 3, 0, false)); + assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); + } + + @Test + public void testSparseBlockCSRCapacityExceedsAllowedLimit() { + SparseBlockCSR sblock = new SparseBlockCSR(3, 50, 0); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(3, 3, 0, false)); + assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); + } + + @Test + public void testSparseBlockDCSRCapacityExceedsAllowedLimit() { + SparseBlockDCSR sblock = new SparseBlockDCSR(3, 50); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(3, 3, 0, false)); + assertTrue(ex.getMessage().startsWith("Capacity is larger than the nnz times a resize factor")); + } + + @Test + public void testSparseBlockMCSCCapacityExceedsAllowedLimit() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlockMCSC sblock = new SparseBlockMCSC(srtmp); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(_rows, _cols, 2, true)); + assertTrue(ex.getMessage().startsWith("The capacity is larger than nnz times a resize factor")); + } + + @Test + public void testSparseBlockMCSRCapacityExceedsAllowedLimit() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlockMCSR sblock = new SparseBlockMCSR(srtmp); + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(_rows, _cols, 2, true)); + assertTrue(ex.getMessage().startsWith("The capacity is larger than nnz times a resize factor")); + } + + private void runSparseBlockValidTest(SparseBlock.Type btype) { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + assertTrue("should pass checkValidity", sblock.checkValidity(_rows, _cols, sblock.size(), true)); + } + + private void runSparseBlockInvalidDimensionsTest(SparseBlock sblock) { + RuntimeException ex1 = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(-1, 1, 0, false)); + assertTrue(ex1.getMessage().startsWith("Invalid block dimensions")); + + RuntimeException ex2 = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(1, -1, 0, false)); + assertTrue(ex2.getMessage().startsWith("Invalid block dimensions")); + } + + private void runSparseBlockInvalidIndexTest(SparseBlock.Type btype, String indexName) { + SparseBlock srtmp = getFixedSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + int[] indexes = (int[]) getField(sblock, indexName); + indexes[0] = -1; + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, true)); + assertTrue(ex.getMessage().startsWith("Invalid index at pos")); + } + + private void runSparseBlockInvalidValueTest(SparseBlock.Type btype) { + SparseBlock srtmp = getFixedSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + double[] values = (double[]) getField(sblock, "_values"); + values[0] = 0; + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, true)); + assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + } + + private void checkValidityFailsWhenArrayLengthIsTemporarilyModified(SparseBlock sblock, String name, Object value){ + Object old = getField(sblock, name); + setField(sblock, name, value); + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(4, 4, 6, false)); + assertEquals("Incorrect array lengths.", ex.getMessage()); + setField(sblock, name, old); + } + + private SparseBlock getFixedSparseBlock(){ + double[][] A = new double[][] {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 1}, {0, 0, 1, 1}}; + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + return mbtmp.getSparseBlock(); + } + + private static void setField(Object obj, String name, Object value) { + try { + Field f = obj.getClass().getDeclaredField(name); + f.setAccessible(true); + f.set(obj, value); + } catch (Exception ex) { + throw new RuntimeException("Reflection failed: " + ex.getMessage()); + } + } + + private static Object getField(Object obj, String name) { + try { + Field f = obj.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.get(obj); + } catch (Exception ex) { + throw new RuntimeException("Reflection failed: " + ex.getMessage()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java new file mode 100644 index 00000000000..19a313e0259 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockCSC; +import org.apache.sysds.runtime.data.SparseBlockMCSC; +import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.data.SparseRowScalar; +import org.apache.sysds.runtime.data.SparseRowVector; + +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; + +public class SparseBlockColTest extends AutomatedTestBase +{ + private final static int _rows = 324; + private final static int _cols = 132; + private final static double _sparsity = 0.3; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSparseBlockCSCGetReset() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockCSC(mbtmp.getSparseBlock())); + runSparseBlockGetResetTest(b, SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockMCSCGetReset() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockMCSC(mbtmp.getSparseBlock())); + runSparseBlockGetResetTest(b, SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockCSCSetSort() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockCSC(mbtmp.getSparseBlock())); + SparseRow[] cols = (new SparseBlockMCSC(mbtmp.getSparseBlock())).getCols(); + runSparseBlockSetSortTest(b, cols); + } + + @Test + public void testSparseBlockMCSCSetSort() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockMCSC(mbtmp.getSparseBlock())); + SparseRow[] cols = (new SparseBlockMCSC(mbtmp.getSparseBlock())).getCols(); + runSparseBlockSetSortTest(b, cols); + } + + @Test + public void testSparseBlockCSCSetDelIdxRange() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockCSC(mbtmp.getSparseBlock())); + SparseRow[] cols = (new SparseBlockMCSC(mbtmp.getSparseBlock())).getCols(); + runSparseBlockSetDelIdxRangeTest(b, cols); + } + + @Test + public void testSparseBlockMCSCSetDelIdxRange() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockColWrapper b = wrap(new SparseBlockMCSC(mbtmp.getSparseBlock())); + SparseRow[] cols = (new SparseBlockMCSC(mbtmp.getSparseBlock())).getCols(); + runSparseBlockSetDelIdxRangeTest(b, cols); + } + + private void runSparseBlockGetResetTest(SparseBlockColWrapper sblock, SparseBlock.Type btype) { + int c = _cols/3; + SparseRow col = sblock.getCol(c); + int size = sblock.sizeCol(c); + Assert.assertEquals(col.size(), size); + + sblock.resetCol(c); + col = sblock.getCol(c); + size = sblock.sizeCol(c); + Assert.assertEquals(0, size); + Assert.assertTrue(col.isEmpty()); + if(btype == SparseBlock.Type.CSC) Assert.assertTrue(col instanceof SparseRowScalar); + + // nothing changes + SparseBlockColWrapper sblock2 = sblock.copy(); + sblock.resetCol(c); + SparseRow col2 = sblock.getCol(c); + Assert.assertArrayEquals(col.indexes(), col2.indexes()); + Assert.assertArrayEquals(col.values(), col2.values(), 0); + Assert.assertEquals(sblock.getObject(), sblock2.getObject()); + } + + private void runSparseBlockSetSortTest(SparseBlockColWrapper sblock, SparseRow[] cols) { + int c = _cols/3; + SparseRow col = cols[c]; + double[] values = col.values().clone(); + int[] indexes = col.indexes().clone(); + int size = col.size(); + + // reverse + for (int i = 0; i < size/2; i++) { + double t = values[i]; + values[i] = values[size-1-i]; + values[size-1-i] = t; + int t2 = indexes[i]; + indexes[i] = indexes[size-1-i]; + indexes[size-1-i] = t2; + } + Assert.assertFalse(Arrays.equals(col.values(), values)); + Assert.assertFalse(Arrays.equals(col.indexes(), indexes)); + + SparseRow col2 = new SparseRowVector(values, indexes); + sblock.resetCol(c); + sblock.setCol(c, col2, true); + Assert.assertArrayEquals(col2.indexes(), sblock.getCol(c).indexes()); + Assert.assertArrayEquals(col2.values(), sblock.getCol(c).values(), 0); + + int nnz = (int) ((SparseBlock) sblock.getObject()).size(); + int rlen = ((SparseBlock) sblock.getObject()).numRows(); + int clen = cols.length; + RuntimeException ex = Assert.assertThrows(RuntimeException.class, + () -> ((SparseBlock) sblock.getObject()).checkValidity(rlen, clen, nnz, true)); + Assert.assertTrue(ex.getMessage().startsWith("Wrong sparse column ordering")); + + sblock.sortCol(c); + Assert.assertTrue(((SparseBlock)sblock.getObject()).checkValidity(rlen, clen, nnz, true)); + Assert.assertArrayEquals(col.indexes(), sblock.getCol(c).indexes()); + Assert.assertArrayEquals(col.values(), sblock.getCol(c).values(), 0); + } + + private void runSparseBlockSetDelIdxRangeTest(SparseBlockColWrapper sblock, SparseRow[] cols) { + int c = _cols/3; + int rl = _rows/4; + int ru = _rows/2; + + SparseRow[] cols2 = Arrays.copyOf(cols, cols.length); + double[] v = getRandomMatrix(1, _rows, -10, 10, 1, 1234)[0]; + for(int i=0; i sblock.checkValidity(_rows, _cols, sblock.size(), true)); + assertTrue(ex.getMessage().startsWith("The values array should not contain zeros")); + long size = sblock.size(); + + sblock.compact(); + + assertTrue("should pass checkValidity", sblock.checkValidity(_rows, _cols, sblock.size(), true)); + assertEquals(size-2, sblock.size()); + } + + private void runSparseBlockModifiedCompactZerosTest(SparseBlock.Type btype, String field) { + + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 13); + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + SparseRow[] sr = (SparseRow[]) getField(sblock, field); + double[] values = sr[0].values(); + values[0] = 0.0; + values[values.length-1] = 0.0; + + RuntimeException ex = assertThrows(RuntimeException.class, + () -> sblock.checkValidity(_rows, _cols, sblock.size(), true)); + assertTrue(ex.getMessage().startsWith("The values are expected to be non zeros")); + long size = sblock.size(); + + sblock.compact(); + + assertTrue("should pass checkValidity", sblock.checkValidity(_rows, _cols, sblock.size(), true)); + assertEquals(size-2, sblock.size()); + } + + private static Object getField(Object obj, String name) { + try { + Field f = obj.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.get(obj); + } catch (Exception ex) { + throw new RuntimeException("Reflection failed: " + ex.getMessage()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockContainsTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockContainsTest.java new file mode 100644 index 00000000000..59b74f40804 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockContainsTest.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class SparseBlockContainsTest extends AutomatedTestBase +{ + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSparseBlockContainsNoMatchCOO() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockContainsNoMatchCSC() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockContainsNoMatchCSR() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockContainsNoMatchDCSR() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockContainsNoMatchMCSC() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockContainsNoMatchMCSR() { + runSparseBlockContainsNoMatchTest(SparseBlock.Type.MCSR); + } + + @Test + public void testSparseBlockContainsNaNCOO() { + runSparseBlockContainsNaNTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockContainsNaNCSC() { + runSparseBlockContainsNaNTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockContainsNaNCSR() { + runSparseBlockContainsNaNTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockContainsNaNDCSR() { + runSparseBlockContainsNaNTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockContainsNaNMCSC() { + runSparseBlockContainsNaNTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockContainsNaNMCSR() { + runSparseBlockContainsNaNTest(SparseBlock.Type.MCSR); + } + + @Test + public void testSparseBlockContainsEarlyAbortCOO() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockContainsEarlyAbortCSC() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockContainsEarlyAbortCSR() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockContainsEarlyAbortDCSR() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockContainsEarlyAbortMCSC() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockContainsEarlyAbortMCSR() { + runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.MCSR); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsCOO() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsCSC() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsCSR() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsDCSR() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsMCSC() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockContainsPatternLongerThanRowsMCSR() { + runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.MCSR); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroCOO() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroCSC() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.CSC); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroCSR() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.CSR); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroDCSR() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroMCSC() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockContainsPatternContainsZeroMCSR() { + runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.MCSR); + } + + private void runSparseBlockContainsNoMatchTest(SparseBlock.Type btype) { + double[] pattern = new double[]{1., 2., 3.}; + double[][] A = new double[][]{{4., 5., 6.}, {7., 8., 9.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + List result = sblock.contains(pattern, false); + assertEquals(List.of(), result); + } + + private void runSparseBlockContainsNaNTest(SparseBlock.Type btype) { + double[] pattern = new double[]{Double.NaN, 2., 3.}; + double[][] A = new double[][]{{Double.NaN, 2., 3.}, {1., 2., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + List result = sblock.contains(pattern, false); + assertEquals(List.of(0), result); + } + + private void runSparseBlockContainsEarlyAbortTest(SparseBlock.Type btype) { + double[] pattern = new double[]{1., 2., 3.}; + double[][] A = new double[][]{{0., 0., 0.}, {1., 2., 3.}, {1., 2., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + List result = sblock.contains(pattern, true); + assertEquals(List.of(1), result); + } + + private void runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type btype) { + double[] pattern = new double[]{1., 2., 3., 4.}; + double[][] A = new double[][]{{0., 0., 0.}, {1., 2., 3.}, {1., 2., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + List result = sblock.contains(pattern, false); + assertEquals(List.of(), result); + } + + private void runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type btype) { + double[] pattern = new double[]{0., 1., 2.}; + double[][] A = new double[][]{{0., 1., 2.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}, {0., 1., 2.}, {1., 2., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype, srtmp, true); + + List result = sblock.contains(pattern, false); + assertEquals(List.of(0, 4), result); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java new file mode 100644 index 00000000000..43dbae3004d --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.apache.sysds.runtime.data.SparseRowVector; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; + +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; + +@RunWith(Enclosed.class) +public class SparseBlockEqualsTest { + + @RunWith(Parameterized.class) + public static class SparseBlockEqualsSparseBlockTest extends AutomatedTestBase { + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + private final SparseBlock.Type type1; + private final SparseBlock.Type type2; + + public SparseBlockEqualsSparseBlockTest(SparseBlock.Type type1, SparseBlock.Type type2) { + this.type1 = type1; + this.type2 = type2; + } + + @Parameterized.Parameters(name = "{0} vs {1}") + public static Iterable types() { + SparseBlock.Type[] types = SparseBlock.Type.values(); + ArrayList params = new ArrayList<>(); + + for (int i = 0; i < types.length; i++) { + for (int j = i; j < types.length; j++) { + params.add(new Object[]{types[i], types[j]}); + } + } + + return params; + } + + @Test + public void testSparseBlockEquals() { + runSparseBlockEqualsTest(type1, type2); + } + + @Test + public void testSparseBlockNotEqualsColIdx() { + runSparseBlockNotEqualsColIdxTest(type1, type2); + } + + @Test + public void testSparseBlockNotEqualsEmptyRow() { + runSparseBlockNotEqualsEmptyRowTest(type1, type2); + } + } + + @RunWith(Parameterized.class) + public static class SparseBlockEqualsDenseValuesTest extends AutomatedTestBase { + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + private final SparseBlock.Type type; + + public SparseBlockEqualsDenseValuesTest(SparseBlock.Type type) { + this.type = type; + } + + @Parameterized.Parameters(name = "{0}") + public static Iterable types() { + ArrayList params = new ArrayList<>(); + for (SparseBlock.Type t : SparseBlock.Type.values()) { + params.add(new Object[]{t}); + } + return params; + } + + @Test + public void testSparseBlockNotEqualsNonSparseBlock() { + runSparseBlockNotEqualsNonSparseBlockTest(type); + } + + @Test + public void testSparseBlockNotEqualsDenseValuesEmptyRow() { + runSparseBlockNotEqualsDenseValuesEmptyRowTest(type); + } + + @Test + public void testSparseBlockNotEqualsDenseValuesNonZero() { + runSparseBlockNotEqualsDenseValuesNonZeroTest(type); + } + + @Test + public void testSparseBlockNotEqualsDenseValuesAdditionalNonZero() { + runSparseBlockNotEqualsDenseValuesAdditionalNonZeroTest(type); + } + } + + private static void runSparseBlockEqualsTest(SparseBlock.Type type1, SparseBlock.Type type2) { + double[][] A = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0., 4., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock1 = SparseBlockFactory.copySparseBlock(type1, srtmp, true); + SparseBlock sblock2 = SparseBlockFactory.copySparseBlock(type2, srtmp, true); + + assertEquals(sblock1, sblock2); + } + + private static void runSparseBlockNotEqualsColIdxTest(SparseBlock.Type type1, SparseBlock.Type type2) { + double[][] A = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0., 4., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}}; + double[][] B = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0., 0., 4.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}}; + + MatrixBlock mbtmp1 = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp1 = mbtmp1.getSparseBlock(); + SparseBlock sblock1 = SparseBlockFactory.copySparseBlock(type1, srtmp1, true); + + MatrixBlock mbtmp2 = DataConverter.convertToMatrixBlock(B); + SparseBlock srtmp2 = mbtmp2.getSparseBlock(); + SparseBlock sblock2 = SparseBlockFactory.copySparseBlock(type2, srtmp2, true); + + assertNotEquals("should not be equal: " + type1 + " " + type2, sblock1, sblock2); + } + + private static void runSparseBlockNotEqualsEmptyRowTest(SparseBlock.Type type1, SparseBlock.Type type2) { + double[][] A = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0., 4., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}}; + double[][] B = new double[][]{{1., 2., 3.}, {0., 4., 0.}, {0., 0., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}}; + + MatrixBlock mbtmp1 = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp1 = mbtmp1.getSparseBlock(); + SparseBlock sblock1 = SparseBlockFactory.copySparseBlock(type1, srtmp1, true); + + MatrixBlock mbtmp2 = DataConverter.convertToMatrixBlock(B); + SparseBlock srtmp2 = mbtmp2.getSparseBlock(); + SparseBlock sblock2 = SparseBlockFactory.copySparseBlock(type2, srtmp2, true); + + assertNotEquals("should not be equal: " + type1 + " " + type2, sblock1, sblock2); + } + + private static void runSparseBlockNotEqualsNonSparseBlockTest(SparseBlock.Type type) { + double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(type, srtmp, true); + + SparseRowVector srv = new SparseRowVector(A[0], new int[]{0, 1, 2}); + + assertNotEquals("should not be equal: " + type, sblock, srv); + } + + private static void runSparseBlockNotEqualsDenseValuesEmptyRowTest(SparseBlock.Type type) { + double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0., 0., 0.}, {4., 0., 6.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(type, srtmp, true); + + double[] denseValues = new double[]{1., 0., 3., 0., 0., 0., 1., 1., 1., 4., 0., 6.}; + + assertFalse("should not be equal: " + type, sblock.equals(denseValues, 3, 1e-10)); + } + + private static void runSparseBlockNotEqualsDenseValuesNonZeroTest(SparseBlock.Type type) { + double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 1.}, {4., 0., 6.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(type, srtmp, true); + + double[] denseValues = new double[]{1., 0., 3., 0., 0., 0., 0., 0., 0., 0., 1., 1., 4., 0., 6.}; + + assertFalse("should not be equal: " + type, sblock.equals(denseValues, 3, 1e-10)); + } + + private static void runSparseBlockNotEqualsDenseValuesAdditionalNonZeroTest(SparseBlock.Type type) { + double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0., 0., 0.}, {4., 0., 0.}}; + + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock sblock = SparseBlockFactory.copySparseBlock(type, srtmp, true); + + double[] denseValues = new double[]{1., 0., 3., 0., 0., 0., 0., 0., 0., 4., 0., 6.}; + + assertFalse("should not be equal: " + type, sblock.equals(denseValues, 3, 1e-10)); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java new file mode 100644 index 00000000000..57185053a7b --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java @@ -0,0 +1,474 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockCOO; +import org.apache.sysds.runtime.data.SparseBlockCSC; +import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockDCSR; +import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.apache.sysds.runtime.data.SparseBlockMCSC; +import org.apache.sysds.runtime.data.SparseBlockMCSR; +import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.data.SparseRowVector; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + + +public class SparseBlockInitializationTest extends AutomatedTestBase +{ + private final static int _rows = 324; + private final static int _cols = 132; + private final static double _sparsity = 0.22; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSparseBlockCreationCOO() { + runSparseBlockCreationTest(SparseBlock.Type.COO); + } + + @Test + public void testSparseBlockCreationDCSR() { + runSparseBlockCreationTest(SparseBlock.Type.DCSR); + } + + @Test + public void testSparseBlockCreationMCSC() { + runSparseBlockCreationTest(SparseBlock.Type.MCSC); + } + + @Test + public void testSparseBlockCreationCSC() { + runSparseBlockCreationTest(SparseBlock.Type.CSC); + } + + private void runSparseBlockCreationTest(SparseBlock.Type type) { + SparseBlock sblock = SparseBlockFactory.createSparseBlock(type, _cols); + assertEquals(SparseBlockFactory.getSparseBlockType(sblock), type); + } + + @Test + public void testSparseBlockCOOInitCapacity() { + int init_capacity = 4; + SparseBlockCOO sblock = new SparseBlockCOO(_cols); + assertEquals("INIT_CAPACITY should be 4", init_capacity, sblock.values(1).length); + } + + @Test + public void testSparseBlockCOORows() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock sblock = new SparseBlockCOO(mbtmp.getSparseBlock()); + + int totalNnz = 0; + int rows = A.length; + SparseRow[] sparseRows = new SparseRow[rows]; + + for (int i = 0; i < rows; i++) { + SparseRow srv = new SparseRowVector(A[i]); + sparseRows[i] = srv; + totalNnz += srv.size(); + } + + SparseBlockCOO sblock2 = new SparseBlockCOO(sparseRows, totalNnz); + assertEquals(sblock, sblock2); + } + + @Test + public void testSparseBlockCOORowsValuesIndexes() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock sblock = new SparseBlockCOO(mbtmp.getSparseBlock()); + + int totalNnz = 0; + int rows = A.length; + SparseRow[] sparseRows = new SparseRow[rows]; + + for (int i = 0; i < rows; i++) { + int[] indexes = new int[A[i].length]; + for (int j = 0; j < A[i].length; j++) indexes[j] = j; + SparseRow srv = new SparseRowVector(A[i], indexes); + srv.compact(); + sparseRows[i] = srv; + totalNnz += srv.size(); + } + + SparseBlockCOO sblock2 = new SparseBlockCOO(sparseRows, totalNnz); + assertEquals(sblock, sblock2); + } + + @Test + public void testSparseBlockCSCInitCapacity() { + int rlen = 4; + int clen = 5; + int capacity = 4; + SparseBlockCSC sblock = new SparseBlockCSC(rlen, clen, capacity); + + assertEquals("num rows should be equal to rlen", rlen, sblock.numRows()); + assertEquals("length ptr should be equal to clen+1", clen+1, sblock.colPointers().length); + assertEquals("length values should be equal to capacity", capacity, sblock.valuesCol(0).length); + assertEquals("length indexes should be equal to capacity", capacity, sblock.indexesCol(0).length); + } + + @Test + public void testSparseBlockCSCInitPointer() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockCSC sblock = new SparseBlockCSC(mbtmp.getSparseBlock()); + + int[] colPtr = sblock.colPointers(); + int[] rowInd = sblock.indexesCol(0); + double[] values = sblock.valuesCol(0); + int nnz = sblock.sizeCol(0); + SparseBlockCSC sblock2 = new SparseBlockCSC(colPtr, rowInd, values, nnz); + + assertEquals(sblock, sblock2); + } + + @Test + public void testSparseBlockCSCInitMSCS() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlock sblock = new SparseBlockMCSC(mbtmp.getSparseBlock()); + + SparseBlockCSC sblock2 = new SparseBlockCSC(sblock); + assertEquals(sblock, sblock2); + } + + @Test + public void testSparseBlockCSCInitCols() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockMCSC sblock = new SparseBlockMCSC(mbtmp.getSparseBlock()); + + SparseRow[] cols = sblock.getCols(); + int totalNnz = (int) sblock.size(); + + SparseBlock sblock2 = new SparseBlockCSC(cols, totalNnz); + assertEquals(sblock, sblock2); + + } + + @Test + public void testSparseBlockCSCInitRowColInd() { + double[][] A = getRandomMatrix(_rows, _cols, -10, 10, _sparsity, 1234); + MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); + SparseBlockCSC sblock = new SparseBlockCSC(mbtmp.getSparseBlock()); + + int[] ptr = sblock.colPointers(); + int[] rowInd = sblock.indexesCol(0); + double[] values = sblock.valuesCol(0); + + int clen = ptr.length-1; + int[] colInd = new int[rowInd.length]; + for(int i=0; i iter = !partial ? sblock.getIterator() : sblock.getIterator(rl, rows); + Iterator iter = rlPartial && ruPartial ? sblock.getIterator(rl, ru, cl, cu): rlPartial? sblock.getIterator(rl, rows) : ruPartial? sblock.getIterator(ru) : sblock.getIterator(); int count = 0; while(iter.hasNext()) { IJV cell = iter.next(); + if(cell.getI() < rl || cell.getI() >= ru) + Assert.fail("iterator row outside of range"); + if(cell.getJ() < cl || cell.getJ() >= cu) + Assert.fail("iterator column outside of range"); if(cell.getV() != A[cell.getI()][cell.getJ()]) Assert.fail("Wrong value returned by iterator: " + cell.getV() + ", expected: " + A[cell.getI()][cell.getJ()]); @@ -277,11 +469,9 @@ private void runSparseBlockIteratorTest(SparseBlock.Type btype, double sparsity, // check iterator over non-zero rows List manualNonZeroRows = new ArrayList<>(); List iteratorNonZeroRows = new ArrayList<>(); - Iterator iterRows = !partial - ? sblock.getNonEmptyRowsIterator(0, rows) - : sblock.getNonEmptyRowsIterator(rl, rows); + Iterator iterRows = sblock.getNonEmptyRowsIterator(rl, ru); - for(int i = rl; i < rows; i++) + for(int i = rl; i < ru; i++) if(!sblock.isEmpty(i)) manualNonZeroRows.add(i); while(iterRows.hasNext()) { @@ -293,6 +483,16 @@ private void runSparseBlockIteratorTest(SparseBlock.Type btype, double sparsity, Assert.fail("Verification of iterator over non-zero rows failed."); } + // check second iterator over non-zero rows + Iterator iterRows2 = !rlPartial && !ruPartial? sblock.getNonEmptyRows().iterator() : sblock.getNonEmptyRows(rl, ru).iterator(); + List iter2NonZeroRows = new ArrayList<>(); + + while(iterRows2.hasNext()) { + iter2NonZeroRows.add(iterRows2.next()); + } + if(!manualNonZeroRows.equals(iter2NonZeroRows)) { + Assert.fail("Verification of second iterator over non-zero rows failed."); + } } catch(Exception ex) { ex.printStackTrace(); diff --git a/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java b/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java new file mode 100644 index 00000000000..307b4335c49 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.sparse; + +import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.data.SparseRowScalar; +import org.apache.sysds.runtime.data.SparseRowVector; +import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + + +public class SparseRowTest extends AutomatedTestBase +{ + private final static int cols = 121; + private final static int minVal = -10; + private final static int maxVal = 10; + private final static double sparsity = 0.3; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSparseRowEmptyToString() { + SparseRowScalar srs = new SparseRowScalar(); + assertEquals("", srs.toString()); + } + + @Test + public void testSparseRowScalarInitZeroVal() { + SparseRowScalar srs = new SparseRowScalar(5, 0); + srs.compact(); + assertEquals(-1, srs.getIndex()); + } + + @Test + public void testSparseRowScalarSetNewVal() { + SparseRowScalar srs = new SparseRowScalar(); + assertTrue(srs.set(3, 5.0)); + } + + @Test + public void testSparseRowScalarInvalidSet() { + SparseRowScalar srs = new SparseRowScalar(1, 1.0); + RuntimeException ex = assertThrows(RuntimeException.class, () -> srs.set(3, 5.0)); + assertEquals("Invalid set to sparse row scalar.", ex.getMessage()); + } + + @Test + public void testSparseRowScalarAppendZero() { + SparseRowScalar srs = new SparseRowScalar(1, 1.0); + SparseRow srs2 = srs.append(2, 0.0); + assertEquals(srs, srs2); + assertNotEquals(0, srs2.values()[0]); + } + + @Test + public void testSparseRowScalarCompactZero() { + SparseRowScalar srs = new SparseRowScalar(1, 0.0); + srs.compact(); + assertEquals(-1, srs.getIndex()); + } + + @Test + public void testSparseRowScalarCompactNonZero() { + SparseRowScalar srs = new SparseRowScalar(1, 1.0); + srs.compact(); + assertEquals(1, srs.getIndex()); + } + + @Test + public void testSparseRowScalarCopy() { + SparseRowScalar srs = new SparseRowScalar(1, 1.0); + SparseRowScalar srs2 = (SparseRowScalar) srs.copy(true); + assertEquals(srs.getIndex(), srs2.getIndex()); + assertEquals(srs.getValue(), srs2.getValue(), 0.0); + assertNotEquals(srs, srs2); + } + + @Test + public void testSparseRowVectorSetValues() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity, 7)[0]; + SparseRowVector srv = new SparseRowVector(UtilFunctions.computeNnz(v, 0, v.length), v, v.length); + + srv.compact(); + int nnz = srv.size(); + double[] w = getRandomMatrix(1, nnz, minVal, maxVal, 1, 13)[0]; + srv.setValues(w); + + assertArrayEquals(w, srv.values(), 0.0); + assertEquals(srv.indexes().length, srv.values().length); + } + + @Test + public void testSparseRowVectorSetIndexes() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, 1, 7)[0]; + int nnz = UtilFunctions.computeNnz(v, 0, v.length); + SparseRowVector srv = new SparseRowVector(nnz, v, v.length); + + int[] indexes = new int[nnz]; + for(int i = 0; i < nnz; i++) indexes[i] = i; + srv.setIndexes(indexes); + + int idx = (int)(Math.random() * nnz); + assertEquals(idx, srv.getIndex(idx)); + assertEquals(-1, srv.getIndex(nnz)); + assertEquals(srv.values().length, srv.indexes().length); + } + + @Test + public void testSparseRowVectorCopyFromLargerArray() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity, 7)[0]; + double[] w = getRandomMatrix(1, 2*cols, minVal, maxVal, sparsity, 7)[0]; + SparseRowVector srv = new SparseRowVector(UtilFunctions.computeNnz(v, 0, v.length), v, v.length); + SparseRowVector other = new SparseRowVector(UtilFunctions.computeNnz(w, 0, w.length), w, w.length); + srv.copy(other); + + assertArrayEquals(other.indexes(), srv.indexes()); + assertArrayEquals(other.values(), srv.values(), 0.0); + assertNotEquals(other, srv); + } + + @Test + public void testSparseRowVectorSetEstimatedNzs() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity, 7)[0]; + int nnz = UtilFunctions.computeNnz(v, 0, v.length); + SparseRowVector srv = new SparseRowVector(nnz, v, v.length); + srv.setEstimatedNzs(nnz+1); + assertEquals(nnz+1, srv.getEstimatedNzs()); + } + + @Test + public void testSparseRowVectorSetAtPos() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity, 7)[0]; + int nnz = UtilFunctions.computeNnz(v, 0, v.length); + SparseRowVector srv = new SparseRowVector(nnz, v, v.length); + + int pos = nnz-1; + int col = 2; + double val = 2.0; + srv.setAtPos(pos, col, val); + + assertEquals(col, srv.indexes()[pos]); + assertEquals(val, srv.indexes()[pos],0.0); + } + + @Test + public void testSparseRowVectorGetIndex() { + double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity, 7)[0]; + int nnz = UtilFunctions.computeNnz(v, 0, v.length); + SparseRowVector srv = new SparseRowVector(nnz, v, v.length); + + int pos = 0; + srv.setAtPos(pos, 5, 2.0); + int index = srv.getIndex(5); + assertEquals(pos, index); + + int col2 = cols+1; + int index2 = srv.getIndex(col2); + assertEquals(-1, index2); + } + + @Test + public void testSparseRowVectorSearchIndexesFirstLTESizeZero() { + SparseRowVector srv = new SparseRowVector(); + int index = srv.searchIndexesFirstLTE(1); + assertEquals(-1, index); + } + + @Test + public void testSparseRowVectorSearchIndexesFirstLTENotFound() { + SparseRowVector srv = new SparseRowVector(new double[] {1.0, 3.0}, new int[] {1, 3}); + int index = srv.searchIndexesFirstLTE(0); + assertEquals(-1, index); + int index2 = srv.searchIndexesFirstLTE(2); + assertEquals(0, index2); + int index3 = srv.searchIndexesFirstLTE(5); + assertEquals(1, index3); + } + + @Test + public void testSparseRowVectorSetIndexRangeWithoutRecap() { + SparseRowVector srv = new SparseRowVector(); + int capacity = srv.capacity(); + + double[] v = getRandomMatrix(1, capacity, minVal, maxVal, sparsity, 7)[0]; + srv.setIndexRange(0, capacity, v, 0, capacity); + assertEquals(capacity, srv.capacity()); + } +}