R6 class to represent data to be used in estimating a model

R6 class to represent data to be used in estimating a model

Details

This class provides consistent names and interfaces to data which will be used in a supervised regression / classification model.

Public fields

label

The labels for the eventual model as a vector.

features

The matrix representation of the data to be used for model fitting. Constructed using stats::model.matrix.

model_frame

The data-frame representation of the data as constructed by stats::model.frame.

split_id

The split identifiers as a vector.

num_splits

The integer number of splits in the data.

cluster

A cluster ID as a vector, constructed using the unit identifiers.

weights

The case-weights as a vector.

Methods


Method new()

Creates an R6 object to represent data to be used in a prediction model.

Usage

Model_data$new(data, label_col, ..., .weight_col = NULL)

Arguments

data

The full dataset to populate the class with.

label_col

The unquoted name of the column to use as the label in supervised learning models.

...

The unquoted names of features to use in the model.

.weight_col

The unquoted name of the column to use as case-weights in subsequent models.

Returns

A Model_data object.

Examples

library("dplyr")
df <- dplyr::tibble(
    uid = 1:100,
    x1 = rnorm(100),
    x2 = rnorm(100),
    x3 = sample(4, 100, replace = TRUE)
) %>% dplyr::mutate(
    y = x1 + x2 + x3 + rnorm(100),
    x3 = factor(x3)
)
df <- make_splits(df, uid, .num_splits = 5)
data <- Model_data$new(df, y, x1, x2, x3)


Method SL_cv_control()

A helper function to create the cross-validation options to be used by SuperLearner.

Usage

Model_data$SL_cv_control()


Method clone()

The objects of this class are cloneable with this method.

Usage

Model_data$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples


## ------------------------------------------------
## Method `Model_data$new`
## ------------------------------------------------

library("dplyr")
#> 
#> Attaching package: ‘dplyr’
#> The following objects are masked from ‘package:stats’:
#> 
#>     filter, lag
#> The following objects are masked from ‘package:base’:
#> 
#>     intersect, setdiff, setequal, union
df <- dplyr::tibble(
    uid = 1:100,
    x1 = rnorm(100),
    x2 = rnorm(100),
    x3 = sample(4, 100, replace = TRUE)
) %>% dplyr::mutate(
    y = x1 + x2 + x3 + rnorm(100),
    x3 = factor(x3)
)
df <- make_splits(df, uid, .num_splits = 5)
data <- Model_data$new(df, y, x1, x2, x3)