torchphysics.utils.plotting package
Different plotting functions:
plot_functions
implement functions to show the output of the neural network or values derivative from it (derivatives, …).animation
implement the same concepts as the plot functions, just as animations.scatter_points
are meant to show a batch of used training points of a sampler.
Submodules
torchphysics.utils.plotting.animation module
This file contains different functions for animating the output of the neural network
- torchphysics.utils.plotting.animation.animate(model, ani_function, ani_sampler, ani_speed=50, angle=[30, 30], ani_type='')[source]
Main function for animations.
- Parameters:
model (torchphysics.models.Model) – The Model/neural network that should be used in the plot.
ani_function (Callable) – A function that specfices the part of the model that should be animated. Of the same form as the plot function.
point_sampler (torchphysics.samplers.AnimationSampler) – A Sampler that creates the points that should be used for the animation.
angle (list, optional) – The view angle for 3D plots. Standard angle is [30, 30]
ani_type (str, optional) –
Specifies how the output sholud be animated. If no input is given, the method will try to use a fitting way, to show the data. Implemented types are:
’line’ for line animations, with a 1D-domain and output
’surface_2D’ for surface animation, with a 2D-domain
’quiver_2D’ for quiver/vector field animation, with a 2D-domain
’contour_surface’ for contour/colormaps, with a 2D-domain
- Returns:
plt.figure – The figure handle of the created plot
animation.FuncAnimation – The function that handles the animation
Notes
This methode only creates a simple animation and is for complex domains not really optimized. Should only be used to get a rough understanding of the trained neural network.
- torchphysics.utils.plotting.animation.animation_contour_2D(outputs, ani_sampler, animation_points, domain_points, angle, ani_speed)[source]
Handles colormap animations in 2D
- torchphysics.utils.plotting.animation.animation_line(outputs, ani_sampler, animation_points, domain_points, angle, ani_speed)[source]
Handels 1D animations, inputs are the same as animation().
torchphysics.utils.plotting.plot_functions module
This file contains different functions for plotting outputs of neural networks
- class torchphysics.utils.plotting.plot_functions.Plotter(plot_function, point_sampler, angle=[30, 30], log_interval=None, plot_type='', **kwargs)[source]
Bases:
object
Object to collect plotting properties.
- Parameters:
plot_function (callable) –
A function that specfices the part of the model that should be plotted. Can be of the same form as the condition-functions. E.g. if the solution name is ‘u’ we can use
plot_func(u):return u[:, 0]to plot the first entry of ‘u’. For the derivative we could write:
plot_func(u, x):return grad(u, x)point_sampler (torchphysics.samplers.PlotSampler) – A Sampler that creates the points that should be used for the plot.
angle (list, optional) – The view angle for surface plots. Standart angle is [30, 30]
log_interval (int) – Plots will be saved every log_interval steps if the plotter is used in training of a model.
plot_type (str, optional) –
Specifies how the output should be plotted. If no input is given, the method will try to use a fitting way, to show the data. Implemented types are:
’line’ for plots in 1D
’surface_2D’ for surface plots, with a 2D-domain
’curve’ for a curve in 3D, with a 1D-domain,
’quiver_2D’ for quiver/vector field plots, with a 2D-domain
’contour_surface’ for contour/colormaps, with a 2D-domain
kwargs – Additional arguments to specify different parameters/behaviour of the plot. See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.html for possible arguments of each underlying object.
- torchphysics.utils.plotting.plot_functions.contour_2D(output, domain_points, point_sampler, angle, **kwargs)[source]
Handles colormap/contour plots w.r.t. a two dimensional variable.
- torchphysics.utils.plotting.plot_functions.curve3D(output, domain_points, point_sampler, angle, **kwargs)[source]
Handles curve plots where the output is 2D and the domain is 1D.
- torchphysics.utils.plotting.plot_functions.line_plot(output, domain_points, point_sampler, angle, **kwargs)[source]
Handels line plots w.r.t. a one dimensional variable.
- torchphysics.utils.plotting.plot_functions.plot(model, plot_function, point_sampler, angle=[30, 30], plot_type='', device='cpu', **kwargs)[source]
Main function for plotting
- Parameters:
model (torchphysics.models.Model) – The Model/neural network that should be used in the plot.
plot_function (callable) –
A function that specfices the part of the model that should be plotted. Of the same form as the condition-functions. E.g. if the solution name is ‘u’, we can use
plot_func(u):return u[:, 0]to plot the first entry of ‘u’. For the derivative we could write:
plot_func(u, x):return grad(u, x)point_sampler (torchphysics.samplers.PlotSampler) – A Sampler that creates the points that should be used for the plot.
angle (list, optional) – The view angle for 3D plots. Standard angle is [30, 30]
plot_type (str, optional) –
Specifies how the output sholud be plotted. If no input is given the method will try to use a fitting way to show the data. Implemented types are:
’line’ for plots in 1D
’surface_2D’ for surface plots, with a 2D-domain
’curve’ for a curve in 3D, with a 1D-domain,
’quiver_2D’ for quiver/vector-field plots, with a 2D-domain
’contour_surface’ for contour/colormaps, with a 2D-domain
kwargs – Additional arguments to specify different parameters/behaviour of the plot. See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.html for possible arguments of each underlying object.
- Returns:
The figure handle of the created plot
- Return type:
plt.figure
Notes
What this function does is: creating points with sampler -> evaluate model -> evalute plot function -> create the plot with matplotlib.pyplot.
The function is only meant to give a fast overview over the trained neural network. In general the methode is not optimized for complex domains.
torchphysics.utils.plotting.scatter_points module
Function to show an example of the created points of the sampler.
- torchphysics.utils.plotting.scatter_points.scatter(subspace, *samplers)[source]
Shows (one batch) of used points in the training. If the sampler is static, the shown points will be the points for the training. If not the points may vary, depending of the sampler.
- Parameters:
subspace (torchphysics.problem.Space) – The (sub-)space of which the points should be plotted. Only plotting for dimensions <= 3 is possible.
*samplers (torchphysics.problem.Samplers) – The diffrent samplers for which the points should be plotted. The plot for each sampler will be created in the order there were passed in.
- Returns:
fig – The figure handle of the plot.
- Return type:
matplotlib.pyplot.figure