Question

I want to train mahout for classification. For me this text is coming from database and I really do not want to store them to file for mahout training. I checked out the the MIA source code and changed the following code for very basic training task. Usual issue with mahout examples are either they show how to use mahout from cmd prompt using 20 news group, or the code has lot of dependency on Hadoop Zookeeper etc. I will really appreciate if someone can have a look at my code, or point me to a very simple tutorial which show how to train a model and then use it.

As of now in following code I am never getting past if (best != null) because learningAlgorithm.getBest(); is always returning null!

Sorry for posting the whole code, but didn't see any other option

public class Classifier {

    private static final int FEATURES = 10000;
    private static final TextValueEncoder encoder = new TextValueEncoder("body");
    private static final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
    private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"};

    /**
     * @param args the command line arguments
     */
    public static void main(String[] args) throws Exception {
        int leakType = 0;
        // TODO code application logic here
        AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1());
        Dictionary newsGroups = new Dictionary();
        //ModelDissector md = new ModelDissector();
        ListMultimap<String, String> noteBySection = LinkedListMultimap.create();
        noteBySection.put("good", "I love this product, the screen is a pleasure to work with and is a great choice for any business");
        noteBySection.put("good", "What a product!! Really amazing clarity and works pretty well");
        noteBySection.put("good", "This product has good battery life and is a little bit heavy but I like it");

        noteBySection.put("bad", "I am really bored with the same UI, this is their 5th version(or fourth or sixth, who knows) and it looks just like the first one");
        noteBySection.put("bad", "The phone is bulky and useless");
        noteBySection.put("bad", "I wish i had never bought this laptop. It died in the first year and now i am not able to return it");


        encoder.setProbes(2);
        double step = 0;
        int[] bumps = {1, 2, 5};
        double averageCorrect = 0;
        double averageLL = 0;
        int k = 0;
        //-------------------------------------
        //notes.keySet()
        for (String key : noteBySection.keySet()) {
            System.out.println(key);
            List<String> notes = noteBySection.get(key);
            for (Iterator<String> it = notes.iterator(); it.hasNext();) {
                String note = it.next();


                int actual = newsGroups.intern(key);
                Vector v = encodeFeatureVector(note);
                learningAlgorithm.train(actual, v);

                k++;
                int bump = bumps[(int) Math.floor(step) % bumps.length];
                int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
                State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
                double maxBeta;
                double nonZeros;
                double positive;
                double norm;

                double lambda = 0;
                double mu = 0;
                if (best != null) {
                    CrossFoldLearner state = best.getPayload().getLearner();
                    averageCorrect = state.percentCorrect();
                    averageLL = state.logLikelihood();

                    OnlineLogisticRegression model = state.getModels().get(0);
                    // finish off pending regularization
                    model.close();

                    Matrix beta = model.getBeta();
                    maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
                    nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {

                        @Override
                        public double apply(double v) {
                            return Math.abs(v) > 1.0e-6 ? 1 : 0;
                        }
                    });
                    positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {

                        @Override
                        public double apply(double v) {
                            return v > 0 ? 1 : 0;
                        }
                    });
                    norm = beta.aggregate(Functions.PLUS, Functions.ABS);

                    lambda = learningAlgorithm.getBest().getMappedParams()[0];
                    mu = learningAlgorithm.getBest().getMappedParams()[1];
                } else {
                    maxBeta = 0;
                    nonZeros = 0;
                    positive = 0;
                    norm = 0;
                }
                System.out.println(k % (bump * scale));
                if (k % (bump * scale) == 0) {

                    if (learningAlgorithm.getBest() != null) {
                        System.out.println("----------------------------");
                        ModelSerializer.writeBinary("c:/tmp/news-group-" + k + ".model",
                                learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
                    }

                    step += 0.25;
                    System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
                    System.out.printf("%d\t%.3f\t%.2f\t%s\n",
                            k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]);
                }
            }

        }
         learningAlgorithm.close();
    }

    private static Vector encodeFeatureVector(String text) {
        encoder.addText(text.toLowerCase());
        //System.out.println(encoder.asString(text));
        Vector v = new RandomAccessSparseVector(FEATURES);
        bias.addToVector((byte[]) null, 1, v);
        encoder.flush(1, v);
        return v;
    }
}
Was it helpful?

Solution

You need to add the words to your feature vector correctly. It looks like the following code:

        bias.addToVector((byte[]) null, 1, v);

Is not doing what you'd expect. Its just adding the null bytes to the feature vector with weight 1.

The you're calling a wrapper to the WordValueEncoder.addToVector(byte[] originalForm, double w, Vector data) method.

Make sure to loop over the word values in your note map values and add them to the feature vector accordingly.

OTHER TIPS

This happened to me earlier today. I see that you have very few initial samples as you are playing with the code like I was as well. My issue was that since this algorithm is an adaptive algorithm, I needed to set the interval and window for "adapting" to be very low like this otherwise it would never find a new best model :

learningAlgorithm.setInterval(1);
learningAlgorithm.setAveragingWindow(1);

This way, the algorithm can be forced to "adapt" after every 1 vector it sees which will be critical since your example code has only 6 vectors.

I strongly suggest you also direct your question to the very nice people in the Mahout mailing list https://mahout.apache.org/general/mailing-lists,-irc-and-archives.html

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top