catenets.models.torch.flextenet module

class ElementWiseParallelActivation(act: Callable, **act_kwargs: Any)

Bases: torch.nn.modules.module.Module

Layer that applies a scalar function elementwise on its inputs.

Input looks like: X_s, X_p0, X_p1, t = inputs

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[torch.Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(tensors: List[torch.Tensor]) List

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class ElementWiseSplitActivation(act: Callable, **act_kwargs: Any)

Bases: torch.nn.modules.module.Module

Layer that applies a scalar function elementwise on its inputs.

Input looks like: X, t = inputs

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[torch.Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(tensors: List[torch.Tensor]) List

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class FlexTELinearLayer(name: str, dropout: bool = False, dropout_prob: float = 0.5, *args: Any, **kwargs: Any)

Bases: torch.nn.modules.module.Module

Layer constructor function for a fully-connected layer. Adapted to allow passing treatment indicator through layer without using it

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[torch.Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(tensors: List[torch.Tensor]) List

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class FlexTENet(n_unit_in: int, binary_y: bool, n_layers_out: int = 2, n_units_s_out: int = 50, n_units_p_out: int = 50, n_layers_r: int = 3, n_units_s_r: int = 100, n_units_p_r: int = 100, private_out: bool = False, weight_decay: float = 0.0001, penalty_orthogonal: float = 0.01, lr: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, shared_repr: bool = False, normalize_ortho: bool = False, mode: int = 1, clipping_value: int = 1, dropout: bool = False, dropout_prob: float = 0.5)

Bases: catenets.models.torch.base.BaseCATEEstimator

CLass implements FlexTENet, an architecture for treatment effect estimation that allows for both shared and private information in each layer of the network.

Parameters
  • n_unit_in (int) – Number of features

  • binary_y (bool, default False) – Whether the outcome is binary

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer)

  • n_units_s_out (int) – Number of hidden units in each shared hypothesis layer

  • n_units_p_out (int) – Number of hidden units in each private hypothesis layer

  • n_layers_r (int) – Number of representation layers before hypothesis layers (distinction between hypothesis layers and representation layers is made to match TARNet & SNets)

  • n_units_s_r (int) – Number of hidden units in each shared representation layer

  • n_units_s_r – Number of hidden units in each private representation layer

  • private_out (bool, False) – Whether the final prediction layer should be fully private, or retain a shared component.

  • weight_decay (float) – l2 (ridge) penalty

  • penalty_orthogonal (float) – orthogonalisation penalty

  • lr (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • early_stopping (bool, default True) – Whether to use early stopping

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • opt (str, default 'adam') – Optimizer to use, accepts ‘adam’ and ‘sgd’

  • shared_repr (bool, False) – Whether to use a shared representation block as TARNet

  • lr_scale (float) – Whether to scale down the learning rate after unfreezing the private components of the network (only used if pretrain_shared=True)

  • normalize_ortho (bool, False) – Whether to normalize the orthogonality penalty (by depth of network)

  • clipping_value (int, default 1) – Gradients clipping value

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[torch.Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_ortho_penalty_asymmetric() torch.Tensor
_parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
_state_dict_hooks: Dict[int, Callable]
fit(X: torch.Tensor, y: torch.Tensor, w: torch.Tensor) catenets.models.torch.flextenet.FlexTENet

Fit treatment models.

Parameters
  • X (torch.Tensor of shape (n_samples, n_features)) – The features to fit to

  • y (torch.Tensor of shape (n_samples,) or (n_samples, )) – The outcome variable

  • w (torch.Tensor of shape (n_samples,)) – The treatment indicator

loss(y0_pred: torch.Tensor, y1_pred: torch.Tensor, y_true: torch.Tensor, t_true: torch.Tensor) torch.Tensor
predict(X: torch.Tensor, return_po: bool = False, training: bool = False) torch.Tensor

Predict treatment effects and potential outcomes

Parameters

X (array-like of shape (n_samples, n_features)) – Test-sample features

Returns

y

Return type

array-like of shape (n_samples,)

training: bool
class FlexTEOutputLayer(n_units_in: int, n_units_in_p: int, private: bool, dropout: bool = False, dropout_prob: float = 0.5)

Bases: torch.nn.modules.module.Module

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[torch.Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(tensors: List[torch.Tensor]) torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class FlexTESplitLayer(name: str, n_units_in: int, n_units_in_p: int, n_units_s: int, n_units_p: int, first_layer: bool, dropout: bool = False, dropout_prob: float = 0.5)

Bases: torch.nn.modules.module.Module

Create multitask layer has shape [shared, private_0, private_1]

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[torch.Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(tensors: List[torch.Tensor]) List

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool