Question

How can I get ID (or name) of terminal node of rpart model for every row? predict.rpart can return only predicted class (number or factor) or class probability or some combination (using type="matrix") for classification tree.

I would like to do something like:

fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit) # there are 5 terminal nodes
predict(fit, type = "node_id")   # should return IDs of terminal nodes (e.g. 1-5) (does not work)
Was it helpful?

Solution 2

For that model there were 4 splits, yielding 5 "terminal nodes" or in the terminology used in rpart: <leaf>s. I do not see why there should be 5 predictions for anything. The predictions are for particular cases and the leaves are the result of a variable number of the splits used to make those predictions. The numbers of rows in the original dataset that ended up in the leaves may be what you want, in which case these are ways of getting those numbers:

# Row-wise predicted class
fit$where

# counts of cases in leaves of prediction rules
table(fit$where)
 3  5  7  8  9 
29 12 14  7 19 

In order to assemble the labels(fit) that apply to a particular leaf, you would need to traverse the rule-tree and accumulate all the labels for all the splits that were applied to produce a particular leaf. You probably want to look at:

?print.rpart    
?rpart.object
?text.rpart
?labels.rpart

OTHER TIPS

The partykit package supports predict(..., type = "node"), both in and out of sample. You can simply convert the rpart object to use this:

library("partykit")
predict(as.party(fit), type = "node")  
## 9 7 9 9 3 3 3 3 3 8 8 3 9 5 3 3 3 7 3 5 3 9 8 9 9 5 9 8 3 3 3 7 7 3 7 3 5 ## 9 5 8 
## 9 7 9 9 3 3 3 3 3 8 8 3 9 5 3 3 3 7 3 5 3 9 8 9 9 5 9 8 3 3 3 7 7 3 7 3 5 ## 9 5 8 
## 9 5 9 9 3 7 3 7 9 7 8 3 9 3 3 3 5 9 5 8 9 9 9 3 3 5 3 7 5 3 7 7 3 7 3 3 7 ## 5 7 9 
## 9 5 9 9 3 7 3 7 9 7 8 3 9 3 3 3 5 9 5 8 9 9 9 3 3 5 3 7 5 3 7 7 3 7 3 3 7 ## 5 7 9 
## 5 
## 5 
table(predict(as.party(fit), type = "node")) 
##  3  5  7  8  9 
## 29 12 14  7 19 

The above method using $where pops up only the row number in the tree frame. And so some observation might be assigned node ID instead of leaf node ID when using kyphosis$ID = fit$where To get the actual leaf node ID use the following:

MyID <- row.names(fit$frame)
kyphosis$ID <- MyID[fit$where]

For predicting leafs on a new data one could use rpart.predict(fit, newdata, nn = TRUE) from the package rpart.plot to add node names to the output.

Here is an isolated rpart leaf preditor:

rpart_leaves <- function(fit, newdata, type = c("where", "leaf"), na.action = na.pass) {
  if (is.null(attr(newdata, "terms"))) {
    Terms <- delete.response(fit$terms)
    newdata <- model.frame(Terms, newdata, na.action = na.action,
                           xlev = attr(fit, "xlevels"))
    if (!is.null(cl <- attr(Terms, "dataClasses")))
      .checkMFClasses(cl, newdata, TRUE)
  }
  newdata <- rpart:::rpart.matrix(newdata)
  where <- unname(rpart:::pred.rpart(fit, newdata))
  if (match.arg(type) == "where")
    return(where)
  rownames(fit$frame)[where]
}
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top