Apply function on a subset of columns (.SDcols) whilst applying a different function on another column (within groups)

StackOverflow https://stackoverflow.com/questions/20459519

  •  30-08-2022
  •  | 
  •  

Question

This is very similar to a question applying a common function to multiple columns of a data.table uning .SDcols answered thoroughly here.

The difference is that I would like to simultaneously apply a different function on another column which is not part of the .SD subset. I post a simple example below to show my attempt to solve the problem:

dt = data.table(grp = sample(letters[1:3],100, replace = TRUE),
                v1 = rnorm(100), 
                v2 = rnorm(100), 
                v3 = rnorm(100))
sd.cols = c("v2", "v3")
dt.out = dt[, list(v1 = sum(v1),  lapply(.SD,mean)), by = grp, .SDcols = sd.cols]

Yields the following error:

Error in `[.data.table`(dt, , list(v1 = sum(v1), lapply(.SD, mean)), by = grp,  
: object 'v1' not found

Now this makes sense because the v1 column is not included in the subset of columns which must be evaluated first. So I explored further by including it in my subset of columns:

sd.cols = c("v1","v2", "v3")
dt.out = dt[, list(sum(v1), lapply(.SD,mean)), by = grp, .SDcols = sd.cols]

Now this does not cause an error but it provides an answer containing 9 rows (for 3 groups), with the sum repeated thrice in column V1 and the means for all 3 columns (as expected but not wanted) placed in V2 as shown below:

> dt.out 
   grp        V1                  V2
1:   c -1.070608 -0.0486639841313638
2:   c -1.070608  -0.178154270921521
3:   c -1.070608  -0.137625003604012
4:   b -2.782252 -0.0794929150464099
5:   b -2.782252  -0.149529237116445
6:   b -2.782252   0.199925178109264
7:   a  6.091355   0.141659419355985
8:   a  6.091355 -0.0272192037753071
9:   a  6.091355 0.00815760216214876

Workaround Solution using 2 steps

Clearly it is possible to solve the problem in multiple steps by calculating the mean by group for the subset of columns and joining it to the sum by group for the single column as follows:

dt.out1 = dt[, sum(v1), by = grp]
dt.out2 = dt[, lapply(.SD,mean), by = grp, .SDcols = sd.cols]
dt.out = merge(dt.out1, dt.out2, by = "grp")

> dt.out
   grp        V1         v2           v3
1:   a  6.091355 -0.0272192  0.008157602
2:   b -2.782252 -0.1495292  0.199925178
3:   c -1.070608 -0.1781543 -0.137625004

Im sure it's a fairly simple thing I am missing, thanks in advance for any guidance.

Was it helpful?

Solution

Update: Issue #495 is solved now with this recent commit, we can now do this just fine:

require(data.table) # v1.9.7+
set.seed(1L)
dt = data.table(grp = sample(letters[1:3],100, replace = TRUE),
                v1 = rnorm(100), 
                v2 = rnorm(100), 
                v3 = rnorm(100))
sd.cols = c("v2", "v3")
dt.out = dt[, list(v1 = sum(v1),  lapply(.SD,mean)), by = grp, .SDcols = sd.cols]

However note that in this case, v2 would be returned as a list. That's because you're doing list(val, list()) effectively. What you intend to do perhaps is:

dt[, c(list(v1=sum(v1)), lapply(.SD, mean)), by=grp, .SDcols = sd.cols]
#    grp        v1          v2         v3
# 1:   a -6.440273  0.16993940  0.2173324
# 2:   b  4.304350 -0.02553813  0.3381612
# 3:   c  0.377974 -0.03828672 -0.2489067

See history for older answer.

OTHER TIPS

Try this:

dt[,list(sum(v1), mean(v2), mean(v3)), by=grp]

In data.table, using list() in the second argument allows you to describe a set of columns that result in the final data.table.

For what it's worth, .SD can be quite slow [^1] so you may want to avoid it unless you truly need all of the data supplied in the subsetted data.table like you might for a more sophisticated function.

Another option, if you have many columns for .SDcols would be to do the merge in one line using the data.table merge syntax.

For example:

dt[, sum(v1), by=grp][dt[,lapply(.SD,mean), by=grp, .SDcols=sd.cols]]

In order to use the merge from data.table, you need to first use setkey() on your data.table so it knows how to match things up.

So really, first you need:

setkey(dt, grp)

Then you can use the line above to produce an equivalent result.

[^1]: I find this to be especially true as your number of groups approach the number of total rows. For example, this might happen where your key is an individual ID and many individuals have just one or two observations.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top