Skip to content

Documentation

Below, you will find the documentation of the causalAssembly project code.

Utility classes and functions related to causalAssembly. Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

ProcessCell

Representation of a single Production Line Cell (to model a station / a process in a production line environment).

A Cell can contain multiple modules (sub-graphs, which are nx.DiGraph objects).

Note that none of the term Cell, Process or Module has a strict definition. The convention is based on a production line, consisting of several cells which are to be modeled by means of smaller graphs (modules) by a user of the repository.

Source code in causalAssembly/models_dag.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
class ProcessCell:
    """
    Representation of a single Production Line Cell
    (to model a station / a process in a production line
    environment).

    A Cell can contain multiple modules (sub-graphs, which are nx.DiGraph objects).

    Note that none of the term Cell, Process or Module has a strict definition.
    The convention is based on a production line, consisting of several cells which
    are to be modeled by means of smaller graphs (modules) by a user of the repository.

    """

    def __init__(self, name: str):
        self.name = name
        self.graph: nx.DiGraph = nx.DiGraph()

        self.description: str = ""  # description of the cell.

        self.__module_prefix = "M"  # M01 vs M1?
        self.modules: dict[str, nx.DiGraph] = dict()  # {'M1': nx.DiGraph, 'M2': nx.DiGraph}
        self.module_connectors: list[tuple] = list()

        self.is_eol = False
        self.random_state = None
        self.drf: dict = dict()

    @property
    def nodes(self) -> list[str]:
        """Nodes in the graph.

        Returns:
            list[str]
        """
        return list(self.graph.nodes())

    @property
    def edges(self) -> list[tuple]:
        """Edges in the graph.

        Returns:
            list[tuple]
        """
        return list(self.graph.edges())

    @property
    def num_nodes(self) -> int:
        """Number of nodes in the graph

        Returns:
            int
        """
        return len(self.nodes)

    @property
    def num_edges(self) -> int:
        """Number of edges in the graph

        Returns:
            int
        """
        return len(self.edges)

    @property
    def sparsity(self) -> float:
        """Sparsity of the graph

        Returns:
            float: in [0,1]
        """
        s = self.num_nodes
        return self.num_edges / s / (s - 1) * 2

    @property
    def ground_truth(self) -> pd.DataFrame:
        """Returns the current ground truth as
        pandas adjacency.

        Returns:
            pd.DataFrame: Adjacenccy matrix.
        """
        return nx.to_pandas_adjacency(self.graph, weight=None)

    @property
    def causal_order(self) -> list[str]:
        """Returns the causal order of the current graph.
        Note that this order is in general not unique.

        Returns:
            list[str]: Causal order
        """
        return list(nx.topological_sort(self.graph))

    def parents(self, of_node: str) -> list[str]:
        """Return parents of node in question.

        Args:
            of_node (str): Node in question.

        Returns:
            list[str]: parent set.
        """
        return list(self.graph.predecessors(of_node))

    def save_drf(self, filename: str, location: str = None):
        """Writes a drf dict to file. Please provide the .pkl suffix!

        Args:
            filename (str): name of the file to be written e.g. examplefile.pkl
            location (str, optional): path to file in case it's not located in
                the current working directory. Defaults to None.
        """

        if not location:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "wb") as f:
            pickle.dump(self.drf, f)

    def add_module(
        self,
        graph: nx.DiGraph,
        allow_in_edges: bool = True,
        mark_hidden: bool | list = False,
    ) -> str:
        """Adds module to cell graph. Module has to be as nx.DiGraph object

        Args:
            graph (nx.DiGraph): Graph to add to cell.
            allow_in_edges (bool, optional):
                whether nodes in the module are allowed to
                have in-edges. Defaults to True.
            mark_hidden (bool | list, optional):
                If False all nodes' 'is_hidden' attribute is set to False.
                If list of node names is provided these get overwritten to True.
                Defaults to False.

        Returns:
            str: prefix of Module created
        """

        next_module_prefix = self.next_module_prefix()

        node_renaming_dict = {
            old_node_name: f"{self.name}_{next_module_prefix}_{old_node_name}"
            for old_node_name in graph.nodes()
        }

        self.modules[self.next_module_prefix()] = graph.copy()
        graph = nx.relabel_nodes(graph, node_renaming_dict)

        if allow_in_edges:  # for later: mark nodes to not have incoming edges
            nx.set_node_attributes(graph, True, NodeAttributes.ALLOW_IN_EDGES)
        else:
            nx.set_node_attributes(graph, False, NodeAttributes.ALLOW_IN_EDGES)

        nx.set_node_attributes(
            graph, False, NodeAttributes.HIDDEN
        )  # set all non-hidden -> visible by default
        if isinstance(mark_hidden, list):
            mark_hidden_renamed = [
                f"{self.name}_{next_module_prefix}_{new_name}" for new_name in mark_hidden
            ]
            overwrite_dict = {node: {NodeAttributes.HIDDEN: True} for node in mark_hidden_renamed}
            nx.set_node_attributes(
                graph, values=overwrite_dict
            )  # only overwrite the ones specified
        # TODO relabel attributes, i.e. name of the parents has changed now?
        # .update_attributes or so or keep and remove prefixes in bayesian network creation?
        self.graph = nx.compose(self.graph, graph)

        return next_module_prefix

    def input_cellgraph_directly(self, graph: nx.DiGraph, allow_in_edges: bool = False):
        """Allow to input graphs on a cell-level. This should only be done if the graph
        is already available for the entire cell, otherwise `add_module()` is preferred.

        Args:
            graph (nx.DiGraph): Cell graph to input
            allow_in_edges (bool, optional): Defaults to False.
        """
        if allow_in_edges:  # for later: mark nodes to not have incoming edges
            nx.set_node_attributes(graph, True, NodeAttributes.ALLOW_IN_EDGES)
        else:
            nx.set_node_attributes(graph, False, NodeAttributes.ALLOW_IN_EDGES)

        node_renaming_dict = {
            old_node_name: f"{self.name}_{old_node_name}" for old_node_name in graph.nodes()
        }
        graph = nx.relabel_nodes(graph, node_renaming_dict)

        self.graph = nx.compose(self.graph, graph)

    def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
        """Draw from the trained DRF.

        Args:
            size (int, optional): Number of samples to be drawn. Defaults to 10.
            smoothed (bool, optional): If set to true, marginal distributions will
                be sampled from smoothed bootstraps. Defaults to True.

        Returns:
            pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
        """
        return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed)

    def interventional_sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
        """Draw from the trained DRF.

        Args:
            size (int, optional): Number of samples to be drawn. Defaults to 10.
            smoothed (bool, optional): If set to true, marginal distributions will
                be sampled from smoothed bootstraps. Defaults to True.

        Returns:
            pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
        """
        return _interventional_sample_from_drf(prod_object=self, size=size, smoothed=smoothed)

    def _generate_random_dag(self, n_nodes: int = 5, p: float = 0.1) -> nx.DiGraph:
        """
        Creates a random DAG by
        taking an arbitrary ordering of the specified number of nodes,
        and then considers edges from node i to j only if i < j.
        That constraint leads to DAGness by construction.

        Args:
            n_nodes (int, optional): Defaults to 5.
            p (float, optional): Defaults to .1.

        Returns:
            nx.DiGraph:
        """
        dag = nx.DiGraph()
        dag.add_nodes_from(range(0, n_nodes))

        causal_order = list(dag.nodes)
        self.random_state.shuffle(causal_order)

        all_forward_edges = itertools.combinations(causal_order, 2)
        edges = np.array(list(all_forward_edges))

        random_choice = self.random_state.choice([False, True], p=[1 - p, p], size=edges.shape[0])

        dag.add_edges_from(edges[random_choice])
        return dag

    def add_random_module(self, n_nodes: int = 7, p: float = 0.10):
        randomdag = self._generate_random_dag(n_nodes=n_nodes, p=p)
        self.add_module(graph=randomdag, allow_in_edges=True, mark_hidden=False)

    def connect_by_module(self, m1: str, m2: str, edges: list[tuple]):
        """Connect two modules (by name, e.g. M2, M4) of the cell by a list
        of edges with the original node names.

        Args:
            m1: str
            m2: str
            edges: list[tuple]: use the original node names before they have entered into the cell,
                i.e. not with Cy_Mx prefix
        """
        self.__verify_edges_are_allowed(m1=m1, m2=m2, edges=edges)

        node_prefix_m1 = f"{self.name}_{m1}"
        node_prefix_m2 = f"{self.name}_{m2}"

        new_edges = [
            (f"{node_prefix_m1}_{edge[0]}", f"{node_prefix_m2}_{edge[1]}") for edge in edges
        ]

        [self.module_connectors.append(edge) for edge in new_edges]

        self.graph.add_edges_from(new_edges)

    def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph:
        """
        Add random edges to graph according to proportion,
        with restriction specified in node attributes.

        Args:
            sparsity (float, optional): Sparsity parameter in (0,1). Defaults to 0.1.

        Raises:
            NotImplementedError: when node attributes are not set.
            TypeError: when resulting graph contains cycles.

        Returns:
            nx.DiGraph: DAG with new edges added.
        """

        arrow_head_candidates = get_arrow_head_candidates_from_graph(
            graph=self.graph, node_attributes_to_filter=NodeAttributes.ALLOW_IN_EDGES
        )

        arrow_tail_candidates = [node for node in self.nodes if node not in arrow_head_candidates]

        potential_edges = tuples_from_cartesian_product(
            l1=arrow_tail_candidates, l2=arrow_head_candidates
        )
        num_choices = int(np.ceil(sparsity * len(potential_edges)))

        ### choose edges uniformly according to sparsity parameter
        chosen_edges = [
            potential_edges[i]
            for i in self.random_state.choice(
                a=len(potential_edges), size=num_choices, replace=False
            )
        ]

        self.graph.add_edges_from(chosen_edges)

        if not nx.is_directed_acyclic_graph(self.graph):
            raise TypeError(
                "The randomly chosen edges induce cycles, this is not supposed to happen."
            )
        return self.graph

    def __repr__(self):
        return f"ProcessCell(name={self.name})"

    def __str__(self):
        cell_description = {
            "Cell Name: ": self.name,
            "Description:": self.description if self.description else "n.a.",
            "Modules:": self.no_of_modules,
            "Nodes: ": self.num_nodes,
        }
        s = str()
        for info, info_text in cell_description.items():
            s += f"{info:<14}{info_text:>5}\n"

        return s

    def __verify_edges_are_allowed(self, m1: str, m2: str, edges: list[tuple]):
        """Check whether all starting point nodes
        (first value in edge tuple) are allowed.

        Args:
            m1 (str): Module1
            m2 (str): Module2
            edges (list[tuple]): Edges

        Raises:
            ValueError: starting node not in M1
            ValueError: ending node not in M2
        """
        source_nodes = set([e[0] for e in edges])
        target_nodes = set([e[1] for e in edges])
        m1_nodes = set(self.modules.get(m1).nodes())
        m2_nodes = set(self.modules.get(m2).nodes())

        if not source_nodes.issubset(m1_nodes):
            raise ValueError(f"source nodes: {source_nodes} not include in {m1}s nodes: {m1_nodes}")
        if not target_nodes.issubset(m2_nodes):
            raise ValueError(f"target nodes: {target_nodes} not include in {m2}s nodes: {m2_nodes}")

    def next_module_prefix(self) -> str:
        """Return the next module prefix, e.g.
        if there are already 3 modules connected to the cell,
        will return module_prefix4

        Returns:
            str: module_prefix
        """
        return f"{self.__module_prefix}{self.no_of_modules + 1}"

    @property
    def module_prefix(self) -> str:
        return self.__module_prefix

    @module_prefix.setter
    def module_prefix(self, module_prefix: str):
        if not isinstance(module_prefix, str):
            raise ValueError("please only use strings as module prefix")

        self.__module_prefix = module_prefix

    @property
    def no_of_modules(self) -> int:
        return len(self.modules)

    def get_nodes_by_attribute(self, attr_name: str, submodule: str = None) -> list:
        pass

    def get_available_attributes(self):
        available_attributes = set()
        for node_tuple in self.graph.nodes(data=True):
            for attribute_name in node_tuple[1].keys():
                available_attributes.add(attribute_name)

        return list(available_attributes)

    def to_cpdag(self) -> PDAG:
        return dag2cpdag(dag=self.graph)

    def show(
        self,
        meta_desc: str = "",
    ):
        """Plots the cell graph by giving extra weight to nodes
        with high in- and out-degree.

        Args:
            meta_desc (str, optional): Defaults to "".

        """
        cmap = plt.get_cmap("cividis")
        fig, ax = plt.subplots()
        center: np.ndarray = np.array([0, 0])
        pos = nx.spring_layout(
            self.graph,
            center=center,
            seed=10,
            k=50,
        )

        max_in_degree = max([d for _, d in self.graph.in_degree()])
        max_out_degree = max([d for _, d in self.graph.out_degree()])

        nx.draw_networkx_nodes(
            self.graph,
            pos=pos,
            ax=ax,
            cmap=cmap,
            vmin=-0.2,
            vmax=1,
            node_color=[
                (d + 10) / (max_in_degree + 10) for _, d in self.graph.in_degree(self.nodes)
            ],
            node_size=[
                500 * (d + 1) / (max_out_degree + 1) for _, d in self.graph.out_degree(self.nodes)
            ],
        )

        nx.draw_networkx_edges(
            self.graph,
            pos=pos,
            ax=ax,
            alpha=0.2,
            arrowsize=8,
            width=0.5,
            connectionstyle="arc3,rad=0.3",
        )

        ax.text(
            center[0],
            center[1] + 1.2,
            self.name + f"\n{meta_desc}",
            horizontalalignment="center",
            fontsize=12,
        )

        ax.axis("off")

    def _plot_cellgraph(
        self,
        ax,
        node_color,
        node_size,
        center=np.array([0, 0]),
        with_edges=True,
        with_box=True,
        meta_desc="",
    ):
        """Plots the cell graph by giving extra weight to nodes
        with high in- and out-degree.

        Args:
            with_edges (bool, optional): Defaults to True.
            with_box (bool, optional): Defaults to True.
            meta_desc (str, optional): Defaults to "".
            center (_type_, optional): Defaults to np.array([0, 0]).
            fig_size (tuple, optional): Defaults to (2, 8).
        """
        cmap = plt.get_cmap("cividis")

        pos = nx.spring_layout(
            self.graph,
            center=center,
            seed=10,
            k=50,
        )

        nx.draw_networkx_nodes(
            self.graph,
            pos=pos,
            ax=ax,
            cmap=cmap,
            vmin=-0.2,
            vmax=1,
            node_color=node_color,
            node_size=node_size,
        )

        if with_edges:
            nx.draw_networkx_edges(
                self.graph,
                pos=pos,
                ax=ax,
                alpha=0.2,
                arrowsize=8,
                width=0.5,
                connectionstyle="arc3,rad=0.3",
            )

        ax.text(
            center[0],
            center[1] + 1.2,
            self.name + f"\n{meta_desc}",
            horizontalalignment="center",
            fontsize=12,
        )

        if with_box:
            ax.add_collection(
                PatchCollection(
                    [
                        FancyBboxPatch(
                            center - [2, 1],
                            4,
                            2.6,
                            boxstyle=BoxStyle("Round", pad=0.02),
                        )
                    ],
                    alpha=0.2,
                    color="gray",
                )
            )

        ax.axis("off")
        return pos

causal_order: list[str] property

Returns the causal order of the current graph. Note that this order is in general not unique.

Returns:

Type Description
list[str]

list[str]: Causal order

edges: list[tuple] property

Edges in the graph.

Returns:

Type Description
list[tuple]

list[tuple]

ground_truth: pd.DataFrame property

Returns the current ground truth as pandas adjacency.

Returns:

Type Description
DataFrame

pd.DataFrame: Adjacenccy matrix.

nodes: list[str] property

Nodes in the graph.

Returns:

Type Description
list[str]

list[str]

num_edges: int property

Number of edges in the graph

Returns:

Type Description
int

int

num_nodes: int property

Number of nodes in the graph

Returns:

Type Description
int

int

sparsity: float property

Sparsity of the graph

Returns:

Name Type Description
float float

in [0,1]

__verify_edges_are_allowed(m1, m2, edges)

Check whether all starting point nodes (first value in edge tuple) are allowed.

Parameters:

Name Type Description Default
m1 str

Module1

required
m2 str

Module2

required
edges list[tuple]

Edges

required

Raises:

Type Description
ValueError

starting node not in M1

ValueError

ending node not in M2

Source code in causalAssembly/models_dag.py
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
def __verify_edges_are_allowed(self, m1: str, m2: str, edges: list[tuple]):
    """Check whether all starting point nodes
    (first value in edge tuple) are allowed.

    Args:
        m1 (str): Module1
        m2 (str): Module2
        edges (list[tuple]): Edges

    Raises:
        ValueError: starting node not in M1
        ValueError: ending node not in M2
    """
    source_nodes = set([e[0] for e in edges])
    target_nodes = set([e[1] for e in edges])
    m1_nodes = set(self.modules.get(m1).nodes())
    m2_nodes = set(self.modules.get(m2).nodes())

    if not source_nodes.issubset(m1_nodes):
        raise ValueError(f"source nodes: {source_nodes} not include in {m1}s nodes: {m1_nodes}")
    if not target_nodes.issubset(m2_nodes):
        raise ValueError(f"target nodes: {target_nodes} not include in {m2}s nodes: {m2_nodes}")

add_module(graph, allow_in_edges=True, mark_hidden=False)

Adds module to cell graph. Module has to be as nx.DiGraph object

Parameters:

Name Type Description Default
graph DiGraph

Graph to add to cell.

required
allow_in_edges bool

whether nodes in the module are allowed to have in-edges. Defaults to True.

True
mark_hidden bool | list

If False all nodes' 'is_hidden' attribute is set to False. If list of node names is provided these get overwritten to True. Defaults to False.

False

Returns:

Name Type Description
str str

prefix of Module created

Source code in causalAssembly/models_dag.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def add_module(
    self,
    graph: nx.DiGraph,
    allow_in_edges: bool = True,
    mark_hidden: bool | list = False,
) -> str:
    """Adds module to cell graph. Module has to be as nx.DiGraph object

    Args:
        graph (nx.DiGraph): Graph to add to cell.
        allow_in_edges (bool, optional):
            whether nodes in the module are allowed to
            have in-edges. Defaults to True.
        mark_hidden (bool | list, optional):
            If False all nodes' 'is_hidden' attribute is set to False.
            If list of node names is provided these get overwritten to True.
            Defaults to False.

    Returns:
        str: prefix of Module created
    """

    next_module_prefix = self.next_module_prefix()

    node_renaming_dict = {
        old_node_name: f"{self.name}_{next_module_prefix}_{old_node_name}"
        for old_node_name in graph.nodes()
    }

    self.modules[self.next_module_prefix()] = graph.copy()
    graph = nx.relabel_nodes(graph, node_renaming_dict)

    if allow_in_edges:  # for later: mark nodes to not have incoming edges
        nx.set_node_attributes(graph, True, NodeAttributes.ALLOW_IN_EDGES)
    else:
        nx.set_node_attributes(graph, False, NodeAttributes.ALLOW_IN_EDGES)

    nx.set_node_attributes(
        graph, False, NodeAttributes.HIDDEN
    )  # set all non-hidden -> visible by default
    if isinstance(mark_hidden, list):
        mark_hidden_renamed = [
            f"{self.name}_{next_module_prefix}_{new_name}" for new_name in mark_hidden
        ]
        overwrite_dict = {node: {NodeAttributes.HIDDEN: True} for node in mark_hidden_renamed}
        nx.set_node_attributes(
            graph, values=overwrite_dict
        )  # only overwrite the ones specified
    # TODO relabel attributes, i.e. name of the parents has changed now?
    # .update_attributes or so or keep and remove prefixes in bayesian network creation?
    self.graph = nx.compose(self.graph, graph)

    return next_module_prefix

connect_by_module(m1, m2, edges)

Connect two modules (by name, e.g. M2, M4) of the cell by a list of edges with the original node names.

Parameters:

Name Type Description Default
m1 str

str

required
m2 str

str

required
edges list[tuple]

list[tuple]: use the original node names before they have entered into the cell, i.e. not with Cy_Mx prefix

required
Source code in causalAssembly/models_dag.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def connect_by_module(self, m1: str, m2: str, edges: list[tuple]):
    """Connect two modules (by name, e.g. M2, M4) of the cell by a list
    of edges with the original node names.

    Args:
        m1: str
        m2: str
        edges: list[tuple]: use the original node names before they have entered into the cell,
            i.e. not with Cy_Mx prefix
    """
    self.__verify_edges_are_allowed(m1=m1, m2=m2, edges=edges)

    node_prefix_m1 = f"{self.name}_{m1}"
    node_prefix_m2 = f"{self.name}_{m2}"

    new_edges = [
        (f"{node_prefix_m1}_{edge[0]}", f"{node_prefix_m2}_{edge[1]}") for edge in edges
    ]

    [self.module_connectors.append(edge) for edge in new_edges]

    self.graph.add_edges_from(new_edges)

connect_by_random_edges(sparsity=0.1)

Add random edges to graph according to proportion, with restriction specified in node attributes.

Parameters:

Name Type Description Default
sparsity float

Sparsity parameter in (0,1). Defaults to 0.1.

0.1

Raises:

Type Description
NotImplementedError

when node attributes are not set.

TypeError

when resulting graph contains cycles.

Returns:

Type Description
DiGraph

nx.DiGraph: DAG with new edges added.

Source code in causalAssembly/models_dag.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph:
    """
    Add random edges to graph according to proportion,
    with restriction specified in node attributes.

    Args:
        sparsity (float, optional): Sparsity parameter in (0,1). Defaults to 0.1.

    Raises:
        NotImplementedError: when node attributes are not set.
        TypeError: when resulting graph contains cycles.

    Returns:
        nx.DiGraph: DAG with new edges added.
    """

    arrow_head_candidates = get_arrow_head_candidates_from_graph(
        graph=self.graph, node_attributes_to_filter=NodeAttributes.ALLOW_IN_EDGES
    )

    arrow_tail_candidates = [node for node in self.nodes if node not in arrow_head_candidates]

    potential_edges = tuples_from_cartesian_product(
        l1=arrow_tail_candidates, l2=arrow_head_candidates
    )
    num_choices = int(np.ceil(sparsity * len(potential_edges)))

    ### choose edges uniformly according to sparsity parameter
    chosen_edges = [
        potential_edges[i]
        for i in self.random_state.choice(
            a=len(potential_edges), size=num_choices, replace=False
        )
    ]

    self.graph.add_edges_from(chosen_edges)

    if not nx.is_directed_acyclic_graph(self.graph):
        raise TypeError(
            "The randomly chosen edges induce cycles, this is not supposed to happen."
        )
    return self.graph

input_cellgraph_directly(graph, allow_in_edges=False)

Allow to input graphs on a cell-level. This should only be done if the graph is already available for the entire cell, otherwise add_module() is preferred.

Parameters:

Name Type Description Default
graph DiGraph

Cell graph to input

required
allow_in_edges bool

Defaults to False.

False
Source code in causalAssembly/models_dag.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def input_cellgraph_directly(self, graph: nx.DiGraph, allow_in_edges: bool = False):
    """Allow to input graphs on a cell-level. This should only be done if the graph
    is already available for the entire cell, otherwise `add_module()` is preferred.

    Args:
        graph (nx.DiGraph): Cell graph to input
        allow_in_edges (bool, optional): Defaults to False.
    """
    if allow_in_edges:  # for later: mark nodes to not have incoming edges
        nx.set_node_attributes(graph, True, NodeAttributes.ALLOW_IN_EDGES)
    else:
        nx.set_node_attributes(graph, False, NodeAttributes.ALLOW_IN_EDGES)

    node_renaming_dict = {
        old_node_name: f"{self.name}_{old_node_name}" for old_node_name in graph.nodes()
    }
    graph = nx.relabel_nodes(graph, node_renaming_dict)

    self.graph = nx.compose(self.graph, graph)

interventional_sample_from_drf(size=10, smoothed=True)

Draw from the trained DRF.

Parameters:

Name Type Description Default
size int

Number of samples to be drawn. Defaults to 10.

10
smoothed bool

If set to true, marginal distributions will be sampled from smoothed bootstraps. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame that follows the distribution implied by the ground truth.

Source code in causalAssembly/models_dag.py
365
366
367
368
369
370
371
372
373
374
375
376
def interventional_sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
    """Draw from the trained DRF.

    Args:
        size (int, optional): Number of samples to be drawn. Defaults to 10.
        smoothed (bool, optional): If set to true, marginal distributions will
            be sampled from smoothed bootstraps. Defaults to True.

    Returns:
        pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
    """
    return _interventional_sample_from_drf(prod_object=self, size=size, smoothed=smoothed)

next_module_prefix()

Return the next module prefix, e.g. if there are already 3 modules connected to the cell, will return module_prefix4

Returns:

Name Type Description
str str

module_prefix

Source code in causalAssembly/models_dag.py
515
516
517
518
519
520
521
522
523
def next_module_prefix(self) -> str:
    """Return the next module prefix, e.g.
    if there are already 3 modules connected to the cell,
    will return module_prefix4

    Returns:
        str: module_prefix
    """
    return f"{self.__module_prefix}{self.no_of_modules + 1}"

parents(of_node)

Return parents of node in question.

Parameters:

Name Type Description Default
of_node str

Node in question.

required

Returns:

Type Description
list[str]

list[str]: parent set.

Source code in causalAssembly/models_dag.py
250
251
252
253
254
255
256
257
258
259
def parents(self, of_node: str) -> list[str]:
    """Return parents of node in question.

    Args:
        of_node (str): Node in question.

    Returns:
        list[str]: parent set.
    """
    return list(self.graph.predecessors(of_node))

sample_from_drf(size=10, smoothed=True)

Draw from the trained DRF.

Parameters:

Name Type Description Default
size int

Number of samples to be drawn. Defaults to 10.

10
smoothed bool

If set to true, marginal distributions will be sampled from smoothed bootstraps. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame that follows the distribution implied by the ground truth.

Source code in causalAssembly/models_dag.py
352
353
354
355
356
357
358
359
360
361
362
363
def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
    """Draw from the trained DRF.

    Args:
        size (int, optional): Number of samples to be drawn. Defaults to 10.
        smoothed (bool, optional): If set to true, marginal distributions will
            be sampled from smoothed bootstraps. Defaults to True.

    Returns:
        pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
    """
    return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed)

save_drf(filename, location=None)

Writes a drf dict to file. Please provide the .pkl suffix!

Parameters:

Name Type Description Default
filename str

name of the file to be written e.g. examplefile.pkl

required
location str

path to file in case it's not located in the current working directory. Defaults to None.

None
Source code in causalAssembly/models_dag.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def save_drf(self, filename: str, location: str = None):
    """Writes a drf dict to file. Please provide the .pkl suffix!

    Args:
        filename (str): name of the file to be written e.g. examplefile.pkl
        location (str, optional): path to file in case it's not located in
            the current working directory. Defaults to None.
    """

    if not location:
        location = Path().resolve()

    location_path = Path(location, filename)

    with open(location_path, "wb") as f:
        pickle.dump(self.drf, f)

show(meta_desc='')

Plots the cell graph by giving extra weight to nodes with high in- and out-degree.

Parameters:

Name Type Description Default
meta_desc str

Defaults to "".

''
Source code in causalAssembly/models_dag.py
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
def show(
    self,
    meta_desc: str = "",
):
    """Plots the cell graph by giving extra weight to nodes
    with high in- and out-degree.

    Args:
        meta_desc (str, optional): Defaults to "".

    """
    cmap = plt.get_cmap("cividis")
    fig, ax = plt.subplots()
    center: np.ndarray = np.array([0, 0])
    pos = nx.spring_layout(
        self.graph,
        center=center,
        seed=10,
        k=50,
    )

    max_in_degree = max([d for _, d in self.graph.in_degree()])
    max_out_degree = max([d for _, d in self.graph.out_degree()])

    nx.draw_networkx_nodes(
        self.graph,
        pos=pos,
        ax=ax,
        cmap=cmap,
        vmin=-0.2,
        vmax=1,
        node_color=[
            (d + 10) / (max_in_degree + 10) for _, d in self.graph.in_degree(self.nodes)
        ],
        node_size=[
            500 * (d + 1) / (max_out_degree + 1) for _, d in self.graph.out_degree(self.nodes)
        ],
    )

    nx.draw_networkx_edges(
        self.graph,
        pos=pos,
        ax=ax,
        alpha=0.2,
        arrowsize=8,
        width=0.5,
        connectionstyle="arc3,rad=0.3",
    )

    ax.text(
        center[0],
        center[1] + 1.2,
        self.name + f"\n{meta_desc}",
        horizontalalignment="center",
        fontsize=12,
    )

    ax.axis("off")

ProductionLineGraph

Blueprint of a Production Line Graph.

A Production Line consists of multiple Cells, each Cell can contain multiple modules. Modules can be instantiated randomly or manually. Cellgraphs and linegraphs can be instantiated directly from nx.DiGraph objects. Similarly, edges can be drawn at random (obeying certain probability choices that can be set by the user) between cells/moduls or manually.

Besides populating a production line with cell/module-graphs one can obtain semi-synthetic data obeying the standard causal assumptions:

1. Markov Property
2. Faithfulness

This can be achieved by fitting distributional random forests to the line/cell-graphs and draw data from these. A random number stream is initiated when calling this class. If desired this can be overwritten manually.

Source code in causalAssembly/models_dag.py
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
class ProductionLineGraph:
    """Blueprint of a Production Line Graph.

    A Production Line consists of multiple Cells, each Cell can contain multiple modules.
    Modules can be instantiated randomly or manually. Cellgraphs and linegraphs can be
    instantiated directly from nx.DiGraph objects. Similarly, edges can be drawn at random
    (obeying certain probability choices that can be set by the user) between cells/moduls
    or manually.

    Besides populating a production line with cell/module-graphs one can obtain
    semi-synthetic data obeying the standard causal assumptions:

        1. Markov Property
        2. Faithfulness

    This can be achieved by fitting distributional random forests to the line/cell-graphs
    and draw data from these. A random number stream is initiated when calling this class.
    If desired this can be overwritten manually.

    """

    def __init__(self):
        self._random_state = np.random.default_rng(seed=2023)
        self.cells: dict[str, ProcessCell] = dict()
        self.cell_prefix = "C"
        self.cell_connectors: list[tuple] = list()
        self.cell_connector_edges = list()
        self.cell_order = list()
        self.drf: dict = dict()
        self.interventional_drf: dict = dict()
        self.__init_mutilated_dag()

    @property
    def random_state(self):
        return self._random_state

    @random_state.setter
    def random_state(self, r: np.random.Generator):
        if not isinstance(r, np.random.Generator):
            raise AssertionError("Specify numpy random number generator object!")
        self._random_state = r

    def __init_mutilated_dag(self):
        self.mutilated_dags = dict()

    @property
    def graph(self) -> nx.DiGraph:
        """
        Returns a nx.DiGraph object of the actual graph.

        The graph is only built HERE, i.e. all ProcessCells exist standalone in self.cells,
        with no connections between their nodes yet.

        The edges are stored in self.cell_connetor_edges, where they are added by random methods
        or by user (the dag builder) himself.

        ATTENTION: you can not work on self.graph and add manually edges, nodes and expect them to
        work.

        Returns nx.DiGraph

        -------

        """
        g = nx.DiGraph()
        for cell in self.cells.values():
            g = nx.compose(g, cell.graph)

        g.add_edges_from(self.cell_connector_edges)

        if not nx.is_directed_acyclic_graph(g):
            raise TypeError(
                "There are cycles in the graph, \
                this is not supposed to happen."
            )

        return g

    @property
    def nodes(self) -> list[str]:
        """Nodes in the graph.

        Returns:
            list[str]
        """
        return list(self.graph.nodes())

    @property
    def edges(self) -> list[tuple]:
        """Edges in the graph.

        Returns:
            list[tuple]
        """
        return list(self.graph.edges())

    @property
    def num_nodes(self) -> int:
        """Number of nodes in the graph

        Returns:
            int
        """
        return len(self.nodes)

    @property
    def num_edges(self) -> int:
        """Number of edges in the graph

        Returns:
            int
        """
        return len(self.edges)

    @property
    def sparsity(self) -> float:
        """Sparsity of the graph

        Returns:
            float: in [0,1]
        """
        s = self.num_nodes
        return self.num_edges / s / (s - 1) * 2

    @property
    def ground_truth(self) -> pd.DataFrame:
        """Returns the current ground truth as
        pandas adjacency.

        Returns:
            pd.DataFrame: Adjacenccy matrix.
        """
        return nx.to_pandas_adjacency(self.graph, weight=None)

    def _get_union_graph(self) -> nx.DiGraph:
        if not self.cells:
            raise AssertionError("Your pline has no cells. Within has no meaning.")
        union_graph = nx.DiGraph()
        for _, station_graph in self.cells.items():
            union_graph = nx.union(union_graph, station_graph.graph)
        return union_graph

    @property
    def within_adjacency(self) -> pd.DataFrame:
        """Returns adjacency matrix ignoring all
        between-cell edges.

        Returns:
            pd.DataFrame: adjacency matrix
        """
        union_graph = self._get_union_graph()
        return nx.to_pandas_adjacency(union_graph)

    @property
    def between_adjacency(self) -> pd.DataFrame:
        """Returns adjacency matrix ignoring all
        within-cell edges.

        Returns:
            pd.DataFrame: adjacency matrix
        """
        union_graph = self._get_union_graph()
        return nx.to_pandas_adjacency(nx.difference(self.graph, union_graph))

    @property
    def causal_order(self) -> list[str]:
        """Returns the causal order of the current graph.
        Note that this order is in general not unique.

        Returns:
            list[str]: Causal order
        """
        return list(nx.topological_sort(self.graph))

    def parents(self, of_node: str) -> list[str]:
        """Return parents of node in question.

        Args:
            of_node (str): Node in question.

        Returns:
            list[str]: parent set.
        """
        return list(self.graph.predecessors(of_node))

    def to_cpdag(self) -> PDAG:
        return dag2cpdag(dag=self.graph)

    def get_nodes_of_station(self, station_name: str) -> list:
        """Returns nodes in chosen Station.

        Args:
            station_name (str): name of station.

        Raises:
            AssertionError: if station name doesn't match pline.

        Returns:
            list: nodes in chosen station
        """
        if station_name not in self.cell_order:
            raise AssertionError("Station name not among cells.")

        return self.cells[station_name].nodes

    def __add_cell(self, cell: ProcessCell) -> ProcessCell:
        cell_names = [cell_name for cell_name in self.cells.keys()]

        if cell.is_eol and any([cell.is_eol for cell in self.cells.values()]):
            raise AssertionError(
                f"Cell: {[cell for cell in self.cells.values() if cell.is_eol]} "
                f"is already EOL Cell in ProductionLineGraph."
            )

        if cell.name not in cell_names:
            self.cells[cell.name] = cell

            return cell

        raise ValueError(f"A cell with name: {cell.name} is already in the Production Line.")

    def new_cell(self, name: str = None, is_eol: bool = False) -> ProcessCell:
        """Add a new cell to the production line.

        If no name is given, cell name is given by counting available cells + 1

        Args:
            name (str, optional): Defaults to None.
            is_eol (bool, optional): Whether cell is end-of-line. Defaults to False.

        Returns:
            ProcessCell
        """
        if name:
            c = ProcessCell(name=name)

        else:
            actual_no_of_cells = len(self.cells.values())
            c = ProcessCell(name=f"{self.cell_prefix}{actual_no_of_cells}")

        c.random_state = self.random_state

        c.is_eol = is_eol
        self.__add_cell(cell=c)
        self.cell_order.append(c.name)
        return c

    def connect_cells(
        self,
        forward_probs: list[float] = [0.1, 0.05],
    ):
        """Randomly connects cells in a ProductionLineGraph according to a forwards logic.

        Args:
            forward_probs (list[float], optional): Array of sparsity scalars of
                dimension max_forward. Defaults to [0.1, 0.05].
        """
        # assume cells are initiated in order
        # otherwise allow to change order

        max_forward = len(forward_probs)
        cells = list(self.cells.values())
        no_of_cells = len(self.cells)

        for cell_idx, cell in enumerate(cells):
            prob_it = 0  # prob it(erator)

            for forwards in range(1, max_forward + 1):
                forward_cell_idx = cell_idx + forwards

                if forward_cell_idx < no_of_cells:
                    forward_cell = cells[forward_cell_idx]
                    chosen_edges = choose_edges_from_cells_randomly(
                        from_cell=cell,
                        to_cell=forward_cell,
                        probability=forward_probs[prob_it],
                        rng=self.random_state,
                    )

                    prob_it += 1  # FIXME: a bit ugly and hard to read
                    self.cell_connector_edges.extend(chosen_edges)

            if eol_cell := self.eol_cell:
                eol_cell_idx = cells.index(eol_cell)

                if cell_idx + max_forward < eol_cell_idx:
                    eol_prob = forward_probs[-1]
                    chosen_eol_edges = choose_edges_from_cells_randomly(
                        from_cell=cell,
                        to_cell=eol_cell,
                        probability=eol_prob,
                        rng=self.random_state,
                    )

                    self.cell_connector_edges.extend(chosen_eol_edges)

    def copy(self) -> ProductionLineGraph:
        """Makes a full copy of the current
        ProductionLineGraph object

        Returns:
            ProductionLineGraph: copyied object.
        """
        copy_graph = ProductionLineGraph()
        for station in self.cell_order:
            copy_graph.new_cell(station)
            # make sure its sorted
            sorted_graph = nx.DiGraph()
            sorted_graph.add_nodes_from(
                sorted(self.cells[station].nodes, key=lambda x: int(x.rpartition("_")[2]))
            )
            sorted_graph.add_edges_from(self.cells[station].edges)
            copy_graph.cells[station].graph = sorted_graph

        between_cell_edges = nx.difference(self.graph, copy_graph.graph).edges()
        copy_graph.connect_across_cells_manually(edges=between_cell_edges)
        return copy_graph

    def connect_across_cells_manually(self, edges: list[tuple]):
        """Add edges manually across cells. You need to give the full name
        Args:
            edges (list[tuple]): list of edges to add
        """
        self.cell_connector_edges.extend(edges)

    def intervene_on(self, nodes_values: dict[str, RandomSymbol | float]):
        """Specify hard or soft intervention. If you want to intervene
        upon more than one node provide a list of nodes to intervene on
        and a list of corresponding values to set these nodes to.
        (see example). The mutilated dag will automatically be
        stored in `mutiliated_dags`.

        Args:
            nodes_values (dict[str, RandomSymbol | float]): either single real
                number or sympy.stats.RandomSymbol. If you like to intervene on
                more than one node, just provide more key-value pairs.

        Raises:
            AssertionError: If node(s) are not in the graph
        """
        if not self.drf:
            raise AssertionError("You need to train a drf first.")
        drf_replace = {}

        if not set(nodes_values.keys()).issubset(set(self.nodes)):
            raise AssertionError(
                "One or more nodes you want to intervene upon are not in the graph."
            )

        mutilated_dag = self.graph.copy()

        for node, value in nodes_values.items():
            old_incoming = self.parents(of_node=node)
            edges_to_remove = [(old, node) for old in old_incoming]
            mutilated_dag.remove_edges_from(edges_to_remove)
            drf_replace[node] = value

        self.mutilated_dags[
            f"do({list(nodes_values.keys())})"
        ] = mutilated_dag  # specifiying the same set twice will override

        self.interventional_drf[f"do({list(nodes_values.keys())})"] = drf_replace

    @property
    def interventions(self) -> list:
        """Returns all interventions performed on the original graph

        Returns:
            list: list of intervened upon nodes in do(x) notation.
        """
        return list(self.mutilated_dags.keys())

    def interventional_amat(self, which_intervention: int | str) -> pd.DataFrame:
        """Returns the adjacency matrix of a chosen mutilated DAG.

        Args:
            which_intervention (int | str): Integer count of your chosen intervention or
                literal string.

        Raises:
            ValueError: "The intervention you provide does not exist."

        Returns:
            pd.DataFrame: Adjacency matrix.
        """
        if isinstance(which_intervention, str) and which_intervention not in self.interventions:
            raise ValueError("The intervention you provide does not exist.")

        if isinstance(which_intervention, int) and which_intervention > len(self.interventions):
            raise ValueError("The intervention you index does not exist.")

        if isinstance(which_intervention, int):
            which_intervention = self.interventions[which_intervention]

        mutilated_dag = self.mutilated_dags[which_intervention].copy()
        return nx.to_pandas_adjacency(mutilated_dag, weight=None)

    @classmethod
    def get_ground_truth(cls) -> ProductionLineGraph:
        """Loads in the ground_truth as described in the paper:
        causalAssembly: Generating Realistic Production Data for
        Benchmarking Causal Discovery
        Returns:
            ProductionLineGraph: ground_truth for cells and line.
        """

        gt_response = requests.get(DATA_GROUND_TRUTH, timeout=5)
        ground_truth = json.loads(gt_response.text)

        assembly_line = json_graph.adjacency_graph(ground_truth)

        stations = ["Station1", "Station2", "Station3", "Station4", "Station5"]
        ground_truth_line = ProductionLineGraph()

        for station in stations:
            ground_truth_line.new_cell(station)
            station_nodes = [node for node in assembly_line.nodes if node.startswith(station)]
            station_subgraph = nx.subgraph(assembly_line, station_nodes)
            # make sure its sorted
            sorted_graph = nx.DiGraph()
            sorted_graph.add_nodes_from(
                sorted(station_nodes, key=lambda x: int(x.rpartition("_")[2]))
            )
            sorted_graph.add_edges_from(station_subgraph.edges)
            ground_truth_line.cells[station].graph = sorted_graph

        between_cell_edges = nx.difference(assembly_line, ground_truth_line.graph).edges()
        ground_truth_line.connect_across_cells_manually(edges=between_cell_edges)
        return ground_truth_line

    @classmethod
    def get_data(cls) -> pd.DataFrame:
        """Load in semi-synthetic data as described in the paper:
        causalAssembly: Generating Realistic Production Data for
        Benchmarking Causal Discovery
        Returns:
            pd.DataFrame: Data from which data should be generated.
        """
        return pd.read_csv(DATA_DATASET)

    @classmethod
    def from_nx(cls, g: nx.DiGraph, cell_mapper: dict[str, list]):
        """Convert nx.DiGraph to ProductionLineGraph. Requires a dict mapping
        where keys are cell names and values correspond to nodes within these cells.

        Args:
            g (nx.DiGraph): graph to be converted
            cell_mapper (dict[str, list]): dict to indicate what nodes belong to which cell

        Returns:
            ProductionLineGraph (ProductionLineGraph): the graph as a ProductionLineGraph object.
        """
        if not isinstance(g, nx.DiGraph):
            raise ValueError("Graph must be of type nx.DiGraph")
        pline = ProductionLineGraph()
        if cell_mapper:
            for cellname, cols in cell_mapper.items():
                pline.new_cell(name=cellname)
                cell_graph = nx.induced_subgraph(g, cols)
                pline.cells[cellname].input_cellgraph_directly(cell_graph, allow_in_edges=True)
        relabel_dict = {}
        for cellname, cols in cell_mapper.items():
            for col in cols:
                relabel_dict[col] = cellname + "_" + col

        g_rename = nx.relabel_nodes(g, relabel_dict)
        between_cell_edges = nx.difference(g_rename, pline.graph).edges()
        pline.connect_across_cells_manually(edges=between_cell_edges)
        return pline

    @classmethod
    def load_drf(cls, filename: str, location: str = None):
        """Loads a drf dict from a .pkl file into the workspace.

        Args:
            filename (str): name of the file e.g. examplefile.pkl
            location (str, optional): path to file in case it's not located
                in the current working directory. Defaults to None.

        Returns:
            DRF (dict): dict of trained drf objects
        """
        if not location:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "rb") as drf:
            pickle_drf = pickle.load(drf)

        return pickle_drf

    @classmethod
    def load_pline_from_pickle(cls, filename: str, location: str = None):
        if not location:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "rb") as pline:
            pickle_line = pickle.load(pline)

        if not isinstance(pickle_line, ProductionLineGraph):
            raise TypeError("You didn't refer to a ProductionLineGraph.")

        return pickle_line

    def save_drf(self, filename: str, location: str = None):
        """Writes a drf dict to file. Please provide the .pkl suffix!

        Args:
            filename (str): name of the file to be written e.g. examplefile.pkl
            location (str, optional): path to file in case it's not located in
                the current working directory. Defaults to None.
        """

        if not location:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "wb") as f:
            pickle.dump(self.drf, f)

    def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
        """Draw from the trained DRF.

        Args:
            size (int, optional): Number of samples to be drawn. Defaults to 10.
            smoothed (bool, optional): If set to true, marginal distributions will
                be sampled from smoothed bootstraps. Defaults to True.

        Returns:
            pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
        """
        return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed)

    def sample_from_interventional_drf(
        self, which_intervention: str | int = 0, size=10, smoothed: bool = True
    ) -> pd.DataFrame:
        """Draw from the trained and intervened upon DRF.

        Args:
            size (int, optional): Number of samples to be drawn. Defaults to 10.
            which_intervention (str | int): Which intervention to choose from.
                Both the literal name (see the property `interventions`) and the index
                are possible. Defaults to the first intervention.
            smoothed (bool, optional): If set to true, marginal distributions will
                be sampled from smoothed bootstraps. Defaults to True.

        Returns:
            pd.DataFrame: Data frame that follows the interventional distribution
                implied by the ground truth.
        """
        return _interventional_sample_from_drf(
            prod_object=self, which_intervention=which_intervention, size=size, smoothed=smoothed
        )

    def hidden_nodes(self) -> list:
        """Returns list of nodes marked as hidden

        Returns:
            list: of hidden nodes
        """
        return [
            node
            for node, hidden in nx.get_node_attributes(self.graph, NodeAttributes.HIDDEN).items()
            if hidden is True
        ]

    def visible_nodes(self):
        return [node for node in self.nodes if node not in self.hidden_nodes()]

    @property
    def eol_cell(self) -> ProcessCell | None:
        """

        Returns ProcessCell: the EOL cell
            (if any single cell has attr .is_eol = True), otherwise returns None
        -------

        """
        for cell in self.cells.values():
            if cell.is_eol:
                return cell

    @property
    def ground_truth_visible(self) -> pd.DataFrame:
        """Generates a ground truth graph in form of a pandas adjacency matrix.
        Row and column names correspond to visible.
        The following integers can occur:

        amat[i,j] = 1 indicates i -> j
        amat[i,j] = 0 indicates no edge
        amat[i,j] = amat[j,i] = 2 indicates i <-> j and there exists a common hidden confounder

        Returns:
            pd.DataFrame: amat with values in {0,1,2}.
        """

        if len(self.hidden_nodes()) == 0:
            return self.ground_truth
        else:
            mediators = self._pairs_with_hidden_mediators()
            confounders = self._pairs_with_hidden_confounders()

            # here 1 indicates that ROWS has edge to COLUMNS!
            amat = nx.to_pandas_adjacency(self.graph)
            amat_visible = amat.loc[self.visible_nodes(), self.visible_nodes()]

            for pairs in mediators:
                amat_visible.loc[pairs] = 1

            # reverse = lambda tuples: tuples[::-1]

            def reverse(tuples):
                """
                Simple function to reverse tuple order

                Args:
                    tuples (tuple): tuple to reverse order

                Returns:
                    tuple: tuple in reversed order
                """
                new_tup = tuples[::-1]
                return new_tup

            for pair, _ in confounders.items():
                amat_visible.loc[pair] = 2
                amat_visible.loc[reverse(pair)] = 2

            return amat_visible

    def show(self, meta_description: list | None = None, fig_size: tuple = (15, 8)):
        """Plot full assembly line

        Args:
            meta_description (list | None, optional): Specify additional cell info.
                Defaults to None.
            fig_size (tuple, optional): Adjust depending on number of cells.
                Defaults to (15, 8).

        Raises:
            AssertionError: Meta list entry needs to exist for each cell!
        """
        _, ax = plt.subplots(figsize=fig_size)

        pos = {}

        if meta_description is None:
            meta_description = ["" for _ in range(len(self.cells))]

        if len(meta_description) != len(self.cells):
            raise AssertionError("Meta list entry needs to exist for each cell!")

        max_in_degree = max([d for _, d in self.graph.in_degree()])
        max_out_degree = max([d for _, d in self.graph.out_degree()])

        for i, (station_name, meta_desc) in enumerate(zip(self.cell_order, meta_description)):
            pos.update(
                self.cells[station_name]._plot_cellgraph(
                    ax=ax,
                    with_edges=False,
                    with_box=True,
                    meta_desc=meta_desc,
                    center=np.array([8 * i, 0]),
                    node_color=[
                        (d + 10) / (max_in_degree + 10)
                        for _, d in self.graph.in_degree(self.get_nodes_of_station(station_name))
                    ],
                    node_size=[
                        500 * (d + 1) / (max_out_degree + 1)
                        for _, d in self.graph.out_degree(self.get_nodes_of_station(station_name))
                    ],
                )
            )

        nx.draw_networkx_edges(
            self.graph,
            pos=pos,
            ax=ax,
            alpha=0.2,
            arrowsize=8,
            width=0.5,
            connectionstyle="arc3,rad=0.3",
        )

    def __str__(self):
        s = "ProductionLine\n\n"
        for cell in self.cells:
            s += f"{cell}\n"
        return s

    def __getattr__(self, attrname):
        if attrname not in self.cells.keys():
            raise AttributeError(f"{attrname} is not a valid attribute (cell name?)")
        return self.cells[attrname]

    # https://docs.python.org/3/library/pickle.html#pickle-protocol
    # TODO why is .cells enough, are the other member vars directly pickable?
    def __getstate__(self):
        return (self.__dict__, self.cells)

    def __setstate__(self, state):
        self.__dict__, self.cells = state

    @classmethod
    def via_cell_number(cls, n_cells: int, cell_prefix: str = "C"):
        """Inits a ProductionLineGraph with predefined number of cells, e.g. n_cells = 3

        Will create empty  C0, C1 and C2 as cells if no other cell_prefix is given.

        Args:
            n_cells (int): Number of cells the graph will have
            cell_prefix (str, optional): If you like other cell names pass them here.
                Defaults to "C".

        """
        pl = cls()
        pl.cell_prefix = cell_prefix

        [pl.new_cell() for _ in range(n_cells)]

        return pl

    def _pairs_with_hidden_mediators(self):
        """
        Return pairs of nodes with hidden mediators present.

        Args:
            graph (nx.DiGraph): DAG
            visible (list): list of visible nodes

        Returns:
            list: list of tuples with pairs of nodes with hidden mediator
        """
        any_paths = []
        visible = self.visible_nodes()
        hidden_all = self.hidden_nodes()
        confounders = [node.pop() for _, node in self._pairs_with_hidden_confounders().items()]
        hidden = [node for node in hidden_all if node not in confounders]
        for i, _ in enumerate(visible):
            for j, _ in enumerate(visible):
                for path in sorted(nx.all_simple_paths(self.graph, visible[i], visible[j])):
                    any_paths.append(path)

        pairs_with_hidden_mediators = [
            (ls[0], ls[-1]) for ls in any_paths if np.all(np.isin(ls[1:-1], hidden)) and len(ls) > 2
        ]

        return pairs_with_hidden_mediators

    def _pairs_with_hidden_confounders(self) -> dict:
        """
        Returns node-pairs that have a common hidden confounder

        Returns:
            dict: Dictionary with keys equal to tuples of node-pairs
            and values their corresponding hidden confounder(s)
        """
        confounder_pairs = {}
        visible = self.visible_nodes()
        pair_order_list = list(itertools.combinations(visible, 2))

        for node1, node2 in pair_order_list:
            ancestors1 = nx.ancestors(self.graph, node1)
            ancestors2 = nx.ancestors(self.graph, node2)
            if np.all(
                np.concatenate(
                    (
                        np.isin(list(ancestors1), visible, invert=True),
                        np.isin(list(ancestors2), visible, invert=True),
                    ),
                    axis=None,
                )
            ):  # annoying way of doing this. List comparison doesn't allow elementwise eval
                confounder = ancestors1.intersection(ancestors2)
                if confounder:  # only populate if set is non-empty
                    confounder_pairs[(node1, node2)] = confounder
            else:
                direct_parents1 = set(self.graph.predecessors(node1))
                direct_parents2 = set(self.graph.predecessors(node2))
                direct_confounder = [
                    node
                    for node in list(direct_parents1.intersection(direct_parents2))
                    if node not in visible
                ]
                if direct_confounder:
                    confounder_pairs[(node1, node2)] = direct_confounder

        return confounder_pairs

between_adjacency: pd.DataFrame property

Returns adjacency matrix ignoring all within-cell edges.

Returns:

Type Description
DataFrame

pd.DataFrame: adjacency matrix

causal_order: list[str] property

Returns the causal order of the current graph. Note that this order is in general not unique.

Returns:

Type Description
list[str]

list[str]: Causal order

edges: list[tuple] property

Edges in the graph.

Returns:

Type Description
list[tuple]

list[tuple]

eol_cell: ProcessCell | None property

the EOL cell

(if any single cell has attr .is_eol = True), otherwise returns None


graph: nx.DiGraph property

Returns a nx.DiGraph object of the actual graph.

The graph is only built HERE, i.e. all ProcessCells exist standalone in self.cells, with no connections between their nodes yet.

The edges are stored in self.cell_connetor_edges, where they are added by random methods or by user (the dag builder) himself.

ATTENTION: you can not work on self.graph and add manually edges, nodes and expect them to work.

Returns nx.DiGraph


ground_truth: pd.DataFrame property

Returns the current ground truth as pandas adjacency.

Returns:

Type Description
DataFrame

pd.DataFrame: Adjacenccy matrix.

ground_truth_visible: pd.DataFrame property

Generates a ground truth graph in form of a pandas adjacency matrix. Row and column names correspond to visible. The following integers can occur:

amat[i,j] = 1 indicates i -> j amat[i,j] = 0 indicates no edge amat[i,j] = amat[j,i] = 2 indicates i <-> j and there exists a common hidden confounder

Returns:

Type Description
DataFrame

pd.DataFrame: amat with values in {0,1,2}.

interventions: list property

Returns all interventions performed on the original graph

Returns:

Name Type Description
list list

list of intervened upon nodes in do(x) notation.

nodes: list[str] property

Nodes in the graph.

Returns:

Type Description
list[str]

list[str]

num_edges: int property

Number of edges in the graph

Returns:

Type Description
int

int

num_nodes: int property

Number of nodes in the graph

Returns:

Type Description
int

int

sparsity: float property

Sparsity of the graph

Returns:

Name Type Description
float float

in [0,1]

within_adjacency: pd.DataFrame property

Returns adjacency matrix ignoring all between-cell edges.

Returns:

Type Description
DataFrame

pd.DataFrame: adjacency matrix

connect_across_cells_manually(edges)

Add edges manually across cells. You need to give the full name Args: edges (list[tuple]): list of edges to add

Source code in causalAssembly/models_dag.py
1090
1091
1092
1093
1094
1095
def connect_across_cells_manually(self, edges: list[tuple]):
    """Add edges manually across cells. You need to give the full name
    Args:
        edges (list[tuple]): list of edges to add
    """
    self.cell_connector_edges.extend(edges)

connect_cells(forward_probs=[0.1, 0.05])

Randomly connects cells in a ProductionLineGraph according to a forwards logic.

Parameters:

Name Type Description Default
forward_probs list[float]

Array of sparsity scalars of dimension max_forward. Defaults to [0.1, 0.05].

[0.1, 0.05]
Source code in causalAssembly/models_dag.py
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
def connect_cells(
    self,
    forward_probs: list[float] = [0.1, 0.05],
):
    """Randomly connects cells in a ProductionLineGraph according to a forwards logic.

    Args:
        forward_probs (list[float], optional): Array of sparsity scalars of
            dimension max_forward. Defaults to [0.1, 0.05].
    """
    # assume cells are initiated in order
    # otherwise allow to change order

    max_forward = len(forward_probs)
    cells = list(self.cells.values())
    no_of_cells = len(self.cells)

    for cell_idx, cell in enumerate(cells):
        prob_it = 0  # prob it(erator)

        for forwards in range(1, max_forward + 1):
            forward_cell_idx = cell_idx + forwards

            if forward_cell_idx < no_of_cells:
                forward_cell = cells[forward_cell_idx]
                chosen_edges = choose_edges_from_cells_randomly(
                    from_cell=cell,
                    to_cell=forward_cell,
                    probability=forward_probs[prob_it],
                    rng=self.random_state,
                )

                prob_it += 1  # FIXME: a bit ugly and hard to read
                self.cell_connector_edges.extend(chosen_edges)

        if eol_cell := self.eol_cell:
            eol_cell_idx = cells.index(eol_cell)

            if cell_idx + max_forward < eol_cell_idx:
                eol_prob = forward_probs[-1]
                chosen_eol_edges = choose_edges_from_cells_randomly(
                    from_cell=cell,
                    to_cell=eol_cell,
                    probability=eol_prob,
                    rng=self.random_state,
                )

                self.cell_connector_edges.extend(chosen_eol_edges)

copy()

Makes a full copy of the current ProductionLineGraph object

Returns:

Name Type Description
ProductionLineGraph ProductionLineGraph

copyied object.

Source code in causalAssembly/models_dag.py
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
def copy(self) -> ProductionLineGraph:
    """Makes a full copy of the current
    ProductionLineGraph object

    Returns:
        ProductionLineGraph: copyied object.
    """
    copy_graph = ProductionLineGraph()
    for station in self.cell_order:
        copy_graph.new_cell(station)
        # make sure its sorted
        sorted_graph = nx.DiGraph()
        sorted_graph.add_nodes_from(
            sorted(self.cells[station].nodes, key=lambda x: int(x.rpartition("_")[2]))
        )
        sorted_graph.add_edges_from(self.cells[station].edges)
        copy_graph.cells[station].graph = sorted_graph

    between_cell_edges = nx.difference(self.graph, copy_graph.graph).edges()
    copy_graph.connect_across_cells_manually(edges=between_cell_edges)
    return copy_graph

from_nx(g, cell_mapper) classmethod

Convert nx.DiGraph to ProductionLineGraph. Requires a dict mapping where keys are cell names and values correspond to nodes within these cells.

Parameters:

Name Type Description Default
g DiGraph

graph to be converted

required
cell_mapper dict[str, list]

dict to indicate what nodes belong to which cell

required

Returns:

Name Type Description
ProductionLineGraph ProductionLineGraph

the graph as a ProductionLineGraph object.

Source code in causalAssembly/models_dag.py
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
@classmethod
def from_nx(cls, g: nx.DiGraph, cell_mapper: dict[str, list]):
    """Convert nx.DiGraph to ProductionLineGraph. Requires a dict mapping
    where keys are cell names and values correspond to nodes within these cells.

    Args:
        g (nx.DiGraph): graph to be converted
        cell_mapper (dict[str, list]): dict to indicate what nodes belong to which cell

    Returns:
        ProductionLineGraph (ProductionLineGraph): the graph as a ProductionLineGraph object.
    """
    if not isinstance(g, nx.DiGraph):
        raise ValueError("Graph must be of type nx.DiGraph")
    pline = ProductionLineGraph()
    if cell_mapper:
        for cellname, cols in cell_mapper.items():
            pline.new_cell(name=cellname)
            cell_graph = nx.induced_subgraph(g, cols)
            pline.cells[cellname].input_cellgraph_directly(cell_graph, allow_in_edges=True)
    relabel_dict = {}
    for cellname, cols in cell_mapper.items():
        for col in cols:
            relabel_dict[col] = cellname + "_" + col

    g_rename = nx.relabel_nodes(g, relabel_dict)
    between_cell_edges = nx.difference(g_rename, pline.graph).edges()
    pline.connect_across_cells_manually(edges=between_cell_edges)
    return pline

get_data() classmethod

Load in semi-synthetic data as described in the paper: causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery Returns: pd.DataFrame: Data from which data should be generated.

Source code in causalAssembly/models_dag.py
1202
1203
1204
1205
1206
1207
1208
1209
1210
@classmethod
def get_data(cls) -> pd.DataFrame:
    """Load in semi-synthetic data as described in the paper:
    causalAssembly: Generating Realistic Production Data for
    Benchmarking Causal Discovery
    Returns:
        pd.DataFrame: Data from which data should be generated.
    """
    return pd.read_csv(DATA_DATASET)

get_ground_truth() classmethod

Loads in the ground_truth as described in the paper: causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery Returns: ProductionLineGraph: ground_truth for cells and line.

Source code in causalAssembly/models_dag.py
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
@classmethod
def get_ground_truth(cls) -> ProductionLineGraph:
    """Loads in the ground_truth as described in the paper:
    causalAssembly: Generating Realistic Production Data for
    Benchmarking Causal Discovery
    Returns:
        ProductionLineGraph: ground_truth for cells and line.
    """

    gt_response = requests.get(DATA_GROUND_TRUTH, timeout=5)
    ground_truth = json.loads(gt_response.text)

    assembly_line = json_graph.adjacency_graph(ground_truth)

    stations = ["Station1", "Station2", "Station3", "Station4", "Station5"]
    ground_truth_line = ProductionLineGraph()

    for station in stations:
        ground_truth_line.new_cell(station)
        station_nodes = [node for node in assembly_line.nodes if node.startswith(station)]
        station_subgraph = nx.subgraph(assembly_line, station_nodes)
        # make sure its sorted
        sorted_graph = nx.DiGraph()
        sorted_graph.add_nodes_from(
            sorted(station_nodes, key=lambda x: int(x.rpartition("_")[2]))
        )
        sorted_graph.add_edges_from(station_subgraph.edges)
        ground_truth_line.cells[station].graph = sorted_graph

    between_cell_edges = nx.difference(assembly_line, ground_truth_line.graph).edges()
    ground_truth_line.connect_across_cells_manually(edges=between_cell_edges)
    return ground_truth_line

get_nodes_of_station(station_name)

Returns nodes in chosen Station.

Parameters:

Name Type Description Default
station_name str

name of station.

required

Raises:

Type Description
AssertionError

if station name doesn't match pline.

Returns:

Name Type Description
list list

nodes in chosen station

Source code in causalAssembly/models_dag.py
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
def get_nodes_of_station(self, station_name: str) -> list:
    """Returns nodes in chosen Station.

    Args:
        station_name (str): name of station.

    Raises:
        AssertionError: if station name doesn't match pline.

    Returns:
        list: nodes in chosen station
    """
    if station_name not in self.cell_order:
        raise AssertionError("Station name not among cells.")

    return self.cells[station_name].nodes

hidden_nodes()

Returns list of nodes marked as hidden

Returns:

Name Type Description
list list

of hidden nodes

Source code in causalAssembly/models_dag.py
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
def hidden_nodes(self) -> list:
    """Returns list of nodes marked as hidden

    Returns:
        list: of hidden nodes
    """
    return [
        node
        for node, hidden in nx.get_node_attributes(self.graph, NodeAttributes.HIDDEN).items()
        if hidden is True
    ]

intervene_on(nodes_values)

Specify hard or soft intervention. If you want to intervene upon more than one node provide a list of nodes to intervene on and a list of corresponding values to set these nodes to. (see example). The mutilated dag will automatically be stored in mutiliated_dags.

Parameters:

Name Type Description Default
nodes_values dict[str, RandomSymbol | float]

either single real number or sympy.stats.RandomSymbol. If you like to intervene on more than one node, just provide more key-value pairs.

required

Raises:

Type Description
AssertionError

If node(s) are not in the graph

Source code in causalAssembly/models_dag.py
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
def intervene_on(self, nodes_values: dict[str, RandomSymbol | float]):
    """Specify hard or soft intervention. If you want to intervene
    upon more than one node provide a list of nodes to intervene on
    and a list of corresponding values to set these nodes to.
    (see example). The mutilated dag will automatically be
    stored in `mutiliated_dags`.

    Args:
        nodes_values (dict[str, RandomSymbol | float]): either single real
            number or sympy.stats.RandomSymbol. If you like to intervene on
            more than one node, just provide more key-value pairs.

    Raises:
        AssertionError: If node(s) are not in the graph
    """
    if not self.drf:
        raise AssertionError("You need to train a drf first.")
    drf_replace = {}

    if not set(nodes_values.keys()).issubset(set(self.nodes)):
        raise AssertionError(
            "One or more nodes you want to intervene upon are not in the graph."
        )

    mutilated_dag = self.graph.copy()

    for node, value in nodes_values.items():
        old_incoming = self.parents(of_node=node)
        edges_to_remove = [(old, node) for old in old_incoming]
        mutilated_dag.remove_edges_from(edges_to_remove)
        drf_replace[node] = value

    self.mutilated_dags[
        f"do({list(nodes_values.keys())})"
    ] = mutilated_dag  # specifiying the same set twice will override

    self.interventional_drf[f"do({list(nodes_values.keys())})"] = drf_replace

interventional_amat(which_intervention)

Returns the adjacency matrix of a chosen mutilated DAG.

Parameters:

Name Type Description Default
which_intervention int | str

Integer count of your chosen intervention or literal string.

required

Raises:

Type Description
ValueError

"The intervention you provide does not exist."

Returns:

Type Description
DataFrame

pd.DataFrame: Adjacency matrix.

Source code in causalAssembly/models_dag.py
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
def interventional_amat(self, which_intervention: int | str) -> pd.DataFrame:
    """Returns the adjacency matrix of a chosen mutilated DAG.

    Args:
        which_intervention (int | str): Integer count of your chosen intervention or
            literal string.

    Raises:
        ValueError: "The intervention you provide does not exist."

    Returns:
        pd.DataFrame: Adjacency matrix.
    """
    if isinstance(which_intervention, str) and which_intervention not in self.interventions:
        raise ValueError("The intervention you provide does not exist.")

    if isinstance(which_intervention, int) and which_intervention > len(self.interventions):
        raise ValueError("The intervention you index does not exist.")

    if isinstance(which_intervention, int):
        which_intervention = self.interventions[which_intervention]

    mutilated_dag = self.mutilated_dags[which_intervention].copy()
    return nx.to_pandas_adjacency(mutilated_dag, weight=None)

load_drf(filename, location=None) classmethod

Loads a drf dict from a .pkl file into the workspace.

Parameters:

Name Type Description Default
filename str

name of the file e.g. examplefile.pkl

required
location str

path to file in case it's not located in the current working directory. Defaults to None.

None

Returns:

Name Type Description
DRF dict

dict of trained drf objects

Source code in causalAssembly/models_dag.py
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
@classmethod
def load_drf(cls, filename: str, location: str = None):
    """Loads a drf dict from a .pkl file into the workspace.

    Args:
        filename (str): name of the file e.g. examplefile.pkl
        location (str, optional): path to file in case it's not located
            in the current working directory. Defaults to None.

    Returns:
        DRF (dict): dict of trained drf objects
    """
    if not location:
        location = Path().resolve()

    location_path = Path(location, filename)

    with open(location_path, "rb") as drf:
        pickle_drf = pickle.load(drf)

    return pickle_drf

new_cell(name=None, is_eol=False)

Add a new cell to the production line.

If no name is given, cell name is given by counting available cells + 1

Parameters:

Name Type Description Default
name str

Defaults to None.

None
is_eol bool

Whether cell is end-of-line. Defaults to False.

False

Returns:

Type Description
ProcessCell

ProcessCell

Source code in causalAssembly/models_dag.py
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
def new_cell(self, name: str = None, is_eol: bool = False) -> ProcessCell:
    """Add a new cell to the production line.

    If no name is given, cell name is given by counting available cells + 1

    Args:
        name (str, optional): Defaults to None.
        is_eol (bool, optional): Whether cell is end-of-line. Defaults to False.

    Returns:
        ProcessCell
    """
    if name:
        c = ProcessCell(name=name)

    else:
        actual_no_of_cells = len(self.cells.values())
        c = ProcessCell(name=f"{self.cell_prefix}{actual_no_of_cells}")

    c.random_state = self.random_state

    c.is_eol = is_eol
    self.__add_cell(cell=c)
    self.cell_order.append(c.name)
    return c

parents(of_node)

Return parents of node in question.

Parameters:

Name Type Description Default
of_node str

Node in question.

required

Returns:

Type Description
list[str]

list[str]: parent set.

Source code in causalAssembly/models_dag.py
946
947
948
949
950
951
952
953
954
955
def parents(self, of_node: str) -> list[str]:
    """Return parents of node in question.

    Args:
        of_node (str): Node in question.

    Returns:
        list[str]: parent set.
    """
    return list(self.graph.predecessors(of_node))

sample_from_drf(size=10, smoothed=True)

Draw from the trained DRF.

Parameters:

Name Type Description Default
size int

Number of samples to be drawn. Defaults to 10.

10
smoothed bool

If set to true, marginal distributions will be sampled from smoothed bootstraps. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame that follows the distribution implied by the ground truth.

Source code in causalAssembly/models_dag.py
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
    """Draw from the trained DRF.

    Args:
        size (int, optional): Number of samples to be drawn. Defaults to 10.
        smoothed (bool, optional): If set to true, marginal distributions will
            be sampled from smoothed bootstraps. Defaults to True.

    Returns:
        pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
    """
    return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed)

sample_from_interventional_drf(which_intervention=0, size=10, smoothed=True)

Draw from the trained and intervened upon DRF.

Parameters:

Name Type Description Default
size int

Number of samples to be drawn. Defaults to 10.

10
which_intervention str | int

Which intervention to choose from. Both the literal name (see the property interventions) and the index are possible. Defaults to the first intervention.

0
smoothed bool

If set to true, marginal distributions will be sampled from smoothed bootstraps. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame that follows the interventional distribution implied by the ground truth.

Source code in causalAssembly/models_dag.py
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
def sample_from_interventional_drf(
    self, which_intervention: str | int = 0, size=10, smoothed: bool = True
) -> pd.DataFrame:
    """Draw from the trained and intervened upon DRF.

    Args:
        size (int, optional): Number of samples to be drawn. Defaults to 10.
        which_intervention (str | int): Which intervention to choose from.
            Both the literal name (see the property `interventions`) and the index
            are possible. Defaults to the first intervention.
        smoothed (bool, optional): If set to true, marginal distributions will
            be sampled from smoothed bootstraps. Defaults to True.

    Returns:
        pd.DataFrame: Data frame that follows the interventional distribution
            implied by the ground truth.
    """
    return _interventional_sample_from_drf(
        prod_object=self, which_intervention=which_intervention, size=size, smoothed=smoothed
    )

save_drf(filename, location=None)

Writes a drf dict to file. Please provide the .pkl suffix!

Parameters:

Name Type Description Default
filename str

name of the file to be written e.g. examplefile.pkl

required
location str

path to file in case it's not located in the current working directory. Defaults to None.

None
Source code in causalAssembly/models_dag.py
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
def save_drf(self, filename: str, location: str = None):
    """Writes a drf dict to file. Please provide the .pkl suffix!

    Args:
        filename (str): name of the file to be written e.g. examplefile.pkl
        location (str, optional): path to file in case it's not located in
            the current working directory. Defaults to None.
    """

    if not location:
        location = Path().resolve()

    location_path = Path(location, filename)

    with open(location_path, "wb") as f:
        pickle.dump(self.drf, f)

show(meta_description=None, fig_size=(15, 8))

Plot full assembly line

Parameters:

Name Type Description Default
meta_description list | None

Specify additional cell info. Defaults to None.

None
fig_size tuple

Adjust depending on number of cells. Defaults to (15, 8).

(15, 8)

Raises:

Type Description
AssertionError

Meta list entry needs to exist for each cell!

Source code in causalAssembly/models_dag.py
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
def show(self, meta_description: list | None = None, fig_size: tuple = (15, 8)):
    """Plot full assembly line

    Args:
        meta_description (list | None, optional): Specify additional cell info.
            Defaults to None.
        fig_size (tuple, optional): Adjust depending on number of cells.
            Defaults to (15, 8).

    Raises:
        AssertionError: Meta list entry needs to exist for each cell!
    """
    _, ax = plt.subplots(figsize=fig_size)

    pos = {}

    if meta_description is None:
        meta_description = ["" for _ in range(len(self.cells))]

    if len(meta_description) != len(self.cells):
        raise AssertionError("Meta list entry needs to exist for each cell!")

    max_in_degree = max([d for _, d in self.graph.in_degree()])
    max_out_degree = max([d for _, d in self.graph.out_degree()])

    for i, (station_name, meta_desc) in enumerate(zip(self.cell_order, meta_description)):
        pos.update(
            self.cells[station_name]._plot_cellgraph(
                ax=ax,
                with_edges=False,
                with_box=True,
                meta_desc=meta_desc,
                center=np.array([8 * i, 0]),
                node_color=[
                    (d + 10) / (max_in_degree + 10)
                    for _, d in self.graph.in_degree(self.get_nodes_of_station(station_name))
                ],
                node_size=[
                    500 * (d + 1) / (max_out_degree + 1)
                    for _, d in self.graph.out_degree(self.get_nodes_of_station(station_name))
                ],
            )
        )

    nx.draw_networkx_edges(
        self.graph,
        pos=pos,
        ax=ax,
        alpha=0.2,
        arrowsize=8,
        width=0.5,
        connectionstyle="arc3,rad=0.3",
    )

via_cell_number(n_cells, cell_prefix='C') classmethod

Inits a ProductionLineGraph with predefined number of cells, e.g. n_cells = 3

Will create empty C0, C1 and C2 as cells if no other cell_prefix is given.

Parameters:

Name Type Description Default
n_cells int

Number of cells the graph will have

required
cell_prefix str

If you like other cell names pass them here. Defaults to "C".

'C'
Source code in causalAssembly/models_dag.py
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
@classmethod
def via_cell_number(cls, n_cells: int, cell_prefix: str = "C"):
    """Inits a ProductionLineGraph with predefined number of cells, e.g. n_cells = 3

    Will create empty  C0, C1 and C2 as cells if no other cell_prefix is given.

    Args:
        n_cells (int): Number of cells the graph will have
        cell_prefix (str, optional): If you like other cell names pass them here.
            Defaults to "C".

    """
    pl = cls()
    pl.cell_prefix = cell_prefix

    [pl.new_cell() for _ in range(n_cells)]

    return pl

choose_edges_from_cells_randomly(from_cell, to_cell, probability, rng)

From two given cells (graphs), we take the cartesian product (end up with from_cell.number_of_nodes x to_cell.number_of_nodes possible edges (node tuples).

From this product we draw probability x cartesian product number of edges randomly.

In case we have a float number, we ceil the value, e.g. 17.3 edges will lead to 18 edges drawn.

Parameters:

Name Type Description Default
from_cell ProcessCell

ProcessCell from where we want the edges

required
to_cell ProcessCell

ProcessCell to where we want the edges

required
probability float

between 0 and 1

required

Returns:

Type Description
list[tuple[str, str]]

list[tuple[str, str]]: Chosen edges.

Source code in causalAssembly/models_dag.py
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
def choose_edges_from_cells_randomly(
    from_cell: ProcessCell,
    to_cell: ProcessCell,
    probability: float,
    rng: np.random.Generator,
) -> list[tuple[str, str]]:
    """
    From two given cells (graphs), we take the cartesian product (end up with
    from_cell.number_of_nodes x to_cell.number_of_nodes possible edges (node tuples).

    From this product we draw probability x cartesian product number of edges randomly.

    In case we have a float number, we ceil the value,
    e.g. 17.3 edges will lead to 18 edges drawn.

    Args:
        from_cell: ProcessCell from where we want the edges
        to_cell: ProcessCell to where we want the edges
        probability: between 0 and 1

    Returns:
        list[tuple[str, str]]: Chosen edges.
    """

    assert 0 <= probability <= 1.0

    arrow_tail_candidates = list(from_cell.graph.nodes)
    arrow_head_candidates = get_arrow_head_candidates_from_graph(graph=to_cell.graph)

    potential_edges = tuples_from_cartesian_product(
        l1=arrow_tail_candidates, l2=arrow_head_candidates
    )

    num_to_choose = int(np.ceil(probability * len(potential_edges)))

    chosen_edges = [
        potential_edges[i]
        for i in rng.choice(a=len(potential_edges), size=num_to_choose, replace=False)
    ]

    return chosen_edges

get_arrow_head_candidates_from_graph(graph, node_attributes_to_filter=NodeAttributes.ALLOW_IN_EDGES)

Returns all arrow head (nodes where an arrow points to) nodes as list of candidates. To later build a list of tuples of potential edges.

Parameters:

Name Type Description Default
graph DiGraph

DAG

required
node_attributes_to_filter str

see NodeAttributes. Defaults to NodeAttributes.ALLOW_IN_EDGES.

ALLOW_IN_EDGES

Returns:

Type Description
list[str]

list[str]: list of nodes

Source code in causalAssembly/models_dag.py
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
def get_arrow_head_candidates_from_graph(
    graph: nx.DiGraph, node_attributes_to_filter: str = NodeAttributes.ALLOW_IN_EDGES
) -> list[str]:
    """Returns all arrow head (nodes where an arrow points to) nodes as list of candidates.
    To later build a list of tuples of potential edges.

    Args:
        graph (nx.DiGraph): DAG
        node_attributes_to_filter (str, optional): see NodeAttributes.
            Defaults to NodeAttributes.ALLOW_IN_EDGES.

    Returns:
        list[str]: list of nodes
    """
    arrow_head_candidates = [
        node
        for node, allowed in nx.get_node_attributes(graph, node_attributes_to_filter).items()
        if allowed is True
    ]

    nodes_without_attributes = list(
        set(graph.nodes).difference(
            set(nx.get_node_attributes(graph, node_attributes_to_filter).keys())
        )
    )

    if len(arrow_head_candidates) == 0 and len(nodes_without_attributes) == 0:
        logger.warning(
            f"None of the nodes in cell {graph} \
            are allowed to have in-edges."
        )

    arrow_head_candidates.extend(nodes_without_attributes)

    return arrow_head_candidates

Utility classes and functions related to causalAssembly. Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

merge_dags(dag_to_insert, target_dag, mapping, remove_in_edges_in_target_dag=False)

Dag_to_insert will be connected to target_tag via mapping dict.

Parameters:

Name Type Description Default
dag_to_insert DiGraph

DAG to insert.

required
target_dag DiGraph

DAG on which to map.

required
mapping dict

Mapping from insert to target dag e.g. {C1: D1, C5: D4} where node C1 from insert dag will be mapped to node D1 of target dag.

required
remove_in_edges_in_target_dag bool

Defaults to False.

False

Raises:

Type Description
ValueError

node does not exist in target_dag

ValueError

node does not exist in dag_to_insert

Returns: nx.DiGraph: merged DAG

Source code in causalAssembly/dag_utils.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def merge_dags(
    dag_to_insert: nx.DiGraph,
    target_dag: nx.DiGraph,
    mapping: dict,
    remove_in_edges_in_target_dag: bool = False,
) -> nx.DiGraph:
    """Dag_to_insert will be connected to target_tag via mapping dict.

    Args:
        dag_to_insert (nx.DiGraph): DAG to insert.
        target_dag (nx.DiGraph): DAG on which to map.
        mapping (dict): Mapping from insert to target dag e.g.
            {C1: D1, C5: D4} where node C1 from insert dag will
            be mapped to node D1 of target dag.
        remove_in_edges_in_target_dag (bool, optional): Defaults to False.

    Raises:
        ValueError: node does not exist in target_dag
        ValueError: node does not exist in dag_to_insert
    Returns:
        nx.DiGraph: merged DAG
    """

    for old_node_name, new_node_name in mapping.items():
        if new_node_name not in target_dag.nodes():
            raise ValueError(f"{new_node_name} does not exist in target_dag")
        if old_node_name not in dag_to_insert.nodes():
            raise ValueError(f"{old_node_name} does not exist in dag_to_insert")

    no_of_nodes_insert_dag = len(dag_to_insert.nodes())
    no_of_nodes_target_dag = len(target_dag.nodes())
    if no_of_nodes_insert_dag > no_of_nodes_target_dag:
        logger.warning(
            f"you are trying to merge a DAG of size={no_of_nodes_insert_dag} "
            f"into a smaller DAG of size={no_of_nodes_target_dag}. "
            f"If this is not intentional consider changing the order"
        )

    # remove all in edges on target node
    if remove_in_edges_in_target_dag:
        for node in mapping.values():
            e = target_dag.in_edges(node)
            target_dag.remove_edges_from(list(e))

    # rename nodes to glue together
    dag = nx.relabel_nodes(dag_to_insert, mapping)

    return nx.compose(dag, target_dag)

merge_dags_via_edges(left_dag, right_dag, edges=None, isolate_target_nodes=False)

Merges two dags via a list of edges.

Parameters:

Name Type Description Default
left_dag DiGraph

dag to merge to right_dag

required
right_dag DiGraph

dag to merge left_dag into

required
edges list[tuple]

list of edges that connect the two dags. Defaults to None.

None
isolate_target_nodes bool

bool if True all incoming edges from the right_dag into the target node are removed: all influence from the left_dag, defined via edges list. Defaults to False.

False

Raises:

Type Description
ValueError

source or target nodes are not available in left dag

ValueError

source or target nodes are not available in right dag

Source code in causalAssembly/dag_utils.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def merge_dags_via_edges(
    left_dag: nx.DiGraph,
    right_dag: nx.DiGraph,
    edges: list[tuple] = None,
    isolate_target_nodes: bool = False,
):
    """Merges two dags via a list of edges.

    Args:
        left_dag (nx.DiGraph): dag to merge to right_dag
        right_dag (nx.DiGraph):  dag to merge left_dag into
        edges (list[tuple], optional): list of edges that connect the two dags.
            Defaults to None.
        isolate_target_nodes (bool, optional): bool if True all incoming edges
            from the right_dag into the target node are removed:
            all influence from the left_dag, defined via edges list.
            Defaults to False.

    Raises:
        ValueError: source or target nodes are not available in left dag
        ValueError: source or target nodes are not available in right dag

    """
    if not edges:
        edges = list()
    source_nodes = set([t[0] for t in edges])
    target_nodes = set([t[1] for t in edges])

    if not source_nodes.issubset(set(left_dag.nodes)):
        raise ValueError(
            f"At least one of the source nodes: {source_nodes} "
            f"cannot be found in the left DAGs nodes: {left_dag.nodes}"
        )

    if not target_nodes.issubset(set(right_dag.nodes)):
        raise ValueError(
            f"At least one of the target nodes: {target_nodes} "
            f"cannot be found in right DAGs nodes: {right_dag.nodes}"
        )

    if isolate_target_nodes:
        for node in target_nodes:
            # cast to list to work with value not with reference
            edges_to_target_node = list(right_dag.in_edges(node))
            right_dag.remove_edges_from(edges_to_target_node)

    merged_dag = nx.compose(left_dag, right_dag)

    # TODO experimental
    merged_dag.add_edges_from(edges, **{"connector": True})

    return merged_dag

tuples_from_cartesian_product(l1, l2)

Given two lists l1 and l2 this creates the cartesian product and returns all tuples.

Parameters:

Name Type Description Default
l1 list

First list of nodes

required
l2 list

Second list of nodes typically

required

Returns:

Type Description
list[tuple]

list[tuple]: list of edges typically

Examples::

l1 = [0,1,2]
l2 = ['a','b','c']
>>> tuples_from_cartesian_product(l1,l2)
[(0,'a'), (0,'b'), (0,'c'), (1,'a'), (1,'b'), (1,'c'), (2,'a'), (2,'b'), (2,'c')]
Source code in causalAssembly/dag_utils.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def tuples_from_cartesian_product(l1: list, l2: list) -> list[tuple]:
    """Given two lists l1 and l2 this creates the cartesian product and returns all tuples.

    Args:
        l1 (list): First list of nodes
        l2 (list): Second list of nodes typically

    Returns:
        list[tuple]: list of edges typically

    Examples::

        l1 = [0,1,2]
        l2 = ['a','b','c']
        >>> tuples_from_cartesian_product(l1,l2)
        [(0,'a'), (0,'b'), (0,'c'), (1,'a'), (1,'b'), (1,'c'), (2,'a'), (2,'b'), (2,'c')]

    """
    # TODO: This could take a long time for large graphs...
    return [
        (tail, head)
        for tail, head in itertools.product(
            l1,
            l2,
        )
    ]

Utility classes and functions related to causalAssembly. Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

FCM

Class to define, intervene and sample from an FCM.

Examples:

from sympy import symbols
from sympy.stats import Normal, Uniform, Gamma

x, y, z = symbols('x,y,z')

eq_x = Eq(x, Uniform("noise", left=-1, right=1))
eq_y = Eq(y, 2 * x ** 2 + Normal("error", 0, .5))
eq_z = Eq(z, 9 * y * x * Gamma("some_name", .5, .5))

eq_list = [eq_x, eq_y, eq_z]


self = FCM(name='test', seed=2023)
self.input_fcm(eq_list)
self.draw(size=10)
Source code in causalAssembly/models_fcm.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
class FCM:
    """Class to define, intervene and sample from an FCM.

    Examples:
        ```python
        from sympy import symbols
        from sympy.stats import Normal, Uniform, Gamma

        x, y, z = symbols('x,y,z')

        eq_x = Eq(x, Uniform("noise", left=-1, right=1))
        eq_y = Eq(y, 2 * x ** 2 + Normal("error", 0, .5))
        eq_z = Eq(z, 9 * y * x * Gamma("some_name", .5, .5))

        eq_list = [eq_x, eq_y, eq_z]


        self = FCM(name='test', seed=2023)
        self.input_fcm(eq_list)
        self.draw(size=10)
        ```

    """

    def __init__(self, name: str | None = None, seed: int = 2023):
        self.name = name
        self._random_state = np.random.default_rng(seed=seed)
        self.__init_dag()
        self.last_df: pd.DataFrame
        self.__init_mutilated_dag()

    def __init_dag(self):
        self.graph = nx.DiGraph(name=self.name)

    def __init_mutilated_dag(self):
        self.mutilated_dags = dict()

    @property
    def source_nodes(self) -> list:
        """Returns source nodes in the current DAG.

        Returns:
            list: List of source nodes.
        """
        return [node for node in self.nodes if len(self.parents(of_node=node)) == 0]

    @property
    def causal_order(self) -> list[Symbol]:
        """Returns the causal order of the current graph.
        Note that this order is in general not unique. To
        ensure uniqueness, we additionally sort lexicograpically.

        Returns:
            list[Symbol]: Causal order
        """
        return list(nx.lexicographical_topological_sort(self.graph, key=lambda x: str(x)))

    @property
    def nodes(self) -> list[Symbol]:
        """Nodes in the graph.

        Returns:
            list[str]
        """
        return list(self.graph.nodes())

    @property
    def edges(self) -> list[tuple]:
        """Edges in the graph.

        Returns:
            list[tuple]
        """
        return list(self.graph.edges())

    @property
    def num_nodes(self) -> int:
        """Number of nodes in the graph

        Returns:
            int
        """
        return len(self.nodes)

    @property
    def num_edges(self) -> int:
        """Number of edges in the graph

        Returns:
            int
        """
        return len(self.edges)

    @property
    def sparsity(self) -> float:
        """Sparsity of the graph

        Returns:
            float: in [0,1]
        """
        s = self.num_nodes
        return self.num_edges / s / (s - 1) * 2

    @property
    def ground_truth(self) -> pd.DataFrame:
        """Returns the current ground truth as
        pandas adjacency.

        Returns:
            pd.DataFrame: Adjacenccy matrix.
        """
        return nx.to_pandas_adjacency(self.graph, weight=None)

    @property
    def interventions(self) -> list:
        """Returns all interventions performed on the original graph

        Returns:
            list: list of intervened upon nodes in do(x) notation.
        """
        return list(self.mutilated_dags.keys())

    def interventional_amat(self, which_intervention: int | str) -> pd.DataFrame:
        """Returns the adjacency matrix of a chosen mutilated DAG.

        Args:
            which_intervention (int | str): Integer count of your chosen intervention or
                literal string.

        Raises:
            ValueError: "The intervention you provide does not exist."

        Returns:
            pd.DataFrame: Adjacency matrix.
        """
        if isinstance(which_intervention, str) and which_intervention not in self.interventions:
            raise ValueError("The intervention you provide does not exist.")

        if isinstance(which_intervention, int) and which_intervention > len(self.interventions):
            raise ValueError("The intervention you index does not exist.")

        if isinstance(which_intervention, int):
            which_intervention = self.interventions[which_intervention]

        mutilated_dag = self.mutilated_dags[which_intervention].copy()
        return nx.to_pandas_adjacency(mutilated_dag, weight=None)

    def parents(self, of_node: Symbol) -> list[Symbol]:
        """Return parents of node in question.

        Args:
            of_node (str): Node in question.

        Returns:
            list[str]: parent set.
        """
        return list(self.graph.predecessors(of_node))

    def parents_of(self, node: Symbol, which_graph: nx.DiGraph) -> list[Symbol]:
        """Return parents of node in question for a chosen DAG.

        Args:
            node (Symbol): node whose parents to return.
            which_graph (nx.DiGraph): which graph along the interventions.

        Returns:
            list[Symbol]: list of parents.
        """
        return list(which_graph.predecessors(node))

    def causal_order_of(self, which_graph: nx.DiGraph) -> list[Symbol]:
        """Returns the causal order of the chosen graph.
        Note that this order is in general not unique. To
        ensure uniqueness, we additionally sort lexicograpically.

        Returns:
            list[Symbol]: Causal order
        """
        return list(nx.lexicographical_topological_sort(which_graph, key=lambda x: str(x)))

    def source_nodes_of(self, which_graph: nx.DiGraph) -> list:
        """Returns the source nodes of a chosen graph. This is mainly for
        choosing different mutilated DAGs.

        Args:
            which_graph (nx.DiGraph): DAG from which source nodes should be returned.

        Returns:
            list: List of nodes.
        """
        return [
            node
            for node in which_graph.nodes
            if len(self.parents_of(node=node, which_graph=which_graph)) == 0
        ]

    def input_fcm(self, fcm: list[Eq]):
        """
        Automatically builds up DAG according to the FCM fed in.
        Args:
            fcm (list): list of sympy equations generated as:
                    ```[python]
                    x,y = symbols('x,y')
                    term_x = Eq(x, Normal('x', 0,1))
                    term_y = Eq(y, 2*x**2*Normal('noise', 0,1))
                    fcm = [term_x, term_y]
                    ```
        """
        nodes_implied = [node.lhs.free_symbols.pop() for node in fcm]
        edges_implied = []
        for term in fcm:
            if not isinstance(term.rhs, RandomSymbol):
                if term.rhs.atoms(RandomSymbol):
                    edges_implied.extend(
                        [
                            (atom, term.lhs)
                            for atom in term.rhs.free_symbols
                            if str(atom) != str(term.rhs.atoms(RandomSymbol).pop())
                        ]
                    )
                else:
                    edges_implied.extend([(atom, term.lhs) for atom in term.rhs.atoms(Symbol)])

        g = self.graph
        g.add_nodes_from(nodes_implied)
        g.add_edges_from(edges_implied)

        term_dict = {}
        for term in fcm:
            term_dict[term.lhs] = {"term": term.rhs}

        nx.set_node_attributes(g, term_dict)

    def function_of(self, node: Symbol) -> dict:
        """Returns functional assignment for node in question.

        Args:
            node (Symbol): node corresponding to lhs.

        Returns:
            dict: key is node and value rhs of functional assignment.
        """
        if node not in self.graph.nodes:
            if node in [str(node) for node in self.nodes]:
                raise AssertionError(
                    "You probably defined a string. Node has to be a symbol, check out",
                    list(self.graph.nodes),
                )
            else:
                raise AssertionError("Node has to be in the graph")

        return {node: self.graph.nodes[node]["term"]}

    def display_functions(self) -> dict:
        """Displays all functional assignments inputted into the FCM.

        Returns:
            dict: Dict with keys equal to nodes and values equal to
                functional assignments.
        """
        fcm_dict = {}
        for node in self.causal_order:
            fcm_dict[node] = self.graph.nodes[node]["term"]

        return fcm_dict

    def sample(
        self,
        size: int,
        additive_gaussian_noise: bool = False,
        snr: None | float = 1 / 2,
        source_df: None | pd.DataFrame = None,
    ) -> pd.DataFrame:
        """Draw samples from the joint distribution that factorizes
        according to the DAG implied by the FCM fed in. To avoid
        unexpected/unintended behavior, avoid defining fully
        deterministic equation systems.
        If parameters in noise terms are additive and left unevaluated,
        they're set according to a chosen Signal-To-Noise (SNR) ratio.
        For convenience, you can add additive Gaussian noise to each equation.
        This will never overwrite any of the chosen noise distributions.
        You may also feed in a data frame for noise distributions (see below
        for more details).

        Args:
            size (int): Number of samples to draw.
            additive_gaussian_noise (bool, optional): This will attach additive
                Gaussian noise to all terms without a RandomSymbol that are not
                source nodes. It acts merely as a convenience option. Variance
                will then be chosen according to SNR. Defaults to False.
            snr (None | float, optional): Signal-to-noise ratio
                \\( SNR =  \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\).
                Defaults to 1/2.
            source_df (None | pd.DataFrame, optional): Data frame containing source node data.
                The sample size must be at least as large as the number of samples
                you'd like to draw. Defaults to None.

        Raises:
            AttributeError: if source node parameters are not given explicitly.
            ValueError: if source node sample size is too small.
            ValueError: if scale parameters are left unevaluated for non-additive terms.

        Returns:
            pd.DataFrame:  Data frame with rows of length `size` and columns equal to the
                number of nodes in the graph.
        """
        return self._sample(
            size=size,
            additive_gaussian_noise=additive_gaussian_noise,
            snr=snr,
            source_df=source_df,
            which_graph=self.graph,
        )

    def interventional_sample(
        self,
        size: int,
        which_intervention: str | int = 0,
        additive_gaussian_noise: bool = False,
        snr: None | float = 1 / 2,
        source_df: None | pd.DataFrame = None,
    ) -> pd.DataFrame:
        """Draw samples from the interventional distribution that factorizes
        according to the mutilated DAG after performing one or multiple
        interventions. Otherwise the method behaves similar to sampling from the
        non-interventional joint distribution. By default samples are drawn from the
        first intervention you performed. If you intervened upon more than one node,
        you'll have swith to another intervention for sampling from the corresponding
        interventional distribution.

        Args:
            size (int): Number of samples to draw.
            which_intervention (str | int): Which interventional distribution to draw
                from. We recommend using integer counts starting from zero. But you can
                also provide the literal string here, e.g. if you intervened on say two
                nodes `x,y` then you would need to provide here: `do([x, y])`.
            additive_gaussian_noise (bool, optional): This will attach additive Gaussian noise
                to all terms without a RandomSymbol that are not source nodes. It acts merely as
                a convenience option. Variance will then be chosen according to SNR.
                Defaults to False.
            snr (None | float, optional): Signal-to-noise ratio
                \\( SNR =  \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\). Defaults to 1/2.
            source_df (None | pd.DataFrame, optional): Data frame containing source node data.
                The sample size must be at least as large as the number of samples
                you'd like to draw. Defaults to None.

        Raises:
            NotImplementedError: Raised when `which_intervention` is not of correct form.

        Returns:
            pd.DataFrame: Data frame with rows of length `size` and columns equal to the
                number of nodes in the graph.
        """
        if isinstance(which_intervention, str):
            int_choice = which_intervention
        elif isinstance(which_intervention, int):
            int_choice = self.interventions[which_intervention]
        else:
            raise NotImplementedError(
                f"which_intervention has to be \
                the literal string or an integer \
                starting at count zero indicating \
                which intervention in {self.interventions} \
                to use."
            )

        return self._sample(
            size=size,
            additive_gaussian_noise=additive_gaussian_noise,
            snr=snr,
            source_df=source_df,
            which_graph=self.mutilated_dags[int_choice],
        )

    def _sample(
        self,
        size: int,
        which_graph: nx.DiGraph,
        additive_gaussian_noise: bool = False,
        snr: None | float = 1 / 2,
        source_df: None | pd.DataFrame = None,
    ) -> pd.DataFrame:
        """Draw samples from the joint distribution that factorizes
        according to the DAG implied by the FCM fed in. To avoid
        unexpected/unintended behavior, avoid defining fully
        deterministic equation systems.
        If parameters in noise terms are additive and left unevaluated,
        they're set according to a chosen Signal-To-Noise (SNR) ratio.
        For convenience, you can add additive Gaussian noise to each equation.
        This will never overwrite any of the chosen noise distributions.
        You may also feed in a data frame for noise distributions (see below
        for more details).

        Args:
            size (int): Number of samples to draw.
            additive_gaussian_noise (bool, optional): _description_. Defaults to False.
            snr (None | float, optional): Signal-to-noise ratio
                \\( SNR =  \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\).
                Defaults to 1/2.
            source_df (None | pd.DataFrame, optional): Data frame conaining source node data.
                The sample size must be at least as large as the number of samples
                you'd like to draw. Defaults to None.

        Raises:
            AttributeError: if source node parameters are not given explicitly.
            ValueError: if source node sample size is too small.
            ValueError: if scale parameters are left unevaluated for non-additive terms.

        Returns:
            pd.DataFrame:  Data frame with rows of lenght size and columns equal to the
                number of nodes in the graph.
        """

        if source_df is not None and not self.__source_df_condition(source_df):
            raise AssertionError("Names in source_df don't match nodenames in graph.")

        df = pd.DataFrame()
        for order in self.causal_order_of(which_graph=which_graph):
            if order in self.source_nodes_of(which_graph=which_graph):
                if source_df is not None and str(order) in source_df.columns:
                    if source_df[str(order)].shape[0] < size:
                        raise ValueError(
                            "Sample size of source node data must be at least \
                            as large as the number of samples you'd like to draw."
                        )
                    df[str(order)] = source_df[str(order)].sample(
                        n=size,
                        replace=False,
                        random_state=self._random_state,
                        ignore_index=True,
                    )

                elif isinstance(which_graph.nodes[order]["term"], RandomSymbol):
                    if not self.__distribution_parameters_explicit(order, which_graph=which_graph):
                        raise AttributeError("Source node parameters need to be given explicitly.")
                    df[str(order)] = sympy_sample(
                        which_graph.nodes[order]["term"],
                        seed=self._random_state,
                        size=size,
                    )

                elif isinstance(which_graph.nodes[order]["term"], Number):
                    df[str(order)] = np.repeat(which_graph.nodes[order]["term"], repeats=size)

                else:
                    raise NotImplementedError(
                        "Source nodes need to have a fully parameterized distribution, \
                        or need to be drawn from an appropriate data frame, or fixed to \
                        a single real number."
                    )
                continue

            fcm_expr = which_graph.nodes[order]["term"]

            if fcm_expr.atoms(RandomSymbol):
                if self.__distribution_parameters_explicit(order, which_graph=which_graph):
                    df[str(fcm_expr.atoms(RandomSymbol).pop())] = sympy_sample(
                        fcm_expr.atoms(RandomSymbol).pop(),
                        size=size,
                        seed=self._random_state,
                    )
                else:
                    df[str(fcm_expr.atoms(RandomSymbol).pop())] = np.zeros(size)

            df[str(order)] = self.__eval_expression(df=df, fcm_expr=fcm_expr)

            if fcm_expr.atoms(RandomSymbol) and not self.__distribution_parameters_explicit(
                order, which_graph=which_graph
            ):
                if not fcm_expr.is_Add:
                    raise ValueError(
                        "Noise term in "
                        + str(order)
                        + "="
                        + str(fcm_expr)
                        + " not additive. Scale parameter selection via SNR \
                        makes sense only for additive noise."
                    )
                logger.warning(
                    "I'll choose the noise scale in "
                    + str(order)
                    + "="
                    + str(fcm_expr)
                    + " according to the given SNR."
                )
                noise_var = df[str(order)].var() / snr
                df[str(order)] = df[str(order)] + sympy_sample(
                    fcm_expr.atoms(RandomSymbol)
                    .pop()
                    .subs(self.__unfree_symbol(fcm_expr), np.sqrt(noise_var)),
                    size=size,
                    seed=self._random_state,
                )

            if additive_gaussian_noise:
                if fcm_expr.atoms(RandomSymbol):
                    logger.warning(
                        "Noise already defined in "
                        + str(order)
                        + "="
                        + str(fcm_expr)
                        + ". I won't override this."
                    )
                else:
                    noise = symbols("noise")
                    noise_var = df[str(order)].var() / snr
                    df[str(noise)] = self._random_state.normal(
                        loc=0, scale=np.sqrt(noise_var), size=size
                    )
                    fcm_expr = which_graph.nodes[order]["term"] + noise
                    df[str(order)] = self.__eval_expression(df=df, fcm_expr=fcm_expr)

        self.last_df = df[[str(order) for order in self.causal_order]]

        return self.last_df

    def __unfree_symbol(self, fcm_expr) -> set[Symbol]:
        random_symbs = fcm_expr.atoms(Symbol).difference(fcm_expr.free_symbols)
        return {
            unfree
            for unfree in random_symbs
            if str(unfree) != str(fcm_expr.atoms(RandomSymbol).pop())
        }.pop()

    def __eval_expression(self, df: pd.DataFrame, fcm_expr: Expr) -> pd.DataFrame:
        """Eval given fcm_expression with the values in given dataframe

        Args:
            df (pd.DataFrame): Data frame.
            fcm_expr (Expr): Sympy expression.

        Returns:
            pd.DataFrame: Data frame after eval.
        """

        correct_order = list(ordered(fcm_expr.free_symbols))  # self.__return_ordered_args(fcm_expr)
        cols = [str(col) for col in correct_order]
        evaluator = lambdify(correct_order, fcm_expr, "scipy")

        return evaluator(*[df[col] for col in cols])

    def __distribution_parameters_explicit(self, order: Symbol, which_graph: nx.DiGraph) -> bool:
        """Returns true if distribution parameters
        are given explicitly, not symbolically.

        Args:
            order (node): node in graph

        Returns:
            bool:
        """
        return len(which_graph.nodes[order]["term"].free_symbols) == len(
            which_graph.nodes[order]["term"].atoms(Symbol)
        )

    def __source_df_condition(self, source_df: pd.DataFrame) -> bool:
        """Returns true if source_df colnames and graph nodenames agree.

        Args:
            source_df (None | pd.DataFrame): data frame containing source node data.

        Returns:
            bool: True if names agree
        """
        return {str(col) for col in source_df.columns}.issubset(
            {str(node) for node in self.source_nodes}
        )

    def intervene_on(self, nodes_values: dict[Symbol, RandomSymbol | float]):
        """Specify hard or soft intervention. If you want to intervene
        upon more than one node provide a list of nodes to intervene on
        and a list of corresponding values to set these nodes to.
        (see example). The mutilated dag will automatically be
        stored in `mutiliated_dags`.

        Args:
            nodes_values (dict[Symbol, RandomSymbol | float]): either single real
                number or RandmSymbol. If you provide more than one
                intervention just provide more key-value pairs.

        Raises:
            AssertionError: If node(s) are not in the graph

        Example:
            ```python
            x,y = symbols("x,y")
            eq_x = Eq(x, Gamma("source", 1,1))
            eq_y = Eq(y, 4*x**3 + Uniform("noise", left=-0.5, right=0.5))

            example_fcm = FCM()
            example_fcm.input_fcm([eq_x,eq_y])
            # Hard intervention
            example_fcm.intervene_on(nodes_values = {y : 4})
            # Soft intervention
            example_fcm.intervene_on(nodes_values = {y : Normal("noise",0,1)})

            ```
        """

        if not set(nodes_values.keys()).issubset(set(self.nodes)):
            raise AssertionError(
                "One or more nodes you want to intervene upon are not in the graph."
            )

        mutilated_dag = self.graph.copy()

        for node, value in nodes_values.items():
            intervention = Eq(node, value)
            old_incoming = self.parents(of_node=node)
            edges_to_remove = [(old, node) for old in old_incoming]
            mutilated_dag.remove_edges_from(edges_to_remove)
            mutilated_dag.nodes[node]["term"] = intervention.rhs

        self.mutilated_dags[
            f"do({list(nodes_values.keys())})"
        ] = mutilated_dag  # specifiying the same set twice will override

    def show(self, header: str | None = None, with_nodenames: bool = True) -> plt:
        """Plots the current DAG.

        Args:
            header (str | None, optional): Header for the DAG. Defaults to None.
            with_nodenames (bool, optional): Whether or not to use nodenames as
                labels in the plot. Defaults to True.

        Returns:
            plt: Plot of the DAG.
        """
        if header is None:
            header = ""
        return self._show(which_graph=self.graph, header=header, with_nodenames=with_nodenames)

    def show_mutilated_dag(
        self, which_intervention: str | int = 0, with_nodenames: bool = True
    ) -> plt:
        """Plot mutilated DAG

        Args:
            which_intervention (str | int, optional): Which interventional distribution
                should be represented by a DAG. Defaults to 0.
            with_nodenames (bool, optional): Whether or not to use nodenames as
                labels in the plot. Defaults to True.

        Returns:
            plt: Plot of the mutilated DAG.
        """
        if isinstance(which_intervention, int):
            which_intervention = self.interventions[which_intervention]

        return self._show(
            which_graph=self.mutilated_dags[which_intervention],
            header=which_intervention,
            with_nodenames=with_nodenames,
        )

    def _show(self, which_graph: nx.DiGraph, header: str, with_nodenames: bool):
        """Plots the graph by giving extra weight to nodes
        with high in- and out-degree.
        """
        cmap = plt.get_cmap("Blues")
        fig, ax = plt.subplots()
        center: np.ndarray = np.array([0, 0])
        pos = nx.spring_layout(
            which_graph,
            center=center,
            seed=10,
            k=50,
        )

        labels = {}
        for node in self.nodes:
            labels[node] = node

        max_in_degree = max([d for _, d in which_graph.in_degree()])
        max_out_degree = max([d for _, d in which_graph.out_degree()])

        nx.draw_networkx_nodes(
            which_graph,
            pos=pos,
            ax=ax,
            cmap=cmap,
            vmin=-0.2,
            vmax=1,
            node_color=[
                (d + 10) / (max_in_degree + 10) for _, d in which_graph.in_degree(self.nodes)
            ],
            node_size=[
                500 * (d + 1) / (max_out_degree + 1) for _, d in which_graph.out_degree(self.nodes)
            ],
        )

        if with_nodenames:
            nx.draw_networkx_labels(
                which_graph,
                pos=pos,
                labels=labels,
                font_size=8,
                font_color="w",
                alpha=0.4,
            )

        nx.draw_networkx_edges(
            which_graph,
            pos=pos,
            ax=ax,
            alpha=0.2,
            arrowsize=8,
            width=0.5,
            connectionstyle="arc3,rad=0.3",
        )

        ax.text(
            center[0],
            center[1] + 1.2,
            f"{header}",
            horizontalalignment="center",
        )

        ax.axis("off")

causal_order: list[Symbol] property

Returns the causal order of the current graph. Note that this order is in general not unique. To ensure uniqueness, we additionally sort lexicograpically.

Returns:

Type Description
list[Symbol]

list[Symbol]: Causal order

edges: list[tuple] property

Edges in the graph.

Returns:

Type Description
list[tuple]

list[tuple]

ground_truth: pd.DataFrame property

Returns the current ground truth as pandas adjacency.

Returns:

Type Description
DataFrame

pd.DataFrame: Adjacenccy matrix.

interventions: list property

Returns all interventions performed on the original graph

Returns:

Name Type Description
list list

list of intervened upon nodes in do(x) notation.

nodes: list[Symbol] property

Nodes in the graph.

Returns:

Type Description
list[Symbol]

list[str]

num_edges: int property

Number of edges in the graph

Returns:

Type Description
int

int

num_nodes: int property

Number of nodes in the graph

Returns:

Type Description
int

int

source_nodes: list property

Returns source nodes in the current DAG.

Returns:

Name Type Description
list list

List of source nodes.

sparsity: float property

Sparsity of the graph

Returns:

Name Type Description
float float

in [0,1]

__distribution_parameters_explicit(order, which_graph)

Returns true if distribution parameters are given explicitly, not symbolically.

Parameters:

Name Type Description Default
order node

node in graph

required

Returns:

Name Type Description
bool bool
Source code in causalAssembly/models_fcm.py
571
572
573
574
575
576
577
578
579
580
581
582
583
def __distribution_parameters_explicit(self, order: Symbol, which_graph: nx.DiGraph) -> bool:
    """Returns true if distribution parameters
    are given explicitly, not symbolically.

    Args:
        order (node): node in graph

    Returns:
        bool:
    """
    return len(which_graph.nodes[order]["term"].free_symbols) == len(
        which_graph.nodes[order]["term"].atoms(Symbol)
    )

__eval_expression(df, fcm_expr)

Eval given fcm_expression with the values in given dataframe

Parameters:

Name Type Description Default
df DataFrame

Data frame.

required
fcm_expr Expr

Sympy expression.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame after eval.

Source code in causalAssembly/models_fcm.py
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
def __eval_expression(self, df: pd.DataFrame, fcm_expr: Expr) -> pd.DataFrame:
    """Eval given fcm_expression with the values in given dataframe

    Args:
        df (pd.DataFrame): Data frame.
        fcm_expr (Expr): Sympy expression.

    Returns:
        pd.DataFrame: Data frame after eval.
    """

    correct_order = list(ordered(fcm_expr.free_symbols))  # self.__return_ordered_args(fcm_expr)
    cols = [str(col) for col in correct_order]
    evaluator = lambdify(correct_order, fcm_expr, "scipy")

    return evaluator(*[df[col] for col in cols])

__source_df_condition(source_df)

Returns true if source_df colnames and graph nodenames agree.

Parameters:

Name Type Description Default
source_df None | DataFrame

data frame containing source node data.

required

Returns:

Name Type Description
bool bool

True if names agree

Source code in causalAssembly/models_fcm.py
585
586
587
588
589
590
591
592
593
594
595
596
def __source_df_condition(self, source_df: pd.DataFrame) -> bool:
    """Returns true if source_df colnames and graph nodenames agree.

    Args:
        source_df (None | pd.DataFrame): data frame containing source node data.

    Returns:
        bool: True if names agree
    """
    return {str(col) for col in source_df.columns}.issubset(
        {str(node) for node in self.source_nodes}
    )

causal_order_of(which_graph)

Returns the causal order of the chosen graph. Note that this order is in general not unique. To ensure uniqueness, we additionally sort lexicograpically.

Returns:

Type Description
list[Symbol]

list[Symbol]: Causal order

Source code in causalAssembly/models_fcm.py
200
201
202
203
204
205
206
207
208
def causal_order_of(self, which_graph: nx.DiGraph) -> list[Symbol]:
    """Returns the causal order of the chosen graph.
    Note that this order is in general not unique. To
    ensure uniqueness, we additionally sort lexicograpically.

    Returns:
        list[Symbol]: Causal order
    """
    return list(nx.lexicographical_topological_sort(which_graph, key=lambda x: str(x)))

display_functions()

Displays all functional assignments inputted into the FCM.

Returns:

Name Type Description
dict dict

Dict with keys equal to nodes and values equal to functional assignments.

Source code in causalAssembly/models_fcm.py
283
284
285
286
287
288
289
290
291
292
293
294
def display_functions(self) -> dict:
    """Displays all functional assignments inputted into the FCM.

    Returns:
        dict: Dict with keys equal to nodes and values equal to
            functional assignments.
    """
    fcm_dict = {}
    for node in self.causal_order:
        fcm_dict[node] = self.graph.nodes[node]["term"]

    return fcm_dict

function_of(node)

Returns functional assignment for node in question.

Parameters:

Name Type Description Default
node Symbol

node corresponding to lhs.

required

Returns:

Name Type Description
dict dict

key is node and value rhs of functional assignment.

Source code in causalAssembly/models_fcm.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def function_of(self, node: Symbol) -> dict:
    """Returns functional assignment for node in question.

    Args:
        node (Symbol): node corresponding to lhs.

    Returns:
        dict: key is node and value rhs of functional assignment.
    """
    if node not in self.graph.nodes:
        if node in [str(node) for node in self.nodes]:
            raise AssertionError(
                "You probably defined a string. Node has to be a symbol, check out",
                list(self.graph.nodes),
            )
        else:
            raise AssertionError("Node has to be in the graph")

    return {node: self.graph.nodes[node]["term"]}

input_fcm(fcm)

Automatically builds up DAG according to the FCM fed in. Args: fcm (list): list of sympy equations generated as: [python] x,y = symbols('x,y') term_x = Eq(x, Normal('x', 0,1)) term_y = Eq(y, 2*x**2*Normal('noise', 0,1)) fcm = [term_x, term_y]

Source code in causalAssembly/models_fcm.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def input_fcm(self, fcm: list[Eq]):
    """
    Automatically builds up DAG according to the FCM fed in.
    Args:
        fcm (list): list of sympy equations generated as:
                ```[python]
                x,y = symbols('x,y')
                term_x = Eq(x, Normal('x', 0,1))
                term_y = Eq(y, 2*x**2*Normal('noise', 0,1))
                fcm = [term_x, term_y]
                ```
    """
    nodes_implied = [node.lhs.free_symbols.pop() for node in fcm]
    edges_implied = []
    for term in fcm:
        if not isinstance(term.rhs, RandomSymbol):
            if term.rhs.atoms(RandomSymbol):
                edges_implied.extend(
                    [
                        (atom, term.lhs)
                        for atom in term.rhs.free_symbols
                        if str(atom) != str(term.rhs.atoms(RandomSymbol).pop())
                    ]
                )
            else:
                edges_implied.extend([(atom, term.lhs) for atom in term.rhs.atoms(Symbol)])

    g = self.graph
    g.add_nodes_from(nodes_implied)
    g.add_edges_from(edges_implied)

    term_dict = {}
    for term in fcm:
        term_dict[term.lhs] = {"term": term.rhs}

    nx.set_node_attributes(g, term_dict)

intervene_on(nodes_values)

Specify hard or soft intervention. If you want to intervene upon more than one node provide a list of nodes to intervene on and a list of corresponding values to set these nodes to. (see example). The mutilated dag will automatically be stored in mutiliated_dags.

Parameters:

Name Type Description Default
nodes_values dict[Symbol, RandomSymbol | float]

either single real number or RandmSymbol. If you provide more than one intervention just provide more key-value pairs.

required

Raises:

Type Description
AssertionError

If node(s) are not in the graph

Example
x,y = symbols("x,y")
eq_x = Eq(x, Gamma("source", 1,1))
eq_y = Eq(y, 4*x**3 + Uniform("noise", left=-0.5, right=0.5))

example_fcm = FCM()
example_fcm.input_fcm([eq_x,eq_y])
# Hard intervention
example_fcm.intervene_on(nodes_values = {y : 4})
# Soft intervention
example_fcm.intervene_on(nodes_values = {y : Normal("noise",0,1)})

Source code in causalAssembly/models_fcm.py
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
def intervene_on(self, nodes_values: dict[Symbol, RandomSymbol | float]):
    """Specify hard or soft intervention. If you want to intervene
    upon more than one node provide a list of nodes to intervene on
    and a list of corresponding values to set these nodes to.
    (see example). The mutilated dag will automatically be
    stored in `mutiliated_dags`.

    Args:
        nodes_values (dict[Symbol, RandomSymbol | float]): either single real
            number or RandmSymbol. If you provide more than one
            intervention just provide more key-value pairs.

    Raises:
        AssertionError: If node(s) are not in the graph

    Example:
        ```python
        x,y = symbols("x,y")
        eq_x = Eq(x, Gamma("source", 1,1))
        eq_y = Eq(y, 4*x**3 + Uniform("noise", left=-0.5, right=0.5))

        example_fcm = FCM()
        example_fcm.input_fcm([eq_x,eq_y])
        # Hard intervention
        example_fcm.intervene_on(nodes_values = {y : 4})
        # Soft intervention
        example_fcm.intervene_on(nodes_values = {y : Normal("noise",0,1)})

        ```
    """

    if not set(nodes_values.keys()).issubset(set(self.nodes)):
        raise AssertionError(
            "One or more nodes you want to intervene upon are not in the graph."
        )

    mutilated_dag = self.graph.copy()

    for node, value in nodes_values.items():
        intervention = Eq(node, value)
        old_incoming = self.parents(of_node=node)
        edges_to_remove = [(old, node) for old in old_incoming]
        mutilated_dag.remove_edges_from(edges_to_remove)
        mutilated_dag.nodes[node]["term"] = intervention.rhs

    self.mutilated_dags[
        f"do({list(nodes_values.keys())})"
    ] = mutilated_dag  # specifiying the same set twice will override

interventional_amat(which_intervention)

Returns the adjacency matrix of a chosen mutilated DAG.

Parameters:

Name Type Description Default
which_intervention int | str

Integer count of your chosen intervention or literal string.

required

Raises:

Type Description
ValueError

"The intervention you provide does not exist."

Returns:

Type Description
DataFrame

pd.DataFrame: Adjacency matrix.

Source code in causalAssembly/models_fcm.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def interventional_amat(self, which_intervention: int | str) -> pd.DataFrame:
    """Returns the adjacency matrix of a chosen mutilated DAG.

    Args:
        which_intervention (int | str): Integer count of your chosen intervention or
            literal string.

    Raises:
        ValueError: "The intervention you provide does not exist."

    Returns:
        pd.DataFrame: Adjacency matrix.
    """
    if isinstance(which_intervention, str) and which_intervention not in self.interventions:
        raise ValueError("The intervention you provide does not exist.")

    if isinstance(which_intervention, int) and which_intervention > len(self.interventions):
        raise ValueError("The intervention you index does not exist.")

    if isinstance(which_intervention, int):
        which_intervention = self.interventions[which_intervention]

    mutilated_dag = self.mutilated_dags[which_intervention].copy()
    return nx.to_pandas_adjacency(mutilated_dag, weight=None)

interventional_sample(size, which_intervention=0, additive_gaussian_noise=False, snr=1 / 2, source_df=None)

Draw samples from the interventional distribution that factorizes according to the mutilated DAG after performing one or multiple interventions. Otherwise the method behaves similar to sampling from the non-interventional joint distribution. By default samples are drawn from the first intervention you performed. If you intervened upon more than one node, you'll have swith to another intervention for sampling from the corresponding interventional distribution.

Parameters:

Name Type Description Default
size int

Number of samples to draw.

required
which_intervention str | int

Which interventional distribution to draw from. We recommend using integer counts starting from zero. But you can also provide the literal string here, e.g. if you intervened on say two nodes x,y then you would need to provide here: do([x, y]).

0
additive_gaussian_noise bool

This will attach additive Gaussian noise to all terms without a RandomSymbol that are not source nodes. It acts merely as a convenience option. Variance will then be chosen according to SNR. Defaults to False.

False
snr None | float

Signal-to-noise ratio \( SNR = \frac{\text{Var}(\hat{X})}{\hat\sigma^2}. \). Defaults to 1/2.

1 / 2
source_df None | DataFrame

Data frame containing source node data. The sample size must be at least as large as the number of samples you'd like to draw. Defaults to None.

None

Raises:

Type Description
NotImplementedError

Raised when which_intervention is not of correct form.

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame with rows of length size and columns equal to the number of nodes in the graph.

Source code in causalAssembly/models_fcm.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def interventional_sample(
    self,
    size: int,
    which_intervention: str | int = 0,
    additive_gaussian_noise: bool = False,
    snr: None | float = 1 / 2,
    source_df: None | pd.DataFrame = None,
) -> pd.DataFrame:
    """Draw samples from the interventional distribution that factorizes
    according to the mutilated DAG after performing one or multiple
    interventions. Otherwise the method behaves similar to sampling from the
    non-interventional joint distribution. By default samples are drawn from the
    first intervention you performed. If you intervened upon more than one node,
    you'll have swith to another intervention for sampling from the corresponding
    interventional distribution.

    Args:
        size (int): Number of samples to draw.
        which_intervention (str | int): Which interventional distribution to draw
            from. We recommend using integer counts starting from zero. But you can
            also provide the literal string here, e.g. if you intervened on say two
            nodes `x,y` then you would need to provide here: `do([x, y])`.
        additive_gaussian_noise (bool, optional): This will attach additive Gaussian noise
            to all terms without a RandomSymbol that are not source nodes. It acts merely as
            a convenience option. Variance will then be chosen according to SNR.
            Defaults to False.
        snr (None | float, optional): Signal-to-noise ratio
            \\( SNR =  \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\). Defaults to 1/2.
        source_df (None | pd.DataFrame, optional): Data frame containing source node data.
            The sample size must be at least as large as the number of samples
            you'd like to draw. Defaults to None.

    Raises:
        NotImplementedError: Raised when `which_intervention` is not of correct form.

    Returns:
        pd.DataFrame: Data frame with rows of length `size` and columns equal to the
            number of nodes in the graph.
    """
    if isinstance(which_intervention, str):
        int_choice = which_intervention
    elif isinstance(which_intervention, int):
        int_choice = self.interventions[which_intervention]
    else:
        raise NotImplementedError(
            f"which_intervention has to be \
            the literal string or an integer \
            starting at count zero indicating \
            which intervention in {self.interventions} \
            to use."
        )

    return self._sample(
        size=size,
        additive_gaussian_noise=additive_gaussian_noise,
        snr=snr,
        source_df=source_df,
        which_graph=self.mutilated_dags[int_choice],
    )

parents(of_node)

Return parents of node in question.

Parameters:

Name Type Description Default
of_node str

Node in question.

required

Returns:

Type Description
list[Symbol]

list[str]: parent set.

Source code in causalAssembly/models_fcm.py
177
178
179
180
181
182
183
184
185
186
def parents(self, of_node: Symbol) -> list[Symbol]:
    """Return parents of node in question.

    Args:
        of_node (str): Node in question.

    Returns:
        list[str]: parent set.
    """
    return list(self.graph.predecessors(of_node))

parents_of(node, which_graph)

Return parents of node in question for a chosen DAG.

Parameters:

Name Type Description Default
node Symbol

node whose parents to return.

required
which_graph DiGraph

which graph along the interventions.

required

Returns:

Type Description
list[Symbol]

list[Symbol]: list of parents.

Source code in causalAssembly/models_fcm.py
188
189
190
191
192
193
194
195
196
197
198
def parents_of(self, node: Symbol, which_graph: nx.DiGraph) -> list[Symbol]:
    """Return parents of node in question for a chosen DAG.

    Args:
        node (Symbol): node whose parents to return.
        which_graph (nx.DiGraph): which graph along the interventions.

    Returns:
        list[Symbol]: list of parents.
    """
    return list(which_graph.predecessors(node))

sample(size, additive_gaussian_noise=False, snr=1 / 2, source_df=None)

Draw samples from the joint distribution that factorizes according to the DAG implied by the FCM fed in. To avoid unexpected/unintended behavior, avoid defining fully deterministic equation systems. If parameters in noise terms are additive and left unevaluated, they're set according to a chosen Signal-To-Noise (SNR) ratio. For convenience, you can add additive Gaussian noise to each equation. This will never overwrite any of the chosen noise distributions. You may also feed in a data frame for noise distributions (see below for more details).

Parameters:

Name Type Description Default
size int

Number of samples to draw.

required
additive_gaussian_noise bool

This will attach additive Gaussian noise to all terms without a RandomSymbol that are not source nodes. It acts merely as a convenience option. Variance will then be chosen according to SNR. Defaults to False.

False
snr None | float

Signal-to-noise ratio \( SNR = \frac{\text{Var}(\hat{X})}{\hat\sigma^2}. \). Defaults to 1/2.

1 / 2
source_df None | DataFrame

Data frame containing source node data. The sample size must be at least as large as the number of samples you'd like to draw. Defaults to None.

None

Raises:

Type Description
AttributeError

if source node parameters are not given explicitly.

ValueError

if source node sample size is too small.

ValueError

if scale parameters are left unevaluated for non-additive terms.

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame with rows of length size and columns equal to the number of nodes in the graph.

Source code in causalAssembly/models_fcm.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def sample(
    self,
    size: int,
    additive_gaussian_noise: bool = False,
    snr: None | float = 1 / 2,
    source_df: None | pd.DataFrame = None,
) -> pd.DataFrame:
    """Draw samples from the joint distribution that factorizes
    according to the DAG implied by the FCM fed in. To avoid
    unexpected/unintended behavior, avoid defining fully
    deterministic equation systems.
    If parameters in noise terms are additive and left unevaluated,
    they're set according to a chosen Signal-To-Noise (SNR) ratio.
    For convenience, you can add additive Gaussian noise to each equation.
    This will never overwrite any of the chosen noise distributions.
    You may also feed in a data frame for noise distributions (see below
    for more details).

    Args:
        size (int): Number of samples to draw.
        additive_gaussian_noise (bool, optional): This will attach additive
            Gaussian noise to all terms without a RandomSymbol that are not
            source nodes. It acts merely as a convenience option. Variance
            will then be chosen according to SNR. Defaults to False.
        snr (None | float, optional): Signal-to-noise ratio
            \\( SNR =  \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\).
            Defaults to 1/2.
        source_df (None | pd.DataFrame, optional): Data frame containing source node data.
            The sample size must be at least as large as the number of samples
            you'd like to draw. Defaults to None.

    Raises:
        AttributeError: if source node parameters are not given explicitly.
        ValueError: if source node sample size is too small.
        ValueError: if scale parameters are left unevaluated for non-additive terms.

    Returns:
        pd.DataFrame:  Data frame with rows of length `size` and columns equal to the
            number of nodes in the graph.
    """
    return self._sample(
        size=size,
        additive_gaussian_noise=additive_gaussian_noise,
        snr=snr,
        source_df=source_df,
        which_graph=self.graph,
    )

show(header=None, with_nodenames=True)

Plots the current DAG.

Parameters:

Name Type Description Default
header str | None

Header for the DAG. Defaults to None.

None
with_nodenames bool

Whether or not to use nodenames as labels in the plot. Defaults to True.

True

Returns:

Name Type Description
plt matplotlib

Plot of the DAG.

Source code in causalAssembly/models_fcm.py
647
648
649
650
651
652
653
654
655
656
657
658
659
660
def show(self, header: str | None = None, with_nodenames: bool = True) -> plt:
    """Plots the current DAG.

    Args:
        header (str | None, optional): Header for the DAG. Defaults to None.
        with_nodenames (bool, optional): Whether or not to use nodenames as
            labels in the plot. Defaults to True.

    Returns:
        plt: Plot of the DAG.
    """
    if header is None:
        header = ""
    return self._show(which_graph=self.graph, header=header, with_nodenames=with_nodenames)

show_mutilated_dag(which_intervention=0, with_nodenames=True)

Plot mutilated DAG

Parameters:

Name Type Description Default
which_intervention str | int

Which interventional distribution should be represented by a DAG. Defaults to 0.

0
with_nodenames bool

Whether or not to use nodenames as labels in the plot. Defaults to True.

True

Returns:

Name Type Description
plt matplotlib

Plot of the mutilated DAG.

Source code in causalAssembly/models_fcm.py
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def show_mutilated_dag(
    self, which_intervention: str | int = 0, with_nodenames: bool = True
) -> plt:
    """Plot mutilated DAG

    Args:
        which_intervention (str | int, optional): Which interventional distribution
            should be represented by a DAG. Defaults to 0.
        with_nodenames (bool, optional): Whether or not to use nodenames as
            labels in the plot. Defaults to True.

    Returns:
        plt: Plot of the mutilated DAG.
    """
    if isinstance(which_intervention, int):
        which_intervention = self.interventions[which_intervention]

    return self._show(
        which_graph=self.mutilated_dags[which_intervention],
        header=which_intervention,
        with_nodenames=with_nodenames,
    )

source_nodes_of(which_graph)

Returns the source nodes of a chosen graph. This is mainly for choosing different mutilated DAGs.

Parameters:

Name Type Description Default
which_graph DiGraph

DAG from which source nodes should be returned.

required

Returns:

Name Type Description
list list

List of nodes.

Source code in causalAssembly/models_fcm.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def source_nodes_of(self, which_graph: nx.DiGraph) -> list:
    """Returns the source nodes of a chosen graph. This is mainly for
    choosing different mutilated DAGs.

    Args:
        which_graph (nx.DiGraph): DAG from which source nodes should be returned.

    Returns:
        list: List of nodes.
    """
    return [
        node
        for node in which_graph.nodes
        if len(self.parents_of(node=node, which_graph=which_graph)) == 0
    ]

Utility classes and functions related to causalAssembly. Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

DAGmetrics

Class to calculate performance metrics for DAGs. Make sure that the ground truth and the estimated DAG have the same order of rows/columns. If these objects are nx.DiGraphs, make sure that graph.nodes() have the same oder or pass a new nodelist to the class when initiating. The same can be done for pd.DataFrames. In case truth and est are np.ndarray objects it is the users responsibility to make sure that the objects are indeed comparable.

Source code in causalAssembly/metrics.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class DAGmetrics:
    """Class to calculate performance metrics for DAGs.
    Make sure that the ground truth and the estimated DAG have the same order of
    rows/columns. If these objects are nx.DiGraphs, make sure that graph.nodes()
    have the same oder or pass a new nodelist to the class when initiating. The
    same can be done for pd.DataFrames. In case `truth` and `est` are np.ndarray
    objects it is the users responsibility to make sure that the objects are
    indeed comparable.
    """

    def __init__(
        self,
        truth: nx.DiGraph | pd.DataFrame | np.ndarray,
        est: nx.DiGraph | pd.DataFrame | np.ndarray,
        nodelist: list = None,
    ):
        if not isinstance(truth, nx.DiGraph | pd.DataFrame | np.ndarray):
            raise TypeError("Ground truth graph has to be one of the permitted classes.")

        if not isinstance(est, nx.DiGraph | pd.DataFrame | np.ndarray):
            raise TypeError("Estimated graph has to be one of the permitted classes")

        self.truth = DAGmetrics._convert_to_numpy(truth, nodelist=nodelist)
        self.est = DAGmetrics._convert_to_numpy(est, nodelist=nodelist)

        self.metrics = None

    def _calculate_scores(self):
        """Calculate Precision, Recall and F1 and g score

        Return:
        precision: float
            TP/(TP + FP)
        recall: float
            TP/(TP + FN)
        f1: float
            2*(recall*precision)/(recall+precision)
        gscore: float
            max(0, (TP-FP))/(TP+FN)
        """
        assert self.est.shape == self.truth.shape and self.est.shape[0] == self.est.shape[1]
        TP = np.where((self.est + self.truth) == 2, 1, 0).sum(axis=1).sum()
        TP_FP = self.est.sum(axis=1).sum()
        FP = TP_FP - TP
        TP_FN = self.truth.sum(axis=1).sum()

        precision = TP / max(TP_FP, 1)
        recall = TP / max(TP_FN, 1)
        F1 = 2 * (recall * precision) / max((recall + precision), 1)
        gscore = max(0, (TP - FP)) / max(TP_FN, 1)

        return {"precision": precision, "recall": recall, "f1": F1, "gscore": gscore}

    def _shd(self, count_anticausal_twice: bool = True):
        """Calculate Structural Hamming Distance (SHD).

        Args:
            count_anticausal_twice (bool, optional): If edge is pointing in the wrong direction
                it's also missing in the right direction and is counted twice. Defaults to True.
        """
        dist = np.abs(self.truth - self.est)
        if count_anticausal_twice:
            return np.sum(dist)
        else:
            dist = dist + dist.transpose()
            dist[dist > 1] = 1
            return np.sum(dist) / 2

    def collect_metrics(self) -> dict[str, float | int]:
        """Collects all metrics defined in this class in a dict.

        Returns:
            dict[str, float|int]: Metrics calculated
        """
        metrics = self._calculate_scores()
        metrics["shd"] = self._shd()
        self.metrics = metrics

    @classmethod
    def _convert_to_numpy(
        cls,
        graph: nx.DiGraph | pd.DataFrame | np.ndarray,
        nodelist: list = None,
    ):
        if isinstance(graph, np.ndarray):
            return copy.deepcopy(graph)
        elif isinstance(graph, pd.DataFrame):
            if nodelist:
                return copy.deepcopy(graph.reindex(nodelist)[nodelist].to_numpy())
            else:
                return copy.deepcopy(graph.to_numpy())
        elif isinstance(graph, nx.DiGraph):
            return nx.to_numpy_array(G=graph, nodelist=nodelist)

collect_metrics()

Collects all metrics defined in this class in a dict.

Returns:

Type Description
dict[str, float | int]

dict[str, float|int]: Metrics calculated

Source code in causalAssembly/metrics.py
90
91
92
93
94
95
96
97
98
def collect_metrics(self) -> dict[str, float | int]:
    """Collects all metrics defined in this class in a dict.

    Returns:
        dict[str, float|int]: Metrics calculated
    """
    metrics = self._calculate_scores()
    metrics["shd"] = self._shd()
    self.metrics = metrics

Utility classes and functions related to causalAssembly. Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

DRF

Wrapper around the corresponding R package: Distributional Random Forests (Cevid et al., 2020). Closely adopted from their python wrapper.

Source code in causalAssembly/drf_fitting.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class DRF:
    """Wrapper around the corresponding R package:
    Distributional Random Forests (Cevid et al., 2020).
    Closely adopted from their python wrapper."""

    def __init__(self, **fit_params):
        self.fit_params = fit_params

    def fit(self, X: pd.DataFrame, Y: pd.DataFrame):
        """Fit DRF in order to estimate conditional
        distribution P(Y|X=x).

        Args:
            X (pd.DataFrame): Conditioning set.
            Y (pd.DataFrame): Variable of interest (can be vector-valued).
        """
        self.X_train = X
        self.Y_train = Y

        X_r = ro.conversion.py2rpy(X)
        Y_r = ro.conversion.py2rpy(Y)
        self.r_fit_object = drf_r_package.drf(X_r, Y_r, **self.fit_params)

    def produce_sample(
        self,
        newdata: pd.DataFrame,
        random_state: np.random.Generator,
        n: int = 1,
    ) -> np.ndarray:
        """Sample data from fitted drf.

        Args:
            newdata (pd.DataFrame): Data samples to predict from.
            random_state (np.random.Generator): control random state.
            n (int, optional): Number of n-samples to draw. Defaults to 1.

        Returns:
            np.ndarray: New predicted samlpe of Y.
        """
        newdata_r = ro.conversion.py2rpy(newdata)
        r_output = drf_r_package.predict_drf(self.r_fit_object, newdata_r)

        weights = base_r_package.as_matrix(r_output[0])

        Y = pd.DataFrame(base_r_package.as_matrix(r_output[1]))
        Y = Y.apply(pd.Series)

        sample = np.zeros((newdata.shape[0], Y.shape[1], n))
        for i in range(newdata.shape[0]):
            for j in range(n):
                ids = random_state.choice(range(Y.shape[0]), 1, p=weights[i, :])[0]
                sample[i, :, j] = Y.iloc[ids, :]

        return sample[:, 0, 0]

fit(X, Y)

Fit DRF in order to estimate conditional distribution P(Y|X=x).

Parameters:

Name Type Description Default
X DataFrame

Conditioning set.

required
Y DataFrame

Variable of interest (can be vector-valued).

required
Source code in causalAssembly/drf_fitting.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def fit(self, X: pd.DataFrame, Y: pd.DataFrame):
    """Fit DRF in order to estimate conditional
    distribution P(Y|X=x).

    Args:
        X (pd.DataFrame): Conditioning set.
        Y (pd.DataFrame): Variable of interest (can be vector-valued).
    """
    self.X_train = X
    self.Y_train = Y

    X_r = ro.conversion.py2rpy(X)
    Y_r = ro.conversion.py2rpy(Y)
    self.r_fit_object = drf_r_package.drf(X_r, Y_r, **self.fit_params)

produce_sample(newdata, random_state, n=1)

Sample data from fitted drf.

Parameters:

Name Type Description Default
newdata DataFrame

Data samples to predict from.

required
random_state Generator

control random state.

required
n int

Number of n-samples to draw. Defaults to 1.

1

Returns:

Type Description
ndarray

np.ndarray: New predicted samlpe of Y.

Source code in causalAssembly/drf_fitting.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def produce_sample(
    self,
    newdata: pd.DataFrame,
    random_state: np.random.Generator,
    n: int = 1,
) -> np.ndarray:
    """Sample data from fitted drf.

    Args:
        newdata (pd.DataFrame): Data samples to predict from.
        random_state (np.random.Generator): control random state.
        n (int, optional): Number of n-samples to draw. Defaults to 1.

    Returns:
        np.ndarray: New predicted samlpe of Y.
    """
    newdata_r = ro.conversion.py2rpy(newdata)
    r_output = drf_r_package.predict_drf(self.r_fit_object, newdata_r)

    weights = base_r_package.as_matrix(r_output[0])

    Y = pd.DataFrame(base_r_package.as_matrix(r_output[1]))
    Y = Y.apply(pd.Series)

    sample = np.zeros((newdata.shape[0], Y.shape[1], n))
    for i in range(newdata.shape[0]):
        for j in range(n):
            ids = random_state.choice(range(Y.shape[0]), 1, p=weights[i, :])[0]
            sample[i, :, j] = Y.iloc[ids, :]

    return sample[:, 0, 0]

fit_drf(graph, data)

Fit distributional random forests to the factorization implied by the current graph Args: data (pd.DataFrame): Columns of dataframe need to match name and order of the graph

Raises:

Type Description
ValueError

Raises error if columns don't meet this requirement

Returns:

Type Description
dict

dict of fitted DRFs.

Source code in causalAssembly/drf_fitting.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def fit_drf(graph: ProductionLineGraph | ProcessCell | DAG, data: pd.DataFrame):
    """Fit distributional random forests to the
    factorization implied by the current graph
    Args:
        data (pd.DataFrame): Columns of dataframe need to match name and order of the graph

    Raises:
        ValueError: Raises error if columns don't meet this requirement

    Returns:
        (dict): dict of fitted DRFs.
    """
    tempdata = data.copy()

    if set(graph.nodes).issubset(tempdata.columns):
        tempdata = tempdata[graph.nodes]

    else:
        raise ValueError("Data columns don't match node names.")

    drf_dict = {}
    for node in graph.nodes:
        parents = graph.parents(of_node=node)
        if not parents:
            drf_dict[node] = gaussian_kde(tempdata[node].to_numpy())
        elif parents:
            drf_object = DRF(
                min_node_size=15, num_trees=2000, splitting_rule="FourierMMD"
            )  # default setting as suggested in the paper
            X = tempdata[parents]
            Y = tempdata[node]
            drf_object.fit(X, Y)
            drf_dict[node] = drf_object
        else:
            raise ValueError("Unexpected behavior in DRF. Check whether data and DAG match?")
    return drf_dict

Utility classes and functions related to causalAssembly. Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

PDAG

Class for dealing with partially directed graph i.e. graphs that contain both directed and undirected edges.

Source code in causalAssembly/pdag.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
class PDAG:
    """
    Class for dealing with partially directed graph i.e.
    graphs that contain both directed and undirected edges.
    """

    def __init__(
        self,
        nodes: list | None = None,
        dir_edges: list[tuple] | None = None,
        undir_edges: list[tuple] | None = None,
    ):
        if nodes is None:
            nodes = []
        if dir_edges is None:
            dir_edges = []
        if undir_edges is None:
            undir_edges = []

        self._nodes = set(nodes)
        self._undir_edges = set()
        self._dir_edges = set()
        self._parents = defaultdict(set)
        self._children = defaultdict(set)
        self._neighbors = defaultdict(set)
        self._undirected_neighbors = defaultdict(set)

        for dir_edge in dir_edges:
            self._add_dir_edge(*dir_edge)
        for unir_edge in undir_edges:
            self._add_undir_edge(*unir_edge)

    def _add_dir_edge(self, i, j):
        self._nodes.add(i)
        self._nodes.add(j)
        self._dir_edges.add((i, j))

        self._neighbors[i].add(j)
        self._neighbors[j].add(i)

        self._children[i].add(j)
        self._parents[j].add(i)

    def _add_undir_edge(self, i, j):
        self._nodes.add(i)
        self._nodes.add(j)
        self._undir_edges.add((i, j))

        self._neighbors[i].add(j)
        self._neighbors[j].add(i)

        self._undirected_neighbors[i].add(j)
        self._undirected_neighbors[j].add(i)

    def children(self, node: str) -> set:
        """Gives all children of node `node`.

        Args:
            node (str): node in current PDAG.

        Returns:
            set: set of children.
        """
        if node in self._children.keys():
            return self._children[node]
        else:
            return set()

    def parents(self, node: str) -> set:
        """Gives all parents of node `node`.

        Args:
            node (str): node in current PDAG.

        Returns:
            set: set of parents.
        """
        if node in self._parents.keys():
            return self._parents[node]
        else:
            return set()

    def neighbors(self, node: str) -> set:
        """Gives all neighbors of node `node`.

        Args:
            node (str): node in current PDAG.

        Returns:
            set: set of neighbors.
        """
        if node in self._neighbors.keys():
            return self._neighbors[node]
        else:
            return set()

    def undir_neighbors(self, node: str) -> set:
        """Gives all undirected neighbors
        of node `node`.

        Args:
            node (str): node in current PDAG.

        Returns:
            set: set of undirected neighbors.
        """
        if node in self._undirected_neighbors.keys():
            return self._undirected_neighbors[node]
        else:
            return set()

    def is_adjacent(self, i: str, j: str) -> bool:
        """Return True if the graph contains an directed
        or undirected edge between i and j.

        Args:
            i (str): node i.
            j (str): node j.

        Returns:
            bool: True if i-j or i->j or i<-j
        """
        return any(
            (
                (j, i) in self.dir_edges or (j, i) in self.undir_edges,
                (i, j) in self.dir_edges or (i, j) in self.undir_edges,
            )
        )

    def is_clique(self, potential_clique: set) -> bool:
        """
        Check every pair of node X potential_clique is adjacent.
        """
        return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2))

    @classmethod
    def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> PDAG:
        """Build PDAG from a Pandas adjacency matrix.

        Args:
            pd_amat (pd.DataFrame): input adjacency matrix.

        Returns:
            PDAG
        """
        assert pd_amat.shape[0] == pd_amat.shape[1]
        nodes = pd_amat.columns

        all_connections = []
        start, end = np.where(pd_amat != 0)
        for idx, _ in enumerate(start):
            all_connections.append((pd_amat.columns[start[idx]], pd_amat.columns[end[idx]]))

        temp = [set(i) for i in all_connections]
        temp2 = [arc for arc in all_connections if temp.count(set(arc)) > 1]
        undir_edges = [tuple(item) for item in set(frozenset(item) for item in temp2)]

        dir_edges = [edge for edge in all_connections if edge not in temp2]

        return PDAG(nodes=nodes, dir_edges=dir_edges, undir_edges=undir_edges)

    def remove_edge(self, i: str, j: str):
        """Removes edge in question

        Args:
            i (str): tail
            j (str): head

        Raises:
            AssertionError: if edge does not exist
        """
        if (i, j) not in self.dir_edges and (i, j) not in self.undir_edges:
            raise AssertionError("Edge does not exist in current PDAG")

        self._undir_edges.discard((i, j))
        self._dir_edges.discard((i, j))
        self._children[i].discard(j)
        self._parents[j].discard(i)
        self._neighbors[i].discard(j)
        self._neighbors[j].discard(i)
        self._undirected_neighbors[i].discard(j)
        self._undirected_neighbors[j].discard(i)

    def undir_to_dir_edge(self, tail: str, head: str):
        """Takes a undirected edge and turns it into a directed one.
        tail indicates the starting node of the edge and head the end node, i.e.
        tail -> head.

        Args:
            tail (str): starting node
            head (str): end node

        Raises:
            AssertionError: if edge does not exist or is not undirected.
        """
        if (tail, head) not in self.undir_edges and (
            head,
            tail,
        ) not in self.undir_edges:
            raise AssertionError("Edge seems not to be undirected or even there at all.")
        self._undir_edges.discard((tail, head))
        self._undir_edges.discard((head, tail))
        self._neighbors[tail].discard(head)
        self._neighbors[head].discard(tail)
        self._undirected_neighbors[tail].discard(head)
        self._undirected_neighbors[head].discard(tail)

        self._add_dir_edge(i=tail, j=head)

    def remove_node(self, node):
        """Remove a node from the graph"""
        self._nodes.remove(node)

        self._dir_edges = {(i, j) for i, j in self._dir_edges if i != node and j != node}

        self._undir_edges = {(i, j) for i, j in self._undir_edges if i != node and j != node}

        for child in self._children[node]:
            self._parents[child].remove(node)
            self._neighbors[child].remove(node)

        for parent in self._parents[node]:
            self._children[parent].remove(node)
            self._neighbors[parent].remove(node)

        for u_nbr in self._undirected_neighbors[node]:
            self._undirected_neighbors[u_nbr].remove(node)
            self._neighbors[u_nbr].remove(node)

        self._parents.pop(node, "I was never here")
        self._children.pop(node, "I was never here")
        self._neighbors.pop(node, "I was never here")
        self._undirected_neighbors.pop(node, "I was never here")

    def to_dag(self) -> nx.DiGraph:
        """
        Algorithm as described in Chickering (2002):

            1. From PDAG P create DAG G containing all directed edges from P
            2. Repeat the following: Select node v in P s.t.
                i. v has no outgoing edges (children) i.e. \\(ch(v) = \\emptyset \\)

                ii. \\(neigh(v) \\neq \\emptyset\\)
                    Then \\( (pa(v) \\cup (neigh(v) \\) form a clique.
                    For each v that is in a clique and is part of an undirected edge in P
                    i.e. w - v, insert a directed edge w -> v in G.
                    Remove v and all incident edges from P and continue with next node.
                    Until all nodes have been deleted from P.

        Returns:
            nx.DiGraph: DAG that belongs to the MEC implied by the PDAG
        """

        pdag = self.copy()

        dag = nx.DiGraph()
        dag.add_nodes_from(pdag.nodes)
        dag.add_edges_from(pdag.dir_edges)

        if pdag.num_undir_edges == 0:
            return dag
        else:
            while pdag.nnodes > 0:
                # find node with (1) no directed outgoing edges and
                #                (2) the set of undirected neighbors is either empty or
                #                    undirected neighbors + parents of X are a clique
                found = False
                for node in pdag.nodes:
                    children = pdag.children(node)
                    neighbors = pdag.neighbors(node)
                    # pdag._undirected_neighbors[node]
                    parents = pdag.parents(node)
                    potential_clique_members = neighbors.union(parents)

                    is_clique = pdag.is_clique(potential_clique_members)

                    if not children and (not neighbors or is_clique):
                        found = True
                        # add all edges of node as outgoing edges to dag
                        for edge in pdag.undir_edges:
                            if node in edge:
                                incident_node = set(edge) - {node}
                                dag.add_edge(*incident_node, node)

                        pdag.remove_node(node)
                        break

                if not found:
                    logger.warning("PDAG not extendible: Random DAG on skeleton drawn.")

                    dag = nx.from_pandas_adjacency(self._amat_to_dag(), create_using=nx.DiGraph)

                    break

            return dag

    @property
    def adjacency_matrix(self) -> pd.DataFrame:
        """Returns adjacency matrix where the i,jth
        entry being one indicates that there is an edge
        from i to j. A zero indicates that there is no edge.

        Returns:
            pd.DataFrame: adjacency matrix
        """
        amat = pd.DataFrame(
            np.zeros([self.nnodes, self.nnodes]),
            index=self.nodes,
            columns=self.nodes,
        )
        for edge in self.dir_edges:
            amat.loc[edge] = 1
        for edge in self.undir_edges:
            amat.loc[edge] = amat.loc[edge[::-1]] = 1
        return amat

    def _amat_to_dag(self) -> pd.DataFrame:
        """Transform the adjacency matrix of an PDAG to the adjacency
        matrix of a SOME DAG in the Markov equivalence class.

        Returns:
            pd.DataFrame: DAG, a member of the MEC.
        """
        pdag_amat = self.adjacency_matrix.to_numpy()

        p = pdag_amat.shape[0]
        ## amat to skel
        skel = pdag_amat + pdag_amat.T
        skel[np.where(skel > 1)] = 1
        ## permute skel
        permute_ord = np.random.choice(a=p, size=p, replace=False)
        skel = skel[:, permute_ord][permute_ord]

        ## skel to dag
        for i in range(1, p):
            for j in range(0, i + 1):
                if skel[i, j] == 1:
                    skel[i, j] = 0

        ## inverse permutation
        i_ord = np.sort(permute_ord)
        skel = skel[:, i_ord][i_ord]
        return pd.DataFrame(
            skel,
            index=self.adjacency_matrix.index,
            columns=self.adjacency_matrix.columns,
        )

    def vstructs(self) -> set:
        """Retrieve v-structures

        Returns:
            set: set of all v-structures
        """
        vstructures = set()
        for node in self._nodes:
            for p1, p2 in combinations(self._parents[node], 2):
                if p1 not in self._parents[p2] and p2 not in self._parents[p1]:
                    vstructures.add((p1, node))
                    vstructures.add((p2, node))
        return vstructures

    def copy(self):
        """Return a copy of the graph"""
        return PDAG(nodes=self._nodes, dir_edges=self._dir_edges, undir_edges=self._undir_edges)

    def show(self):
        """Plot PDAG."""
        graph = self.to_networkx()
        pos = nx.circular_layout(graph)
        nx.draw(graph, pos=pos, with_labels=True)

    def to_networkx(self) -> nx.MultiDiGraph:
        """Convert to networkx graph.

        Returns:
            nx.MultiDiGraph: Graph with directed and undirected edges.
        """
        nx_pdag = nx.MultiDiGraph()
        nx_pdag.add_nodes_from(self.nodes)
        nx_pdag.add_edges_from(self.dir_edges)
        for edge in self.undir_edges:
            nx_pdag.add_edge(*edge)
            nx_pdag.add_edge(*edge[::-1])

        return nx_pdag

    def _meek_mec_enumeration(self, pdag: PDAG, dag_list: list):
        """Recursion algorithm which recursively applies the
        following steps:
            1. Orient the first undirected edge found.
            2. Apply Meek rules.
            3. Recurse with each direction of the oriented edge.
        This corresponds to Algorithm 2 in Wienöbst et al. (2023).

        Args:
            pdag (PDAG): partially directed graph in question.
            dag_list (list): list of currently found DAGs.

        References:
            Wienöbst, Marcel, et al. "Efficient enumeration of Markov equivalent DAGs."
            Proceedings of the AAAI Conference on Artificial Intelligence.
            Vol. 37. No. 10. 2023.
        """
        g_copy = pdag.copy()
        g_copy = self._apply_meek_rules(g_copy)  # Apply Meek rules

        undir_edges = g_copy.undir_edges
        if undir_edges:
            i, j = undir_edges[0]  # Take first undirected edge

        if not g_copy.undir_edges:
            # makes sure that flaoting nodes are preserved
            new_member = nx.DiGraph()
            new_member.add_nodes_from(g_copy.nodes)
            new_member.add_edges_from(g_copy.dir_edges)
            dag_list.append(new_member)
            return  # Add DAG to current list

        # Recursion first orientation:
        g_copy.undir_to_dir_edge(i, j)
        self._meek_mec_enumeration(pdag=g_copy, dag_list=dag_list)
        g_copy.remove_edge(i, j)

        # Recursion second orientation
        g_copy._add_dir_edge(j, i)
        self._meek_mec_enumeration(pdag=g_copy, dag_list=dag_list)

    def to_allDAGs(self) -> list[nx.DiGraph]:
        """Recursion algorithm which recursively applies the
        following steps:
            1. Orient the first undirected edge found.
            2. Apply Meek rules.
            3. Recurse with each direction of the oriented edge.
        This corresponds to Algorithm 2 in Wienöbst et al. (2023).

        References:
            Wienöbst, Marcel, et al. "Efficient enumeration of Markov equivalent DAGs."
            Proceedings of the AAAI Conference on Artificial Intelligence.
            Vol. 37. No. 10. 2023.
        """
        all_dags = []
        self._meek_mec_enumeration(pdag=self, dag_list=all_dags)
        return all_dags

    # use Meek's cpdag2alldag
    def _apply_meek_rules(self, G: PDAG) -> PDAG:
        """Apply all four Meek rules to a
        PDAG turning it into a CPDAG.

        Args:
            G (PDAG): PDAG to complete

        Returns:
            PDAG: completed PDAG.
        """
        # Apply Meek Rules
        cpdag = G.copy()
        cpdag = rule_1(pdag=cpdag)
        cpdag = rule_2(pdag=cpdag)
        cpdag = rule_3(pdag=cpdag)
        cpdag = rule_4(pdag=cpdag)
        return cpdag

    def to_random_dag(self) -> nx.DiGraph:
        """Provides a random DAG residing in the MEC.

        Returns:
            nx.DiGraph: random DAG living in MEC
        """
        to_dag_candidate = self.copy()

        while to_dag_candidate.num_undir_edges > 0:
            chosen_edge = to_dag_candidate.undir_edges[
                np.random.choice(to_dag_candidate.num_undir_edges)
            ]
            choose_orientation = [chosen_edge, chosen_edge[::-1]]
            node_i, node_j = choose_orientation[np.random.choice(len(choose_orientation))]

            to_dag_candidate.undir_to_dir_edge(tail=node_i, head=node_j)
            to_dag_candidate = to_dag_candidate._apply_meek_rules(G=to_dag_candidate)

        return nx.from_pandas_adjacency(to_dag_candidate.adjacency_matrix, create_using=nx.DiGraph)

    @property
    def nodes(self) -> list:
        """Get all nods in current PDAG.

        Returns:
            list: list of nodes.
        """
        return sorted(list(self._nodes))

    @property
    def nnodes(self) -> int:
        """Number of nodes in current PDAG.

        Returns:
            int: Number of nodes
        """
        return len(self._nodes)

    @property
    def num_undir_edges(self) -> int:
        """Number of undirected edges
        in current PDAG.

        Returns:
            int: Number of undirected edges
        """
        return len(self._undir_edges)

    @property
    def num_dir_edges(self) -> int:
        """Number of directed edges
        in current PDAG.

        Returns:
            int: Number of directed edges
        """
        return len(self._dir_edges)

    @property
    def num_adjacencies(self) -> int:
        """Number of adjacent nodes
        in current PDAG.

        Returns:
            int: Number of adjacent nodes
        """
        return self.num_undir_edges + self.num_dir_edges

    @property
    def undir_edges(self) -> list[tuple]:
        """Gives all undirected edges in
        current PDAG.

        Returns:
            list[tuple]: List of undirected edges.
        """
        return list(self._undir_edges)

    @property
    def dir_edges(self) -> list[tuple]:
        """Gives all directed edges in
        current PDAG.

        Returns:
            list[tuple]: List of directed edges.
        """
        return list(self._dir_edges)

adjacency_matrix: pd.DataFrame property

Returns adjacency matrix where the i,jth entry being one indicates that there is an edge from i to j. A zero indicates that there is no edge.

Returns:

Type Description
DataFrame

pd.DataFrame: adjacency matrix

dir_edges: list[tuple] property

Gives all directed edges in current PDAG.

Returns:

Type Description
list[tuple]

list[tuple]: List of directed edges.

nnodes: int property

Number of nodes in current PDAG.

Returns:

Name Type Description
int int

Number of nodes

nodes: list property

Get all nods in current PDAG.

Returns:

Name Type Description
list list

list of nodes.

num_adjacencies: int property

Number of adjacent nodes in current PDAG.

Returns:

Name Type Description
int int

Number of adjacent nodes

num_dir_edges: int property

Number of directed edges in current PDAG.

Returns:

Name Type Description
int int

Number of directed edges

num_undir_edges: int property

Number of undirected edges in current PDAG.

Returns:

Name Type Description
int int

Number of undirected edges

undir_edges: list[tuple] property

Gives all undirected edges in current PDAG.

Returns:

Type Description
list[tuple]

list[tuple]: List of undirected edges.

children(node)

Gives all children of node node.

Parameters:

Name Type Description Default
node str

node in current PDAG.

required

Returns:

Name Type Description
set set

set of children.

Source code in causalAssembly/pdag.py
83
84
85
86
87
88
89
90
91
92
93
94
95
def children(self, node: str) -> set:
    """Gives all children of node `node`.

    Args:
        node (str): node in current PDAG.

    Returns:
        set: set of children.
    """
    if node in self._children.keys():
        return self._children[node]
    else:
        return set()

copy()

Return a copy of the graph

Source code in causalAssembly/pdag.py
391
392
393
def copy(self):
    """Return a copy of the graph"""
    return PDAG(nodes=self._nodes, dir_edges=self._dir_edges, undir_edges=self._undir_edges)

from_pandas_adjacency(pd_amat) classmethod

Build PDAG from a Pandas adjacency matrix.

Parameters:

Name Type Description Default
pd_amat DataFrame

input adjacency matrix.

required

Returns:

Type Description
PDAG

PDAG

Source code in causalAssembly/pdag.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@classmethod
def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> PDAG:
    """Build PDAG from a Pandas adjacency matrix.

    Args:
        pd_amat (pd.DataFrame): input adjacency matrix.

    Returns:
        PDAG
    """
    assert pd_amat.shape[0] == pd_amat.shape[1]
    nodes = pd_amat.columns

    all_connections = []
    start, end = np.where(pd_amat != 0)
    for idx, _ in enumerate(start):
        all_connections.append((pd_amat.columns[start[idx]], pd_amat.columns[end[idx]]))

    temp = [set(i) for i in all_connections]
    temp2 = [arc for arc in all_connections if temp.count(set(arc)) > 1]
    undir_edges = [tuple(item) for item in set(frozenset(item) for item in temp2)]

    dir_edges = [edge for edge in all_connections if edge not in temp2]

    return PDAG(nodes=nodes, dir_edges=dir_edges, undir_edges=undir_edges)

is_adjacent(i, j)

Return True if the graph contains an directed or undirected edge between i and j.

Parameters:

Name Type Description Default
i str

node i.

required
j str

node j.

required

Returns:

Name Type Description
bool bool

True if i-j or i->j or i<-j

Source code in causalAssembly/pdag.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def is_adjacent(self, i: str, j: str) -> bool:
    """Return True if the graph contains an directed
    or undirected edge between i and j.

    Args:
        i (str): node i.
        j (str): node j.

    Returns:
        bool: True if i-j or i->j or i<-j
    """
    return any(
        (
            (j, i) in self.dir_edges or (j, i) in self.undir_edges,
            (i, j) in self.dir_edges or (i, j) in self.undir_edges,
        )
    )

is_clique(potential_clique)

Check every pair of node X potential_clique is adjacent.

Source code in causalAssembly/pdag.py
158
159
160
161
162
def is_clique(self, potential_clique: set) -> bool:
    """
    Check every pair of node X potential_clique is adjacent.
    """
    return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2))

neighbors(node)

Gives all neighbors of node node.

Parameters:

Name Type Description Default
node str

node in current PDAG.

required

Returns:

Name Type Description
set set

set of neighbors.

Source code in causalAssembly/pdag.py
111
112
113
114
115
116
117
118
119
120
121
122
123
def neighbors(self, node: str) -> set:
    """Gives all neighbors of node `node`.

    Args:
        node (str): node in current PDAG.

    Returns:
        set: set of neighbors.
    """
    if node in self._neighbors.keys():
        return self._neighbors[node]
    else:
        return set()

parents(node)

Gives all parents of node node.

Parameters:

Name Type Description Default
node str

node in current PDAG.

required

Returns:

Name Type Description
set set

set of parents.

Source code in causalAssembly/pdag.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def parents(self, node: str) -> set:
    """Gives all parents of node `node`.

    Args:
        node (str): node in current PDAG.

    Returns:
        set: set of parents.
    """
    if node in self._parents.keys():
        return self._parents[node]
    else:
        return set()

remove_edge(i, j)

Removes edge in question

Parameters:

Name Type Description Default
i str

tail

required
j str

head

required

Raises:

Type Description
AssertionError

if edge does not exist

Source code in causalAssembly/pdag.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def remove_edge(self, i: str, j: str):
    """Removes edge in question

    Args:
        i (str): tail
        j (str): head

    Raises:
        AssertionError: if edge does not exist
    """
    if (i, j) not in self.dir_edges and (i, j) not in self.undir_edges:
        raise AssertionError("Edge does not exist in current PDAG")

    self._undir_edges.discard((i, j))
    self._dir_edges.discard((i, j))
    self._children[i].discard(j)
    self._parents[j].discard(i)
    self._neighbors[i].discard(j)
    self._neighbors[j].discard(i)
    self._undirected_neighbors[i].discard(j)
    self._undirected_neighbors[j].discard(i)

remove_node(node)

Remove a node from the graph

Source code in causalAssembly/pdag.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def remove_node(self, node):
    """Remove a node from the graph"""
    self._nodes.remove(node)

    self._dir_edges = {(i, j) for i, j in self._dir_edges if i != node and j != node}

    self._undir_edges = {(i, j) for i, j in self._undir_edges if i != node and j != node}

    for child in self._children[node]:
        self._parents[child].remove(node)
        self._neighbors[child].remove(node)

    for parent in self._parents[node]:
        self._children[parent].remove(node)
        self._neighbors[parent].remove(node)

    for u_nbr in self._undirected_neighbors[node]:
        self._undirected_neighbors[u_nbr].remove(node)
        self._neighbors[u_nbr].remove(node)

    self._parents.pop(node, "I was never here")
    self._children.pop(node, "I was never here")
    self._neighbors.pop(node, "I was never here")
    self._undirected_neighbors.pop(node, "I was never here")

show()

Plot PDAG.

Source code in causalAssembly/pdag.py
395
396
397
398
399
def show(self):
    """Plot PDAG."""
    graph = self.to_networkx()
    pos = nx.circular_layout(graph)
    nx.draw(graph, pos=pos, with_labels=True)

to_allDAGs()

Recursion algorithm which recursively applies the following steps: 1. Orient the first undirected edge found. 2. Apply Meek rules. 3. Recurse with each direction of the oriented edge. This corresponds to Algorithm 2 in Wienöbst et al. (2023).

References

Wienöbst, Marcel, et al. "Efficient enumeration of Markov equivalent DAGs." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 37. No. 10. 2023.

Source code in causalAssembly/pdag.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def to_allDAGs(self) -> list[nx.DiGraph]:
    """Recursion algorithm which recursively applies the
    following steps:
        1. Orient the first undirected edge found.
        2. Apply Meek rules.
        3. Recurse with each direction of the oriented edge.
    This corresponds to Algorithm 2 in Wienöbst et al. (2023).

    References:
        Wienöbst, Marcel, et al. "Efficient enumeration of Markov equivalent DAGs."
        Proceedings of the AAAI Conference on Artificial Intelligence.
        Vol. 37. No. 10. 2023.
    """
    all_dags = []
    self._meek_mec_enumeration(pdag=self, dag_list=all_dags)
    return all_dags

to_dag()

Algorithm as described in Chickering (2002):

1. From PDAG P create DAG G containing all directed edges from P
2. Repeat the following: Select node v in P s.t.
    i. v has no outgoing edges (children) i.e. \(ch(v) = \emptyset \)

    ii. \(neigh(v) \neq \emptyset\)
        Then \( (pa(v) \cup (neigh(v) \) form a clique.
        For each v that is in a clique and is part of an undirected edge in P
        i.e. w - v, insert a directed edge w -> v in G.
        Remove v and all incident edges from P and continue with next node.
        Until all nodes have been deleted from P.

Returns:

Type Description
DiGraph

nx.DiGraph: DAG that belongs to the MEC implied by the PDAG

Source code in causalAssembly/pdag.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def to_dag(self) -> nx.DiGraph:
    """
    Algorithm as described in Chickering (2002):

        1. From PDAG P create DAG G containing all directed edges from P
        2. Repeat the following: Select node v in P s.t.
            i. v has no outgoing edges (children) i.e. \\(ch(v) = \\emptyset \\)

            ii. \\(neigh(v) \\neq \\emptyset\\)
                Then \\( (pa(v) \\cup (neigh(v) \\) form a clique.
                For each v that is in a clique and is part of an undirected edge in P
                i.e. w - v, insert a directed edge w -> v in G.
                Remove v and all incident edges from P and continue with next node.
                Until all nodes have been deleted from P.

    Returns:
        nx.DiGraph: DAG that belongs to the MEC implied by the PDAG
    """

    pdag = self.copy()

    dag = nx.DiGraph()
    dag.add_nodes_from(pdag.nodes)
    dag.add_edges_from(pdag.dir_edges)

    if pdag.num_undir_edges == 0:
        return dag
    else:
        while pdag.nnodes > 0:
            # find node with (1) no directed outgoing edges and
            #                (2) the set of undirected neighbors is either empty or
            #                    undirected neighbors + parents of X are a clique
            found = False
            for node in pdag.nodes:
                children = pdag.children(node)
                neighbors = pdag.neighbors(node)
                # pdag._undirected_neighbors[node]
                parents = pdag.parents(node)
                potential_clique_members = neighbors.union(parents)

                is_clique = pdag.is_clique(potential_clique_members)

                if not children and (not neighbors or is_clique):
                    found = True
                    # add all edges of node as outgoing edges to dag
                    for edge in pdag.undir_edges:
                        if node in edge:
                            incident_node = set(edge) - {node}
                            dag.add_edge(*incident_node, node)

                    pdag.remove_node(node)
                    break

            if not found:
                logger.warning("PDAG not extendible: Random DAG on skeleton drawn.")

                dag = nx.from_pandas_adjacency(self._amat_to_dag(), create_using=nx.DiGraph)

                break

        return dag

to_networkx()

Convert to networkx graph.

Returns:

Type Description
MultiDiGraph

nx.MultiDiGraph: Graph with directed and undirected edges.

Source code in causalAssembly/pdag.py
401
402
403
404
405
406
407
408
409
410
411
412
413
414
def to_networkx(self) -> nx.MultiDiGraph:
    """Convert to networkx graph.

    Returns:
        nx.MultiDiGraph: Graph with directed and undirected edges.
    """
    nx_pdag = nx.MultiDiGraph()
    nx_pdag.add_nodes_from(self.nodes)
    nx_pdag.add_edges_from(self.dir_edges)
    for edge in self.undir_edges:
        nx_pdag.add_edge(*edge)
        nx_pdag.add_edge(*edge[::-1])

    return nx_pdag

to_random_dag()

Provides a random DAG residing in the MEC.

Returns:

Type Description
DiGraph

nx.DiGraph: random DAG living in MEC

Source code in causalAssembly/pdag.py
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
def to_random_dag(self) -> nx.DiGraph:
    """Provides a random DAG residing in the MEC.

    Returns:
        nx.DiGraph: random DAG living in MEC
    """
    to_dag_candidate = self.copy()

    while to_dag_candidate.num_undir_edges > 0:
        chosen_edge = to_dag_candidate.undir_edges[
            np.random.choice(to_dag_candidate.num_undir_edges)
        ]
        choose_orientation = [chosen_edge, chosen_edge[::-1]]
        node_i, node_j = choose_orientation[np.random.choice(len(choose_orientation))]

        to_dag_candidate.undir_to_dir_edge(tail=node_i, head=node_j)
        to_dag_candidate = to_dag_candidate._apply_meek_rules(G=to_dag_candidate)

    return nx.from_pandas_adjacency(to_dag_candidate.adjacency_matrix, create_using=nx.DiGraph)

undir_neighbors(node)

Gives all undirected neighbors of node node.

Parameters:

Name Type Description Default
node str

node in current PDAG.

required

Returns:

Name Type Description
set set

set of undirected neighbors.

Source code in causalAssembly/pdag.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def undir_neighbors(self, node: str) -> set:
    """Gives all undirected neighbors
    of node `node`.

    Args:
        node (str): node in current PDAG.

    Returns:
        set: set of undirected neighbors.
    """
    if node in self._undirected_neighbors.keys():
        return self._undirected_neighbors[node]
    else:
        return set()

undir_to_dir_edge(tail, head)

Takes a undirected edge and turns it into a directed one. tail indicates the starting node of the edge and head the end node, i.e. tail -> head.

Parameters:

Name Type Description Default
tail str

starting node

required
head str

end node

required

Raises:

Type Description
AssertionError

if edge does not exist or is not undirected.

Source code in causalAssembly/pdag.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def undir_to_dir_edge(self, tail: str, head: str):
    """Takes a undirected edge and turns it into a directed one.
    tail indicates the starting node of the edge and head the end node, i.e.
    tail -> head.

    Args:
        tail (str): starting node
        head (str): end node

    Raises:
        AssertionError: if edge does not exist or is not undirected.
    """
    if (tail, head) not in self.undir_edges and (
        head,
        tail,
    ) not in self.undir_edges:
        raise AssertionError("Edge seems not to be undirected or even there at all.")
    self._undir_edges.discard((tail, head))
    self._undir_edges.discard((head, tail))
    self._neighbors[tail].discard(head)
    self._neighbors[head].discard(tail)
    self._undirected_neighbors[tail].discard(head)
    self._undirected_neighbors[head].discard(tail)

    self._add_dir_edge(i=tail, j=head)

vstructs()

Retrieve v-structures

Returns:

Name Type Description
set set

set of all v-structures

Source code in causalAssembly/pdag.py
377
378
379
380
381
382
383
384
385
386
387
388
389
def vstructs(self) -> set:
    """Retrieve v-structures

    Returns:
        set: set of all v-structures
    """
    vstructures = set()
    for node in self._nodes:
        for p1, p2 in combinations(self._parents[node], 2):
            if p1 not in self._parents[p2] and p2 not in self._parents[p1]:
                vstructures.add((p1, node))
                vstructures.add((p2, node))
    return vstructures

dag2cpdag(dag)

Convertes a DAG into its unique CPDAG

Parameters:

Name Type Description Default
dag DiGraph

DAG the CPDAG corresponds to.

required

Returns:

Name Type Description
PDAG PDAG

unique CPDAG

Source code in causalAssembly/pdag.py
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
def dag2cpdag(dag: nx.DiGraph) -> PDAG:
    """Convertes a DAG into its unique CPDAG

    Args:
        dag (nx.DiGraph): DAG the CPDAG corresponds to.

    Returns:
        PDAG: unique CPDAG
    """
    copy_dag = dag.copy()
    # Skeleton
    skeleton = nx.to_pandas_adjacency(copy_dag.to_undirected())
    # v-Structures
    vstructures = vstructs(dag=copy_dag)

    for edge in vstructures:  # orient v-structures
        skeleton.loc[edge[::-1]] = 0

    pdag_init = PDAG.from_pandas_adjacency(skeleton)

    # Apply Meek Rules
    cpdag = rule_1(pdag=pdag_init)
    cpdag = rule_2(pdag=cpdag)
    cpdag = rule_3(pdag=cpdag)
    cpdag = rule_4(pdag=cpdag)

    return cpdag

rule_1(pdag)

Given the following pattern X -> Y - Z. Orient Y - Z to Y -> Z if X and Z are non-adjacent (otherwise a new v-structure arises).

Parameters:

Name Type Description Default
pdag PDAG

PDAG before application of rule.

required

Returns:

Name Type Description
PDAG PDAG

PDAG after application of rule.

Source code in causalAssembly/pdag.py
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
def rule_1(pdag: PDAG) -> PDAG:
    """Given the following pattern X -> Y - Z. Orient Y - Z to Y -> Z
    if X and Z are non-adjacent (otherwise a new v-structure arises).

    Args:
        pdag (PDAG): PDAG before application of rule.

    Returns:
        PDAG: PDAG after application of rule.
    """
    copy_pdag = pdag.copy()
    for edge in copy_pdag.undir_edges:
        reverse_edge = edge[::-1]
        test_edges = [edge, reverse_edge]
        for tail, head in test_edges:
            orient = False
            undir_parents = copy_pdag.parents(tail)
            if undir_parents:
                for parent in undir_parents:
                    if not copy_pdag.is_adjacent(parent, head):
                        orient = True
            if orient:
                copy_pdag.undir_to_dir_edge(tail=tail, head=head)
                break
    return copy_pdag

rule_2(pdag)

Given the following directed triple X -> Y -> Z where X - Z are indeed adjacent. Orient X - Z to X -> Z otherwise a cycle arises.

Parameters:

Name Type Description Default
pdag PDAG

PDAG before application of rule.

required

Returns:

Name Type Description
PDAG PDAG

PDAG after application of rule.

Source code in causalAssembly/pdag.py
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
def rule_2(pdag: PDAG) -> PDAG:
    """Given the following directed triple
    X -> Y -> Z where X - Z are indeed adjacent.
    Orient X - Z to X -> Z otherwise a cycle arises.

    Args:
        pdag (PDAG): PDAG before application of rule.

    Returns:
        PDAG: PDAG after application of rule.
    """
    copy_pdag = pdag.copy()
    for edge in copy_pdag.undir_edges:
        reverse_edge = edge[::-1]
        test_edges = [edge, reverse_edge]
        for tail, head in test_edges:
            orient = False
            undir_children = copy_pdag.children(tail)
            if undir_children:
                for child in undir_children:
                    if head in copy_pdag.children(child):
                        orient = True
            if orient:
                copy_pdag.undir_to_dir_edge(tail=tail, head=head)
                break
    return copy_pdag

rule_3(pdag)

Orient X - Z to X -> Z, whenever there are two triples X - Y1 -> Z and X - Y2 -> Z such that Y1 and Y2 are non-adjacent.

Parameters:

Name Type Description Default
pdag PDAG

PDAG before application of rule.

required

Returns:

Name Type Description
PDAG PDAG

PDAG after application of rule.

Source code in causalAssembly/pdag.py
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
def rule_3(pdag: PDAG) -> PDAG:
    """Orient X - Z to X -> Z, whenever there are two triples
    X - Y1 -> Z and X - Y2 -> Z such that Y1 and Y2 are non-adjacent.

    Args:
        pdag (PDAG): PDAG before application of rule.

    Returns:
        PDAG: PDAG after application of rule.
    """
    copy_pdag = pdag.copy()
    for edge in copy_pdag.undir_edges:
        reverse_edge = edge[::-1]
        test_edges = [edge, reverse_edge]
        for tail, head in test_edges:
            # if true that tail - node1 -> head and tail - node2 -> head
            # while {node1 U node2} = 0 then orient tail -> head
            orient = False
            if len(copy_pdag.undir_neighbors(tail)) >= 2:
                undir_n = copy_pdag.undir_neighbors(tail)
                selection = [
                    (node1, node2)
                    for node1, node2 in combinations(undir_n, 2)
                    if not copy_pdag.is_adjacent(node1, node2)
                ]
                if selection:
                    for node1, node2 in selection:
                        if head in copy_pdag.parents(node1).intersection(copy_pdag.parents(node2)):
                            orient = True
            if orient:
                copy_pdag.undir_to_dir_edge(tail=tail, head=head)
                break
    return pdag

rule_4(pdag)

Orient X - Y1 to X -> Y1, whenever there are two triples with X - Z and X - Y1 <- Z and X - Y2 -> Z such that Y1 and Y2 are non-adjacent.

Parameters:

Name Type Description Default
pdag PDAG

PDAG before application of rule.

required

Returns:

Name Type Description
PDAG PDAG

PDAG after application of rule.

Source code in causalAssembly/pdag.py
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
def rule_4(pdag: PDAG) -> PDAG:
    """Orient X - Y1 to X -> Y1, whenever there are
    two triples with X - Z and X - Y1 <- Z and X - Y2 -> Z
    such that Y1 and Y2 are non-adjacent.

    Args:
        pdag (PDAG): PDAG before application of rule.

    Returns:
        PDAG: PDAG after application of rule.
    """
    copy_pdag = pdag.copy()
    for edge in copy_pdag.undir_edges:
        reverse_edge = edge[::-1]
        test_edges = [edge, reverse_edge]
        for tail, head in test_edges:
            orient = False
            if len(copy_pdag.undir_neighbors(tail)) > 0:
                undirected_n = copy_pdag.undir_neighbors(tail)
                for undir_n in undirected_n:
                    if tail in copy_pdag.children(undir_n):
                        children_select = list(copy_pdag.children(undir_n))
                        if children_select:
                            for parent in children_select:
                                if head in copy_pdag.children(parent):
                                    orient = True
            if orient:
                copy_pdag.undir_to_dir_edge(tail=tail, head=head)
                break
    return pdag

vstructs(dag)

Retrieve all v-structures in a DAG.

Parameters:

Name Type Description Default
dag DiGraph

DAG in question

required

Returns:

Name Type Description
set set

Set of all v-structures.

Source code in causalAssembly/pdag.py
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
def vstructs(dag: nx.DiGraph) -> set:
    """Retrieve all v-structures in a DAG.

    Args:
        dag (nx.DiGraph): DAG in question

    Returns:
        set: Set of all v-structures.
    """
    vstructures = set()
    for node in dag.nodes():
        for p1, p2 in combinations(list(dag.predecessors(node)), 2):  # get all parents of node
            if not dag.has_edge(p1, p2) and not dag.has_edge(p2, p1):
                vstructures.add((p1, node))
                vstructures.add((p2, node))
    return vstructures

DAG class

DAG

General class for dealing with directed acyclic graph i.e. graphs that are directed and must not contain any cycles.

Source code in causalAssembly/dag.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
class DAG:
    """
    General class for dealing with directed acyclic graph i.e.
    graphs that are directed and must not contain any cycles.
    """

    def __init__(
        self,
        nodes: list | None = None,
        edges: list[tuple] | None = None,
    ):
        if nodes is None:
            nodes = []
        if edges is None:
            edges = []

        self._nodes = set(nodes)
        self._edges = set()
        self._parents = defaultdict(set)
        self._children = defaultdict(set)
        self.drf: dict = dict()
        self._random_state: np.random.Generator = np.random.default_rng(seed=2023)

        for edge in edges:
            self._add_edge(*edge)

    def _add_edge(self, i, j):
        self._nodes.add(i)
        self._nodes.add(j)
        self._edges.add((i, j))

        # Check if graph is acyclic
        # TODO: Make check really after each edge is added?
        if not self.is_acyclic():
            raise ValueError(
                "The edge set you provided \
                induces one or more cycles.\
                Check your input!"
            )

        self._children[i].add(j)
        self._parents[j].add(i)

    @property
    def random_state(self):
        return self._random_state

    @random_state.setter
    def random_state(self, r: np.random.Generator):
        if not isinstance(r, np.random.Generator):
            raise AssertionError("Specify numpy random number generator object!")
        self._random_state = r

    def add_edge(self, edge: tuple[str, str]):
        """Add edge to DAG

        Args:
            edge (tuple[str, str]): Edge to add
        """
        self._add_edge(*edge)

    def add_edges_from(self, edges: list[tuple[str, str]]):
        """Add multiple edges to DAG

        Args:
            edges (list[tuple[str, str]]): Edges to add
        """
        for edge in edges:
            self.add_edge(edge=edge)

    def children(self, of_node: str) -> list[str]:
        """Gives all children of node `of_node`.

        Args:
            of_node (str): node in current DAG.

        Returns:
            list: of children.
        """
        if of_node in self._children.keys():
            return list(self._children[of_node])
        else:
            return []

    def parents(self, of_node: str) -> list[str]:
        """Gives all parents of node `of_node`.

        Args:
            of_node (str): node in current DAG.

        Returns:
            list: of parents.
        """
        if of_node in self._parents.keys():
            return list(self._parents[of_node])
        else:
            return []

    def induced_subgraph(self, nodes: list[str]) -> DAG:
        """Returns the induced subgraph on the nodes in `nodes`.

        Args:
            nodes (list[str]): List of nodes.

        Returns:
            DAG: Induced subgraph.
        """
        edges = [(i, j) for i, j in self.edges if i in nodes and j in nodes]
        return DAG(nodes=nodes, edges=edges)

    def is_adjacent(self, i: str, j: str) -> bool:
        """Return True if the graph contains an directed
        edge between i and j.

        Args:
            i (str): node i.
            j (str): node j.

        Returns:
            bool: True if i->j or i<-j
        """
        return (j, i) in self.edges or (i, j) in self.edges

    def is_clique(self, potential_clique: set) -> bool:
        """
        Check every pair of node X potential_clique is adjacent.
        """
        return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2))

    def is_acyclic(self) -> bool:
        """Check if the graph is acyclic.

        Returns:
            bool: True if graph is acyclic.
        """
        nx_dag = self.to_networkx()
        return nx.is_directed_acyclic_graph(nx_dag)

    @classmethod
    def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> DAG:
        """Build DAG from a Pandas adjacency matrix.

        Args:
            pd_amat (pd.DataFrame): input adjacency matrix.

        Returns:
            DAG
        """
        assert pd_amat.shape[0] == pd_amat.shape[1]
        nodes = pd_amat.columns

        all_connections = []
        start, end = np.where(pd_amat != 0)
        for idx, _ in enumerate(start):
            all_connections.append((pd_amat.columns[start[idx]], pd_amat.columns[end[idx]]))

        temp = [set(i) for i in all_connections]
        temp2 = [arc for arc in all_connections if temp.count(set(arc)) > 1]

        dir_edges = [edge for edge in all_connections if edge not in temp2]

        return DAG(nodes=nodes, edges=dir_edges)

    def remove_edge(self, i: str, j: str):
        """Removes edge in question

        Args:
            i (str): tail
            j (str): head

        Raises:
            AssertionError: if edge does not exist
        """
        if (i, j) not in self.edges:
            raise AssertionError("Edge does not exist in current DAG")

        self._edges.discard((i, j))
        self._children[i].discard(j)
        self._parents[j].discard(i)

    def remove_node(self, node):
        """Remove a node from the graph"""
        self._nodes.remove(node)

        self._edges = {(i, j) for i, j in self._edges if i != node and j != node}

        for child in self._children[node]:
            self._parents[child].remove(node)

        for parent in self._parents[node]:
            self._children[parent].remove(node)

        self._parents.pop(node, "I was never here")
        self._children.pop(node, "I was never here")

    @property
    def adjacency_matrix(self) -> pd.DataFrame:
        """Returns adjacency matrix where the i,jth
        entry being one indicates that there is an edge
        from i to j. A zero indicates that there is no edge.

        Returns:
            pd.DataFrame: adjacency matrix
        """
        amat = pd.DataFrame(
            np.zeros([self.num_nodes, self.num_nodes]),
            index=self.nodes,
            columns=self.nodes,
        )
        for edge in self.edges:
            amat.loc[edge] = 1
        return amat

    def vstructs(self) -> set:
        """Retrieve v-structures

        Returns:
            set: set of all v-structures
        """
        vstructures = set()
        for node in self._nodes:
            for p1, p2 in combinations(self._parents[node], 2):
                if p1 not in self._parents[p2] and p2 not in self._parents[p1]:
                    vstructures.add((p1, node))
                    vstructures.add((p2, node))
        return vstructures

    def copy(self):
        """Return a copy of the graph"""
        return DAG(nodes=self._nodes, edges=self._edges)

    def show(self):
        """Plot DAG."""
        graph = self.to_networkx()
        pos = nx.circular_layout(graph)
        nx.draw(graph, pos=pos, with_labels=True)

    def to_networkx(self) -> nx.DiGraph:
        """Convert to networkx graph.

        Returns:
            nx.DiGraph: DAG.
        """
        nx_dag = nx.DiGraph()
        nx_dag.add_nodes_from(self.nodes)
        nx_dag.add_edges_from(self.edges)

        return nx_dag

    @property
    def nodes(self) -> list:
        """Get all nods in current DAG.

        Returns:
            list: list of nodes.
        """
        return sorted(list(self._nodes))

    @property
    def num_nodes(self) -> int:
        """Number of nodes in current DAG.

        Returns:
            int: Number of nodes
        """
        return len(self._nodes)

    @property
    def num_edges(self) -> int:
        """Number of directed edges
        in current DAG.

        Returns:
            int: Number of directed edges
        """
        return len(self._edges)

    @property
    def sparsity(self) -> float:
        """Sparsity of the graph

        Returns:
            float: in [0,1]
        """
        s = self.num_nodes
        return self.num_edges / s / (s - 1) * 2

    @property
    def edges(self) -> list[tuple]:
        """Gives all directed edges in
        current DAG.

        Returns:
            list[tuple]: List of directed edges.
        """
        return list(self._edges)

    @property
    def causal_order(self) -> list[str]:
        """Returns the causal order of the current graph.
        Note that this order is in general not unique.

        Returns:
            list[str]: Causal order
        """
        return list(nx.lexicographical_topological_sort(self.to_networkx()))

    @property
    def max_in_degree(self) -> int:
        """Maximum in-degree of the graph.

        Returns:
            int: Maximum in-degree
        """
        return max(len(self._parents[node]) for node in self._nodes)

    @property
    def max_out_degree(self) -> int:
        """Maximum out-degree of the graph.

        Returns:
            int: Maximum out-degree
        """
        return max(len(self._children[node]) for node in self._nodes)

    @classmethod
    def from_nx(cls, nx_dag: nx.DiGraph) -> DAG:
        """Convert to DAG from nx.DiGraph.

        Args:
            nx_dag (nx.DiGraph): DAG in question.

        Raises:
            TypeError: If DAG is not nx.DiGraph

        Returns:
            DAG
        """
        if not isinstance(nx_dag, nx.DiGraph):
            raise TypeError("DAG must be of type nx.DiGraph")
        return DAG(nodes=list(nx_dag.nodes), edges=list(nx_dag.edges))

    def save_drf(self, filename: str, location: str = None):
        """Writes a drf dict to file. Please provide the .pkl suffix!

        Args:
            filename (str): name of the file to be written e.g. examplefile.pkl
            location (str, optional): path to file in case it's not located in
                the current working directory. Defaults to None.
        """

        if not location:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "wb") as f:
            pickle.dump(self.drf, f)

    def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
        """Draw from the trained DRF.

        Args:
            size (int, optional): Number of samples to be drawn. Defaults to 10.
            smoothed (bool, optional): If set to true, marginal distributions will
                be sampled from smoothed bootstraps. Defaults to True.

        Returns:
            pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
        """
        return _sample_from_drf(graph=self, size=size, smoothed=smoothed)

    def to_cpdag(self) -> PDAG:
        return dag2cpdag(dag=self.to_networkx())

    @classmethod
    def load_drf(cls, filename: str, location: str = None) -> dict:
        """Loads a drf dict from a .pkl file into the workspace.

        Args:
            filename (str): name of the file e.g. examplefile.pkl
            location (str, optional): path to file in case it's not located
                in the current working directory. Defaults to None.

        Returns:
            DRF (dict): dict of trained drf objects
        """
        if not location:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "rb") as drf:
            pickle_drf = pickle.load(drf)

        return pickle_drf

adjacency_matrix: pd.DataFrame property

Returns adjacency matrix where the i,jth entry being one indicates that there is an edge from i to j. A zero indicates that there is no edge.

Returns:

Type Description
DataFrame

pd.DataFrame: adjacency matrix

causal_order: list[str] property

Returns the causal order of the current graph. Note that this order is in general not unique.

Returns:

Type Description
list[str]

list[str]: Causal order

edges: list[tuple] property

Gives all directed edges in current DAG.

Returns:

Type Description
list[tuple]

list[tuple]: List of directed edges.

max_in_degree: int property

Maximum in-degree of the graph.

Returns:

Name Type Description
int int

Maximum in-degree

max_out_degree: int property

Maximum out-degree of the graph.

Returns:

Name Type Description
int int

Maximum out-degree

nodes: list property

Get all nods in current DAG.

Returns:

Name Type Description
list list

list of nodes.

num_edges: int property

Number of directed edges in current DAG.

Returns:

Name Type Description
int int

Number of directed edges

num_nodes: int property

Number of nodes in current DAG.

Returns:

Name Type Description
int int

Number of nodes

sparsity: float property

Sparsity of the graph

Returns:

Name Type Description
float float

in [0,1]

add_edge(edge)

Add edge to DAG

Parameters:

Name Type Description Default
edge tuple[str, str]

Edge to add

required
Source code in causalAssembly/dag.py
74
75
76
77
78
79
80
def add_edge(self, edge: tuple[str, str]):
    """Add edge to DAG

    Args:
        edge (tuple[str, str]): Edge to add
    """
    self._add_edge(*edge)

add_edges_from(edges)

Add multiple edges to DAG

Parameters:

Name Type Description Default
edges list[tuple[str, str]]

Edges to add

required
Source code in causalAssembly/dag.py
82
83
84
85
86
87
88
89
def add_edges_from(self, edges: list[tuple[str, str]]):
    """Add multiple edges to DAG

    Args:
        edges (list[tuple[str, str]]): Edges to add
    """
    for edge in edges:
        self.add_edge(edge=edge)

children(of_node)

Gives all children of node of_node.

Parameters:

Name Type Description Default
of_node str

node in current DAG.

required

Returns:

Name Type Description
list list[str]

of children.

Source code in causalAssembly/dag.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def children(self, of_node: str) -> list[str]:
    """Gives all children of node `of_node`.

    Args:
        of_node (str): node in current DAG.

    Returns:
        list: of children.
    """
    if of_node in self._children.keys():
        return list(self._children[of_node])
    else:
        return []

copy()

Return a copy of the graph

Source code in causalAssembly/dag.py
248
249
250
def copy(self):
    """Return a copy of the graph"""
    return DAG(nodes=self._nodes, edges=self._edges)

from_nx(nx_dag) classmethod

Convert to DAG from nx.DiGraph.

Parameters:

Name Type Description Default
nx_dag DiGraph

DAG in question.

required

Raises:

Type Description
TypeError

If DAG is not nx.DiGraph

Returns:

Type Description
DAG

DAG

Source code in causalAssembly/dag.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
@classmethod
def from_nx(cls, nx_dag: nx.DiGraph) -> DAG:
    """Convert to DAG from nx.DiGraph.

    Args:
        nx_dag (nx.DiGraph): DAG in question.

    Raises:
        TypeError: If DAG is not nx.DiGraph

    Returns:
        DAG
    """
    if not isinstance(nx_dag, nx.DiGraph):
        raise TypeError("DAG must be of type nx.DiGraph")
    return DAG(nodes=list(nx_dag.nodes), edges=list(nx_dag.edges))

from_pandas_adjacency(pd_amat) classmethod

Build DAG from a Pandas adjacency matrix.

Parameters:

Name Type Description Default
pd_amat DataFrame

input adjacency matrix.

required

Returns:

Type Description
DAG

DAG

Source code in causalAssembly/dag.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
@classmethod
def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> DAG:
    """Build DAG from a Pandas adjacency matrix.

    Args:
        pd_amat (pd.DataFrame): input adjacency matrix.

    Returns:
        DAG
    """
    assert pd_amat.shape[0] == pd_amat.shape[1]
    nodes = pd_amat.columns

    all_connections = []
    start, end = np.where(pd_amat != 0)
    for idx, _ in enumerate(start):
        all_connections.append((pd_amat.columns[start[idx]], pd_amat.columns[end[idx]]))

    temp = [set(i) for i in all_connections]
    temp2 = [arc for arc in all_connections if temp.count(set(arc)) > 1]

    dir_edges = [edge for edge in all_connections if edge not in temp2]

    return DAG(nodes=nodes, edges=dir_edges)

induced_subgraph(nodes)

Returns the induced subgraph on the nodes in nodes.

Parameters:

Name Type Description Default
nodes list[str]

List of nodes.

required

Returns:

Name Type Description
DAG DAG

Induced subgraph.

Source code in causalAssembly/dag.py
119
120
121
122
123
124
125
126
127
128
129
def induced_subgraph(self, nodes: list[str]) -> DAG:
    """Returns the induced subgraph on the nodes in `nodes`.

    Args:
        nodes (list[str]): List of nodes.

    Returns:
        DAG: Induced subgraph.
    """
    edges = [(i, j) for i, j in self.edges if i in nodes and j in nodes]
    return DAG(nodes=nodes, edges=edges)

is_acyclic()

Check if the graph is acyclic.

Returns:

Name Type Description
bool bool

True if graph is acyclic.

Source code in causalAssembly/dag.py
150
151
152
153
154
155
156
157
def is_acyclic(self) -> bool:
    """Check if the graph is acyclic.

    Returns:
        bool: True if graph is acyclic.
    """
    nx_dag = self.to_networkx()
    return nx.is_directed_acyclic_graph(nx_dag)

is_adjacent(i, j)

Return True if the graph contains an directed edge between i and j.

Parameters:

Name Type Description Default
i str

node i.

required
j str

node j.

required

Returns:

Name Type Description
bool bool

True if i->j or i<-j

Source code in causalAssembly/dag.py
131
132
133
134
135
136
137
138
139
140
141
142
def is_adjacent(self, i: str, j: str) -> bool:
    """Return True if the graph contains an directed
    edge between i and j.

    Args:
        i (str): node i.
        j (str): node j.

    Returns:
        bool: True if i->j or i<-j
    """
    return (j, i) in self.edges or (i, j) in self.edges

is_clique(potential_clique)

Check every pair of node X potential_clique is adjacent.

Source code in causalAssembly/dag.py
144
145
146
147
148
def is_clique(self, potential_clique: set) -> bool:
    """
    Check every pair of node X potential_clique is adjacent.
    """
    return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2))

load_drf(filename, location=None) classmethod

Loads a drf dict from a .pkl file into the workspace.

Parameters:

Name Type Description Default
filename str

name of the file e.g. examplefile.pkl

required
location str

path to file in case it's not located in the current working directory. Defaults to None.

None

Returns:

Name Type Description
DRF dict

dict of trained drf objects

Source code in causalAssembly/dag.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
@classmethod
def load_drf(cls, filename: str, location: str = None) -> dict:
    """Loads a drf dict from a .pkl file into the workspace.

    Args:
        filename (str): name of the file e.g. examplefile.pkl
        location (str, optional): path to file in case it's not located
            in the current working directory. Defaults to None.

    Returns:
        DRF (dict): dict of trained drf objects
    """
    if not location:
        location = Path().resolve()

    location_path = Path(location, filename)

    with open(location_path, "rb") as drf:
        pickle_drf = pickle.load(drf)

    return pickle_drf

parents(of_node)

Gives all parents of node of_node.

Parameters:

Name Type Description Default
of_node str

node in current DAG.

required

Returns:

Name Type Description
list list[str]

of parents.

Source code in causalAssembly/dag.py
105
106
107
108
109
110
111
112
113
114
115
116
117
def parents(self, of_node: str) -> list[str]:
    """Gives all parents of node `of_node`.

    Args:
        of_node (str): node in current DAG.

    Returns:
        list: of parents.
    """
    if of_node in self._parents.keys():
        return list(self._parents[of_node])
    else:
        return []

remove_edge(i, j)

Removes edge in question

Parameters:

Name Type Description Default
i str

tail

required
j str

head

required

Raises:

Type Description
AssertionError

if edge does not exist

Source code in causalAssembly/dag.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def remove_edge(self, i: str, j: str):
    """Removes edge in question

    Args:
        i (str): tail
        j (str): head

    Raises:
        AssertionError: if edge does not exist
    """
    if (i, j) not in self.edges:
        raise AssertionError("Edge does not exist in current DAG")

    self._edges.discard((i, j))
    self._children[i].discard(j)
    self._parents[j].discard(i)

remove_node(node)

Remove a node from the graph

Source code in causalAssembly/dag.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def remove_node(self, node):
    """Remove a node from the graph"""
    self._nodes.remove(node)

    self._edges = {(i, j) for i, j in self._edges if i != node and j != node}

    for child in self._children[node]:
        self._parents[child].remove(node)

    for parent in self._parents[node]:
        self._children[parent].remove(node)

    self._parents.pop(node, "I was never here")
    self._children.pop(node, "I was never here")

sample_from_drf(size=10, smoothed=True)

Draw from the trained DRF.

Parameters:

Name Type Description Default
size int

Number of samples to be drawn. Defaults to 10.

10
smoothed bool

If set to true, marginal distributions will be sampled from smoothed bootstraps. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame that follows the distribution implied by the ground truth.

Source code in causalAssembly/dag.py
380
381
382
383
384
385
386
387
388
389
390
391
def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
    """Draw from the trained DRF.

    Args:
        size (int, optional): Number of samples to be drawn. Defaults to 10.
        smoothed (bool, optional): If set to true, marginal distributions will
            be sampled from smoothed bootstraps. Defaults to True.

    Returns:
        pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
    """
    return _sample_from_drf(graph=self, size=size, smoothed=smoothed)

save_drf(filename, location=None)

Writes a drf dict to file. Please provide the .pkl suffix!

Parameters:

Name Type Description Default
filename str

name of the file to be written e.g. examplefile.pkl

required
location str

path to file in case it's not located in the current working directory. Defaults to None.

None
Source code in causalAssembly/dag.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def save_drf(self, filename: str, location: str = None):
    """Writes a drf dict to file. Please provide the .pkl suffix!

    Args:
        filename (str): name of the file to be written e.g. examplefile.pkl
        location (str, optional): path to file in case it's not located in
            the current working directory. Defaults to None.
    """

    if not location:
        location = Path().resolve()

    location_path = Path(location, filename)

    with open(location_path, "wb") as f:
        pickle.dump(self.drf, f)

show()

Plot DAG.

Source code in causalAssembly/dag.py
252
253
254
255
256
def show(self):
    """Plot DAG."""
    graph = self.to_networkx()
    pos = nx.circular_layout(graph)
    nx.draw(graph, pos=pos, with_labels=True)

to_networkx()

Convert to networkx graph.

Returns:

Type Description
DiGraph

nx.DiGraph: DAG.

Source code in causalAssembly/dag.py
258
259
260
261
262
263
264
265
266
267
268
def to_networkx(self) -> nx.DiGraph:
    """Convert to networkx graph.

    Returns:
        nx.DiGraph: DAG.
    """
    nx_dag = nx.DiGraph()
    nx_dag.add_nodes_from(self.nodes)
    nx_dag.add_edges_from(self.edges)

    return nx_dag

vstructs()

Retrieve v-structures

Returns:

Name Type Description
set set

set of all v-structures

Source code in causalAssembly/dag.py
234
235
236
237
238
239
240
241
242
243
244
245
246
def vstructs(self) -> set:
    """Retrieve v-structures

    Returns:
        set: set of all v-structures
    """
    vstructures = set()
    for node in self._nodes:
        for p1, p2 in combinations(self._parents[node], 2):
            if p1 not in self._parents[p2] and p2 not in self._parents[p1]:
                vstructures.add((p1, node))
                vstructures.add((p2, node))
    return vstructures