diff --git a/docs/modules/optimizer/symplectic_integrator.md b/docs/modules/optimizer/symplectic_integrator.md index c466304a..2b3ee80d 100644 --- a/docs/modules/optimizer/symplectic_integrator.md +++ b/docs/modules/optimizer/symplectic_integrator.md @@ -1,6 +1,6 @@ # Symplectic Integrator ```{eval-rst} -.. autoclass:: simulated_bifurcation.optimizer.SymplecticIntegrator +.. autoclass:: simulated_bifurcation.optimizer.EulerSymplecticIntegrator :members: ``` diff --git a/requirements.txt b/requirements.txt index bd277241..84f5a320 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ numpy>=2.0.0 sympy==1.13.1 -torch>=2.2.0 +torch==2.8.0 tqdm==4.67.1 diff --git a/src/simulated_bifurcation/optimizer/__init__.py b/src/simulated_bifurcation/optimizer/__init__.py index d3e4be92..5c9b20a9 100644 --- a/src/simulated_bifurcation/optimizer/__init__.py +++ b/src/simulated_bifurcation/optimizer/__init__.py @@ -15,10 +15,10 @@ """ from .environment import get_env, reset_env, set_env +from .integrator import EulerSymplecticIntegrator from .simulated_bifurcation_engine import SimulatedBifurcationEngine from .simulated_bifurcation_optimizer import ( ConvergenceWarning, SimulatedBifurcationOptimizer, ) from .stop_window import StopWindow -from .symplectic_integrator import SymplecticIntegrator diff --git a/src/simulated_bifurcation/optimizer/integrator/__init__.py b/src/simulated_bifurcation/optimizer/integrator/__init__.py new file mode 100644 index 00000000..dad05fe8 --- /dev/null +++ b/src/simulated_bifurcation/optimizer/integrator/__init__.py @@ -0,0 +1,2 @@ +from .euler_symplectic_integrator import EulerSymplecticIntegrator +from .stormer_verlet_symplectic_integrator import StormerVerletSymplecticIntegrator diff --git a/src/simulated_bifurcation/optimizer/symplectic_integrator.py b/src/simulated_bifurcation/optimizer/integrator/abc_symplectic_integrator.py similarity index 71% rename from src/simulated_bifurcation/optimizer/symplectic_integrator.py rename to src/simulated_bifurcation/optimizer/integrator/abc_symplectic_integrator.py index 92433d22..5e686085 100644 --- a/src/simulated_bifurcation/optimizer/symplectic_integrator.py +++ b/src/simulated_bifurcation/optimizer/integrator/abc_symplectic_integrator.py @@ -1,12 +1,13 @@ +from abc import ABC, abstractmethod from typing import Callable, Tuple import torch from numpy import minimum -from ..core.tensor_bearer import TensorBearer +from ...core.tensor_bearer import TensorBearer -class SymplecticIntegrator(TensorBearer): +class ABCSymplecticIntegrator(ABC, TensorBearer): """ Simulates the evolution of spins' momentum and position following the Hamiltonian quantum mechanics equations that drive the Simulated Bifurcation (SB) algorithm. @@ -42,31 +43,6 @@ def __init__( def init_oscillator(self, shape: Tuple[int, int]) -> torch.Tensor: return 2.0 * torch.rand(size=shape, device=self.device, dtype=self.dtype) - 1.0 - def position_update(self) -> None: - torch.add( - self.position, - self.momentum, - alpha=self.time_step, - out=self.position, - ) - - def momentum_update(self) -> None: - torch.add( - self.momentum, - self.position, - alpha=self.time_step * (self.get_current_pressure() - 1.0), - out=self.momentum, - ) - - def quadratic_momentum_update(self) -> None: - # do not use out=self.position because of side effects - self.momentum = torch.addmm( - self.momentum, - self.quadratic_tensor, - self.activation_function(self.position), - alpha=self.time_step * self.quadratic_scale_parameter, - ) - def simulate_inelastic_walls(self) -> None: self.momentum[torch.abs(self.position) > 1.0] = 0.0 torch.clip(self.position, -1.0, 1.0, out=self.position) @@ -85,9 +61,7 @@ def get_current_pressure(self) -> float: def integration_step(self) -> None: if self.heat: momentum_copy = self.momentum.clone() - self.momentum_update() - self.quadratic_momentum_update() - self.position_update() + self.integrate() self.simulate_inelastic_walls() if self.heat: self.simulate_heating(momentum_copy) @@ -95,3 +69,7 @@ def integration_step(self) -> None: def sample_spins(self) -> torch.Tensor: return torch.where(self.position >= 0.0, 1.0, -1.0) + + @abstractmethod + def integrate(self): + raise NotImplementedError() diff --git a/src/simulated_bifurcation/optimizer/integrator/euler_symplectic_integrator.py b/src/simulated_bifurcation/optimizer/integrator/euler_symplectic_integrator.py new file mode 100644 index 00000000..cc721183 --- /dev/null +++ b/src/simulated_bifurcation/optimizer/integrator/euler_symplectic_integrator.py @@ -0,0 +1,66 @@ +from typing import Callable + +import torch + +from .abc_symplectic_integrator import ABCSymplecticIntegrator + + +class EulerSymplecticIntegrator(ABCSymplecticIntegrator): + """ + Simulates the evolution of spins' momentum and position following the Hamiltonian quantum mechanics equations that + drive the Simulated Bifurcation (SB) algorithm. + """ + + def __init__( + self, + n_oscillators: int, + time_step: float, + pressure_slope: float, + heat_coefficient: float, + activation_function: Callable[[torch.Tensor], torch.Tensor], + heat: bool, + quadratic_tensor: torch.Tensor, + dtype: torch.dtype, + device: torch.device, + ): + super().__init__( + n_oscillators=n_oscillators, + time_step=time_step, + pressure_slope=pressure_slope, + heat_coefficient=heat_coefficient, + activation_function=activation_function, + heat=heat, + quadratic_tensor=quadratic_tensor, + dtype=dtype, + device=device, + ) + + def position_update(self) -> None: + torch.add( + self.position, + self.momentum, + alpha=self.time_step, + out=self.position, + ) + + def momentum_update(self) -> None: + torch.add( + self.momentum, + self.position, + alpha=self.time_step * (self.get_current_pressure() - 1.0), + out=self.momentum, + ) + + def quadratic_momentum_update(self) -> None: + # do not use out=self.position because of side effects + self.momentum = torch.addmm( + self.momentum, + self.quadratic_tensor, + self.activation_function(self.position), + alpha=self.time_step * self.quadratic_scale_parameter, + ) + + def integrate(self): + self.momentum_update() + self.quadratic_momentum_update() + self.position_update() diff --git a/src/simulated_bifurcation/optimizer/integrator/stormer_verlet_symplectic_integrator.py b/src/simulated_bifurcation/optimizer/integrator/stormer_verlet_symplectic_integrator.py new file mode 100644 index 00000000..e91c0a24 --- /dev/null +++ b/src/simulated_bifurcation/optimizer/integrator/stormer_verlet_symplectic_integrator.py @@ -0,0 +1,73 @@ +from typing import Callable, Tuple + +import torch +from numpy import minimum + +from .abc_symplectic_integrator import ABCSymplecticIntegrator + + +class StormerVerletSymplecticIntegrator(ABCSymplecticIntegrator): + """ + Order-2 symplectic integrator based on the Störmer-Verlet integration method. + """ + + def __init__( + self, + n_oscillators: int, + time_step: float, + pressure_slope: float, + heat_coefficient: float, + activation_function: Callable[[torch.Tensor], torch.Tensor], + heat: bool, + quadratic_tensor: torch.Tensor, + dtype: torch.dtype, + device: torch.device, + ): + super().__init__( + n_oscillators=n_oscillators, + time_step=time_step, + pressure_slope=pressure_slope, + heat_coefficient=heat_coefficient, + activation_function=activation_function, + heat=heat, + quadratic_tensor=quadratic_tensor, + dtype=dtype, + device=device, + ) + + def integrate(self): + intermediate_position = self.position.clone() + intermediate_momentum = self.momentum.clone() + n = 4 + for _ in range(n): + auxiliary_momentum = torch.add( + intermediate_momentum, + intermediate_position, + alpha=self.time_step * (self.get_current_pressure() - 1.0) / (2.0 * n), + ) + auxiliary_momentum = torch.addmm( + auxiliary_momentum, + self.quadratic_tensor, + self.activation_function(intermediate_position), + alpha=self.time_step * self.quadratic_scale_parameter / (2.0 * n), + ) + torch.add( + intermediate_position, + intermediate_momentum, + alpha=self.time_step / n, + out=intermediate_position, + ) + torch.add( + auxiliary_momentum, + intermediate_position, + alpha=self.time_step * (self.get_current_pressure() - 1.0) / (2.0 * n), + out=intermediate_momentum, + ) + intermediate_momentum = torch.addmm( + intermediate_momentum, + self.quadratic_tensor, + self.activation_function(intermediate_position), + alpha=self.time_step * self.quadratic_scale_parameter / (2.0 * n), + ) + self.position = intermediate_position.clone() + self.momentum = intermediate_momentum.clone() diff --git a/src/simulated_bifurcation/optimizer/simulated_bifurcation_optimizer.py b/src/simulated_bifurcation/optimizer/simulated_bifurcation_optimizer.py index 2cd24ceb..a4a84a8b 100644 --- a/src/simulated_bifurcation/optimizer/simulated_bifurcation_optimizer.py +++ b/src/simulated_bifurcation/optimizer/simulated_bifurcation_optimizer.py @@ -3,14 +3,13 @@ from typing import Optional, Union import torch -from numpy import minimum from tqdm.auto import tqdm from ..core.tensor_bearer import TensorBearer from .environment import ENVIRONMENT +from .integrator import EulerSymplecticIntegrator, StormerVerletSymplecticIntegrator from .simulated_bifurcation_engine import SimulatedBifurcationEngine from .stop_window import StopWindow -from .symplectic_integrator import SymplecticIntegrator class ConvergenceWarning(Warning): @@ -131,7 +130,7 @@ def __init_window(self, matrix: torch.Tensor, early_stopping: bool) -> None: ) def __init_symplectic_integrator(self, matrix: torch.Tensor) -> None: - self.symplectic_integrator = SymplecticIntegrator( + self.symplectic_integrator = StormerVerletSymplecticIntegrator( self.agents, self.time_step, self.pressure_slope, diff --git a/tests/optimizer/test_symplectic_integrator.py b/tests/optimizer/test_symplectic_integrator.py index 53afba90..10eca04b 100644 --- a/tests/optimizer/test_symplectic_integrator.py +++ b/tests/optimizer/test_symplectic_integrator.py @@ -3,7 +3,7 @@ import pytest import torch -from src.simulated_bifurcation.optimizer import SymplecticIntegrator +from src.simulated_bifurcation.optimizer import EulerSymplecticIntegrator from ..test_utils import DEVICES, DTYPES @@ -17,8 +17,8 @@ def init_integrator( device: torch.device, activation_function: Callable[[torch.Tensor], torch.Tensor], heat: bool, -) -> SymplecticIntegrator: - symplectic_integrator = SymplecticIntegrator( +) -> EulerSymplecticIntegrator: + symplectic_integrator = EulerSymplecticIntegrator( 2, 0.1, 0.01,