Question

As part of a class code project, part of the assignment is to construct an algorithm that will multiply two matrices using the Strassen method.

I have created the algorithms necessary to do so, however I am getting a Stack Overflow error on a getSize() call. Any help with figuring out why this is would be appreciated.

I have not appended all code, just that which pertains to the question.

    // CODE CALLING THE MULTIPLICATION
    Matrix m = new Matrix(Integer.parseInt(size.getText()), 20);
    Matrix m2 = new Matrix(Integer.parseInt(size.getText()), 20);
    Matrix m3 = m.strassenMult(m2);

    // THE MATRIX CLASS
    public class Matrix {

    private int[][] vals;
    private int size;
    private boolean stable = false;
    private long time;

    // Constructor for known data
    public Matrix(int[][] vals) {
    this.vals = vals;
    this.size = vals.length;
    stable = true;
    }

    // Constructor for matrix with random size, if bounds = 0 it is empty
    public Matrix(int size, int bounds) {
    this.size = size;
    vals = new int[size][size];
    stable = true;
    if (bounds != 0) {
        Random rand = new Random(System.currentTimeMillis());
        for (int i = 0; i < size; i++) {
            for (int j = 0; j < size; j++) {
                int x = rand.nextInt();
                x = x % bounds;
                vals[j][i] = x;
            }
        }
    }
    }

    private Matrix add(Matrix m2) {
    ///////////////////////// THE ERROR IS ON THE BELOW LINE
    int n = m2.getSize();
    int[][] newVals = new int[n][n];
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            newVals[i][j] = vals[i][j] + m2.getVal(i, j);
        }
    }
    Matrix sum = new Matrix(newVals);
    return sum;
    }
    private Matrix subtract(Matrix m2) {
    int n = m2.getSize();
    int[][] newVals = new int[n][n];
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            newVals[i][j] = vals[i][j] - m2.getVal(i, j);
        }
    }
    return new Matrix(newVals);
    }


    // The strassen multiplication algorithm
    public Matrix strassenMult(Matrix m2) {
    int n = m2.getSize();
    int newSize = n / 2;
    // initialize new sub-matricies
    Matrix a11 = new Matrix(newSize, 0);
    Matrix a12 = new Matrix(newSize, 0);
    Matrix a21 = new Matrix(newSize, 0);
    Matrix a22 = new Matrix(newSize, 0);
    Matrix b11 = new Matrix(newSize, 0);
    Matrix b12 = new Matrix(newSize, 0);
    Matrix b21 = new Matrix(newSize, 0);
    Matrix b22 = new Matrix(newSize, 0);
    Matrix aResult = new Matrix(newSize, 0);
    Matrix bResult = new Matrix(newSize, 0);

    // divide existing matries into the sub-matricies
    for (int i = 0; i < newSize; i++) {
        for (int j = 0; j < newSize; j++) {
            a11.set(i, j, vals[i][j]); // top left
            a12.set(i, j, vals[i][j + newSize]); // top right
            a21.set(i, j, vals[i + newSize][j]); // bottom left
            a22.set(i, j, vals[i + newSize][j + newSize]); // bottom right

            b11.set(i, j, m2.getVal(i, j)); // top left
            b12.set(i, j,m2.getVal(i, j + newSize)); // top right
            b21.set(i, j, m2.getVal(i + newSize, j)); // bottom left
            b22.set(i, j, m2.getVal(i + newSize, j + newSize)); // bottom right
        }
    }
    // Calculating p1 to p7:
    ////////////////////////////////// ERROR IS CALLED BY THIS METHOD
    aResult = a11.add(a22);
    bResult = b11.add(b22);
    Matrix p1 = aResult.strassenMult(bResult);
    // p1 = (a11+a22) * (b11+b22)

    aResult = a21.add(a22); // a21 + a22
    Matrix p2 = aResult.strassenMult(b11); // p2 = (a21+a22) * (b11)

    bResult = b12.subtract(b22); // b12 - b22
    Matrix p3 = a11.strassenMult(bResult);
    // p3 = (a11) * (b12 - b22)

    bResult = b21.subtract(b11); // b21 - b11
    Matrix p4 = a22.strassenMult(bResult);
    // p4 = (a22) * (b21 - b11)

    aResult = a11.add(a12); // a11 + a12
    Matrix p5 = aResult.strassenMult(b22);
    // p5 = (a11+a12) * (b22)

    aResult = a21.subtract(a11); // a21 - a11
    bResult = b11.add(b12); // b11 + b12
    Matrix p6 = aResult.strassenMult(bResult);
    // p6 = (a21-a11) * (b11+b12)

    aResult = a12.subtract(a22); // a12 - a22
    bResult = b21.add(b22); // b21 + b22
    Matrix p7 = aResult.strassenMult(bResult);
    // p7 = (a12-a22) * (b21+b22)

    // calculating c21, c21, c11 e c22:
    Matrix c12 = p3.add(p5); // c12 = p3 + p5
    Matrix c21 = p2.add(p4); // c21 = p2 + p4

    aResult = p1.add(p4); // p1 + p4
    bResult = aResult.add(p7); // p1 + p4 + p7
    Matrix c11 = bResult.subtract(p5);
    // c11 = p1 + p4 - p5 + p7

    aResult = p1.add(p3); // p1 + p3
    bResult = aResult.add(p6); // p1 + p3 + p6
    Matrix c22 = bResult.subtract(p2);
    // c22 = p1 + p3 - p2 + p6

    // Grouping the results obtained in a single matrix:
    Matrix result = new Matrix(n, 0);
    for (int i = 0; i < newSize; i++) {
        for (int j = 0; j < newSize; j++) {
            result.set(i, j, c11.getVal(i, j));
            result.set(i, j + newSize, c12.getVal(i, j));
            result.set(i + newSize, j, c21.getVal(i, j));
            result.set(i + newSize, j + newSize, c22.getVal(i, j));
        }
    }
    return result;
    }

    public int getSize() {
    return size;
    }
Was it helpful?

Solution

Your Strassen multiplication function is missing a base case. At some point, you need to stop recursing and call a different matrix multiplication algorithm.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top