Pergunta

I have created a random forest prediction model in R using the randomForest function:

model = randomForest(classification ~., data=train, ntree=100, proximity=T)

Next I plotted the model in order to see the overall error of the model:

plot(model, log="y")

This gives me the following plot: enter image description here

My question is how do I put a legend on this so that I can see which color corresponds to each value in the factor used for the classification? The factor variable is data$classification. I can't figure out the legend() call to do this.

Foi útil?

Solução

The plot S3 method plot use matplot to plot random forest model. You should add legend manually. This should be a good start:

library(randomForest)
model = randomForest(Species ~., data=iris, ntree=100, proximity=T)
layout(matrix(c(1,2),nrow=1),
       width=c(4,1)) 
par(mar=c(5,4,4,0)) #No margin on the right side
plot(model, log="y")
par(mar=c(5,0,4,2)) #No margin on the left side
plot(c(0,1),type="n", axes=F, xlab="", ylab="")
legend("top", colnames(model$err.rate),col=1:4,cex=0.8,fill=1:4)

enter image description here

Outras dicas

you can use this,

model$finalModel.legend <- if (is.null(model$finalModel$test$err.rate)) {colnames(model$finalModel$err.rate)} else {colnames(model$finalModel$test$err.rate)}
legend("top", cex =0.7, legend=model$finalModel.legend, lty=c(1,2,3), col=c(1,2,3), horiz=T)
Licenciado em: CC-BY-SA com atribuição
Não afiliado a StackOverflow
scroll top