Domanda

I'm trying to implement multithreaded merge sort in Java. The idea is to recursively call to new threads on every iteration. Everything works properly, but problem is that regular single-thread version appears to be much more faster. Please, help fixing it. I've tried to play with .join(), but it haven't brought any success. My code:

public class MergeThread implements Runnable {

    private final int begin;
    private final int end;

    public MergeThread(int b, int e) {
        this.begin = b;
        this.end = e;
    }

    @Override
    public void run() {
        try {
            MergeSort.mergesort(begin, end);
        } catch (InterruptedException ex) {
            Logger.getLogger(MergeThread.class.getName()).log(Level.SEVERE, null, ex);
        }
    }
}

public class MergeSort {
    private static volatile int[] numbers;
    private static volatile int[] helper;

    private int number;

    public void sort(int[] values) throws InterruptedException {
    MergeSort.numbers = values;
    number = values.length;
    MergeSort.helper = new int[number];
    mergesort(0, number - 1);
    }

    public static void mergesort(int low, int high) throws InterruptedException {
    // check if low is smaller than high, if not then the array
    // is sorted
    if (low < high) {
        // Get the index of the element which is in the middle
        int middle = low + (high - low) / 2;
        // Sort the left side of the array
            Thread left = new Thread(new MergeThread(low, middle));
            Thread right = new Thread(new MergeThread(middle+1, high));

            left.start();
            right.start();
            left.join();
            right.join();

        // combine the sides
        merge(low, middle, high);
    }
}

    private static void merge(int low, int middle, int high) {
    // Copy both parts into the helper array
    for (int i = low; i <= high; i++) {
        helper[i] = numbers[i];
    }

    int i = low;
    int j = middle + 1;
    int k = low;
    // Copy the smallest value from either the left or right side
    // back to the original array
    while (i <= middle && j <= high) {
        if (helper[i] <= helper[j]) {
        numbers[k] = helper[i];
        i++;
        } else {
        numbers[k] = helper[j];
        j++;
        }
        k++;
    }
    // Copy the rest of the left side of the array
    while (i <= middle) {
        numbers[k] = helper[i];
        k++;
        i++;
    }
}

    public static void main(String[] args) throws InterruptedException {
        int[] array = new int[1000];
        for(int pos = 0; pos<1000; pos++) {
            array[pos] = 1000-pos;
        }
        long start = System.currentTimeMillis();
        new MergeSort().sort(array);
        long finish = System.currentTimeMillis();

        for(int i = 0; i<array.length; i++) {
            System.out.print(array[i]+" ");
        }
        System.out.println();
        System.out.println(finish-start);

    }
}
È stato utile?

Soluzione

There are several factors here. First of all, you are spawning too many threads. A lot more than the number of cores your processor has. If I understand your algorithm correctly you are doing something like log2(n) at the bottom level of your tree.

Given that you're doing processor intensive computations, not involving I/O, once you pass the number of cores with your thread count, the performance starts degrading pretty fast. Hitting something like several thousand threads will slow and in the end crash the VM.

If you want to actually benefit from having a multi-core processor in this computation you should try to use a fixed size thread-pool (upper bounded on the number of cores or thereabout) or an equivalent thread reuse policy.

Second point, if you want to do a valid comparison you should try with computations that last longer (sorting 100 numbers doesn't qualify). If not, you are taking a significant relative hit from the cost of creating threads.

Altri suggerimenti

Provide a threadcount in begining as number of cores or less.

Below link has performance analysis too.

Here if a good example https://courses.cs.washington.edu/courses/cse373/13wi/lectures/03-13/MergeSort.java

Below is the iterative serial version of MergeSort which is indeed faster than the recursive version and also doesnot involve calculation of middle so avoids the overflow error for it. However overflow errors can occur for other integers as well. You can try for parallelizing it if you are interested.

protected static int[] ASC(int input_array[]) // Sorts in ascending order
{
    int num = input_array.length;
    int[] temp_array = new int[num];
    int temp_indx;
    int left;
    int mid,j;
    int right;
    int[] swap;
    int LIMIT = 1;
    while (LIMIT < num)
    {
        left = 0;
        mid = LIMIT ; // The mid point
        right = LIMIT << 1;
        while (mid < num)
        {
            if (right > num){ right = num; }
            temp_indx = left;
            j = mid;
            while ((left < mid) && (j < right))
            {
                if (input_array[left] < input_array[j]){  temp_array[temp_indx++] = input_array[left++];  }
                else{  temp_array[temp_indx++] = input_array[j++];  }
            }
            while (left < mid){  temp_array[temp_indx++] = input_array[left++];  }
            while (j < right){  temp_array[temp_indx++] = input_array[j++];  }

            // Do not copy back the elements to input_array
            left = right;
            mid = left + LIMIT;
            right = mid + LIMIT;
        }
        // Instead of copying back in previous loop, copy remaining elements to temp_array, then swap the array pointers
        while (left < num){  temp_array[left] = input_array[left++];  }

        swap = input_array;
        input_array = temp_array;
        temp_array = swap;

        LIMIT <<= 1;
    }
    return input_array ;
}

Use the java executor service, thats a lot faster, even with threads exceeding the number of cores ( you can build scalable multithreaded applications with it ), I have a code that uses only threads but its very very slow, and Im new to executors so cant help much, but its an interesting area to explore.

Also there is a cost for parallelism, because thread management is a big deal, so go for parallelism at high N, if you are looking for a serial alternative to merge sort, I suggest the Dual-Pivot-QuickSort or 3-Partition-Quick-Sort as they are known to beat merge sort often. Reason is that they have low constant factors than MergeSort and the worst case time complexity has the probability of occuring only 1/(n!). If N is large, the worst case probability becomes very small paving way for increased probability of average case. You could multithread both and see which one among the 4 programs ( 1 serial and 1 multithreaded for each : DPQ and 3PQ ) runs the fastest.

But Dual-Pivot-QuickSort works best when there are no, or almost no duplicate keys and 3-Partition-Quick-Sort works best when there are many duplicate keys. I have never seen 3-Partition-Quick-Sort beat the Dual-Pivot-QuickSort when there are none or very few duplicate keys, but I have seen Dual-Pivot-QuickSort beat 3-Partition-Quick-Sort a very small number of times in case of many duplicate keys. In case you are interested, DPQ serial code is below( both ascending and descending)

protected static void ASC(int[]a, int left, int right, int div)
{
    int len = 1 + right - left;
    if (len < 27)
    {
        // insertion sort for small array
        int P1 = left + 1;
        int P2 = left;
        while ( P1 <= right )
        {
            div = a[P1];
            while(( P2 >= left )&&( a[P2] > div ))
            {
                a[P2 + 1] = a[P2];
                P2--;
            }
            a[P2 + 1] = div;
            P2 = P1;
            P1++;
        }
        return;
    }
    int third = len / div;
    // "medians"
    int P1 = left + third;
    int P2 = right - third;
    if (P1 <= left)
    {
        P1 = left + 1;
    }
    if (P2 >= right)
    {
        P2 = right - 1;
    }
    int temp;
    if (a[P1] < a[P2])
    {
        temp = a[P1]; a[P1] = a[left]; a[left] = temp;
        temp = a[P2]; a[P2] = a[right]; a[right] = temp;
    }
    else
    {
        temp = a[P1];  a[P1] = a[right];  a[right] = temp;
        temp = a[P2];  a[P2] = a[left];  a[left] = temp;
    }
    // pivots
    int pivot1 = a[left];
    int pivot2 = a[right];
    // pointers
    int less = left + 1;
    int great = right - 1;
    // sorting
    for (int k = less; k <= great; k++)
    {
        if (a[k] < pivot1)
        {
            temp = a[k];  a[k] = a[less];  a[less] = temp;
            less++;
        }
        else if (a[k] > pivot2)
        {
            while (k < great && a[great] > pivot2)
            {
                great--;
            }
            temp = a[k];  a[k] = a[great];  a[great] = temp;
            great--;
            if (a[k] < pivot1)
            {
                temp = a[k];  a[k] = a[less];  a[less] = temp;
                less++;
            }
        }
    }
    int dist = great - less;
    if (dist < 13)
    {
        div++;
    }
    temp = a[less-1];  a[less-1] = a[left];  a[left] = temp;
    temp = a[great+1];  a[great+1] = a[right];  a[right] = temp;
    // subarrays
    ASC(a, left, less - 2, div);
    ASC(a, great + 2, right, div);
    // equal elements
    if (dist > len - 13 && pivot1 != pivot2)
    {
        for (int k = less; k <= great; k++)
        {
            if (a[k] == pivot1)
            {
                temp = a[k];  a[k] = a[less];  a[less] = temp;
                less++;
            }
            else if (a[k] == pivot2)
            {
                temp = a[k];  a[k] = a[great];  a[great] = temp;
                great--;
                if (a[k] == pivot1)
                {
                    temp = a[k];  a[k] = a[less];  a[less] = temp;
                    less++;
                }
            }
        }
    }
    // subarray
    if (pivot1 < pivot2)
    {
        ASC(a, less, great, div);
    }
}

protected static void DSC(int[]a, int left, int right, int div)
{
    int len = 1 + right - left;
    if (len < 27)
    {
        // insertion sort for large array
        int P1 = left + 1;
        int P2 = left;
        while ( P1 <= right )
        {
            div = a[P1];
            while(( P2 >= left )&&( a[P2] < div ))
            {
                a[P2 + 1] = a[P2];
                P2--;
            }
            a[P2 + 1] = div;
            P2 = P1;
            P1++;
        }
        return;
    }
    int third = len / div;
    // "medians"
    int P1 = left + third;
    int P2 = right - third;
    if (P1 >= left)
    {
        P1 = left + 1;
    }
    if (P2 <= right)
    {
        P2 = right - 1;
    }
    int temp;
    if (a[P1] > a[P2])
    {
        temp = a[P1]; a[P1] = a[left]; a[left] = temp;
        temp = a[P2]; a[P2] = a[right]; a[right] = temp;
    }
    else
    {
        temp = a[P1];  a[P1] = a[right];  a[right] = temp;
        temp = a[P2];  a[P2] = a[left];  a[left] = temp;
    }
    // pivots
    int pivot1 = a[left];
    int pivot2 = a[right];
    // pointers
    int less = left + 1;
    int great = right - 1;
    // sorting
    for (int k = less; k <= great; k++)
    {
        if (a[k] > pivot1)
        {
            temp = a[k];  a[k] = a[less];  a[less] = temp;
            less++;
        }
        else if (a[k] < pivot2)
        {
            while (k < great && a[great] < pivot2)
            {
                great--;
            }
            temp = a[k];  a[k] = a[great];  a[great] = temp;
            great--;
            if (a[k] > pivot1)
            {
                temp = a[k];  a[k] = a[less];  a[less] = temp;
                less++;
            }
        }
    }
    int dist = great - less;
    if (dist < 13)
    {
        div++;
    }
    temp = a[less-1];  a[less-1] = a[left];  a[left] = temp;
    temp = a[great+1];  a[great+1] = a[right];  a[right] = temp;
    // subarrays
    DSC(a, left, less - 2, div);
    DSC(a, great + 2, right, div);
    // equal elements
    if (dist > len - 13 && pivot1 != pivot2)
    {
        for (int k = less; k <= great; k++)
        {
            if (a[k] == pivot1)
            {
                temp = a[k];  a[k] = a[less];  a[less] = temp;
                less++;
            }
            else if (a[k] == pivot2)
            {
                temp = a[k];  a[k] = a[great];  a[great] = temp;
                great--;
                if (a[k] == pivot1)
                {
                    temp = a[k];  a[k] = a[less];  a[less] = temp;
                    less++;
                }
            }
        }
    }
    // subarray
    if (pivot1 > pivot2)
    {
        DSC(a, less, great, div);
    }
}
Autorizzato sotto: CC-BY-SA insieme a attribuzione
Non affiliato a StackOverflow
scroll top