catenets.models.jax.model_utils module
Model utils shared across different nets
- check_X_is_np(X: pandas.core.frame.DataFrame) jax._src.basearray.Array
- check_shape_1d_data(y: jax._src.basearray.Array) jax._src.basearray.Array
- heads_l2_penalty(params_0: jax._src.basearray.Array, params_1: jax._src.basearray.Array, n_layers_out: jax._src.basearray.Array, reg_diff: jax._src.basearray.Array, penalty_0: jax._src.basearray.Array, penalty_1: jax._src.basearray.Array) jax._src.basearray.Array
- make_val_split(X: jax._src.basearray.Array, y: jax._src.basearray.Array, w: Optional[jax._src.basearray.Array] = None, val_split_prop: float = 0.3, seed: int = 42, stratify_w: bool = True) Any