catenets.models.jax.pseudo_outcome_nets module
Implements Pseudo-outcome based Two-step Nets, namely the DR-learner, the PW-learner and the RA-learner.
- class DRNet(first_stage_strategy: str = 'T', data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = 2, binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, rescale_transformation: bool = False, nonlin: str = 'elu', first_stage_args: Optional[dict] = None)
Bases:
catenets.models.jax.pseudo_outcome_nets.PseudoOutcomeNet
Wrapper for DR-learner using PseudoOutcomeNet
- _abc_impl = <_abc_data object>
- class PWNet(first_stage_strategy: str = 'T', data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = 2, binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, rescale_transformation: bool = False, nonlin: str = 'elu', first_stage_args: Optional[dict] = None)
Bases:
catenets.models.jax.pseudo_outcome_nets.PseudoOutcomeNet
Wrapper for PW-learner using PseudoOutcomeNet
- _abc_impl = <_abc_data object>
- class PseudoOutcomeNet(first_stage_strategy: str = 'T', first_stage_args: Optional[dict] = None, data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = 2, transformation: str = 'DR', binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, rescale_transformation: bool = False, nonlin: str = 'elu')
Bases:
catenets.models.jax.base.BaseCATENet
Class implements TwoStepLearners based on pseudo-outcome regression as discussed in Curth &vd Schaar (2021): RA-learner, PW-learner and DR-learner
- Parameters
first_stage_strategy (str, default 't') – which nuisance estimator to use in first stage
first_stage_args (dict) – Any additional arguments to pass to first stage training function
data_split (bool, default False) – Whether to split the data in two folds for estimation
cross_fit (bool, default False) – Whether to perform cross fitting
n_cf_folds (int) – Number of crossfitting folds to use
transformation (str, default 'AIPW') – pseudo-outcome to use (‘AIPW’ for DR-learner, ‘HT’ for PW learner, ‘RA’ for RA-learner)
binary_y (bool, default False) – Whether the outcome is binary
n_layers_out (int) – First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_units_out (int) – First stage Number of hidden units in each hypothesis layer
n_layers_r (int) – First stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)
n_units_r (int) – First stage Number of hidden units in each representation layer
n_layers_out_t (int) – Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_units_out_t (int) – Second stage Number of hidden units in each hypothesis layer
n_layers_r_t (int) – Second stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)
n_units_r_t (int) – Second stage Number of hidden units in each representation layer
penalty_l2 (float) – First stage l2 (ridge) penalty
penalty_l2_t (float) – Second stage l2 (ridge) penalty
step_size (float) – First stage learning rate for optimizer
step_size_t (float) – Second stage learning rate for optimizer
n_iter (int) – Maximum number of iterations
batch_size (int) – Batch size
val_split_prop (float) – Proportion of samples used for validation split (can be 0)
early_stopping (bool, default True) – Whether to use early stopping
patience (int) – Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min (int) – Minimum number of iterations to go through before starting early stopping
n_iter_print (int) – Number of iterations after which to print updates
seed (int) – Seed used
nonlin (string, default 'elu') – Nonlinearity to use in NN
- _abc_impl = <_abc_data object>
- _get_predict_function() Callable
- _get_train_function() Callable
- fit(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, p: Optional[jax._src.basearray.Array] = None) catenets.models.jax.pseudo_outcome_nets.PseudoOutcomeNet
Fit method for a CATENet. Takes covariates, outcome variable and treatment indicator as input
- Parameters
X (pd.DataFrame or np.array) – Covariate matrix
y (np.array) – Outcome vector
w (np.array) – Treatment indicator
p (np.array) – Vector of (known) treatment propensities. Currently only supported for TwoStepNets.
- predict(X: jax._src.basearray.Array, return_po: bool = False, return_prop: bool = False) jax._src.basearray.Array
Predict treatment effect estimates using a CATENet. Depending on method, can also return potential outcome estimate and propensity score estimate.
- Parameters
X (pd.DataFrame or np.array) – Covariate matrix
return_po (bool, default False) – Whether to return potential outcome estimate
return_prop (bool, default False) – Whether to return propensity estimate
- Returns
- Return type
array of CATE estimates, optionally also potential outcomes and propensity
- class RANet(first_stage_strategy: str = 'T', data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = 2, binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, rescale_transformation: bool = False, nonlin: str = 'elu', first_stage_args: Optional[dict] = None)
Bases:
catenets.models.jax.pseudo_outcome_nets.PseudoOutcomeNet
Wrapper for RA-learner using PseudoOutcomeNet
- _abc_impl = <_abc_data object>
- _train_and_predict_first_stage(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, fit_mask: jax._src.basearray.Array, pred_mask: jax._src.basearray.Array, first_stage_strategy: str, binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_units_out: int = 100, n_units_r: int = 200, penalty_l2: float = 0.0001, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, nonlin: str = 'elu', avg_objective: bool = False, transformation: str = 'DR', first_stage_args: Optional[dict] = None) Tuple
- train_pseudooutcome_net(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, p: Optional[jax._src.basearray.Array] = None, first_stage_strategy: str = 'T', data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = 2, transformation: str = 'DR', binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_r_t: int = 3, n_layers_out_t: int = 2, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, rescale_transformation: bool = False, return_val_loss: bool = False, nonlin: str = 'elu', avg_objective: bool = True, first_stage_args: Optional[dict] = None) Tuple