3 geom_node_plot - martin-borkovec/ggparty GitHub Wiki
If we want to plot the data
contained within the individual nodes of
the tree, we need to add geom_node_plot()
to our ggparty()
call. To
understand why this is necessary let’s reiterate what ggparty()
does
and how it uses the ggplot()
function. Every ggplot()
call needs a
'data.frame'
, so as we’ve seen above ggparty()
creates one from the
'party'
object. In this 'data.frame'
every row corresponds to a node
of the tree.
Each column of this node’s data
is stored as a 'list'
in its own
column. This way it is not directly usable by ggplot()
, since
ggplot()
can’t handle lists inside its data. This is where
geom_node_plot()
comes into play and each instance of
geom_node_plot()
creates a completely separate ggplot()
call after
transforming all the columns containing lists of data (created by
ggparty()
) into a new 'data.frame'
for the new separate ggplot()
call.
All the other columns of ggparty’s 'data.frame'
(like kids
,
parent
, etc.) get lost in this process, since usually we will not be
interested in these when plotting the node data and they could
potentially cause naming conflicts. In case we do want to use them,
there is a fairly easy way to do so. So by default
we can access anything that can be found in the data slot of the party
object, the fitted_nodes and additionally if the 'party'
object
contains any, the fitted.values
and the residuals
of the included
model.
Now let’s take a look at a constparty object created from the same data.
data("WeatherPlay", package = "partykit")
sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75)
sp_w <- partysplit(4L, index = 1:2)
pn <- partynode(1L, split = sp_o, kids = list(
partynode(2L, split = sp_h, kids = list(
partynode(3L, info = "yes"),
partynode(4L, info = "no"))),
partynode(5L, info = "yes"),
partynode(6L, split = sp_w, kids = list(
partynode(7L, info = "yes"),
partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)
n1 <- partynode(id = 1L, split = sp_o, kids = lapply(2L:4L, partynode))
t2 <- party(n1,
data = WeatherPlay,
fitted = data.frame(
"(fitted)" = fitted_node(n1, data = WeatherPlay),
"(response)" = WeatherPlay$play,
check.names = FALSE),
terms = terms(play ~ ., data = WeatherPlay)
)
t2 <- as.constparty(t2)
To visualize the distribution of the variable play
we will use the
geom_node_plot()
function. It allows us to show the data
of each
node in its separate plot. For this to work, we have to specify the
argument gglist
. Basically we have to provide a 'list'
of all the
'gg'
components we would add to a ggplot()
call on the data
element of a node.
ggplot(t2[2]$data) +
geom_bar(aes(x = "", fill = play),
position = position_fill()) +
xlab("play")
So if we were to use the above code to create the desired plot for one
node, we can instead pass a 'list'
of the two components to gglist
and geom_node_plot
will create a version of it for every specified
node (per default the terminal
nodes). Keep in mind, that since it’s a
'list'
we need to use ","
instead of "+"
to combine the
components.
ggparty(t2) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
# pass list to gglist containing all ggplot components we want to plot for each
# (default: terminal) node
geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
position = position_fill()),
xlab("play")))
Setting shared_axis_labels
to TRUE
allows us to use the space more
efficiently and legend_separator = TRUE
draws a line between the tree
and the legend.
ggparty(t2) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
position = position_fill()),
xlab("play")),
# draw only one label for each axis
shared_axis_labels = TRUE,
# draw line between tree and legend
legend_separator = TRUE
)
Setting shared_legend
to FALSE
draws an individual legend at each
plot instead of one common at the bottom of the plot. This might be
necessary if we use multiple different geom_node_plots()
which lead to
various legends. In case we want to remove the legend all together
(i.e. theme(legend.position = "none")
) shared_legend
has to be set
to FALSE
.
ggparty(t2) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
position = position_fill()),
xlab("play")),
# draw individual legend for each plot
shared_legend = FALSE
)
Thanks to the versatility of ggplot2 we are also very flexible in
creating these node plots. For example the barplot can be easily changed
into a pie chart. The argument size
of geom_node_plot()
can be set
to "nodesize"
which changes the size of the node plot relative to the
number of observations in the respective node.
ggparty(t2) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
# draw pie charts with their size relative to nodesize
geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
position = position_fill()),
coord_polar("y"),
theme_void()),
size = "nodesize")
If the party object contains a model with only one predictor we can use
the argument predict
to choose to show a prediction line. Additional
arguments for the geom_line()
drawing this line can be passed via
perdict_gpar
.
So let’s take a look at this 'lmtree'
containing linear models
explaining eval
with beauty
.
data("TeachingRatings", package = "AER")
tr <- subset(TeachingRatings, credits == "more")
tr_tree <- lmtree(eval ~ beauty | minority + age + gender + division + native +
tenure, data = tr, weights = students, caseweights = FALSE)
ggparty(tr_tree) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
geom_node_plot(gglist = list(geom_point(aes(x = beauty,
y = eval,
col = tenure,
shape = minority),
alpha = 0.8),
theme_bw(base_size = 10)),
shared_axis_labels = TRUE,
legend_separator = TRUE,
# predict based on variable
predict = "beauty",
# graphical parameters for geom_line of predictions
predict_gpar = list(col = "blue",
size = 1.2)
)
In case we want to generate predictions for a more complicated model, we
need to do this beforehand and pass the new data through the data
argument inside geom_node_plot()
’s gglist
.
First the tree of class 'party'
is created using the partykit
infrastructure.
data("GBSG2", package = "TH.data")
GBSG2$time <- GBSG2$time/365
library("survival")
wbreg <- function(y, x, start = NULL, weights = NULL, offset = NULL, ...) {
survreg(y ~ 0 + x, weights = weights, dist = "weibull", ...)
}
logLik.survreg <- function(object, ...)
structure(object$loglik[2], df = sum(object$df), class = "logLik")
gbsg2_tree <- mob(Surv(time, cens) ~ horTh + pnodes | age + tsize +
tgrade + progrec + estrec + menostat, data = GBSG2,
fit = wbreg, control = mob_control(minsize = 80))
So in this case we want to create a sequence over the range of the
metric variable pnodes
and combine it once with the first level of the
binary variable horTh
and once with the second. Using this data we
then (in this case) need to generate predictions of the type
"quantile"
with p
set to 0.5
. The function get_predictions()
can
help us with the second part since it applies a newdata
function
defined by us to each node and returns a suitable 'data.frame'
.
If we want to use it, we need to supply the 'party'
object, a function
that creates the new data from each node’s data
and optionally
predict_arg
, additional arguments to pass to the predict()
call.
# function to generate newdata for predictions
generate_newdata <- function(data) {
z <- data.frame(horTh = factor(rep(c("yes", "no"),
each = length(data$pnodes))),
pnodes = rep(seq(from = min(data$pnodes),
to = max(data$pnodes),
length.out = length(data$pnodes)),
2))
z$x <- model.matrix(~ ., data = z)
z}
# convenience function to create dataframe for predictions
pred_df <- get_predictions(gbsg2_tree,
# IMPORTANT to set same ids as in geom_node_plot
# later used for plotting
ids = "terminal",
newdata_fun = generate_newdata,
predict_arg = list(type = "quantile",
p = 0.5)
)
The 'data.frame'
created this way can then be passed to any 'gg'
component in geom_node_plot()
’s gglist
. In this case we want to draw
a line for both values of horTh
and separate them by color.
ggparty(gbsg2_tree, terminal_space = 0.8, horizontal = TRUE) +
geom_edge() +
geom_node_splitvar() +
geom_edge_label() +
geom_node_plot(
gglist = list(geom_point(aes(y = `Surv(time, cens).time`,
x = pnodes,
col = horTh),
alpha = 0.6),
# supply pred_df as data argument of geom_line
geom_line(data = pred_df,
aes(x = pnodes,
y = prediction,
col = horTh),
size = 1.2),
theme_bw(),
ylab("Survival Time")
),
ids = "terminal", # not necessary since default
shared_axis_labels = TRUE
)
The object passed to gglist
has to be a 'list'
and therefore we must
not use "+"
to combine the components of a geom_node_plot()
but
instead ","
.
As we now know, each geom_node_plot()
is basically a completely
separate plot with its own arguments and specifications which are
independent from the base plot of the tree (i.e. the ggparty call with
edges, labels, etc.). For that reason, if for example, we want ro remove
the legend of a geom_node_plot()
we must not pass it at the base level
(as a component of the tree) but inside the gglist
of the
geom_node_plot()
.