catenets.models.jax.xnet module

Module implements X-learner from Kuenzel et al (2019) using NNs

class XNet(weight_strategy: Optional[int] = None, first_stage_strategy: str = 'T', first_stage_args: Optional[dict] = None, binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, nonlin: str = 'elu')

Bases: catenets.models.jax.base.BaseCATENet

Class implements X-learner using NNs.

Parameters
  • weight_strategy (int, default None) –

    Which strategy to use to weight the two CATE estimators in the second stage. weight_strategy is coded as follows: for tau(x)=g(x)tau_0(x) + (1-g(x))tau_1(x) [eq 9, kuenzel et al (2019)] weight_strategy=0 sets g(x)=0, weight_strategy=1 sets g(x)=1, weight_strategy=None sets g(x)=pi(x) [propensity score],

    weight_strategy=-1 sets g(x)=(1-pi(x))

  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – First stage Number of hidden units in each hypothesis layer

  • n_layers_r (int) – First stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_r (int) – First stage Number of hidden units in each representation layer

  • n_layers_out_t (int) – Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out_t (int) – Second stage Number of hidden units in each hypothesis layer

  • n_layers_r_t (int) – Second stage Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_r_t (int) – Second stage Number of hidden units in each representation layer

  • penalty_l2 (float) – First stage l2 (ridge) penalty

  • penalty_l2_t (float) – Second stage l2 (ridge) penalty

  • step_size (float) – First stage learning rate for optimizer

  • step_size_t (float) – Second stage learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
predict(X: jax._src.basearray.Array, return_po: bool = False, return_prop: bool = False) jax._src.basearray.Array

Predict treatment effect estimates using a CATENet. Depending on method, can also return potential outcome estimate and propensity score estimate.

Parameters
  • X (pd.DataFrame or np.array) – Covariate matrix

  • return_po (bool, default False) – Whether to return potential outcome estimate

  • return_prop (bool, default False) – Whether to return propensity estimate

Returns

Return type

array of CATE estimates, optionally also potential outcomes and propensity

_get_first_stage_pos(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, first_stage_strategy: str = 'T', first_stage_args: Optional[dict] = None, binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_units_out: int = 100, n_units_r: int = 200, penalty_l2: float = 0.0001, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, nonlin: str = 'elu', avg_objective: bool = True) Tuple[jax._src.basearray.Array, jax._src.basearray.Array]
predict_x_net(X: jax._src.basearray.Array, trained_params: dict, predict_funs: list, return_po: bool = False, return_prop: bool = False, weight_strategy: Optional[int] = None) jax._src.basearray.Array
train_x_net(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, weight_strategy: Optional[int] = None, first_stage_strategy: str = 'T', first_stage_args: Optional[dict] = None, binary_y: bool = False, n_layers_out: int = 2, n_layers_r: int = 3, n_layers_out_t: int = 2, n_layers_r_t: int = 3, n_units_out: int = 100, n_units_r: int = 200, n_units_out_t: int = 100, n_units_r_t: int = 200, penalty_l2: float = 0.0001, penalty_l2_t: float = 0.0001, step_size: float = 0.0001, step_size_t: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_min: int = 200, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_print: int = 50, seed: int = 42, nonlin: str = 'elu', return_val_loss: bool = False, avg_objective: bool = True) Tuple