I'm curious about the following construct in Java 8:
double[] doubles = //...
double sum = DoubleStream.of(doubles).parallel().sum();
To cut to the chase:
- Will the value of
sum
always be the same, e.g. when run on different computers?
More background...
Floating point arithmetic is lossy and (unlike real-valued arithmetic) is not associative. So unless care is taken in how the work is divided and reassembled, it could lead to non-deterministic results.
I was happy to discover that the sum()
method employs Kahan Summation under the hood. This significantly reduces the error, but does still not give precise* results.
In my testing repeated calls appear to return the same result each time, but I'd like to know how stable we can safely assume it is. e.g.:
- Stable in all circumstances?
- Stable across computers with the same number of cores?
- Stable only on a given computer?
- Can't depend on it being stable at all?
I'm happy to assume the same JVM version on each computer.
Here's a test I whipped up:
public static void main(String[] args) {
Random random = new Random(42L);
for (int j = 1; j < 20; j++) {
// Stream increases in size and the magnitude of the values at each iteration.
double[] doubles = generate(random, j*100, j);
// Like a simple for loop
double sum1 = DoubleStream.of(doubles).reduce(0, Double::sum);
double sum2 = DoubleStream.of(doubles).sum();
double sum3 = DoubleStream.of(doubles).parallel().sum();
System.out.println(printStats(doubles, sum1, sum2, sum3));
// Is the parallel computation stable?
for (int i = 0; i < 1000; i++) {
double sum4 = DoubleStream.of(doubles).parallel().sum();
assert sum4 == sum3;
}
Arrays.sort(doubles);
}
}
/**
* @param spread When odd, returns a mix of +ve and -ve numbers.
* When even, returns only +ve numbers.
* Higher values cause a wider spread of magnitudes in the returned values.
* Must not be negative.
*/
private static double[] generate(Random random, int count, int spread) {
return random.doubles(count).map(x -> Math.pow(4*x-2, spread)).toArray();
}
private static String printStats(double[] doubles, double sum1, double sum2, double sum3) {
DoubleSummaryStatistics stats = DoubleStream.of(doubles).summaryStatistics();
return String.format("-----%nMin: %g, Max: %g, Average: %g%n"
+ "Serial difference: %g%n"
+ "Parallel difference: %g",
stats.getMin(), stats.getMax(), stats.getAverage(), sum2-sum1, sum3-sum1);
}
When I run this, the first few iterations are:
-----
Min: -1.89188, Max: 1.90414, Average: 0.0541140
Serial difference: -2.66454e-15
Parallel difference: -2.66454e-15
-----
Min: 0.000113827, Max: 3.99513, Average: 1.17402
Serial difference: 1.70530e-13
Parallel difference: 1.42109e-13
-----
Min: -7.95673, Max: 7.87757, Average: 0.0658356
Serial difference: 0.00000
Parallel difference: -7.10543e-15
-----
Min: 2.53794e-09, Max: 15.8122, Average: 2.96504
Serial difference: -4.54747e-13
Parallel difference: -6.82121e-13
Notice that while sum2
& sum3
can be assumed to be more accurate than sum1
- they might not be the same as each other!
I seeded Random
with 42, so if anyone gets a different result to me, that would immediately prove some something. :-)
*
For the curious...
- Here are some (python) algorithms that give precise results
- The precise-sum algorithm with the best-sounding performance characteristics I've heard of is given here (ACM subscription or fee required). It takes 5 flops per input, but is written (in C) to exploit instruction-level parallelism and only run 2 - 3 times slower than naive summation, which sounds rather good for a precise result. (c.f. Kahan summation at 4 flops per input)