catenets.models.torch.base module
- class BaseCATEEstimator
Bases:
torch.nn.modules.module.Module
Interface for estimators of CATE.
The interface has train/forward API for PyTorch-based models and fit/predict API for sklearn-based models.
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _check_tensor(X: torch.Tensor) torch.Tensor
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- abstract fit(X: torch.Tensor, y: torch.Tensor, w: torch.Tensor) Any
- forward(**kwargs: Any) Any
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- abstract predict(**kwargs: Any) Any
- score(X: torch.Tensor, y: torch.Tensor) float
Return the sqrt PEHE error (oracle metric).
- Parameters
X (torch.Tensor) – Covariate matrix
y (torch.Tensor) – Expected potential outcome vector
- training: bool
- class BasicNet(name: str, n_unit_in: int, n_layers_out: int = 2, n_units_out: int = 100, binary_y: bool = False, nonlin: str = 'elu', lr: float = 0.0001, weight_decay: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_print: int = 50, seed: int = 42, val_split_prop: float = 0.3, patience: int = 10, n_iter_min: int = 200, clipping_value: int = 1, batch_norm: bool = True, early_stopping: bool = True, dropout: bool = False, dropout_prob: float = 0.2)
Bases:
torch.nn.modules.module.Module
Basic hypothesis neural net.
- Parameters
n_unit_in (int) – Number of features
n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer)
n_units_out (int) – Number of hidden units in each hypothesis layer
binary_y (bool, default False) – Whether the outcome is binary. Impacts the loss function.
nonlin (string, default 'elu') – Nonlinearity to use in NN. Can be ‘elu’, ‘relu’, ‘selu’ or ‘leaky_relu’.
lr (float) – learning rate for optimizer. step_size equivalent in the JAX version.
weight_decay (float) – l2 (ridge) penalty for the weights.
n_iter (int) – Maximum number of iterations.
batch_size (int) – Batch size
n_iter_print (int) – Number of iterations after which to print updates and check the validation loss.
seed (int) – Seed used
val_split_prop (float) – Proportion of samples used for validation split (can be 0)
patience (int) – Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min (int) – Minimum number of iterations to go through before starting early stopping
clipping_value (int, default 1) – Gradients clipping value
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _check_tensor(X: torch.Tensor) torch.Tensor
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- fit(X: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] = None) catenets.models.torch.base.BasicNet
- forward(X: torch.Tensor) torch.Tensor
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class PropensityNet(name: str, n_unit_in: int, n_unit_out: int, weighting_strategy: str, n_units_out_prop: int = 100, n_layers_out_prop: int = 0, nonlin: str = 'elu', lr: float = 0.0001, weight_decay: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_print: int = 50, seed: int = 42, val_split_prop: float = 0.3, patience: int = 10, n_iter_min: int = 200, clipping_value: int = 1, batch_norm: bool = True, early_stopping: bool = True, dropout: bool = False, dropout_prob: float = 0.2)
Bases:
torch.nn.modules.module.Module
Basic propensity neural net
- Parameters
name (str) – Display name
n_unit_in (int) – Number of features
n_unit_out (int) – Number of output features
weighting_strategy (str) – Weighting strategy
n_units_out_prop (int) – Number of hidden units in each propensity score hypothesis layer
n_layers_out_prop (int) – Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer)
nonlin (string, default 'elu') – Nonlinearity to use in NN. Can be ‘elu’, ‘relu’, ‘selu’ or ‘leaky_relu’.
lr (float) – learning rate for optimizer. step_size equivalent in the JAX version.
weight_decay (float) – l2 (ridge) penalty for the weights.
n_iter (int) – Maximum number of iterations.
batch_size (int) – Batch size
n_iter_print (int) – Number of iterations after which to print updates and check the validation loss.
seed (int) – Seed used
val_split_prop (float) – Proportion of samples used for validation split (can be 0)
patience (int) – Number of iterations to wait before early stopping after decrease in validation loss
n_iter_min (int) – Minimum number of iterations to go through before starting early stopping
clipping_value (int, default 1) – Gradients clipping value
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _check_tensor(X: torch.Tensor) torch.Tensor
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- fit(X: torch.Tensor, y: torch.Tensor) catenets.models.torch.base.PropensityNet
- forward(X: torch.Tensor) torch.Tensor
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_importance_weights(X: torch.Tensor, w: Optional[torch.Tensor] = None) torch.Tensor
- loss(y_pred: torch.Tensor, y_target: torch.Tensor) torch.Tensor
- training: bool
- class RepresentationNet(n_unit_in: int, n_layers: int = 3, n_units: int = 200, nonlin: str = 'elu', batch_norm: bool = True)
Bases:
torch.nn.modules.module.Module
Basic representation neural net
- Parameters
n_unit_in (int) – Number of features
n_layers (int) – Number of shared representation layers before hypothesis layers
n_units (int) – Number of hidden units in each representation layer
nonlin (string, default 'elu') – Nonlinearity to use in NN. Can be ‘elu’, ‘relu’, ‘selu’ or ‘leaky_relu’.
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- forward(X: torch.Tensor) torch.Tensor
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool