diff --git a/.github/workflows/python-check.yml b/.github/workflows/python-check.yml index bbfbc72..e13c384 100644 --- a/.github/workflows/python-check.yml +++ b/.github/workflows/python-check.yml @@ -45,15 +45,15 @@ jobs: - name: Check formatting with black working-directory: ./python - run: black --check ./src ./test + run: black --check ./src ./test ./tools - name: Sort imports with isort working-directory: ./python - run: isort --check-only --diff ./src ./test + run: isort --check-only --diff ./src ./test ./tools - name: Type checking with mypy working-directory: ./python - run: mypy --ignore-missing-imports ./src ./test + run: mypy --ignore-missing-imports ./src ./test ./tools - name: Run tests with pytest working-directory: ./python diff --git a/python/src/common/types.py b/python/src/common/types.py index 10ad999..fbd9b12 100644 --- a/python/src/common/types.py +++ b/python/src/common/types.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from dataclasses import dataclass from enum import Enum from src.common.base_pool_state import BasePoolState +from src.pools.buffer.buffer_data import BufferState from src.pools.gyro.gyro_2clp_data import Gyro2CLPState from src.pools.gyro.gyro_eclp_data import GyroECLPState from src.pools.quantamm.quantamm_data import QuantAmmState @@ -28,6 +31,8 @@ class AddLiquidityInput: class AddLiquidityResult: bpt_amount_out_raw: int amounts_in_raw: list[int] + updated_pool_state: PoolState + swap_fee_amounts_scaled18: list[int] class RemoveLiquidityKind(Enum): @@ -48,6 +53,8 @@ class RemoveLiquidityInput: class RemoveLiquidityResult: bpt_amount_in_raw: int amounts_out_raw: list[int] + updated_pool_state: PoolState + swap_fee_amounts_scaled18: list[int] class SwapKind(Enum): @@ -65,7 +72,9 @@ class SwapInput: @dataclass class SwapResult: - amount_out_raw: int + amount_calculated_raw: int + updated_pool_state: PoolState | BufferState + swap_fee_amount_scaled18: int PoolState = ( diff --git a/python/src/vault/add_liquidity.py b/python/src/vault/add_liquidity.py index 155f728..375fd25 100644 --- a/python/src/vault/add_liquidity.py +++ b/python/src/vault/add_liquidity.py @@ -1,3 +1,5 @@ +from dataclasses import replace + from src.common.base_pool_math import ( compute_add_liquidity_single_token_exact_out, compute_add_liquidity_unbalanced, @@ -149,7 +151,16 @@ def add_liquidity( for i, a in enumerate(after_add_result.hook_adjusted_amounts_in_raw): amounts_in_raw[i] = a + # Create updated pool state with new balances and total supply + updated_pool_state = replace( + pool_state, + balances_live_scaled18=updated_balances_live_scaled18, + total_supply=pool_state.total_supply + bpt_amount_out, + ) + return AddLiquidityResult( bpt_amount_out_raw=bpt_amount_out, amounts_in_raw=amounts_in_raw, + updated_pool_state=updated_pool_state, + swap_fee_amounts_scaled18=swap_fee_amounts_scaled18, ) diff --git a/python/src/vault/remove_liquidity.py b/python/src/vault/remove_liquidity.py index 227f063..421ad66 100644 --- a/python/src/vault/remove_liquidity.py +++ b/python/src/vault/remove_liquidity.py @@ -1,3 +1,5 @@ +from dataclasses import replace + from src.common.base_pool_math import ( compute_proportional_amounts_out, compute_remove_liquidity_single_token_exact_in, @@ -16,6 +18,7 @@ _get_single_input_index, _require_unbalanced_liquidity_enabled, _to_raw_undo_rate_round_down, + _to_scaled_18_apply_rate_round_down, ) from src.hooks.types import HookBase, HookState @@ -127,7 +130,7 @@ def remove_liquidity( # A Pool's token balance always decreases after an exit # Computes protocol and pool creator fee which is eventually taken from pool balance - aggregate_swap_fee_amount_scaled18 = _compute_and_charge_aggregate_swap_fees( + aggregate_swap_fee_amount_raw = _compute_and_charge_aggregate_swap_fees( swap_fee_amounts_scaled18[i], pool_state.aggregate_swap_fee, pool_state.scaling_factors, @@ -135,6 +138,12 @@ def remove_liquidity( i, ) + aggregate_swap_fee_amount_scaled18 = _to_scaled_18_apply_rate_round_down( + aggregate_swap_fee_amount_raw, + pool_state.scaling_factors[i], + pool_state.token_rates[i], + ) + updated_balances_live_scaled18[i] = updated_balances_live_scaled18[i] - ( amounts_out_scaled18[i] + aggregate_swap_fee_amount_scaled18 ) @@ -163,7 +172,16 @@ def remove_liquidity( for i, a in enumerate(after_remove_result.hook_adjusted_amounts_out_raw): amounts_out_raw[i] = a + # Create updated pool state with new balances and total supply + updated_pool_state = replace( + pool_state, + balances_live_scaled18=updated_balances_live_scaled18, + total_supply=pool_state.total_supply - bpt_amount_in, + ) + return RemoveLiquidityResult( bpt_amount_in_raw=bpt_amount_in, amounts_out_raw=amounts_out_raw, + updated_pool_state=updated_pool_state, + swap_fee_amounts_scaled18=swap_fee_amounts_scaled18, ) diff --git a/python/src/vault/swap.py b/python/src/vault/swap.py index 8f690a0..da574cd 100644 --- a/python/src/vault/swap.py +++ b/python/src/vault/swap.py @@ -1,8 +1,10 @@ +from dataclasses import replace + from src.common.constants import WAD from src.common.maths import complement_fixed, mul_div_up_fixed, mul_up_fixed from src.common.pool_base import PoolBase from src.common.swap_params import SwapParams -from src.common.types import PoolState, SwapInput, SwapKind +from src.common.types import PoolState, SwapInput, SwapKind, SwapResult from src.common.utils import ( _compute_and_charge_aggregate_swap_fees, _to_raw_undo_rate_round_down, @@ -22,7 +24,7 @@ def swap( pool_class: PoolBase, hook_class: HookBase, hook_state: HookState | object | None, -) -> int: +) -> SwapResult: input_index = find_case_insensitive_index_in_list( pool_state.tokens, swap_input.token_in ) @@ -123,7 +125,7 @@ def swap( pool_state.token_rates[input_index], ) - aggregate_swap_fee_amount_scaled18 = _compute_and_charge_aggregate_swap_fees( + aggregate_swap_fee_amount_raw = _compute_and_charge_aggregate_swap_fees( total_swap_fee_amount_scaled18, pool_state.aggregate_swap_fee, pool_state.scaling_factors, @@ -131,6 +133,12 @@ def swap( input_index, ) + aggregate_swap_fee_amount_scaled18 = _to_scaled_18_apply_rate_round_down( + aggregate_swap_fee_amount_raw, + pool_state.scaling_factors[input_index], + pool_state.token_rates[input_index], + ) + # For ExactIn, we increase the tokenIn balance by `amountIn`, # and decrease the tokenOut balance by the # (`amountOut` + fees). @@ -185,7 +193,16 @@ def swap( after_swap_result.hook_adjusted_amount_calculated_raw ) - return amount_calculated_raw + # Create updated pool state with new balances + updated_pool_state = replace( + pool_state, balances_live_scaled18=updated_balances_live_scaled18 + ) + + return SwapResult( + amount_calculated_raw=amount_calculated_raw, + updated_pool_state=updated_pool_state, + swap_fee_amount_scaled18=total_swap_fee_amount_scaled18, + ) def _compute_amount_given_scaled18( diff --git a/python/src/vault/vault.py b/python/src/vault/vault.py index f028637..e1dcfe0 100644 --- a/python/src/vault/vault.py +++ b/python/src/vault/vault.py @@ -1,3 +1,4 @@ +from dataclasses import replace from typing import Dict, Optional, Type from src.common.pool_base import PoolBase @@ -8,6 +9,7 @@ RemoveLiquidityInput, RemoveLiquidityResult, SwapInput, + SwapResult, ) from src.hooks.default_hook import DefaultHook from src.hooks.exit_fee.exit_fee_hook import ExitFeeHook @@ -60,13 +62,23 @@ def swap( swap_input: SwapInput, pool_state: PoolState | BufferState, hook_state: HookState | object | None = None, - ) -> int: + ) -> SwapResult: if swap_input.amount_raw == 0: - return 0 + return SwapResult( + amount_calculated_raw=0, + updated_pool_state=pool_state, + swap_fee_amount_scaled18=0, + ) # buffer is handled separately than a "normal" pool if isinstance(pool_state, BufferState): - return erc4626_buffer_wrap_or_unwrap(swap_input, pool_state) + amount_out = erc4626_buffer_wrap_or_unwrap(swap_input, pool_state) + # Buffer state doesn't change in the same way, but we still return it + return SwapResult( + amount_calculated_raw=amount_out, + updated_pool_state=pool_state, + swap_fee_amount_scaled18=0, + ) pool_class = self._get_pool(pool_state=pool_state) hook_class = self._get_hook( hook_name=pool_state.hook_type, hook_state=hook_state diff --git a/python/test/hooks/test_after_swap.py b/python/test/hooks/test_after_swap.py index d65a0e5..b8bdc48 100644 --- a/python/test/hooks/test_after_swap.py +++ b/python/test/hooks/test_after_swap.py @@ -136,12 +136,12 @@ def test_hook_after_swap_no_fee(): ] ) custom_state_no_fee = map_custom_pool_state({**pool, "aggregateSwapFee": 0}) - test = vault.swap( + swap_result = vault.swap( swap_input=swap_input, pool_state=custom_state_no_fee, hook_state=input_hook_state, ) - assert test == 1 + assert swap_result.amount_calculated_raw == 1 def test_hook_after_swap_with_fee(): @@ -160,9 +160,9 @@ def test_hook_after_swap_with_fee(): custom_state_with_fee = map_custom_pool_state( {**pool, "aggregateSwapFee": 500000000000000000} ) - test = vault.swap( + swap_result = vault.swap( swap_input=swap_input, pool_state=custom_state_with_fee, hook_state=input_hook_state, ) - assert test == 1 + assert swap_result.amount_calculated_raw == 1 diff --git a/python/test/hooks/test_before_swap.py b/python/test/hooks/test_before_swap.py index 5d8f37c..4e78ee2 100644 --- a/python/test/hooks/test_before_swap.py +++ b/python/test/hooks/test_before_swap.py @@ -87,9 +87,9 @@ def test_before_swap(): balance_change=[1000000000000000000, 1000000000000000000] ) weighted_state = map_weighted_state(pool) - test = vault.swap( + swap_result = vault.swap( swap_input=swap_input, pool_state=weighted_state, hook_state=input_hook_state, ) - assert test == 89999999 + assert swap_result.amount_calculated_raw == 89999999 diff --git a/python/test/hooks/test_stable_surge.py b/python/test/hooks/test_stable_surge.py index 94050c9..199f76f 100644 --- a/python/test/hooks/test_stable_surge.py +++ b/python/test/hooks/test_stable_surge.py @@ -41,10 +41,10 @@ def test_below_surge_threshold_static_swap_fee_case1(): token_in=pool_state["tokens"][0], token_out=pool_state["tokens"][1], ) - output_amount = vault.swap( + swap_result = vault.swap( swap_input=swap_input, pool_state=stable_state, hook_state=hook_state ) - assert output_amount == 78522716365403684 + assert swap_result.amount_calculated_raw == 78522716365403684 def test_below_surge_threshold_static_swap_fee_case2(): @@ -54,10 +54,10 @@ def test_below_surge_threshold_static_swap_fee_case2(): token_in=pool_state["tokens"][0], token_out=pool_state["tokens"][1], ) - output_amount = vault.swap( + swap_result = vault.swap( swap_input=swap_input, pool_state=stable_state, hook_state=hook_state ) - assert output_amount == 452983383563178802 + assert swap_result.amount_calculated_raw == 452983383563178802 def test_above_surge_threshold_uses_surge_fee(): @@ -67,7 +67,7 @@ def test_above_surge_threshold_uses_surge_fee(): token_in=pool_state["tokens"][1], token_out=pool_state["tokens"][0], ) - output_amount = vault.swap( + swap_result = vault.swap( swap_input=swap_input, pool_state=stable_state, hook_state=hook_state ) - assert output_amount == 3252130027531260 + assert swap_result.amount_calculated_raw == 3252130027531260 diff --git a/python/test/test_add_liquidity.py b/python/test/test_add_liquidity.py index 01bbcbd..05327cb 100644 --- a/python/test/test_add_liquidity.py +++ b/python/test/test_add_liquidity.py @@ -3,8 +3,10 @@ transform_strings_to_ints, ) from test.utils.read_test_data import read_test_data +from test.utils.validate_balances import build_add_liquidity_deltas, validate_balances from typing import cast +from src.common.base_pool_state import BasePoolState from src.common.types import AddLiquidityInput, AddLiquidityKind, PoolState from src.vault.vault import Vault @@ -42,3 +44,25 @@ def test_add_liquidity(): assert calculated_amount.amounts_in_raw == list( map(int, add_test["inputAmountsRaw"]) ) + + # Validate updated balances + # Skip validation if hook has shouldCallComputeDynamicSwapFee enabled + should_validate = True + if isinstance(pool_state, BasePoolState): + hook = vault._get_hook( + hook_name=pool_state.hook_type, hook_state=hook_state + ) + should_validate = not hook.should_call_compute_dynamic_swap_fee + + if should_validate: + deltas = build_add_liquidity_deltas( + pool_state=pool_state, + amounts_in_raw=calculated_amount.amounts_in_raw, + swap_fee_amounts_scaled18=calculated_amount.swap_fee_amounts_scaled18, + ) + if deltas is not None: + validate_balances( + initial_pool_state=pool_state, + updated_pool_state=calculated_amount.updated_pool_state, + amount_deltas_raw=deltas, + ) diff --git a/python/test/test_custom_pool.py b/python/test/test_custom_pool.py index 40f96e6..ecad1c9 100644 --- a/python/test/test_custom_pool.py +++ b/python/test/test_custom_pool.py @@ -55,7 +55,7 @@ def test_custom_pool(): } custom_pool_state = map_custom_pool_state(pool_state) vault = Vault(custom_pool_classes={"CustomPool": CustomPool}) - calculated_amount = vault.swap( + swap_result = vault.swap( swap_input=SwapInput( amount_raw=1000000000000000000, token_in="0x7b79995e5f793A07Bc00c21412e50Ecae098E7f9", @@ -64,7 +64,7 @@ def test_custom_pool(): ), pool_state=custom_pool_state, ) - assert calculated_amount == custom_pool_state.randoms[0] + assert swap_result.amount_calculated_raw == custom_pool_state.randoms[0] class CustomPool(PoolBase): diff --git a/python/test/test_remove_liquidity.py b/python/test/test_remove_liquidity.py index b21cca8..41f2773 100644 --- a/python/test/test_remove_liquidity.py +++ b/python/test/test_remove_liquidity.py @@ -3,8 +3,13 @@ transform_strings_to_ints, ) from test.utils.read_test_data import read_test_data +from test.utils.validate_balances import ( + build_remove_liquidity_deltas, + validate_balances, +) from typing import cast +from src.common.base_pool_state import BasePoolState from src.common.types import PoolState, RemoveLiquidityInput, RemoveLiquidityKind from src.vault.vault import Vault @@ -37,3 +42,25 @@ def test_remove_liquidity(): assert calculated_amount.amounts_out_raw == list( map(int, remove_test["amountsOutRaw"]) ) + + # Validate updated balances + # Skip validation if hook has shouldCallComputeDynamicSwapFee enabled + should_validate = True + if isinstance(pool_state, BasePoolState): + hook = vault._get_hook( + hook_name=pool_state.hook_type, hook_state=hook_state + ) + should_validate = not hook.should_call_compute_dynamic_swap_fee + + if should_validate: + deltas = build_remove_liquidity_deltas( + pool_state=pool_state, + amounts_out_raw=calculated_amount.amounts_out_raw, + swap_fee_amounts_scaled18=calculated_amount.swap_fee_amounts_scaled18, + ) + if deltas is not None: + validate_balances( + initial_pool_state=pool_state, + updated_pool_state=calculated_amount.updated_pool_state, + amount_deltas_raw=deltas, + ) diff --git a/python/test/test_swaps.py b/python/test/test_swaps.py index 212c4a4..1275ff8 100644 --- a/python/test/test_swaps.py +++ b/python/test/test_swaps.py @@ -3,6 +3,7 @@ transform_strings_to_ints, ) from test.utils.read_test_data import read_test_data +from test.utils.validate_balances import build_swap_deltas, validate_balances from src.common.types import SwapInput, SwapKind from src.vault.vault import Vault @@ -12,32 +13,64 @@ def test_swaps(): vault = Vault() + for swap_test in test_data["swaps"]: if swap_test["test"] == "1-23511249-GyroECLP-Barter.json": continue - print(swap_test["test"]) - if swap_test["test"] not in test_data["pools"]: - raise ValueError(f"Pool not in test data: {swap_test['test']}") - pool = test_data["pools"][swap_test["test"]] - # note any amounts must be passed as ints not strings + + test_name = swap_test["test"] + if test_name not in test_data["pools"]: + raise ValueError(f"Pool not in test data: {test_name}") + + pool = test_data["pools"][test_name] pool_with_ints = transform_strings_to_ints(pool) pool_state, hook_state = map_pool_and_hook_state(pool_with_ints) - calculated_amount = vault.swap( - swap_input=SwapInput( - amount_raw=int(swap_test["amountRaw"]), - token_in=swap_test["tokenIn"], - token_out=swap_test["tokenOut"], - swap_kind=SwapKind(swap_test["swapKind"]), - ), + + swap_input = SwapInput( + amount_raw=int(swap_test["amountRaw"]), + token_in=swap_test["tokenIn"], + token_out=swap_test["tokenOut"], + swap_kind=SwapKind(swap_test["swapKind"]), + ) + swap_result = vault.swap( + swap_input=swap_input, pool_state=pool_state, hook_state=hook_state, ) + + # Validate amount out if pool["poolType"] == "Buffer": assert are_big_ints_within_percent( - calculated_amount, int(swap_test["outputRaw"]), 0.01 + swap_result.amount_calculated_raw, int(swap_test["outputRaw"]), 0.01 ) else: - assert calculated_amount == int(swap_test["outputRaw"]) + assert swap_result.amount_calculated_raw == int(swap_test["outputRaw"]) + + # Validate updated balances + amount_in_raw = ( + swap_input.amount_raw + if swap_input.swap_kind.value == SwapKind.GIVENIN.value + else swap_result.amount_calculated_raw + ) + amount_out_raw = ( + swap_input.amount_raw + if swap_input.swap_kind.value == SwapKind.GIVENOUT.value + else swap_result.amount_calculated_raw + ) + deltas = build_swap_deltas( + pool_state=pool_state, + token_in=swap_input.token_in, + token_out=swap_input.token_out, + amount_in_raw=amount_in_raw, + amount_out_raw=amount_out_raw, + swap_fee_amount_scaled18=swap_result.swap_fee_amount_scaled18, + ) + if deltas is not None: + validate_balances( + initial_pool_state=pool_state, + updated_pool_state=swap_result.updated_pool_state, + amount_deltas_raw=deltas, + ) def are_big_ints_within_percent(value1: int, value2: int, percent: float) -> bool: diff --git a/python/test/utils/map_pool_state.py b/python/test/utils/map_pool_state.py index 6eedadf..6d11db9 100644 --- a/python/test/utils/map_pool_state.py +++ b/python/test/utils/map_pool_state.py @@ -15,10 +15,8 @@ from src.pools.weighted.weighted_data import map_weighted_state -def map_pool_state(pool_state: dict) -> PoolState | BufferState: - if pool_state["poolType"] == "Buffer": - return map_buffer_state(pool_state) - elif pool_state["poolType"] == "GYRO": +def map_pool_state(pool_state: dict) -> PoolState: + if pool_state["poolType"] == "GYRO": return map_gyro_2clp_state(pool_state) elif pool_state["poolType"] == "GYROE": return map_gyro_eclp_state(pool_state) @@ -38,6 +36,13 @@ def map_pool_state(pool_state: dict) -> PoolState | BufferState: raise ValueError(f"Unsupported pool type: {pool_state['poolType']}") +def map_pool_and_buffer_state(pool_state: dict) -> PoolState | BufferState: + if pool_state["poolType"] == "Buffer": + return map_buffer_state(pool_state) + else: + return map_pool_state(pool_state) + + def transform_strings_to_ints(pool_with_strings): pool_with_ints = {} for key, value in pool_with_strings.items(): @@ -68,10 +73,6 @@ def map_pool_and_hook_state( """ Maps pool data to pool state and hook state (if present). - This function maps both the pool state and any associated hook state from - the raw pool dictionary. Hook data is extracted and mapped only for pool - types that support hooks (STABLE and WEIGHTED). - Args: pool: Pool dict from JSON (already converted to ints via transform_strings_to_ints) @@ -81,7 +82,7 @@ def map_pool_and_hook_state( - hook_state: Mapped HookState if hook exists and pool supports it, None otherwise """ # First, map the pool state - pool_state = map_pool_state(pool) + pool_state = map_pool_and_buffer_state(pool) # Check if pool has hook data hook_data = pool.get("hook") diff --git a/python/test/utils/read_test_data.py b/python/test/utils/read_test_data.py index 00915d2..631a335 100644 --- a/python/test/utils/read_test_data.py +++ b/python/test/utils/read_test_data.py @@ -44,7 +44,7 @@ def read_test_data(): test_data["removes"].append( { **remove, - "kind": mapRemoveKind(remove["kind"]), + "kind": map_remove_kind(remove["kind"]), "test": filename, } ) @@ -54,7 +54,7 @@ def read_test_data(): return test_data -def mapRemoveKind(kind): +def map_remove_kind(kind): if kind == "Proportional": return 0 elif kind == "SingleTokenExactIn": diff --git a/python/test/utils/validate_balances.py b/python/test/utils/validate_balances.py new file mode 100644 index 0000000..c50448c --- /dev/null +++ b/python/test/utils/validate_balances.py @@ -0,0 +1,164 @@ +from src.common.utils import ( + _compute_and_charge_aggregate_swap_fees, + _to_raw_undo_rate_round_down, + find_case_insensitive_index_in_list, +) + + +def validate_balances( + initial_pool_state, + updated_pool_state, + amount_deltas_raw: list[int], +): + """ + Validates that updated balances match expected values based on amount deltas. + + Compares for each token: + - initial_balance[i] + amount_delta[i] ≈ updated_balance[i] + + Use positive deltas for amounts added to the pool (add liquidity, swap token in). + Use negative deltas for amounts removed from the pool (remove liquidity, swap token out). + + Amounts are scaled from raw to scaled18 using scaling_factors and token_rates. + Allows tolerance of a few wei for rounding differences. + """ + # Skip validation for buffer pools (they don't have balances_live_scaled18) + if not hasattr(initial_pool_state, "balances_live_scaled18"): + return + + tolerance = 100 + + for i, token in enumerate(initial_pool_state.tokens): + initial_balance_raw = _to_raw_undo_rate_round_down( + initial_pool_state.balances_live_scaled18[i], + initial_pool_state.scaling_factors[i], + initial_pool_state.token_rates[i], + ) + + expected_balance = initial_balance_raw + amount_deltas_raw[i] + + actual_balance = _to_raw_undo_rate_round_down( + updated_pool_state.balances_live_scaled18[i], + updated_pool_state.scaling_factors[i], + updated_pool_state.token_rates[i], + ) + + diff = abs(actual_balance - expected_balance) + + if diff > tolerance: + raise AssertionError( + f"Token {token} balance mismatch:\n" + f" Expected: {expected_balance}\n" + f" Actual: {actual_balance}\n" + f" Diff: {diff} (tolerance: {tolerance})\n" + f" Initial: {initial_pool_state.balances_live_scaled18[i]}\n" + f" Delta (raw): {amount_deltas_raw[i]}\n" + ) + + +def build_add_liquidity_deltas( + pool_state, + amounts_in_raw: list[int], + swap_fee_amounts_scaled18: list[int], +) -> list[int] | None: + """ + Builds amount deltas array for an add liquidity operation. + + Returns an array of deltas where each position has: + +amount_in_raw - aggregate_swap_fee_raw + + Returns None for buffer pools (they don't have balances_live_scaled18). + """ + # Skip for buffer pools + if not hasattr(pool_state, "balances_live_scaled18"): + return None + + deltas = [] + for i in range(len(pool_state.tokens)): + aggregate_swap_fee_amount_raw = _compute_and_charge_aggregate_swap_fees( + swap_fee_amounts_scaled18[i], + pool_state.aggregate_swap_fee, + pool_state.scaling_factors, + pool_state.token_rates, + i, + ) + deltas.append(amounts_in_raw[i] - aggregate_swap_fee_amount_raw) + + return deltas + + +def build_remove_liquidity_deltas( + pool_state, + amounts_out_raw: list[int], + swap_fee_amounts_scaled18: list[int], +) -> list[int] | None: + """ + Builds amount deltas array for a remove liquidity operation. + + Returns an array of deltas where each position has: + -(amount_out_raw + aggregate_swap_fee_raw) + + Returns None for buffer pools (they don't have balances_live_scaled18). + """ + # Skip for buffer pools + if not hasattr(pool_state, "balances_live_scaled18"): + return None + + deltas = [] + for i in range(len(pool_state.tokens)): + aggregate_swap_fee_amount_raw = _compute_and_charge_aggregate_swap_fees( + swap_fee_amounts_scaled18[i], + pool_state.aggregate_swap_fee, + pool_state.scaling_factors, + pool_state.token_rates, + i, + ) + deltas.append(-(amounts_out_raw[i] + aggregate_swap_fee_amount_raw)) + + return deltas + + +def build_swap_deltas( + pool_state, + token_in: str, + token_out: str, + amount_in_raw: int, + amount_out_raw: int, + swap_fee_amount_scaled18: int, +) -> list[int] | None: + """ + Builds amount deltas array for a swap operation. + + Returns an array of deltas where: + - token_in position has +amount_in_raw (minus aggregate fee) + - token_out position has -amount_out_raw + - all other positions have 0 + + Returns None for buffer pools (they don't have balances_live_scaled18). + """ + + # Skip for buffer pools + if not hasattr(pool_state, "balances_live_scaled18"): + return None + + token_in_index = find_case_insensitive_index_in_list(pool_state.tokens, token_in) + token_out_index = find_case_insensitive_index_in_list(pool_state.tokens, token_out) + + if token_in_index == -1 or token_out_index == -1: + raise ValueError(f"Token not found in pool: {token_in} or {token_out}") + + # Calculate aggregate fee from the swap fee amount + aggregate_swap_fee_amount_raw = _compute_and_charge_aggregate_swap_fees( + swap_fee_amount_scaled18, + pool_state.aggregate_swap_fee, + pool_state.scaling_factors, + pool_state.token_rates, + token_in_index, + ) + + # Build deltas array + deltas = [0] * len(pool_state.tokens) + deltas[token_in_index] = amount_in_raw - aggregate_swap_fee_amount_raw + deltas[token_out_index] = -amount_out_raw + + return deltas diff --git a/python/tools/__init__.py b/python/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/tools/simulator/README.md b/python/tools/simulator/README.md new file mode 100644 index 0000000..49c992b --- /dev/null +++ b/python/tools/simulator/README.md @@ -0,0 +1,274 @@ +# Balancer Pool Simulator + +A stateful security testing tool for Balancer V3 pools that enables sequential operation simulation and invariant tracking. + +## Overview + +The Pool Simulator provides a stateful wrapper around balancer-maths operations, designed specifically for security analysis and vulnerability testing. It uses a separate data directory (`simulationData`) to avoid interfering with the test suite. It supports: + +- Sequential swaps, add/remove liquidity operations +- COMMIT vs SIMULATE execution modes +- State snapshots and rollback +- Operation history tracking +- Invariant monitoring across operations + +## Directory Structure + +The simulator uses a separate configuration and data directory to avoid conflicts with the test suite: + +``` +testData/ +├── config.json # Test suite configuration +├── config_simulator.json # Simulator configuration (add pools here) +├── testData/ # Test suite data (generated from config.json) +└── simulationData/ # Simulator data (generated from config_simulator.json) +``` + +**To add pools for simulation:** +1. Edit `testData/config_simulator.json` to include your desired pools +2. Run the data generation script to populate `testData/simulationData/` +3. The simulator will automatically load from `simulationData` + +## Quick Start + +### Basic Usage + +```python +from tools.simulator import PoolSimulator, StateLoader, ExecutionMode +from src.common.types import SwapInput, SwapKind +from test.utils.read_test_data import read_test_data + +# Load pool data from simulationData +test_data = read_test_data(use_simulation_data=True) +pool_dict = test_data["pools"]["11155111-7439300-Weighted-USDC-DAI.json"] +pool_state, hook_state = StateLoader.from_pool_dict(pool_dict) + +# Or load directly from a JSON file +# pool_state, hook_state = StateLoader.from_json_file("path/to/pool.json") + +# Initialize simulator +simulator = PoolSimulator(pool_state, hook_state) + +# Execute a swap in COMMIT mode (applies state changes) +swap_input = SwapInput( + amount_raw=1000000000000000000, # 1.0 tokens (18 decimals) + swap_kind=SwapKind.GIVENIN, + token_in="0x...", + token_out="0x..." +) +result = simulator.swap(swap_input, mode=ExecutionMode.COMMIT) + +# Check current pool state +current_state = simulator.current_state +print(f"New balances: {current_state.balances_live_scaled18}") +``` + +### Execution Modes + +**COMMIT Mode** - Applies state changes permanently: +```python +result = simulator.swap(swap_input, mode=ExecutionMode.COMMIT) +# State is updated after this call +``` + +**SIMULATE Mode** - Queries without mutating state: +```python +result = simulator.swap(swap_input, mode=ExecutionMode.SIMULATE) +# State remains unchanged, but result is calculated +``` + +### Snapshots & Rollback + +Create snapshots to test exploit scenarios and rollback: + +```python +# Create snapshot before potentially malicious operations +snapshot_id = simulator.create_snapshot("before_attack") + +# Execute suspicious sequence +simulator.swap(...) +simulator.add_liquidity(...) + +# Rollback if needed +simulator.restore_snapshot(snapshot_id) + +# Or reset to initial state completely +simulator.reset() +``` + +### Operation History + +Track all operations for analysis: + +```python +history = simulator.get_history() +for op in history: + print(f"Operation {op.sequence}: {op.operation_type}") + print(f" Committed: {op.was_committed}") + print(f" Balances before: {op.balances_before}") + print(f" Balances after: {op.balances_after}") +``` + +## Example: Fuzzing Script + +Run the included fuzzing example to see invariant tracking across sequential swaps: + +```bash +python3 -m tools.simulator.fuzz_example +``` + +This script demonstrates: +- Loading pool state from testData +- Executing 10 randomized swaps (5-20% of balance) +- Tracking invariant changes after each operation +- Detecting protocol violations (MaxInRatio) + +### Sample Output + +``` +================================================================================ +Balancer Pool Fuzzing Simulator - Invariant Tracking +================================================================================ + +Selected pool: 11155111-7439300-Weighted-USDC-DAI.json +Pool type: WEIGHTED + +Initial Pool State: + Tokens: 2 + Token 0: 6,916.38 + Token 1: 6,240.66 + Initial Invariant: 6,569.8399 + +Swap #1: + Direction: Token 1 → Token 0 + Amount In: 511.37 (Token 1) + Amount Out: 0.00 (Token 0) + Balance In: 6,240.66 → 6,752.02 + Balance Out: 6,916.38 → 6,397.42 + Invariant Before: 6,569.8399 + Invariant After: 6,572.3292 + Invariant Change: 2.4893 (↑ 0.037889%) +``` + +## Security Testing Patterns + +### Pattern 1: Invariant Manipulation Detection + +```python +from src.pools.weighted.weighted_math import compute_invariant_down + +# Calculate invariant before operation +invariant_before = compute_invariant_down( + pool_state.weights, pool_state.balances_live_scaled18 +) + +# Execute suspicious operation +result = simulator.swap(..., mode=ExecutionMode.COMMIT) + +# Check invariant after +invariant_after = compute_invariant_down( + simulator.current_state.weights, + simulator.current_state.balances_live_scaled18 +) + +# Invariant should only increase due to fees +if invariant_after < invariant_before: + print("⚠️ INVARIANT DECREASED - POTENTIAL EXPLOIT") +``` + +### Pattern 2: Sandwich Attack Simulation + +```python +# Front-run: Large swap in same direction +simulator.swap(SwapInput(...), mode=ExecutionMode.COMMIT) + +# Victim transaction +snapshot = simulator.create_snapshot() +victim_result = simulator.swap(victim_swap_input, mode=ExecutionMode.COMMIT) + +# Back-run: Reverse the initial swap +simulator.swap(reverse_swap_input, mode=ExecutionMode.COMMIT) + +# Analyze profit +# ...then rollback to try other scenarios +simulator.restore_snapshot(snapshot) +``` + +### Pattern 3: Reentrancy Simulation + +```python +# Simulate nested calls by manually orchestrating state +simulator.swap(initial_swap, mode=ExecutionMode.COMMIT) + +# Before state is fully committed, simulate reentrant call +# (In practice, the protocol should prevent this) +reentrant_result = simulator.swap(reentrant_swap, mode=ExecutionMode.SIMULATE) + +# Check if reentrant call would succeed with stale state +``` + +## API Reference + +### PoolSimulator + +#### Constructor +```python +PoolSimulator( + initial_pool_state: PoolState, + initial_hook_state: Optional[HookState] = None, + config: Optional[SimulatorConfig] = None +) +``` + +#### Core Operations +- `swap(swap_input, mode=ExecutionMode.COMMIT) -> SwapResult` +- `add_liquidity(add_liquidity_input, mode=ExecutionMode.COMMIT) -> AddLiquidityResult` +- `remove_liquidity(remove_liquidity_input, mode=ExecutionMode.COMMIT) -> RemoveLiquidityResult` + +#### State Management +- `current_state: PoolState` - Get current pool state (property) +- `initial_state: PoolState` - Get initial pool state (property) +- `reset()` - Reset to initial state and clear history + +#### Snapshots +- `create_snapshot(name: Optional[str]) -> str` - Create state snapshot +- `restore_snapshot(snapshot_id: str)` - Restore from snapshot +- `delete_snapshot(snapshot_id: str)` - Delete snapshot +- `list_snapshots() -> List[str]` - List all snapshot IDs + +#### History +- `get_history() -> List[OperationRecord]` - Get operation history +- `clear_history()` - Clear operation history + +### StateLoader + +#### Methods +- `from_json_file(filepath: str) -> Tuple[PoolState, Optional[HookState]]` +- `from_pool_dict(pool_dict: dict) -> Tuple[PoolState, Optional[HookState]]` + +### SimulatorConfig + +```python +@dataclass +class SimulatorConfig: + track_history: bool = True + max_history_size: Optional[int] = None +``` + +## Limitations + +- **Buffer pools not supported** - Simulator only works with regular AMM pools (Weighted, Stable, etc.) +- **No flash loans** - Flash loan mechanics are not simulated +- **Hook simulation** - Hook state tracking is available but hooks must be implemented in balancer-maths +- **Block-based loading** - Direct blockchain state loading not yet implemented (use testData JSON files) + +## Contributing + +When adding new pool types or hooks, ensure they are compatible with the simulator by: +1. Implementing state update logic that returns new PoolState +2. Adding invariant calculation functions +3. Testing with the fuzzing script + +## License + +MIT diff --git a/python/tools/simulator/__init__.py b/python/tools/simulator/__init__.py new file mode 100644 index 0000000..7a026e5 --- /dev/null +++ b/python/tools/simulator/__init__.py @@ -0,0 +1,15 @@ +"""Pool simulator for security analysis and sequential operation testing.""" + +from tools.simulator.read_simulation_data import read_simulation_data +from tools.simulator.simulator import PoolSimulator +from tools.simulator.state_loader import StateLoader +from tools.simulator.types import ExecutionMode, OperationRecord, SimulatorConfig + +__all__ = [ + "PoolSimulator", + "StateLoader", + "ExecutionMode", + "SimulatorConfig", + "OperationRecord", + "read_simulation_data", +] diff --git a/python/tools/simulator/fuzz_example.py b/python/tools/simulator/fuzz_example.py new file mode 100644 index 0000000..d68a068 --- /dev/null +++ b/python/tools/simulator/fuzz_example.py @@ -0,0 +1,222 @@ +"""Example fuzzing script demonstrating invariant tracking across sequential swaps. + +This script: +1. Loads the first non-buffer pool from testData +2. Performs 10 randomized swaps with amounts between 5-20% of current pool balances +3. Tracks and prints invariant changes after each swap +4. Note: Protocol enforces maximum 30% swap ratio relative to balance + +Usage: + python3 -m tools.simulator.fuzz_example +""" + +import random +from test.utils.map_pool_state import transform_strings_to_ints + +from src.common.types import SwapInput, SwapKind +from src.common.utils import _to_raw_undo_rate_round_down +from src.pools.stable.stable_data import StableState +from src.pools.stable.stable_math import compute_invariant as compute_stable_invariant +from src.pools.weighted.weighted_data import WeightedState +from src.pools.weighted.weighted_math import ( + compute_invariant_down as compute_weighted_invariant, +) +from tools.simulator import ( + ExecutionMode, + PoolSimulator, + StateLoader, + read_simulation_data, +) + + +def calculate_invariant(pool_state): + """Calculate invariant based on pool type.""" + if isinstance(pool_state, WeightedState): + return compute_weighted_invariant( + pool_state.weights, pool_state.balances_live_scaled18 + ) + elif isinstance(pool_state, StableState): + return compute_stable_invariant( + pool_state.amp, pool_state.balances_live_scaled18 + ) + else: + # For other pool types, return sum of balances as a simple metric + return sum(pool_state.balances_live_scaled18) + + +def format_balance(balance: int) -> str: + """Format balance for display (convert from 18 decimals).""" + return f"{balance / 1e18:,.2f}" + + +def format_invariant(invariant: int) -> str: + """Format invariant for display.""" + return f"{invariant / 1e18:,.4f}" + + +def main(): + print("=" * 80) + print("Balancer Pool Fuzzing Simulator - Invariant Tracking") + print("=" * 80) + print() + + # Load simulation data + print("Loading simulation data...") + pools = read_simulation_data() + + # Find first non-buffer pool + pool_dict = None + pool_name = None + for name, pool in pools.items(): + if pool.get("poolType") != "Buffer": + pool_dict = pool + pool_name = name + break + + if not pool_dict: + print("ERROR: No suitable pool found in simulation data") + return + + print(f"Selected pool: {pool_name}") + print(f"Pool type: {pool_dict['poolType']}") + print(f"Pool address: {pool_dict['poolAddress']}") + print() + + # Load pool state + pool_with_ints = transform_strings_to_ints(pool_dict) + pool_state, hook_state = StateLoader.from_pool_dict(pool_with_ints) + + # Initialize simulator + simulator = PoolSimulator(pool_state, hook_state) + + # Display initial state + print("Initial Pool State:") + print(f" Tokens: {len(pool_state.tokens)}") + for i, (token, balance) in enumerate( + zip(pool_state.tokens, pool_state.balances_live_scaled18) + ): + print(f" Token {i} ({token[:8]}...): {format_balance(balance)}") + + initial_invariant = calculate_invariant(pool_state) + print(f" Initial Invariant: {format_invariant(initial_invariant)}") + print() + + # Set random seed for reproducibility + random.seed(42) + + print("-" * 80) + print("Starting Fuzz Swaps (10 iterations)") + print("-" * 80) + print() + + # Perform 10 fuzz swaps + for iteration in range(10): + current_state = simulator.current_state + + # Randomly select token pair + token_in_idx = random.randint(0, len(current_state.tokens) - 1) + token_out_idx = random.randint(0, len(current_state.tokens) - 1) + + # Ensure different tokens + while token_out_idx == token_in_idx: + token_out_idx = random.randint(0, len(current_state.tokens) - 1) + + token_in = current_state.tokens[token_in_idx] + token_out = current_state.tokens[token_out_idx] + + # Calculate swap amount (1-30% of current balance for safety) + # Note: balance_in is scaled18 with rates applied + balance_in_scaled18 = current_state.balances_live_scaled18[token_in_idx] + + # Calculate 1-30% of the scaled balance + max_amount_scaled18 = int(balance_in_scaled18 * 0.30) + min_amount_scaled18 = int(balance_in_scaled18 * 0.01) # At least 1% + + # Random amount between 1% and 30% (in scaled18) + amount_scaled18 = random.randint(min_amount_scaled18, max_amount_scaled18) + + # Convert back to raw amount (undo scaling and rate) + scaling_factor = current_state.scaling_factors[token_in_idx] + token_rate = current_state.token_rates[token_in_idx] + amount_raw = _to_raw_undo_rate_round_down( + amount_scaled18, scaling_factor, token_rate + ) + + # Create swap input + swap_input = SwapInput( + amount_raw=amount_raw, + swap_kind=SwapKind.GIVENIN, + token_in=token_in, + token_out=token_out, + ) + + # Calculate invariant before swap + invariant_before = calculate_invariant(current_state) + + # Execute swap + try: + result = simulator.swap(swap_input, mode=ExecutionMode.COMMIT) + + # Calculate invariant after swap + invariant_after = calculate_invariant(simulator.current_state) + + # Calculate invariant change + invariant_change = invariant_after - invariant_before + invariant_change_pct = (invariant_change / invariant_before) * 100 + + # Get final balances for display + balance_in_after = simulator.current_state.balances_live_scaled18[ + token_in_idx + ] + balance_out_after = simulator.current_state.balances_live_scaled18[ + token_out_idx + ] + + # Display results + print(f"Swap #{iteration + 1}:") + print(f" Direction: Token {token_in_idx} → Token {token_out_idx}") + print( + f" Amount In: {format_balance(amount_scaled18)} (Token {token_in_idx})" + ) + print( + f" Amount Out: {format_balance(result.amount_calculated_raw)} (Token {token_out_idx})" + ) + print( + f" Balance In: {format_balance(balance_in_scaled18)} → {format_balance(balance_in_after)}" + ) + print( + f" Balance Out: {format_balance(current_state.balances_live_scaled18[token_out_idx])} → {format_balance(balance_out_after)}" + ) + print(f" Invariant Before: {format_invariant(invariant_before)}") + print(f" Invariant After: {format_invariant(invariant_after)}") + print( + f" Invariant Change: {format_invariant(abs(invariant_change))} " + f"({'↓' if invariant_change < 0 else '↑'} {abs(invariant_change_pct):.6f}%)" + ) + print() + + except Exception as e: + print(f"Swap #{iteration + 1}: FAILED - {e}") + print() + + # Final summary + print("-" * 80) + print("Simulation Complete") + print("-" * 80) + final_invariant = calculate_invariant(simulator.current_state) + total_change = final_invariant - initial_invariant + total_change_pct = (total_change / initial_invariant) * 100 + + print(f"Initial Invariant: {format_invariant(initial_invariant)}") + print(f"Final Invariant: {format_invariant(final_invariant)}") + print( + f"Total Change: {format_invariant(abs(total_change))} " + f"({'↓' if total_change < 0 else '↑'} {abs(total_change_pct):.6f}%)" + ) + print() + print(f"Total swaps executed: {len(simulator.get_history())}") + print() + + +if __name__ == "__main__": + main() diff --git a/python/tools/simulator/read_simulation_data.py b/python/tools/simulator/read_simulation_data.py new file mode 100644 index 0000000..d7e5bd8 --- /dev/null +++ b/python/tools/simulator/read_simulation_data.py @@ -0,0 +1,46 @@ +"""Read simulation pool data from simulationData directory.""" + +import json +import os + + +def read_simulation_data(): + """Read pool data from simulationData directory. + + Returns: + Dictionary mapping filenames to pool dictionaries. + Format: {"filename.json": pool_dict, ...} + """ + # Define the directory containing simulation JSON files + relative_path = "../../../testData/simulationData" + absolute_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), relative_path) + ) + + pools = {} + + # Check if directory exists + if not os.path.exists(absolute_path): + raise FileNotFoundError( + f"Simulation data directory not found: {absolute_path}\n" + "Run 'npm run generate:simulator' in testData/ to generate simulation data." + ) + + # Iterate over all files in the directory + for filename in os.listdir(absolute_path): + if filename.endswith(".json"): # Check if the file is a JSON file + filepath = os.path.join(absolute_path, filename) + + with open(filepath) as json_file: + data = json.load(json_file) + # Store only the pool data + if "pool" in data: + pools[filename] = data["pool"] + + if not pools: + raise ValueError( + f"No pool data found in {absolute_path}\n" + "Run 'npm run generate:simulator' in testData/ to generate simulation data." + ) + + return pools diff --git a/python/tools/simulator/simulator.py b/python/tools/simulator/simulator.py new file mode 100644 index 0000000..150bbbb --- /dev/null +++ b/python/tools/simulator/simulator.py @@ -0,0 +1,312 @@ +"""Pool simulator for security analysis and sequential operation testing. + +Provides a stateful wrapper around balancer-maths vault operations. +""" + +from __future__ import annotations + +import uuid +from copy import deepcopy +from typing import Dict, List, Optional + +from src.common.types import ( + AddLiquidityInput, + AddLiquidityResult, + BufferState, + PoolState, + RemoveLiquidityInput, + RemoveLiquidityResult, + SwapInput, + SwapResult, +) +from src.hooks.types import HookState +from src.vault.vault import Vault +from tools.simulator.types import ExecutionMode, OperationRecord, SimulatorConfig + + +class PoolSimulator: + """Stateful wrapper for pool operations enabling security analysis. + + The simulator maintains pool state across operations and supports: + - Sequential swaps, add/remove liquidity + - COMMIT mode (apply state changes) vs SIMULATE mode (query only) + - Snapshot/restore for state rollback + - Operation history tracking + """ + + def __init__( + self, + initial_pool_state: PoolState, + initial_hook_state: Optional[HookState] = None, + config: Optional[SimulatorConfig] = None, + ): + """Initialize the pool simulator. + + Args: + initial_pool_state: Starting pool state + initial_hook_state: Starting hook state (if applicable) + config: Simulator configuration + """ + self._vault = Vault() + self._initial_state = deepcopy(initial_pool_state) + self._current_state = deepcopy(initial_pool_state) + self._hook_state = deepcopy(initial_hook_state) if initial_hook_state else None + self._history: List[OperationRecord] = [] + self._snapshots: Dict[str, PoolState] = {} + self._config = config or SimulatorConfig() + self._sequence_counter = 0 + + # Core operations + + def swap( + self, swap_input: SwapInput, mode: ExecutionMode = ExecutionMode.COMMIT + ) -> SwapResult: + """Execute a swap operation. + + Args: + swap_input: Swap parameters + mode: COMMIT to apply changes, SIMULATE to query only + + Returns: + SwapResult with calculated amounts and updated state + """ + result = self._vault.swap( + swap_input=swap_input, + pool_state=self._current_state, + hook_state=self._hook_state, + ) + + if mode == ExecutionMode.COMMIT: + self._record_operation( + operation_type="swap", + input_data=swap_input, + result=result, + was_committed=True, + ) + assert not isinstance( + result.updated_pool_state, BufferState + ), "Simulator does not support buffer pools" + self._current_state = result.updated_pool_state + else: + self._record_operation( + operation_type="swap", + input_data=swap_input, + result=result, + was_committed=False, + ) + + return result + + def add_liquidity( + self, + add_liquidity_input: AddLiquidityInput, + mode: ExecutionMode = ExecutionMode.COMMIT, + ) -> AddLiquidityResult: + """Execute an add liquidity operation. + + Args: + add_liquidity_input: Add liquidity parameters + mode: COMMIT to apply changes, SIMULATE to query only + + Returns: + AddLiquidityResult with calculated amounts and updated state + """ + result = self._vault.add_liquidity( + add_liquidity_input=add_liquidity_input, + pool_state=self._current_state, + hook_state=self._hook_state, + ) + + if mode == ExecutionMode.COMMIT: + self._record_operation( + operation_type="add_liquidity", + input_data=add_liquidity_input, + result=result, + was_committed=True, + ) + self._current_state = result.updated_pool_state + else: + self._record_operation( + operation_type="add_liquidity", + input_data=add_liquidity_input, + result=result, + was_committed=False, + ) + + return result + + def remove_liquidity( + self, + remove_liquidity_input: RemoveLiquidityInput, + mode: ExecutionMode = ExecutionMode.COMMIT, + ) -> RemoveLiquidityResult: + """Execute a remove liquidity operation. + + Args: + remove_liquidity_input: Remove liquidity parameters + mode: COMMIT to apply changes, SIMULATE to query only + + Returns: + RemoveLiquidityResult with calculated amounts and updated state + """ + result = self._vault.remove_liquidity( + remove_liquidity_input=remove_liquidity_input, + pool_state=self._current_state, + hook_state=self._hook_state, + ) + + if mode == ExecutionMode.COMMIT: + self._record_operation( + operation_type="remove_liquidity", + input_data=remove_liquidity_input, + result=result, + was_committed=True, + ) + self._current_state = result.updated_pool_state + else: + self._record_operation( + operation_type="remove_liquidity", + input_data=remove_liquidity_input, + result=result, + was_committed=False, + ) + + return result + + # State access + + @property + def current_state(self) -> PoolState: + """Get a copy of the current pool state. + + Returns: + Deep copy of current pool state + """ + return deepcopy(self._current_state) + + @property + def initial_state(self) -> PoolState: + """Get a copy of the initial pool state. + + Returns: + Deep copy of initial pool state + """ + return deepcopy(self._initial_state) + + # Snapshots + + def create_snapshot(self, name: Optional[str] = None) -> str: + """Create a snapshot of the current state. + + Args: + name: Optional name for the snapshot (UUID generated if not provided) + + Returns: + Snapshot ID + """ + snapshot_id = name or str(uuid.uuid4()) + self._snapshots[snapshot_id] = deepcopy(self._current_state) + return snapshot_id + + def restore_snapshot(self, snapshot_id: str) -> None: + """Restore state from a snapshot. + + Args: + snapshot_id: ID of snapshot to restore + + Raises: + KeyError: If snapshot_id does not exist + """ + if snapshot_id not in self._snapshots: + raise KeyError(f"Snapshot '{snapshot_id}' not found") + self._current_state = deepcopy(self._snapshots[snapshot_id]) + + def delete_snapshot(self, snapshot_id: str) -> None: + """Delete a snapshot. + + Args: + snapshot_id: ID of snapshot to delete + + Raises: + KeyError: If snapshot_id does not exist + """ + del self._snapshots[snapshot_id] + + def list_snapshots(self) -> List[str]: + """List all snapshot IDs. + + Returns: + List of snapshot IDs + """ + return list(self._snapshots.keys()) + + def reset(self) -> None: + """Reset to initial state and clear history.""" + self._current_state = deepcopy(self._initial_state) + self._history.clear() + self._sequence_counter = 0 + + # History + + def get_history(self) -> List[OperationRecord]: + """Get operation history. + + Returns: + List of operation records + """ + return list(self._history) + + def clear_history(self) -> None: + """Clear operation history.""" + self._history.clear() + self._sequence_counter = 0 + + # Internal methods + + def _record_operation( + self, + operation_type: str, + input_data, + result, + was_committed: bool, + ) -> None: + """Record an operation in the history. + + Args: + operation_type: Type of operation + input_data: Operation input + result: Operation result + was_committed: Whether the operation was committed + """ + if not self._config.track_history: + return + + # Extract balances before operation + balances_before = list(self._current_state.balances_live_scaled18) + total_supply_before = self._current_state.total_supply + + # Extract balances after operation + balances_after = list(result.updated_pool_state.balances_live_scaled18) + total_supply_after = result.updated_pool_state.total_supply + + record = OperationRecord( + sequence=self._sequence_counter, + operation_type=operation_type, + input=input_data, + result=result, + was_committed=was_committed, + balances_before=balances_before, + balances_after=balances_after, + total_supply_before=total_supply_before, + total_supply_after=total_supply_after, + ) + + self._history.append(record) + self._sequence_counter += 1 + + # Enforce max history size + if ( + self._config.max_history_size is not None + and len(self._history) > self._config.max_history_size + ): + self._history.pop(0) diff --git a/python/tools/simulator/state_loader.py b/python/tools/simulator/state_loader.py new file mode 100644 index 0000000..4041e34 --- /dev/null +++ b/python/tools/simulator/state_loader.py @@ -0,0 +1,65 @@ +"""State loader for the pool simulator. + +Loads pool and hook state from testData JSON files. +""" + +from __future__ import annotations + +import json +from test.utils.map_pool_state import map_pool_and_hook_state, transform_strings_to_ints +from typing import Optional, Tuple + +from src.common.types import BufferState, PoolState +from src.hooks.types import HookState + + +class StateLoader: + """Loads pool state from testData JSON files.""" + + @staticmethod + def from_json_file(filepath: str) -> Tuple[PoolState, Optional[HookState]]: + """Load pool state from a testData JSON file. + + Args: + filepath: Path to the JSON file containing pool data + + Returns: + Tuple of (PoolState, Optional[HookState]) + + Raises: + FileNotFoundError: If the file does not exist + json.JSONDecodeError: If the file is not valid JSON + KeyError: If required keys are missing from the JSON + """ + with open(filepath) as f: + data = json.load(f) + + # Transform string amounts to integers + pool_data = transform_strings_to_ints(data) + + # Map to PoolState and HookState + pool_state, hook_state = map_pool_and_hook_state(pool_data) + assert not isinstance( + pool_state, BufferState + ), "Simulator does not support buffer pools" + return pool_state, hook_state + + @staticmethod + def from_pool_dict(pool_dict: dict) -> Tuple[PoolState, Optional[HookState]]: + """Load pool state from a dictionary (e.g., from testData["pools"]["poolName"]). + + Args: + pool_dict: Dictionary containing pool data + + Returns: + Tuple of (PoolState, Optional[HookState]) + """ + # Transform string amounts to integers + pool_data = transform_strings_to_ints(pool_dict) + + # Map to PoolState and HookState + pool_state, hook_state = map_pool_and_hook_state(pool_data) + assert not isinstance( + pool_state, BufferState + ), "Simulator does not support buffer pools" + return pool_state, hook_state diff --git a/python/tools/simulator/test_simulator.py b/python/tools/simulator/test_simulator.py new file mode 100644 index 0000000..d4c2bb8 --- /dev/null +++ b/python/tools/simulator/test_simulator.py @@ -0,0 +1,421 @@ +"""Tests for the pool simulator.""" + +from copy import deepcopy +from test.utils.map_pool_state import transform_strings_to_ints +from test.utils.read_test_data import read_test_data + +import pytest +from simulator.utils import map_state + +from src.common.types import ( + AddLiquidityInput, + AddLiquidityKind, + RemoveLiquidityInput, + RemoveLiquidityKind, + SwapInput, + SwapKind, +) +from tools.simulator import ( + ExecutionMode, + PoolSimulator, + SimulatorConfig, + StateLoader, +) + +# Load test data once for all tests +test_data = read_test_data() + + +class TestStateLoader: + """Tests for StateLoader.""" + + def test_from_pool_dict(self): + """Test loading state from a pool dictionary.""" + # Use first available pool + pool_name = list(test_data["pools"].keys())[0] + pool_dict = test_data["pools"][pool_name] + + # Skip buffer pools + if pool_dict.get("poolType") == "Buffer": + pytest.skip("Buffer pools not supported") + + pool_state, _ = StateLoader.from_pool_dict(pool_dict) + + assert pool_state is not None + assert pool_state.pool_address == pool_dict["poolAddress"] + assert pool_state.pool_type == pool_dict["poolType"] + assert len(pool_state.tokens) == len(pool_dict["tokens"]) + + +class TestPoolSimulator: + """Tests for PoolSimulator.""" + + def _get_test_pool(self): + """Get a test pool that supports all operations.""" + for pool_name, pool_dict in test_data["pools"].items(): + if pool_dict.get("poolType") != "Buffer": + return pool_name, pool_dict + raise ValueError("No suitable test pool found") + + def test_initialization(self): + """Test simulator initialization.""" + _, pool_dict = self._get_test_pool() + pool_state, hook_state = StateLoader.from_pool_dict(pool_dict) + + simulator = PoolSimulator(pool_state, hook_state) + + assert simulator.current_state.pool_address == pool_state.pool_address + assert simulator.initial_state.pool_address == pool_state.pool_address + + def test_swap_commit_mode(self): + """Test swap in COMMIT mode updates state.""" + # Find a swap test + if not test_data["swaps"]: + pytest.skip("No swap tests available") + + swap_test = test_data["swaps"][0] + pool_name = swap_test["test"] + pool_dict = test_data["pools"][pool_name] + + if pool_dict.get("poolType") == "Buffer": + pytest.skip("Buffer pools handled separately") + + pool_with_ints = transform_strings_to_ints(pool_dict) + pool_state, hook_state = map_state(pool_with_ints) + + simulator = PoolSimulator(pool_state, hook_state) + initial_balances = list(simulator.current_state.balances_live_scaled18) + + # Execute swap in COMMIT mode + swap_input = SwapInput( + amount_raw=int(swap_test["amountRaw"]), + swap_kind=SwapKind(swap_test["swapKind"]), + token_in=swap_test["tokenIn"], + token_out=swap_test["tokenOut"], + ) + + _ = simulator.swap(swap_input, mode=ExecutionMode.COMMIT) + + # State should be updated + assert simulator.current_state.balances_live_scaled18 != initial_balances + # History should contain one record + assert len(simulator.get_history()) == 1 + assert simulator.get_history()[0].was_committed is True + assert simulator.get_history()[0].operation_type == "swap" + + def test_swap_simulate_mode(self): + """Test swap in SIMULATE mode does not update state.""" + if not test_data["swaps"]: + pytest.skip("No swap tests available") + + swap_test = test_data["swaps"][0] + pool_name = swap_test["test"] + pool_dict = test_data["pools"][pool_name] + + if pool_dict.get("poolType") == "Buffer": + pytest.skip("Buffer pools handled separately") + + pool_with_ints = transform_strings_to_ints(pool_dict) + pool_state, hook_state = map_state(pool_with_ints) + + simulator = PoolSimulator(pool_state, hook_state) + initial_balances = list(simulator.current_state.balances_live_scaled18) + + # Execute swap in SIMULATE mode + swap_input = SwapInput( + amount_raw=int(swap_test["amountRaw"]), + swap_kind=SwapKind(swap_test["swapKind"]), + token_in=swap_test["tokenIn"], + token_out=swap_test["tokenOut"], + ) + + _ = simulator.swap(swap_input, mode=ExecutionMode.SIMULATE) + + # State should NOT be updated + assert simulator.current_state.balances_live_scaled18 == initial_balances + # History should contain one record + assert len(simulator.get_history()) == 1 + assert simulator.get_history()[0].was_committed is False + + def test_add_liquidity_commit_mode(self): + """Test add liquidity in COMMIT mode updates state.""" + if not test_data["adds"]: + pytest.skip("No add liquidity tests available") + + add_test = test_data["adds"][0] + pool_name = add_test["test"] + pool_dict = test_data["pools"][pool_name] + + if pool_dict.get("poolType") == "Buffer": + pytest.skip("Buffer pools do not support addLiquidity") + + pool_with_ints = transform_strings_to_ints(pool_dict) + pool_state, hook_state = map_state(pool_with_ints) + + simulator = PoolSimulator(pool_state, hook_state) + initial_total_supply = simulator.current_state.total_supply + + # Execute add liquidity + add_input = AddLiquidityInput( + pool=pool_dict["poolAddress"], + max_amounts_in_raw=list(map(int, add_test["inputAmountsRaw"])), + min_bpt_amount_out_raw=int(add_test["bptOutRaw"]), + kind=AddLiquidityKind(add_test["kind"]), + ) + + _ = simulator.add_liquidity(add_input, mode=ExecutionMode.COMMIT) + + # Total supply should increase + assert simulator.current_state.total_supply > initial_total_supply + # History should contain one record + assert len(simulator.get_history()) == 1 + assert simulator.get_history()[0].was_committed is True + assert simulator.get_history()[0].operation_type == "add_liquidity" + + def test_remove_liquidity_commit_mode(self): + """Test remove liquidity in COMMIT mode updates state.""" + if not test_data["removes"]: + pytest.skip("No remove liquidity tests available") + + remove_test = test_data["removes"][0] + pool_name = remove_test["test"] + pool_dict = test_data["pools"][pool_name] + + if pool_dict.get("poolType") == "Buffer": + pytest.skip("Buffer pools do not support removeLiquidity") + + pool_with_ints = transform_strings_to_ints(pool_dict) + pool_state, hook_state = map_state(pool_with_ints) + + simulator = PoolSimulator(pool_state, hook_state) + initial_total_supply = simulator.current_state.total_supply + + # Execute remove liquidity + remove_input = RemoveLiquidityInput( + pool=pool_dict["poolAddress"], + min_amounts_out_raw=list(map(int, remove_test["amountsOutRaw"])), + max_bpt_amount_in_raw=int(remove_test["bptInRaw"]), + kind=RemoveLiquidityKind(remove_test["kind"]), + ) + + _ = simulator.remove_liquidity(remove_input, mode=ExecutionMode.COMMIT) + + # Total supply should decrease + assert simulator.current_state.total_supply < initial_total_supply + # History should contain one record + assert len(simulator.get_history()) == 1 + assert simulator.get_history()[0].was_committed is True + assert simulator.get_history()[0].operation_type == "remove_liquidity" + + def test_snapshot_and_restore(self): + """Test snapshot and restore functionality.""" + if not test_data["swaps"]: + pytest.skip("No swap tests available") + + swap_test = test_data["swaps"][0] + pool_name = swap_test["test"] + pool_dict = test_data["pools"][pool_name] + + if pool_dict.get("poolType") == "Buffer": + pytest.skip("Buffer pools handled separately") + + pool_with_ints = transform_strings_to_ints(pool_dict) + pool_state, hook_state = map_state(pool_with_ints) + + simulator = PoolSimulator(pool_state, hook_state) + + # Create snapshot + snapshot_id = simulator.create_snapshot("before_swap") + initial_balances = deepcopy(simulator.current_state.balances_live_scaled18) + + # Execute swap + swap_input = SwapInput( + amount_raw=int(swap_test["amountRaw"]), + swap_kind=SwapKind(swap_test["swapKind"]), + token_in=swap_test["tokenIn"], + token_out=swap_test["tokenOut"], + ) + simulator.swap(swap_input, mode=ExecutionMode.COMMIT) + + # Balances should have changed + assert simulator.current_state.balances_live_scaled18 != initial_balances + + # Restore snapshot + simulator.restore_snapshot(snapshot_id) + + # Balances should be back to initial + assert simulator.current_state.balances_live_scaled18 == initial_balances + + def test_list_snapshots(self): + """Test listing snapshots.""" + _, pool_dict = self._get_test_pool() + pool_state, hook_state = StateLoader.from_pool_dict(pool_dict) + + simulator = PoolSimulator(pool_state, hook_state) + + # Initially no snapshots + assert len(simulator.list_snapshots()) == 0 + + # Create snapshots + _ = simulator.create_snapshot("snapshot1") + _ = simulator.create_snapshot("snapshot2") + + # Should have 2 snapshots + snapshots = simulator.list_snapshots() + assert len(snapshots) == 2 + assert "snapshot1" in snapshots + assert "snapshot2" in snapshots + + def test_delete_snapshot(self): + """Test deleting a snapshot.""" + _, pool_dict = self._get_test_pool() + pool_state, hook_state = StateLoader.from_pool_dict(pool_dict) + + simulator = PoolSimulator(pool_state, hook_state) + + # Create and delete snapshot + snapshot_id = simulator.create_snapshot("test") + assert snapshot_id in simulator.list_snapshots() + + simulator.delete_snapshot(snapshot_id) + assert snapshot_id not in simulator.list_snapshots() + + def test_reset(self): + """Test reset functionality.""" + if not test_data["swaps"]: + pytest.skip("No swap tests available") + + swap_test = test_data["swaps"][0] + pool_name = swap_test["test"] + pool_dict = test_data["pools"][pool_name] + + if pool_dict.get("poolType") == "Buffer": + pytest.skip("Buffer pools handled separately") + + pool_with_ints = transform_strings_to_ints(pool_dict) + pool_state, hook_state = map_state(pool_with_ints) + + simulator = PoolSimulator(pool_state, hook_state) + initial_balances = deepcopy(simulator.current_state.balances_live_scaled18) + + # Execute swap + swap_input = SwapInput( + amount_raw=int(swap_test["amountRaw"]), + swap_kind=SwapKind(swap_test["swapKind"]), + token_in=swap_test["tokenIn"], + token_out=swap_test["tokenOut"], + ) + simulator.swap(swap_input, mode=ExecutionMode.COMMIT) + + # State should have changed + assert simulator.current_state.balances_live_scaled18 != initial_balances + assert len(simulator.get_history()) > 0 + + # Reset + simulator.reset() + + # State should be back to initial + assert simulator.current_state.balances_live_scaled18 == initial_balances + assert len(simulator.get_history()) == 0 + + def test_clear_history(self): + """Test clearing history.""" + if not test_data["swaps"]: + pytest.skip("No swap tests available") + + swap_test = test_data["swaps"][0] + pool_name = swap_test["test"] + pool_dict = test_data["pools"][pool_name] + + if pool_dict.get("poolType") == "Buffer": + pytest.skip("Buffer pools handled separately") + + pool_with_ints = transform_strings_to_ints(pool_dict) + pool_state, hook_state = map_state(pool_with_ints) + + simulator = PoolSimulator(pool_state, hook_state) + + # Execute swap + swap_input = SwapInput( + amount_raw=int(swap_test["amountRaw"]), + swap_kind=SwapKind(swap_test["swapKind"]), + token_in=swap_test["tokenIn"], + token_out=swap_test["tokenOut"], + ) + simulator.swap(swap_input, mode=ExecutionMode.COMMIT) + + assert len(simulator.get_history()) > 0 + + # Clear history + simulator.clear_history() + + assert len(simulator.get_history()) == 0 + + def test_history_disabled(self): + """Test that history can be disabled.""" + pool_name, pool_dict = self._get_test_pool() + pool_state, hook_state = StateLoader.from_pool_dict(pool_dict) + + # Create simulator with history disabled + config = SimulatorConfig(track_history=False) + simulator = PoolSimulator(pool_state, hook_state, config=config) + + # Find a swap test + if not test_data["swaps"]: + pytest.skip("No swap tests available") + + swap_test = next( + (s for s in test_data["swaps"] if s["test"] == pool_name), None + ) + if swap_test is None: + pytest.skip("No swap test for this pool") + + # Execute swap + swap_input = SwapInput( + amount_raw=int(swap_test["amountRaw"]), + swap_kind=SwapKind(swap_test["swapKind"]), + token_in=swap_test["tokenIn"], + token_out=swap_test["tokenOut"], + ) + simulator.swap(swap_input, mode=ExecutionMode.COMMIT) + + # History should be empty + assert len(simulator.get_history()) == 0 + + def test_sequential_operations(self): + """Test sequential operations maintain correct state.""" + if not test_data["swaps"]: + pytest.skip("No swap tests available") + + swap_test = test_data["swaps"][0] + pool_name = swap_test["test"] + pool_dict = test_data["pools"][pool_name] + + if pool_dict.get("poolType") == "Buffer": + pytest.skip("Buffer pools handled separately") + + pool_with_ints = transform_strings_to_ints(pool_dict) + pool_state, hook_state = map_state(pool_with_ints) + + simulator = PoolSimulator(pool_state, hook_state) + + # Execute swap with smaller amount to avoid depleting pool + # Use 10% of original amount to be safe + small_amount = int(swap_test["amountRaw"]) // 10 + + # Execute multiple swaps + for i in range(3): + swap_input = SwapInput( + amount_raw=small_amount, + swap_kind=SwapKind(swap_test["swapKind"]), + token_in=swap_test["tokenIn"], + token_out=swap_test["tokenOut"], + ) + simulator.swap(swap_input, mode=ExecutionMode.COMMIT) + + # Should have 3 operations in history + assert len(simulator.get_history()) == 3 + + # Each operation should have different sequence numbers + sequences = [op.sequence for op in simulator.get_history()] + assert sequences == [0, 1, 2] diff --git a/python/tools/simulator/types.py b/python/tools/simulator/types.py new file mode 100644 index 0000000..829a494 --- /dev/null +++ b/python/tools/simulator/types.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Union + +from src.common.types import ( + AddLiquidityInput, + AddLiquidityResult, + RemoveLiquidityInput, + RemoveLiquidityResult, + SwapInput, + SwapResult, +) + + +class ExecutionMode(Enum): + """Execution mode for simulator operations.""" + + COMMIT = "commit" # Apply state changes + SIMULATE = "simulate" # Return result without state change + + +@dataclass +class SimulatorConfig: + """Configuration for the pool simulator.""" + + track_history: bool = True + max_history_size: Optional[int] = None + + +@dataclass +class OperationRecord: + """Record of a single operation in the simulator history.""" + + sequence: int + operation_type: str # "swap", "add_liquidity", "remove_liquidity" + input: Union[SwapInput, AddLiquidityInput, RemoveLiquidityInput] + result: Union[SwapResult, AddLiquidityResult, RemoveLiquidityResult] + was_committed: bool + balances_before: List[int] + balances_after: List[int] + total_supply_before: int + total_supply_after: int diff --git a/python/tools/simulator/utils.py b/python/tools/simulator/utils.py new file mode 100644 index 0000000..216a8a5 --- /dev/null +++ b/python/tools/simulator/utils.py @@ -0,0 +1,39 @@ +from test.utils.map_hook_state import map_hook_state +from test.utils.map_pool_state import map_pool_state + +from common.types import PoolState +from hooks.types import HookState + + +def map_state( + pool: dict, +) -> tuple[PoolState, HookState | None]: + """ + Maps pool data to pool state and hook state (if present). + + Args: + pool: Pool dict from JSON (already converted to ints via transform_strings_to_ints) + + Returns: + Tuple of (pool_state, hook_state or None) + - pool_state: Mapped PoolState + - hook_state: Mapped HookState if hook exists and pool supports it, None otherwise + """ + # First, map the pool state + pool_state = map_pool_state(pool) + + # Check if pool has hook data + hook_data = pool.get("hook") + if not hook_data: + return (pool_state, None) + + # Map the hook state using the centralized mapper + try: + hook_state = map_hook_state(hook_data, pool) + pool_state.hook_type = hook_state.hook_type + return (pool_state, hook_state) + except (KeyError, ValueError) as e: + # If hook mapping fails, raise with context about which pool failed + raise ValueError( + f"Failed to map hook state for pool {pool.get('poolAddress', 'unknown')}: {e}" + ) from e diff --git a/testData/README_SIMULATOR.md b/testData/README_SIMULATOR.md new file mode 100644 index 0000000..ab1d904 --- /dev/null +++ b/testData/README_SIMULATOR.md @@ -0,0 +1,84 @@ +# Simulator Data Configuration + +This directory contains separate configuration and data for the pool simulator to avoid interference with the test suite. + +## Files + +- **`config_simulator.json`** - Configuration file for pools to be used in simulation +- **`simulationData/`** - Directory containing generated pool data for simulation +- **`index_simulator.ts`** - Generation script for simulation data + +## Usage + +### Adding Pools for Simulation + +1. Edit `config_simulator.json` to add pool configurations: + +```json +{ + "poolTests": [ + { + "testName": "Weighted-USDC-DAI", + "chainId": "11155111", + "blockNumber": "7439300", + "poolAddress": "0x86fde41ff01b35846eb2f27868fb2938addd44c4", + "poolType": "WEIGHTED" + }, + { + "testName": "Stable-USDC-USDT", + "chainId": "1", + "blockNumber": "12345678", + "poolAddress": "0x...", + "poolType": "STABLE" + } + ] +} +``` + +2. Generate pool data using the simulator-specific script: + +```bash +cd testData +npm run generate:simulator +# Or to overwrite existing files: +npm run generate:simulator:overwrite +``` + +3. The generated JSON files will be placed in `simulationData/` + +### Loading Simulation Data in Code + +```python +from test.utils.read_test_data import read_test_data + +# Load from simulationData instead of testData +test_data = read_test_data(use_simulation_data=True) + +# Access pools +pool_dict = test_data["pools"]["11155111-7439300-Weighted-USDC-DAI.json"] +``` + +## Separation from Test Suite + +The simulator uses `config_simulator.json` and `simulationData/` to keep simulation pools separate from the test suite's `config.json` and `testData/`. This allows: + +- **Independent pool selection** - Add/remove simulation pools without affecting tests +- **Different block numbers** - Use specific blocks for security analysis +- **Isolation** - Prevent simulation data from interfering with CI/CD test runs + +## Scripts + +Two npm scripts are available: + +- **`generate:simulator`** - Generates simulation data from `config_simulator.json` to `simulationData/` (skips existing files) +- **`generate:simulator:overwrite`** - Same as above but overwrites existing files + +These scripts are separate from the test data generation scripts (`generate` and `generate:overwrite`) which use `config.json` and output to `testData/`. + +## Current Pools + +The following pools are currently configured for simulation: + +- **Weighted-USDC-DAI** (Sepolia, block 7439300) - 50/50 weighted pool for basic testing + +Add more pools by editing `config_simulator.json` and regenerating the data. diff --git a/testData/config_simulator.json b/testData/config_simulator.json new file mode 100644 index 0000000..a0a1a87 --- /dev/null +++ b/testData/config_simulator.json @@ -0,0 +1,11 @@ +{ + "poolTests": [ + { + "testName": "Weighted-USDC-DAI", + "chainId": "11155111", + "blockNumber": "7439300", + "poolAddress": "0x86fde41ff01b35846eb2f27868fb2938addd44c4", + "poolType": "WEIGHTED" + } + ] +} diff --git a/testData/index_simulator.ts b/testData/index_simulator.ts new file mode 100644 index 0000000..817e6b3 --- /dev/null +++ b/testData/index_simulator.ts @@ -0,0 +1,37 @@ +import { generatePoolTestData } from './src/generatePoolTestData'; +import type { Config } from './src/types'; + +const RPC_URL = { + 1: Bun.env.ETHEREUM_RPC_URL, + 11155111: Bun.env.SEPOLIA_RPC_URL, + 8453: Bun.env.BASE_RPC_URL, +}; + +async function generateSimulationData() { + const configFile = './config_simulator.json'; + const config = await readConfig(configFile); + const overWrite = Bun.argv[2] === 'true'; + const outputDir = './simulationData'; // Use simulationData instead of testData + + for (const poolTest of config.poolTests) { + const rpcUrl = RPC_URL[poolTest.chainId]; + if (!rpcUrl) + throw new Error(`Missing RPC env for chain: ${poolTest.chainId}`); + await generatePoolTestData( + { + ...poolTest, + rpcUrl, + }, + overWrite, + outputDir, // Pass custom output directory + ); + } +} + +async function readConfig(path: string) { + const file = Bun.file(path); + const contents = await file.json(); + return contents as Config; +} + +generateSimulationData(); diff --git a/testData/package.json b/testData/package.json index 7bf195f..ee05f3b 100644 --- a/testData/package.json +++ b/testData/package.json @@ -22,9 +22,11 @@ "scripts": { "generate": "bun index.ts false", "generate:overwrite": "bun index.ts true", + "generate:simulator": "bun index_simulator.ts false", + "generate:simulator:overwrite": "bun index_simulator.ts true", "format": "prettier --config .prettierrc 'src/**/*.ts' --write", "lint": "eslint ./src --ext .ts", "lint:fix": "eslint ./src --ext .ts --fix", "build": "tsc" } -} \ No newline at end of file +} diff --git a/testData/simulationData/11155111-7439300-Weighted-USDC-DAI.json b/testData/simulationData/11155111-7439300-Weighted-USDC-DAI.json new file mode 100644 index 0000000..ed1d79c --- /dev/null +++ b/testData/simulationData/11155111-7439300-Weighted-USDC-DAI.json @@ -0,0 +1,31 @@ +{ + "pool": { + "chainId": "11155111", + "blockNumber": "7439300", + "poolType": "WEIGHTED", + "poolAddress": "0x86fde41ff01b35846eb2f27868fb2938addd44c4", + "tokens": [ + "0x94a9D9AC8a22534E3FaCa9F4e7F2E2cf85d5E4C8", + "0xFF34B3d4Aee8ddCd6F9AFFFB6Fe49bD371b8a357" + ], + "scalingFactors": [ + "1000000000000", + "1" + ], + "weights": [ + "500000000000000000", + "500000000000000000" + ], + "swapFee": "10000000000000000", + "totalSupply": "6565147517543863649467", + "balancesLiveScaled18": [ + "6916384366000000000000", + "6240659067374271172646" + ], + "tokenRates": [ + "1000000000000000000", + "1000000000000000000" + ], + "aggregateSwapFee": "0" + } +} \ No newline at end of file diff --git a/testData/src/generatePoolTestData.ts b/testData/src/generatePoolTestData.ts index 19f7410..c9d0da5 100644 --- a/testData/src/generatePoolTestData.ts +++ b/testData/src/generatePoolTestData.ts @@ -7,8 +7,9 @@ import { getRemoveLiquiditys } from './getRemoves'; export async function generatePoolTestData( input: TestInput, overwrite = false, + outputDir = './testData', ) { - const path = `./testData/${input.chainId}-${input.blockNumber}-${input.testName}.json`; + const path = `${outputDir}/${input.chainId}-${input.blockNumber}-${input.testName}.json`; if (!overwrite) { const file = Bun.file(path); if (await file.exists()) { diff --git a/testData/src/liquidityBootstrappingPool.ts b/testData/src/liquidityBootstrappingPool.ts index 57bdecf..0599491 100644 --- a/testData/src/liquidityBootstrappingPool.ts +++ b/testData/src/liquidityBootstrappingPool.ts @@ -9,6 +9,7 @@ import { CHAINS } from '@balancer/sdk'; import { VAULT_V3, vaultExtensionAbi_V3 } from '@balancer/sdk'; import { liquidityBootstrappingAbi } from './abi/liquidityBootstrapping'; import { TransformBigintToString } from './types'; +import { vaultExplorerAbi } from './abi/vaultExplorer'; export type LBPoolImmutableData = { tokens: string[]; @@ -99,6 +100,7 @@ export class LiquidityBootstrappingPool { weights: string[]; swapFee: string; tokenRates: string[]; + aggregateSwapFee: string; currentTimestamp: string; } > @@ -116,8 +118,15 @@ export class LiquidityBootstrappingPool { args: [address], } as const; + const poolConfigCall = { + address: this.vault, + abi: vaultExplorerAbi, + functionName: 'getPoolConfig', + args: [address], + } as const; + const multicallResult = await this.client.multicall({ - contracts: [dynamicDataCall, tokenRatesCall], + contracts: [dynamicDataCall, tokenRatesCall, poolConfigCall], allowFailure: false, blockNumber, }); @@ -135,6 +144,9 @@ export class LiquidityBootstrappingPool { const tokenRates = multicallResult[1][1] as bigint[]; + const aggregateSwapFee = + multicallResult[2].aggregateSwapFeePercentage.toString(); + const { timestamp } = await this.client.getBlock({ blockNumber }); return { @@ -147,6 +159,7 @@ export class LiquidityBootstrappingPool { isPoolPaused, isPoolInRecoveryMode, isSwapEnabled, + aggregateSwapFee, currentTimestamp: timestamp.toString(), }; } diff --git a/testData/src/quantAmm.ts b/testData/src/quantAmm.ts index 9ec3e55..d1f13fa 100644 --- a/testData/src/quantAmm.ts +++ b/testData/src/quantAmm.ts @@ -21,6 +21,7 @@ export interface QuantAmmImmutableData { maxTradeSizeRatio: bigint; scalingFactors: bigint[]; swapFee: bigint; + aggregateSwapFee: bigint; } export class QuantAmmPool { @@ -64,11 +65,20 @@ export class QuantAmmPool { blockNumber, } as const; + // TODO: check if it makes more sense to move this to mutable data queries + const poolConfigCall = { + address: this.vault, + abi: vaultExplorerAbi, + functionName: 'getPoolConfig', + args: [address], + } as const; + const multicallResult = await this.client.multicall({ contracts: [ immutableDataCall, scalingFactorsCall, staticSwapFeePercentageCall, + poolConfigCall, ], allowFailure: false, blockNumber, @@ -81,6 +91,8 @@ export class QuantAmmPool { maxTradeSizeRatio: multicallResult[0].maxTradeSizeRatio.toString(), scalingFactors: multicallResult[1][0].map((sf) => sf.toString()), swapFee: multicallResult[2].toString(), + aggregateSwapFee: + multicallResult[3].aggregateSwapFeePercentage.toString(), }; } diff --git a/testData/testData/1-22524240-QuantAMM.json b/testData/testData/1-22524240-QuantAMM.json index 7e49de8..810fd02 100644 --- a/testData/testData/1-22524240-QuantAMM.json +++ b/testData/testData/1-22524240-QuantAMM.json @@ -122,6 +122,7 @@ "1000000000000" ], "swapFee": "20000000000000000", + "aggregateSwapFee": "500000000000000000", "balancesLiveScaled18": [ "900794470000000000", "1304051331499334098", diff --git a/testData/testData/11155111-8085514-LBP-BAL-DAI.json b/testData/testData/11155111-8085514-LBP-BAL-DAI.json index 50cef54..3984743 100644 --- a/testData/testData/11155111-8085514-LBP-BAL-DAI.json +++ b/testData/testData/11155111-8085514-LBP-BAL-DAI.json @@ -72,6 +72,7 @@ "isPoolPaused": false, "isPoolInRecoveryMode": false, "isSwapEnabled": true, + "aggregateSwapFee": "500000000000000000", "currentTimestamp": "1744221012" } } \ No newline at end of file