Pergunta

I am trying to build upon a problem, to solve another similar problem... given below is a code for finding the total number of subsets that sum to a particular value, and I am trying to modify the code so that I can return all subsets that sum to that value (instead of finding the count).

Code for finding the total number of suibsets that sum to 'sum':

 /**
 * method to return number of sets with a given sum.
 **/
public static int count = 0;
public static void countSubsetSum2(int arr[], int k, int sum) {
    if(sum == 0) {
        count++;
        return;
    }
    if(sum != 0 && k == 0) {
        return;
    }
    if(sum < arr[k - 1]) {
        countSubsetSum2(arr, k-1, sum);
    }
    countSubsetSum2(arr, k-1, sum - arr[k-1]);
    countSubsetSum2(arr, k-1, sum);
}

Can someone propose some changes to this code, to make it return the subsets rather than the subset count?

Foi útil?

Solução

Firstly, your code isn't correct.

The function, at every step, recurses with the sum excluding and including the current element 1, moving on to the next element, thanks to these lines:

countSubsetSum2(arr, k-1, sum - arr[k-1]);
countSubsetSum2(arr, k-1, sum);

But then there's also this:

if(sum < arr[k - 1]) {
    countSubsetSum2(arr, k-1, sum);
}

which causes it to recurse twice with the sum excluding the current element under some circumstances (which it should never do).

Essentially you just need to remove that if-statement.

If all the elements are positive and sum - arr[k-1] < 0, we'd keep going, but we can never get a sum of 0 since the sum can't increase, thus we'd be doing a lot of unnecessary work. So, if the elements are all positive, we can add a check for if(arr[k - 1] <= sum) to the first call to improve the running time. If the elements aren't all positive, the code won't find all sums.

Now on to printing the sums

If you understand the code well, changing it to print the sums instead should be pretty easy. I suggest you work on understanding it a bit more - trace what the program will do by hand, then trace what you want the program to do.

And a hint for solving the actual problem: On noting that countSubsetSum2(arr, k-1, sum - arr[k-1]); recurses with the sum including the current element (and the other recursive call recurses with the sum excluding the current element), what you should do should become clear.


1: Well, technically it's reversed (we start with the target sum and decrease to 0 instead of starting at 0 and increasing to sum), but the same idea is there.

Outras dicas

This is the code that works:

import java.util.LinkedList;
import java.util.Iterator;
import java.util.List;

public class subset{
    public static int count = 0;
    public static List list = new LinkedList();
    public static void countSubsetSum2(int arr[], int k, int sum) {
        if(sum <= 0 || k < 0) {
            count++;
            return;
        }
        if(sum == arr[k]) {
            System.out.print(arr[k]);
            for(Iterator i = list.iterator(); i.hasNext();)
                System.out.print("\t" + i.next());
            System.out.println();
        }
        list.add(arr[k]);
        countSubsetSum2(arr, k-1, sum - arr[k]);
        list.remove(list.size() - 1);
        countSubsetSum2(arr, k-1, sum);
    }

    public static void main(String[] args)
    {
        int [] array = {1, 4, 5, 6};
        countSubsetSum2(array, 3, 10);
    }
}

First off, the code you have there doesn't seem to actually work (I tested it on input [1,2,3, ..., 10] with a sum of 3 and it output 128).

To get it working, first note that you implemented the algorithm in a pretty unorthodox way. Mathematical functions take input and produce output. (Arguably) the most elegant programming functions should also take input and produce output because then we can reason about them as we reason about math.

In your case you don't produce any output (the return type is void) and instead store the result in a static variable. This means it's hard to tell exactly what it means to call countSubsetSum2. In particular, what happens if you call it multiple times? It does something different each time (because the count variable will have a different starting value!) Instead, if you write countSubsetSum2 so that it returns a value then you can define its behavior to be: countSubsetSum2 returns the number of subsets of the input arr[0...k] that sum to sum. And then you can try proving why your implementation meets that specification.

I'm not doing the greatest job of explaining, but I think a more natural way to write it would be:

// Algorithm stops once k is the least element in the array
if (k == 0) {
    if (sum == 0 || sum == arr[k]) {
        // Either we can sum to "sum"
        return 1;
    }
    else {
        // Or we can't sum to "sum"
        return 0;
    }   
}   

// Otherwise, let's recursively see if we can sum to "sum"

// Any valid subset either includes arr[k]
return countSubsetSum2(arr, k-1, sum - arr[k]) +
// Or it doesn't
countSubsetSum2(arr, k-1, sum);

As described above, this function takes an input and outputs a value that we can define and prove to be true mathematically (caveat: it's usually not quite a proof because there are crazy edge cases in most programming languages unfortunately).

Anyways, to get back to your question. The issue with the above code is that it doesn't store any data... it just returns the count. Instead, let's generate the actual subsets while we're generating them. In particular, when I say Any valid subset either includes arr[k] I mean... the subset we're generating includes arr[k]; so add it. Below I assumed that the code you wrote above is java-ish. Hopefully it makes sense:

// Algorithm stops once k is the least element in the array
if (k == 0) {
    if (sum == 0 || sum == arr[k]) {
        // Either we can sum to "sum" using just arr[0]
        // So return a list of all of the subsets that sum to "sum"
        // There are actually a few edge cases here, so we need to be careful
        List<Set<int>> ret = new List<Set<int>>();

        // First consider if the singleton containing arr[k] could equal sum 
        if (sum == arr[k])
        {   
            Set<int> subSet = new Subset<int>();
            subSet.Add(arr[k]);
            ret.Add(subSet);
        }   

        // Now consider the empty set 
        if (sum == 0)
        {   
            Set<int> subSet = new Subset<int>();
            ret.Add(subSet);
        }   

        return ret;
    }   
    else {
        // Or we can't sum to "sum" using just arr[0]
        // So return a list of all of the subsets that sum to "sum". None 
        // (given our inputs!)
        List<Set<int>> ret = new List<Set<int>>();
        return ret;
    }
}

// Otherwise, let's recursively generate subsets summing to "sum"

// Any valid subset either includes arr[k]
List<Set<int>> subsetsThatNeedKthElement = genSubsetSum(arr, k-1, sum - arr[k]);
// Or it doesn't
List<Set<int>> completeSubsets = genSubsetSum(arr, k-1, sum);

// Note that subsetsThatNeedKthElement only sum to "sum" - arr[k]... so we need to add
// arr[k] to each of those subsets to create subsets which sum to "sum"
// On the other hand, completeSubsets contains subsets which already sum to "sum"
// so they're "complete"

// Initialize it with the completed subsets
List<Set<int>> ret = new List<Set<int>>(completeSubsets);
// Now augment the incomplete subsets and add them to the final list
foreach (Set<int> subset in subsetsThatNeedKthElement)
{
    subset.Add(arr[k]);
    ret.Add(subset);
}

return ret;

The code is pretty cluttered with all the comments; but the key point is that this implementation always returns what it's specified to return (a list of sets of ints from arr[0] to arr[k] which sum to whatever sum was passed in).

FYI, there is another approach which is "bottom-up" (i.e. doesn't use recursion) which should be more performant. If you implement it that way, then you need to store extra data in static state (a "memoized table")... which is a bit ugly but practical. However, when you implement it this way you need to have a more clever way of generating the subsets. Feel free to ask that question in a separate post after giving it a try.

Based, on the comments/suggestions here, I have been able to get the solution for this problem in this way:

public static int counter = 0;
public static List<List<Integer>> lists = new ArrayList<>();
public static void getSubsetCountThatSumToTargetValue(int[] arr, int k, int targetSum, List<Integer> list) {
    if(targetSum == 0) {
        counter++;
        lists.add(list);
        return;
    }

    if(k <= 0) {
        return;
    }

    getSubsetCountThatSumToTargetValue(arr, k - 1, targetSum, list);

    List<Integer> appendedlist = new ArrayList<>();
    appendedlist.addAll(list);
    appendedlist.add(arr[k - 1]);
    getSubsetCountThatSumToTargetValue(arr, k - 1, targetSum - arr[k - 1], appendedlist);
}

The main method looks like this:

public static void main(String[] args) {

    int[] arr = {1, 2, 3, 4, 5};
    SubSetSum.getSubsetCountThatSumToTargetValue(arr, 5, 9, new ArrayList<Integer>());
    System.out.println("Result count: " + counter);
    System.out.println("lists: " + lists);

}

Output:

Result: 3
lists: [[4, 3, 2], [5, 3, 1], [5, 4]]

A Python implementation with k moving from 0 to len() - 1:

import functools
def sum_of_subsets( numbers, sum_original ):

  def _sum_of_subsets( list, k, sum ):
    if sum < 0 or k == len( numbers ):
      return

    if ( sum == numbers[ k ] ):
      expression = functools.reduce( lambda result, num: str( num ) if len( result ) == 0 else \
                                                          "%s + %d" % ( result, num ),
                           sorted( list + [ numbers[ k ]] ),
                           '' )
      print "%d = %s" % ( sum_original, expression )
      return

    list.append( numbers[ k ] )
    _sum_of_subsets( list, k + 1, sum - numbers[ k ])

    list.pop( -1 )
    _sum_of_subsets( list, k + 1, sum )

  _sum_of_subsets( [], 0, sum_original )

...

sum_of_subsets( [ 8, 6, 3, 4, 2, 5, 7, 1, 9, 11, 10, 13, 12, 14, 15 ], 15 )

...

15 = 1 + 6 + 8
15 = 3 + 4 + 8
15 = 1 + 2 + 4 + 8
15 = 2 + 5 + 8
15 = 7 + 8
15 = 2 + 3 + 4 + 6
15 = 1 + 3 + 5 + 6
15 = 4 + 5 + 6
15 = 2 + 6 + 7
15 = 6 + 9
15 = 1 + 2 + 3 + 4 + 5
15 = 1 + 3 + 4 + 7
15 = 1 + 2 + 3 + 9
15 = 2 + 3 + 10
15 = 3 + 5 + 7
15 = 1 + 3 + 11
15 = 3 + 12
15 = 2 + 4 + 9
15 = 1 + 4 + 10
15 = 4 + 11
15 = 1 + 2 + 5 + 7
15 = 1 + 2 + 12
15 = 2 + 13
15 = 1 + 5 + 9
15 = 5 + 10
15 = 1 + 14
15 = 15
Licenciado em: CC-BY-SA com atribuição
Não afiliado a StackOverflow
scroll top