catenets.models.jax.representation_nets module

Module implements SNet1 and SNet2, which are based on CFRNet/TARNet from Shalit et al (2017) and DragonNet from Shi et al (2019), respectively.

class DragonNet(binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, penalty_l2: float = 0.0001, n_units_out_prop: int = 100, n_layers_out_prop: int = 0, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, reg_diff: bool = False, same_init: bool = False, penalty_diff: float = 0.0001, nonlin: str = 'elu')

Bases: catenets.models.jax.representation_nets.SNet2

Wrapper for DragonNet

_abc_impl = <_abc_data object>
class SNet1(binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, penalty_l2: float = 0.0001, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, reg_diff: bool = False, penalty_diff: float = 0.0001, same_init: bool = False, nonlin: str = 'elu', penalty_disc: float = 0)

Bases: catenets.models.jax.base.BaseCATENet

Class implements Shalit et al (2017)’s TARNet & CFR (discrepancy regularization is NOT TESTED). Also referred to as SNet-1 in our paper.

Parameters
  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – Number of hidden units in each hypothesis layer

  • n_layers_r (int) – Number of shared representation layers before hypothesis layers

  • n_units_r (int) – Number of hidden units in each representation layer

  • penalty_l2 (float) – l2 (ridge) penalty

  • step_size (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • reg_diff (bool, default False) – Whether to regularize the difference between the two potential outcome heads

  • penalty_diff (float) – l2-penalty for regularizing the difference between output heads. used only if train_separate=False

  • same_init (bool, False) – Whether to initialise the two output heads with same values

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

  • penalty_disc (float, default zero) – Discrepancy penalty. Defaults to zero as this feature is not tested.

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
class SNet2(binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, penalty_l2: float = 0.0001, n_units_out_prop: int = 100, n_layers_out_prop: int = 2, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, reg_diff: bool = False, same_init: bool = False, penalty_diff: float = 0.0001, nonlin: str = 'elu')

Bases: catenets.models.jax.base.BaseCATENet

Class implements SNet-2, which is based on Shi et al (2019)’s DragonNet (this version does NOT use targeted regularization and has a (possibly deeper) propensity head.

Parameters
  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_layers_out_prop (int) – Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_out (int) – Number of hidden units in each hypothesis layer

  • n_units_out_prop (int) – Number of hidden units in each propensity score hypothesis layer

  • n_layers_r (int) – Number of shared representation layers before hypothesis layers

  • n_units_r (int) – Number of hidden units in each representation layer

  • penalty_l2 (float) – l2 (ridge) penalty

  • step_size (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • reg_diff (bool, default False) – Whether to regularize the difference between the two potential outcome heads

  • penalty_diff (float) – l2-penalty for regularizing the difference between output heads. used only if train_separate=False

  • same_init (bool, False) – Whether to initialise the two output heads with same values

  • nonlin (string, default 'elu') – Nonlinearity to use in NN

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
class TARNet(binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, penalty_l2: float = 0.0001, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, reg_diff: bool = False, penalty_diff: float = 0.0001, same_init: bool = False, nonlin: str = 'elu')

Bases: catenets.models.jax.representation_nets.SNet1

Wrapper for TARNet

_abc_impl = <_abc_data object>
mmd2_lin(X: jax._src.basearray.Array, w: jax._src.basearray.Array) jax._src.basearray.Array
predict_snet1(X: jax._src.basearray.Array, trained_params: dict, predict_funs: list, return_po: bool = False, return_prop: bool = False) jax._src.basearray.Array
predict_snet2(X: jax._src.basearray.Array, trained_params: dict, predict_funs: list, return_po: bool = False, return_prop: bool = False) jax._src.basearray.Array
train_snet1(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, penalty_l2: float = 0.0001, penalty_disc: int = 0, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, return_val_loss: bool = False, reg_diff: bool = False, same_init: bool = False, penalty_diff: float = 0.0001, nonlin: str = 'elu', avg_objective: bool = True) Any
train_snet2(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, penalty_l2: float = 0.0001, n_units_out_prop: int = 100, n_layers_out_prop: int = 2, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, return_val_loss: bool = False, reg_diff: bool = False, penalty_diff: float = 0.0001, nonlin: str = 'elu', avg_objective: bool = True, same_init: bool = False) Any

SNet2 corresponds to DragonNet (Shi et al, 2019) [without TMLE regularisation term].