Pergunta

I'm working with decision trees in python's scikit learn. Unlike many use cases for this, I'm not so much interested in the accuracy of the classifier at this point so much as I am extracting the specific path a data point takes through the tree when I call .predict() on it. Has anyone done this before? I'd like to build a data frame containing ($X_{i}$, path$_{i}$) pairs for use in a down-stream analysis.

Foi útil?

Solução

Looks like this is easier to do in R, using the rpart library in combination with the partykit library. I'd ideally like to find a way to do this in python, but here's the code, for anyone who is interested (taken from here):

pathpred <- function(object, ...){
    ## coerce to "party" object if necessary
    if(!inherits(object, "party")) object <- as.party(object)

    ## get standard predictions (response/prob) and collect in data frame
    rval <- data.frame(response = predict(object, type = "response", ...))
    rval$prob <- predict(object, type = "prob", ...)

    ## get rules for each node
    rls <- partykit:::.list.rules.party(object)

    ## get predicted node and select corresponding rule
    rval$rule <- rls[as.character(predict(object, type = "node", ...))]

    return(rval)
}

Illustration using the iris data and rpart():

library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]

Yielding,

       response prob.setosa prob.versicolor prob.virginica
 1       setosa  1.00000000      0.00000000     0.00000000
 51  versicolor  0.00000000      0.90740741     0.09259259
 101  virginica  0.00000000      0.02173913     0.97826087
                                           rule
 1                          Petal.Length < 2.45
 51   Petal.Length >= 2.45 & Petal.Width < 1.75
 101 Petal.Length >= 2.45 & Petal.Width >= 1.75

Which looks to be something I could at least use to derive shared parent node information.

Licenciado em: CC-BY-SA com atribuição
scroll top