Source code for torchphysics.problem.domains.functionsets.harmonic_functionset

import torch
import math

from ...spaces.points import Points
from .functionset import TestFunctionSet
from ..domain1D import Interval


[docs] class HarmonicFunctionSet1D(TestFunctionSet): def __init__(self, function_space, interval : Interval, frequence, samples_per_max_frequence : int = 5): super().__init__(function_space=function_space) self.interval = interval self.samples_max = samples_per_max_frequence if isinstance(frequence, list): self.basis_dim = max(frequence) self.frequence_list = frequence else: self.basis_dim = frequence self.frequence_list = torch.arange(1, frequence+1, 1) quad_points = torch.linspace(self.interval.lower_bound(), self.interval.upper_bound(), self.basis_dim * self.samples_max + 2)[1:-1] self.quadrature_points_per_dof = quad_points.repeat((self.basis_dim, 1)).unsqueeze(-1) self.quadrature_weights_per_dof = quad_points[1] - quad_points[0] self.compute_basis_at_quadrature_points()
[docs] def switch_quadrature_mode_on(self, set_on : bool): self.quadrature_mode_on = set_on if not set_on: AssertionError("Arbritrary evaluation not implemented!")
[docs] def to(self, device): self.quadrature_points_per_dof = self.quadrature_points_per_dof.to(device) self.quadrature_weigths_per_dof = self.quadrature_weights_per_dof.to(device) self.basis_at_quadrature = self.basis_at_quadrature.to(device) self.grad_at_quadrature = self.grad_at_quadrature.to(device)
[docs] def compute_basis_at_quadrature_points(self): self.basis_at_quadrature = torch.zeros_like(self.quadrature_points_per_dof) self.grad_at_quadrature = torch.zeros_like(self.quadrature_points_per_dof) int_size = self.interval.upper_bound() - self.interval.lower_bound() for i, n in enumerate(self.frequence_list): self.basis_at_quadrature[i] = \ torch.sin(n*math.pi/(int_size) * \ (self.quadrature_points_per_dof[i] - self.interval.lower_bound())) self.grad_at_quadrature[i] = -n*math.pi/(int_size) * \ torch.cos(n*math.pi/(int_size) * \ (self.quadrature_points_per_dof[i] - self.interval.lower_bound()))
[docs] def __call__(self, x=None): if self.quadrature_mode_on: input_variable_name = self.function_space.input_space.variables.pop() return Points(self.eval_fn_helper.apply(x[input_variable_name], self.basis_at_quadrature, self.grad_at_quadrature), self.function_space.output_space) else: raise NotImplementedError
[docs] def grad(self, x=None): if self.quadrature_mode_on or x == None: return self.grad_at_quadrature
[docs] def get_quad_weights(self, n): repeats = n // len(self.quadrature_weights_per_dof) return self.quadrature_weights_per_dof.repeat((repeats, 1, 1))
[docs] def get_quadrature_points(self): return Points(self.quadrature_points_per_dof, self.function_space.input_space)