catenets.models.jax.base module
Base modules shared across different nets
- class BaseCATENet
Bases:
sklearn.base.BaseEstimator
,sklearn.base.RegressorMixin
,abc.ABC
Base CATENet class to serve as template for all other nets
- _abc_impl = <_abc_data object>
- static _check_inputs(w: jax._src.basearray.Array, p: jax._src.basearray.Array) None
- abstract _get_predict_function() Callable
- abstract _get_train_function() Callable
- fit(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, p: Optional[jax._src.basearray.Array] = None) catenets.models.jax.base.BaseCATENet
Fit method for a CATENet. Takes covariates, outcome variable and treatment indicator as input
- Parameters
X (pd.DataFrame or np.array) – Covariate matrix
y (np.array) – Outcome vector
w (np.array) – Treatment indicator
p (np.array) – Vector of (known) treatment propensities. Currently only supported for TwoStepNets.
- fit_and_select_params(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: jax._src.basearray.Array, p: Optional[jax._src.basearray.Array] = None, param_grid: dict = {}) catenets.models.jax.base.BaseCATENet
- predict(X: jax._src.basearray.Array, return_po: bool = False, return_prop: bool = False) jax._src.basearray.Array
Predict treatment effect estimates using a CATENet. Depending on method, can also return potential outcome estimate and propensity score estimate.
- Parameters
X (pd.DataFrame or np.array) – Covariate matrix
return_po (bool, default False) – Whether to return potential outcome estimate
return_prop (bool, default False) – Whether to return propensity estimate
- Returns
- Return type
array of CATE estimates, optionally also potential outcomes and propensity
- score(X: jax._src.basearray.Array, y: jax._src.basearray.Array, sample_weight: Optional[jax._src.basearray.Array] = None) float
Return the sqrt PEHE error (Oracle metric).
- Parameters
X (pd.DataFrame or np.array) – Covariate matrix
y (np.array) – Expected potential outcome vector
- OutputHead(n_layers_out: int = 2, n_units_out: int = 100, binary_y: bool = False, n_layers_r: int = 0, n_units_r: int = 200, nonlin: str = 'elu') Any
- ReprBlock(n_layers: int = 3, n_units: int = 100, nonlin: str = 'elu') Any
- train_output_net_only(X: jax._src.basearray.Array, y: jax._src.basearray.Array, binary_y: bool = False, n_layers_out: int = 2, n_units_out: int = 100, n_layers_r: int = 0, n_units_r: int = 200, penalty_l2: float = 0.0001, step_size: float = 0.0001, n_iter: int = 10000, batch_size: int = 100, val_split_prop: float = 0.3, early_stopping: bool = True, patience: int = 10, n_iter_min: int = 200, n_iter_print: int = 50, seed: int = 42, return_val_loss: bool = False, nonlin: str = 'elu', avg_objective: bool = False) Any