Question

I am using the party package in R.

I would like to get various statistics (mean, median, etc) from various nodes of the resultant tree, but I cannot see how to do this. For example

airq <- subset(airquality, !is.na(Ozone))
airct <- ctree(Ozone ~ ., data = airq, 
                   controls = ctree_control(maxsurrogate = 3))
airct
plot(airct)

results in a tree with 4 terminal nodes. How would I get the mean airquality for each of those nodes?

Was it helpful?

Solution 2

This is surprisingly harder than I thought. Try something like this:

a <- by(airq,where(airct),colMeans) #or whatever function you desire for colMeans
a
a$"3" #access at node three
a[["3"]] #same thing

You might find some other useful examples with ?`BinaryTree-class`.

OTHER TIPS

I can't get which variable of the node is the airquality. But I show you here how to customize your tree plot:

innerWeights <- function(node){
  grid.circle(gp = gpar(fill = "White", col = 1))
  mainlab <- node$psplit$variableName
  label   <- paste(mainlab,paste('prediction=',round(node$prediction,2) ,sep= ''),sep= '\n')
  grid.text( label= label,gp = gpar(col='red'))
}

plot(airct, inner_panel = innerWeights)

enter image description here

Edit to get statistics by node

library(gridExtra)

innerWeights <- function(node){
  dat <- round_any(node$criterion$statistic,0.01)
  grid.table(t(dat))
}
plot(airct, inner_panel = innerWeights)

enter image description here

How to get there if you are lost in R-space (and the documentation does not help you immediately)

First, try str(airct): The output is a bit lengthy, since the results are complex, but for easier cases, e.g. t-test, this is all you need.

Since print(airct) or simply airct gives quite useful info, how does print work? Try class(airct) or check the documentation: The result if of class BinaryTree.

Ok, we could have seen this from the docs, and in this case the information on the BinaryTree page is good enough (see the examples on that page.)

But assume the author was lazy: the try getAnywhere(print.BinaryTree). On the top you find y<-x@responses: So try airct@responses next

You can also do this using the dplyr package.

First get which node each observation belongs to and store it in the dataframe.

airq$node <- where(airct)

Then use group_by to group the observations by node, and use summarise to calculate the mean of the Ozone measurement. You can swap mean out for whatever summary statistic function you like.

airq %>% group_by(node) %>% summarise(avg=mean(Ozone))

Which gives the following results.

    node     avg
   (int)    (dbl)
1     3 55.60000
2     5 18.47917
3     6 31.14286
4     8 81.63333
5     9 48.71429
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top