catenets.models.torch.representation_nets module
- class BasicDragonNet(name: str, n_unit_in: int, propensity_estimator: torch.nn.modules.module.Module, binary_y: bool = False, n_layers_r: int = 3, n_units_r: int = 200, n_layers_out: int = 2, n_units_out: int = 100, weight_decay: float = 0.0001, lr: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, n_iter_print: int = 50, seed: int = 42, nonlin: str = 'elu', weighting_strategy: Optional[str] = None, penalty_disc: float = 0, batch_norm: bool = True, early_stopping: bool = True, prop_loss_multiplier: float = 1, n_iter_min: int = 200, patience: int = 10, dropout: bool = False, dropout_prob: float = 0.2)
Bases:
catenets.models.torch.base.BaseCATEEstimator
Base class for TARNet and DragonNet.
- Parameters
name (str) – Estimator name
n_unit_in (int) – Number of features
propensity_estimator (nn.Module) – Propensity estimator
binary_y (bool, default False) – Whether the outcome is binary
n_layers_out (int) – Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
n_units_out (int) – Number of hidden units in each hypothesis layer
n_layers_r (int) – Number of shared & private representation layers before the hypothesis layers.
n_units_r (int) – Number of hidden units in representation before the hypothesis layers.
weight_decay (float) – l2 (ridge) penalty
lr (float) – learning rate for optimizer
n_iter (int) – Maximum number of iterations
batch_size (int) – Batch size
val_split_prop (float) – Proportion of samples used for validation split (can be 0)
n_iter_print (int) – Number of iterations after which to print updates
seed (int) – Seed used
nonlin (string, default 'elu') – Nonlinearity to use in the neural net. Can be ‘elu’, ‘relu’, ‘selu’, ‘leaky_relu’.
weighting_strategy (optional str, None) – Whether to include propensity head and which weightening strategy to use
penalty_disc (float, default zero) – Discrepancy penalty.
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _forward(X: torch.Tensor) torch.Tensor
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _maximum_mean_discrepancy(X: torch.Tensor, w: torch.Tensor) torch.Tensor
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- abstract _step(X: torch.Tensor, w: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- fit(X: torch.Tensor, y: torch.Tensor, w: torch.Tensor) catenets.models.torch.representation_nets.BasicDragonNet
Fit the treatment models.
- Parameters
X (torch.Tensor of shape (n_samples, n_features)) – The features to fit to
y (torch.Tensor of shape (n_samples,) or (n_samples, )) – The outcome variable
w (torch.Tensor of shape (n_samples,)) – The treatment indicator
- loss(po_pred: torch.Tensor, t_pred: torch.Tensor, y_true: torch.Tensor, t_true: torch.Tensor, discrepancy: torch.Tensor) torch.Tensor
- predict(X: torch.Tensor, return_po: bool = False, training: bool = False) torch.Tensor
Predict the treatment effects
- Parameters
X (array-like of shape (n_samples, n_features)) – Test-sample features
- Returns
y
- Return type
array-like of shape (n_samples,)
- training: bool
- class DragonNet(n_unit_in: int, binary_y: bool = False, n_units_out_prop: int = 100, n_layers_out_prop: int = 0, nonlin: str = 'elu', n_units_r: int = 200, batch_norm: bool = True, dropout: bool = False, dropout_prob: float = 0.2, **kwargs: Any)
Bases:
catenets.models.torch.representation_nets.BasicDragonNet
Class implements a variant based on Shi et al (2019)’s DragonNet.
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- _step(X: torch.Tensor, w: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- training: bool
- class TARNet(n_unit_in: int, binary_y: bool = False, n_units_out_prop: int = 100, n_layers_out_prop: int = 0, nonlin: str = 'elu', penalty_disc: float = 0, batch_norm: bool = True, dropout: bool = False, dropout_prob: float = 0.2, **kwargs: Any)
Bases:
catenets.models.torch.representation_nets.BasicDragonNet
Class implements Shalit et al (2017)’s TARNet
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Optional[torch.Tensor]]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: Optional[bool]
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, Optional[Module]]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Optional[torch.nn.parameter.Parameter]]
- _state_dict_hooks: Dict[int, Callable]
- _step(X: torch.Tensor, w: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- training: bool