catenets.models.jax.flextenet module

Module implements FlexTENet, also referred to as the ‘flexible approach’ in “On inductive biases for heterogeneous treatment effect estimation”, Curth & vd Schaar (2021).

DenseW(out_dim: int, W_init: Callable = <function variance_scaling.<locals>.init>, b_init: Callable = <function normal.<locals>.init>) Tuple

Layer constructor function for a dense (fully-connected) layer. Adapted to allow passing treatment indicator through layer without using it

class FlexTENet(binary_y: bool = False, n_layers_out: int = 2, n_units_s_out: int = 50, n_units_p_out: int = 50, n_layers_r: int = 3, n_units_s_r: int = 100, n_units_p_r: int = 100, private_out: bool = False, penalty_l2: float = 0.0001, penalty_l2_p: float = 0.0001, penalty_orthogonal: float = 0.01, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, return_val_loss: bool = False, opt: str = 'adam', shared_repr: bool = False, pretrain_shared: bool = False, same_init: bool = True, lr_scale: float = 10, normalize_ortho: bool = False)

Bases: catenets.models.jax.base.BaseCATENet

Module implements FlexTENet, an architecture for treatment effect estimation that allows for both shared and private information in each layer of the network.

Parameters
  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)

  • n_units_s_out (int) – Number of hidden units in each shared hypothesis layer

  • n_units_p_out (int) – Number of hidden units in each private hypothesis layer

  • n_layers_r (int) – Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_s_r (int) – Number of hidden units in each shared representation layer

  • n_units_s_r – Number of hidden units in each private representation layer

  • private_out (bool, False) – Whether the final prediction layer should be fully private, or retain a shared component.

  • penalty_l2 (float) – l2 (ridge) penalty

  • penalty_l2_p (float) – l2 (ridge) penalty for private layers

  • penalty_orthogonal (float) – orthogonalisation penalty

  • step_size (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • opt (str, default 'adam') – Optimizer to use, accepts ‘adam’ and ‘sgd’

  • shared_repr (bool, False) – Whether to use a shared representation block as TARNet

  • pretrain_shared (bool, False) – Whether to pretrain the shared component of the network while freezing the private parameters

  • same_init (bool, True) – Whether to use the same initialisation for all private spaces

  • lr_scale (float) – Whether to scale down the learning rate after unfreezing the private components of the network (only used if pretrain_shared=True)

  • normalize_ortho (bool, False) – Whether to normalize the orthogonality penalty (by depth of network)

_abc_impl = <_abc_data object>
_get_predict_function() Callable
_get_train_function() Callable
FlexTENetArchitecture(n_layers_out: int = 2, n_units_s_out: int = 50, n_units_p_out: int = 50, n_layers_r: int = 3, n_units_s_r: int = 100, n_units_p_r: int = 100, private_out: bool = False, binary_y: bool = False, shared_repr: bool = False, same_init: bool = True) Any
SplitLayerAsymmetric(n_units_s: int, n_units_p: int, first_layer: bool = False, same_init: bool = True) Tuple
TEOutputLayerAsymmetric(private: bool = True, same_init: bool = True) Tuple
_compute_ortho_penalty_asymmetric(params: jax._src.basearray.Array, n_layers_out: int, n_layers_r: int, private_out: int, penalty_orthogonal: float, shared_repr: bool, normalize_ortho: bool, mode: int = 1) float
_compute_penalty(params: jax._src.basearray.Array, n_layers_out: int, n_layers_r: int, private_out: int, penalty_l2: float, penalty_l2_p: float, penalty_orthogonal: float, shared_repr: bool, normalize_ortho: bool, mode: int = 1) jax._src.basearray.Array
_compute_penalty_l2(params: jax._src.basearray.Array, n_layers_out: int, n_layers_r: int, private_out: int, penalty_l2: float, penalty_l2_p: float, shared_repr: bool, mode: int = 1) jax._src.basearray.Array
_get_cos_reg(params_0: jax._src.basearray.Array, params_1: jax._src.basearray.Array, normalize: bool) jax._src.basearray.Array
elementwise_parallel(fun: Callable, **fun_kwargs: Any) Tuple

Layer that applies a scalar function elementwise on its inputs. Adapted from original jax.stax to allow three inputs and to skip treatment indicator.

Input looks like: X_s, X_p0, X_p1, t = inputs

elementwise_split(fun: Callable, **fun_kwargs: Any) Tuple

Layer that applies a scalar function elementwise on its inputs. Adapted from original jax.stax to skip treatment indicator.

Input looks like: X, t = inputs

predict_flextenet(X: jax._src.basearray.Array, trained_params: jax._src.basearray.Array, predict_funs: Callable, return_po: bool = False, return_prop: bool = False) Any
train_flextenet(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, binary_y: bool = False, n_layers_out: int = 2, n_units_s_out: int = 50, n_units_p_out: int = 50, n_layers_r: int = 3, n_units_s_r: int = 100, n_units_p_r: int = 100, private_out: bool = False, penalty_l2: float = 0.0001, penalty_l2_p: float = 0.0001, penalty_orthogonal: float = 0.01, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, avg_objective: bool = True, n_iter_print: int = 50, seed: int = 42, return_val_loss: bool = False, opt: str = 'adam', shared_repr: bool = False, pretrain_shared: bool = False, same_init: bool = True, lr_scale: float = 10, normalize_ortho: bool = False, nonlin: str = 'elu', n_units_r: Optional[int] = None, n_units_out: Optional[int] = None) Tuple