catenets.models.jax.snet module
Module implements SNet class as discussed in Curth & van der Schaar (2021)
- class SNet(with_prop: bool = True, binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 100, n_layers_out: int = 2, n_units_r_small: int = 50, n_units_out: int = 100, n_units_out_prop: int = 100, n_layers_out_prop: int = 2, penalty_l2: float = 0.0001, penalty_orthogonal: float = 0.01, penalty_disc: float = 0, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, reg_diff: bool = False, penalty_diff: float = 0.0001, seed: int = 42, nonlin: str = 'elu', same_init: bool = False, ortho_reg_type: str = 'abs')
Bases:
catenets.models.jax.base.BaseCATENet
Class implements SNet as discussed in Curth & van der Schaar (2021). Additionally to the version implemented in the AISTATS paper, we also include an implementation that does not have propensity heads (set with_prop=False)
- Parameters
with_prop (bool, True) – Whether to include propensity head
binary_y (bool, default False) – Whether the outcome is binary
n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_layers_out_prop (int) – Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer)
n_units_out (int) – Number of hidden units in each hypothesis layer
n_units_out_prop (int) – Number of hidden units in each propensity score hypothesis layer
n_layers_r (int) – Number of shared & private representation layers before hypothesis layers
n_units_r (int) – If withprop=True: Number of hidden units in representation layer shared by propensity score and outcome function (the ‘confounding factor’) and in the (‘instrumental factor’) If withprop=False: Number of hidden units in representation shared across PO function
n_units_r_small (int) – If withprop=True: Number of hidden units in representation layer of the ‘outcome factor’ and each PO functions private representation if withprop=False: Number of hidden units in each PO functions private representation
penalty_l2 (float) – l2 (ridge) penalty
step_size (float) – learning rate for optimizer
n_iter (int) – Maximum number of iterations
batch_size (int) – Batch size
val_split_prop (float) – Proportion of samples used for validation split (can be 0)
early_stopping (bool, default True) – Whether to use early stopping
patience (int) – Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min (int) – Minimum number of iterations to go through before starting early stopping
n_iter_print (int) – Number of iterations after which to print updates
seed (int) – Seed used
reg_diff (bool, default False) – Whether to regularize the difference between the two potential outcome heads
penalty_diff (float) – l2-penalty for regularizing the difference between output heads. used only if train_separate=False
same_init (bool, False) – Whether to initialise the two output heads with same values
nonlin (string, default 'elu') – Nonlinearity to use in NN
penalty_disc (float, default zero) – Discrepancy penalty. Defaults to zero as this feature is not tested.
ortho_reg_type (str, 'abs') – Which type of orthogonalization to use. ‘abs’ uses the (hard) disentanglement described in AISTATS paper, ‘fro’ uses frobenius norm as in FlexTENet
- _abc_impl = <_abc_data object>
- _get_predict_function() Callable
- _get_train_function() Callable
- predict_snet(X: jax._src.basearray.Array, trained_params: jax._src.basearray.Array, predict_funs: list, return_po: bool = False, return_prop: bool = False) jax._src.basearray.Array
- predict_snet_noprop(X: jax._src.basearray.Array, trained_params: jax._src.basearray.Array, predict_funs: list, return_po: bool = False, return_prop: bool = False) jax._src.basearray.Array
- train_snet(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 100, n_units_r_small: int = 50, n_layers_out: int = 2, n_units_out: int = 100, n_units_out_prop: int = 100, n_layers_out_prop: int = 2, penalty_l2: float = 0.0001, penalty_disc: float = 0, penalty_orthogonal: float = 0.01, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, return_val_loss: bool = False, reg_diff: bool = False, penalty_diff: float = 0.0001, nonlin: str = 'elu', avg_objective: bool = True, with_prop: bool = True, same_init: bool = False, ortho_reg_type: str = 'abs') Tuple
- train_snet_noprop(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 150, n_units_r_small: int = 50, n_layers_out: int = 2, n_units_out: int = 100, n_units_out_prop: int = 100, n_layers_out_prop: int = 2, penalty_l2: float = 0.0001, penalty_orthogonal: float = 0.01, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, n_iter_min: int = 200, patience: int = 10, n_iter_print: int = 50, seed: int = 42, return_val_loss: bool = False, reg_diff: bool = False, penalty_diff: float = 0.0001, nonlin: str = 'elu', avg_objective: bool = True, with_prop: bool = False, same_init: bool = False, ortho_reg_type: str = 'abs') Tuple
SNet but without the propensity head