Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/modules/optimizer/symplectic_integrator.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Symplectic Integrator

```{eval-rst}
.. autoclass:: simulated_bifurcation.optimizer.SymplecticIntegrator
.. autoclass:: simulated_bifurcation.optimizer.EulerSymplecticIntegrator
:members:
```
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy>=2.0.0
sympy==1.13.1
torch>=2.2.0
torch==2.8.0
tqdm==4.67.1
2 changes: 1 addition & 1 deletion src/simulated_bifurcation/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
"""

from .environment import get_env, reset_env, set_env
from .integrator import EulerSymplecticIntegrator
from .simulated_bifurcation_engine import SimulatedBifurcationEngine
from .simulated_bifurcation_optimizer import (
ConvergenceWarning,
SimulatedBifurcationOptimizer,
)
from .stop_window import StopWindow
from .symplectic_integrator import SymplecticIntegrator
2 changes: 2 additions & 0 deletions src/simulated_bifurcation/optimizer/integrator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .euler_symplectic_integrator import EulerSymplecticIntegrator
from .stormer_verlet_symplectic_integrator import StormerVerletSymplecticIntegrator
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from typing import Callable, Tuple

import torch
from numpy import minimum

from ..core.tensor_bearer import TensorBearer
from ...core.tensor_bearer import TensorBearer


class SymplecticIntegrator(TensorBearer):
class ABCSymplecticIntegrator(ABC, TensorBearer):
"""
Simulates the evolution of spins' momentum and position following the Hamiltonian quantum mechanics equations that
drive the Simulated Bifurcation (SB) algorithm.
Expand Down Expand Up @@ -42,31 +43,6 @@ def __init__(
def init_oscillator(self, shape: Tuple[int, int]) -> torch.Tensor:
return 2.0 * torch.rand(size=shape, device=self.device, dtype=self.dtype) - 1.0

def position_update(self) -> None:
torch.add(
self.position,
self.momentum,
alpha=self.time_step,
out=self.position,
)

def momentum_update(self) -> None:
torch.add(
self.momentum,
self.position,
alpha=self.time_step * (self.get_current_pressure() - 1.0),
out=self.momentum,
)

def quadratic_momentum_update(self) -> None:
# do not use out=self.position because of side effects
self.momentum = torch.addmm(
self.momentum,
self.quadratic_tensor,
self.activation_function(self.position),
alpha=self.time_step * self.quadratic_scale_parameter,
)

def simulate_inelastic_walls(self) -> None:
self.momentum[torch.abs(self.position) > 1.0] = 0.0
torch.clip(self.position, -1.0, 1.0, out=self.position)
Expand All @@ -85,13 +61,15 @@ def get_current_pressure(self) -> float:
def integration_step(self) -> None:
if self.heat:
momentum_copy = self.momentum.clone()
self.momentum_update()
self.quadratic_momentum_update()
self.position_update()
self.integrate()
self.simulate_inelastic_walls()
if self.heat:
self.simulate_heating(momentum_copy)
self.step += 1

def sample_spins(self) -> torch.Tensor:
return torch.where(self.position >= 0.0, 1.0, -1.0)

@abstractmethod
def integrate(self):
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Callable

import torch

from .abc_symplectic_integrator import ABCSymplecticIntegrator


class EulerSymplecticIntegrator(ABCSymplecticIntegrator):
"""
Simulates the evolution of spins' momentum and position following the Hamiltonian quantum mechanics equations that
drive the Simulated Bifurcation (SB) algorithm.
"""

def __init__(
self,
n_oscillators: int,
time_step: float,
pressure_slope: float,
heat_coefficient: float,
activation_function: Callable[[torch.Tensor], torch.Tensor],
heat: bool,
quadratic_tensor: torch.Tensor,
dtype: torch.dtype,
device: torch.device,
):
super().__init__(
n_oscillators=n_oscillators,
time_step=time_step,
pressure_slope=pressure_slope,
heat_coefficient=heat_coefficient,
activation_function=activation_function,
heat=heat,
quadratic_tensor=quadratic_tensor,
dtype=dtype,
device=device,
)

def position_update(self) -> None:
torch.add(
self.position,
self.momentum,
alpha=self.time_step,
out=self.position,
)

def momentum_update(self) -> None:
torch.add(
self.momentum,
self.position,
alpha=self.time_step * (self.get_current_pressure() - 1.0),
out=self.momentum,
)

def quadratic_momentum_update(self) -> None:
# do not use out=self.position because of side effects
self.momentum = torch.addmm(
self.momentum,
self.quadratic_tensor,
self.activation_function(self.position),
alpha=self.time_step * self.quadratic_scale_parameter,
)

def integrate(self):
self.momentum_update()
self.quadratic_momentum_update()
self.position_update()
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Callable, Tuple

import torch
from numpy import minimum

from .abc_symplectic_integrator import ABCSymplecticIntegrator


class StormerVerletSymplecticIntegrator(ABCSymplecticIntegrator):
"""
Order-2 symplectic integrator based on the Störmer-Verlet integration method.
"""

def __init__(
self,
n_oscillators: int,
time_step: float,
pressure_slope: float,
heat_coefficient: float,
activation_function: Callable[[torch.Tensor], torch.Tensor],
heat: bool,
quadratic_tensor: torch.Tensor,
dtype: torch.dtype,
device: torch.device,
):
super().__init__(
n_oscillators=n_oscillators,
time_step=time_step,
pressure_slope=pressure_slope,
heat_coefficient=heat_coefficient,
activation_function=activation_function,
heat=heat,
quadratic_tensor=quadratic_tensor,
dtype=dtype,
device=device,
)

def integrate(self):
intermediate_position = self.position.clone()
intermediate_momentum = self.momentum.clone()
n = 4
for _ in range(n):
auxiliary_momentum = torch.add(
intermediate_momentum,
intermediate_position,
alpha=self.time_step * (self.get_current_pressure() - 1.0) / (2.0 * n),
)
auxiliary_momentum = torch.addmm(
auxiliary_momentum,
self.quadratic_tensor,
self.activation_function(intermediate_position),
alpha=self.time_step * self.quadratic_scale_parameter / (2.0 * n),
)
torch.add(
intermediate_position,
intermediate_momentum,
alpha=self.time_step / n,
out=intermediate_position,
)
torch.add(
auxiliary_momentum,
intermediate_position,
alpha=self.time_step * (self.get_current_pressure() - 1.0) / (2.0 * n),
out=intermediate_momentum,
)
intermediate_momentum = torch.addmm(
intermediate_momentum,
self.quadratic_tensor,
self.activation_function(intermediate_position),
alpha=self.time_step * self.quadratic_scale_parameter / (2.0 * n),
)
self.position = intermediate_position.clone()
self.momentum = intermediate_momentum.clone()
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from typing import Optional, Union

import torch
from numpy import minimum
from tqdm.auto import tqdm

from ..core.tensor_bearer import TensorBearer
from .environment import ENVIRONMENT
from .integrator import EulerSymplecticIntegrator, StormerVerletSymplecticIntegrator
from .simulated_bifurcation_engine import SimulatedBifurcationEngine
from .stop_window import StopWindow
from .symplectic_integrator import SymplecticIntegrator


class ConvergenceWarning(Warning):
Expand Down Expand Up @@ -131,7 +130,7 @@ def __init_window(self, matrix: torch.Tensor, early_stopping: bool) -> None:
)

def __init_symplectic_integrator(self, matrix: torch.Tensor) -> None:
self.symplectic_integrator = SymplecticIntegrator(
self.symplectic_integrator = StormerVerletSymplecticIntegrator(
self.agents,
self.time_step,
self.pressure_slope,
Expand Down
6 changes: 3 additions & 3 deletions tests/optimizer/test_symplectic_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from src.simulated_bifurcation.optimizer import SymplecticIntegrator
from src.simulated_bifurcation.optimizer import EulerSymplecticIntegrator

from ..test_utils import DEVICES, DTYPES

Expand All @@ -17,8 +17,8 @@ def init_integrator(
device: torch.device,
activation_function: Callable[[torch.Tensor], torch.Tensor],
heat: bool,
) -> SymplecticIntegrator:
symplectic_integrator = SymplecticIntegrator(
) -> EulerSymplecticIntegrator:
symplectic_integrator = EulerSymplecticIntegrator(
2,
0.1,
0.01,
Expand Down
Loading