Pergunta

I'm using Mallet through Java, and I can't work out how to evaluate new documents against an existing topic model which I have trained.

My initial code to generate my model is very similar to that in the Mallett Developers Guide for Topic Modelling, after which I simply save the model as a Java object. In a later process, I reload that Java object from file, add new instances via .addInstances() and would then like to evaluate only these new instances against the topics found in the original training set.

This stats.SE thread provides some high-level suggestions, but I can't see how to work them into the Mallet framework.

Any help much appreciated.

Foi útil?

Solução 2

And I've found the answer hidden in a slide-deck from Mallet's lead developer:

TopicInferencer inferencer = model.getInferencer();
double[] topicProbs = inferencer.getSampledDistribution(newInstance, 100, 10, 10);

Outras dicas

Inference is actually also listed in the example link provided in the question (the last few lines).

For anyone interested in the whole code for saving/loading the trained model and then using it for inferring model distribution for new documents - here are some snippets:

After model.estimate() has completed, you have the actual trained model so you can serialize it using a standard Java ObjectOutputStream (since ParallelTopicModel implements Serializable):

try {
    FileOutputStream outFile = new FileOutputStream("model.ser");
    ObjectOutputStream oos = new ObjectOutputStream(outFile);
    oos.writeObject(model);
    oos.close();
} catch (FileNotFoundException ex) {
    // handle this error
} catch (IOException ex) {
    // handle this error
}

Note though, when you infer, you need also to pass the new sentences (as Instance) through the same pipeline in order to pre-process it (tokenzie etc) thus, you need to also save the pipe-list (since we're using SerialPipe when can create an instance and then serialize it):

// initialize the pipelist (using in model training)
SerialPipes pipes = new SerialPipes(pipeList);

try {
    FileOutputStream outFile = new FileOutputStream("pipes.ser");
    ObjectOutputStream oos = new ObjectOutputStream(outFile);
    oos.writeObject(pipes);
    oos.close();
} catch (FileNotFoundException ex) {
    // handle error
} catch (IOException ex) {
    // handle error
}

In order to load the model/pipeline and use them for inference we need to de-serialize:

private static void InferByModel(String sentence) {
    // define model and pipeline
    ParallelTopicModel model = null;
    SerialPipes pipes = null;

    // load the model
    try {
        FileInputStream outFile = new FileInputStream("model.ser");
        ObjectInputStream oos = new ObjectInputStream(outFile);
        model = (ParallelTopicModel) oos.readObject();
    } catch (IOException ex) {
        System.out.println("Could not read model from file: " + ex);
    } catch (ClassNotFoundException ex) {
        System.out.println("Could not load the model: " + ex);
    }

    // load the pipeline
    try {
        FileInputStream outFile = new FileInputStream("pipes.ser");
        ObjectInputStream oos = new ObjectInputStream(outFile);
        pipes = (SerialPipes) oos.readObject();
    } catch (IOException ex) {
        System.out.println("Could not read pipes from file: " + ex);
    } catch (ClassNotFoundException ex) {
        System.out.println("Could not load the pipes: " + ex);
    }

    // if both are properly loaded
    if (model != null && pipes != null){

        // Create a new instance named "test instance" with empty target 
        // and source fields note we are using the pipes list here
        InstanceList testing = new InstanceList(pipes);   
        testing.addThruPipe(
            new Instance(sentence, null, "test instance", null));

        // here we get an inferencer from our loaded model and use it
        TopicInferencer inferencer = model.getInferencer();
        double[] testProbabilities = inferencer
                   .getSampledDistribution(testing.get(0), 10, 1, 5);
        System.out.println("0\t" + testProbabilities[0]);
    }
}

For some reason I am not getting the exact same inference with the loaded model as with the original one - but this is a matter for another question (if anyone knows though, I'd be happy to hear)

Licenciado em: CC-BY-SA com atribuição
Não afiliado a StackOverflow
scroll top