/* This class implements a Compress-Sparse-Row format for a sparse * matrix. Matrices are assumed square in this implementation (in * particular the toString method). */ class CSR implements Sparse { public CSR (double [1d] data, int [1d] rowIndex, int [1d] rowStarts) { this.data = data; this.rowIndex = rowIndex; this.rowStarts = rowStarts; } public CSR (double [1d] diagonal) { int n = diagonal.domain().size(); rowIndex = new int [0:n-1]; rowStarts = new int [0:n]; data = diagonal; foreach (i in rowIndex.domain()) { rowIndex[i] = i[1]; rowStarts[i] = i[1]; } rowStarts[n] = n; // last index is off the end } public CSR (int n, double val) { rowStarts = new int [0:n]; rowIndex = new int [0:n*n-1]; data = new double [0:n*n-1]; foreach (i in data.domain()) { data[i] = val; } foreach (i in [0:n-1]) { rowStarts[i] = i[1]*n; foreach (j in [0:n-1]) { rowIndex[i*n+j] = j[1]; } } rowStarts[n] = n*n; } public int dim () { return rowStarts.domain().size()-1; } public String toString () { String result = ""; int n = rowStarts.domain().size()-1; for (int i = 0; i < n; i++) { int curr = rowStarts[i]; for (int j = 0; j < n; j++) { if (curr < rowStarts[i+1] && rowIndex[curr] == j) { result += data[curr] + "\t"; curr++; } else { result += "0\t"; } } result += "\n"; } return result; } public void matvec (double [1d] y, double [1d] x) { foreach (i in y.domain()) { y[i] = 0; foreach (j in [rowStarts[i]:rowStarts[i+[1]]-1]) { y[i] += data[j] * x[rowIndex[j]]; } } } public void psolve(double [1d] minvx, Sparse mdata, double [1d] x) { minvx.copy(x); } /* Internal State */ protected double [1d] data; protected int [1d] rowIndex; protected int [1d] rowStarts; /* Test code */ public static void tester (String [] args) { if (args.length < 1) { printUsage(); return; } int n = Integer.parseInt(args[0]); // Allocate source and destination vectors double [1d] vec1 = new double [0:n-1]; double [1d] vec2 = new double [0:n-1]; foreach (i in vec1.domain()) { vec1[i] = i[1]; } // Test 1: diagonal 2*I matrix times vector double [1d] diag = new double [0:n-1]; foreach (i in diag.domain()) { diag[i] = 2.0; } CSR twoIdent = new CSR(diag); System.out.println("Diagonal matrix: 2*I: \n" + twoIdent); twoIdent.matvec(vec2, vec1); System.out.println("Source vector v1: \n" + Vector.toString(vec1) + "\n"); System.out.println("Result vector v2: \n" + Vector.toString(vec2) + "\n"); // Test 2: dense matrix (constant entries) times vector CSR allThrees = new CSR(n, 3.0); System.out.println("Constant matrix: \n" + allThrees); allThrees.matvec(vec2, vec1); System.out.println("Source vector v1: \n" + Vector.toString(vec1) + "\n"); System.out.println("Result vector v2: \n" + Vector.toString(vec2) + "\n"); } private static void printUsage() { System.out.println("CSR \n where n is the size of the matrix"); } }