Optimisation of recursive algorithm in Java
-
26-06-2021 - |
Вопрос
Background
I have an ordered set of data points stored as a TreeSet<DataPoint>
. Each data point has a position
and a Set
of Event
objects (HashSet<Event>
).
There are 4 possible Event
objects A
, B
, C
, and D
. Every DataPoint
has 2 of these, e.g. A
and C
, except the first and last DataPoint
objects in the set, which have T
of size 1.
My algorithm is to find the probability of a new DataPoint
Q
at position x
having Event
q
in this set.
I do this by calculating a value S
for this data set, then adding Q
to the set and calculating S
again. I then divide the second S
by the first to isolate the probability for the new DataPoint
Q
.
Algorithm
The formula for calculating S
is:
http://mathbin.net/equations/105225_0.png
where
http://mathbin.net/equations/105225_1.png
http://mathbin.net/equations/105225_2.png
for http://mathbin.net/equations/105225_3.png
and
http://mathbin.net/equations/105225_4.png
http://mathbin.net/equations/105225_5.png is an expensive probability function that only depends on its arguments and nothing else (and http://mathbin.net/equations/105225_6.png), http://mathbin.net/equations/105225_7.png is the last DataPoint
in the set (righthand node), http://mathbin.net/equations/105225_8.png is the first DataPoint
(lefthand node), http://mathbin.net/equations/105225_9.png is the rightmost DataPoint
that isn't the node, http://mathbin.net/equations/105225_10.png is a DataPoint
,http://mathbin.net/equations/105225_12.png is the Set
of events for this DataPoint
.
So the probability for Q
with Event
q
is:
http://mathbin.net/equations/105225_11.png
Implementation
I implemented this algorithm in Java like so:
public class ProbabilityCalculator {
private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
// do some stuff
}
private Double f(DataPoint right, Event rightEvent, NavigableSet<DataPoint> points) {
DataPoint left = points.lower(right);
Double result = 0.0;
if(left.isLefthandNode()) {
result = 0.25 * p(right, rightEvent, left, null);
} else if(left.isQ()) {
result = p(right, rightEvent, left, left.getQEvent()) * f(left, left.getQEvent(), points);
} else { // if M_k
for(Event leftEvent : left.getEvents())
result += p(right, rightEvent, left, leftEvent) * f(left, leftEvent, points);
}
return result;
}
public Double S(NavigableSet<DataPoint> points) {
return f(points.last(), points.last().getRightNodeEvent(), points)
}
}
So to find the probability of Q
at x
with q
:
Double S1 = S(points);
points.add(Q);
Double S2 = S(points);
Double probability = S2/S1;
Problem
As the implementation stands at the moment it follows the mathematical algorithm closely. However this turns out not to be a particularly good idea in practice, as f
calls itself twice for each DataPoint
. So for http://mathbin.net/equations/105225_9.png, f
is called twice, then for the n-1
f
is called twice again for each of the previous calls, and so on and so forth. This leads to a complexity of O(2^n)
which is pretty terrible considering there can be over 1000 DataPoints
in each Set
. Because p()
is independent of everything except its parameters I have included a caching function where if p()
has already been calculated for these parameters it just returns the previous result, but this doesn't solve the inherent complexity problem. Am I missing something here with regards to repeat computations, or is the complexity unavoidable in this algorithm?
Решение 3
Thanks for all your suggestions. I implemented my solution by creating new nested classes for the values of P
and F
already calculated, then used a HashMap
to store the results. The HashMap
is then queried for the result before computation takes place; if it is present it just returns the result, if it is not it computes the result and adds it to the HashMap
.
The final product looks a bit like this:
public class ProbabilityCalculator {
private NavigableSet<DataPoint> points;
private ProbabilityCalculator(NavigableSet<DataPoint> points) {
this.points = points;
}
private static class P {
public final DataPoint left;
public final Event leftEvent;
public final DataPoint right;
public final Event rightEvent;
public P(DataPoint left, Event leftEvent, DataPoint right, Event rightEvent) {
this.left = left;
this.leftEvent = leftEvent;
this.right = right;
this.rightEvent = rightEvent;
}
public boolean equals(Object o) {
if(!(o instanceof P)) return false;
P p = (P) o;
if(!(this.leftEvent == null ? p.leftEvent == null : this.leftEvent.equals(p.leftEvent)))
return false;
if(!(this.rightEvent == null ? p.rightEvent == null : this.rightEvent.equals(p.rightEvent)))
return false;
return this.left.equals(p.left) && this.right.equals(p.right);
}
public int hashCode() {
int result = 93;
result = 31 * result + this.left.hashCode();
result = 31 * result + this.right.hashCode();
result = this.leftEvent != null ? 31 * result + this.leftEvent.hashCode() : 31 * result;
result = this.rightEvent != null ? 31 * result + this.rightEvent.hashCode() : 31 * result;
return result;
}
}
private Map<P, Double> usedPs = new HashMap<P, Double>();
private static class F {
public final DataPoint left;
public final Event leftEvent;
public final NavigableSet<DataPoint> dataPointsToLeft;
public F(DataPoint dataPoint, Event dataPointEvent, NavigableSet<DataPoint> dataPointsToLeft) {
this.dataPoint = dataPoint;
this.dataPointEvent = dataPointEvent;
this.dataPointsToLeft = dataPointsToLeft;
}
public boolean equals(Object o) {
if(!(o instanceof F)) return false;
F f = (F) o;
return this.dataPoint.equals(f.dataPoint) && this.dataPointEvent.equals(f.dataPointEvent) && this.dataPointsToLeft.equals(f.dataPointsToLeft);
}
public int hashCode() {
int result = 7;
result = 31 * result + this.dataPoint.hashCode();
result = 31 * result + this.dataPointEvent.hashCode();
result = 31 * result + this.dataPointsToLeft.hashCode();
return result;
}
}
private Map<F, Double> usedFs = new HashMap<F, Double>();
private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
P newP = new P(right, rightEvent, left, leftEvent);
if(this.usedPs.containsKey(newP)) return usedPs.get(newP);
// do some stuff
usedPs.put(newP, result);
return result;
}
private Double f(DataPoint right, Event rightEvent) {
NavigableSet<DataPoint> dataPointsToLeft = dataPoints.headSet(right, false);
F newF = new F(right, rightEvent, dataPointsToLeft);
if(usedFs.containsKey(newF)) return usedFs.get(newF);
DataPoint left = points.lower(right);
Double result = 0.0;
if(left.isLefthandNode()) {
result = 0.25 * p(right, rightEvent, left, null);
} else if(left.isQ()) {
result = p(right, rightEvent, left, left.getQEvent()) * f(left, left.getQEvent(), points);
} else { // if M_k
for(Event leftEvent : left.getEvents())
result += p(right, rightEvent, left, leftEvent) * f(left, leftEvent, points);
}
usedFs.put(newF, result)
return result;
}
public Double S() {
return f(points.last(), points.last().getRightNodeEvent(), points)
}
public static probabilityOfQ(DataPoint q, NavigableSet<DataPoint> points) {
ProbabilityCalculator pc = new ProbabilityCalculator(points);
Double S1 = S();
points.add(q);
Double S2 = S();
return S2/S1;
}
}
Другие советы
You also need to memoize f
on the first 2 arguments (the 3rd is always passed through, so you don't need to worry about that). This will reduce the time complexity of your code from O(2^n) to O(n).
UPDATED:
Since as commented below, order can not be used to help optimize another method must be utilized. Since most of the P values will be calculated multiple times (and as noted, this is expensive), one optimization would be to cache them. I am not sure of what the best key would be, but you could imagine changing the code something like:
....
private Map<String, Double> previousResultMap = new ....
private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
String key = // calculate unique key from inputs
Double previousResult = previousResultMap.get(key);
if (previousResult != null) {
return previousResult;
}
// do some stuff
previousResultMap.put(key, result);
return result;
}
This approach should effectively reduce a lot of the redundant calculations - however, as you know the data much more than I, you will need to determine the best way to set the key (and even if String is the best representation for that).