/* This class implements a Compress-Sparse-Row format for a sparse * matrix. Matrices are assumed square in this implementation (in * particular the toString method). */ class CSR implements Sparse { public CSR (MatrixEntry [1d] entries, int n) { fillEntries(entries, n); } public CSR (double [1d] single [1d] diagonal) { myData = new double [diagonal[Ti.thisProc()].domain()]; myData.copy(diagonal[Ti.thisProc()]); allData = new double [0:Ti.numProcs()-1][1d]; allData.exchange(myData); myColIndex = new int [myData.domain()]; allColIndex = new int [0:Ti.numProcs()-1][1d]; allColIndex.exchange(myColIndex); myRowStart = new int [myData.domain().min():myData.domain().max()+[1]]; foreach (p in myData.domain()) { myRowStart[p] = p[1]; myColIndex[p] = p[1]; } Point <1> high = myRowStart.domain().max(); myRowStart[high] = high[1]; allRowStart = new int [0:Ti.numProcs()-1][1d]; allRowStart.exchange(myRowStart); } public single CSR (int n, double val) { // for simplicity, reuse general constructor MatrixEntry [1d] entries = new MatrixEntry [0:n*n-1]; foreach (i in [0:n-1]) { foreach (j in [0:n-1]) { entries[i*n+j] = new MatrixEntry(i[1],j[1],val); } } fillEntries(entries, n); } public int dim () { return allRowStart[allRowStart.domain().max()].domain().max()[1]+1; } private single void fillEntries (MatrixEntry [1d] entries, int n) { int nz = entries.domain().size(); int thisP = Ti.thisProc(); int single numP = Ti.numProcs(); int myRowCount = PVector.numPer(n); int myLow = PVector.lowPer(n); myRowStart = new int [myLow:myLow+myRowCount]; allRowStart = new int [0:numP-1][1d]; allRowStart.exchange(myRowStart); int [1d] nzCounts; int [1d] nzStart; if (thisP == 0) { nzCounts = sortAndCount(entries, n); nzStart = Vector.addScan(nzCounts); foreach (i in nzStart.domain()) { nzStart[i] -= nzCounts[i]; } } /* The elements will still be on Processor 0 */ nzCounts = broadcast nzCounts from 0; nzStart = broadcast nzStart from 0; int myNzCount = nzCounts[thisP]; int myNzStart = nzStart[thisP]; RectDomain <1> myIndices = [myNzStart:myNzStart+myNzCount-1]; /* build and exchange arrays, make copies, ... */ allData = new double [0:Ti.numProcs()-1][1d]; myData = new double [myIndices]; allData.exchange(myData); allColIndex = new int [0:Ti.numProcs()-1][1d]; myColIndex = new int [myIndices]; allColIndex.exchange(myColIndex); MatrixEntry [1d] myEntries = new MatrixEntry [myNzStart:myNzStart+myNzCount-1]; /* Bulk copy of all the array elements pointers */ /* Might speed up later accesses by making MatrixEntry immutable */ myEntries.copy(entries); myRowStart[myRowStart.domain().min()] = myEntries.domain().min()[1]; for (int k = myEntries.domain().min()[1]; k <= myEntries.domain().max()[1]; k++) { myRowStart[myEntries[k].i+1] = k+1; myColIndex[k] = myEntries[k].j; myData[k] = myEntries[k].val; } } public void matvec (double [1d] single [1d] allY, double [1d] single [1d] allX) { /* make a complete copy of x locally */ double [1d] copyX = new double [0:dim()-1]; /* Lots of optimization opportunities here */ foreach (p in Ti.myTeam().domain()) { copyX.copy(allX[p]); } double [1d] myY = allY[Ti.thisProc()]; foreach (i in myY.domain()) { myY[i] = 0; foreach (j in [myRowStart[i]:myRowStart[i+[1]]-1]) { myY[i] += myData[j] * copyX[myColIndex[j]]; } } } public String toString () { String result = ""; int numP = Ti.numProcs(); int n = allRowStart[numP-1].domain().max()[1]; foreach (p in [0:numP-1]) { for (int i = allRowStart[p].domain().min()[1]; i < allRowStart[p].domain().max()[1]; i++) { int curr = allRowStart[p][i]; for (int j = 0; j < n; j++) { if (curr < allRowStart[p][i+1] && allColIndex[p][curr] == j) { result += allData[p][curr] + "\t"; curr++; } else { result += "0\t"; } } result += "\n"; } } return result; } /* Local computation to sort entires and count the nonzeros per row */ int [1d] sortAndCount(MatrixEntry [1d] entries, int n) { QSort.sort(entries); int numP = Ti.numProcs(); int [1d] nzCount = new int [0:numP-1]; nzCount.set(0); int curr = entries.domain().min()[1]; for (int i = 0; i < numP; i++) { int lowNext = PVector.lowPer(n, i+1); while (curr <= entries.domain().max()[1] && entries[curr].i < lowNext) { nzCount[i]++; curr++; } } return nzCount; } /* Internal State */ protected double [1d] single [1d] allData; protected double [1d] myData; protected int [1d] single [1d] allColIndex; protected int [1d] myColIndex; protected int [1d] single [1d] allRowStart; protected int [1d] myRowStart; /* Test code */ public static void tester (String [] single args) { int single numP = Ti.numProcs(); int thisP = Ti.thisProc(); if (args.length < 1) { if (thisP == 0) { printUsage(); } return; } int n = 10; try { n = Integer.parseInt(args[0]); } catch (Exception e) { if (thisP == 0) { System.out.println("Incorrect argument format, using n = 10."); } } // Allocate source and destination vectors int myN = PVector.numPer(n); int myLow = PVector.lowPer(n); double [1d] myVec1 = new double [myLow:myLow+myN-1]; double [1d] myVec2 = new double [myLow:myLow+myN-1]; double [1d] single [1d] allVec1 = new double [0:numP-1][1d]; double [1d] single [1d] allVec2 = new double [0:numP-1][1d]; allVec1.exchange(myVec1); allVec2.exchange(myVec2); foreach (i in myVec1.domain()) { myVec1[i] = i[1]; } myVec2.set(2.0); if (thisP == 0) { System.out.println("Diagonal vector v2: " + PVector.toString(allVec2)); } // Test 1: diagonal 2*I matrix times vector CSR single twoIdent = new CSR(allVec2); Ti.barrier(); if (thisP == 0) { System.out.println("Diagonal matrix: 2*I:\n" + twoIdent); } twoIdent.matvec(allVec2, allVec1); Ti.barrier(); if (thisP == 0) { System.out.println("Source vector v1: " + PVector.toString(allVec1)); System.out.println("Result vector v2: " + PVector.toString(allVec2)); } // Test 2: dense matrix (constant entries) times vector CSR single allThrees = new CSR(n, 3.0); Ti.barrier(); if (thisP == 0) { System.out.println("Constant matrix of all 3.0s:\n" + allThrees); } Ti.barrier(); allThrees.matvec(allVec2, allVec1); Ti.barrier(); if (thisP == 0) { System.out.println("Source vector v1: " + PVector.toString(allVec1)); System.out.println("Result vector v2: " + PVector.toString(allVec2)); } } private static void printUsage() { System.out.println("CSR \n for nxn test problem"); } }