Question

I wrote two matrix classes in Java just to compare the performance of their matrix multiplications. One class (Mat1) stores a double[][] A member where row i of the matrix is A[i]. The other class (Mat2) stores A and T where T is the transpose of A.

Let's say we have a square matrix M and we want the product of M.mult(M). Call the product P.

When M is a Mat1 instance the algorithm used was the straightforward one:

P[i][j] += M.A[i][k] * M.A[k][j]
    for k in range(0, M.A.length)

In the case where M is a Mat2 I used:

P[i][j] += M.A[i][k] * M.T[j][k]

which is the same algorithm because T[j][k]==A[k][j]. On 1000x1000 matrices the second algorithm takes about 1.2 seconds on my machine, while the first one takes at least 25 seconds. I was expecting the second one to be faster, but not by this much. The question is, why is it this much faster?

My only guess is that the second one makes better use of the CPU caches, since data is pulled into the caches in chunks larger than 1 word, and the second algorithm benefits from this by traversing only rows, while the first ignores the data pulled into the caches by going immediately to the row below (which is ~1000 words in memory, because arrays are stored in row major order), none of the data for which is cached.

I asked someone and he thought it was because of friendlier memory access patterns (i.e. that the second version would result in fewer TLB soft faults). I didn't think of this at all but I can sort of see how it results in fewer TLB faults.

So, which is it? Or is there some other reason for the performance difference?

Was it helpful?

Solution

This because of locality of your data.

In RAM a matrix, although bidimensional from your point of view, it's of course stored as a contiguous array of bytes. The only difference from a 1D array is that the offset is calculated by interpolating both indices that you use.

This means that if you access element at position x,y it will calculate x*row_length + y and this will be the offset used to reference to the element at position specified.

What happens is that a big matrix isn't stored in just a page of memory (this is how you OS manages the RAM, by splitting it into chunks) so it has to load inside CPU cache the correct page if you try to access an element that is not already present.

As long as you go contiguously doing your multiplication you don't create any problems, since you mainly use all coefficients of a page and then switch to the next one but if you invert indices what happens is that every single element may be contained in a different memory page so everytime it needs to ask to RAM a different page, this almost for every single multiplication you do, this is why the difference is so neat.

(I rather simplified the whole explaination, it's just to give you the basic idea around this problem)

In any case I don't think this is caused by JVM by itself. It maybe related in how your OS manages the memory of the Java process..

OTHER TIPS

The cache and TLB hypotheses are both reasonable, but I'd like to see the complete code of your benchmark ... not just pseudo-code snippets.

Another possibility is that performance difference is a result of your application using 50% more memory for the data arrays in the version with the transpose. If your JVM's heap size is small, it is possible that this is causing the GC to run too often. This could well be a result of using the default heap size. (Three lots of 1000 x 1000 x 8 bytes is ~24Mb)

Try setting the initial and max heap sizes to (say) double the current max size. If that makes no difference, then this is not a simple heap size issue.

It's easy to guess that the problem might be locality, and maybe it is, but that's still a guess.

It's not necessary to guess. Two techniques might give you the answer - single stepping and random pausing.

If you single-step the slow code you might find out that it's doing a lot of stuff you never dreamed of. Such as, you ask? Try it and find out. What you should see it doing, at the machine-language level, is efficiently stepping through the inner loop with no waste motion.

If it actually is stepping through the inner loop with no waste motion, then random pausing will give you information. Since the slow one is taking 20 times longer than the fast one, that implies 95% of the time it is doing something it doesn't have to. So see what it is. Each time you pause it, the chance is 95% that you will see what that is, and why.

If in the slow case, the instructions it is executing appear just as efficient as the fast case, then cache locality is a reasonable guess of why it is slow. I'm sure, once you've eliminated any other silliness that may be going on, that cache locality will dominate.

You might try comparing performance between JDK6 and OpenJDK7, given this set of results...

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top