Skip to content

Documentation

Below, you will find the documentation of the causalAssembly project code.

Utility classes and functions related to causalAssembly.

Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

NodeAttributes dataclass

Node Attributes.

Source code in causalAssembly/models_dag.py
48
49
50
51
52
53
@dataclass
class NodeAttributes:
    """Node Attributes."""

    ALLOW_IN_EDGES = "allow_in_edges"
    HIDDEN = "is_hidden"

ProcessCell

Representation of a single Production Line Cell.

(to model a station / a process in a production line environment).

A Cell can contain multiple modules (sub-graphs, which are nx.DiGraph objects).

Note that none of the term Cell, Process or Module has a strict definition. The convention is based on a production line, consisting of several cells which are to be modeled by means of smaller graphs (modules) by a user of the repository.

Source code in causalAssembly/models_dag.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
class ProcessCell:
    """Representation of a single Production Line Cell.

    (to model a station / a process in a production line
    environment).

    A Cell can contain multiple modules (sub-graphs, which are nx.DiGraph objects).

    Note that none of the term Cell, Process or Module has a strict definition.
    The convention is based on a production line, consisting of several cells which
    are to be modeled by means of smaller graphs (modules) by a user of the repository.

    """

    def __init__(self, name: str):
        """Inits Process cell class.

        Args:
            name (str): _description_
        """
        self.name = name
        self.graph: nx.DiGraph = nx.DiGraph()

        self.description: str = ""  # description of the cell.

        self.__module_prefix = "M"  # M01 vs M1?
        self.modules: dict[str, nx.DiGraph] = dict()  # {'M1': nx.DiGraph, 'M2': nx.DiGraph}
        self.module_connectors: list[tuple] = list()

        self.is_eol = False
        self.random_state = None
        self.drf: dict = dict()

    @property
    def nodes(self) -> list[str]:
        """Nodes in the graph.

        Returns:
            list[str]
        """
        return list(self.graph.nodes())

    @property
    def edges(self) -> list[tuple]:
        """Edges in the graph.

        Returns:
            list[tuple]
        """
        return list(self.graph.edges())

    @property
    def num_nodes(self) -> int:
        """Number of nodes in the graph.

        Returns:
            int
        """
        return len(self.nodes)

    @property
    def num_edges(self) -> int:
        """Number of edges in the graph.

        Returns:
            int
        """
        return len(self.edges)

    @property
    def sparsity(self) -> float:
        """Sparsity of the graph.

        Returns:
            float: in [0,1]
        """
        s = self.num_nodes
        return self.num_edges / s / (s - 1) * 2

    @property
    def ground_truth(self) -> pd.DataFrame:
        """Returns the current ground truth as pandas adjacency.

        Returns:
            pd.DataFrame: Adjacenccy matrix.
        """
        return nx.to_pandas_adjacency(self.graph, weight=None)

    @property
    def causal_order(self) -> list[str]:
        """Returns the causal order of the current graph.

        Note that this order is in general not unique.

        Returns:
            list[str]: Causal order
        """
        return list(nx.topological_sort(self.graph))

    def parents(self, of_node: str) -> list[str]:
        """Return parents of node in question.

        Args:
            of_node (str): Node in question.

        Returns:
            list[str]: parent set.
        """
        return list(self.graph.predecessors(of_node))

    def save_drf(self, filename: str, location: str | Path | None = None):
        """Writes a drf dict to file. Please provide the .pkl suffix!

        Args:
            filename (str): name of the file to be written e.g. examplefile.pkl
            location (str, optional): path to file in case it's not located in
                the current working directory. Defaults to None.
        """
        if not location:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "wb") as f:
            pickle.dump(self.drf, f)

    def add_module(
        self,
        graph: nx.DiGraph,
        allow_in_edges: bool = True,
        mark_hidden: bool | list = False,
    ) -> str:
        """Adds module to cell graph. Module has to be as nx.DiGraph object.

        Args:
            graph (nx.DiGraph): Graph to add to cell.
            allow_in_edges (bool, optional):
                whether nodes in the module are allowed to
                have in-edges. Defaults to True.
            mark_hidden (bool | list, optional):
                If False all nodes' 'is_hidden' attribute is set to False.
                If list of node names is provided these get overwritten to True.
                Defaults to False.

        Returns:
            str: prefix of Module created
        """
        next_module_prefix = self.next_module_prefix()

        node_renaming_dict = {
            old_node_name: f"{self.name}_{next_module_prefix}_{old_node_name}"
            for old_node_name in graph.nodes()
        }
        self.modules[self.next_module_prefix()] = graph.copy()  # type: ignore
        graph = nx.relabel_nodes(graph, node_renaming_dict)

        if allow_in_edges:  # for later: mark nodes to not have incoming edges
            nx.set_node_attributes(graph, True, NodeAttributes.ALLOW_IN_EDGES)
        else:
            nx.set_node_attributes(graph, False, NodeAttributes.ALLOW_IN_EDGES)

        nx.set_node_attributes(
            graph, False, NodeAttributes.HIDDEN
        )  # set all non-hidden -> visible by default
        if isinstance(mark_hidden, list):
            mark_hidden_renamed = [
                f"{self.name}_{next_module_prefix}_{new_name}" for new_name in mark_hidden
            ]
            overwrite_dict = {node: {NodeAttributes.HIDDEN: True} for node in mark_hidden_renamed}
            nx.set_node_attributes(
                graph, values=overwrite_dict
            )  # only overwrite the ones specified
        self.graph = nx.compose(self.graph, graph)

        return next_module_prefix

    def input_cellgraph_directly(self, graph: nx.DiGraph, allow_in_edges: bool = False):
        """Allow to input graphs on a cell-level.

        This should only be done if the graph
        is already available for the entire cell, otherwise `add_module()` is preferred.

        Args:
            graph (nx.DiGraph): Cell graph to input
            allow_in_edges (bool, optional): Defaults to False.
        """
        if allow_in_edges:  # for later: mark nodes to not have incoming edges
            nx.set_node_attributes(graph, True, NodeAttributes.ALLOW_IN_EDGES)
        else:
            nx.set_node_attributes(graph, False, NodeAttributes.ALLOW_IN_EDGES)

        node_renaming_dict = {
            old_node_name: f"{self.name}_{old_node_name}" for old_node_name in graph.nodes()
        }
        graph = nx.relabel_nodes(graph, node_renaming_dict)

        self.graph = nx.compose(self.graph, graph)

    def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
        """Draw from the trained DRF.

        Args:
            size (int, optional): Number of samples to be drawn. Defaults to 10.
            smoothed (bool, optional): If set to true, marginal distributions will
                be sampled from smoothed bootstraps. Defaults to True.

        Returns:
            pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
        """
        return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed)

    def _generate_random_dag(self, n_nodes: int = 5, p: float = 0.1) -> nx.DiGraph:
        """Creates a random DAG.

        By taking an arbitrary ordering of the specified number of nodes,
        and then considers edges from node i to j only if i < j.
        That constraint leads to DAGness by construction.

        Args:
            n_nodes (int, optional): Defaults to 5.
            p (float, optional): Defaults to .1.

        Returns:
            nx.DiGraph:
        """
        rng = self.random_state
        if rng is None:
            rng = np.random.default_rng()
        dag = nx.DiGraph()
        dag.add_nodes_from(range(0, n_nodes))

        causal_order = list(dag.nodes)
        rng.shuffle(causal_order)

        all_forward_edges = itertools.combinations(causal_order, 2)
        edges = np.array(list(all_forward_edges))

        random_choice = rng.choice([False, True], p=[1 - p, p], size=edges.shape[0])

        dag.add_edges_from(edges[random_choice])
        return dag

    def add_random_module(self, n_nodes: int = 7, p: float = 0.10):
        """Add random module to the cell.

        Args:
            n_nodes (int, optional): _description_. Defaults to 7.
            p (float, optional): _description_. Defaults to 0.10.
        """
        randomdag = self._generate_random_dag(n_nodes=n_nodes, p=p)
        self.add_module(graph=randomdag, allow_in_edges=True, mark_hidden=False)

    def connect_by_module(self, m1: str, m2: str, edges: list[tuple]):
        """Connect two modules.

        (by name, e.g. M2, M4) of the cell by a list
        of edges with the original node names.

        Args:
            m1: str
            m2: str
            edges: list[tuple]: use the original node names before they have entered into the cell,
                i.e. not with Cy_Mx prefix
        """
        self.__verify_edges_are_allowed(m1=m1, m2=m2, edges=edges)

        node_prefix_m1 = f"{self.name}_{m1}"
        node_prefix_m2 = f"{self.name}_{m2}"

        new_edges = [
            (f"{node_prefix_m1}_{edge[0]}", f"{node_prefix_m2}_{edge[1]}") for edge in edges
        ]

        [self.module_connectors.append(edge) for edge in new_edges]

        self.graph.add_edges_from(new_edges)

    def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph:
        """Add random edges to graph.

        according to proportion
        with restriction specified in node attributes.

        Args:
            sparsity (float, optional): Sparsity parameter in (0,1). Defaults to 0.1.

        Raises:
            NotImplementedError: when node attributes are not set.
            TypeError: when resulting graph contains cycles.

        Returns:
            nx.DiGraph: DAG with new edges added.
        """
        rng = self.random_state
        if rng is None:
            rng = np.random.default_rng()
        arrow_head_candidates = get_arrow_head_candidates_from_graph(
            graph=self.graph, node_attributes_to_filter=NodeAttributes.ALLOW_IN_EDGES
        )

        arrow_tail_candidates = [node for node in self.nodes if node not in arrow_head_candidates]

        potential_edges = tuples_from_cartesian_product(
            l1=arrow_tail_candidates, l2=arrow_head_candidates
        )
        num_choices = int(np.ceil(sparsity * len(potential_edges)))

        ### choose edges uniformly according to sparsity parameter
        chosen_edges = [
            potential_edges[i]
            for i in rng.choice(a=len(potential_edges), size=num_choices, replace=False)
        ]

        self.graph.add_edges_from(chosen_edges)

        if not nx.is_directed_acyclic_graph(self.graph):
            raise TypeError(
                "The randomly chosen edges induce cycles, this is not supposed to happen."
            )
        return self.graph

    def __repr__(self):
        """Repr method.

        Returns:
            _type_: _description_
        """
        return f"ProcessCell(name={self.name})"

    def __str__(self):
        """Str method.

        Returns:
            _type_: _description_
        """
        cell_description = {
            "Cell Name: ": self.name,
            "Description:": self.description if self.description else "n.a.",
            "Modules:": self.no_of_modules,
            "Nodes: ": self.num_nodes,
        }
        s = ""
        for info, info_text in cell_description.items():
            s += f"{info:<14}{info_text:>5}\n"

        return s

    def __verify_edges_are_allowed(self, m1: str, m2: str, edges: list[tuple]):
        """Check whether all starting point nodes (first value in edge tuple) are allowed.

        Args:
            m1 (str): Module1
            m2 (str): Module2
            edges (list[tuple]): Edges

        Raises:
            ValueError: starting node not in M1
            ValueError: ending node not in M2
        """
        source_nodes = set([e[0] for e in edges])
        target_nodes = set([e[1] for e in edges])
        m1_nodes = set(self.modules.get(m1).nodes())  # type: ignore
        m2_nodes = set(self.modules.get(m2).nodes())  # type: ignore

        if not source_nodes.issubset(m1_nodes):
            raise ValueError(f"source nodes: {source_nodes} not include in {m1}s nodes: {m1_nodes}")
        if not target_nodes.issubset(m2_nodes):
            raise ValueError(f"target nodes: {target_nodes} not include in {m2}s nodes: {m2_nodes}")

    def next_module_prefix(self) -> str:
        """Return the next module prefix, e.g.

        if there are already 3 modules connected to the cell,
        will return module_prefix4

        Returns:
            str: module_prefix
        """
        return f"{self.__module_prefix}{self.no_of_modules + 1}"

    @property
    def module_prefix(self) -> str:
        """Module prefix.

        Returns:
            str: _description_
        """
        return self.__module_prefix

    @module_prefix.setter
    def module_prefix(self, module_prefix: str):
        if not isinstance(module_prefix, str):
            raise ValueError("please only use strings as module prefix")

        self.__module_prefix = module_prefix

    @property
    def no_of_modules(self) -> int:
        """Number of modules in the cell.

        Returns:
            int: _description_
        """
        return len(self.modules)

    def get_available_attributes(self):
        """Get available attributes of the nodes in the graph.

        Returns:
            _type_: _description_
        """
        available_attributes = set()
        for node_tuple in self.graph.nodes(data=True):
            for attribute_name in node_tuple[1].keys():
                available_attributes.add(attribute_name)

        return list(available_attributes)

    def to_cpdag(self) -> PDAG:
        """To CPDAG conversion.

        Returns:
            PDAG: _description_
        """
        return dag2cpdag(dag=self.graph)

    def show(
        self,
        meta_desc: str = "",
    ):
        """Plots the cell graph.

        by giving extra weight to nodes
        with high in- and out-degree.

        Args:
            meta_desc (str, optional): Defaults to "".

        """
        cmap = plt.get_cmap("cividis")
        fig, ax = plt.subplots()
        center: np.ndarray = np.array([0, 0])
        pos = nx.spring_layout(
            self.graph,
            center=center,
            seed=10,
            k=50,
        )

        max_in_degree = max([d for _, d in self.graph.in_degree()])
        max_out_degree = max([d for _, d in self.graph.out_degree()])

        nx.draw_networkx_nodes(
            self.graph,
            pos=pos,
            ax=ax,
            cmap=cmap,
            vmin=-0.2,
            vmax=1,
            node_color=[
                (d + 10) / (max_in_degree + 10)
                for _, d in self.graph.in_degree(self.nodes)  # type: ignore
            ],
            node_size=[
                500 * (d + 1) / (max_out_degree + 1) for _, d in self.graph.out_degree(self.nodes)
            ],  # type: ignore
        )

        nx.draw_networkx_edges(
            self.graph,
            pos=pos,
            ax=ax,
            alpha=0.2,
            arrowsize=8,
            width=0.5,
            connectionstyle="arc3,rad=0.3",
        )

        ax.text(
            center[0],
            center[1] + 1.2,
            self.name + f"\n{meta_desc}",
            horizontalalignment="center",
            fontsize=12,
        )

        ax.axis("off")

    def _plot_cellgraph(
        self,
        ax,
        node_color,
        node_size,
        center=np.array([0, 0]),
        with_edges=True,
        with_box=True,
        meta_desc="",
    ):
        """Plots the cell graph.

        by giving extra weight to nodes
        with high in- and out-degree.

        Args:
            ax (_type_): _description_
            node_color (_type_): _description_
            node_size (_type_): _description_
            center (_type_, optional): _description_. Defaults to np.array([0, 0]).
            with_edges (bool, optional): _description_. Defaults to True.
            with_box (bool, optional): _description_. Defaults to True.
            meta_desc (str, optional): _description_. Defaults to "".

        Returns:
            _type_: _description_
        """
        cmap = plt.get_cmap("cividis")

        pos = nx.spring_layout(
            self.graph,
            center=center,
            seed=10,
            k=50,
        )

        nx.draw_networkx_nodes(
            self.graph,
            pos=pos,
            ax=ax,
            cmap=cmap,
            vmin=-0.2,
            vmax=1,
            node_color=node_color,
            node_size=node_size,
        )

        if with_edges:
            nx.draw_networkx_edges(
                self.graph,
                pos=pos,
                ax=ax,
                alpha=0.2,
                arrowsize=8,
                width=0.5,
                connectionstyle="arc3,rad=0.3",
            )

        ax.text(
            center[0],
            center[1] + 1.2,
            self.name + f"\n{meta_desc}",
            horizontalalignment="center",
            fontsize=12,
        )

        if with_box:
            ax.add_collection(
                PatchCollection(
                    [
                        FancyBboxPatch(
                            center - [2, 1],  # type: ignore
                            4,
                            2.6,
                            boxstyle=BoxStyle("Round", pad=0.02),
                        )
                    ],
                    alpha=0.2,
                    color="gray",
                )
            )

        ax.axis("off")
        return pos

causal_order property

Returns the causal order of the current graph.

Note that this order is in general not unique.

Returns:

Type Description
list[str]

list[str]: Causal order

edges property

Edges in the graph.

Returns:

Type Description
list[tuple]

list[tuple]

ground_truth property

Returns the current ground truth as pandas adjacency.

Returns:

Type Description
DataFrame

pd.DataFrame: Adjacenccy matrix.

module_prefix property writable

Module prefix.

Returns:

Name Type Description
str str

description

no_of_modules property

Number of modules in the cell.

Returns:

Name Type Description
int int

description

nodes property

Nodes in the graph.

Returns:

Type Description
list[str]

list[str]

num_edges property

Number of edges in the graph.

Returns:

Type Description
int

int

num_nodes property

Number of nodes in the graph.

Returns:

Type Description
int

int

sparsity property

Sparsity of the graph.

Returns:

Name Type Description
float float

in [0,1]

__init__(name)

Inits Process cell class.

Parameters:

Name Type Description Default
name str

description

required
Source code in causalAssembly/models_dag.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def __init__(self, name: str):
    """Inits Process cell class.

    Args:
        name (str): _description_
    """
    self.name = name
    self.graph: nx.DiGraph = nx.DiGraph()

    self.description: str = ""  # description of the cell.

    self.__module_prefix = "M"  # M01 vs M1?
    self.modules: dict[str, nx.DiGraph] = dict()  # {'M1': nx.DiGraph, 'M2': nx.DiGraph}
    self.module_connectors: list[tuple] = list()

    self.is_eol = False
    self.random_state = None
    self.drf: dict = dict()

__repr__()

Repr method.

Returns:

Name Type Description
_type_

description

Source code in causalAssembly/models_dag.py
485
486
487
488
489
490
491
def __repr__(self):
    """Repr method.

    Returns:
        _type_: _description_
    """
    return f"ProcessCell(name={self.name})"

__str__()

Str method.

Returns:

Name Type Description
_type_

description

Source code in causalAssembly/models_dag.py
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
def __str__(self):
    """Str method.

    Returns:
        _type_: _description_
    """
    cell_description = {
        "Cell Name: ": self.name,
        "Description:": self.description if self.description else "n.a.",
        "Modules:": self.no_of_modules,
        "Nodes: ": self.num_nodes,
    }
    s = ""
    for info, info_text in cell_description.items():
        s += f"{info:<14}{info_text:>5}\n"

    return s

__verify_edges_are_allowed(m1, m2, edges)

Check whether all starting point nodes (first value in edge tuple) are allowed.

Parameters:

Name Type Description Default
m1 str

Module1

required
m2 str

Module2

required
edges list[tuple]

Edges

required

Raises:

Type Description
ValueError

starting node not in M1

ValueError

ending node not in M2

Source code in causalAssembly/models_dag.py
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
def __verify_edges_are_allowed(self, m1: str, m2: str, edges: list[tuple]):
    """Check whether all starting point nodes (first value in edge tuple) are allowed.

    Args:
        m1 (str): Module1
        m2 (str): Module2
        edges (list[tuple]): Edges

    Raises:
        ValueError: starting node not in M1
        ValueError: ending node not in M2
    """
    source_nodes = set([e[0] for e in edges])
    target_nodes = set([e[1] for e in edges])
    m1_nodes = set(self.modules.get(m1).nodes())  # type: ignore
    m2_nodes = set(self.modules.get(m2).nodes())  # type: ignore

    if not source_nodes.issubset(m1_nodes):
        raise ValueError(f"source nodes: {source_nodes} not include in {m1}s nodes: {m1_nodes}")
    if not target_nodes.issubset(m2_nodes):
        raise ValueError(f"target nodes: {target_nodes} not include in {m2}s nodes: {m2_nodes}")

add_module(graph, allow_in_edges=True, mark_hidden=False)

Adds module to cell graph. Module has to be as nx.DiGraph object.

Parameters:

Name Type Description Default
graph DiGraph

Graph to add to cell.

required
allow_in_edges bool

whether nodes in the module are allowed to have in-edges. Defaults to True.

True
mark_hidden bool | list

If False all nodes' 'is_hidden' attribute is set to False. If list of node names is provided these get overwritten to True. Defaults to False.

False

Returns:

Name Type Description
str str

prefix of Module created

Source code in causalAssembly/models_dag.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def add_module(
    self,
    graph: nx.DiGraph,
    allow_in_edges: bool = True,
    mark_hidden: bool | list = False,
) -> str:
    """Adds module to cell graph. Module has to be as nx.DiGraph object.

    Args:
        graph (nx.DiGraph): Graph to add to cell.
        allow_in_edges (bool, optional):
            whether nodes in the module are allowed to
            have in-edges. Defaults to True.
        mark_hidden (bool | list, optional):
            If False all nodes' 'is_hidden' attribute is set to False.
            If list of node names is provided these get overwritten to True.
            Defaults to False.

    Returns:
        str: prefix of Module created
    """
    next_module_prefix = self.next_module_prefix()

    node_renaming_dict = {
        old_node_name: f"{self.name}_{next_module_prefix}_{old_node_name}"
        for old_node_name in graph.nodes()
    }
    self.modules[self.next_module_prefix()] = graph.copy()  # type: ignore
    graph = nx.relabel_nodes(graph, node_renaming_dict)

    if allow_in_edges:  # for later: mark nodes to not have incoming edges
        nx.set_node_attributes(graph, True, NodeAttributes.ALLOW_IN_EDGES)
    else:
        nx.set_node_attributes(graph, False, NodeAttributes.ALLOW_IN_EDGES)

    nx.set_node_attributes(
        graph, False, NodeAttributes.HIDDEN
    )  # set all non-hidden -> visible by default
    if isinstance(mark_hidden, list):
        mark_hidden_renamed = [
            f"{self.name}_{next_module_prefix}_{new_name}" for new_name in mark_hidden
        ]
        overwrite_dict = {node: {NodeAttributes.HIDDEN: True} for node in mark_hidden_renamed}
        nx.set_node_attributes(
            graph, values=overwrite_dict
        )  # only overwrite the ones specified
    self.graph = nx.compose(self.graph, graph)

    return next_module_prefix

add_random_module(n_nodes=7, p=0.1)

Add random module to the cell.

Parameters:

Name Type Description Default
n_nodes int

description. Defaults to 7.

7
p float

description. Defaults to 0.10.

0.1
Source code in causalAssembly/models_dag.py
406
407
408
409
410
411
412
413
414
def add_random_module(self, n_nodes: int = 7, p: float = 0.10):
    """Add random module to the cell.

    Args:
        n_nodes (int, optional): _description_. Defaults to 7.
        p (float, optional): _description_. Defaults to 0.10.
    """
    randomdag = self._generate_random_dag(n_nodes=n_nodes, p=p)
    self.add_module(graph=randomdag, allow_in_edges=True, mark_hidden=False)

connect_by_module(m1, m2, edges)

Connect two modules.

(by name, e.g. M2, M4) of the cell by a list of edges with the original node names.

Parameters:

Name Type Description Default
m1 str

str

required
m2 str

str

required
edges list[tuple]

list[tuple]: use the original node names before they have entered into the cell, i.e. not with Cy_Mx prefix

required
Source code in causalAssembly/models_dag.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
def connect_by_module(self, m1: str, m2: str, edges: list[tuple]):
    """Connect two modules.

    (by name, e.g. M2, M4) of the cell by a list
    of edges with the original node names.

    Args:
        m1: str
        m2: str
        edges: list[tuple]: use the original node names before they have entered into the cell,
            i.e. not with Cy_Mx prefix
    """
    self.__verify_edges_are_allowed(m1=m1, m2=m2, edges=edges)

    node_prefix_m1 = f"{self.name}_{m1}"
    node_prefix_m2 = f"{self.name}_{m2}"

    new_edges = [
        (f"{node_prefix_m1}_{edge[0]}", f"{node_prefix_m2}_{edge[1]}") for edge in edges
    ]

    [self.module_connectors.append(edge) for edge in new_edges]

    self.graph.add_edges_from(new_edges)

connect_by_random_edges(sparsity=0.1)

Add random edges to graph.

according to proportion with restriction specified in node attributes.

Parameters:

Name Type Description Default
sparsity float

Sparsity parameter in (0,1). Defaults to 0.1.

0.1

Raises:

Type Description
NotImplementedError

when node attributes are not set.

TypeError

when resulting graph contains cycles.

Returns:

Type Description
DiGraph

nx.DiGraph: DAG with new edges added.

Source code in causalAssembly/models_dag.py
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
def connect_by_random_edges(self, sparsity: float = 0.1) -> nx.DiGraph:
    """Add random edges to graph.

    according to proportion
    with restriction specified in node attributes.

    Args:
        sparsity (float, optional): Sparsity parameter in (0,1). Defaults to 0.1.

    Raises:
        NotImplementedError: when node attributes are not set.
        TypeError: when resulting graph contains cycles.

    Returns:
        nx.DiGraph: DAG with new edges added.
    """
    rng = self.random_state
    if rng is None:
        rng = np.random.default_rng()
    arrow_head_candidates = get_arrow_head_candidates_from_graph(
        graph=self.graph, node_attributes_to_filter=NodeAttributes.ALLOW_IN_EDGES
    )

    arrow_tail_candidates = [node for node in self.nodes if node not in arrow_head_candidates]

    potential_edges = tuples_from_cartesian_product(
        l1=arrow_tail_candidates, l2=arrow_head_candidates
    )
    num_choices = int(np.ceil(sparsity * len(potential_edges)))

    ### choose edges uniformly according to sparsity parameter
    chosen_edges = [
        potential_edges[i]
        for i in rng.choice(a=len(potential_edges), size=num_choices, replace=False)
    ]

    self.graph.add_edges_from(chosen_edges)

    if not nx.is_directed_acyclic_graph(self.graph):
        raise TypeError(
            "The randomly chosen edges induce cycles, this is not supposed to happen."
        )
    return self.graph

get_available_attributes()

Get available attributes of the nodes in the graph.

Returns:

Name Type Description
_type_

description

Source code in causalAssembly/models_dag.py
569
570
571
572
573
574
575
576
577
578
579
580
def get_available_attributes(self):
    """Get available attributes of the nodes in the graph.

    Returns:
        _type_: _description_
    """
    available_attributes = set()
    for node_tuple in self.graph.nodes(data=True):
        for attribute_name in node_tuple[1].keys():
            available_attributes.add(attribute_name)

    return list(available_attributes)

input_cellgraph_directly(graph, allow_in_edges=False)

Allow to input graphs on a cell-level.

This should only be done if the graph is already available for the entire cell, otherwise add_module() is preferred.

Parameters:

Name Type Description Default
graph DiGraph

Cell graph to input

required
allow_in_edges bool

Defaults to False.

False
Source code in causalAssembly/models_dag.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def input_cellgraph_directly(self, graph: nx.DiGraph, allow_in_edges: bool = False):
    """Allow to input graphs on a cell-level.

    This should only be done if the graph
    is already available for the entire cell, otherwise `add_module()` is preferred.

    Args:
        graph (nx.DiGraph): Cell graph to input
        allow_in_edges (bool, optional): Defaults to False.
    """
    if allow_in_edges:  # for later: mark nodes to not have incoming edges
        nx.set_node_attributes(graph, True, NodeAttributes.ALLOW_IN_EDGES)
    else:
        nx.set_node_attributes(graph, False, NodeAttributes.ALLOW_IN_EDGES)

    node_renaming_dict = {
        old_node_name: f"{self.name}_{old_node_name}" for old_node_name in graph.nodes()
    }
    graph = nx.relabel_nodes(graph, node_renaming_dict)

    self.graph = nx.compose(self.graph, graph)

next_module_prefix()

Return the next module prefix, e.g.

if there are already 3 modules connected to the cell, will return module_prefix4

Returns:

Name Type Description
str str

module_prefix

Source code in causalAssembly/models_dag.py
533
534
535
536
537
538
539
540
541
542
def next_module_prefix(self) -> str:
    """Return the next module prefix, e.g.

    if there are already 3 modules connected to the cell,
    will return module_prefix4

    Returns:
        str: module_prefix
    """
    return f"{self.__module_prefix}{self.no_of_modules + 1}"

parents(of_node)

Return parents of node in question.

Parameters:

Name Type Description Default
of_node str

Node in question.

required

Returns:

Type Description
list[str]

list[str]: parent set.

Source code in causalAssembly/models_dag.py
263
264
265
266
267
268
269
270
271
272
def parents(self, of_node: str) -> list[str]:
    """Return parents of node in question.

    Args:
        of_node (str): Node in question.

    Returns:
        list[str]: parent set.
    """
    return list(self.graph.predecessors(of_node))

sample_from_drf(size=10, smoothed=True)

Draw from the trained DRF.

Parameters:

Name Type Description Default
size int

Number of samples to be drawn. Defaults to 10.

10
smoothed bool

If set to true, marginal distributions will be sampled from smoothed bootstraps. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame that follows the distribution implied by the ground truth.

Source code in causalAssembly/models_dag.py
362
363
364
365
366
367
368
369
370
371
372
373
def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
    """Draw from the trained DRF.

    Args:
        size (int, optional): Number of samples to be drawn. Defaults to 10.
        smoothed (bool, optional): If set to true, marginal distributions will
            be sampled from smoothed bootstraps. Defaults to True.

    Returns:
        pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
    """
    return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed)

save_drf(filename, location=None)

Writes a drf dict to file. Please provide the .pkl suffix!

Parameters:

Name Type Description Default
filename str

name of the file to be written e.g. examplefile.pkl

required
location str

path to file in case it's not located in the current working directory. Defaults to None.

None
Source code in causalAssembly/models_dag.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def save_drf(self, filename: str, location: str | Path | None = None):
    """Writes a drf dict to file. Please provide the .pkl suffix!

    Args:
        filename (str): name of the file to be written e.g. examplefile.pkl
        location (str, optional): path to file in case it's not located in
            the current working directory. Defaults to None.
    """
    if not location:
        location = Path().resolve()

    location_path = Path(location, filename)

    with open(location_path, "wb") as f:
        pickle.dump(self.drf, f)

show(meta_desc='')

Plots the cell graph.

by giving extra weight to nodes with high in- and out-degree.

Parameters:

Name Type Description Default
meta_desc str

Defaults to "".

''
Source code in causalAssembly/models_dag.py
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
def show(
    self,
    meta_desc: str = "",
):
    """Plots the cell graph.

    by giving extra weight to nodes
    with high in- and out-degree.

    Args:
        meta_desc (str, optional): Defaults to "".

    """
    cmap = plt.get_cmap("cividis")
    fig, ax = plt.subplots()
    center: np.ndarray = np.array([0, 0])
    pos = nx.spring_layout(
        self.graph,
        center=center,
        seed=10,
        k=50,
    )

    max_in_degree = max([d for _, d in self.graph.in_degree()])
    max_out_degree = max([d for _, d in self.graph.out_degree()])

    nx.draw_networkx_nodes(
        self.graph,
        pos=pos,
        ax=ax,
        cmap=cmap,
        vmin=-0.2,
        vmax=1,
        node_color=[
            (d + 10) / (max_in_degree + 10)
            for _, d in self.graph.in_degree(self.nodes)  # type: ignore
        ],
        node_size=[
            500 * (d + 1) / (max_out_degree + 1) for _, d in self.graph.out_degree(self.nodes)
        ],  # type: ignore
    )

    nx.draw_networkx_edges(
        self.graph,
        pos=pos,
        ax=ax,
        alpha=0.2,
        arrowsize=8,
        width=0.5,
        connectionstyle="arc3,rad=0.3",
    )

    ax.text(
        center[0],
        center[1] + 1.2,
        self.name + f"\n{meta_desc}",
        horizontalalignment="center",
        fontsize=12,
    )

    ax.axis("off")

to_cpdag()

To CPDAG conversion.

Returns:

Name Type Description
PDAG PDAG

description

Source code in causalAssembly/models_dag.py
582
583
584
585
586
587
588
def to_cpdag(self) -> PDAG:
    """To CPDAG conversion.

    Returns:
        PDAG: _description_
    """
    return dag2cpdag(dag=self.graph)

ProductionLineGraph

Blueprint of a Production Line Graph.

A Production Line consists of multiple Cells, each Cell can contain multiple modules. Modules can be instantiated randomly or manually. Cellgraphs and linegraphs can be instantiated directly from nx.DiGraph objects. Similarly, edges can be drawn at random (obeying certain probability choices that can be set by the user) between cells/moduls or manually.

Besides populating a production line with cell/module-graphs one can obtain semi-synthetic data obeying the standard causal assumptions:

1. Markov Property
2. Faithfulness

This can be achieved by fitting distributional random forests to the line/cell-graphs and draw data from these. A random number stream is initiated when calling this class. If desired this can be overwritten manually.

Source code in causalAssembly/models_dag.py
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
class ProductionLineGraph:
    """Blueprint of a Production Line Graph.

    A Production Line consists of multiple Cells, each Cell can contain multiple modules.
    Modules can be instantiated randomly or manually. Cellgraphs and linegraphs can be
    instantiated directly from nx.DiGraph objects. Similarly, edges can be drawn at random
    (obeying certain probability choices that can be set by the user) between cells/moduls
    or manually.

    Besides populating a production line with cell/module-graphs one can obtain
    semi-synthetic data obeying the standard causal assumptions:

        1. Markov Property
        2. Faithfulness

    This can be achieved by fitting distributional random forests to the line/cell-graphs
    and draw data from these. A random number stream is initiated when calling this class.
    If desired this can be overwritten manually.

    """

    def __init__(self):
        """Inits ProductionLineGraph.

        Raises:
            AssertionError: _description_
            TypeError: _description_
            AssertionError: _description_
            AssertionError: _description_
            AssertionError: _description_
            ValueError: _description_
            AssertionError: _description_
            AssertionError: _description_
            ValueError: _description_
            ValueError: _description_
            ValueError: _description_
            TypeError: _description_
            AssertionError: _description_
            AttributeError: _description_

        Returns:
            _type_: _description_
        """
        self._random_state = np.random.default_rng(seed=2023)
        self.cells: dict[str, ProcessCell] = dict()
        self.cell_prefix = "C"
        self.cell_connectors: list[tuple] = list()
        self.cell_connector_edges = list()
        self.cell_order = list()
        self.drf: dict = dict()
        self.interventional_drf: dict = dict()
        self.__init_mutilated_dag()

    @property
    def random_state(self):
        """Random state.

        Returns:
            _type_: _description_
        """
        return self._random_state

    @random_state.setter
    def random_state(self, r: np.random.Generator):
        if not isinstance(r, np.random.Generator):
            raise AssertionError("Specify numpy random number generator object!")
        self._random_state = r

    def __init_mutilated_dag(self):
        self.mutilated_dags = dict()

    @property
    def graph(self) -> nx.DiGraph:
        """Returns a nx.DiGraph object of the actual graph.

        The graph is only built HERE, i.e. all ProcessCells exist standalone in self.cells,
        with no connections between their nodes yet.

        The edges are stored in self.cell_connetor_edges, where they are added by random methods
        or by user (the dag builder) himself.

        ATTENTION: you can not work on self.graph and add manually edges, nodes and expect them to
        work.

        Returns nx.DiGraph

        -------

        """
        g = nx.DiGraph()
        for cell in self.cells.values():
            g = nx.compose(g, cell.graph)

        g.add_edges_from(self.cell_connector_edges)

        if not nx.is_directed_acyclic_graph(g):
            raise TypeError(
                "There are cycles in the graph, \
                this is not supposed to happen."
            )

        return g

    @property
    def nodes(self) -> list[str]:
        """Nodes in the graph.

        Returns:
            list[str]
        """
        return list(self.graph.nodes())

    @property
    def edges(self) -> list[tuple]:
        """Edges in the graph.

        Returns:
            list[tuple]
        """
        return list(self.graph.edges())

    @property
    def num_nodes(self) -> int:
        """Number of nodes in the graph.

        Returns:
            int
        """
        return len(self.nodes)

    @property
    def num_edges(self) -> int:
        """Number of edges in the graph.

        Returns:
            int
        """
        return len(self.edges)

    @property
    def sparsity(self) -> float:
        """Sparsity of the graph.

        Returns:
            float: in [0,1]
        """
        s = self.num_nodes
        return self.num_edges / s / (s - 1) * 2

    @property
    def ground_truth(self) -> pd.DataFrame:
        """Returns the current ground truth as pandas adjacency.

        Returns:
            pd.DataFrame: Adjacenccy matrix.
        """
        return nx.to_pandas_adjacency(self.graph, weight=None)

    def _get_union_graph(self) -> nx.DiGraph:
        if not self.cells:
            raise AssertionError("Your pline has no cells. Within has no meaning.")
        union_graph = nx.DiGraph()
        for _, station_graph in self.cells.items():
            union_graph = nx.union(union_graph, station_graph.graph)
        return union_graph

    @property
    def within_adjacency(self) -> pd.DataFrame:
        """Returns adjacency matrix ignoring all between-cell edges.

        Returns:
            pd.DataFrame: adjacency matrix
        """
        union_graph = self._get_union_graph()
        return nx.to_pandas_adjacency(union_graph)

    @property
    def between_adjacency(self) -> pd.DataFrame:
        """Returns adjacency matrix ignoring all within-cell edges.

        Returns:
            pd.DataFrame: adjacency matrix
        """
        union_graph = self._get_union_graph()
        return nx.to_pandas_adjacency(nx.difference(self.graph, union_graph))

    @property
    def causal_order(self) -> list[str]:
        """Returns the causal order of the current graph.

        Note that this order is in general not unique.

        Returns:
            list[str]: Causal order
        """
        return list(nx.topological_sort(self.graph))

    def parents(self, of_node: str) -> list[str]:
        """Return parents of node in question.

        Args:
            of_node (str): Node in question.

        Returns:
            list[str]: parent set.
        """
        return list(self.graph.predecessors(of_node))

    def to_cpdag(self) -> PDAG:
        """Convert to CPDAG.

        Returns:
            PDAG: _description_
        """
        return dag2cpdag(dag=self.graph)

    def get_nodes_of_station(self, station_name: str) -> list:
        """Returns nodes in chosen Station.

        Args:
            station_name (str): name of station.

        Raises:
            AssertionError: if station name doesn't match pline.

        Returns:
            list: nodes in chosen station
        """
        if station_name not in self.cell_order:
            raise AssertionError("Station name not among cells.")

        return self.cells[station_name].nodes

    def __add_cell(self, cell: ProcessCell) -> ProcessCell:
        cell_names = [cell_name for cell_name in self.cells.keys()]

        if cell.is_eol and any([cell.is_eol for cell in self.cells.values()]):
            raise AssertionError(
                f"Cell: {[cell for cell in self.cells.values() if cell.is_eol]} "
                f"is already EOL Cell in ProductionLineGraph."
            )

        if cell.name not in cell_names:
            self.cells[cell.name] = cell

            return cell

        raise ValueError(f"A cell with name: {cell.name} is already in the Production Line.")

    def new_cell(self, name: str | None = None, is_eol: bool = False) -> ProcessCell:
        """Add a new cell to the production line.

        If no name is given, cell name is given by counting available cells + 1

        Args:
            name (str, optional): Defaults to None.
            is_eol (bool, optional): Whether cell is end-of-line. Defaults to False.

        Returns:
            ProcessCell
        """
        if name:
            c = ProcessCell(name=name)

        else:
            actual_no_of_cells = len(self.cells.values())
            c = ProcessCell(name=f"{self.cell_prefix}{actual_no_of_cells}")

        c.random_state = self.random_state  # type: ignore

        c.is_eol = is_eol
        self.__add_cell(cell=c)
        self.cell_order.append(c.name)
        return c

    def connect_cells(
        self,
        forward_probs: list[float] = [0.1, 0.05],
    ):
        """Randomly connects cells in a ProductionLineGraph according to a forwards logic.

        Args:
            forward_probs (list[float], optional): Array of sparsity scalars of
                dimension max_forward. Defaults to [0.1, 0.05].
        """
        # assume cells are initiated in order
        # otherwise allow to change order

        max_forward = len(forward_probs)
        cells = list(self.cells.values())
        no_of_cells = len(self.cells)

        for cell_idx, cell in enumerate(cells):
            prob_it = 0  # prob it(erator)

            for forwards in range(1, max_forward + 1):
                forward_cell_idx = cell_idx + forwards

                if forward_cell_idx < no_of_cells:
                    forward_cell = cells[forward_cell_idx]
                    chosen_edges = choose_edges_from_cells_randomly(
                        from_cell=cell,
                        to_cell=forward_cell,
                        probability=forward_probs[prob_it],
                        rng=self.random_state,
                    )

                    prob_it += 1
                    self.cell_connector_edges.extend(chosen_edges)

            if eol_cell := self.eol_cell:
                eol_cell_idx = cells.index(eol_cell)

                if cell_idx + max_forward < eol_cell_idx:
                    eol_prob = forward_probs[-1]
                    chosen_eol_edges = choose_edges_from_cells_randomly(
                        from_cell=cell,
                        to_cell=eol_cell,
                        probability=eol_prob,
                        rng=self.random_state,
                    )

                    self.cell_connector_edges.extend(chosen_eol_edges)

    def copy(self) -> ProductionLineGraph:
        """Makes a full copy of the current ProductionLineGraph object.

        Returns:
            ProductionLineGraph: copyied object.
        """
        copy_graph = ProductionLineGraph()
        for station in self.cell_order:
            copy_graph.new_cell(station)
            # make sure its sorted
            sorted_graph = nx.DiGraph()
            sorted_graph.add_nodes_from(
                sorted(self.cells[station].nodes, key=lambda x: int(x.rpartition("_")[2]))
            )
            sorted_graph.add_edges_from(self.cells[station].edges)
            copy_graph.cells[station].graph = sorted_graph

        between_cell_edges = nx.difference(self.graph, copy_graph.graph).edges()
        copy_graph.connect_across_cells_manually(edges=between_cell_edges)
        return copy_graph

    def connect_across_cells_manually(self, edges: list[tuple]):
        """Add edges manually across cells.

        You need to give the full name
        Args:
            edges (list[tuple]): list of edges to add
        """
        self.cell_connector_edges.extend(edges)

    def intervene_on(self, nodes_values: dict[str, RandomSymbol | float]):
        """Specify hard or soft intervention.

        If you want to intervene
        upon more than one node provide a list of nodes to intervene on
        and a list of corresponding values to set these nodes to.
        (see example). The mutilated dag will automatically be
        stored in `mutiliated_dags`.

        Args:
            nodes_values (dict[str, RandomSymbol | float]): either single real
                number or sympy.stats.RandomSymbol. If you like to intervene on
                more than one node, just provide more key-value pairs.

        Raises:
            AssertionError: If node(s) are not in the graph
        """
        if not self.drf:
            raise AssertionError("You need to train a drf first.")
        drf_replace = {}

        if not set(nodes_values.keys()).issubset(set(self.nodes)):
            raise AssertionError(
                "One or more nodes you want to intervene upon are not in the graph."
            )

        mutilated_dag = self.graph.copy()

        for node, value in nodes_values.items():
            old_incoming = self.parents(of_node=node)
            edges_to_remove = [(old, node) for old in old_incoming]
            mutilated_dag.remove_edges_from(edges_to_remove)
            drf_replace[node] = value

        self.mutilated_dags[f"do({list(nodes_values.keys())})"] = (
            mutilated_dag  # specifiying the same set twice will override
        )

        self.interventional_drf[f"do({list(nodes_values.keys())})"] = drf_replace

    @property
    def interventions(self) -> list:
        """Returns all interventions performed on the original graph.

        Returns:
            list: list of intervened upon nodes in do(x) notation.
        """
        return list(self.mutilated_dags.keys())

    def interventional_amat(self, which_intervention: int | str) -> pd.DataFrame:
        """Returns the adjacency matrix of a chosen mutilated DAG.

        Args:
            which_intervention (int | str): Integer count of your chosen intervention or
                literal string.

        Raises:
            ValueError: "The intervention you provide does not exist."

        Returns:
            pd.DataFrame: Adjacency matrix.
        """
        if isinstance(which_intervention, str) and which_intervention not in self.interventions:
            raise ValueError("The intervention you provide does not exist.")

        if isinstance(which_intervention, int) and which_intervention > len(self.interventions):
            raise ValueError("The intervention you index does not exist.")

        if isinstance(which_intervention, int):
            which_intervention = self.interventions[which_intervention]

        mutilated_dag = self.mutilated_dags[which_intervention].copy()
        return nx.to_pandas_adjacency(mutilated_dag, weight=None)

    @classmethod
    def get_ground_truth(cls) -> ProductionLineGraph:
        """Loads in the ground_truth as described in the paper.

        causalAssembly: Generating Realistic Production Data for
        Benchmarking Causal Discovery
        Returns:
            ProductionLineGraph: ground_truth for cells and line.
        """
        gt_response = requests.get(DATA_GROUND_TRUTH, timeout=5)
        ground_truth = json.loads(gt_response.text)

        assembly_line = json_graph.adjacency_graph(ground_truth)

        stations = ["Station1", "Station2", "Station3", "Station4", "Station5"]
        ground_truth_line = ProductionLineGraph()

        for station in stations:
            ground_truth_line.new_cell(station)
            station_nodes = [node for node in assembly_line.nodes if node.startswith(station)]
            station_subgraph = nx.subgraph(assembly_line, station_nodes)
            # make sure its sorted
            sorted_graph = nx.DiGraph()
            sorted_graph.add_nodes_from(
                sorted(station_nodes, key=lambda x: int(x.rpartition("_")[2]))
            )
            sorted_graph.add_edges_from(station_subgraph.edges)
            ground_truth_line.cells[station].graph = sorted_graph

        between_cell_edges = nx.difference(assembly_line, ground_truth_line.graph).edges()
        ground_truth_line.connect_across_cells_manually(edges=between_cell_edges)
        return ground_truth_line

    @classmethod
    def get_data(cls) -> pd.DataFrame:
        """Load in semi-synthetic data as described in the paper.

        causalAssembly: Generating Realistic Production Data for
        Benchmarking Causal Discovery
        Returns:
            pd.DataFrame: Data from which data should be generated.
        """
        return pd.read_csv(DATA_DATASET)

    @classmethod
    def from_nx(cls, g: nx.DiGraph, cell_mapper: dict[str, list]):
        """Convert nx.DiGraph to ProductionLineGraph.

        Requires a dict mapping
        where keys are cell names and values correspond to nodes within these cells.

        Args:
            g (nx.DiGraph): graph to be converted
            cell_mapper (dict[str, list]): dict to indicate what nodes belong to which cell

        Returns:
            ProductionLineGraph (ProductionLineGraph): the graph as a ProductionLineGraph object.
        """
        if not isinstance(g, nx.DiGraph):
            raise ValueError("Graph must be of type nx.DiGraph")
        pline = ProductionLineGraph()
        if cell_mapper:
            for cellname, cols in cell_mapper.items():
                pline.new_cell(name=cellname)
                cell_graph = nx.induced_subgraph(g, cols)
                pline.cells[cellname].input_cellgraph_directly(cell_graph, allow_in_edges=True)
        relabel_dict = {}
        for cellname, cols in cell_mapper.items():
            for col in cols:
                relabel_dict[col] = cellname + "_" + col

        g_rename = nx.relabel_nodes(g, relabel_dict)
        between_cell_edges = nx.difference(g_rename, pline.graph).edges()
        pline.connect_across_cells_manually(edges=between_cell_edges)
        return pline

    @classmethod
    def load_drf(cls, filename: str, location: str | Path | None = None):
        """Loads a drf dict from a .pkl file into the workspace.

        Args:
            filename (str): name of the file e.g. examplefile.pkl
            location (str, optional): path to file in case it's not located
                in the current working directory. Defaults to None.

        Returns:
            DRF (dict): dict of trained drf objects
        """
        if not location:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "rb") as drf:
            pickle_drf = pickle.load(drf)

        return pickle_drf

    @classmethod
    def load_pline_from_pickle(cls, filename: str, location: str | Path | None = None):
        """Load production line graph from a pickle file.

        Args:
            filename (str): _description_
            location (str | Path | None, optional): _description_. Defaults to None.

        Raises:
            TypeError: _description_

        Returns:
            _type_: _description_
        """
        if not location:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "rb") as pline:
            pickle_line = pickle.load(pline)

        if not isinstance(pickle_line, ProductionLineGraph):
            raise TypeError("You didn't refer to a ProductionLineGraph.")

        return pickle_line

    def save_drf(self, filename: str, location: str | Path | None = None):
        """Writes a drf dict to file. Please provide the .pkl suffix!

        Args:
            filename (str): name of the file to be written e.g. examplefile.pkl
            location (str, optional): path to file in case it's not located in
                the current working directory. Defaults to None.
        """
        if not location:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "wb") as f:
            pickle.dump(self.drf, f)

    def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
        """Draw from the trained DRF.

        Args:
            size (int, optional): Number of samples to be drawn. Defaults to 10.
            smoothed (bool, optional): If set to true, marginal distributions will
                be sampled from smoothed bootstraps. Defaults to True.

        Returns:
            pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
        """
        return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed)

    def sample_from_interventional_drf(
        self, which_intervention: str | int = 0, size=10, smoothed: bool = True
    ) -> pd.DataFrame:
        """Draw from the trained and intervened upon DRF.

        Args:
            size (int, optional): Number of samples to be drawn. Defaults to 10.
            which_intervention (str | int): Which intervention to choose from.
                Both the literal name (see the property `interventions`) and the index
                are possible. Defaults to the first intervention.
            smoothed (bool, optional): If set to true, marginal distributions will
                be sampled from smoothed bootstraps. Defaults to True.

        Returns:
            pd.DataFrame: Data frame that follows the interventional distribution
                implied by the ground truth.
        """
        return _interventional_sample_from_drf(
            prod_object=self, which_intervention=which_intervention, size=size, smoothed=smoothed
        )

    def hidden_nodes(self) -> list:
        """Returns list of nodes marked as hidden.

        Returns:
            list: of hidden nodes
        """
        return [
            node
            for node, hidden in nx.get_node_attributes(self.graph, NodeAttributes.HIDDEN).items()
            if hidden is True
        ]

    def visible_nodes(self):
        """All visible nodes in the graph.

        Returns:
            _type_: _description_
        """
        return [node for node in self.nodes if node not in self.hidden_nodes()]

    @property
    def eol_cell(self) -> ProcessCell | None:
        """Returns ProcessCell.

        the EOL cell
            (if any single cell has attr .is_eol = True), otherwise returns None
        """
        for cell in self.cells.values():
            if cell.is_eol:
                return cell

    @property
    def ground_truth_visible(self) -> pd.DataFrame:
        """Generates a ground truth graph in form of a pandas adjacency matrix.

        Row and column names correspond to visible.
        The following integers can occur:

        amat[i,j] = 1 indicates i -> j
        amat[i,j] = 0 indicates no edge
        amat[i,j] = amat[j,i] = 2 indicates i <-> j and there exists a common hidden confounder

        Returns:
            pd.DataFrame: amat with values in {0,1,2}.
        """
        if len(self.hidden_nodes()) == 0:
            return self.ground_truth
        else:
            mediators = self._pairs_with_hidden_mediators()
            confounders = self._pairs_with_hidden_confounders()

            # here 1 indicates that ROWS has edge to COLUMNS!
            amat = nx.to_pandas_adjacency(self.graph)
            amat_visible = amat.loc[self.visible_nodes(), self.visible_nodes()]

            for pairs in mediators:
                amat_visible.loc[pairs] = 1

            # reverse = lambda tuples: tuples[::-1]

            def reverse(tuples):
                """Simple function to reverse tuple order.

                Args:
                    tuples (tuple): tuple to reverse order

                Returns:
                    tuple: tuple in reversed order
                """
                new_tup = tuples[::-1]
                return new_tup

            for pair, _ in confounders.items():
                amat_visible.loc[pair] = 2
                amat_visible.loc[reverse(pair)] = 2

            return amat_visible

    def show(self, meta_description: list | None = None, fig_size: tuple = (15, 8)):
        """Plot full assembly line.

        Args:
            meta_description (list | None, optional): Specify additional cell info.
                Defaults to None.
            fig_size (tuple, optional): Adjust depending on number of cells.
                Defaults to (15, 8).

        Raises:
            AssertionError: Meta list entry needs to exist for each cell!
        """
        _, ax = plt.subplots(figsize=fig_size)

        pos = {}

        if meta_description is None:
            meta_description = ["" for _ in range(len(self.cells))]

        if len(meta_description) != len(self.cells):
            raise AssertionError("Meta list entry needs to exist for each cell!")

        max_in_degree = max([d for _, d in self.graph.in_degree()])
        max_out_degree = max([d for _, d in self.graph.out_degree()])

        for i, (station_name, meta_desc) in enumerate(zip(self.cell_order, meta_description)):
            pos.update(
                self.cells[station_name]._plot_cellgraph(
                    ax=ax,
                    with_edges=False,
                    with_box=True,
                    meta_desc=meta_desc,
                    center=np.array([8 * i, 0]),
                    node_color=[
                        (d + 10) / (max_in_degree + 10)
                        for _, d in self.graph.in_degree(self.get_nodes_of_station(station_name))
                    ],
                    node_size=[
                        500 * (d + 1) / (max_out_degree + 1)
                        for _, d in self.graph.out_degree(self.get_nodes_of_station(station_name))
                    ],
                )
            )

        nx.draw_networkx_edges(
            self.graph,
            pos=pos,
            ax=ax,
            alpha=0.2,
            arrowsize=8,
            width=0.5,
            connectionstyle="arc3,rad=0.3",
        )

    def __str__(self):
        """String method for ProductionLineGraph.

        Returns:
            _type_: _description_
        """
        s = "ProductionLine\n\n"
        for cell in self.cells:
            s += f"{cell}\n"
        return s

    def __getattr__(self, attrname):
        """Get a cell by its name.

        Args:
            attrname (_type_): _description_

        Raises:
            AttributeError: _description_

        Returns:
            _type_: _description_
        """
        if attrname not in self.cells.keys():
            raise AttributeError(f"{attrname} is not a valid attribute (cell name?)")
        return self.cells[attrname]

    def __getstate__(self):
        """Get current state of the ProductionLineGraph.

        Returns:
            _type_: _description_
        """
        return (self.__dict__, self.cells)

    def __setstate__(self, state):
        """Set state of the ProductionLineGraph.

        Args:
            state (_type_): _description_
        """
        self.__dict__, self.cells = state

    @classmethod
    def via_cell_number(cls, n_cells: int, cell_prefix: str = "C"):
        """Inits a ProductionLineGraph with predefined number of cells, e.g. n_cells = 3.

        Will create empty  C0, C1 and C2 as cells if no other cell_prefix is given.

        Args:
            n_cells (int): Number of cells the graph will have
            cell_prefix (str, optional): If you like other cell names pass them here.
                Defaults to "C".

        """
        pl = cls()
        pl.cell_prefix = cell_prefix

        [pl.new_cell() for _ in range(n_cells)]

        return pl

    def _pairs_with_hidden_mediators(self):
        """Return pairs of nodes with hidden mediators present.

        Args:
            graph (nx.DiGraph): DAG
            visible (list): list of visible nodes

        Returns:
            list: list of tuples with pairs of nodes with hidden mediator
        """
        TWO = 2
        any_paths = []
        visible = self.visible_nodes()
        hidden_all = self.hidden_nodes()
        confounders = [node.pop() for _, node in self._pairs_with_hidden_confounders().items()]
        hidden = [node for node in hidden_all if node not in confounders]
        for i, _ in enumerate(visible):
            for j, _ in enumerate(visible):
                for path in sorted(nx.all_simple_paths(self.graph, visible[i], visible[j])):
                    any_paths.append(path)

        pairs_with_hidden_mediators = [
            (ls[0], ls[-1])
            for ls in any_paths
            if np.all(np.isin(ls[1:-1], hidden)) and len(ls) > TWO
        ]

        return pairs_with_hidden_mediators

    def _pairs_with_hidden_confounders(self) -> dict:
        """Returns node-pairs that have a common hidden confounder.

        Returns:
            dict: Dictionary with keys equal to tuples of node-pairs
            and values their corresponding hidden confounder(s)
        """
        confounder_pairs = {}
        visible = self.visible_nodes()
        pair_order_list = list(itertools.combinations(visible, 2))

        for node1, node2 in pair_order_list:
            ancestors1 = nx.ancestors(self.graph, node1)
            ancestors2 = nx.ancestors(self.graph, node2)
            if np.all(
                np.concatenate(
                    (
                        np.isin(list(ancestors1), visible, invert=True),
                        np.isin(list(ancestors2), visible, invert=True),
                    ),
                    axis=None,
                )
            ):  # annoying way of doing this. List comparison doesn't allow elementwise eval
                confounder = ancestors1.intersection(ancestors2)
                if confounder:  # only populate if set is non-empty
                    confounder_pairs[(node1, node2)] = confounder
            else:
                direct_parents1 = set(self.graph.predecessors(node1))
                direct_parents2 = set(self.graph.predecessors(node2))
                direct_confounder = [
                    node
                    for node in list(direct_parents1.intersection(direct_parents2))
                    if node not in visible
                ]
                if direct_confounder:
                    confounder_pairs[(node1, node2)] = direct_confounder

        return confounder_pairs

between_adjacency property

Returns adjacency matrix ignoring all within-cell edges.

Returns:

Type Description
DataFrame

pd.DataFrame: adjacency matrix

causal_order property

Returns the causal order of the current graph.

Note that this order is in general not unique.

Returns:

Type Description
list[str]

list[str]: Causal order

edges property

Edges in the graph.

Returns:

Type Description
list[tuple]

list[tuple]

eol_cell property

Returns ProcessCell.

the EOL cell (if any single cell has attr .is_eol = True), otherwise returns None

graph property

Returns a nx.DiGraph object of the actual graph.

The graph is only built HERE, i.e. all ProcessCells exist standalone in self.cells, with no connections between their nodes yet.

The edges are stored in self.cell_connetor_edges, where they are added by random methods or by user (the dag builder) himself.

ATTENTION: you can not work on self.graph and add manually edges, nodes and expect them to work.

Returns nx.DiGraph


ground_truth property

Returns the current ground truth as pandas adjacency.

Returns:

Type Description
DataFrame

pd.DataFrame: Adjacenccy matrix.

ground_truth_visible property

Generates a ground truth graph in form of a pandas adjacency matrix.

Row and column names correspond to visible. The following integers can occur:

amat[i,j] = 1 indicates i -> j amat[i,j] = 0 indicates no edge amat[i,j] = amat[j,i] = 2 indicates i <-> j and there exists a common hidden confounder

Returns:

Type Description
DataFrame

pd.DataFrame: amat with values in {0,1,2}.

interventions property

Returns all interventions performed on the original graph.

Returns:

Name Type Description
list list

list of intervened upon nodes in do(x) notation.

nodes property

Nodes in the graph.

Returns:

Type Description
list[str]

list[str]

num_edges property

Number of edges in the graph.

Returns:

Type Description
int

int

num_nodes property

Number of nodes in the graph.

Returns:

Type Description
int

int

random_state property writable

Random state.

Returns:

Name Type Description
_type_

description

sparsity property

Sparsity of the graph.

Returns:

Name Type Description
float float

in [0,1]

within_adjacency property

Returns adjacency matrix ignoring all between-cell edges.

Returns:

Type Description
DataFrame

pd.DataFrame: adjacency matrix

__getattr__(attrname)

Get a cell by its name.

Parameters:

Name Type Description Default
attrname _type_

description

required

Raises:

Type Description
AttributeError

description

Returns:

Name Type Description
_type_

description

Source code in causalAssembly/models_dag.py
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
def __getattr__(self, attrname):
    """Get a cell by its name.

    Args:
        attrname (_type_): _description_

    Raises:
        AttributeError: _description_

    Returns:
        _type_: _description_
    """
    if attrname not in self.cells.keys():
        raise AttributeError(f"{attrname} is not a valid attribute (cell name?)")
    return self.cells[attrname]

__getstate__()

Get current state of the ProductionLineGraph.

Returns:

Name Type Description
_type_

description

Source code in causalAssembly/models_dag.py
1583
1584
1585
1586
1587
1588
1589
def __getstate__(self):
    """Get current state of the ProductionLineGraph.

    Returns:
        _type_: _description_
    """
    return (self.__dict__, self.cells)

__init__()

Inits ProductionLineGraph.

Raises:

Type Description
AssertionError

description

TypeError

description

AssertionError

description

AssertionError

description

AssertionError

description

ValueError

description

AssertionError

description

AssertionError

description

ValueError

description

ValueError

description

ValueError

description

TypeError

description

AssertionError

description

AttributeError

description

Returns:

Name Type Description
_type_

description

Source code in causalAssembly/models_dag.py
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
def __init__(self):
    """Inits ProductionLineGraph.

    Raises:
        AssertionError: _description_
        TypeError: _description_
        AssertionError: _description_
        AssertionError: _description_
        AssertionError: _description_
        ValueError: _description_
        AssertionError: _description_
        AssertionError: _description_
        ValueError: _description_
        ValueError: _description_
        ValueError: _description_
        TypeError: _description_
        AssertionError: _description_
        AttributeError: _description_

    Returns:
        _type_: _description_
    """
    self._random_state = np.random.default_rng(seed=2023)
    self.cells: dict[str, ProcessCell] = dict()
    self.cell_prefix = "C"
    self.cell_connectors: list[tuple] = list()
    self.cell_connector_edges = list()
    self.cell_order = list()
    self.drf: dict = dict()
    self.interventional_drf: dict = dict()
    self.__init_mutilated_dag()

__setstate__(state)

Set state of the ProductionLineGraph.

Parameters:

Name Type Description Default
state _type_

description

required
Source code in causalAssembly/models_dag.py
1591
1592
1593
1594
1595
1596
1597
def __setstate__(self, state):
    """Set state of the ProductionLineGraph.

    Args:
        state (_type_): _description_
    """
    self.__dict__, self.cells = state

__str__()

String method for ProductionLineGraph.

Returns:

Name Type Description
_type_

description

Source code in causalAssembly/models_dag.py
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
def __str__(self):
    """String method for ProductionLineGraph.

    Returns:
        _type_: _description_
    """
    s = "ProductionLine\n\n"
    for cell in self.cells:
        s += f"{cell}\n"
    return s

connect_across_cells_manually(edges)

Add edges manually across cells.

You need to give the full name Args: edges (list[tuple]): list of edges to add

Source code in causalAssembly/models_dag.py
1166
1167
1168
1169
1170
1171
1172
1173
def connect_across_cells_manually(self, edges: list[tuple]):
    """Add edges manually across cells.

    You need to give the full name
    Args:
        edges (list[tuple]): list of edges to add
    """
    self.cell_connector_edges.extend(edges)

connect_cells(forward_probs=[0.1, 0.05])

Randomly connects cells in a ProductionLineGraph according to a forwards logic.

Parameters:

Name Type Description Default
forward_probs list[float]

Array of sparsity scalars of dimension max_forward. Defaults to [0.1, 0.05].

[0.1, 0.05]
Source code in causalAssembly/models_dag.py
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
def connect_cells(
    self,
    forward_probs: list[float] = [0.1, 0.05],
):
    """Randomly connects cells in a ProductionLineGraph according to a forwards logic.

    Args:
        forward_probs (list[float], optional): Array of sparsity scalars of
            dimension max_forward. Defaults to [0.1, 0.05].
    """
    # assume cells are initiated in order
    # otherwise allow to change order

    max_forward = len(forward_probs)
    cells = list(self.cells.values())
    no_of_cells = len(self.cells)

    for cell_idx, cell in enumerate(cells):
        prob_it = 0  # prob it(erator)

        for forwards in range(1, max_forward + 1):
            forward_cell_idx = cell_idx + forwards

            if forward_cell_idx < no_of_cells:
                forward_cell = cells[forward_cell_idx]
                chosen_edges = choose_edges_from_cells_randomly(
                    from_cell=cell,
                    to_cell=forward_cell,
                    probability=forward_probs[prob_it],
                    rng=self.random_state,
                )

                prob_it += 1
                self.cell_connector_edges.extend(chosen_edges)

        if eol_cell := self.eol_cell:
            eol_cell_idx = cells.index(eol_cell)

            if cell_idx + max_forward < eol_cell_idx:
                eol_prob = forward_probs[-1]
                chosen_eol_edges = choose_edges_from_cells_randomly(
                    from_cell=cell,
                    to_cell=eol_cell,
                    probability=eol_prob,
                    rng=self.random_state,
                )

                self.cell_connector_edges.extend(chosen_eol_edges)

copy()

Makes a full copy of the current ProductionLineGraph object.

Returns:

Name Type Description
ProductionLineGraph ProductionLineGraph

copyied object.

Source code in causalAssembly/models_dag.py
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
def copy(self) -> ProductionLineGraph:
    """Makes a full copy of the current ProductionLineGraph object.

    Returns:
        ProductionLineGraph: copyied object.
    """
    copy_graph = ProductionLineGraph()
    for station in self.cell_order:
        copy_graph.new_cell(station)
        # make sure its sorted
        sorted_graph = nx.DiGraph()
        sorted_graph.add_nodes_from(
            sorted(self.cells[station].nodes, key=lambda x: int(x.rpartition("_")[2]))
        )
        sorted_graph.add_edges_from(self.cells[station].edges)
        copy_graph.cells[station].graph = sorted_graph

    between_cell_edges = nx.difference(self.graph, copy_graph.graph).edges()
    copy_graph.connect_across_cells_manually(edges=between_cell_edges)
    return copy_graph

from_nx(g, cell_mapper) classmethod

Convert nx.DiGraph to ProductionLineGraph.

Requires a dict mapping where keys are cell names and values correspond to nodes within these cells.

Parameters:

Name Type Description Default
g DiGraph

graph to be converted

required
cell_mapper dict[str, list]

dict to indicate what nodes belong to which cell

required

Returns:

Name Type Description
ProductionLineGraph ProductionLineGraph

the graph as a ProductionLineGraph object.

Source code in causalAssembly/models_dag.py
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
@classmethod
def from_nx(cls, g: nx.DiGraph, cell_mapper: dict[str, list]):
    """Convert nx.DiGraph to ProductionLineGraph.

    Requires a dict mapping
    where keys are cell names and values correspond to nodes within these cells.

    Args:
        g (nx.DiGraph): graph to be converted
        cell_mapper (dict[str, list]): dict to indicate what nodes belong to which cell

    Returns:
        ProductionLineGraph (ProductionLineGraph): the graph as a ProductionLineGraph object.
    """
    if not isinstance(g, nx.DiGraph):
        raise ValueError("Graph must be of type nx.DiGraph")
    pline = ProductionLineGraph()
    if cell_mapper:
        for cellname, cols in cell_mapper.items():
            pline.new_cell(name=cellname)
            cell_graph = nx.induced_subgraph(g, cols)
            pline.cells[cellname].input_cellgraph_directly(cell_graph, allow_in_edges=True)
    relabel_dict = {}
    for cellname, cols in cell_mapper.items():
        for col in cols:
            relabel_dict[col] = cellname + "_" + col

    g_rename = nx.relabel_nodes(g, relabel_dict)
    between_cell_edges = nx.difference(g_rename, pline.graph).edges()
    pline.connect_across_cells_manually(edges=between_cell_edges)
    return pline

get_data() classmethod

Load in semi-synthetic data as described in the paper.

causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery Returns: pd.DataFrame: Data from which data should be generated.

Source code in causalAssembly/models_dag.py
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
@classmethod
def get_data(cls) -> pd.DataFrame:
    """Load in semi-synthetic data as described in the paper.

    causalAssembly: Generating Realistic Production Data for
    Benchmarking Causal Discovery
    Returns:
        pd.DataFrame: Data from which data should be generated.
    """
    return pd.read_csv(DATA_DATASET)

get_ground_truth() classmethod

Loads in the ground_truth as described in the paper.

causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery Returns: ProductionLineGraph: ground_truth for cells and line.

Source code in causalAssembly/models_dag.py
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
@classmethod
def get_ground_truth(cls) -> ProductionLineGraph:
    """Loads in the ground_truth as described in the paper.

    causalAssembly: Generating Realistic Production Data for
    Benchmarking Causal Discovery
    Returns:
        ProductionLineGraph: ground_truth for cells and line.
    """
    gt_response = requests.get(DATA_GROUND_TRUTH, timeout=5)
    ground_truth = json.loads(gt_response.text)

    assembly_line = json_graph.adjacency_graph(ground_truth)

    stations = ["Station1", "Station2", "Station3", "Station4", "Station5"]
    ground_truth_line = ProductionLineGraph()

    for station in stations:
        ground_truth_line.new_cell(station)
        station_nodes = [node for node in assembly_line.nodes if node.startswith(station)]
        station_subgraph = nx.subgraph(assembly_line, station_nodes)
        # make sure its sorted
        sorted_graph = nx.DiGraph()
        sorted_graph.add_nodes_from(
            sorted(station_nodes, key=lambda x: int(x.rpartition("_")[2]))
        )
        sorted_graph.add_edges_from(station_subgraph.edges)
        ground_truth_line.cells[station].graph = sorted_graph

    between_cell_edges = nx.difference(assembly_line, ground_truth_line.graph).edges()
    ground_truth_line.connect_across_cells_manually(edges=between_cell_edges)
    return ground_truth_line

get_nodes_of_station(station_name)

Returns nodes in chosen Station.

Parameters:

Name Type Description Default
station_name str

name of station.

required

Raises:

Type Description
AssertionError

if station name doesn't match pline.

Returns:

Name Type Description
list list

nodes in chosen station

Source code in causalAssembly/models_dag.py
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
def get_nodes_of_station(self, station_name: str) -> list:
    """Returns nodes in chosen Station.

    Args:
        station_name (str): name of station.

    Raises:
        AssertionError: if station name doesn't match pline.

    Returns:
        list: nodes in chosen station
    """
    if station_name not in self.cell_order:
        raise AssertionError("Station name not among cells.")

    return self.cells[station_name].nodes

hidden_nodes()

Returns list of nodes marked as hidden.

Returns:

Name Type Description
list list

of hidden nodes

Source code in causalAssembly/models_dag.py
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
def hidden_nodes(self) -> list:
    """Returns list of nodes marked as hidden.

    Returns:
        list: of hidden nodes
    """
    return [
        node
        for node, hidden in nx.get_node_attributes(self.graph, NodeAttributes.HIDDEN).items()
        if hidden is True
    ]

intervene_on(nodes_values)

Specify hard or soft intervention.

If you want to intervene upon more than one node provide a list of nodes to intervene on and a list of corresponding values to set these nodes to. (see example). The mutilated dag will automatically be stored in mutiliated_dags.

Parameters:

Name Type Description Default
nodes_values dict[str, RandomSymbol | float]

either single real number or sympy.stats.RandomSymbol. If you like to intervene on more than one node, just provide more key-value pairs.

required

Raises:

Type Description
AssertionError

If node(s) are not in the graph

Source code in causalAssembly/models_dag.py
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
def intervene_on(self, nodes_values: dict[str, RandomSymbol | float]):
    """Specify hard or soft intervention.

    If you want to intervene
    upon more than one node provide a list of nodes to intervene on
    and a list of corresponding values to set these nodes to.
    (see example). The mutilated dag will automatically be
    stored in `mutiliated_dags`.

    Args:
        nodes_values (dict[str, RandomSymbol | float]): either single real
            number or sympy.stats.RandomSymbol. If you like to intervene on
            more than one node, just provide more key-value pairs.

    Raises:
        AssertionError: If node(s) are not in the graph
    """
    if not self.drf:
        raise AssertionError("You need to train a drf first.")
    drf_replace = {}

    if not set(nodes_values.keys()).issubset(set(self.nodes)):
        raise AssertionError(
            "One or more nodes you want to intervene upon are not in the graph."
        )

    mutilated_dag = self.graph.copy()

    for node, value in nodes_values.items():
        old_incoming = self.parents(of_node=node)
        edges_to_remove = [(old, node) for old in old_incoming]
        mutilated_dag.remove_edges_from(edges_to_remove)
        drf_replace[node] = value

    self.mutilated_dags[f"do({list(nodes_values.keys())})"] = (
        mutilated_dag  # specifiying the same set twice will override
    )

    self.interventional_drf[f"do({list(nodes_values.keys())})"] = drf_replace

interventional_amat(which_intervention)

Returns the adjacency matrix of a chosen mutilated DAG.

Parameters:

Name Type Description Default
which_intervention int | str

Integer count of your chosen intervention or literal string.

required

Raises:

Type Description
ValueError

"The intervention you provide does not exist."

Returns:

Type Description
DataFrame

pd.DataFrame: Adjacency matrix.

Source code in causalAssembly/models_dag.py
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
def interventional_amat(self, which_intervention: int | str) -> pd.DataFrame:
    """Returns the adjacency matrix of a chosen mutilated DAG.

    Args:
        which_intervention (int | str): Integer count of your chosen intervention or
            literal string.

    Raises:
        ValueError: "The intervention you provide does not exist."

    Returns:
        pd.DataFrame: Adjacency matrix.
    """
    if isinstance(which_intervention, str) and which_intervention not in self.interventions:
        raise ValueError("The intervention you provide does not exist.")

    if isinstance(which_intervention, int) and which_intervention > len(self.interventions):
        raise ValueError("The intervention you index does not exist.")

    if isinstance(which_intervention, int):
        which_intervention = self.interventions[which_intervention]

    mutilated_dag = self.mutilated_dags[which_intervention].copy()
    return nx.to_pandas_adjacency(mutilated_dag, weight=None)

load_drf(filename, location=None) classmethod

Loads a drf dict from a .pkl file into the workspace.

Parameters:

Name Type Description Default
filename str

name of the file e.g. examplefile.pkl

required
location str

path to file in case it's not located in the current working directory. Defaults to None.

None

Returns:

Name Type Description
DRF dict

dict of trained drf objects

Source code in causalAssembly/models_dag.py
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
@classmethod
def load_drf(cls, filename: str, location: str | Path | None = None):
    """Loads a drf dict from a .pkl file into the workspace.

    Args:
        filename (str): name of the file e.g. examplefile.pkl
        location (str, optional): path to file in case it's not located
            in the current working directory. Defaults to None.

    Returns:
        DRF (dict): dict of trained drf objects
    """
    if not location:
        location = Path().resolve()

    location_path = Path(location, filename)

    with open(location_path, "rb") as drf:
        pickle_drf = pickle.load(drf)

    return pickle_drf

load_pline_from_pickle(filename, location=None) classmethod

Load production line graph from a pickle file.

Parameters:

Name Type Description Default
filename str

description

required
location str | Path | None

description. Defaults to None.

None

Raises:

Type Description
TypeError

description

Returns:

Name Type Description
_type_

description

Source code in causalAssembly/models_dag.py
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
@classmethod
def load_pline_from_pickle(cls, filename: str, location: str | Path | None = None):
    """Load production line graph from a pickle file.

    Args:
        filename (str): _description_
        location (str | Path | None, optional): _description_. Defaults to None.

    Raises:
        TypeError: _description_

    Returns:
        _type_: _description_
    """
    if not location:
        location = Path().resolve()

    location_path = Path(location, filename)

    with open(location_path, "rb") as pline:
        pickle_line = pickle.load(pline)

    if not isinstance(pickle_line, ProductionLineGraph):
        raise TypeError("You didn't refer to a ProductionLineGraph.")

    return pickle_line

new_cell(name=None, is_eol=False)

Add a new cell to the production line.

If no name is given, cell name is given by counting available cells + 1

Parameters:

Name Type Description Default
name str

Defaults to None.

None
is_eol bool

Whether cell is end-of-line. Defaults to False.

False

Returns:

Type Description
ProcessCell

ProcessCell

Source code in causalAssembly/models_dag.py
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
def new_cell(self, name: str | None = None, is_eol: bool = False) -> ProcessCell:
    """Add a new cell to the production line.

    If no name is given, cell name is given by counting available cells + 1

    Args:
        name (str, optional): Defaults to None.
        is_eol (bool, optional): Whether cell is end-of-line. Defaults to False.

    Returns:
        ProcessCell
    """
    if name:
        c = ProcessCell(name=name)

    else:
        actual_no_of_cells = len(self.cells.values())
        c = ProcessCell(name=f"{self.cell_prefix}{actual_no_of_cells}")

    c.random_state = self.random_state  # type: ignore

    c.is_eol = is_eol
    self.__add_cell(cell=c)
    self.cell_order.append(c.name)
    return c

parents(of_node)

Return parents of node in question.

Parameters:

Name Type Description Default
of_node str

Node in question.

required

Returns:

Type Description
list[str]

list[str]: parent set.

Source code in causalAssembly/models_dag.py
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
def parents(self, of_node: str) -> list[str]:
    """Return parents of node in question.

    Args:
        of_node (str): Node in question.

    Returns:
        list[str]: parent set.
    """
    return list(self.graph.predecessors(of_node))

sample_from_drf(size=10, smoothed=True)

Draw from the trained DRF.

Parameters:

Name Type Description Default
size int

Number of samples to be drawn. Defaults to 10.

10
smoothed bool

If set to true, marginal distributions will be sampled from smoothed bootstraps. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame that follows the distribution implied by the ground truth.

Source code in causalAssembly/models_dag.py
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
    """Draw from the trained DRF.

    Args:
        size (int, optional): Number of samples to be drawn. Defaults to 10.
        smoothed (bool, optional): If set to true, marginal distributions will
            be sampled from smoothed bootstraps. Defaults to True.

    Returns:
        pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
    """
    return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed)

sample_from_interventional_drf(which_intervention=0, size=10, smoothed=True)

Draw from the trained and intervened upon DRF.

Parameters:

Name Type Description Default
size int

Number of samples to be drawn. Defaults to 10.

10
which_intervention str | int

Which intervention to choose from. Both the literal name (see the property interventions) and the index are possible. Defaults to the first intervention.

0
smoothed bool

If set to true, marginal distributions will be sampled from smoothed bootstraps. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame that follows the interventional distribution implied by the ground truth.

Source code in causalAssembly/models_dag.py
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
def sample_from_interventional_drf(
    self, which_intervention: str | int = 0, size=10, smoothed: bool = True
) -> pd.DataFrame:
    """Draw from the trained and intervened upon DRF.

    Args:
        size (int, optional): Number of samples to be drawn. Defaults to 10.
        which_intervention (str | int): Which intervention to choose from.
            Both the literal name (see the property `interventions`) and the index
            are possible. Defaults to the first intervention.
        smoothed (bool, optional): If set to true, marginal distributions will
            be sampled from smoothed bootstraps. Defaults to True.

    Returns:
        pd.DataFrame: Data frame that follows the interventional distribution
            implied by the ground truth.
    """
    return _interventional_sample_from_drf(
        prod_object=self, which_intervention=which_intervention, size=size, smoothed=smoothed
    )

save_drf(filename, location=None)

Writes a drf dict to file. Please provide the .pkl suffix!

Parameters:

Name Type Description Default
filename str

name of the file to be written e.g. examplefile.pkl

required
location str

path to file in case it's not located in the current working directory. Defaults to None.

None
Source code in causalAssembly/models_dag.py
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
def save_drf(self, filename: str, location: str | Path | None = None):
    """Writes a drf dict to file. Please provide the .pkl suffix!

    Args:
        filename (str): name of the file to be written e.g. examplefile.pkl
        location (str, optional): path to file in case it's not located in
            the current working directory. Defaults to None.
    """
    if not location:
        location = Path().resolve()

    location_path = Path(location, filename)

    with open(location_path, "wb") as f:
        pickle.dump(self.drf, f)

show(meta_description=None, fig_size=(15, 8))

Plot full assembly line.

Parameters:

Name Type Description Default
meta_description list | None

Specify additional cell info. Defaults to None.

None
fig_size tuple

Adjust depending on number of cells. Defaults to (15, 8).

(15, 8)

Raises:

Type Description
AssertionError

Meta list entry needs to exist for each cell!

Source code in causalAssembly/models_dag.py
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
def show(self, meta_description: list | None = None, fig_size: tuple = (15, 8)):
    """Plot full assembly line.

    Args:
        meta_description (list | None, optional): Specify additional cell info.
            Defaults to None.
        fig_size (tuple, optional): Adjust depending on number of cells.
            Defaults to (15, 8).

    Raises:
        AssertionError: Meta list entry needs to exist for each cell!
    """
    _, ax = plt.subplots(figsize=fig_size)

    pos = {}

    if meta_description is None:
        meta_description = ["" for _ in range(len(self.cells))]

    if len(meta_description) != len(self.cells):
        raise AssertionError("Meta list entry needs to exist for each cell!")

    max_in_degree = max([d for _, d in self.graph.in_degree()])
    max_out_degree = max([d for _, d in self.graph.out_degree()])

    for i, (station_name, meta_desc) in enumerate(zip(self.cell_order, meta_description)):
        pos.update(
            self.cells[station_name]._plot_cellgraph(
                ax=ax,
                with_edges=False,
                with_box=True,
                meta_desc=meta_desc,
                center=np.array([8 * i, 0]),
                node_color=[
                    (d + 10) / (max_in_degree + 10)
                    for _, d in self.graph.in_degree(self.get_nodes_of_station(station_name))
                ],
                node_size=[
                    500 * (d + 1) / (max_out_degree + 1)
                    for _, d in self.graph.out_degree(self.get_nodes_of_station(station_name))
                ],
            )
        )

    nx.draw_networkx_edges(
        self.graph,
        pos=pos,
        ax=ax,
        alpha=0.2,
        arrowsize=8,
        width=0.5,
        connectionstyle="arc3,rad=0.3",
    )

to_cpdag()

Convert to CPDAG.

Returns:

Name Type Description
PDAG PDAG

description

Source code in causalAssembly/models_dag.py
1029
1030
1031
1032
1033
1034
1035
def to_cpdag(self) -> PDAG:
    """Convert to CPDAG.

    Returns:
        PDAG: _description_
    """
    return dag2cpdag(dag=self.graph)

via_cell_number(n_cells, cell_prefix='C') classmethod

Inits a ProductionLineGraph with predefined number of cells, e.g. n_cells = 3.

Will create empty C0, C1 and C2 as cells if no other cell_prefix is given.

Parameters:

Name Type Description Default
n_cells int

Number of cells the graph will have

required
cell_prefix str

If you like other cell names pass them here. Defaults to "C".

'C'
Source code in causalAssembly/models_dag.py
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
@classmethod
def via_cell_number(cls, n_cells: int, cell_prefix: str = "C"):
    """Inits a ProductionLineGraph with predefined number of cells, e.g. n_cells = 3.

    Will create empty  C0, C1 and C2 as cells if no other cell_prefix is given.

    Args:
        n_cells (int): Number of cells the graph will have
        cell_prefix (str, optional): If you like other cell names pass them here.
            Defaults to "C".

    """
    pl = cls()
    pl.cell_prefix = cell_prefix

    [pl.new_cell() for _ in range(n_cells)]

    return pl

visible_nodes()

All visible nodes in the graph.

Returns:

Name Type Description
_type_

description

Source code in causalAssembly/models_dag.py
1436
1437
1438
1439
1440
1441
1442
def visible_nodes(self):
    """All visible nodes in the graph.

    Returns:
        _type_: _description_
    """
    return [node for node in self.nodes if node not in self.hidden_nodes()]

choose_edges_from_cells_randomly(from_cell, to_cell, probability, rng)

Choose cells randomly.

From two given cells (graphs), we take the cartesian product (end up with from_cell.number_of_nodes x to_cell.number_of_nodes possible edges (node tuples).

From this product we draw probability x cartesian product number of edges randomly.

In case we have a float number, we ceil the value, e.g. 17.3 edges will lead to 18 edges drawn.

Parameters:

Name Type Description Default
from_cell ProcessCell

ProcessCell from where we want the edges

required
to_cell ProcessCell

ProcessCell to where we want the edges

required
probability float

between 0 and 1

required
rng Generator

Random number generator to use.

required

Returns:

Type Description
list[tuple[str, str]]

list[tuple[str, str]]: Chosen edges.

Source code in causalAssembly/models_dag.py
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
def choose_edges_from_cells_randomly(
    from_cell: ProcessCell,
    to_cell: ProcessCell,
    probability: float,
    rng: np.random.Generator,
) -> list[tuple[str, str]]:
    """Choose cells randomly.

    From two given cells (graphs), we take the cartesian product (end up with
    from_cell.number_of_nodes x to_cell.number_of_nodes possible edges (node tuples).

    From this product we draw probability x cartesian product number of edges randomly.

    In case we have a float number, we ceil the value,
    e.g. 17.3 edges will lead to 18 edges drawn.

    Args:
        from_cell: ProcessCell from where we want the edges
        to_cell: ProcessCell to where we want the edges
        probability: between 0 and 1
        rng (np.random.Generator): Random number generator to use.

    Returns:
        list[tuple[str, str]]: Chosen edges.
    """
    ONE = 1.0
    assert 0 <= probability <= ONE

    arrow_tail_candidates = list(from_cell.graph.nodes)
    arrow_head_candidates = get_arrow_head_candidates_from_graph(graph=to_cell.graph)

    potential_edges = tuples_from_cartesian_product(
        l1=arrow_tail_candidates, l2=arrow_head_candidates
    )

    num_to_choose = int(np.ceil(probability * len(potential_edges)))

    chosen_edges = [
        potential_edges[i]
        for i in rng.choice(a=len(potential_edges), size=num_to_choose, replace=False)
    ]

    return chosen_edges

get_arrow_head_candidates_from_graph(graph, node_attributes_to_filter=NodeAttributes.ALLOW_IN_EDGES)

Returns all arrow head (nodes where an arrow points to) nodes as list of candidates.

To later build a list of tuples of potential edges.

Parameters:

Name Type Description Default
graph DiGraph

DAG

required
node_attributes_to_filter str

see NodeAttributes. Defaults to NodeAttributes.ALLOW_IN_EDGES.

ALLOW_IN_EDGES

Returns:

Type Description
list[str]

list[str]: list of nodes

Source code in causalAssembly/models_dag.py
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
def get_arrow_head_candidates_from_graph(
    graph: nx.DiGraph, node_attributes_to_filter: str = NodeAttributes.ALLOW_IN_EDGES
) -> list[str]:
    """Returns all arrow head (nodes where an arrow points to) nodes as list of candidates.

    To later build a list of tuples of potential edges.

    Args:
        graph (nx.DiGraph): DAG
        node_attributes_to_filter (str, optional): see NodeAttributes.
            Defaults to NodeAttributes.ALLOW_IN_EDGES.

    Returns:
        list[str]: list of nodes
    """
    arrow_head_candidates = [
        node
        for node, allowed in nx.get_node_attributes(graph, node_attributes_to_filter).items()
        if allowed is True
    ]

    nodes_without_attributes = list(
        set(graph.nodes).difference(
            set(nx.get_node_attributes(graph, node_attributes_to_filter).keys())
        )
    )

    if len(arrow_head_candidates) == 0 and len(nodes_without_attributes) == 0:
        logger.warning(
            f"None of the nodes in cell {graph} \
            are allowed to have in-edges."
        )

    arrow_head_candidates.extend(nodes_without_attributes)

    return arrow_head_candidates

Utility classes and functions related to causalAssembly.

Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

merge_dags(dag_to_insert, target_dag, mapping, remove_in_edges_in_target_dag=False)

Dag_to_insert will be connected to target_tag via mapping dict.

Parameters:

Name Type Description Default
dag_to_insert DiGraph

DAG to insert.

required
target_dag DiGraph

DAG on which to map.

required
mapping dict

Mapping from insert to target dag e.g. {C1: D1, C5: D4} where node C1 from insert dag will be mapped to node D1 of target dag.

required
remove_in_edges_in_target_dag bool

Defaults to False.

False

Raises:

Type Description
ValueError

node does not exist in target_dag

ValueError

node does not exist in dag_to_insert

Returns: nx.DiGraph: merged DAG

Source code in causalAssembly/dag_utils.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def merge_dags(
    dag_to_insert: nx.DiGraph,
    target_dag: nx.DiGraph,
    mapping: dict,
    remove_in_edges_in_target_dag: bool = False,
) -> nx.DiGraph:
    """Dag_to_insert will be connected to target_tag via mapping dict.

    Args:
        dag_to_insert (nx.DiGraph): DAG to insert.
        target_dag (nx.DiGraph): DAG on which to map.
        mapping (dict): Mapping from insert to target dag e.g.
            {C1: D1, C5: D4} where node C1 from insert dag will
            be mapped to node D1 of target dag.
        remove_in_edges_in_target_dag (bool, optional): Defaults to False.

    Raises:
        ValueError: node does not exist in target_dag
        ValueError: node does not exist in dag_to_insert
    Returns:
        nx.DiGraph: merged DAG
    """
    for old_node_name, new_node_name in mapping.items():
        if new_node_name not in target_dag.nodes():
            raise ValueError(f"{new_node_name} does not exist in target_dag")
        if old_node_name not in dag_to_insert.nodes():
            raise ValueError(f"{old_node_name} does not exist in dag_to_insert")

    no_of_nodes_insert_dag = len(dag_to_insert.nodes())
    no_of_nodes_target_dag = len(target_dag.nodes())
    if no_of_nodes_insert_dag > no_of_nodes_target_dag:
        logger.warning(
            f"you are trying to merge a DAG of size={no_of_nodes_insert_dag} "
            f"into a smaller DAG of size={no_of_nodes_target_dag}. "
            f"If this is not intentional consider changing the order"
        )

    # remove all in edges on target node
    if remove_in_edges_in_target_dag:
        for node in mapping.values():
            e = target_dag.in_edges(node)
            target_dag.remove_edges_from(list(e))

    # rename nodes to glue together
    dag = nx.relabel_nodes(dag_to_insert, mapping)

    return nx.compose(dag, target_dag)

merge_dags_via_edges(left_dag, right_dag, edges=None, isolate_target_nodes=False)

Merges two dags via a list of edges.

Parameters:

Name Type Description Default
left_dag DiGraph

dag to merge to right_dag

required
right_dag DiGraph

dag to merge left_dag into

required
edges list[tuple]

list of edges that connect the two dags. Defaults to None.

None
isolate_target_nodes bool

bool if True all incoming edges from the right_dag into the target node are removed: all influence from the left_dag, defined via edges list. Defaults to False.

False

Raises:

Type Description
ValueError

source or target nodes are not available in left dag

ValueError

source or target nodes are not available in right dag

Source code in causalAssembly/dag_utils.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def merge_dags_via_edges(
    left_dag: nx.DiGraph,
    right_dag: nx.DiGraph,
    edges: list[tuple] | None = None,
    isolate_target_nodes: bool = False,
):
    """Merges two dags via a list of edges.

    Args:
        left_dag (nx.DiGraph): dag to merge to right_dag
        right_dag (nx.DiGraph):  dag to merge left_dag into
        edges (list[tuple], optional): list of edges that connect the two dags.
            Defaults to None.
        isolate_target_nodes (bool, optional): bool if True all incoming edges
            from the right_dag into the target node are removed:
            all influence from the left_dag, defined via edges list.
            Defaults to False.

    Raises:
        ValueError: source or target nodes are not available in left dag
        ValueError: source or target nodes are not available in right dag

    """
    if not edges:
        edges = list()
    source_nodes = set([t[0] for t in edges])
    target_nodes = set([t[1] for t in edges])

    if not source_nodes.issubset(set(left_dag.nodes)):
        raise ValueError(
            f"At least one of the source nodes: {source_nodes} "
            f"cannot be found in the left DAGs nodes: {left_dag.nodes}"
        )

    if not target_nodes.issubset(set(right_dag.nodes)):
        raise ValueError(
            f"At least one of the target nodes: {target_nodes} "
            f"cannot be found in right DAGs nodes: {right_dag.nodes}"
        )

    if isolate_target_nodes:
        for node in target_nodes:
            # cast to list to work with value not with reference
            edges_to_target_node = list(right_dag.in_edges(node))
            right_dag.remove_edges_from(edges_to_target_node)

    merged_dag = nx.compose(left_dag, right_dag)

    merged_dag.add_edges_from(edges, **{"connector": True})

    return merged_dag

tuples_from_cartesian_product(l1, l2)

Given two lists l1 and l2 this creates the cartesian product and returns all tuples.

Parameters:

Name Type Description Default
l1 list

First list of nodes

required
l2 list

Second list of nodes typically

required

Returns:

Type Description
list[tuple]

list[tuple]: list of edges typically

Examples::

l1 = [0,1,2]
l2 = ['a','b','c']
>>> tuples_from_cartesian_product(l1,l2)
[(0,'a'), (0,'b'), (0,'c'), (1,'a'), (1,'b'), (1,'c'), (2,'a'), (2,'b'), (2,'c')]
Source code in causalAssembly/dag_utils.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def tuples_from_cartesian_product(l1: list, l2: list) -> list[tuple]:
    """Given two lists l1 and l2 this creates the cartesian product and returns all tuples.

    Args:
        l1 (list): First list of nodes
        l2 (list): Second list of nodes typically

    Returns:
        list[tuple]: list of edges typically

    Examples::

        l1 = [0,1,2]
        l2 = ['a','b','c']
        >>> tuples_from_cartesian_product(l1,l2)
        [(0,'a'), (0,'b'), (0,'c'), (1,'a'), (1,'b'), (1,'c'), (2,'a'), (2,'b'), (2,'c')]

    """
    return [
        (tail, head)
        for tail, head in itertools.product(
            l1,
            l2,
        )
    ]

Utility classes and functions related to causalAssembly.

Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

FCM

Class to define, intervene and sample from an FCM.

Examples:

from sympy import symbols
from sympy.stats import Normal, Uniform, Gamma

x, y, z = symbols('x,y,z')

eq_x = Eq(x, Uniform("noise", left=-1, right=1))
eq_y = Eq(y, 2 * x ** 2 + Normal("error", 0, .5))
eq_z = Eq(z, 9 * y * x * Gamma("some_name", .5, .5))

eq_list = [eq_x, eq_y, eq_z]


self = FCM(name='test', seed=2023)
self.input_fcm(eq_list)
self.draw(size=10)
Source code in causalAssembly/models_fcm.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
class FCM:
    """Class to define, intervene and sample from an FCM.

    Examples:
        ```python
        from sympy import symbols
        from sympy.stats import Normal, Uniform, Gamma

        x, y, z = symbols('x,y,z')

        eq_x = Eq(x, Uniform("noise", left=-1, right=1))
        eq_y = Eq(y, 2 * x ** 2 + Normal("error", 0, .5))
        eq_z = Eq(z, 9 * y * x * Gamma("some_name", .5, .5))

        eq_list = [eq_x, eq_y, eq_z]


        self = FCM(name='test', seed=2023)
        self.input_fcm(eq_list)
        self.draw(size=10)
        ```

    """

    def __init__(self, name: str | None = None, seed: int = 2023):
        """Inits the FCM class.

        Args:
            name (str | None, optional): Name. Defaults to None.
            seed (int, optional): Seed. Defaults to 2023.
        """
        self.name = name
        self._random_state = np.random.default_rng(seed=seed)
        self.__init_dag()
        self.last_df: pd.DataFrame
        self.__init_mutilated_dag()

    def __init_dag(self):
        self.graph = nx.DiGraph(name=self.name)

    def __init_mutilated_dag(self):
        self.mutilated_dags = dict()

    @property
    def source_nodes(self) -> list:
        """Returns source nodes in the current DAG.

        Returns:
            list: List of source nodes.
        """
        return [node for node in self.nodes if len(self.parents(of_node=node)) == 0]

    @property
    def causal_order(self) -> list[Symbol]:
        """Returns the causal order of the current graph.

        Note that this order is in general not unique. To
        ensure uniqueness, we additionally sort lexicograpically.

        Returns:
            list[Symbol]: Causal order
        """
        return list(nx.lexicographical_topological_sort(self.graph, key=lambda x: str(x)))

    @property
    def nodes(self) -> list[Symbol]:
        """Nodes in the graph.

        Returns:
            list[str]
        """
        return list(self.graph.nodes())

    @property
    def edges(self) -> list[tuple]:
        """Edges in the graph.

        Returns:
            list[tuple]
        """
        return list(self.graph.edges())

    @property
    def num_nodes(self) -> int:
        """Number of nodes in the graph.

        Returns:
            int
        """
        return len(self.nodes)

    @property
    def num_edges(self) -> int:
        """Number of edges in the graph.

        Returns:
            int
        """
        return len(self.edges)

    @property
    def sparsity(self) -> float:
        """Sparsity of the graph.

        Returns:
            float: in [0,1]
        """
        s = self.num_nodes
        return self.num_edges / s / (s - 1) * 2

    @property
    def ground_truth(self) -> pd.DataFrame:
        """Returns the current ground truth as pandas adjacency.

        Returns:
            pd.DataFrame: Adjacenccy matrix.
        """
        return nx.to_pandas_adjacency(self.graph, weight=None)

    @property
    def interventions(self) -> list[str]:
        """Returns all interventions performed on the original graph.

        Returns:
            list: list of intervened upon nodes in do(x) notation.
        """
        return list(self.mutilated_dags.keys())

    def interventional_amat(self, which_intervention: int | str) -> pd.DataFrame:
        """Returns the adjacency matrix of a chosen mutilated DAG.

        Args:
            which_intervention (int | str): Integer count of your chosen intervention or
                literal string.

        Raises:
            ValueError: "The intervention you provide does not exist."

        Returns:
            pd.DataFrame: Adjacency matrix.
        """
        if isinstance(which_intervention, str) and which_intervention not in self.interventions:
            raise ValueError("The intervention you provide does not exist.")

        if isinstance(which_intervention, int) and which_intervention > len(self.interventions):
            raise ValueError("The intervention you index does not exist.")

        if isinstance(which_intervention, int):
            which_intervention = self.interventions[which_intervention]

        mutilated_dag = self.mutilated_dags[which_intervention].copy()
        return nx.to_pandas_adjacency(mutilated_dag, weight=None)

    def parents(self, of_node: Symbol) -> list[Symbol]:
        """Return parents of node in question.

        Args:
            of_node (str): Node in question.

        Returns:
            list[str]: parent set.
        """
        return list(self.graph.predecessors(of_node))

    def parents_of(self, node: Symbol, which_graph: nx.DiGraph) -> list[Symbol]:
        """Return parents of node in question for a chosen DAG.

        Args:
            node (Symbol): node whose parents to return.
            which_graph (nx.DiGraph): which graph along the interventions.

        Returns:
            list[Symbol]: list of parents.
        """
        return list(which_graph.predecessors(node))

    def causal_order_of(self, which_graph: nx.DiGraph) -> list[Symbol]:
        """Returns the causal order of the chosen graph.

        Note that this order is in general not unique. To
        ensure uniqueness, we additionally sort lexicograpically.

        Returns:
            list[Symbol]: Causal order
        """
        return list(nx.lexicographical_topological_sort(which_graph, key=lambda x: str(x)))

    def source_nodes_of(self, which_graph: nx.DiGraph) -> list:
        """Returns the source nodes of a chosen graph.

        This is mainly for
        choosing different mutilated DAGs.

        Args:
            which_graph (nx.DiGraph): DAG from which source nodes should be returned.

        Returns:
            list: List of nodes.
        """
        return [
            node
            for node in which_graph.nodes
            if len(self.parents_of(node=node, which_graph=which_graph)) == 0
        ]

    def input_fcm(self, fcm: list[Eq]):
        """Automatically builds up DAG according to the FCM fed in.

        Args:
            fcm (list): list of sympy equations generated as:
                    ```[python]
                    x,y = symbols('x,y')
                    term_x = Eq(x, Normal('x', 0,1))
                    term_y = Eq(y, 2*x**2*Normal('noise', 0,1))
                    fcm = [term_x, term_y]
                    ```
        """
        nodes_implied = [node.lhs.free_symbols.pop() for node in fcm]
        edges_implied = []
        for term in fcm:
            if not isinstance(term.rhs, RandomSymbol):
                if term.rhs.atoms(RandomSymbol):
                    edges_implied.extend(
                        [
                            (atom, term.lhs)
                            for atom in term.rhs.free_symbols
                            if str(atom) != str(term.rhs.atoms(RandomSymbol).pop())
                        ]
                    )
                else:
                    edges_implied.extend([(atom, term.lhs) for atom in term.rhs.atoms(Symbol)])

        g = self.graph
        g.add_nodes_from(nodes_implied)
        g.add_edges_from(edges_implied)

        term_dict = {}
        for term in fcm:
            term_dict[term.lhs] = {"term": term.rhs}

        nx.set_node_attributes(g, term_dict)

    def function_of(self, node: Symbol) -> dict:
        """Returns functional assignment for node in question.

        Args:
            node (Symbol): node corresponding to lhs.

        Returns:
            dict: key is node and value rhs of functional assignment.
        """
        if node not in self.graph.nodes:
            if node in [str(node) for node in self.nodes]:
                raise AssertionError(
                    "You probably defined a string. Node has to be a symbol, check out",
                    list(self.graph.nodes),
                )
            else:
                raise AssertionError("Node has to be in the graph")

        return {node: self.graph.nodes[node]["term"]}

    def display_functions(self) -> dict:
        """Displays all functional assignments inputted into the FCM.

        Returns:
            dict: Dict with keys equal to nodes and values equal to
                functional assignments.
        """
        fcm_dict = {}
        for node in self.causal_order:
            fcm_dict[node] = self.graph.nodes[node]["term"]

        return fcm_dict

    def sample(
        self,
        size: int,
        additive_gaussian_noise: bool = False,
        snr: None | float = 1 / 2,
        source_df: None | pd.DataFrame = None,
    ) -> pd.DataFrame:
        r"""Sample from joint.

        Draw samples from the joint distribution that factorizes
        according to the DAG implied by the FCM fed in.

        To avoid unexpected/unintended behavior, avoid defining fully
        deterministic equation systems.
        If parameters in noise terms are additive and left unevaluated,
        they're set according to a chosen Signal-To-Noise (SNR) ratio.
        For convenience, you can add additive Gaussian noise to each equation.
        This will never overwrite any of the chosen noise distributions.
        You may also feed in a data frame for noise distributions (see below
        for more details).

        Args:
            size (int): Number of samples to draw.
            additive_gaussian_noise (bool, optional): This will attach additive
                Gaussian noise to all terms without a RandomSymbol that are not
                source nodes. It acts merely as a convenience option. Variance
                will then be chosen according to SNR. Defaults to False.
            snr (None | float, optional): Signal-to-noise ratio
                \\( SNR =  \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\).
                Defaults to 1/2.
            source_df (None | pd.DataFrame, optional): Data frame containing source node data.
                The sample size must be at least as large as the number of samples
                you'd like to draw. Defaults to None.

        Raises:
            AttributeError: if source node parameters are not given explicitly.
            ValueError: if source node sample size is too small.
            ValueError: if scale parameters are left unevaluated for non-additive terms.

        Returns:
            pd.DataFrame:  Data frame with rows of length `size` and columns equal to the
                number of nodes in the graph.
        """
        return self._sample(
            size=size,
            additive_gaussian_noise=additive_gaussian_noise,
            snr=snr,
            source_df=source_df,
            which_graph=self.graph,
        )

    def interventional_sample(
        self,
        size: int,
        which_intervention: str | int = 0,
        additive_gaussian_noise: bool = False,
        snr: None | float = 1 / 2,
        source_df: None | pd.DataFrame = None,
    ) -> pd.DataFrame:
        r"""Draw samples from the interventional distribution.

        that factorizes according to the mutilated DAG after performing one or multiple
        interventions. Otherwise the method behaves similar to sampling from the
        non-interventional joint distribution. By default samples are drawn from the
        first intervention you performed. If you intervened upon more than one node,
        you'll have swith to another intervention for sampling from the corresponding
        interventional distribution.

        Args:
            size (int): Number of samples to draw.
            which_intervention (str | int): Which interventional distribution to draw
                from. We recommend using integer counts starting from zero. But you can
                also provide the literal string here, e.g. if you intervened on say two
                nodes `x,y` then you would need to provide here: `do([x, y])`.
            additive_gaussian_noise (bool, optional): This will attach additive Gaussian noise
                to all terms without a RandomSymbol that are not source nodes. It acts merely as
                a convenience option. Variance will then be chosen according to SNR.
                Defaults to False.
            snr (None | float, optional): Signal-to-noise ratio
                \\( SNR =  \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\). Defaults to 1/2.
            source_df (None | pd.DataFrame, optional): Data frame containing source node data.
                The sample size must be at least as large as the number of samples
                you'd like to draw. Defaults to None.

        Raises:
            NotImplementedError: Raised when `which_intervention` is not of correct form.

        Returns:
            pd.DataFrame: Data frame with rows of length `size` and columns equal to the
                number of nodes in the graph.
        """
        if isinstance(which_intervention, str):
            int_choice = which_intervention
        elif isinstance(which_intervention, int):
            int_choice = self.interventions[which_intervention]
        else:
            raise NotImplementedError(
                f"which_intervention has to be \
                the literal string or an integer \
                starting at count zero indicating \
                which intervention in {self.interventions} \
                to use."
            )

        return self._sample(
            size=size,
            additive_gaussian_noise=additive_gaussian_noise,
            snr=snr,
            source_df=source_df,
            which_graph=self.mutilated_dags[int_choice],
        )

    def _sample(
        self,
        size: int,
        which_graph: nx.DiGraph,
        additive_gaussian_noise: bool = False,
        snr: None | float = 1 / 2,
        source_df: None | pd.DataFrame = None,
    ) -> pd.DataFrame:
        r"""Draw samples from the joint distribution.

        that factorizes
        according to the DAG implied by the FCM fed in. To avoid
        unexpected/unintended behavior, avoid defining fully
        deterministic equation systems.
        If parameters in noise terms are additive and left unevaluated,
        they're set according to a chosen Signal-To-Noise (SNR) ratio.
        For convenience, you can add additive Gaussian noise to each equation.
        This will never overwrite any of the chosen noise distributions.
        You may also feed in a data frame for noise distributions (see below
        for more details).

        Args:
            size (int): Number of samples to draw.
            which_graph (nx.DiGraph): Which graph to sample from.
            additive_gaussian_noise (bool, optional): _description_. Defaults to False.
            snr (None | float, optional): Signal-to-noise ratio
                \\( SNR =  \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\).
                Defaults to 1/2.
            source_df (None | pd.DataFrame, optional): Data frame conaining source node data.
                The sample size must be at least as large as the number of samples
                you'd like to draw. Defaults to None.

        Raises:
            AttributeError: if source node parameters are not given explicitly.
            ValueError: if source node sample size is too small.
            ValueError: if scale parameters are left unevaluated for non-additive terms.

        Returns:
            pd.DataFrame:  Data frame with rows of lenght size and columns equal to the
                number of nodes in the graph.
        """
        if source_df is not None and not self.__source_df_condition(source_df):
            raise AssertionError("Names in source_df don't match nodenames in graph.")

        df = pd.DataFrame()
        for order in self.causal_order_of(which_graph=which_graph):
            if order in self.source_nodes_of(which_graph=which_graph):
                if source_df is not None and str(order) in source_df.columns:
                    if source_df[str(order)].shape[0] < size:
                        raise ValueError(
                            "Sample size of source node data must be at least \
                            as large as the number of samples you'd like to draw."
                        )
                    df[str(order)] = source_df[str(order)].sample(
                        n=size,
                        replace=False,
                        random_state=self._random_state,
                        ignore_index=True,
                    )

                elif isinstance(which_graph.nodes[order]["term"], RandomSymbol):
                    if not self.__distribution_parameters_explicit(order, which_graph=which_graph):
                        raise AttributeError("Source node parameters need to be given explicitly.")
                    df[str(order)] = sympy_sample(
                        which_graph.nodes[order]["term"],
                        seed=self._random_state,
                        size=size,
                    )

                elif isinstance(which_graph.nodes[order]["term"], Number):
                    df[str(order)] = np.repeat(which_graph.nodes[order]["term"], repeats=size)

                else:
                    raise NotImplementedError(
                        "Source nodes need to have a fully parameterized distribution, \
                        or need to be drawn from an appropriate data frame, or fixed to \
                        a single real number."
                    )
                continue

            fcm_expr = which_graph.nodes[order]["term"]

            if fcm_expr.atoms(RandomSymbol):
                if self.__distribution_parameters_explicit(order, which_graph=which_graph):
                    df[str(fcm_expr.atoms(RandomSymbol).pop())] = sympy_sample(
                        fcm_expr.atoms(RandomSymbol).pop(),
                        size=size,
                        seed=self._random_state,
                    )
                else:
                    df[str(fcm_expr.atoms(RandomSymbol).pop())] = np.zeros(size)

            df[str(order)] = self.__eval_expression(df=df, fcm_expr=fcm_expr)

            if fcm_expr.atoms(RandomSymbol) and not self.__distribution_parameters_explicit(
                order, which_graph=which_graph
            ):
                if not fcm_expr.is_Add:
                    raise ValueError(
                        "Noise term in "
                        + str(order)
                        + "="
                        + str(fcm_expr)
                        + " not additive. Scale parameter selection via SNR \
                        makes sense only for additive noise."
                    )
                logger.warning(
                    "I'll choose the noise scale in "
                    + str(order)
                    + "="
                    + str(fcm_expr)
                    + " according to the given SNR."
                )
                noise_var = df[str(order)].var() / snr  # type: ignore
                df[str(order)] = df[str(order)] + sympy_sample(
                    fcm_expr.atoms(RandomSymbol)
                    .pop()
                    .subs(self.__unfree_symbol(fcm_expr), np.sqrt(noise_var)),  # type: ignore
                    size=size,
                    seed=self._random_state,
                )

            if additive_gaussian_noise:
                if fcm_expr.atoms(RandomSymbol):
                    logger.warning(
                        "Noise already defined in "
                        + str(order)
                        + "="
                        + str(fcm_expr)
                        + ". I won't override this."
                    )
                else:
                    noise = symbols("noise")
                    noise_var = df[str(order)].var() / snr  # type: ignore
                    df[str(noise)] = self._random_state.normal(
                        loc=0,
                        scale=np.sqrt(noise_var),  # type: ignore
                        size=size,
                    )
                    fcm_expr = which_graph.nodes[order]["term"] + noise
                    df[str(order)] = self.__eval_expression(df=df, fcm_expr=fcm_expr)

        self.last_df = df[[str(order) for order in self.causal_order]]

        return self.last_df

    def __unfree_symbol(self, fcm_expr) -> set[Symbol]:
        random_symbs = fcm_expr.atoms(Symbol).difference(fcm_expr.free_symbols)
        return {
            unfree
            for unfree in random_symbs
            if str(unfree) != str(fcm_expr.atoms(RandomSymbol).pop())
        }.pop()

    def __eval_expression(self, df: pd.DataFrame, fcm_expr: Expr) -> pd.DataFrame:
        """Eval given fcm_expression with the values in given dataframe.

        Args:
            df (pd.DataFrame): Data frame.
            fcm_expr (Expr): Sympy expression.

        Returns:
            pd.DataFrame: Data frame after eval.
        """
        correct_order = list(ordered(fcm_expr.free_symbols))  # self.__return_ordered_args(fcm_expr)
        cols = [str(col) for col in correct_order]
        evaluator = lambdify(correct_order, fcm_expr, "scipy")

        return evaluator(*[df[col] for col in cols])

    def __distribution_parameters_explicit(self, order: Symbol, which_graph: nx.DiGraph) -> bool:
        """Returns true if distribution parameters are given explicitly, not symbolically.

        Args:
            order (node): node in graph
            which_graph (nx.DiGraph): which graph to choose.

        Returns:
            bool:
        """
        return len(which_graph.nodes[order]["term"].free_symbols) == len(
            which_graph.nodes[order]["term"].atoms(Symbol)
        )

    def __source_df_condition(self, source_df: pd.DataFrame) -> bool:
        """Returns true if source_df colnames and graph nodenames agree.

        Args:
            source_df (None | pd.DataFrame): data frame containing source node data.

        Returns:
            bool: True if names agree
        """
        return {str(col) for col in source_df.columns}.issubset(
            {str(node) for node in self.source_nodes}
        )

    def intervene_on(self, nodes_values: dict[Symbol, RandomSymbol | float]):
        """Specify hard or soft intervention.

        If you want to intervene
        upon more than one node provide a list of nodes to intervene on
        and a list of corresponding values to set these nodes to.
        (see example). The mutilated dag will automatically be
        stored in `mutiliated_dags`.

        Args:
            nodes_values (dict[Symbol, RandomSymbol | float]): either single real
                number or RandmSymbol. If you provide more than one
                intervention just provide more key-value pairs.

        Raises:
            AssertionError: If node(s) are not in the graph

        Example:
            ```python
            x,y = symbols("x,y")
            eq_x = Eq(x, Gamma("source", 1,1))
            eq_y = Eq(y, 4*x**3 + Uniform("noise", left=-0.5, right=0.5))

            example_fcm = FCM()
            example_fcm.input_fcm([eq_x,eq_y])
            # Hard intervention
            example_fcm.intervene_on(nodes_values = {y : 4})
            # Soft intervention
            example_fcm.intervene_on(nodes_values = {y : Normal("noise",0,1)})

            ```
        """
        if not set(nodes_values.keys()).issubset(set(self.nodes)):
            raise AssertionError(
                "One or more nodes you want to intervene upon are not in the graph."
            )

        mutilated_dag = self.graph.copy()

        for node, value in nodes_values.items():
            intervention = Eq(node, value)
            old_incoming = self.parents(of_node=node)
            edges_to_remove = [(old, node) for old in old_incoming]
            mutilated_dag.remove_edges_from(edges_to_remove)
            mutilated_dag.nodes[node]["term"] = intervention.rhs

        self.mutilated_dags[f"do({list(nodes_values.keys())})"] = (
            mutilated_dag  # specifiying the same set twice will override
        )

    def show(self, header: str | None = None, with_nodenames: bool = True):
        """Plots the current DAG.

        Args:
            header (str | None, optional): Header for the DAG. Defaults to None.
            with_nodenames (bool, optional): Whether or not to use nodenames as
                labels in the plot. Defaults to True.

        Returns:
            plt: Plot of the DAG.
        """
        if header is None:
            header = ""
        return self._show(which_graph=self.graph, header=header, with_nodenames=with_nodenames)

    def show_mutilated_dag(self, which_intervention: str | int = 0, with_nodenames: bool = True):
        """Plot mutilated DAG.

        Args:
            which_intervention (str | int, optional): Which interventional distribution
                should be represented by a DAG. Defaults to 0.
            with_nodenames (bool, optional): Whether or not to use nodenames as
                labels in the plot. Defaults to True.

        Returns:
            plt: Plot of the mutilated DAG.
        """
        if isinstance(which_intervention, int):
            which_intervention = self.interventions[which_intervention]

        return self._show(
            which_graph=self.mutilated_dags[which_intervention],
            header=which_intervention,
            with_nodenames=with_nodenames,
        )

    def _show(self, which_graph: nx.DiGraph, header: str, with_nodenames: bool):
        """Plots the graph by giving extra weight to nodes with high in- and out-degree.

        Args:
            which_graph (nx.DiGraph): _description_
            header (str): _description_
            with_nodenames (bool): _description_
        """
        cmap = plt.get_cmap("Blues")
        fig, ax = plt.subplots()
        center: np.ndarray = np.array([0, 0])
        pos = nx.spring_layout(
            which_graph,
            center=center,
            seed=10,
            k=50,
        )

        labels = {}
        for node in self.nodes:
            labels[node] = node

        max_in_degree = max([d for _, d in which_graph.in_degree()])
        max_out_degree = max([d for _, d in which_graph.out_degree()])

        nx.draw_networkx_nodes(
            which_graph,
            pos=pos,
            ax=ax,
            cmap=cmap,
            vmin=-0.2,
            vmax=1,
            node_color=[
                (d + 10) / (max_in_degree + 10) for _, d in which_graph.in_degree(self.nodes)
            ],  # type: ignore
            node_size=[
                500 * (d + 1) / (max_out_degree + 1) for _, d in which_graph.out_degree(self.nodes)
            ],  # type: ignore
        )

        if with_nodenames:
            nx.draw_networkx_labels(
                which_graph,
                pos=pos,
                labels=labels,
                font_size=8,
                font_color="w",
                alpha=0.4,
            )

        nx.draw_networkx_edges(
            which_graph,
            pos=pos,
            ax=ax,
            alpha=0.2,
            arrowsize=8,
            width=0.5,
            connectionstyle="arc3,rad=0.3",
        )

        ax.text(
            center[0],
            center[1] + 1.2,
            f"{header}",
            horizontalalignment="center",
        )

        ax.axis("off")

causal_order property

Returns the causal order of the current graph.

Note that this order is in general not unique. To ensure uniqueness, we additionally sort lexicograpically.

Returns:

Type Description
list[Symbol]

list[Symbol]: Causal order

edges property

Edges in the graph.

Returns:

Type Description
list[tuple]

list[tuple]

ground_truth property

Returns the current ground truth as pandas adjacency.

Returns:

Type Description
DataFrame

pd.DataFrame: Adjacenccy matrix.

interventions property

Returns all interventions performed on the original graph.

Returns:

Name Type Description
list list[str]

list of intervened upon nodes in do(x) notation.

nodes property

Nodes in the graph.

Returns:

Type Description
list[Symbol]

list[str]

num_edges property

Number of edges in the graph.

Returns:

Type Description
int

int

num_nodes property

Number of nodes in the graph.

Returns:

Type Description
int

int

source_nodes property

Returns source nodes in the current DAG.

Returns:

Name Type Description
list list

List of source nodes.

sparsity property

Sparsity of the graph.

Returns:

Name Type Description
float float

in [0,1]

__distribution_parameters_explicit(order, which_graph)

Returns true if distribution parameters are given explicitly, not symbolically.

Parameters:

Name Type Description Default
order node

node in graph

required
which_graph DiGraph

which graph to choose.

required

Returns:

Name Type Description
bool bool
Source code in causalAssembly/models_fcm.py
589
590
591
592
593
594
595
596
597
598
599
600
601
def __distribution_parameters_explicit(self, order: Symbol, which_graph: nx.DiGraph) -> bool:
    """Returns true if distribution parameters are given explicitly, not symbolically.

    Args:
        order (node): node in graph
        which_graph (nx.DiGraph): which graph to choose.

    Returns:
        bool:
    """
    return len(which_graph.nodes[order]["term"].free_symbols) == len(
        which_graph.nodes[order]["term"].atoms(Symbol)
    )

__eval_expression(df, fcm_expr)

Eval given fcm_expression with the values in given dataframe.

Parameters:

Name Type Description Default
df DataFrame

Data frame.

required
fcm_expr Expr

Sympy expression.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame after eval.

Source code in causalAssembly/models_fcm.py
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
def __eval_expression(self, df: pd.DataFrame, fcm_expr: Expr) -> pd.DataFrame:
    """Eval given fcm_expression with the values in given dataframe.

    Args:
        df (pd.DataFrame): Data frame.
        fcm_expr (Expr): Sympy expression.

    Returns:
        pd.DataFrame: Data frame after eval.
    """
    correct_order = list(ordered(fcm_expr.free_symbols))  # self.__return_ordered_args(fcm_expr)
    cols = [str(col) for col in correct_order]
    evaluator = lambdify(correct_order, fcm_expr, "scipy")

    return evaluator(*[df[col] for col in cols])

__init__(name=None, seed=2023)

Inits the FCM class.

Parameters:

Name Type Description Default
name str | None

Name. Defaults to None.

None
seed int

Seed. Defaults to 2023.

2023
Source code in causalAssembly/models_fcm.py
56
57
58
59
60
61
62
63
64
65
66
67
def __init__(self, name: str | None = None, seed: int = 2023):
    """Inits the FCM class.

    Args:
        name (str | None, optional): Name. Defaults to None.
        seed (int, optional): Seed. Defaults to 2023.
    """
    self.name = name
    self._random_state = np.random.default_rng(seed=seed)
    self.__init_dag()
    self.last_df: pd.DataFrame
    self.__init_mutilated_dag()

__source_df_condition(source_df)

Returns true if source_df colnames and graph nodenames agree.

Parameters:

Name Type Description Default
source_df None | DataFrame

data frame containing source node data.

required

Returns:

Name Type Description
bool bool

True if names agree

Source code in causalAssembly/models_fcm.py
603
604
605
606
607
608
609
610
611
612
613
614
def __source_df_condition(self, source_df: pd.DataFrame) -> bool:
    """Returns true if source_df colnames and graph nodenames agree.

    Args:
        source_df (None | pd.DataFrame): data frame containing source node data.

    Returns:
        bool: True if names agree
    """
    return {str(col) for col in source_df.columns}.issubset(
        {str(node) for node in self.source_nodes}
    )

causal_order_of(which_graph)

Returns the causal order of the chosen graph.

Note that this order is in general not unique. To ensure uniqueness, we additionally sort lexicograpically.

Returns:

Type Description
list[Symbol]

list[Symbol]: Causal order

Source code in causalAssembly/models_fcm.py
208
209
210
211
212
213
214
215
216
217
def causal_order_of(self, which_graph: nx.DiGraph) -> list[Symbol]:
    """Returns the causal order of the chosen graph.

    Note that this order is in general not unique. To
    ensure uniqueness, we additionally sort lexicograpically.

    Returns:
        list[Symbol]: Causal order
    """
    return list(nx.lexicographical_topological_sort(which_graph, key=lambda x: str(x)))

display_functions()

Displays all functional assignments inputted into the FCM.

Returns:

Name Type Description
dict dict

Dict with keys equal to nodes and values equal to functional assignments.

Source code in causalAssembly/models_fcm.py
294
295
296
297
298
299
300
301
302
303
304
305
def display_functions(self) -> dict:
    """Displays all functional assignments inputted into the FCM.

    Returns:
        dict: Dict with keys equal to nodes and values equal to
            functional assignments.
    """
    fcm_dict = {}
    for node in self.causal_order:
        fcm_dict[node] = self.graph.nodes[node]["term"]

    return fcm_dict

function_of(node)

Returns functional assignment for node in question.

Parameters:

Name Type Description Default
node Symbol

node corresponding to lhs.

required

Returns:

Name Type Description
dict dict

key is node and value rhs of functional assignment.

Source code in causalAssembly/models_fcm.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def function_of(self, node: Symbol) -> dict:
    """Returns functional assignment for node in question.

    Args:
        node (Symbol): node corresponding to lhs.

    Returns:
        dict: key is node and value rhs of functional assignment.
    """
    if node not in self.graph.nodes:
        if node in [str(node) for node in self.nodes]:
            raise AssertionError(
                "You probably defined a string. Node has to be a symbol, check out",
                list(self.graph.nodes),
            )
        else:
            raise AssertionError("Node has to be in the graph")

    return {node: self.graph.nodes[node]["term"]}

input_fcm(fcm)

Automatically builds up DAG according to the FCM fed in.

Parameters:

Name Type Description Default
fcm list

list of sympy equations generated as: [python] x,y = symbols('x,y') term_x = Eq(x, Normal('x', 0,1)) term_y = Eq(y, 2*x**2*Normal('noise', 0,1)) fcm = [term_x, term_y]

required
Source code in causalAssembly/models_fcm.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
def input_fcm(self, fcm: list[Eq]):
    """Automatically builds up DAG according to the FCM fed in.

    Args:
        fcm (list): list of sympy equations generated as:
                ```[python]
                x,y = symbols('x,y')
                term_x = Eq(x, Normal('x', 0,1))
                term_y = Eq(y, 2*x**2*Normal('noise', 0,1))
                fcm = [term_x, term_y]
                ```
    """
    nodes_implied = [node.lhs.free_symbols.pop() for node in fcm]
    edges_implied = []
    for term in fcm:
        if not isinstance(term.rhs, RandomSymbol):
            if term.rhs.atoms(RandomSymbol):
                edges_implied.extend(
                    [
                        (atom, term.lhs)
                        for atom in term.rhs.free_symbols
                        if str(atom) != str(term.rhs.atoms(RandomSymbol).pop())
                    ]
                )
            else:
                edges_implied.extend([(atom, term.lhs) for atom in term.rhs.atoms(Symbol)])

    g = self.graph
    g.add_nodes_from(nodes_implied)
    g.add_edges_from(edges_implied)

    term_dict = {}
    for term in fcm:
        term_dict[term.lhs] = {"term": term.rhs}

    nx.set_node_attributes(g, term_dict)

intervene_on(nodes_values)

Specify hard or soft intervention.

If you want to intervene upon more than one node provide a list of nodes to intervene on and a list of corresponding values to set these nodes to. (see example). The mutilated dag will automatically be stored in mutiliated_dags.

Parameters:

Name Type Description Default
nodes_values dict[Symbol, RandomSymbol | float]

either single real number or RandmSymbol. If you provide more than one intervention just provide more key-value pairs.

required

Raises:

Type Description
AssertionError

If node(s) are not in the graph

Example
x,y = symbols("x,y")
eq_x = Eq(x, Gamma("source", 1,1))
eq_y = Eq(y, 4*x**3 + Uniform("noise", left=-0.5, right=0.5))

example_fcm = FCM()
example_fcm.input_fcm([eq_x,eq_y])
# Hard intervention
example_fcm.intervene_on(nodes_values = {y : 4})
# Soft intervention
example_fcm.intervene_on(nodes_values = {y : Normal("noise",0,1)})

Source code in causalAssembly/models_fcm.py
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
def intervene_on(self, nodes_values: dict[Symbol, RandomSymbol | float]):
    """Specify hard or soft intervention.

    If you want to intervene
    upon more than one node provide a list of nodes to intervene on
    and a list of corresponding values to set these nodes to.
    (see example). The mutilated dag will automatically be
    stored in `mutiliated_dags`.

    Args:
        nodes_values (dict[Symbol, RandomSymbol | float]): either single real
            number or RandmSymbol. If you provide more than one
            intervention just provide more key-value pairs.

    Raises:
        AssertionError: If node(s) are not in the graph

    Example:
        ```python
        x,y = symbols("x,y")
        eq_x = Eq(x, Gamma("source", 1,1))
        eq_y = Eq(y, 4*x**3 + Uniform("noise", left=-0.5, right=0.5))

        example_fcm = FCM()
        example_fcm.input_fcm([eq_x,eq_y])
        # Hard intervention
        example_fcm.intervene_on(nodes_values = {y : 4})
        # Soft intervention
        example_fcm.intervene_on(nodes_values = {y : Normal("noise",0,1)})

        ```
    """
    if not set(nodes_values.keys()).issubset(set(self.nodes)):
        raise AssertionError(
            "One or more nodes you want to intervene upon are not in the graph."
        )

    mutilated_dag = self.graph.copy()

    for node, value in nodes_values.items():
        intervention = Eq(node, value)
        old_incoming = self.parents(of_node=node)
        edges_to_remove = [(old, node) for old in old_incoming]
        mutilated_dag.remove_edges_from(edges_to_remove)
        mutilated_dag.nodes[node]["term"] = intervention.rhs

    self.mutilated_dags[f"do({list(nodes_values.keys())})"] = (
        mutilated_dag  # specifiying the same set twice will override
    )

interventional_amat(which_intervention)

Returns the adjacency matrix of a chosen mutilated DAG.

Parameters:

Name Type Description Default
which_intervention int | str

Integer count of your chosen intervention or literal string.

required

Raises:

Type Description
ValueError

"The intervention you provide does not exist."

Returns:

Type Description
DataFrame

pd.DataFrame: Adjacency matrix.

Source code in causalAssembly/models_fcm.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def interventional_amat(self, which_intervention: int | str) -> pd.DataFrame:
    """Returns the adjacency matrix of a chosen mutilated DAG.

    Args:
        which_intervention (int | str): Integer count of your chosen intervention or
            literal string.

    Raises:
        ValueError: "The intervention you provide does not exist."

    Returns:
        pd.DataFrame: Adjacency matrix.
    """
    if isinstance(which_intervention, str) and which_intervention not in self.interventions:
        raise ValueError("The intervention you provide does not exist.")

    if isinstance(which_intervention, int) and which_intervention > len(self.interventions):
        raise ValueError("The intervention you index does not exist.")

    if isinstance(which_intervention, int):
        which_intervention = self.interventions[which_intervention]

    mutilated_dag = self.mutilated_dags[which_intervention].copy()
    return nx.to_pandas_adjacency(mutilated_dag, weight=None)

interventional_sample(size, which_intervention=0, additive_gaussian_noise=False, snr=1 / 2, source_df=None)

Draw samples from the interventional distribution.

that factorizes according to the mutilated DAG after performing one or multiple interventions. Otherwise the method behaves similar to sampling from the non-interventional joint distribution. By default samples are drawn from the first intervention you performed. If you intervened upon more than one node, you'll have swith to another intervention for sampling from the corresponding interventional distribution.

Parameters:

Name Type Description Default
size int

Number of samples to draw.

required
which_intervention str | int

Which interventional distribution to draw from. We recommend using integer counts starting from zero. But you can also provide the literal string here, e.g. if you intervened on say two nodes x,y then you would need to provide here: do([x, y]).

0
additive_gaussian_noise bool

This will attach additive Gaussian noise to all terms without a RandomSymbol that are not source nodes. It acts merely as a convenience option. Variance will then be chosen according to SNR. Defaults to False.

False
snr None | float

Signal-to-noise ratio \( SNR = \frac{\text{Var}(\hat{X})}{\hat\sigma^2}. \). Defaults to 1/2.

1 / 2
source_df None | DataFrame

Data frame containing source node data. The sample size must be at least as large as the number of samples you'd like to draw. Defaults to None.

None

Raises:

Type Description
NotImplementedError

Raised when which_intervention is not of correct form.

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame with rows of length size and columns equal to the number of nodes in the graph.

Source code in causalAssembly/models_fcm.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def interventional_sample(
    self,
    size: int,
    which_intervention: str | int = 0,
    additive_gaussian_noise: bool = False,
    snr: None | float = 1 / 2,
    source_df: None | pd.DataFrame = None,
) -> pd.DataFrame:
    r"""Draw samples from the interventional distribution.

    that factorizes according to the mutilated DAG after performing one or multiple
    interventions. Otherwise the method behaves similar to sampling from the
    non-interventional joint distribution. By default samples are drawn from the
    first intervention you performed. If you intervened upon more than one node,
    you'll have swith to another intervention for sampling from the corresponding
    interventional distribution.

    Args:
        size (int): Number of samples to draw.
        which_intervention (str | int): Which interventional distribution to draw
            from. We recommend using integer counts starting from zero. But you can
            also provide the literal string here, e.g. if you intervened on say two
            nodes `x,y` then you would need to provide here: `do([x, y])`.
        additive_gaussian_noise (bool, optional): This will attach additive Gaussian noise
            to all terms without a RandomSymbol that are not source nodes. It acts merely as
            a convenience option. Variance will then be chosen according to SNR.
            Defaults to False.
        snr (None | float, optional): Signal-to-noise ratio
            \\( SNR =  \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\). Defaults to 1/2.
        source_df (None | pd.DataFrame, optional): Data frame containing source node data.
            The sample size must be at least as large as the number of samples
            you'd like to draw. Defaults to None.

    Raises:
        NotImplementedError: Raised when `which_intervention` is not of correct form.

    Returns:
        pd.DataFrame: Data frame with rows of length `size` and columns equal to the
            number of nodes in the graph.
    """
    if isinstance(which_intervention, str):
        int_choice = which_intervention
    elif isinstance(which_intervention, int):
        int_choice = self.interventions[which_intervention]
    else:
        raise NotImplementedError(
            f"which_intervention has to be \
            the literal string or an integer \
            starting at count zero indicating \
            which intervention in {self.interventions} \
            to use."
        )

    return self._sample(
        size=size,
        additive_gaussian_noise=additive_gaussian_noise,
        snr=snr,
        source_df=source_df,
        which_graph=self.mutilated_dags[int_choice],
    )

parents(of_node)

Return parents of node in question.

Parameters:

Name Type Description Default
of_node str

Node in question.

required

Returns:

Type Description
list[Symbol]

list[str]: parent set.

Source code in causalAssembly/models_fcm.py
185
186
187
188
189
190
191
192
193
194
def parents(self, of_node: Symbol) -> list[Symbol]:
    """Return parents of node in question.

    Args:
        of_node (str): Node in question.

    Returns:
        list[str]: parent set.
    """
    return list(self.graph.predecessors(of_node))

parents_of(node, which_graph)

Return parents of node in question for a chosen DAG.

Parameters:

Name Type Description Default
node Symbol

node whose parents to return.

required
which_graph DiGraph

which graph along the interventions.

required

Returns:

Type Description
list[Symbol]

list[Symbol]: list of parents.

Source code in causalAssembly/models_fcm.py
196
197
198
199
200
201
202
203
204
205
206
def parents_of(self, node: Symbol, which_graph: nx.DiGraph) -> list[Symbol]:
    """Return parents of node in question for a chosen DAG.

    Args:
        node (Symbol): node whose parents to return.
        which_graph (nx.DiGraph): which graph along the interventions.

    Returns:
        list[Symbol]: list of parents.
    """
    return list(which_graph.predecessors(node))

sample(size, additive_gaussian_noise=False, snr=1 / 2, source_df=None)

Sample from joint.

Draw samples from the joint distribution that factorizes according to the DAG implied by the FCM fed in.

To avoid unexpected/unintended behavior, avoid defining fully deterministic equation systems. If parameters in noise terms are additive and left unevaluated, they're set according to a chosen Signal-To-Noise (SNR) ratio. For convenience, you can add additive Gaussian noise to each equation. This will never overwrite any of the chosen noise distributions. You may also feed in a data frame for noise distributions (see below for more details).

Parameters:

Name Type Description Default
size int

Number of samples to draw.

required
additive_gaussian_noise bool

This will attach additive Gaussian noise to all terms without a RandomSymbol that are not source nodes. It acts merely as a convenience option. Variance will then be chosen according to SNR. Defaults to False.

False
snr None | float

Signal-to-noise ratio \( SNR = \frac{\text{Var}(\hat{X})}{\hat\sigma^2}. \). Defaults to 1/2.

1 / 2
source_df None | DataFrame

Data frame containing source node data. The sample size must be at least as large as the number of samples you'd like to draw. Defaults to None.

None

Raises:

Type Description
AttributeError

if source node parameters are not given explicitly.

ValueError

if source node sample size is too small.

ValueError

if scale parameters are left unevaluated for non-additive terms.

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame with rows of length size and columns equal to the number of nodes in the graph.

Source code in causalAssembly/models_fcm.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def sample(
    self,
    size: int,
    additive_gaussian_noise: bool = False,
    snr: None | float = 1 / 2,
    source_df: None | pd.DataFrame = None,
) -> pd.DataFrame:
    r"""Sample from joint.

    Draw samples from the joint distribution that factorizes
    according to the DAG implied by the FCM fed in.

    To avoid unexpected/unintended behavior, avoid defining fully
    deterministic equation systems.
    If parameters in noise terms are additive and left unevaluated,
    they're set according to a chosen Signal-To-Noise (SNR) ratio.
    For convenience, you can add additive Gaussian noise to each equation.
    This will never overwrite any of the chosen noise distributions.
    You may also feed in a data frame for noise distributions (see below
    for more details).

    Args:
        size (int): Number of samples to draw.
        additive_gaussian_noise (bool, optional): This will attach additive
            Gaussian noise to all terms without a RandomSymbol that are not
            source nodes. It acts merely as a convenience option. Variance
            will then be chosen according to SNR. Defaults to False.
        snr (None | float, optional): Signal-to-noise ratio
            \\( SNR =  \\frac{\\text{Var}(\\hat{X})}{\\hat\\sigma^2}. \\).
            Defaults to 1/2.
        source_df (None | pd.DataFrame, optional): Data frame containing source node data.
            The sample size must be at least as large as the number of samples
            you'd like to draw. Defaults to None.

    Raises:
        AttributeError: if source node parameters are not given explicitly.
        ValueError: if source node sample size is too small.
        ValueError: if scale parameters are left unevaluated for non-additive terms.

    Returns:
        pd.DataFrame:  Data frame with rows of length `size` and columns equal to the
            number of nodes in the graph.
    """
    return self._sample(
        size=size,
        additive_gaussian_noise=additive_gaussian_noise,
        snr=snr,
        source_df=source_df,
        which_graph=self.graph,
    )

show(header=None, with_nodenames=True)

Plots the current DAG.

Parameters:

Name Type Description Default
header str | None

Header for the DAG. Defaults to None.

None
with_nodenames bool

Whether or not to use nodenames as labels in the plot. Defaults to True.

True

Returns:

Name Type Description
plt

Plot of the DAG.

Source code in causalAssembly/models_fcm.py
666
667
668
669
670
671
672
673
674
675
676
677
678
679
def show(self, header: str | None = None, with_nodenames: bool = True):
    """Plots the current DAG.

    Args:
        header (str | None, optional): Header for the DAG. Defaults to None.
        with_nodenames (bool, optional): Whether or not to use nodenames as
            labels in the plot. Defaults to True.

    Returns:
        plt: Plot of the DAG.
    """
    if header is None:
        header = ""
    return self._show(which_graph=self.graph, header=header, with_nodenames=with_nodenames)

show_mutilated_dag(which_intervention=0, with_nodenames=True)

Plot mutilated DAG.

Parameters:

Name Type Description Default
which_intervention str | int

Which interventional distribution should be represented by a DAG. Defaults to 0.

0
with_nodenames bool

Whether or not to use nodenames as labels in the plot. Defaults to True.

True

Returns:

Name Type Description
plt

Plot of the mutilated DAG.

Source code in causalAssembly/models_fcm.py
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
def show_mutilated_dag(self, which_intervention: str | int = 0, with_nodenames: bool = True):
    """Plot mutilated DAG.

    Args:
        which_intervention (str | int, optional): Which interventional distribution
            should be represented by a DAG. Defaults to 0.
        with_nodenames (bool, optional): Whether or not to use nodenames as
            labels in the plot. Defaults to True.

    Returns:
        plt: Plot of the mutilated DAG.
    """
    if isinstance(which_intervention, int):
        which_intervention = self.interventions[which_intervention]

    return self._show(
        which_graph=self.mutilated_dags[which_intervention],
        header=which_intervention,
        with_nodenames=with_nodenames,
    )

source_nodes_of(which_graph)

Returns the source nodes of a chosen graph.

This is mainly for choosing different mutilated DAGs.

Parameters:

Name Type Description Default
which_graph DiGraph

DAG from which source nodes should be returned.

required

Returns:

Name Type Description
list list

List of nodes.

Source code in causalAssembly/models_fcm.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def source_nodes_of(self, which_graph: nx.DiGraph) -> list:
    """Returns the source nodes of a chosen graph.

    This is mainly for
    choosing different mutilated DAGs.

    Args:
        which_graph (nx.DiGraph): DAG from which source nodes should be returned.

    Returns:
        list: List of nodes.
    """
    return [
        node
        for node in which_graph.nodes
        if len(self.parents_of(node=node, which_graph=which_graph)) == 0
    ]

Utility classes and functions related to causalAssembly.

Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

DAGmetrics

Class to calculate performance metrics for DAGs.

Make sure that the ground truth and the estimated DAG have the same order of rows/columns. If these objects are nx.DiGraphs, make sure that graph.nodes() have the same oder or pass a new nodelist to the class when initiating. The same can be done for pd.DataFrames. In case truth and est are np.ndarray objects it is the users responsibility to make sure that the objects are indeed comparable.

Source code in causalAssembly/metrics.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class DAGmetrics:
    """Class to calculate performance metrics for DAGs.

    Make sure that the ground truth and the estimated DAG have the same order of
    rows/columns. If these objects are nx.DiGraphs, make sure that graph.nodes()
    have the same oder or pass a new nodelist to the class when initiating. The
    same can be done for pd.DataFrames. In case `truth` and `est` are np.ndarray
    objects it is the users responsibility to make sure that the objects are
    indeed comparable.
    """

    def __init__(
        self,
        truth: nx.DiGraph | pd.DataFrame | np.ndarray,
        est: nx.DiGraph | pd.DataFrame | np.ndarray,
        nodelist: list[str] | None = None,
    ):
        """Inits the DAGmetrics class.

        Args:
            truth (nx.DiGraph | pd.DataFrame | np.ndarray): _description_
            est (nx.DiGraph | pd.DataFrame | np.ndarray): _description_
            nodelist (list, optional): _description_. Defaults to None.

        Raises:
            TypeError: _description_
            TypeError: _description_
        """
        if not isinstance(truth, nx.DiGraph | pd.DataFrame | np.ndarray):
            raise TypeError("Ground truth graph has to be one of the permitted classes.")

        if not isinstance(est, nx.DiGraph | pd.DataFrame | np.ndarray):
            raise TypeError("Estimated graph has to be one of the permitted classes")

        self.truth = DAGmetrics._convert_to_numpy(truth, nodelist=nodelist)
        self.est = DAGmetrics._convert_to_numpy(est, nodelist=nodelist)

        self.metrics = None

    def _calculate_scores(self):
        """Calculate Precision, Recall and F1 and g score.

        Return:
        precision: float
            TP/(TP + FP)
        recall: float
            TP/(TP + FN)
        f1: float
            2*(recall*precision)/(recall+precision)
        gscore: float
            max(0, (TP-FP))/(TP+FN)
        """
        TWO = 2
        assert self.est.shape == self.truth.shape and self.est.shape[0] == self.est.shape[1]
        TP = np.where((self.est + self.truth) == TWO, 1, 0).sum(axis=1).sum()
        TP_FP = self.est.sum(axis=1).sum()
        FP = TP_FP - TP
        TP_FN = self.truth.sum(axis=1).sum()

        precision = TP / max(TP_FP, 1)
        recall = TP / max(TP_FN, 1)
        F1 = 2 * (recall * precision) / max((recall + precision), 1)
        gscore = max(0, (TP - FP)) / max(TP_FN, 1)

        return {"precision": precision, "recall": recall, "f1": F1, "gscore": gscore}

    def _shd(self, count_anticausal_twice: bool = True):
        """Calculate Structural Hamming Distance (SHD).

        Args:
            count_anticausal_twice (bool, optional): If edge is pointing in the wrong direction
                it's also missing in the right direction and is counted twice. Defaults to True.
        """
        dist = np.abs(self.truth - self.est)
        if count_anticausal_twice:
            return np.sum(dist)
        else:
            dist = dist + dist.transpose()
            dist[dist > 1] = 1
            return np.sum(dist) / 2

    def collect_metrics(self) -> dict[str, float | int]:
        """Collects all metrics defined in this class in a dict.

        Returns:
            dict[str, float|int]: Metrics calculated
        """
        metrics = self._calculate_scores()
        metrics["shd"] = self._shd()
        self.metrics = metrics
        return metrics

    @classmethod
    def _convert_to_numpy(
        cls,
        graph: nx.DiGraph | pd.DataFrame | np.ndarray,
        nodelist: list[str] | None = None,
    ):
        if isinstance(graph, np.ndarray):
            return copy.deepcopy(graph)
        elif isinstance(graph, pd.DataFrame):
            if nodelist:
                return copy.deepcopy(graph.reindex(nodelist)[nodelist].to_numpy())
            else:
                return copy.deepcopy(graph.to_numpy())
        elif isinstance(graph, nx.DiGraph):
            return nx.to_numpy_array(G=graph, nodelist=nodelist)

__init__(truth, est, nodelist=None)

Inits the DAGmetrics class.

Parameters:

Name Type Description Default
truth DiGraph | DataFrame | ndarray

description

required
est DiGraph | DataFrame | ndarray

description

required
nodelist list

description. Defaults to None.

None

Raises:

Type Description
TypeError

description

TypeError

description

Source code in causalAssembly/metrics.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    truth: nx.DiGraph | pd.DataFrame | np.ndarray,
    est: nx.DiGraph | pd.DataFrame | np.ndarray,
    nodelist: list[str] | None = None,
):
    """Inits the DAGmetrics class.

    Args:
        truth (nx.DiGraph | pd.DataFrame | np.ndarray): _description_
        est (nx.DiGraph | pd.DataFrame | np.ndarray): _description_
        nodelist (list, optional): _description_. Defaults to None.

    Raises:
        TypeError: _description_
        TypeError: _description_
    """
    if not isinstance(truth, nx.DiGraph | pd.DataFrame | np.ndarray):
        raise TypeError("Ground truth graph has to be one of the permitted classes.")

    if not isinstance(est, nx.DiGraph | pd.DataFrame | np.ndarray):
        raise TypeError("Estimated graph has to be one of the permitted classes")

    self.truth = DAGmetrics._convert_to_numpy(truth, nodelist=nodelist)
    self.est = DAGmetrics._convert_to_numpy(est, nodelist=nodelist)

    self.metrics = None

collect_metrics()

Collects all metrics defined in this class in a dict.

Returns:

Type Description
dict[str, float | int]

dict[str, float|int]: Metrics calculated

Source code in causalAssembly/metrics.py
105
106
107
108
109
110
111
112
113
114
def collect_metrics(self) -> dict[str, float | int]:
    """Collects all metrics defined in this class in a dict.

    Returns:
        dict[str, float|int]: Metrics calculated
    """
    metrics = self._calculate_scores()
    metrics["shd"] = self._shd()
    self.metrics = metrics
    return metrics

Utility classes and functions related to causalAssembly.

Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

DRF

Wrapper around the corresponding R package.

Distributional Random Forests (Cevid et al., 2020). Closely adopted from their python wrapper.

Source code in causalAssembly/drf_fitting.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class DRF:
    """Wrapper around the corresponding R package.

    Distributional Random Forests (Cevid et al., 2020).
    Closely adopted from their python wrapper.
    """

    def __init__(self, **fit_params):
        """Initialize the DRF object with fit parameters."""
        self.fit_params = fit_params

    def fit(self, X: pd.DataFrame, Y: pd.DataFrame | pd.Series):
        """Fit DRF in order to estimate conditional distribution P(Y|X=x).

        Args:
            X (pd.DataFrame): Predictor variables.
            Y (pd.DataFrame | pd.Series): Response variable(s).
        """
        self.X_train = X
        self.Y_train = Y

        # Use localconverter
        with localconverter(R_CONVERTER):
            X_r = ro.conversion.py2rpy(X)
            Y_r = ro.conversion.py2rpy(Y)
            self.r_fit_object = drf_r_package.drf(X_r, Y_r, **self.fit_params)

    def produce_sample(
        self,
        newdata: pd.DataFrame,
        random_state: np.random.Generator,
        n: int = 1,
    ) -> np.ndarray:
        """Sample data from fitted drf.

        Args:
            newdata (pd.DataFrame): Data samples to predict from.
            random_state (np.random.Generator): control random state.
            n (int, optional): Number of n-samples to draw. Defaults to 1.

        Returns:
            np.ndarray: New predicted sample of Y.
        """
        with localconverter(R_CONVERTER):
            newdata_r = ro.conversion.py2rpy(newdata)
            r_output = drf_r_package.predict_drf(self.r_fit_object, newdata_r)

            # Convert back to Python
            weights = ro.conversion.rpy2py(base_r_package.as_matrix(r_output[0]))
            Y = ro.conversion.rpy2py(base_r_package.as_matrix(r_output[1]))

        if not isinstance(Y, pd.DataFrame):
            Y = pd.DataFrame(Y)

        sample = np.zeros((newdata.shape[0], Y.shape[1], n))
        for i in range(newdata.shape[0]):
            for j in range(n):
                ids = random_state.choice(range(Y.shape[0]), 1, p=weights[i, :])[0]
                sample[i, :, j] = Y.iloc[ids, :]

        return sample[:, 0, 0]

__init__(**fit_params)

Initialize the DRF object with fit parameters.

Source code in causalAssembly/drf_fitting.py
52
53
54
def __init__(self, **fit_params):
    """Initialize the DRF object with fit parameters."""
    self.fit_params = fit_params

fit(X, Y)

Fit DRF in order to estimate conditional distribution P(Y|X=x).

Parameters:

Name Type Description Default
X DataFrame

Predictor variables.

required
Y DataFrame | Series

Response variable(s).

required
Source code in causalAssembly/drf_fitting.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def fit(self, X: pd.DataFrame, Y: pd.DataFrame | pd.Series):
    """Fit DRF in order to estimate conditional distribution P(Y|X=x).

    Args:
        X (pd.DataFrame): Predictor variables.
        Y (pd.DataFrame | pd.Series): Response variable(s).
    """
    self.X_train = X
    self.Y_train = Y

    # Use localconverter
    with localconverter(R_CONVERTER):
        X_r = ro.conversion.py2rpy(X)
        Y_r = ro.conversion.py2rpy(Y)
        self.r_fit_object = drf_r_package.drf(X_r, Y_r, **self.fit_params)

produce_sample(newdata, random_state, n=1)

Sample data from fitted drf.

Parameters:

Name Type Description Default
newdata DataFrame

Data samples to predict from.

required
random_state Generator

control random state.

required
n int

Number of n-samples to draw. Defaults to 1.

1

Returns:

Type Description
ndarray

np.ndarray: New predicted sample of Y.

Source code in causalAssembly/drf_fitting.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def produce_sample(
    self,
    newdata: pd.DataFrame,
    random_state: np.random.Generator,
    n: int = 1,
) -> np.ndarray:
    """Sample data from fitted drf.

    Args:
        newdata (pd.DataFrame): Data samples to predict from.
        random_state (np.random.Generator): control random state.
        n (int, optional): Number of n-samples to draw. Defaults to 1.

    Returns:
        np.ndarray: New predicted sample of Y.
    """
    with localconverter(R_CONVERTER):
        newdata_r = ro.conversion.py2rpy(newdata)
        r_output = drf_r_package.predict_drf(self.r_fit_object, newdata_r)

        # Convert back to Python
        weights = ro.conversion.rpy2py(base_r_package.as_matrix(r_output[0]))
        Y = ro.conversion.rpy2py(base_r_package.as_matrix(r_output[1]))

    if not isinstance(Y, pd.DataFrame):
        Y = pd.DataFrame(Y)

    sample = np.zeros((newdata.shape[0], Y.shape[1], n))
    for i in range(newdata.shape[0]):
        for j in range(n):
            ids = random_state.choice(range(Y.shape[0]), 1, p=weights[i, :])[0]
            sample[i, :, j] = Y.iloc[ids, :]

    return sample[:, 0, 0]

fit_drf(graph, data)

Fit distributional random forests to the factorization implied by the current graph.

Parameters:

Name Type Description Default
graph ProductionLineGraph | ProcessCell | DAG

Graph to fit the DRF to.

required
data DataFrame

Columns of dataframe need to match name and order of the graph

required

Raises: ValueError: Raises error if columns don't meet this requirement

Returns: (dict): dict of fitted DRFs.

Source code in causalAssembly/drf_fitting.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def fit_drf(graph: ProductionLineGraph | ProcessCell | DAG, data: pd.DataFrame):
    """Fit distributional random forests to the factorization implied by the current graph.

    Args:
        graph (ProductionLineGraph | ProcessCell | DAG): Graph to fit the DRF to.
        data (pd.DataFrame): Columns of dataframe need to match name and order of the graph

    Raises: ValueError: Raises error if columns don't meet this requirement

    Returns: (dict): dict of fitted DRFs.
    """
    tempdata = data.copy()

    if set(graph.nodes).issubset(tempdata.columns):
        tempdata = tempdata[graph.nodes]
    else:
        raise ValueError("Data columns don't match node names.")

    drf_dict = {}
    for node in graph.nodes:
        parents = graph.parents(of_node=node)
        if not parents:
            drf_dict[node] = gaussian_kde(tempdata[node].to_numpy())
        elif parents:
            # default setting as suggested in the paper
            drf_object = DRF(min_node_size=15, num_trees=2000, splitting_rule="FourierMMD")
            X = tempdata[parents]
            Y = tempdata[node]
            drf_object.fit(X, Y)
            drf_dict[node] = drf_object
        else:
            raise ValueError("Unexpected behavior in DRF. Check whether data and DAG match?")
    return drf_dict

Utility classes and functions related to causalAssembly.

Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

PDAG

Class for dealing with partially directed graphs.

i.e., graphs that contain both directed and undirected edges.

Source code in causalAssembly/pdag.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
class PDAG:
    """Class for dealing with partially directed graphs.

    i.e., graphs that contain both directed and undirected edges.
    """

    def __init__(
        self,
        nodes: list[str] | list[int] | set[str] | set[int] | None = None,
        dir_edges: list[tuple[str, str]]
        | list[tuple[int, int]]
        | set[tuple[str, str]]
        | set[tuple[int, int]]
        | None = None,
        undir_edges: list[tuple[str, str]]
        | list[tuple[int, int]]
        | set[tuple[str, str]]
        | set[tuple[int, int]]
        | None = None,
    ):
        """Inits the PDAG class.

        Args:
            nodes (list | None, optional): _description_. Defaults to None.
            dir_edges (list[tuple] | None, optional): _description_. Defaults to None.
            undir_edges (list[tuple] | None, optional): _description_. Defaults to None.
        """
        if nodes is None:
            nodes = []
        if dir_edges is None:
            dir_edges = []
        if undir_edges is None:
            undir_edges = []

        self._nodes = set(nodes)
        self._undir_edges = set()
        self._dir_edges = set()
        self._parents = defaultdict(set)
        self._children = defaultdict(set)
        self._neighbors = defaultdict(set)
        self._undirected_neighbors = defaultdict(set)

        for dir_edge in dir_edges:
            self._add_dir_edge(*dir_edge)
        for unir_edge in undir_edges:
            self._add_undir_edge(*unir_edge)

    def _add_dir_edge(self, i, j):
        self._nodes.add(i)
        self._nodes.add(j)
        self._dir_edges.add((i, j))

        self._neighbors[i].add(j)
        self._neighbors[j].add(i)

        self._children[i].add(j)
        self._parents[j].add(i)

    def _add_undir_edge(self, i, j):
        self._nodes.add(i)
        self._nodes.add(j)
        self._undir_edges.add((i, j))

        self._neighbors[i].add(j)
        self._neighbors[j].add(i)

        self._undirected_neighbors[i].add(j)
        self._undirected_neighbors[j].add(i)

    def children(self, node: str | int) -> set:
        """Gives all children of node `node`.

        Args:
            node (str): node in current PDAG.

        Returns:
            set: set of children.
        """
        if node in self._children.keys():
            return self._children[node]
        else:
            return set()

    def parents(self, node: str | int) -> set:
        """Gives all parents of node `node`.

        Args:
            node (str): node in current PDAG.

        Returns:
            set: set of parents.
        """
        if node in self._parents.keys():
            return self._parents[node]
        else:
            return set()

    def neighbors(self, node: str | int) -> set:
        """Gives all neighbors of node `node`.

        Args:
            node (str): node in current PDAG.

        Returns:
            set: set of neighbors.
        """
        if node in self._neighbors.keys():
            return self._neighbors[node]
        else:
            return set()

    def undir_neighbors(self, node: str | int) -> set:
        """Gives all undirected neighbors of node `node`.

        Args:
            node (str): node in current PDAG.

        Returns:
            set: set of undirected neighbors.
        """
        if node in self._undirected_neighbors.keys():
            return self._undirected_neighbors[node]
        else:
            return set()

    def is_adjacent(self, i: str, j: str) -> bool:
        """Return True if the graph contains an directed or undirected edge between i and j.

        Args:
            i (str): node i.
            j (str): node j.

        Returns:
            bool: True if i-j or i->j or i<-j
        """
        return any(
            (
                (j, i) in self.dir_edges or (j, i) in self.undir_edges,
                (i, j) in self.dir_edges or (i, j) in self.undir_edges,
            )
        )

    def is_clique(self, potential_clique: set) -> bool:
        """Check every pair of node X potential_clique is adjacent."""
        return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2))

    @classmethod
    def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> PDAG:
        """Build PDAG from a Pandas adjacency matrix.

        Args:
            pd_amat (pd.DataFrame): input adjacency matrix.

        Returns:
            PDAG
        """
        assert pd_amat.shape[0] == pd_amat.shape[1]
        nodes = list(pd_amat.columns)

        all_connections = []
        start, end = np.where(pd_amat != 0)
        for idx, _ in enumerate(start):
            all_connections.append((pd_amat.columns[start[idx]], pd_amat.columns[end[idx]]))

        temp = [set(i) for i in all_connections]
        temp2 = [arc for arc in all_connections if temp.count(set(arc)) > 1]
        undir_edges = [tuple(item) for item in set(frozenset(item) for item in temp2)]

        dir_edges = [edge for edge in all_connections if edge not in temp2]

        return PDAG(nodes=nodes, dir_edges=dir_edges, undir_edges=undir_edges)

    def remove_edge(self, i: str, j: str):
        """Removes edge in question.

        Args:
            i (str): tail
            j (str): head

        Raises:
            AssertionError: if edge does not exist
        """
        if (i, j) not in self.dir_edges and (i, j) not in self.undir_edges:
            raise AssertionError("Edge does not exist in current PDAG")

        self._undir_edges.discard((i, j))
        self._dir_edges.discard((i, j))
        self._children[i].discard(j)
        self._parents[j].discard(i)
        self._neighbors[i].discard(j)
        self._neighbors[j].discard(i)
        self._undirected_neighbors[i].discard(j)
        self._undirected_neighbors[j].discard(i)

    def undir_to_dir_edge(self, tail: str, head: str):
        """Takes a undirected edge and turns it into a directed one.

        tail indicates the starting node of the edge and head the end node, i.e.
        tail -> head.

        Args:
            tail (str): starting node
            head (str): end node

        Raises:
            AssertionError: if edge does not exist or is not undirected.
        """
        if (tail, head) not in self.undir_edges and (
            head,
            tail,
        ) not in self.undir_edges:
            raise AssertionError("Edge seems not to be undirected or even there at all.")
        self._undir_edges.discard((tail, head))
        self._undir_edges.discard((head, tail))
        self._neighbors[tail].discard(head)
        self._neighbors[head].discard(tail)
        self._undirected_neighbors[tail].discard(head)
        self._undirected_neighbors[head].discard(tail)

        self._add_dir_edge(i=tail, j=head)

    def remove_node(self, node):
        """Remove a node from the graph."""
        self._nodes.remove(node)

        self._dir_edges = {(i, j) for i, j in self._dir_edges if node not in (i, j)}

        self._undir_edges = {(i, j) for i, j in self._undir_edges if node not in (i, j)}

        for child in self._children[node]:
            self._parents[child].remove(node)
            self._neighbors[child].remove(node)

        for parent in self._parents[node]:
            self._children[parent].remove(node)
            self._neighbors[parent].remove(node)

        for u_nbr in self._undirected_neighbors[node]:
            self._undirected_neighbors[u_nbr].remove(node)
            self._neighbors[u_nbr].remove(node)

        self._parents.pop(node, "I was never here")
        self._children.pop(node, "I was never here")
        self._neighbors.pop(node, "I was never here")
        self._undirected_neighbors.pop(node, "I was never here")

    def to_dag(self) -> nx.DiGraph:
        r"""Algorithm as described in Chickering (2002).

            1. From PDAG P create DAG G containing all directed edges from P
            2. Repeat the following: Select node v in P s.t.
                i. v has no outgoing edges (children) i.e. \\(ch(v) = \\emptyset \\)

                ii. \\(neigh(v) \\neq \\emptyset\\)
                    Then \\( (pa(v) \\cup (neigh(v) \\) form a clique.
                    For each v that is in a clique and is part of an undirected edge in P
                    i.e. w - v, insert a directed edge w -> v in G.
                    Remove v and all incident edges from P and continue with next node.
                    Until all nodes have been deleted from P.

        Returns:
            nx.DiGraph: DAG that belongs to the MEC implied by the PDAG
        """
        pdag = self.copy()

        dag = nx.DiGraph()
        dag.add_nodes_from(pdag.nodes)
        dag.add_edges_from(pdag.dir_edges)

        if pdag.num_undir_edges == 0:
            return dag
        else:
            while pdag.nnodes > 0:
                # find node with (1) no directed outgoing edges and
                #                (2) the set of undirected neighbors is either empty or
                #                    undirected neighbors + parents of X are a clique
                found = False
                for node in pdag.nodes:
                    children = pdag.children(node)
                    neighbors = pdag.neighbors(node)
                    # pdag._undirected_neighbors[node]
                    parents = pdag.parents(node)
                    potential_clique_members = neighbors.union(parents)

                    is_clique = pdag.is_clique(potential_clique_members)

                    if not children and (not neighbors or is_clique):
                        found = True
                        # add all edges of node as outgoing edges to dag
                        for edge in pdag.undir_edges:
                            if node in edge:
                                incident_node = set(edge) - {node}
                                dag.add_edge(*incident_node, node)  # type: ignore

                        pdag.remove_node(node)
                        break

                if not found:
                    logger.warning("PDAG not extendible: Random DAG on skeleton drawn.")

                    dag = nx.from_pandas_adjacency(self._amat_to_dag(), create_using=nx.DiGraph)

                    break

            return dag

    @property
    def adjacency_matrix(self) -> pd.DataFrame:
        """Returns adjacency matrix.

        The i,jth
        entry being one indicates that there is an edge
        from i to j. A zero indicates that there is no edge.

        Returns:
            pd.DataFrame: adjacency matrix
        """
        amat = pd.DataFrame(
            np.zeros([self.nnodes, self.nnodes]),
            index=self.nodes,
            columns=self.nodes,
        )
        for edge in self.dir_edges:
            amat.loc[edge] = 1
        for edge in self.undir_edges:
            amat.loc[edge] = amat.loc[edge[::-1]] = 1
        return amat

    def _amat_to_dag(self) -> pd.DataFrame:
        """Adjacency matrix to random DAG.

        Transform the adjacency matrix of an PDAG to the adjacency
        matrix of a SOME DAG in the Markov equivalence class.

        Returns:
            pd.DataFrame: DAG, a member of the MEC.
        """
        pdag_amat = self.adjacency_matrix.to_numpy()

        p = pdag_amat.shape[0]
        ## amat to skel
        skel = pdag_amat + pdag_amat.T
        skel[np.where(skel > 1)] = 1
        ## permute skel
        permute_ord = np.random.choice(a=p, size=p, replace=False)
        skel = skel[:, permute_ord][permute_ord]

        ## skel to dag
        for i in range(1, p):
            for j in range(0, i + 1):
                if skel[i, j] == 1:
                    skel[i, j] = 0

        ## inverse permutation
        i_ord = np.sort(permute_ord)
        skel = skel[:, i_ord][i_ord]
        return pd.DataFrame(
            skel,
            index=self.adjacency_matrix.index,
            columns=self.adjacency_matrix.columns,
        )

    def vstructs(self) -> set:
        """Retrieve v-structures.

        Returns:
            set: set of all v-structures
        """
        vstructures = set()
        for node in self._nodes:
            for p1, p2 in combinations(self._parents[node], 2):
                if p1 not in self._parents[p2] and p2 not in self._parents[p1]:
                    vstructures.add((p1, node))
                    vstructures.add((p2, node))
        return vstructures

    def copy(self):
        """Return a copy of the graph."""
        return PDAG(nodes=self._nodes, dir_edges=self._dir_edges, undir_edges=self._undir_edges)  # type: ignore

    def show(self):
        """Plot PDAG."""
        graph = self.to_networkx()
        pos = nx.circular_layout(graph)
        nx.draw(graph, pos=pos, with_labels=True)

    def to_networkx(self) -> nx.MultiDiGraph:
        """Convert to networkx graph.

        Returns:
            nx.MultiDiGraph: Graph with directed and undirected edges.
        """
        nx_pdag = nx.MultiDiGraph()
        nx_pdag.add_nodes_from(self.nodes)
        nx_pdag.add_edges_from(self.dir_edges)
        for edge in self.undir_edges:
            nx_pdag.add_edge(*edge)
            nx_pdag.add_edge(*edge[::-1])

        return nx_pdag

    def _meek_mec_enumeration(self, pdag: PDAG, dag_list: list):
        """Recursion algorithm which recursively applies the following steps.

            1. Orient the first undirected edge found.
            2. Apply Meek rules.
            3. Recurse with each direction of the oriented edge.
        This corresponds to Algorithm 2 in Wienöbst et al. (2023).

        Args:
            pdag (PDAG): partially directed graph in question.
            dag_list (list): list of currently found DAGs.

        References:
            Wienöbst, Marcel, et al. "Efficient enumeration of Markov equivalent DAGs."
            Proceedings of the AAAI Conference on Artificial Intelligence.
            Vol. 37. No. 10. 2023.
        """
        g_copy = pdag.copy()
        g_copy = self._apply_meek_rules(g_copy)  # Apply Meek rules

        undir_edges = g_copy.undir_edges
        if undir_edges:
            i, j = undir_edges[0]  # Take first undirected edge

        if not g_copy.undir_edges:
            # makes sure that flaoting nodes are preserved
            new_member = nx.DiGraph()
            new_member.add_nodes_from(g_copy.nodes)
            new_member.add_edges_from(g_copy.dir_edges)
            dag_list.append(new_member)
            return  # Add DAG to current list

        # Recursion first orientation:
        g_copy.undir_to_dir_edge(i, j)
        self._meek_mec_enumeration(pdag=g_copy, dag_list=dag_list)
        g_copy.remove_edge(i, j)

        # Recursion second orientation
        g_copy._add_dir_edge(j, i)
        self._meek_mec_enumeration(pdag=g_copy, dag_list=dag_list)

    def to_allDAGs(self) -> list[nx.DiGraph]:
        """Recursion algorithm which recursively applies the following steps.

            1. Orient the first undirected edge found.
            2. Apply Meek rules.
            3. Recurse with each direction of the oriented edge.
        This corresponds to Algorithm 2 in Wienöbst et al. (2023).

        References:
            Wienöbst, Marcel, et al. "Efficient enumeration of Markov equivalent DAGs."
            Proceedings of the AAAI Conference on Artificial Intelligence.
            Vol. 37. No. 10. 2023.
        """
        all_dags = []
        self._meek_mec_enumeration(pdag=self, dag_list=all_dags)
        return all_dags

    # use Meek's cpdag2alldag
    def _apply_meek_rules(self, G: PDAG) -> PDAG:
        """Apply all four Meek rules to a PDAG turning it into a CPDAG.

        Args:
            G (PDAG): PDAG to complete

        Returns:
            PDAG: completed PDAG.
        """
        # Apply Meek Rules
        cpdag = G.copy()
        cpdag = rule_1(pdag=cpdag)
        cpdag = rule_2(pdag=cpdag)
        cpdag = rule_3(pdag=cpdag)
        cpdag = rule_4(pdag=cpdag)
        return cpdag

    def to_random_dag(self) -> nx.DiGraph:
        """Provides a random DAG residing in the MEC.

        Returns:
            nx.DiGraph: random DAG living in MEC
        """
        to_dag_candidate = self.copy()

        while to_dag_candidate.num_undir_edges > 0:
            chosen_edge = to_dag_candidate.undir_edges[
                np.random.choice(to_dag_candidate.num_undir_edges)
            ]
            choose_orientation = [chosen_edge, chosen_edge[::-1]]
            node_i, node_j = choose_orientation[np.random.choice(len(choose_orientation))]

            to_dag_candidate.undir_to_dir_edge(tail=node_i, head=node_j)
            to_dag_candidate = to_dag_candidate._apply_meek_rules(G=to_dag_candidate)

        return nx.from_pandas_adjacency(to_dag_candidate.adjacency_matrix, create_using=nx.DiGraph)

    @property
    def nodes(self) -> list[str] | list[int]:
        """Get all nods in current PDAG.

        Returns:
            list: list of nodes.
        """
        return sorted(list(self._nodes))  # type: ignore

    @property
    def nnodes(self) -> int:
        """Number of nodes in current PDAG.

        Returns:
            int: Number of nodes
        """
        return len(self._nodes)

    @property
    def num_undir_edges(self) -> int:
        """Number of undirected edges in current PDAG.

        Returns:
            int: Number of undirected edges
        """
        return len(self._undir_edges)

    @property
    def num_dir_edges(self) -> int:
        """Number of directed edges in current PDAG.

        Returns:
            int: Number of directed edges
        """
        return len(self._dir_edges)

    @property
    def num_adjacencies(self) -> int:
        """Number of adjacent nodes in current PDAG.

        Returns:
            int: Number of adjacent nodes
        """
        return self.num_undir_edges + self.num_dir_edges

    @property
    def undir_edges(self) -> list[tuple]:
        """Gives all undirected edges in current PDAG.

        Returns:
            list[tuple]: List of undirected edges.
        """
        return list(self._undir_edges)

    @property
    def dir_edges(self) -> list[tuple]:
        """Gives all directed edges in current PDAG.

        Returns:
            list[tuple]: List of directed edges.
        """
        return list(self._dir_edges)

adjacency_matrix property

Returns adjacency matrix.

The i,jth entry being one indicates that there is an edge from i to j. A zero indicates that there is no edge.

Returns:

Type Description
DataFrame

pd.DataFrame: adjacency matrix

dir_edges property

Gives all directed edges in current PDAG.

Returns:

Type Description
list[tuple]

list[tuple]: List of directed edges.

nnodes property

Number of nodes in current PDAG.

Returns:

Name Type Description
int int

Number of nodes

nodes property

Get all nods in current PDAG.

Returns:

Name Type Description
list list[str] | list[int]

list of nodes.

num_adjacencies property

Number of adjacent nodes in current PDAG.

Returns:

Name Type Description
int int

Number of adjacent nodes

num_dir_edges property

Number of directed edges in current PDAG.

Returns:

Name Type Description
int int

Number of directed edges

num_undir_edges property

Number of undirected edges in current PDAG.

Returns:

Name Type Description
int int

Number of undirected edges

undir_edges property

Gives all undirected edges in current PDAG.

Returns:

Type Description
list[tuple]

list[tuple]: List of undirected edges.

__init__(nodes=None, dir_edges=None, undir_edges=None)

Inits the PDAG class.

Parameters:

Name Type Description Default
nodes list | None

description. Defaults to None.

None
dir_edges list[tuple] | None

description. Defaults to None.

None
undir_edges list[tuple] | None

description. Defaults to None.

None
Source code in causalAssembly/pdag.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(
    self,
    nodes: list[str] | list[int] | set[str] | set[int] | None = None,
    dir_edges: list[tuple[str, str]]
    | list[tuple[int, int]]
    | set[tuple[str, str]]
    | set[tuple[int, int]]
    | None = None,
    undir_edges: list[tuple[str, str]]
    | list[tuple[int, int]]
    | set[tuple[str, str]]
    | set[tuple[int, int]]
    | None = None,
):
    """Inits the PDAG class.

    Args:
        nodes (list | None, optional): _description_. Defaults to None.
        dir_edges (list[tuple] | None, optional): _description_. Defaults to None.
        undir_edges (list[tuple] | None, optional): _description_. Defaults to None.
    """
    if nodes is None:
        nodes = []
    if dir_edges is None:
        dir_edges = []
    if undir_edges is None:
        undir_edges = []

    self._nodes = set(nodes)
    self._undir_edges = set()
    self._dir_edges = set()
    self._parents = defaultdict(set)
    self._children = defaultdict(set)
    self._neighbors = defaultdict(set)
    self._undirected_neighbors = defaultdict(set)

    for dir_edge in dir_edges:
        self._add_dir_edge(*dir_edge)
    for unir_edge in undir_edges:
        self._add_undir_edge(*unir_edge)

children(node)

Gives all children of node node.

Parameters:

Name Type Description Default
node str

node in current PDAG.

required

Returns:

Name Type Description
set set

set of children.

Source code in causalAssembly/pdag.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
def children(self, node: str | int) -> set:
    """Gives all children of node `node`.

    Args:
        node (str): node in current PDAG.

    Returns:
        set: set of children.
    """
    if node in self._children.keys():
        return self._children[node]
    else:
        return set()

copy()

Return a copy of the graph.

Source code in causalAssembly/pdag.py
406
407
408
def copy(self):
    """Return a copy of the graph."""
    return PDAG(nodes=self._nodes, dir_edges=self._dir_edges, undir_edges=self._undir_edges)  # type: ignore

from_pandas_adjacency(pd_amat) classmethod

Build PDAG from a Pandas adjacency matrix.

Parameters:

Name Type Description Default
pd_amat DataFrame

input adjacency matrix.

required

Returns:

Type Description
PDAG

PDAG

Source code in causalAssembly/pdag.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
@classmethod
def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> PDAG:
    """Build PDAG from a Pandas adjacency matrix.

    Args:
        pd_amat (pd.DataFrame): input adjacency matrix.

    Returns:
        PDAG
    """
    assert pd_amat.shape[0] == pd_amat.shape[1]
    nodes = list(pd_amat.columns)

    all_connections = []
    start, end = np.where(pd_amat != 0)
    for idx, _ in enumerate(start):
        all_connections.append((pd_amat.columns[start[idx]], pd_amat.columns[end[idx]]))

    temp = [set(i) for i in all_connections]
    temp2 = [arc for arc in all_connections if temp.count(set(arc)) > 1]
    undir_edges = [tuple(item) for item in set(frozenset(item) for item in temp2)]

    dir_edges = [edge for edge in all_connections if edge not in temp2]

    return PDAG(nodes=nodes, dir_edges=dir_edges, undir_edges=undir_edges)

is_adjacent(i, j)

Return True if the graph contains an directed or undirected edge between i and j.

Parameters:

Name Type Description Default
i str

node i.

required
j str

node j.

required

Returns:

Name Type Description
bool bool

True if i-j or i->j or i<-j

Source code in causalAssembly/pdag.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def is_adjacent(self, i: str, j: str) -> bool:
    """Return True if the graph contains an directed or undirected edge between i and j.

    Args:
        i (str): node i.
        j (str): node j.

    Returns:
        bool: True if i-j or i->j or i<-j
    """
    return any(
        (
            (j, i) in self.dir_edges or (j, i) in self.undir_edges,
            (i, j) in self.dir_edges or (i, j) in self.undir_edges,
        )
    )

is_clique(potential_clique)

Check every pair of node X potential_clique is adjacent.

Source code in causalAssembly/pdag.py
172
173
174
def is_clique(self, potential_clique: set) -> bool:
    """Check every pair of node X potential_clique is adjacent."""
    return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2))

neighbors(node)

Gives all neighbors of node node.

Parameters:

Name Type Description Default
node str

node in current PDAG.

required

Returns:

Name Type Description
set set

set of neighbors.

Source code in causalAssembly/pdag.py
127
128
129
130
131
132
133
134
135
136
137
138
139
def neighbors(self, node: str | int) -> set:
    """Gives all neighbors of node `node`.

    Args:
        node (str): node in current PDAG.

    Returns:
        set: set of neighbors.
    """
    if node in self._neighbors.keys():
        return self._neighbors[node]
    else:
        return set()

parents(node)

Gives all parents of node node.

Parameters:

Name Type Description Default
node str

node in current PDAG.

required

Returns:

Name Type Description
set set

set of parents.

Source code in causalAssembly/pdag.py
113
114
115
116
117
118
119
120
121
122
123
124
125
def parents(self, node: str | int) -> set:
    """Gives all parents of node `node`.

    Args:
        node (str): node in current PDAG.

    Returns:
        set: set of parents.
    """
    if node in self._parents.keys():
        return self._parents[node]
    else:
        return set()

remove_edge(i, j)

Removes edge in question.

Parameters:

Name Type Description Default
i str

tail

required
j str

head

required

Raises:

Type Description
AssertionError

if edge does not exist

Source code in causalAssembly/pdag.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def remove_edge(self, i: str, j: str):
    """Removes edge in question.

    Args:
        i (str): tail
        j (str): head

    Raises:
        AssertionError: if edge does not exist
    """
    if (i, j) not in self.dir_edges and (i, j) not in self.undir_edges:
        raise AssertionError("Edge does not exist in current PDAG")

    self._undir_edges.discard((i, j))
    self._dir_edges.discard((i, j))
    self._children[i].discard(j)
    self._parents[j].discard(i)
    self._neighbors[i].discard(j)
    self._neighbors[j].discard(i)
    self._undirected_neighbors[i].discard(j)
    self._undirected_neighbors[j].discard(i)

remove_node(node)

Remove a node from the graph.

Source code in causalAssembly/pdag.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def remove_node(self, node):
    """Remove a node from the graph."""
    self._nodes.remove(node)

    self._dir_edges = {(i, j) for i, j in self._dir_edges if node not in (i, j)}

    self._undir_edges = {(i, j) for i, j in self._undir_edges if node not in (i, j)}

    for child in self._children[node]:
        self._parents[child].remove(node)
        self._neighbors[child].remove(node)

    for parent in self._parents[node]:
        self._children[parent].remove(node)
        self._neighbors[parent].remove(node)

    for u_nbr in self._undirected_neighbors[node]:
        self._undirected_neighbors[u_nbr].remove(node)
        self._neighbors[u_nbr].remove(node)

    self._parents.pop(node, "I was never here")
    self._children.pop(node, "I was never here")
    self._neighbors.pop(node, "I was never here")
    self._undirected_neighbors.pop(node, "I was never here")

show()

Plot PDAG.

Source code in causalAssembly/pdag.py
410
411
412
413
414
def show(self):
    """Plot PDAG."""
    graph = self.to_networkx()
    pos = nx.circular_layout(graph)
    nx.draw(graph, pos=pos, with_labels=True)

to_allDAGs()

Recursion algorithm which recursively applies the following steps.

1. Orient the first undirected edge found.
2. Apply Meek rules.
3. Recurse with each direction of the oriented edge.

This corresponds to Algorithm 2 in Wienöbst et al. (2023).

References

Wienöbst, Marcel, et al. "Efficient enumeration of Markov equivalent DAGs." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 37. No. 10. 2023.

Source code in causalAssembly/pdag.py
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def to_allDAGs(self) -> list[nx.DiGraph]:
    """Recursion algorithm which recursively applies the following steps.

        1. Orient the first undirected edge found.
        2. Apply Meek rules.
        3. Recurse with each direction of the oriented edge.
    This corresponds to Algorithm 2 in Wienöbst et al. (2023).

    References:
        Wienöbst, Marcel, et al. "Efficient enumeration of Markov equivalent DAGs."
        Proceedings of the AAAI Conference on Artificial Intelligence.
        Vol. 37. No. 10. 2023.
    """
    all_dags = []
    self._meek_mec_enumeration(pdag=self, dag_list=all_dags)
    return all_dags

to_dag()

Algorithm as described in Chickering (2002).

1. From PDAG P create DAG G containing all directed edges from P
2. Repeat the following: Select node v in P s.t.
    i. v has no outgoing edges (children) i.e. \\(ch(v) = \\emptyset \\)

    ii. \\(neigh(v) \\neq \\emptyset\\)
        Then \\( (pa(v) \\cup (neigh(v) \\) form a clique.
        For each v that is in a clique and is part of an undirected edge in P
        i.e. w - v, insert a directed edge w -> v in G.
        Remove v and all incident edges from P and continue with next node.
        Until all nodes have been deleted from P.

Returns:

Type Description
DiGraph

nx.DiGraph: DAG that belongs to the MEC implied by the PDAG

Source code in causalAssembly/pdag.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def to_dag(self) -> nx.DiGraph:
    r"""Algorithm as described in Chickering (2002).

        1. From PDAG P create DAG G containing all directed edges from P
        2. Repeat the following: Select node v in P s.t.
            i. v has no outgoing edges (children) i.e. \\(ch(v) = \\emptyset \\)

            ii. \\(neigh(v) \\neq \\emptyset\\)
                Then \\( (pa(v) \\cup (neigh(v) \\) form a clique.
                For each v that is in a clique and is part of an undirected edge in P
                i.e. w - v, insert a directed edge w -> v in G.
                Remove v and all incident edges from P and continue with next node.
                Until all nodes have been deleted from P.

    Returns:
        nx.DiGraph: DAG that belongs to the MEC implied by the PDAG
    """
    pdag = self.copy()

    dag = nx.DiGraph()
    dag.add_nodes_from(pdag.nodes)
    dag.add_edges_from(pdag.dir_edges)

    if pdag.num_undir_edges == 0:
        return dag
    else:
        while pdag.nnodes > 0:
            # find node with (1) no directed outgoing edges and
            #                (2) the set of undirected neighbors is either empty or
            #                    undirected neighbors + parents of X are a clique
            found = False
            for node in pdag.nodes:
                children = pdag.children(node)
                neighbors = pdag.neighbors(node)
                # pdag._undirected_neighbors[node]
                parents = pdag.parents(node)
                potential_clique_members = neighbors.union(parents)

                is_clique = pdag.is_clique(potential_clique_members)

                if not children and (not neighbors or is_clique):
                    found = True
                    # add all edges of node as outgoing edges to dag
                    for edge in pdag.undir_edges:
                        if node in edge:
                            incident_node = set(edge) - {node}
                            dag.add_edge(*incident_node, node)  # type: ignore

                    pdag.remove_node(node)
                    break

            if not found:
                logger.warning("PDAG not extendible: Random DAG on skeleton drawn.")

                dag = nx.from_pandas_adjacency(self._amat_to_dag(), create_using=nx.DiGraph)

                break

        return dag

to_networkx()

Convert to networkx graph.

Returns:

Type Description
MultiDiGraph

nx.MultiDiGraph: Graph with directed and undirected edges.

Source code in causalAssembly/pdag.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def to_networkx(self) -> nx.MultiDiGraph:
    """Convert to networkx graph.

    Returns:
        nx.MultiDiGraph: Graph with directed and undirected edges.
    """
    nx_pdag = nx.MultiDiGraph()
    nx_pdag.add_nodes_from(self.nodes)
    nx_pdag.add_edges_from(self.dir_edges)
    for edge in self.undir_edges:
        nx_pdag.add_edge(*edge)
        nx_pdag.add_edge(*edge[::-1])

    return nx_pdag

to_random_dag()

Provides a random DAG residing in the MEC.

Returns:

Type Description
DiGraph

nx.DiGraph: random DAG living in MEC

Source code in causalAssembly/pdag.py
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
def to_random_dag(self) -> nx.DiGraph:
    """Provides a random DAG residing in the MEC.

    Returns:
        nx.DiGraph: random DAG living in MEC
    """
    to_dag_candidate = self.copy()

    while to_dag_candidate.num_undir_edges > 0:
        chosen_edge = to_dag_candidate.undir_edges[
            np.random.choice(to_dag_candidate.num_undir_edges)
        ]
        choose_orientation = [chosen_edge, chosen_edge[::-1]]
        node_i, node_j = choose_orientation[np.random.choice(len(choose_orientation))]

        to_dag_candidate.undir_to_dir_edge(tail=node_i, head=node_j)
        to_dag_candidate = to_dag_candidate._apply_meek_rules(G=to_dag_candidate)

    return nx.from_pandas_adjacency(to_dag_candidate.adjacency_matrix, create_using=nx.DiGraph)

undir_neighbors(node)

Gives all undirected neighbors of node node.

Parameters:

Name Type Description Default
node str

node in current PDAG.

required

Returns:

Name Type Description
set set

set of undirected neighbors.

Source code in causalAssembly/pdag.py
141
142
143
144
145
146
147
148
149
150
151
152
153
def undir_neighbors(self, node: str | int) -> set:
    """Gives all undirected neighbors of node `node`.

    Args:
        node (str): node in current PDAG.

    Returns:
        set: set of undirected neighbors.
    """
    if node in self._undirected_neighbors.keys():
        return self._undirected_neighbors[node]
    else:
        return set()

undir_to_dir_edge(tail, head)

Takes a undirected edge and turns it into a directed one.

tail indicates the starting node of the edge and head the end node, i.e. tail -> head.

Parameters:

Name Type Description Default
tail str

starting node

required
head str

end node

required

Raises:

Type Description
AssertionError

if edge does not exist or is not undirected.

Source code in causalAssembly/pdag.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def undir_to_dir_edge(self, tail: str, head: str):
    """Takes a undirected edge and turns it into a directed one.

    tail indicates the starting node of the edge and head the end node, i.e.
    tail -> head.

    Args:
        tail (str): starting node
        head (str): end node

    Raises:
        AssertionError: if edge does not exist or is not undirected.
    """
    if (tail, head) not in self.undir_edges and (
        head,
        tail,
    ) not in self.undir_edges:
        raise AssertionError("Edge seems not to be undirected or even there at all.")
    self._undir_edges.discard((tail, head))
    self._undir_edges.discard((head, tail))
    self._neighbors[tail].discard(head)
    self._neighbors[head].discard(tail)
    self._undirected_neighbors[tail].discard(head)
    self._undirected_neighbors[head].discard(tail)

    self._add_dir_edge(i=tail, j=head)

vstructs()

Retrieve v-structures.

Returns:

Name Type Description
set set

set of all v-structures

Source code in causalAssembly/pdag.py
392
393
394
395
396
397
398
399
400
401
402
403
404
def vstructs(self) -> set:
    """Retrieve v-structures.

    Returns:
        set: set of all v-structures
    """
    vstructures = set()
    for node in self._nodes:
        for p1, p2 in combinations(self._parents[node], 2):
            if p1 not in self._parents[p2] and p2 not in self._parents[p1]:
                vstructures.add((p1, node))
                vstructures.add((p2, node))
    return vstructures

dag2cpdag(dag)

Convertes a DAG into its unique CPDAG.

Parameters:

Name Type Description Default
dag DiGraph

DAG the CPDAG corresponds to.

required

Returns:

Name Type Description
PDAG PDAG

unique CPDAG

Source code in causalAssembly/pdag.py
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
def dag2cpdag(dag: nx.DiGraph) -> PDAG:
    """Convertes a DAG into its unique CPDAG.

    Args:
        dag (nx.DiGraph): DAG the CPDAG corresponds to.

    Returns:
        PDAG: unique CPDAG
    """
    copy_dag: nx.DiGraph = dag.copy()  # type: ignore
    # Skeleton
    skeleton = nx.to_pandas_adjacency(copy_dag.to_undirected())
    # v-Structures
    vstructures = vstructs(dag=copy_dag)

    for edge in vstructures:  # orient v-structures
        skeleton.loc[edge[::-1]] = 0

    pdag_init = PDAG.from_pandas_adjacency(skeleton)

    # Apply Meek Rules
    cpdag = rule_1(pdag=pdag_init)
    cpdag = rule_2(pdag=cpdag)
    cpdag = rule_3(pdag=cpdag)
    cpdag = rule_4(pdag=cpdag)

    return cpdag

rule_1(pdag)

Meeks first rule.

Given the following pattern X -> Y - Z. Orient Y - Z to Y -> Z if X and Z are non-adjacent (otherwise a new v-structure arises).

Parameters:

Name Type Description Default
pdag PDAG

PDAG before application of rule.

required

Returns:

Name Type Description
PDAG PDAG

PDAG after application of rule.

Source code in causalAssembly/pdag.py
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
def rule_1(pdag: PDAG) -> PDAG:
    """Meeks first rule.

    Given the following pattern X -> Y - Z. Orient Y - Z to Y -> Z
    if X and Z are non-adjacent (otherwise a new v-structure arises).

    Args:
        pdag (PDAG): PDAG before application of rule.

    Returns:
        PDAG: PDAG after application of rule.
    """
    copy_pdag = pdag.copy()
    for edge in copy_pdag.undir_edges:
        reverse_edge = edge[::-1]
        test_edges = [edge, reverse_edge]
        for tail, head in test_edges:
            orient = False
            undir_parents = copy_pdag.parents(tail)
            if undir_parents:
                for parent in undir_parents:
                    if not copy_pdag.is_adjacent(parent, head):
                        orient = True
            if orient:
                copy_pdag.undir_to_dir_edge(tail=tail, head=head)
                break
    return copy_pdag

rule_2(pdag)

Meeks 2nd rule.

Given the following directed triple X -> Y -> Z where X - Z are indeed adjacent. Orient X - Z to X -> Z otherwise a cycle arises.

Parameters:

Name Type Description Default
pdag PDAG

PDAG before application of rule.

required

Returns:

Name Type Description
PDAG PDAG

PDAG after application of rule.

Source code in causalAssembly/pdag.py
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
def rule_2(pdag: PDAG) -> PDAG:
    """Meeks 2nd rule.

    Given the following directed triple
    X -> Y -> Z where X - Z are indeed adjacent.
    Orient X - Z to X -> Z otherwise a cycle arises.

    Args:
        pdag (PDAG): PDAG before application of rule.

    Returns:
        PDAG: PDAG after application of rule.
    """
    copy_pdag = pdag.copy()
    for edge in copy_pdag.undir_edges:
        reverse_edge = edge[::-1]
        test_edges = [edge, reverse_edge]
        for tail, head in test_edges:
            orient = False
            undir_children = copy_pdag.children(tail)
            if undir_children:
                for child in undir_children:
                    if head in copy_pdag.children(child):
                        orient = True
            if orient:
                copy_pdag.undir_to_dir_edge(tail=tail, head=head)
                break
    return copy_pdag

rule_3(pdag)

Meeks third rule.

Orient X - Z to X -> Z, whenever there are two triples X - Y1 -> Z and X - Y2 -> Z such that Y1 and Y2 are non-adjacent.

Parameters:

Name Type Description Default
pdag PDAG

PDAG before application of rule.

required

Returns:

Name Type Description
PDAG PDAG

PDAG after application of rule.

Source code in causalAssembly/pdag.py
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
def rule_3(pdag: PDAG) -> PDAG:
    """Meeks third rule.

    Orient X - Z to X -> Z, whenever there are two triples
    X - Y1 -> Z and X - Y2 -> Z such that Y1 and Y2 are non-adjacent.

    Args:
        pdag (PDAG): PDAG before application of rule.

    Returns:
        PDAG: PDAG after application of rule.
    """
    TWO = 2
    copy_pdag = pdag.copy()
    for edge in copy_pdag.undir_edges:
        reverse_edge = edge[::-1]
        test_edges = [edge, reverse_edge]
        for tail, head in test_edges:
            # if true that tail - node1 -> head and tail - node2 -> head
            # while {node1 U node2} = 0 then orient tail -> head
            orient = False
            if len(copy_pdag.undir_neighbors(tail)) >= TWO:
                undir_n = copy_pdag.undir_neighbors(tail)
                selection = [
                    (node1, node2)
                    for node1, node2 in combinations(undir_n, 2)
                    if not copy_pdag.is_adjacent(node1, node2)
                ]
                if selection:
                    for node1, node2 in selection:
                        if head in copy_pdag.parents(node1).intersection(copy_pdag.parents(node2)):
                            orient = True
            if orient:
                copy_pdag.undir_to_dir_edge(tail=tail, head=head)
                break
    return pdag

rule_4(pdag)

Meeks 4th rule.

Orient X - Y1 to X -> Y1, whenever there are two triples with X - Z and X - Y1 <- Z and X - Y2 -> Z such that Y1 and Y2 are non-adjacent.

Parameters:

Name Type Description Default
pdag PDAG

PDAG before application of rule.

required

Returns:

Name Type Description
PDAG PDAG

PDAG after application of rule.

Source code in causalAssembly/pdag.py
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
def rule_4(pdag: PDAG) -> PDAG:
    """Meeks 4th rule.

    Orient X - Y1 to X -> Y1, whenever there are
    two triples with X - Z and X - Y1 <- Z and X - Y2 -> Z
    such that Y1 and Y2 are non-adjacent.

    Args:
        pdag (PDAG): PDAG before application of rule.

    Returns:
        PDAG: PDAG after application of rule.
    """
    copy_pdag = pdag.copy()
    for edge in copy_pdag.undir_edges:
        reverse_edge = edge[::-1]
        test_edges = [edge, reverse_edge]
        for tail, head in test_edges:
            orient = False
            if len(copy_pdag.undir_neighbors(tail)) > 0:
                undirected_n = copy_pdag.undir_neighbors(tail)
                for undir_n in undirected_n:
                    if tail in copy_pdag.children(undir_n):
                        children_select = list(copy_pdag.children(undir_n))
                        if children_select:
                            for parent in children_select:
                                if head in copy_pdag.children(parent):
                                    orient = True
            if orient:
                copy_pdag.undir_to_dir_edge(tail=tail, head=head)
                break
    return pdag

vstructs(dag)

Retrieve all v-structures in a DAG.

Parameters:

Name Type Description Default
dag DiGraph

DAG in question

required

Returns:

Name Type Description
set set

Set of all v-structures.

Source code in causalAssembly/pdag.py
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
def vstructs(dag: nx.DiGraph) -> set:
    """Retrieve all v-structures in a DAG.

    Args:
        dag (nx.DiGraph): DAG in question

    Returns:
        set: Set of all v-structures.
    """
    vstructures = set()
    for node in dag.nodes():
        for p1, p2 in combinations(list(dag.predecessors(node)), 2):  # get all parents of node
            if not dag.has_edge(p1, p2) and not dag.has_edge(p2, p1):
                vstructures.add((p1, node))
                vstructures.add((p2, node))
    return vstructures

DAG class.

Copyright (c) 2023 Robert Bosch GmbH

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see https://www.gnu.org/licenses/.

DAG

General class for dealing with directed acyclic graph i.e.

graphs that are directed and must not contain any cycles.

Source code in causalAssembly/dag.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
class DAG:
    """General class for dealing with directed acyclic graph i.e.

    graphs that are directed and must not contain any cycles.
    """

    def __init__(
        self,
        nodes: list[str] | list[int] | set[int | str] | None = None,
        edges: list[tuple[str | int, str | int]] | set[tuple[str | int, str | int]] | None = None,
    ):
        """Initialize DAG.

        Args:
            nodes (list | None, optional): _description_. Defaults to None.
            edges (list[tuple] | None, optional): _description_. Defaults to None.
        """
        if nodes is None:
            nodes = []
        if edges is None:
            edges = []

        self._nodes = set(nodes)
        self._edges = set()
        self._parents = defaultdict(set)
        self._children = defaultdict(set)
        self.drf: dict = dict()
        self._random_state: np.random.Generator = np.random.default_rng(seed=2023)

        for edge in edges:
            self._add_edge(*edge)

    def _add_edge(self, i, j):
        self._nodes.add(i)
        self._nodes.add(j)
        self._edges.add((i, j))

        # Check if graph is acyclic
        if not self.is_acyclic():
            raise ValueError(
                "The edge set you provided \
                induces one or more cycles.\
                Check your input!"
            )

        self._children[i].add(j)
        self._parents[j].add(i)

    @property
    def random_state(self):
        """Random state.

        Returns:
            _type_: _description_
        """
        return self._random_state

    @random_state.setter
    def random_state(self, r: np.random.Generator):
        if not isinstance(r, np.random.Generator):
            raise AssertionError("Specify numpy random number generator object!")
        self._random_state = r

    def add_edge(self, edge: tuple[str, str]):
        """Add edge to DAG.

        Args:
            edge (tuple[str, str]): Edge to add
        """
        self._add_edge(*edge)

    def add_edges_from(self, edges: list[tuple[str, str]]):
        """Add multiple edges to DAG.

        Args:
            edges (list[tuple[str, str]]): Edges to add
        """
        for edge in edges:
            self.add_edge(edge=edge)

    def children(self, of_node: str) -> list[str]:
        """Gives all children of node `of_node`.

        Args:
            of_node (str): node in current DAG.

        Returns:
            list: of children.
        """
        if of_node in self._children.keys():
            return list(self._children[of_node])
        else:
            return []

    def parents(self, of_node: str) -> list[str]:
        """Gives all parents of node `of_node`.

        Args:
            of_node (str): node in current DAG.

        Returns:
            list: of parents.
        """
        if of_node in self._parents.keys():
            return list(self._parents[of_node])
        else:
            return []

    def induced_subgraph(self, nodes: list[str]) -> DAG:
        """Returns the induced subgraph on the nodes in `nodes`.

        Args:
            nodes (list[str]): List of nodes.

        Returns:
            DAG: Induced subgraph.
        """
        edges = [(i, j) for i, j in self.edges if i in nodes and j in nodes]
        return DAG(nodes=nodes, edges=edges)

    def is_adjacent(self, i: str, j: str) -> bool:
        """Return True if the graph contains an directed edge between i and j.

        Args:
            i (str): node i.
            j (str): node j.

        Returns:
            bool: True if i->j or i<-j
        """
        return (j, i) in self.edges or (i, j) in self.edges

    def is_clique(self, potential_clique: set) -> bool:
        """Check every pair of node X potential_clique is adjacent."""
        return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2))

    def is_acyclic(self) -> bool:
        """Check if the graph is acyclic.

        Returns:
            bool: True if graph is acyclic.
        """
        nx_dag = self.to_networkx()
        return nx.is_directed_acyclic_graph(nx_dag)

    @classmethod
    def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> DAG:
        """Build DAG from a Pandas adjacency matrix.

        Args:
            pd_amat (pd.DataFrame): input adjacency matrix.

        Returns:
            DAG
        """
        assert pd_amat.shape[0] == pd_amat.shape[1]
        nodes = list(pd_amat.columns)

        all_connections = []
        start, end = np.where(pd_amat != 0)
        for idx, _ in enumerate(start):
            all_connections.append((pd_amat.columns[start[idx]], pd_amat.columns[end[idx]]))

        temp = [set(i) for i in all_connections]
        temp2 = [arc for arc in all_connections if temp.count(set(arc)) > 1]

        dir_edges = [edge for edge in all_connections if edge not in temp2]

        return DAG(nodes=nodes, edges=dir_edges)

    def remove_edge(self, i: str, j: str):
        """Removes edge in question.

        Args:
            i (str): tail
            j (str): head

        Raises:
            AssertionError: if edge does not exist
        """
        if (i, j) not in self.edges:
            raise AssertionError("Edge does not exist in current DAG")

        self._edges.discard((i, j))
        self._children[i].discard(j)
        self._parents[j].discard(i)

    def remove_node(self, node):
        """Remove a node from the graph."""
        self._nodes.remove(node)

        self._edges = {(i, j) for i, j in self._edges if node not in (i, j)}

        for child in self._children[node]:
            self._parents[child].remove(node)

        for parent in self._parents[node]:
            self._children[parent].remove(node)

        self._parents.pop(node, "I was never here")
        self._children.pop(node, "I was never here")

    @property
    def adjacency_matrix(self) -> pd.DataFrame:
        """Returns adjacency matrix.

        The i,jth entry being one indicates that there is an edge
        from i to j. A zero indicates that there is no edge.

        Returns:
            pd.DataFrame: adjacency matrix
        """
        amat = pd.DataFrame(
            np.zeros([self.num_nodes, self.num_nodes]),
            index=self.nodes,
            columns=self.nodes,
        )
        for edge in self.edges:
            amat.loc[edge] = 1
        return amat

    def vstructs(self) -> set:
        """Retrieve v-structures.

        Returns:
            set: set of all v-structures
        """
        vstructures = set()
        for node in self._nodes:
            for p1, p2 in combinations(self._parents[node], 2):
                if p1 not in self._parents[p2] and p2 not in self._parents[p1]:
                    vstructures.add((p1, node))
                    vstructures.add((p2, node))
        return vstructures

    def copy(self):
        """Return a copy of the graph."""
        return DAG(nodes=self._nodes, edges=self._edges)

    def show(self):
        """Plot DAG."""
        graph = self.to_networkx()
        pos = nx.circular_layout(graph)
        nx.draw(graph, pos=pos, with_labels=True)

    def to_networkx(self) -> nx.DiGraph:
        """Convert to networkx graph.

        Returns:
            nx.DiGraph: DAG.
        """
        nx_dag = nx.DiGraph()
        nx_dag.add_nodes_from(self.nodes)
        nx_dag.add_edges_from(self.edges)

        return nx_dag

    @property
    def nodes(self) -> list:
        """Get all nods in current DAG.

        Returns:
            list: list of nodes.
        """
        return sorted(list(self._nodes))

    @property
    def num_nodes(self) -> int:
        """Number of nodes in current DAG.

        Returns:
            int: Number of nodes
        """
        return len(self._nodes)

    @property
    def num_edges(self) -> int:
        """Number of directed edges in current DAG.

        Returns:
            int: Number of directed edges
        """
        return len(self._edges)

    @property
    def sparsity(self) -> float:
        """Sparsity of the graph.

        Returns:
            float: in [0,1]
        """
        s = self.num_nodes
        return self.num_edges / s / (s - 1) * 2

    @property
    def edges(self) -> list[tuple]:
        """Gives all directed edges in current DAG.

        Returns:
            list[tuple]: List of directed edges.
        """
        return list(self._edges)

    @property
    def causal_order(self) -> list[str]:
        """Returns the causal order of the current graph.

        Note that this order is in general not unique.

        Returns:
            list[str]: Causal order
        """
        return list(nx.lexicographical_topological_sort(self.to_networkx()))

    @property
    def max_in_degree(self) -> int:
        """Maximum in-degree of the graph.

        Returns:
            int: Maximum in-degree
        """
        return max(len(self._parents[node]) for node in self._nodes)

    @property
    def max_out_degree(self) -> int:
        """Maximum out-degree of the graph.

        Returns:
            int: Maximum out-degree
        """
        return max(len(self._children[node]) for node in self._nodes)

    @classmethod
    def from_nx(cls, nx_dag: nx.DiGraph) -> DAG:
        """Convert to DAG from nx.DiGraph.

        Args:
            nx_dag (nx.DiGraph): DAG in question.

        Raises:
            TypeError: If DAG is not nx.DiGraph

        Returns:
            DAG
        """
        if not isinstance(nx_dag, nx.DiGraph):
            raise TypeError("DAG must be of type nx.DiGraph")
        return DAG(nodes=list(nx_dag.nodes), edges=list(nx_dag.edges))

    def save_drf(self, filename: str, location: str | Path | None = None):
        """Writes a drf dict to file. Please provide the .pkl suffix!

        Args:
            filename (str): name of the file to be written e.g. examplefile.pkl
            location (str, optional): path to file in case it's not located in
                the current working directory. Defaults to None.
        """
        if location is None:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "wb") as f:
            pickle.dump(self.drf, f)

    def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
        """Draw from the trained DRF.

        Args:
            size (int, optional): Number of samples to be drawn. Defaults to 10.
            smoothed (bool, optional): If set to true, marginal distributions will
                be sampled from smoothed bootstraps. Defaults to True.

        Returns:
            pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
        """
        return _sample_from_drf(graph=self, size=size, smoothed=smoothed)

    def to_cpdag(self) -> PDAG:
        """Conversion to CPDAG.

        Returns:
            PDAG: _description_
        """
        return dag2cpdag(dag=self.to_networkx())

    @classmethod
    def load_drf(cls, filename: str, location: str | Path | None = None) -> dict:
        """Loads a drf dict from a .pkl file into the workspace.

        Args:
            filename (str): name of the file e.g. examplefile.pkl
            location (str, optional): path to file in case it's not located
                in the current working directory. Defaults to None.

        Returns:
            DRF (dict): dict of trained drf objects
        """
        if not location:
            location = Path().resolve()

        location_path = Path(location, filename)

        with open(location_path, "rb") as drf:
            pickle_drf = pickle.load(drf)

        return pickle_drf

adjacency_matrix property

Returns adjacency matrix.

The i,jth entry being one indicates that there is an edge from i to j. A zero indicates that there is no edge.

Returns:

Type Description
DataFrame

pd.DataFrame: adjacency matrix

causal_order property

Returns the causal order of the current graph.

Note that this order is in general not unique.

Returns:

Type Description
list[str]

list[str]: Causal order

edges property

Gives all directed edges in current DAG.

Returns:

Type Description
list[tuple]

list[tuple]: List of directed edges.

max_in_degree property

Maximum in-degree of the graph.

Returns:

Name Type Description
int int

Maximum in-degree

max_out_degree property

Maximum out-degree of the graph.

Returns:

Name Type Description
int int

Maximum out-degree

nodes property

Get all nods in current DAG.

Returns:

Name Type Description
list list

list of nodes.

num_edges property

Number of directed edges in current DAG.

Returns:

Name Type Description
int int

Number of directed edges

num_nodes property

Number of nodes in current DAG.

Returns:

Name Type Description
int int

Number of nodes

random_state property writable

Random state.

Returns:

Name Type Description
_type_

description

sparsity property

Sparsity of the graph.

Returns:

Name Type Description
float float

in [0,1]

__init__(nodes=None, edges=None)

Initialize DAG.

Parameters:

Name Type Description Default
nodes list | None

description. Defaults to None.

None
edges list[tuple] | None

description. Defaults to None.

None
Source code in causalAssembly/dag.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    nodes: list[str] | list[int] | set[int | str] | None = None,
    edges: list[tuple[str | int, str | int]] | set[tuple[str | int, str | int]] | None = None,
):
    """Initialize DAG.

    Args:
        nodes (list | None, optional): _description_. Defaults to None.
        edges (list[tuple] | None, optional): _description_. Defaults to None.
    """
    if nodes is None:
        nodes = []
    if edges is None:
        edges = []

    self._nodes = set(nodes)
    self._edges = set()
    self._parents = defaultdict(set)
    self._children = defaultdict(set)
    self.drf: dict = dict()
    self._random_state: np.random.Generator = np.random.default_rng(seed=2023)

    for edge in edges:
        self._add_edge(*edge)

add_edge(edge)

Add edge to DAG.

Parameters:

Name Type Description Default
edge tuple[str, str]

Edge to add

required
Source code in causalAssembly/dag.py
 99
100
101
102
103
104
105
def add_edge(self, edge: tuple[str, str]):
    """Add edge to DAG.

    Args:
        edge (tuple[str, str]): Edge to add
    """
    self._add_edge(*edge)

add_edges_from(edges)

Add multiple edges to DAG.

Parameters:

Name Type Description Default
edges list[tuple[str, str]]

Edges to add

required
Source code in causalAssembly/dag.py
107
108
109
110
111
112
113
114
def add_edges_from(self, edges: list[tuple[str, str]]):
    """Add multiple edges to DAG.

    Args:
        edges (list[tuple[str, str]]): Edges to add
    """
    for edge in edges:
        self.add_edge(edge=edge)

children(of_node)

Gives all children of node of_node.

Parameters:

Name Type Description Default
of_node str

node in current DAG.

required

Returns:

Name Type Description
list list[str]

of children.

Source code in causalAssembly/dag.py
116
117
118
119
120
121
122
123
124
125
126
127
128
def children(self, of_node: str) -> list[str]:
    """Gives all children of node `of_node`.

    Args:
        of_node (str): node in current DAG.

    Returns:
        list: of children.
    """
    if of_node in self._children.keys():
        return list(self._children[of_node])
    else:
        return []

copy()

Return a copy of the graph.

Source code in causalAssembly/dag.py
271
272
273
def copy(self):
    """Return a copy of the graph."""
    return DAG(nodes=self._nodes, edges=self._edges)

from_nx(nx_dag) classmethod

Convert to DAG from nx.DiGraph.

Parameters:

Name Type Description Default
nx_dag DiGraph

DAG in question.

required

Raises:

Type Description
TypeError

If DAG is not nx.DiGraph

Returns:

Type Description
DAG

DAG

Source code in causalAssembly/dag.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
@classmethod
def from_nx(cls, nx_dag: nx.DiGraph) -> DAG:
    """Convert to DAG from nx.DiGraph.

    Args:
        nx_dag (nx.DiGraph): DAG in question.

    Raises:
        TypeError: If DAG is not nx.DiGraph

    Returns:
        DAG
    """
    if not isinstance(nx_dag, nx.DiGraph):
        raise TypeError("DAG must be of type nx.DiGraph")
    return DAG(nodes=list(nx_dag.nodes), edges=list(nx_dag.edges))

from_pandas_adjacency(pd_amat) classmethod

Build DAG from a Pandas adjacency matrix.

Parameters:

Name Type Description Default
pd_amat DataFrame

input adjacency matrix.

required

Returns:

Type Description
DAG

DAG

Source code in causalAssembly/dag.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
@classmethod
def from_pandas_adjacency(cls, pd_amat: pd.DataFrame) -> DAG:
    """Build DAG from a Pandas adjacency matrix.

    Args:
        pd_amat (pd.DataFrame): input adjacency matrix.

    Returns:
        DAG
    """
    assert pd_amat.shape[0] == pd_amat.shape[1]
    nodes = list(pd_amat.columns)

    all_connections = []
    start, end = np.where(pd_amat != 0)
    for idx, _ in enumerate(start):
        all_connections.append((pd_amat.columns[start[idx]], pd_amat.columns[end[idx]]))

    temp = [set(i) for i in all_connections]
    temp2 = [arc for arc in all_connections if temp.count(set(arc)) > 1]

    dir_edges = [edge for edge in all_connections if edge not in temp2]

    return DAG(nodes=nodes, edges=dir_edges)

induced_subgraph(nodes)

Returns the induced subgraph on the nodes in nodes.

Parameters:

Name Type Description Default
nodes list[str]

List of nodes.

required

Returns:

Name Type Description
DAG DAG

Induced subgraph.

Source code in causalAssembly/dag.py
144
145
146
147
148
149
150
151
152
153
154
def induced_subgraph(self, nodes: list[str]) -> DAG:
    """Returns the induced subgraph on the nodes in `nodes`.

    Args:
        nodes (list[str]): List of nodes.

    Returns:
        DAG: Induced subgraph.
    """
    edges = [(i, j) for i, j in self.edges if i in nodes and j in nodes]
    return DAG(nodes=nodes, edges=edges)

is_acyclic()

Check if the graph is acyclic.

Returns:

Name Type Description
bool bool

True if graph is acyclic.

Source code in causalAssembly/dag.py
172
173
174
175
176
177
178
179
def is_acyclic(self) -> bool:
    """Check if the graph is acyclic.

    Returns:
        bool: True if graph is acyclic.
    """
    nx_dag = self.to_networkx()
    return nx.is_directed_acyclic_graph(nx_dag)

is_adjacent(i, j)

Return True if the graph contains an directed edge between i and j.

Parameters:

Name Type Description Default
i str

node i.

required
j str

node j.

required

Returns:

Name Type Description
bool bool

True if i->j or i<-j

Source code in causalAssembly/dag.py
156
157
158
159
160
161
162
163
164
165
166
def is_adjacent(self, i: str, j: str) -> bool:
    """Return True if the graph contains an directed edge between i and j.

    Args:
        i (str): node i.
        j (str): node j.

    Returns:
        bool: True if i->j or i<-j
    """
    return (j, i) in self.edges or (i, j) in self.edges

is_clique(potential_clique)

Check every pair of node X potential_clique is adjacent.

Source code in causalAssembly/dag.py
168
169
170
def is_clique(self, potential_clique: set) -> bool:
    """Check every pair of node X potential_clique is adjacent."""
    return all(self.is_adjacent(i, j) for i, j in combinations(potential_clique, 2))

load_drf(filename, location=None) classmethod

Loads a drf dict from a .pkl file into the workspace.

Parameters:

Name Type Description Default
filename str

name of the file e.g. examplefile.pkl

required
location str

path to file in case it's not located in the current working directory. Defaults to None.

None

Returns:

Name Type Description
DRF dict

dict of trained drf objects

Source code in causalAssembly/dag.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
@classmethod
def load_drf(cls, filename: str, location: str | Path | None = None) -> dict:
    """Loads a drf dict from a .pkl file into the workspace.

    Args:
        filename (str): name of the file e.g. examplefile.pkl
        location (str, optional): path to file in case it's not located
            in the current working directory. Defaults to None.

    Returns:
        DRF (dict): dict of trained drf objects
    """
    if not location:
        location = Path().resolve()

    location_path = Path(location, filename)

    with open(location_path, "rb") as drf:
        pickle_drf = pickle.load(drf)

    return pickle_drf

parents(of_node)

Gives all parents of node of_node.

Parameters:

Name Type Description Default
of_node str

node in current DAG.

required

Returns:

Name Type Description
list list[str]

of parents.

Source code in causalAssembly/dag.py
130
131
132
133
134
135
136
137
138
139
140
141
142
def parents(self, of_node: str) -> list[str]:
    """Gives all parents of node `of_node`.

    Args:
        of_node (str): node in current DAG.

    Returns:
        list: of parents.
    """
    if of_node in self._parents.keys():
        return list(self._parents[of_node])
    else:
        return []

remove_edge(i, j)

Removes edge in question.

Parameters:

Name Type Description Default
i str

tail

required
j str

head

required

Raises:

Type Description
AssertionError

if edge does not exist

Source code in causalAssembly/dag.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def remove_edge(self, i: str, j: str):
    """Removes edge in question.

    Args:
        i (str): tail
        j (str): head

    Raises:
        AssertionError: if edge does not exist
    """
    if (i, j) not in self.edges:
        raise AssertionError("Edge does not exist in current DAG")

    self._edges.discard((i, j))
    self._children[i].discard(j)
    self._parents[j].discard(i)

remove_node(node)

Remove a node from the graph.

Source code in causalAssembly/dag.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def remove_node(self, node):
    """Remove a node from the graph."""
    self._nodes.remove(node)

    self._edges = {(i, j) for i, j in self._edges if node not in (i, j)}

    for child in self._children[node]:
        self._parents[child].remove(node)

    for parent in self._parents[node]:
        self._children[parent].remove(node)

    self._parents.pop(node, "I was never here")
    self._children.pop(node, "I was never here")

sample_from_drf(size=10, smoothed=True)

Draw from the trained DRF.

Parameters:

Name Type Description Default
size int

Number of samples to be drawn. Defaults to 10.

10
smoothed bool

If set to true, marginal distributions will be sampled from smoothed bootstraps. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Data frame that follows the distribution implied by the ground truth.

Source code in causalAssembly/dag.py
401
402
403
404
405
406
407
408
409
410
411
412
def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
    """Draw from the trained DRF.

    Args:
        size (int, optional): Number of samples to be drawn. Defaults to 10.
        smoothed (bool, optional): If set to true, marginal distributions will
            be sampled from smoothed bootstraps. Defaults to True.

    Returns:
        pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
    """
    return _sample_from_drf(graph=self, size=size, smoothed=smoothed)

save_drf(filename, location=None)

Writes a drf dict to file. Please provide the .pkl suffix!

Parameters:

Name Type Description Default
filename str

name of the file to be written e.g. examplefile.pkl

required
location str

path to file in case it's not located in the current working directory. Defaults to None.

None
Source code in causalAssembly/dag.py
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def save_drf(self, filename: str, location: str | Path | None = None):
    """Writes a drf dict to file. Please provide the .pkl suffix!

    Args:
        filename (str): name of the file to be written e.g. examplefile.pkl
        location (str, optional): path to file in case it's not located in
            the current working directory. Defaults to None.
    """
    if location is None:
        location = Path().resolve()

    location_path = Path(location, filename)

    with open(location_path, "wb") as f:
        pickle.dump(self.drf, f)

show()

Plot DAG.

Source code in causalAssembly/dag.py
275
276
277
278
279
def show(self):
    """Plot DAG."""
    graph = self.to_networkx()
    pos = nx.circular_layout(graph)
    nx.draw(graph, pos=pos, with_labels=True)

to_cpdag()

Conversion to CPDAG.

Returns:

Name Type Description
PDAG PDAG

description

Source code in causalAssembly/dag.py
414
415
416
417
418
419
420
def to_cpdag(self) -> PDAG:
    """Conversion to CPDAG.

    Returns:
        PDAG: _description_
    """
    return dag2cpdag(dag=self.to_networkx())

to_networkx()

Convert to networkx graph.

Returns:

Type Description
DiGraph

nx.DiGraph: DAG.

Source code in causalAssembly/dag.py
281
282
283
284
285
286
287
288
289
290
291
def to_networkx(self) -> nx.DiGraph:
    """Convert to networkx graph.

    Returns:
        nx.DiGraph: DAG.
    """
    nx_dag = nx.DiGraph()
    nx_dag.add_nodes_from(self.nodes)
    nx_dag.add_edges_from(self.edges)

    return nx_dag

vstructs()

Retrieve v-structures.

Returns:

Name Type Description
set set

set of all v-structures

Source code in causalAssembly/dag.py
257
258
259
260
261
262
263
264
265
266
267
268
269
def vstructs(self) -> set:
    """Retrieve v-structures.

    Returns:
        set: set of all v-structures
    """
    vstructures = set()
    for node in self._nodes:
        for p1, p2 in combinations(self._parents[node], 2):
            if p1 not in self._parents[p2] and p2 not in self._parents[p1]:
                vstructures.add((p1, node))
                vstructures.add((p2, node))
    return vstructures