Data and gRESIT algorithm
The data required by gRESIT should be structured in a way that each key in the data_dict corresponds to a group, and the value is a numpy.ndarray containing the samples for that group. We do not require the numpy.ndarray's to be of the same shape, but they should contain the same number of samples.
import numpy as np
rng = np.random.default_rng(42) # Set seed for reproducibility
X_1 = rng.multivariate_normal(mean=np.zeros(2), cov=np.eye(2), size=2000)
X_2 = rng.multivariate_normal(mean=np.zeros(3), cov=np.eye(3), size=2000)
X_3 = np.column_stack(
[X_1[:, 0] * X_2[:, 0] + X_1[:, 1] * X_2[:, 1], X_1[:, 1] * X_2[:, 2]]
) + 0.1 * rng.multivariate_normal(mean=np.zeros(2), cov=np.eye(2), size=2000)
data_dict = {
"X_1": X_1,
"X_2": X_2,
"X_3": X_3,
}
| Key | Shape | Dtype | Example Values |
|---|---|---|---|
X_1 |
(2000, 2) | float64 | [[0.305, -1.04], [0.75, 0.941]] |
X_2 |
(2000, 3) | float64 | [[0.253, 0.895, 0.273], [2.239, 1.43, -0.308]] |
X_3 |
(2000, 2) | float64 | [[-0.836, -0.194], [2.878, -0.318]] |
Given this data, we can run the gRESIT algorithm as follows:
from gresit.group_resit import GroupResit
from gresit.independence_tests import HSIC
from gresit.torch_models import Multioutcome_MLP
gresit = GroupResit(regressor=Multioutcome_MLP(), test=HSIC)
gresit.learn_graph(data_dict)
gresit.show_interactive()
Which produces the following interactive graph:
In the section gRESIT you will find details on all arguments and hyperparameters for gRESIT.