torchphysics.utils.plotting package

Different plotting functions:

  • plot_functions implement functions to show the output of the neural network or values derivative from it (derivatives, …).

  • animation implement the same concepts as the plot functions, just as animations.

  • scatter_points are meant to show a batch of used training points of a sampler.

Submodules

torchphysics.utils.plotting.animation module

This file contains different functions for animating the output of the neural network

torchphysics.utils.plotting.animation.animate(model, ani_function, ani_sampler, ani_speed=50, angle=[30, 30], ani_type='')[source]

Main function for animations.

Parameters:
  • model (torchphysics.models.Model) – The Model/neural network that should be used in the plot.

  • ani_function (Callable) – A function that specfices the part of the model that should be animated. Of the same form as the plot function.

  • point_sampler (torchphysics.samplers.AnimationSampler) – A Sampler that creates the points that should be used for the animation.

  • angle (list, optional) – The view angle for 3D plots. Standard angle is [30, 30]

  • ani_type (str, optional) –

    Specifies how the output sholud be animated. If no input is given, the method will try to use a fitting way, to show the data. Implemented types are:

    • ’line’ for line animations, with a 1D-domain and output

    • ’surface_2D’ for surface animation, with a 2D-domain

    • ’quiver_2D’ for quiver/vector field animation, with a 2D-domain

    • ’contour_surface’ for contour/colormaps, with a 2D-domain

Returns:

  • plt.figure – The figure handle of the created plot

  • animation.FuncAnimation – The function that handles the animation

Notes

This methode only creates a simple animation and is for complex domains not really optimized. Should only be used to get a rough understanding of the trained neural network.

torchphysics.utils.plotting.animation.animation_contour_2D(outputs, ani_sampler, animation_points, domain_points, angle, ani_speed)[source]

Handles colormap animations in 2D

torchphysics.utils.plotting.animation.animation_line(outputs, ani_sampler, animation_points, domain_points, angle, ani_speed)[source]

Handels 1D animations, inputs are the same as animation().

torchphysics.utils.plotting.animation.animation_quiver_2D(outputs, ani_sampler, animation_points, domain_points, angle, ani_speed)[source]

Handles quiver animations in 2D

torchphysics.utils.plotting.animation.animation_surface2D(outputs, ani_sampler, animation_points, domain_points, angle, ani_speed)[source]

Handels 2D animations, inputs are the same as animation().

torchphysics.utils.plotting.plot_functions module

This file contains different functions for plotting outputs of neural networks

class torchphysics.utils.plotting.plot_functions.Plotter(plot_function, point_sampler, angle=[30, 30], log_interval=None, plot_type='', **kwargs)[source]

Bases: object

Object to collect plotting properties.

Parameters:
  • plot_function (callable) –

    A function that specfices the part of the model that should be plotted. Can be of the same form as the condition-functions. E.g. if the solution name is ‘u’ we can use

    plot_func(u):
    return u[:, 0]

    to plot the first entry of ‘u’. For the derivative we could write:

    plot_func(u, x):
    return grad(u, x)

  • point_sampler (torchphysics.samplers.PlotSampler) – A Sampler that creates the points that should be used for the plot.

  • angle (list, optional) – The view angle for surface plots. Standart angle is [30, 30]

  • log_interval (int) – Plots will be saved every log_interval steps if the plotter is used in training of a model.

  • plot_type (str, optional) –

    Specifies how the output should be plotted. If no input is given, the method will try to use a fitting way, to show the data. Implemented types are:

    • ’line’ for plots in 1D

    • ’surface_2D’ for surface plots, with a 2D-domain

    • ’curve’ for a curve in 3D, with a 1D-domain,

    • ’quiver_2D’ for quiver/vector field plots, with a 2D-domain

    • ’contour_surface’ for contour/colormaps, with a 2D-domain

  • kwargs – Additional arguments to specify different parameters/behaviour of the plot. See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.html for possible arguments of each underlying object.

plot(model)[source]

Creates the plot of the model.

Parameters:

model (torchphysics.models.Model) – The Model/neural network that should be used in the plot.

Returns:

The figure handle of the created plot

Return type:

plt.figure

torchphysics.utils.plotting.plot_functions.contour_2D(output, domain_points, point_sampler, angle, **kwargs)[source]

Handles colormap/contour plots w.r.t. a two dimensional variable.

torchphysics.utils.plotting.plot_functions.curve3D(output, domain_points, point_sampler, angle, **kwargs)[source]

Handles curve plots where the output is 2D and the domain is 1D.

torchphysics.utils.plotting.plot_functions.line_plot(output, domain_points, point_sampler, angle, **kwargs)[source]

Handels line plots w.r.t. a one dimensional variable.

torchphysics.utils.plotting.plot_functions.plot(model, plot_function, point_sampler, angle=[30, 30], plot_type='', device='cpu', **kwargs)[source]

Main function for plotting

Parameters:
  • model (torchphysics.models.Model) – The Model/neural network that should be used in the plot.

  • plot_function (callable) –

    A function that specfices the part of the model that should be plotted. Of the same form as the condition-functions. E.g. if the solution name is ‘u’, we can use

    plot_func(u):
    return u[:, 0]

    to plot the first entry of ‘u’. For the derivative we could write:

    plot_func(u, x):
    return grad(u, x)

  • point_sampler (torchphysics.samplers.PlotSampler) – A Sampler that creates the points that should be used for the plot.

  • angle (list, optional) – The view angle for 3D plots. Standard angle is [30, 30]

  • plot_type (str, optional) –

    Specifies how the output sholud be plotted. If no input is given the method will try to use a fitting way to show the data. Implemented types are:

    • ’line’ for plots in 1D

    • ’surface_2D’ for surface plots, with a 2D-domain

    • ’curve’ for a curve in 3D, with a 1D-domain,

    • ’quiver_2D’ for quiver/vector-field plots, with a 2D-domain

    • ’contour_surface’ for contour/colormaps, with a 2D-domain

  • kwargs – Additional arguments to specify different parameters/behaviour of the plot. See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.html for possible arguments of each underlying object.

Returns:

The figure handle of the created plot

Return type:

plt.figure

Notes

What this function does is: creating points with sampler -> evaluate model -> evalute plot function -> create the plot with matplotlib.pyplot.

The function is only meant to give a fast overview over the trained neural network. In general the methode is not optimized for complex domains.

torchphysics.utils.plotting.plot_functions.quiver2D(output, domain_points, point_sampler, angle, **kwargs)[source]

Handles quiver/vector field plots w.r.t. a two dimensional variable.

torchphysics.utils.plotting.plot_functions.surface2D(output, domain_points, point_sampler, angle, **kwargs)[source]

Handels surface plots w.r.t. a two dimensional variable.

torchphysics.utils.plotting.scatter_points module

Function to show an example of the created points of the sampler.

torchphysics.utils.plotting.scatter_points.scatter(subspace, *samplers)[source]

Shows (one batch) of used points in the training. If the sampler is static, the shown points will be the points for the training. If not the points may vary, depending of the sampler.

Parameters:
  • subspace (torchphysics.problem.Space) – The (sub-)space of which the points should be plotted. Only plotting for dimensions <= 3 is possible.

  • *samplers (torchphysics.problem.Samplers) – The diffrent samplers for which the points should be plotted. The plot for each sampler will be created in the order there were passed in.

Returns:

fig – The figure handle of the plot.

Return type:

matplotlib.pyplot.figure