catenets.models.torch.snet module

class SNet(n_unit_in: int, binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 100, n_layers_out: int = 2, n_units_r_small: int = 50, n_units_out: int = 100, n_units_out_prop: int = 100, n_layers_out_prop: int = 2, weight_decay: float = 0.0001, penalty_orthogonal: float = 0.01, penalty_disc: float = 0, lr: float = 0.0001, n_iter: int = 10000, n_iter_min: int = 200, batch_size: int = 100, val_split_prop: float = 0.3, n_iter_print: int = 50, seed: int = 42, nonlin: str = 'elu', ortho_reg_type: str = 'abs', patience: int = 10, clipping_value: int = 1, batch_norm: bool = True, with_prop: bool = True, early_stopping: bool = True, prop_loss_multiplier: float = 1, dropout: bool = False, dropout_prob: float = 0.2)

Bases: catenets.models.torch.base.BaseCATEEstimator

Class implements SNet as discussed in Curth & van der Schaar (2021). Additionally to the version implemented in the AISTATS paper, we also include an implementation that does not have propensity heads (set with_prop=False) :param n_unit_in: Number of features :type n_unit_in: int :param binary_y: Whether the outcome is binary :type binary_y: bool, default False :param n_layers_r: Number of shared & private representation layers before the hypothesis layers. :type n_layers_r: int :param n_units_r: Number of hidden units in representation shared before the hypothesis layer. :type n_units_r: int :param n_layers_out: Number of hypothesis layers (n_layers_out x n_units_out + 1 x Linear layer) :type n_layers_out: int :param n_layers_out_prop: Number of hypothesis layers for propensity score(n_layers_out x n_units_out + 1 x Linear

layer)

Parameters
  • n_units_out (int) – Number of hidden units in each hypothesis layer

  • n_units_out_prop (int) – Number of hidden units in each propensity score hypothesis layer

  • n_units_r_small (int) – Number of hidden units in each PO functions private representation

  • weight_decay (float) – l2 (ridge) penalty

  • lr (float) – learning rate for optimizer

  • n_iter (int) – Maximum number of iterations

  • batch_size (int) – Batch size

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • n_iter_print (int) – Number of iterations after which to print updates

  • seed (int) – Seed used

  • nonlin (string, default 'elu') – Nonlinearity to use in the neural net. Can be ‘elu’, ‘relu’, ‘selu’ or ‘leaky_relu’.

  • penalty_disc (float, default zero) – Discrepancy penalty. Defaults to zero as this feature is not tested.

  • clipping_value (int, default 1) – Gradients clipping value

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[torch.Tensor]]
_forward(X: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_maximum_mean_discrepancy(X: torch.Tensor, w: torch.Tensor) torch.Tensor
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_ortho_reg() float
_parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
_state_dict_hooks: Dict[int, Callable]
_step(X: torch.Tensor, w: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
fit(X: torch.Tensor, y: torch.Tensor, w: torch.Tensor) catenets.models.torch.snet.SNet

Fit treatment models.

Parameters
  • X (torch.Tensor of shape (n_samples, n_features)) – The features to fit to

  • y (torch.Tensor of shape (n_samples,) or (n_samples, )) – The outcome variable

  • w (torch.Tensor of shape (n_samples,)) – The treatment indicator

loss(y0_pred: torch.Tensor, y1_pred: torch.Tensor, t_pred: torch.Tensor, discrepancy: torch.Tensor, y_true: torch.Tensor, t_true: torch.Tensor) torch.Tensor
predict(X: torch.Tensor, return_po: bool = False, training: bool = False) torch.Tensor

Predict treatment effects and potential outcomes

Parameters

X (array-like of shape (n_samples, n_features)) – Test-sample features

Returns

y

Return type

array-like of shape (n_samples,)

training: bool