catenets.models.jax package

JAX-based implementations for the CATE estimators.

class DRNet(first_stage_strategy: str = 'T', data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = 2, binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, rescale_transformation: bool = False, nonlin: str = 'elu', first_stage_args: Optional[dict] = None)

Bases: catenets.models.jax.pseudo_outcome_nets.PseudoOutcomeNet

Wrapper for DR-learner using PseudoOutcomeNet

_abc_impl = <_abc_data object>
class DragonNet(binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, penalty_l2: float = 0.0001, n_units_out_prop: int = 100, n_layers_out_prop: int = 0, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, reg_diff: bool = False, same_init: bool = False, penalty_diff: float = 0.0001, nonlin: str = 'elu')

Bases: catenets.models.jax.representation_nets.SNet2

Wrapper for DragonNet

_abc_impl = <_abc_data object>
class FlexTENet(binary_y: bool = False, n_layers_out: int = 2, n_units_s_out: int = 50, n_units_p_out: int = 50, n_layers_r: int = 3, n_units_s_r: int = 100, n_units_p_r: int = 100, private_out: bool = False, penalty_l2: float = 0.0001, penalty_l2_p: float = 0.0001, penalty_orthogonal: float = 0.01, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, return_val_loss: bool = False, opt: str = 'adam', shared_repr: bool = False, pretrain_shared: bool = False, same_init: bool = True, lr_scale: float = 10, normalize_ortho: bool = False)

Bases: catenets.models.jax.base.BaseCATENet

Module implements FlexTENet, an architecture for treatment effect estimation that allows for both shared and private information in each layer of the network.

Parameters
  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_s_out (int) – Number of hidden units in each shared hypothesis layer

  • n_units_p_out (int) – Number of hidden units in each private hypothesis layer

  • n_layers_r (int) – Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_s_r (int) – Number of hidden units in each shared representation layer

  • n_units_s_r – Number of hidden units in each private representation layer

  • private_out (bool, False) – Whether the final prediction layer should be fully private, or retain a shared component.

  • penalty_l2 (float) – l2 (ridge) penalty

  • penalty_l2_p (float) – l2 (ridge) penalty for private layers

  • penalty_orthogonal (float) – orthogonalisation penalty

  • step_size (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • opt (str, default 'adam') – Optimizer to use, accepts ‘adam’ and ‘sgd’

  • shared_repr (bool, False) – Whether to use a shared representation block as TARNet

  • pretrain_shared (bool, False) – Whether to pretrain the shared component of the network while freezing the private parameters

  • same_init (bool, True) – Whether to use the same initialisation for all private spaces

  • lr_scale (float) – Whether to scale down the learning rate after unfreezing the private components of the network (only used if pretrain_shared=True)

  • normalize_ortho (bool, False) – Whether to normalize the orthogonality penalty (by depth of network)

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
class OffsetNet(binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, penalty_l2: float = 0.0001, penalty_l2_p: float = 0.0001, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, nonlin: str = 'elu')

Bases: catenets.models.jax.base.BaseCATENet

Module implements OffsetNet, also referred to as the ‘reparametrization approach’ and ‘hard approach’ in Curth & vd Schaar (2021); modeling the POs using a shared prognostic function and an offset (treatment effect).

Parameters
  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – Number of hidden units in each hypothesis layer

  • n_layers_r (int) – Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_r (int) – Number of hidden units in each representation layer

  • penalty_l2 (float) – l2 (ridge) penalty

  • step_size (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • penalty_l2_p (float) – l2-penalty for regularizing the offset

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
class PWNet(first_stage_strategy: str = 'T', data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = 2, binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, rescale_transformation: bool = False, nonlin: str = 'elu', first_stage_args: Optional[dict] = None)

Bases: catenets.models.jax.pseudo_outcome_nets.PseudoOutcomeNet

Wrapper for PW-learner using PseudoOutcomeNet

_abc_impl = <_abc_data object>
class PseudoOutcomeNet(first_stage_strategy: str = 'T', first_stage_args: Optional[dict] = None, data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = 2, transformation: str = 'DR', binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, rescale_transformation: bool = False, nonlin: str = 'elu')

Bases: catenets.models.jax.base.BaseCATENet

Class implements TwoStepLearners based on pseudo-outcome regression as discussed in Curth &vd Schaar (2021): RA-learner, PW-learner and DR-learner

Parameters
  • first_stage_strategy (str, default 't') – which nuisance estimator to use in first stage

  • first_stage_args (dict) – Any additional arguments to pass to first stage training function

  • data_split (bool, default False) – Whether to split the data in two folds for estimation

  • cross_fit (bool, default False) – Whether to perform cross fitting

  • n_cf_folds (int) – Number of crossfitting folds to use

  • transformation (str, default 'AIPW') – pseudo-outcome to use (‘AIPW’ for DR-learner, ‘HT’ for PW learner, ‘RA’ for RA-learner)

  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – First stage Number of hidden units in each hypothesis layer

  • n_layers_r (int) – First stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_r (int) – First stage Number of hidden units in each representation layer

  • n_layers_out_t (int) – Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out_t (int) – Second stage Number of hidden units in each hypothesis layer

  • n_layers_r_t (int) – Second stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_r_t (int) – Second stage Number of hidden units in each representation layer

  • penalty_l2 (float) – First stage l2 (ridge) penalty

  • penalty_l2_t (float) – Second stage l2 (ridge) penalty

  • step_size (float) – First stage learning rate for optimizer

  • step_size_t (float) – Second stage learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
fit(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, p: Optional[jax._src.basearray.Array] = None) catenets.models.jax.pseudo_outcome_nets.PseudoOutcomeNet

Fit method for a CATENet. Takes covariates, outcome variable and treatment indicator as input

Parameters
  • X (pd.DataFrame or np.array) – Covariate matrix

  • y (np.array) – Outcome vector

  • w (np.array) – Treatment indicator

  • p (np.array) – Vector of (known) treatment propensities. Currently only supported for TwoStepNets.

predict(X: jax._src.basearray.Array, return_po: bool = False, return_prop: bool = False) jax._src.basearray.Array

Predict treatment effect estimates using a CATENet. Depending on method, can also return potential outcome estimate and propensity score estimate.

Parameters
  • X (pd.DataFrame or np.array) – Covariate matrix

  • return_po (bool, default False) – Whether to return potential outcome estimate

  • return_prop (bool, default False) – Whether to return propensity estimate

Returns

Return type

array of CATE estimates, optionally also potential outcomes and propensity

class RANet(first_stage_strategy: str = 'T', data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = 2, binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, rescale_transformation: bool = False, nonlin: str = 'elu', first_stage_args: Optional[dict] = None)

Bases: catenets.models.jax.pseudo_outcome_nets.PseudoOutcomeNet

Wrapper for RA-learner using PseudoOutcomeNet

_abc_impl = <_abc_data object>
class RNet(second_stage_strategy: str = 'R', data_split: bool = False, cross_fit: bool = False, n_cf_folds: int = 2, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, nonlin: str = 'elu', binary_y: bool = False)

Bases: catenets.models.jax.base.BaseCATENet

Class implements R-learner and U-learner using NNs

Parameters
  • second_stage_strategy (str, default 'R') – Which strategy to use in the second stage (‘R’ for R-learner, ‘U’ for U-learner)

  • data_split (bool, default False) – Whether to split the data in two folds for estimation

  • cross_fit (bool, default False) – Whether to perform cross fitting

  • n_cf_folds (int) – Number of crossfitting folds to use

  • n_layers_out (int) – First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – First stage Number of hidden units in each hypothesis layer

  • n_layers_r (int) – First stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_r (int) – First stage Number of hidden units in each representation layer

  • n_layers_out_t (int) – Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out_t (int) – Second stage Number of hidden units in each hypothesis layer

  • n_layers_r_t (int) – Second stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_r_t (int) – Second stage Number of hidden units in each representation layer

  • penalty_l2 (float) – First stage l2 (ridge) penalty

  • penalty_l2_t (float) – Second stage l2 (ridge) penalty

  • step_size (float) – First stage learning rate for optimizer

  • step_size_t (float) – Second stage learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
fit(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, p: Optional[jax._src.basearray.Array] = None) catenets.models.jax.rnet.RNet

Fit method for a CATENet. Takes covariates, outcome variable and treatment indicator as input

Parameters
  • X (pd.DataFrame or np.array) – Covariate matrix

  • y (np.array) – Outcome vector

  • w (np.array) – Treatment indicator

  • p (np.array) – Vector of (known) treatment propensities. Currently only supported for TwoStepNets.

predict(X: jax._src.basearray.Array, return_po: bool = False, return_prop: bool = False) jax._src.basearray.Array

Predict treatment effect estimates using a CATENet. Depending on method, can also return potential outcome estimate and propensity score estimate.

Parameters
  • X (pd.DataFrame or np.array) – Covariate matrix

  • return_po (bool, default False) – Whether to return potential outcome estimate

  • return_prop (bool, default False) – Whether to return propensity estimate

Returns

Return type

array of CATE estimates, optionally also potential outcomes and propensity

class SNet(with_prop: bool = True, binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 100, n_layers_out: int = 2, n_units_r_small: int = 50, n_units_out: int = 100, n_units_out_prop: int = 100, n_layers_out_prop: int = 2, penalty_l2: float = 0.0001, penalty_orthogonal: float = 0.01, penalty_disc: float = 0, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, reg_diff: bool = False, penalty_diff: float = 0.0001, seed: int = 42, nonlin: str = 'elu', same_init: bool = False, ortho_reg_type: str = 'abs')

Bases: catenets.models.jax.base.BaseCATENet

Class implements SNet as discussed in Curth & van der Schaar (2021). Additionally to the version implemented in the AISTATS paper, we also include an implementation that does not have propensity heads (set with_prop=False)

Parameters
  • with_prop (bool, True) – Whether to include propensity head

  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_layers_out_prop (int) – Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – Number of hidden units in each hypothesis layer

  • n_units_out_prop (int) – Number of hidden units in each propensity score hypothesis layer

  • n_layers_r (int) – Number of shared & private representation layers before hypothesis layers

  • n_units_r (int) – If withprop=True: Number of hidden units in representation layer shared by propensity score and outcome function (the ‘confounding factor’) and in the (‘instrumental factor’) If withprop=False: Number of hidden units in representation shared across PO function

  • n_units_r_small (int) – If withprop=True: Number of hidden units in representation layer of the ‘outcome factor’ and each PO functions private representation if withprop=False: Number of hidden units in each PO functions private representation

  • penalty_l2 (float) – l2 (ridge) penalty

  • step_size (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • reg_diff (bool, default False) – Whether to regularize the difference between the two potential outcome heads

  • penalty_diff (float) – l2-penalty for regularizing the difference between output heads. used only if train_separate=False

  • same_init (bool, False) – Whether to initialise the two output heads with same values

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

  • penalty_disc (float, default zero) – Discrepancy penalty. Defaults to zero as this feature is not tested.

  • ortho_reg_type (str, 'abs') – Which type of orthogonalization to use. ‘abs’ uses the (hard) disentanglement described in AISTATS paper, ‘fro’ uses frobenius norm as in FlexTENet

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
class SNet1(binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, penalty_l2: float = 0.0001, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, reg_diff: bool = False, penalty_diff: float = 0.0001, same_init: bool = False, nonlin: str = 'elu', penalty_disc: float = 0)

Bases: catenets.models.jax.base.BaseCATENet

Class implements Shalit et al (2017)’s TARNet & CFR (discrepancy regularization is NOT TESTED). Also referred to as SNet-1 in our paper.

Parameters
  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – Number of hidden units in each hypothesis layer

  • n_layers_r (int) – Number of shared representation layers before hypothesis layers

  • n_units_r (int) – Number of hidden units in each representation layer

  • penalty_l2 (float) – l2 (ridge) penalty

  • step_size (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • reg_diff (bool, default False) – Whether to regularize the difference between the two potential outcome heads

  • penalty_diff (float) – l2-penalty for regularizing the difference between output heads. used only if train_separate=False

  • same_init (bool, False) – Whether to initialise the two output heads with same values

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

  • penalty_disc (float, default zero) – Discrepancy penalty. Defaults to zero as this feature is not tested.

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
class SNet2(binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, penalty_l2: float = 0.0001, n_units_out_prop: int = 100, n_layers_out_prop: int = 2, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, reg_diff: bool = False, same_init: bool = False, penalty_diff: float = 0.0001, nonlin: str = 'elu')

Bases: catenets.models.jax.base.BaseCATENet

Class implements SNet-2, which is based on Shi et al (2019)’s DragonNet (this version does NOT use targeted regularization and has a (possibly deeper) propensity head.

Parameters
  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_layers_out_prop (int) – Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – Number of hidden units in each hypothesis layer

  • n_units_out_prop (int) – Number of hidden units in each propensity score hypothesis layer

  • n_layers_r (int) – Number of shared representation layers before hypothesis layers

  • n_units_r (int) – Number of hidden units in each representation layer

  • penalty_l2 (float) – l2 (ridge) penalty

  • step_size (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • reg_diff (bool, default False) – Whether to regularize the difference between the two potential outcome heads

  • penalty_diff (float) – l2-penalty for regularizing the difference between output heads. used only if train_separate=False

  • same_init (bool, False) – Whether to initialise the two output heads with same values

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
class SNet3(binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 150, n_layers_out: int = 2, n_units_r_small: int = 50, n_units_out: int = 100, n_units_out_prop: int = 100, n_layers_out_prop: int = 2, penalty_l2: float = 0.0001, penalty_orthogonal: float = 0.01, penalty_disc: float = 0, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, nonlin: str = 'elu', reg_diff: bool = False, penalty_diff: float = 0.0001, same_init: bool = False)

Bases: catenets.models.jax.base.BaseCATENet

Class implements SNet-3, which is based on Hassanpour & Greiner (2020)’s DR-CFR (Without propensity weighting), using an orthogonal regularizer to enforce decomposition similar to Wu et al (2020).

Parameters
  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_layers_out_prop (int) – Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – Number of hidden units in each hypothesis layer

  • n_units_out_prop (int) – Number of hidden units in each propensity score hypothesis layer

  • n_layers_r (int) – Number of shared & private representation layers before hypothesis layers

  • n_units_r (int) – Number of hidden units in representation layer shared by propensity score and outcome function (the ‘confounding factor’)

  • n_units_r_small (int) – Number of hidden units in representation layer NOT shared by propensity score and outcome functions (the ‘outcome factor’ and the ‘instrumental factor’)

  • penalty_l2 (float) – l2 (ridge) penalty

  • step_size (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • reg_diff (bool, default False) – Whether to regularize the difference between the two potential outcome heads

  • penalty_diff (float) – l2-penalty for regularizing the difference between output heads. used only if train_separate=False

  • same_init (bool, False) – Whether to initialise the two output heads with same values

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

  • penalty_disc (float, default zero) – Discrepancy penalty. Defaults to zero as this feature is not tested.

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
class TARNet(binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, penalty_l2: float = 0.0001, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, reg_diff: bool = False, penalty_diff: float = 0.0001, same_init: bool = False, nonlin: str = 'elu')

Bases: catenets.models.jax.representation_nets.SNet1

Wrapper for TARNet

_abc_impl = <_abc_data object>
class TNet(binary_y: bool = False, n_layers_out: int = 2, n_units_out: int = 100, n_layers_r: int = 3, n_units_r: int = 200, penalty_l2: float = 0.0001, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, train_separate: bool = True, penalty_diff: float = 0.0001, nonlin: str = 'elu')

Bases: catenets.models.jax.base.BaseCATENet

TNet class – two separate functions learned for each Potential Outcome function

Parameters
  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – Number of hidden units in each hypothesis layer

  • n_layers_r (int) – Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_r (int) – Number of hidden units in each representation layer

  • penalty_l2 (float) – l2 (ridge) penalty

  • step_size (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • train_separate (bool, default True) – Whether to train the two output heads completely separately or whether to regularize their difference

  • penalty_diff (float) – l2-penalty for regularizing the difference between output heads. used only if train_separate=False

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
class XNet(weight_strategy: Optional[int] = None, first_stage_strategy: str = 'T', first_stage_args: Optional[dict] = None, binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, nonlin: str = 'elu')

Bases: catenets.models.jax.base.BaseCATENet

Class implements X-learner using NNs.

Parameters
  • weight_strategy (int, default None) –

    Which strategy to use to weight the two CATE estimators in the second stage. weight_strategy is coded as follows: for tau(x)=g(x)tau_0(x) + (1-g(x))tau_1(x) [eq 9, kuenzel et al (2019)] weight_strategy=0 sets g(x)=0, weight_strategy=1 sets g(x)=1, weight_strategy=None sets g(x)=pi(x) [propensity score],

    weight_strategy=-1 sets g(x)=(1-pi(x))

  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – First stage Number of hidden units in each hypothesis layer

  • n_layers_r (int) – First stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_r (int) – First stage Number of hidden units in each representation layer

  • n_layers_out_t (int) – Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out_t (int) – Second stage Number of hidden units in each hypothesis layer

  • n_layers_r_t (int) – Second stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_r_t (int) – Second stage Number of hidden units in each representation layer

  • penalty_l2 (float) – First stage l2 (ridge) penalty

  • penalty_l2_t (float) – Second stage l2 (ridge) penalty

  • step_size (float) – First stage learning rate for optimizer

  • step_size_t (float) – Second stage learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
predict(X: jax._src.basearray.Array, return_po: bool = False, return_prop: bool = False) jax._src.basearray.Array

Predict treatment effect estimates using a CATENet. Depending on method, can also return potential outcome estimate and propensity score estimate.

Parameters
  • X (pd.DataFrame or np.array) – Covariate matrix

  • return_po (bool, default False) – Whether to return potential outcome estimate

  • return_prop (bool, default False) – Whether to return propensity estimate

Returns

Return type

array of CATE estimates, optionally also potential outcomes and propensity

Submodules