3 geom_node_plot - martin-borkovec/ggparty GitHub Wiki

Node Plots

If we want to plot the data contained within the individual nodes of the tree, we need to add geom_node_plot() to our ggparty() call. To understand why this is necessary let’s reiterate what ggparty() does and how it uses the ggplot() function. Every ggplot() call needs a 'data.frame', so as we’ve seen above ggparty() creates one from the 'party' object. In this 'data.frame' every row corresponds to a node of the tree.
Each column of this node’s data is stored as a 'list'in its own column. This way it is not directly usable by ggplot(), since ggplot() can’t handle lists inside its data. This is where geom_node_plot() comes into play and each instance of geom_node_plot() creates a completely separate ggplot() call after transforming all the columns containing lists of data (created by ggparty()) into a new 'data.frame' for the new separate ggplot() call.
All the other columns of ggparty’s 'data.frame' (like kids, parent, etc.) get lost in this process, since usually we will not be interested in these when plotting the node data and they could potentially cause naming conflicts. In case we do want to use them, there is a fairly easy way to do so. So by default we can access anything that can be found in the data slot of the party object, the fitted_nodes and additionally if the 'party' object contains any, the fitted.values and the residuals of the included model.

Now let’s take a look at a constparty object created from the same data.

data("WeatherPlay", package = "partykit")
sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75)
sp_w <- partysplit(4L, index = 1:2)
pn <- partynode(1L, split = sp_o, kids = list(
  partynode(2L, split = sp_h, kids = list(
    partynode(3L, info = "yes"),
    partynode(4L, info = "no"))),
  partynode(5L, info = "yes"),
  partynode(6L, split = sp_w, kids = list(
    partynode(7L, info = "yes"),
    partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)

n1 <- partynode(id = 1L, split = sp_o, kids = lapply(2L:4L, partynode))
t2 <- party(n1,
            data = WeatherPlay,
            fitted = data.frame(
              "(fitted)" = fitted_node(n1, data = WeatherPlay),
              "(response)" = WeatherPlay$play,
              check.names = FALSE),
            terms = terms(play ~ ., data = WeatherPlay)
)
t2 <- as.constparty(t2)

To visualize the distribution of the variable play we will use the geom_node_plot() function. It allows us to show the data of each node in its separate plot. For this to work, we have to specify the argument gglist. Basically we have to provide a 'list' of all the 'gg' components we would add to a ggplot() call on the data element of a node.

ggplot(t2[2]$data) +
  geom_bar(aes(x = "", fill = play),
           position = position_fill()) +
  xlab("play")

So if we were to use the above code to create the desired plot for one node, we can instead pass a 'list' of the two components to gglist and geom_node_plot will create a version of it for every specified node (per default the terminal nodes). Keep in mind, that since it’s a 'list' we need to use "," instead of "+" to combine the components.

ggparty(t2) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  # pass list to gglist containing all ggplot components we want to plot for each
  # (default: terminal) node
  geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
                                        position = position_fill()),
                               xlab("play")))

Axes and Legends

Setting shared_axis_labels to TRUE allows us to use the space more efficiently and legend_separator = TRUE draws a line between the tree and the legend.

ggparty(t2) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
                                        position = position_fill()),
                               xlab("play")),
                 # draw only one label for each axis
                 shared_axis_labels = TRUE,
                 # draw line between tree and legend
                 legend_separator = TRUE
                 )

Setting shared_legend to FALSE draws an individual legend at each plot instead of one common at the bottom of the plot. This might be necessary if we use multiple different geom_node_plots() which lead to various legends. In case we want to remove the legend all together (i.e. theme(legend.position = "none")) shared_legend has to be set to FALSE.

ggparty(t2) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
                                        position = position_fill()),
                               xlab("play")),
                 # draw individual legend for each plot
                 shared_legend = FALSE
  )

Thanks to the versatility of ggplot2 we are also very flexible in creating these node plots. For example the barplot can be easily changed into a pie chart. The argument size of geom_node_plot() can be set to "nodesize" which changes the size of the node plot relative to the number of observations in the respective node.

ggparty(t2) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  # draw pie charts with their size relative to nodesize
  geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
                                        position = position_fill()),
                               coord_polar("y"),
                               theme_void()),
                 size = "nodesize")

Predictions

If the party object contains a model with only one predictor we can use the argument predict to choose to show a prediction line. Additional arguments for the geom_line() drawing this line can be passed via perdict_gpar.

So let’s take a look at this 'lmtree' containing linear models explaining eval with beauty.

data("TeachingRatings", package = "AER")
tr <- subset(TeachingRatings, credits == "more")

tr_tree <- lmtree(eval ~ beauty | minority + age + gender + division + native +
                    tenure, data = tr, weights = students, caseweights = FALSE)
ggparty(tr_tree) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist = list(geom_point(aes(x = beauty,
                                             y = eval,
                                             col = tenure,
                                             shape = minority),
                                         alpha = 0.8),
                              theme_bw(base_size = 10)),
                shared_axis_labels = TRUE,
                legend_separator = TRUE,
                # predict based on variable
                predict = "beauty",
                # graphical parameters for geom_line of predictions
                predict_gpar = list(col = "blue",
                                   size = 1.2)
                )

In case we want to generate predictions for a more complicated model, we need to do this beforehand and pass the new data through the data argument inside geom_node_plot()’s gglist.

First the tree of class 'party' is created using the partykit infrastructure.

data("GBSG2", package = "TH.data")
GBSG2$time <- GBSG2$time/365

library("survival")
wbreg <- function(y, x, start = NULL, weights = NULL, offset = NULL, ...) {
  survreg(y ~ 0 + x, weights = weights, dist = "weibull", ...)
}


logLik.survreg <- function(object, ...)
  structure(object$loglik[2], df = sum(object$df), class = "logLik")

gbsg2_tree <- mob(Surv(time, cens) ~ horTh + pnodes | age + tsize +
                    tgrade + progrec + estrec + menostat, data = GBSG2,
                  fit = wbreg, control = mob_control(minsize = 80))

So in this case we want to create a sequence over the range of the metric variable pnodes and combine it once with the first level of the binary variable horTh and once with the second. Using this data we then (in this case) need to generate predictions of the type "quantile" with p set to 0.5. The function get_predictions() can help us with the second part since it applies a newdata function defined by us to each node and returns a suitable 'data.frame'.
If we want to use it, we need to supply the 'party' object, a function that creates the new data from each node’s data and optionally predict_arg, additional arguments to pass to the predict() call.

# function to generate newdata for predictions
generate_newdata <- function(data) {
  z <- data.frame(horTh = factor(rep(c("yes", "no"),
                                     each = length(data$pnodes))),
                  pnodes = rep(seq(from = min(data$pnodes),
                                   to = max(data$pnodes),
                                   length.out = length(data$pnodes)),
                               2))
  z$x <- model.matrix(~ ., data = z)
  z}

# convenience function to create dataframe for predictions
pred_df <- get_predictions(gbsg2_tree,
                           # IMPORTANT to set same ids as in geom_node_plot
                           # later used for plotting
                           ids = "terminal",
                           newdata_fun = generate_newdata,
                           predict_arg = list(type = "quantile",
                                              p = 0.5)
)

The 'data.frame' created this way can then be passed to any 'gg' component in geom_node_plot()’s gglist. In this case we want to draw a line for both values of horTh and separate them by color.

ggparty(gbsg2_tree, terminal_space = 0.8, horizontal = TRUE) +
  geom_edge() +
  geom_node_splitvar() +
  geom_edge_label() +
  geom_node_plot(
    gglist = list(geom_point(aes(y = `Surv(time, cens).time`,
                                 x = pnodes,
                                 col = horTh),
                             alpha = 0.6),
                  # supply pred_df as data argument of geom_line
                  geom_line(data = pred_df,
                            aes(x = pnodes,
                                y = prediction,
                                col = horTh),
                            size = 1.2),
                  theme_bw(),
                  ylab("Survival Time")
                  ),
    ids = "terminal", # not necessary since default
    shared_axis_labels = TRUE
  )

Potential Pitfalls

Combining 'gg' Components in gglist with "+"

The object passed to gglist has to be a 'list' and therefore we must not use "+" to combine the components of a geom_node_plot() but instead ",".

Passing Components at the Wrong Place

As we now know, each geom_node_plot() is basically a completely separate plot with its own arguments and specifications which are independent from the base plot of the tree (i.e. the ggparty call with edges, labels, etc.). For that reason, if for example, we want ro remove the legend of a geom_node_plot() we must not pass it at the base level (as a component of the tree) but inside the gglist of the geom_node_plot().

⚠️ **GitHub.com Fallback** ⚠️