I would like to improve my fork/join little example to show that during Java Fork/Join framework execution work stealing occurs.

What changes I need to do to following code? Purpose of example: just do a linear research of a value breaking up work between multiple threads.

package com.stackoverflow.questions;

import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

public class CounterFJ<T extends Comparable<T>> extends RecursiveTask<Integer> {

    private static final long serialVersionUID = 5075739389907066763L;
    private List<T> _list;
    private T _test;
    private int _lastCount = -1;
    private int _start;
    private int _end;
    private int _divideFactor = 4;

    private static final int THRESHOLD = 20;

    public CounterFJ(List<T> list, T test, int start, int end, int factor) {
        _list = list;
        _test = test;
        _start = start;
        _end = end;
        _divideFactor = factor;

    }

    public CounterFJ(List<T> list, T test, int factor) {
        this(list, test, 0, list.size(), factor);
    }

    @Override
    protected Integer compute() {

        if (_end - _start < THRESHOLD) {
            int count = 0;
            for (int i = _start; i < _end; i++) {
                if (_list.get(i).compareTo(_test) == 0) {
                    count++;
                }
            }
            _lastCount = count;
            return new Integer(count);
        }
        LinkedList<CounterFJ<T>> taskList = new LinkedList<>();

        int step = (_end - _start) / _divideFactor;
        for (int j = 0; j < _divideFactor; j++) {
            CounterFJ<T> task = null;
            if (j == 0)
                task = new CounterFJ<T>(_list, _test, _start, _start + step, _divideFactor);
            else if (j == _divideFactor - 1)
                task = new CounterFJ<T>(_list, _test, _start + (step * j), _end, _divideFactor);
            else
                task = new CounterFJ<T>(_list, _test, _start + (step * j), _start + (step * (j + 1)), _divideFactor);

            // task.fork();
            taskList.add(task);
        }
        invokeAll(taskList);

        _lastCount = 0;
        for (CounterFJ<T> task : taskList) {
            _lastCount += task.join();
        }

        return new Integer(_lastCount);

    }

    public int getResult() {
        return _lastCount;
    }

    public static void main(String[] args) {

        LinkedList<Long> list = new LinkedList<Long>();
        long range = 200;
        Random r = new Random(42);

        for (int i = 0; i < 1000; i++) {
            list.add(new Long((long) (r.nextDouble() * range)));
        }

        CounterFJ<Long> counter = new CounterFJ<>(list, new Long(100), 4);

        ForkJoinPool pool = new ForkJoinPool();
        long time = System.currentTimeMillis();
        pool.invoke(counter);

        System.out.println("Fork join counter in " + (System.currentTimeMillis() - time));
        System.out.println("Occurrences:" + counter.getResult());

    }

}
有帮助吗?

解决方案

Finally I managed how to and it's not difficult so I leave this for future readers.

In the costructor of the RecursiveTask save thread that created the instance itself. In the compute method check if executing thread is the same or not. If not work-stealing has occurred.

So I added this member variable

private long _threadId = -1;
private static int stolen_tasks = 0;

changed constructor like this:

public CounterFJ(List<T> list, T test, int start, int end, int factor) {
        _list = list;
        _threadId = Thread.currentThread().getId(); //added
        _test = test;
        _start = start;
        _end = end;
        _branchFactor = factor;
}

and added comparison into compute method:

   @Override
   protected Integer compute() {

      long thisThreadId = Thread.currentThread().getId();
      if (_threadId != thisThreadId){
          stolen_tasks++;
      }
   // rest of the method
许可以下: CC-BY-SA归因
不隶属于 StackOverflow
scroll top