문제

I have created a random forest prediction model in R using the randomForest function:

model = randomForest(classification ~., data=train, ntree=100, proximity=T)

Next I plotted the model in order to see the overall error of the model:

plot(model, log="y")

This gives me the following plot: enter image description here

My question is how do I put a legend on this so that I can see which color corresponds to each value in the factor used for the classification? The factor variable is data$classification. I can't figure out the legend() call to do this.

도움이 되었습니까?

해결책

The plot S3 method plot use matplot to plot random forest model. You should add legend manually. This should be a good start:

library(randomForest)
model = randomForest(Species ~., data=iris, ntree=100, proximity=T)
layout(matrix(c(1,2),nrow=1),
       width=c(4,1)) 
par(mar=c(5,4,4,0)) #No margin on the right side
plot(model, log="y")
par(mar=c(5,0,4,2)) #No margin on the left side
plot(c(0,1),type="n", axes=F, xlab="", ylab="")
legend("top", colnames(model$err.rate),col=1:4,cex=0.8,fill=1:4)

enter image description here

다른 팁

you can use this,

model$finalModel.legend <- if (is.null(model$finalModel$test$err.rate)) {colnames(model$finalModel$err.rate)} else {colnames(model$finalModel$test$err.rate)}
legend("top", cex =0.7, legend=model$finalModel.legend, lty=c(1,2,3), col=c(1,2,3), horiz=T)
라이센스 : CC-BY-SA ~와 함께 속성
제휴하지 않습니다 StackOverflow
scroll top