catenets.models.torch.flextenet module
- class ElementWiseParallelActivation(act: Callable, **act_kwargs: Any)
Bases:
torch.nn.modules.module.Module
Layer that applies a scalar function elementwise on its inputs.
Input looks like: X_s, X_p0, X_p1, t = inputs
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(tensors: List[torch.Tensor]) List
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class ElementWiseSplitActivation(act: Callable, **act_kwargs: Any)
Bases:
torch.nn.modules.module.Module
Layer that applies a scalar function elementwise on its inputs.
Input looks like: X, t = inputs
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(tensors: List[torch.Tensor]) List
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class FlexTELinearLayer(name: str, dropout: bool = False, dropout_prob: float = 0.5, *args: Any, **kwargs: Any)
Bases:
torch.nn.modules.module.Module
Layer constructor function for a fully-connected layer. Adapted to allow passing treatment indicator through layer without using it
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(tensors: List[torch.Tensor]) List
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class FlexTENet(n_unit_in: int, binary_y: bool, n_layers_out: int = 2, n_units_s_out: int = 50, n_units_p_out: int = 50, n_layers_r: int = 3, n_units_s_r: int = 100, n_units_p_r: int = 100, private_out: bool = False, weight_decay: float = 0.0001, penalty_orthogonal: float = 0.01, lr: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, shared_repr: bool = False, normalize_ortho: bool = False, mode: int = 1, clipping_value: int = 1, dropout: bool = False, dropout_prob: float = 0.5)
Bases:
catenets.models.torch.base.BaseCATEEstimator
CLass implements FlexTENet, an architecture for treatment effect estimation that allows for both shared and private information in each layer of the network.
- Parameters
n_unit_in (int) – Number of features
binary_y (bool, default False) – Whether the outcome is binary
n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer)
n_units_s_out (int) – Number of hidden units in each shared hypothesis layer
n_units_p_out (int) – Number of hidden units in each private hypothesis layer
n_layers_r (int) – Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)
n_units_s_r (int) – Number of hidden units in each shared representation layer
n_units_s_r – Number of hidden units in each private representation layer
private_out (bool, False) – Whether the final prediction layer should be fully private, or retain a shared component.
weight_decay (float) – l2 (ridge) penalty
penalty_orthogonal (float) – orthogonalisation penalty
lr (float) – learning rate for optimizer
n_iter (int) – Maximum number of iterations
batch_size (int) – Batch size
val_split_prop (float) – Proportion of samples used for validation split (can be 0)
early_stopping (bool, default True) – Whether to use early stopping
patience (int) – Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min (int) – Minimum number of iterations to go through before starting early stopping
n_iter_print (int) – Number of iterations after which to print updates
seed (int) – Seed used
opt (str, default 'adam') – Optimizer to use, accepts ‘adam’ and ‘sgd’
shared_repr (bool, False) – Whether to use a shared representation block as TARNet
lr_scale (float) – Whether to scale down the learning rate after unfreezing the private components of the network (only used if pretrain_shared=True)
normalize_ortho (bool, False) – Whether to normalize the orthogonality penalty (by depth of network)
clipping_value (int, default 1) – Gradients clipping value
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _ortho_penalty_asymmetric() torch.Tensor
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- fit(X: torch.Tensor, y: torch.Tensor, w: torch.Tensor) catenets.models.torch.flextenet.FlexTENet
Fit treatment models.
- Parameters
X (torch.Tensor of shape (n_samples, n_features)) – The features to fit to
y (torch.Tensor of shape (n_samples,) or (n_samples, )) – The outcome variable
w (torch.Tensor of shape (n_samples,)) – The treatment indicator
- loss(y0_pred: torch.Tensor, y1_pred: torch.Tensor, y_true: torch.Tensor, t_true: torch.Tensor) torch.Tensor
- predict(X: torch.Tensor, return_po: bool = False, training: bool = False) torch.Tensor
Predict treatment effects and potential outcomes
- Parameters
X (array-like of shape (n_samples, n_features)) – Test-sample features
- Returns
y
- Return type
array-like of shape (n_samples,)
- training: bool
- class FlexTEOutputLayer(n_units_in: int, n_units_in_p: int, private: bool, dropout: bool = False, dropout_prob: float = 0.5)
Bases:
torch.nn.modules.module.Module
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(tensors: List[torch.Tensor]) torch.Tensor
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class FlexTESplitLayer(name: str, n_units_in: int, n_units_in_p: int, n_units_s: int, n_units_p: int, first_layer: bool, dropout: bool = False, dropout_prob: float = 0.5)
Bases:
torch.nn.modules.module.Module
Create multitask layer has shape [shared, private_0, private_1]
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(tensors: List[torch.Tensor]) List
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool