catenets.models.jax.flextenet module
Module implements FlexTENet, also referred to as the ‘flexible approach’ in “On inductive biases for heterogeneous treatment effect estimation”, Curth & vd Schaar (2021).
- DenseW(out_dim: int, W_init: Callable = <function variance_scaling.<locals>.init>, b_init: Callable = <function normal.<locals>.init>) Tuple
Layer constructor function for a dense (fully-connected) layer. Adapted to allow passing treatment indicator through layer without using it
- class FlexTENet(binary_y: bool = False, n_layers_out: int = 2, n_units_s_out: int = 50, n_units_p_out: int = 50, n_layers_r: int = 3, n_units_s_r: int = 100, n_units_p_r: int = 100, private_out: bool = False, penalty_l2: float = 0.0001, penalty_l2_p: float = 0.0001, penalty_orthogonal: float = 0.01, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, return_val_loss: bool = False, opt: str = 'adam', shared_repr: bool = False, pretrain_shared: bool = False, same_init: bool = True, lr_scale: float = 10, normalize_ortho: bool = False)
Bases:
catenets.models.jax.base.BaseCATENet
Module implements FlexTENet, an architecture for treatment effect estimation that allows for both shared and private information in each layer of the network.
- Parameters
binary_y (bool, default False) – Whether the outcome is binary
n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_units_s_out (int) – Number of hidden units in each shared hypothesis layer
n_units_p_out (int) – Number of hidden units in each private hypothesis layer
n_layers_r (int) – Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)
n_units_s_r (int) – Number of hidden units in each shared representation layer
n_units_s_r – Number of hidden units in each private representation layer
private_out (bool, False) – Whether the final prediction layer should be fully private, or retain a shared component.
penalty_l2 (float) – l2 (ridge) penalty
penalty_l2_p (float) – l2 (ridge) penalty for private layers
penalty_orthogonal (float) – orthogonalisation penalty
step_size (float) – learning rate for optimizer
n_iter (int) – Maximum number of iterations
batch_size (int) – Batch size
val_split_prop (float) – Proportion of samples used for validation split (can be 0)
early_stopping (bool, default True) – Whether to use early stopping
patience (int) – Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min (int) – Minimum number of iterations to go through before starting early stopping
n_iter_print (int) – Number of iterations after which to print updates
seed (int) – Seed used
opt (str, default 'adam') – Optimizer to use, accepts ‘adam’ and ‘sgd’
shared_repr (bool, False) – Whether to use a shared representation block as TARNet
pretrain_shared (bool, False) – Whether to pretrain the shared component of the network while freezing the private parameters
same_init (bool, True) – Whether to use the same initialisation for all private spaces
lr_scale (float) – Whether to scale down the learning rate after unfreezing the private components of the network (only used if pretrain_shared=True)
normalize_ortho (bool, False) – Whether to normalize the orthogonality penalty (by depth of network)
- _abc_impl = <_abc_data object>
- _get_predict_function() Callable
- _get_train_function() Callable
- FlexTENetArchitecture(n_layers_out: int = 2, n_units_s_out: int = 50, n_units_p_out: int = 50, n_layers_r: int = 3, n_units_s_r: int = 100, n_units_p_r: int = 100, private_out: bool = False, binary_y: bool = False, shared_repr: bool = False, same_init: bool = True) Any
- SplitLayerAsymmetric(n_units_s: int, n_units_p: int, first_layer: bool = False, same_init: bool = True) Tuple
- TEOutputLayerAsymmetric(private: bool = True, same_init: bool = True) Tuple
- _compute_ortho_penalty_asymmetric(params: jax._src.basearray.Array, n_layers_out: int, n_layers_r: int, private_out: int, penalty_orthogonal: float, shared_repr: bool, normalize_ortho: bool, mode: int = 1) float
- _compute_penalty(params: jax._src.basearray.Array, n_layers_out: int, n_layers_r: int, private_out: int, penalty_l2: float, penalty_l2_p: float, penalty_orthogonal: float, shared_repr: bool, normalize_ortho: bool, mode: int = 1) jax._src.basearray.Array
- _compute_penalty_l2(params: jax._src.basearray.Array, n_layers_out: int, n_layers_r: int, private_out: int, penalty_l2: float, penalty_l2_p: float, shared_repr: bool, mode: int = 1) jax._src.basearray.Array
- _get_cos_reg(params_0: jax._src.basearray.Array, params_1: jax._src.basearray.Array, normalize: bool) jax._src.basearray.Array
- elementwise_parallel(fun: Callable, **fun_kwargs: Any) Tuple
Layer that applies a scalar function elementwise on its inputs. Adapted from original jax.stax to allow three inputs and to skip treatment indicator.
Input looks like: X_s, X_p0, X_p1, t = inputs
- elementwise_split(fun: Callable, **fun_kwargs: Any) Tuple
Layer that applies a scalar function elementwise on its inputs. Adapted from original jax.stax to skip treatment indicator.
Input looks like: X, t = inputs
- predict_flextenet(X: jax._src.basearray.Array, trained_params: jax._src.basearray.Array, predict_funs: Callable, return_po: bool = False, return_prop: bool = False) Any
- train_flextenet(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, binary_y: bool = False, n_layers_out: int = 2, n_units_s_out: int = 50, n_units_p_out: int = 50, n_layers_r: int = 3, n_units_s_r: int = 100, n_units_p_r: int = 100, private_out: bool = False, penalty_l2: float = 0.0001, penalty_l2_p: float = 0.0001, penalty_orthogonal: float = 0.01, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, avg_objective: bool = True, n_iter_print: int = 50, seed: int = 42, return_val_loss: bool = False, opt: str = 'adam', shared_repr: bool = False, pretrain_shared: bool = False, same_init: bool = True, lr_scale: float = 10, normalize_ortho: bool = False, nonlin: str = 'elu', n_units_r: Optional[int] = None, n_units_out: Optional[int] = None) Tuple