質問

I have created a random forest prediction model in R using the randomForest function:

model = randomForest(classification ~., data=train, ntree=100, proximity=T)

Next I plotted the model in order to see the overall error of the model:

plot(model, log="y")

This gives me the following plot: enter image description here

My question is how do I put a legend on this so that I can see which color corresponds to each value in the factor used for the classification? The factor variable is data$classification. I can't figure out the legend() call to do this.

役に立ちましたか?

解決

The plot S3 method plot use matplot to plot random forest model. You should add legend manually. This should be a good start:

library(randomForest)
model = randomForest(Species ~., data=iris, ntree=100, proximity=T)
layout(matrix(c(1,2),nrow=1),
       width=c(4,1)) 
par(mar=c(5,4,4,0)) #No margin on the right side
plot(model, log="y")
par(mar=c(5,0,4,2)) #No margin on the left side
plot(c(0,1),type="n", axes=F, xlab="", ylab="")
legend("top", colnames(model$err.rate),col=1:4,cex=0.8,fill=1:4)

enter image description here

他のヒント

you can use this,

model$finalModel.legend <- if (is.null(model$finalModel$test$err.rate)) {colnames(model$finalModel$err.rate)} else {colnames(model$finalModel$test$err.rate)}
legend("top", cex =0.7, legend=model$finalModel.legend, lty=c(1,2,3), col=c(1,2,3), horiz=T)
ライセンス: CC-BY-SA帰属
所属していません StackOverflow
scroll top