catenets.experiment_utils.base module

Some utils for experiments

eval_abs_error_ate(cate_pred: jax._src.basearray.Array, cate_true: jax._src.basearray.Array) jax._src.basearray.Array
eval_mse(preds: jax._src.basearray.Array, targets: jax._src.basearray.Array) jax._src.basearray.Array
eval_mse_model(inputs: jax._src.basearray.Array, targets: jax._src.basearray.Array, predict_fun: Callable, params: jax._src.basearray.Array) jax._src.basearray.Array
eval_root_mse(cate_pred: jax._src.basearray.Array, cate_true: jax._src.basearray.Array) jax._src.basearray.Array
get_all_pseudoout_models() Dict
get_all_snets() Dict
get_all_twostep_models() Dict
get_model_set(model_selection: Union[str, list] = 'all', model_params: Optional[dict] = None) Dict

Helper function to retrieve a set of models