import torch def epxm1_x(x): # Compute the (exp(x) - 1) / x term with a small value to avoid division by zero. result = torch.special.expm1(x) / x # replace NaN or inf values with 0 result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) mask = torch.abs(x) < 1e-2 result = torch.where(mask, 1 + x/2. + x**2 / 6., result) return result def epxm1mx_x2(x): # Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero. result = (torch.special.expm1(x) - x) / x**2 # replace NaN or inf values with 0 result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) mask = torch.abs(x**2) < 1e-2 result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result) return result def expm1mxmhx2_x3(x): # Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero. result = (torch.special.expm1(x) - x - x**2 / 2) / x**3 # replace NaN or inf values with 0 result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) mask = torch.abs(x**3) < 1e-2 result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result) return result def exp_1mcosh_GD(gamma_t, delta): """ Compute e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) Parameters: gamma_t: Γ*t term (could be a scalar or tensor) delta: Δ term (could be a scalar or tensor) Returns: Result of the computation with numerical stability handling """ # Main computation is_positive = delta > 0 sqrt_abs_delta = torch.sqrt(torch.abs(delta)) gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) ) numerator = torch.where(is_positive, numerator_pos, numerator_neg) result = numerator / (delta * gamma_t**2 ) # Handle NaN/inf cases result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) # Handle numerical instability for small delta mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2 taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t) result = torch.where(mask, taylor, result) return result def exp_sinh_GsqrtD(gamma_t, delta): """ Compute e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) Parameters: gamma_t: Γ*t term (could be a scalar or tensor) delta: Δ term (could be a scalar or tensor) Returns: Result of the computation with numerical stability handling """ # Main computation is_positive = delta > 0 sqrt_abs_delta = torch.sqrt(torch.abs(delta)) gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 denominator_pos = gamma_t_sqrt_delta result_pos = numerator_pos / gamma_t_sqrt_delta result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos)) # Taylor expansion for small gamma_t_sqrt_delta mask = torch.abs(gamma_t_sqrt_delta) < 1e-2 taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t) result_pos = torch.where(mask, taylor, result_pos) # Handle negative delta result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi) result = torch.where(is_positive, result_pos, result_neg) return result def exp_cosh(gamma_t, delta): """ Compute e^(-Γt) * cosh(Γt√Δ) Parameters: gamma_t: Γ*t term (could be a scalar or tensor) delta: Δ term (could be a scalar or tensor) Returns: Result of the computation with numerical stability handling """ exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result return result def exp_sinh_sqrtD(gamma_t, delta): """ Compute e^(-Γt) * sinh(Γt√Δ) / √Δ Parameters: gamma_t: Γ*t term (could be a scalar or tensor) delta: Δ term (could be a scalar or tensor) Returns: Result of the computation with numerical stability handling """ exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) result = gamma_t * exp_sinh_GsqrtD_result return result def zeta1(gamma_t, delta): # Compute hyperbolic terms and exponential half_gamma_t = gamma_t / 2 exp_cosh_term = exp_cosh(half_gamma_t, delta) exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta) # Main computation numerator = 1 - (exp_cosh_term + exp_sinh_term) denominator = gamma_t * (1 - delta) / 4 result = 1 - numerator / denominator # Handle numerical instability result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) # Taylor expansion for small x (similar to your epxm1Dx approach) mask = torch.abs(denominator) < 5e-3 term1 = epxm1_x(-gamma_t) term2 = epxm1mx_x2(-gamma_t) term3 = expm1mxmhx2_x3(-gamma_t) taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2 result = torch.where(mask, taylor, result) return result def exp_cosh_minus_terms(gamma_t, delta): """ Compute E^(-tΓ) * (Cosh[tΓ] - 1 - (Cosh[tΓ√Δ] - 1)/Δ) / (tΓ(1 - Δ)) Parameters: gamma_t: Γ*t term (could be a scalar or tensor) delta: Δ term (could be a scalar or tensor) Returns: Result of the computation with numerical stability handling """ exp_term = torch.exp(-gamma_t) # Compute individual terms exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term #exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) # Main computation numerator = exp_cosh_term - exp_cosh_delta_term denominator = gamma_t * (1 - delta) result = numerator / denominator # Handle numerical instability result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) # Taylor expansion for small gamma_t and delta near 1 mask = (torch.abs(denominator) < 1e-1) exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0) taylor = ( gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0) - denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) ) ) result = torch.where(mask, taylor, result) return result def zeta2(gamma_t, delta): half_gamma_t = gamma_t / 2 return exp_sinh_GsqrtD(half_gamma_t, delta) def sig11(gamma_t, delta): return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) def Zcoefs(gamma_t, delta): Zeta1 = zeta1(gamma_t, delta) Zeta2 = zeta2(gamma_t, delta) sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8 amplitude = torch.sqrt(sq_total) Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5 #cterm = exp_cosh_minus_terms(gamma_t, delta) #sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta) #Zcoef3 = 2 * torch.sqrt( cterm / ( gamma_t * (1 - delta) * cterm + sterm ) ) Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) ) return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude def Zcoefs_asymp(gamma_t, delta): A_t = (gamma_t * (1 - delta) )/4 return epxm1_x(- 2 * A_t) class StochasticHarmonicOscillator: """ Simulates a stochastic harmonic oscillator governed by the equations: dy(t) = q(t) dt dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt Also define v(t) = q(t) / √Γ, which is numerically more stable. Where: y(t) - Position variable q(t) - Velocity variable Γ - Damping coefficient A - Harmonic potential strength C - Constant force term D - Noise amplitude dw(t) - Wiener process (Brownian motion) """ def __init__(self, Gamma, A, C, D): self.Gamma = Gamma self.A = A self.C = C self.D = D self.Delta = 1 - 4 * A / Gamma def sig11(self, gamma_t, delta): return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) def sig22(self, gamma_t, delta): return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta) def dynamics(self, y0, v0, t): """ Calculates the position and velocity variables at time t. Parameters: y0 (float): Initial position v0 (float): Initial velocity v(0) = q(0) / √Γ t (float): Time at which to evaluate the dynamics Returns: tuple: (y(t), v(t)) """ dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0 Delta = self.Delta + dummyzero Gamma_hat = self.Gamma * t + dummyzero A = self.A + dummyzero C = self.C + dummyzero D = self.D + dummyzero Gamma = self.Gamma + dummyzero zeta_1 = zeta1( Gamma_hat, Delta) zeta_2 = zeta2( Gamma_hat, Delta) EE = 1 - Gamma_hat * zeta_2 if v0 is None: v0 = torch.randn_like(y0) * D / 2 ** 0.5 #v0 = (C - A * y0)/Gamma**0.5 # Calculate mean position and velocity term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t y_mean = term1 + y0 v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0 cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta) cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2 cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5) # sample new position and velocity with multivariate normal distribution batch_shape = y0.shape cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) cov_matrix[..., 0, 0] = cov_yy cov_matrix[..., 0, 1] = cov_yv cov_matrix[..., 1, 0] = cov_yv # symmetric cov_matrix[..., 1, 1] = cov_vv # Compute the Cholesky decomposition to get scale_tril #scale_tril = torch.linalg.cholesky(cov_matrix) scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) tol = 1e-8 cov_yy = torch.clamp( cov_yy, min = tol ) sd_yy = torch.sqrt( cov_yy ) inv_sd_yy = 1/(sd_yy) scale_tril[..., 0, 0] = sd_yy scale_tril[..., 0, 1] = 0. scale_tril[..., 1, 0] = cov_yv * inv_sd_yy scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5 # check if it matches torch.linalg. #assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 ) # Sample correlated noise from multivariate normal mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype) mean[..., 0] = y_mean mean[..., 1] = v_mean new_yv = torch.distributions.MultivariateNormal( loc=mean, scale_tril=scale_tril ).sample() return new_yv[...,0], new_yv[...,1]