CausalForest

The CausalForest class is initiated by setting the respective parameters which handle the behavior of the forest and trees. We control the randomness by incrementing seeds whenever randomness is needed, starting with an initial seed seed_counter which defaults to 1.

The fitting and prediction process can be parallelized using the num_workers argument (default is 1), which triggers a parallelization over processes using joblib.

class cforest.forest.CausalForest(num_trees, split_ratio, min_leaf, max_depth, use_transformed_outcomes, num_workers=1, seed_counter=1)

Estimator class to fit a causal forest.

Estimator class to fit a causal forest on numerical data given hyperparameters set in forest_params and tree_params. Provides methods to fit the model, predict using the fitted model, save the fitted model and load a fitted model.

Note that the structure of this estimator is based on the BaseEstimator and RegressorMixin from sklearn; however, here we predict treatment effects –which are unobservable– hence regular model validation and model selection techniques (e.g. cross validation grid search) do not work as we can never estimate a loss on a training sample, thus a tighter integration into the sklearn workflow is unlikely for now.

Attributes:
forestparams (dict):

Hyperparameters for forest. Includes ‘num_trees’ (int) and ‘split_ratio’ (in [0, 1]). Example: forestparams = {

‘num_trees’: 100, ‘split_ratio’: 0.7,

}

treeparams (dict):

Parameters for tree. Includes ‘min_leaf’ (int) and ‘max_depth’ (int). Example: treeparams = {

‘min_leaf’: 5, ‘max_depth’: 25

}

num_workers (int):

Number of workers to use for the parallelization.

_is_fitted (bool):

True if the fit method was called or a fitted model was loaded using the load method and False otherwise.

num_features (int):

Number of features in design matrix that was used to fit the model. If forest has not been fitted is set to None.

fitted_model (pd.DataFrame):

Data frame representing the fitted model, see function _assert_df_is_valid_cforest for how a Causal Forest model is represented using data frames.

seed_counter (int):

Number where to start the seed counter.

__init__(num_trees, split_ratio, min_leaf, max_depth, use_transformed_outcomes, num_workers=1, seed_counter=1)

Initiliazes CausalForest estimator with hyperparameters.

Initializes CausalForest estimator with hyperparameters for the forest, i.e. forest_params, which contains the number of trees that should be fit and the ratio of features to be randomly considered at each split; and for the trees of which the forest is made of we consider tree_params, which contains the minimum number of observations in a leaf node and the maximum depth of a tree.

Args:
num_trees (int):

Number of (causal) trees to use in the forest.

split_ratio (float):

Ratio of features to (randomly) consider for each tree. Has to be in the range [0, 1].

min_leaf (int):

Minimum number of observations of each type (treated, untreated) allowed in a leaf.

max_depth (int):

Maximum depth a single tree is allowed to grow.

use_transformed_outcomes (bool):

Should the transformed outcomes be used to evaluate goodness of splits when building a tree.

num_workers (int):

Number of workers for parallelization.

seed_counter (int):

Number where to start the seed counter.

fit(X, t, y)

Fits Causal Forest on supplied data.

Fits a Causal Forest on outcomes y with treatment status t and features X, if data has no missing values and is of consistent shape.

Args:
X (pd.DataFrame or np.ndarray):

Data on features.

t (pd.Series or np.ndarray):

Data on treatment status.

y (pd.Series or np.ndarray):

Data on outcomes.

Returns:
self:

The fitted regressor.

Raises:
  • TypeError, if data is not a pd.DataFrame or np.array.

  • ValueError, if data has inconsistent shapes.

predict(X, num_workers=None)

Predicts treatment effects for new features X.

If the regressor has been fitted, predicts treatment effects of new features X, if X is a np.array of pd.DataFrame of correct shape.

Args:
X (pd.DataFrame or np.array):

Data on new features.

num_workers (int):

Number of workers for parallelization. Defaults to the number passed to the init method.

Returns:
predictions (np.array):

Predictions per row of X.

save(filename, overwrite=True)

Save fitted model as a csv file.

Args:
filename (str):

Complete directory path including filename where to save the fitted model.

overwrite (bool):

Overwrite existing file if True and otherwise do nothing.

Returns:

None

load(filename, overwrite_fitted_model=False)

Load fitted model from disc.

Args:
filename (str):

Complete directory path including filename where to load the fitted model from.

overwrite_fitted_model (bool):

Overwrite self.fitted_model if True and do nothing otherwise.

Returns:

self: the fitted regressor.

fit(X, t, y)

Fits Causal Forest on supplied data.

Fits a Causal Forest on outcomes y with treatment status t and features X, if data has no missing values and is of consistent shape.

Args:
X (pd.DataFrame or np.ndarray):

Data on features.

t (pd.Series or np.ndarray):

Data on treatment status.

y (pd.Series or np.ndarray):

Data on outcomes.

Returns:
self:

The fitted regressor.

Raises:
  • TypeError, if data is not a pd.DataFrame or np.array.

  • ValueError, if data has inconsistent shapes.

load(filename, overwrite_fitted_model=False)

Load fitted model from disc.

Args:
filename (str):

Complete directory path including filename where to load the fitted model from.

overwrite_fitted_model (bool):

Overwrite self.fitted_model if True and do nothing otherwise.

Returns:

self: the fitted regressor.

predict(X, num_workers=None)

Predicts treatment effects for new features X.

If the regressor has been fitted, predicts treatment effects of new features X, if X is a np.array of pd.DataFrame of correct shape.

Args:
X (pd.DataFrame or np.array):

Data on new features.

num_workers (int):

Number of workers for parallelization. Defaults to the number passed to the init method.

Returns:
predictions (np.array):

Predictions per row of X.

save(filename, overwrite=True)

Save fitted model as a csv file.

Args:
filename (str):

Complete directory path including filename where to save the fitted model.

overwrite (bool):

Overwrite existing file if True and otherwise do nothing.

Returns:

None