catenets.models.jax.model_utils module

Model utils shared across different nets

check_X_is_np(X: pandas.core.frame.DataFrame) jax._src.basearray.Array
check_shape_1d_data(y: jax._src.basearray.Array) jax._src.basearray.Array
heads_l2_penalty(params_0: jax._src.basearray.Array, params_1: jax._src.basearray.Array, n_layers_out: jax._src.basearray.Array, reg_diff: jax._src.basearray.Array, penalty_0: jax._src.basearray.Array, penalty_1: jax._src.basearray.Array) jax._src.basearray.Array
make_val_split(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: Optional[jax._src.basearray.Array] = None, val_split_prop: float = 0.3, seed: int = 42, stratify_w: bool = True) Any