Skip to content

Data and gRESIT algorithm

The data required by gRESIT should be structured in a way that each key in the data_dict corresponds to a group, and the value is a numpy.ndarray containing the samples for that group. We do not require the numpy.ndarray's to be of the same shape, but they should contain the same number of samples.

import numpy as np

rng = np.random.default_rng(42)  # Set seed for reproducibility

X_1 = rng.multivariate_normal(mean=np.zeros(2), cov=np.eye(2), size=2000)
X_2 = rng.multivariate_normal(mean=np.zeros(3), cov=np.eye(3), size=2000)

X_3 = np.column_stack(
    [X_1[:, 0] * X_2[:, 0] + X_1[:, 1] * X_2[:, 1], X_1[:, 1] * X_2[:, 2]]
) + 0.1 * rng.multivariate_normal(mean=np.zeros(2), cov=np.eye(2), size=2000)

data_dict = {
    "X_1": X_1,
    "X_2": X_2,
    "X_3": X_3,
}
Key Shape Dtype Example Values
X_1 (2000, 2) float64 [[0.305, -1.04], [0.75, 0.941]]
X_2 (2000, 3) float64 [[0.253, 0.895, 0.273], [2.239, 1.43, -0.308]]
X_3 (2000, 2) float64 [[-0.836, -0.194], [2.878, -0.318]]

Given this data, we can run the gRESIT algorithm as follows:

from gresit.group_resit import GroupResit
from gresit.independence_tests import HSIC
from gresit.torch_models import Multioutcome_MLP

gresit = GroupResit(regressor=Multioutcome_MLP(), test=HSIC)
gresit.learn_graph(data_dict)
gresit.show_interactive()

Which produces the following interactive graph:

In the section gRESIT you will find details on all arguments and hyperparameters for gRESIT.