catenets.models.torch.base module

class BaseCATEEstimator

Bases: torch.nn.modules.module.Module

Interface for estimators of CATE.

The interface has train/forward API for PyTorch-based models and fit/predict API for sklearn-based models.

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[torch.Tensor]]
_check_tensor(X: torch.Tensor) torch.Tensor
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
_state_dict_hooks: Dict[int, Callable]
abstract fit(X: torch.Tensor, y: torch.Tensor, w: torch.Tensor) Any
forward(**kwargs: Any) Any

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

abstract predict(**kwargs: Any) Any
score(X: torch.Tensor, y: torch.Tensor) float

Return the sqrt PEHE error (oracle metric).

Parameters
  • X (torch.Tensor) – Covariate matrix

  • y (torch.Tensor) – Expected potential outcome vector

training: bool
class BasicNet(name: str, n_unit_in: int, n_layers_out: int = 2, n_units_out: int = 100, binary_y: bool = False, nonlin: str = 'elu', lr: float = 0.0001, weight_decay: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_print: int = 50, seed: int = 42, val_split_prop: float = 0.3, patience: int = 10, n_iter_min: int = 200, clipping_value: int = 1, batch_norm: bool = True, early_stopping: bool = True, dropout: bool = False, dropout_prob: float = 0.2)

Bases: torch.nn.modules.module.Module

Basic hypothesis neural net.

Parameters
  • n_unit_in (int) – Number of features

  • n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer)

  • n_units_out (int) – Number of hidden units in each hypothesis layer

  • binary_y (bool, default False) – Whether the outcome is binary. Impacts the loss function.

  • nonlin (string, default 'elu') – Nonlinearity to use in NN. Can be ‘elu’, ‘relu’, ‘selu’ or ‘leaky_relu’.

  • lr (float) – learning rate for optimizer. step_size equivalent in the JAX version.

  • weight_decay (float) – l2 (ridge) penalty for the weights.

  • n_iter (int) – Maximum number of iterations.

  • batch_size (int) – Batch size

  • n_iter_print (int) – Number of iterations after which to print updates and check the validation loss.

  • seed (int) – Seed used

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • clipping_value (int, default 1) – Gradients clipping value

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[torch.Tensor]]
_check_tensor(X: torch.Tensor) torch.Tensor
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
_state_dict_hooks: Dict[int, Callable]
fit(X: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] = None) catenets.models.torch.base.BasicNet
forward(X: torch.Tensor) torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class PropensityNet(name: str, n_unit_in: int, n_unit_out: int, weighting_strategy: str, n_units_out_prop: int = 100, n_layers_out_prop: int = 0, nonlin: str = 'elu', lr: float = 0.0001, weight_decay: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, n_iter_print: int = 50, seed: int = 42, val_split_prop: float = 0.3, patience: int = 10, n_iter_min: int = 200, clipping_value: int = 1, batch_norm: bool = True, early_stopping: bool = True, dropout: bool = False, dropout_prob: float = 0.2)

Bases: torch.nn.modules.module.Module

Basic propensity neural net

Parameters
  • name (str) – Display name

  • n_unit_in (int) – Number of features

  • n_unit_out (int) – Number of output features

  • weighting_strategy (str) – Weighting strategy

  • n_units_out_prop (int) – Number of hidden units in each propensity score hypothesis layer

  • n_layers_out_prop (int) – Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Dense layer)

  • nonlin (string, default 'elu') – Nonlinearity to use in NN. Can be ‘elu’, ‘relu’, ‘selu’ or ‘leaky_relu’.

  • lr (float) – learning rate for optimizer. step_size equivalent in the JAX version.

  • weight_decay (float) – l2 (ridge) penalty for the weights.

  • n_iter (int) – Maximum number of iterations.

  • batch_size (int) – Batch size

  • n_iter_print (int) – Number of iterations after which to print updates and check the validation loss.

  • seed (int) – Seed used

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • clipping_value (int, default 1) – Gradients clipping value

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[torch.Tensor]]
_check_tensor(X: torch.Tensor) torch.Tensor
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
_state_dict_hooks: Dict[int, Callable]
fit(X: torch.Tensor, y: torch.Tensor) catenets.models.torch.base.PropensityNet
forward(X: torch.Tensor) torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_importance_weights(X: torch.Tensor, w: Optional[torch.Tensor] = None) torch.Tensor
loss(y_pred: torch.Tensor, y_target: torch.Tensor) torch.Tensor
training: bool
class RepresentationNet(n_unit_in: int, n_layers: int = 3, n_units: int = 200, nonlin: str = 'elu', batch_norm: bool = True)

Bases: torch.nn.modules.module.Module

Basic representation neural net

Parameters
  • n_unit_in (int) – Number of features

  • n_layers (int) – Number of shared representation layers before hypothesis layers

  • n_units (int) – Number of hidden units in each representation layer

  • nonlin (string, default 'elu') – Nonlinearity to use in NN. Can be ‘elu’, ‘relu’, ‘selu’ or ‘leaky_relu’.

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[torch.Tensor]]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(X: torch.Tensor) torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool