Skip to content
Open
6 changes: 3 additions & 3 deletions .github/workflows/python-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ jobs:

- name: Check formatting with black
working-directory: ./python
run: black --check ./src ./test
run: black --check ./src ./test ./tools

- name: Sort imports with isort
working-directory: ./python
run: isort --check-only --diff ./src ./test
run: isort --check-only --diff ./src ./test ./tools

- name: Type checking with mypy
working-directory: ./python
run: mypy --ignore-missing-imports ./src ./test
run: mypy --ignore-missing-imports ./src ./test ./tools

- name: Run tests with pytest
working-directory: ./python
Expand Down
11 changes: 10 additions & 1 deletion python/src/common/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum

from src.common.base_pool_state import BasePoolState
from src.pools.buffer.buffer_data import BufferState
from src.pools.gyro.gyro_2clp_data import Gyro2CLPState
from src.pools.gyro.gyro_eclp_data import GyroECLPState
from src.pools.quantamm.quantamm_data import QuantAmmState
Expand All @@ -28,6 +31,8 @@ class AddLiquidityInput:
class AddLiquidityResult:
bpt_amount_out_raw: int
amounts_in_raw: list[int]
updated_pool_state: PoolState
swap_fee_amounts_scaled18: list[int]


class RemoveLiquidityKind(Enum):
Expand All @@ -48,6 +53,8 @@ class RemoveLiquidityInput:
class RemoveLiquidityResult:
bpt_amount_in_raw: int
amounts_out_raw: list[int]
updated_pool_state: PoolState
swap_fee_amounts_scaled18: list[int]


class SwapKind(Enum):
Expand All @@ -65,7 +72,9 @@ class SwapInput:

@dataclass
class SwapResult:
amount_out_raw: int
amount_calculated_raw: int
updated_pool_state: PoolState | BufferState
swap_fee_amount_scaled18: int


PoolState = (
Expand Down
11 changes: 11 additions & 0 deletions python/src/vault/add_liquidity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import replace

from src.common.base_pool_math import (
compute_add_liquidity_single_token_exact_out,
compute_add_liquidity_unbalanced,
Expand Down Expand Up @@ -149,7 +151,16 @@ def add_liquidity(
for i, a in enumerate(after_add_result.hook_adjusted_amounts_in_raw):
amounts_in_raw[i] = a

# Create updated pool state with new balances and total supply
updated_pool_state = replace(
pool_state,
balances_live_scaled18=updated_balances_live_scaled18,
total_supply=pool_state.total_supply + bpt_amount_out,
)

return AddLiquidityResult(
bpt_amount_out_raw=bpt_amount_out,
amounts_in_raw=amounts_in_raw,
updated_pool_state=updated_pool_state,
swap_fee_amounts_scaled18=swap_fee_amounts_scaled18,
)
20 changes: 19 additions & 1 deletion python/src/vault/remove_liquidity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import replace

from src.common.base_pool_math import (
compute_proportional_amounts_out,
compute_remove_liquidity_single_token_exact_in,
Expand All @@ -16,6 +18,7 @@
_get_single_input_index,
_require_unbalanced_liquidity_enabled,
_to_raw_undo_rate_round_down,
_to_scaled_18_apply_rate_round_down,
)
from src.hooks.types import HookBase, HookState

Expand Down Expand Up @@ -127,14 +130,20 @@ def remove_liquidity(

# A Pool's token balance always decreases after an exit
# Computes protocol and pool creator fee which is eventually taken from pool balance
aggregate_swap_fee_amount_scaled18 = _compute_and_charge_aggregate_swap_fees(
aggregate_swap_fee_amount_raw = _compute_and_charge_aggregate_swap_fees(
swap_fee_amounts_scaled18[i],
pool_state.aggregate_swap_fee,
pool_state.scaling_factors,
pool_state.token_rates,
i,
)

aggregate_swap_fee_amount_scaled18 = _to_scaled_18_apply_rate_round_down(
aggregate_swap_fee_amount_raw,
pool_state.scaling_factors[i],
pool_state.token_rates[i],
)

updated_balances_live_scaled18[i] = updated_balances_live_scaled18[i] - (
amounts_out_scaled18[i] + aggregate_swap_fee_amount_scaled18
)
Expand Down Expand Up @@ -163,7 +172,16 @@ def remove_liquidity(
for i, a in enumerate(after_remove_result.hook_adjusted_amounts_out_raw):
amounts_out_raw[i] = a

# Create updated pool state with new balances and total supply
updated_pool_state = replace(
pool_state,
balances_live_scaled18=updated_balances_live_scaled18,
total_supply=pool_state.total_supply - bpt_amount_in,
)

return RemoveLiquidityResult(
bpt_amount_in_raw=bpt_amount_in,
amounts_out_raw=amounts_out_raw,
updated_pool_state=updated_pool_state,
swap_fee_amounts_scaled18=swap_fee_amounts_scaled18,
)
25 changes: 21 additions & 4 deletions python/src/vault/swap.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import replace

from src.common.constants import WAD
from src.common.maths import complement_fixed, mul_div_up_fixed, mul_up_fixed
from src.common.pool_base import PoolBase
from src.common.swap_params import SwapParams
from src.common.types import PoolState, SwapInput, SwapKind
from src.common.types import PoolState, SwapInput, SwapKind, SwapResult
from src.common.utils import (
_compute_and_charge_aggregate_swap_fees,
_to_raw_undo_rate_round_down,
Expand All @@ -22,7 +24,7 @@ def swap(
pool_class: PoolBase,
hook_class: HookBase,
hook_state: HookState | object | None,
) -> int:
) -> SwapResult:
input_index = find_case_insensitive_index_in_list(
pool_state.tokens, swap_input.token_in
)
Expand Down Expand Up @@ -123,14 +125,20 @@ def swap(
pool_state.token_rates[input_index],
)

aggregate_swap_fee_amount_scaled18 = _compute_and_charge_aggregate_swap_fees(
aggregate_swap_fee_amount_raw = _compute_and_charge_aggregate_swap_fees(
total_swap_fee_amount_scaled18,
pool_state.aggregate_swap_fee,
pool_state.scaling_factors,
pool_state.token_rates,
input_index,
)

aggregate_swap_fee_amount_scaled18 = _to_scaled_18_apply_rate_round_down(
aggregate_swap_fee_amount_raw,
pool_state.scaling_factors[input_index],
pool_state.token_rates[input_index],
)

# For ExactIn, we increase the tokenIn balance by `amountIn`,
# and decrease the tokenOut balance by the
# (`amountOut` + fees).
Expand Down Expand Up @@ -185,7 +193,16 @@ def swap(
after_swap_result.hook_adjusted_amount_calculated_raw
)

return amount_calculated_raw
# Create updated pool state with new balances
updated_pool_state = replace(
pool_state, balances_live_scaled18=updated_balances_live_scaled18
)

return SwapResult(
amount_calculated_raw=amount_calculated_raw,
updated_pool_state=updated_pool_state,
swap_fee_amount_scaled18=total_swap_fee_amount_scaled18,
)


def _compute_amount_given_scaled18(
Expand Down
18 changes: 15 additions & 3 deletions python/src/vault/vault.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import replace
from typing import Dict, Optional, Type

from src.common.pool_base import PoolBase
Expand All @@ -8,6 +9,7 @@
RemoveLiquidityInput,
RemoveLiquidityResult,
SwapInput,
SwapResult,
)
from src.hooks.default_hook import DefaultHook
from src.hooks.exit_fee.exit_fee_hook import ExitFeeHook
Expand Down Expand Up @@ -60,13 +62,23 @@ def swap(
swap_input: SwapInput,
pool_state: PoolState | BufferState,
hook_state: HookState | object | None = None,
) -> int:
) -> SwapResult:
if swap_input.amount_raw == 0:
return 0
return SwapResult(
amount_calculated_raw=0,
updated_pool_state=pool_state,
swap_fee_amount_scaled18=0,
)

# buffer is handled separately than a "normal" pool
if isinstance(pool_state, BufferState):
return erc4626_buffer_wrap_or_unwrap(swap_input, pool_state)
amount_out = erc4626_buffer_wrap_or_unwrap(swap_input, pool_state)
# Buffer state doesn't change in the same way, but we still return it
return SwapResult(
amount_calculated_raw=amount_out,
updated_pool_state=pool_state,
swap_fee_amount_scaled18=0,
)
pool_class = self._get_pool(pool_state=pool_state)
hook_class = self._get_hook(
hook_name=pool_state.hook_type, hook_state=hook_state
Expand Down
8 changes: 4 additions & 4 deletions python/test/hooks/test_after_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def test_hook_after_swap_no_fee():
]
)
custom_state_no_fee = map_custom_pool_state({**pool, "aggregateSwapFee": 0})
test = vault.swap(
swap_result = vault.swap(
swap_input=swap_input,
pool_state=custom_state_no_fee,
hook_state=input_hook_state,
)
assert test == 1
assert swap_result.amount_calculated_raw == 1


def test_hook_after_swap_with_fee():
Expand All @@ -160,9 +160,9 @@ def test_hook_after_swap_with_fee():
custom_state_with_fee = map_custom_pool_state(
{**pool, "aggregateSwapFee": 500000000000000000}
)
test = vault.swap(
swap_result = vault.swap(
swap_input=swap_input,
pool_state=custom_state_with_fee,
hook_state=input_hook_state,
)
assert test == 1
assert swap_result.amount_calculated_raw == 1
4 changes: 2 additions & 2 deletions python/test/hooks/test_before_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def test_before_swap():
balance_change=[1000000000000000000, 1000000000000000000]
)
weighted_state = map_weighted_state(pool)
test = vault.swap(
swap_result = vault.swap(
swap_input=swap_input,
pool_state=weighted_state,
hook_state=input_hook_state,
)
assert test == 89999999
assert swap_result.amount_calculated_raw == 89999999
12 changes: 6 additions & 6 deletions python/test/hooks/test_stable_surge.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def test_below_surge_threshold_static_swap_fee_case1():
token_in=pool_state["tokens"][0],
token_out=pool_state["tokens"][1],
)
output_amount = vault.swap(
swap_result = vault.swap(
swap_input=swap_input, pool_state=stable_state, hook_state=hook_state
)
assert output_amount == 78522716365403684
assert swap_result.amount_calculated_raw == 78522716365403684


def test_below_surge_threshold_static_swap_fee_case2():
Expand All @@ -54,10 +54,10 @@ def test_below_surge_threshold_static_swap_fee_case2():
token_in=pool_state["tokens"][0],
token_out=pool_state["tokens"][1],
)
output_amount = vault.swap(
swap_result = vault.swap(
swap_input=swap_input, pool_state=stable_state, hook_state=hook_state
)
assert output_amount == 452983383563178802
assert swap_result.amount_calculated_raw == 452983383563178802


def test_above_surge_threshold_uses_surge_fee():
Expand All @@ -67,7 +67,7 @@ def test_above_surge_threshold_uses_surge_fee():
token_in=pool_state["tokens"][1],
token_out=pool_state["tokens"][0],
)
output_amount = vault.swap(
swap_result = vault.swap(
swap_input=swap_input, pool_state=stable_state, hook_state=hook_state
)
assert output_amount == 3252130027531260
assert swap_result.amount_calculated_raw == 3252130027531260
24 changes: 24 additions & 0 deletions python/test/test_add_liquidity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
transform_strings_to_ints,
)
from test.utils.read_test_data import read_test_data
from test.utils.validate_balances import build_add_liquidity_deltas, validate_balances
from typing import cast

from src.common.base_pool_state import BasePoolState
from src.common.types import AddLiquidityInput, AddLiquidityKind, PoolState
from src.vault.vault import Vault

Expand Down Expand Up @@ -42,3 +44,25 @@ def test_add_liquidity():
assert calculated_amount.amounts_in_raw == list(
map(int, add_test["inputAmountsRaw"])
)

# Validate updated balances
# Skip validation if hook has shouldCallComputeDynamicSwapFee enabled
should_validate = True
if isinstance(pool_state, BasePoolState):
hook = vault._get_hook(
hook_name=pool_state.hook_type, hook_state=hook_state
)
should_validate = not hook.should_call_compute_dynamic_swap_fee

if should_validate:
deltas = build_add_liquidity_deltas(
pool_state=pool_state,
amounts_in_raw=calculated_amount.amounts_in_raw,
swap_fee_amounts_scaled18=calculated_amount.swap_fee_amounts_scaled18,
)
if deltas is not None:
validate_balances(
initial_pool_state=pool_state,
updated_pool_state=calculated_amount.updated_pool_state,
amount_deltas_raw=deltas,
)
4 changes: 2 additions & 2 deletions python/test/test_custom_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_custom_pool():
}
custom_pool_state = map_custom_pool_state(pool_state)
vault = Vault(custom_pool_classes={"CustomPool": CustomPool})
calculated_amount = vault.swap(
swap_result = vault.swap(
swap_input=SwapInput(
amount_raw=1000000000000000000,
token_in="0x7b79995e5f793A07Bc00c21412e50Ecae098E7f9",
Expand All @@ -64,7 +64,7 @@ def test_custom_pool():
),
pool_state=custom_pool_state,
)
assert calculated_amount == custom_pool_state.randoms[0]
assert swap_result.amount_calculated_raw == custom_pool_state.randoms[0]


class CustomPool(PoolBase):
Expand Down
Loading
Loading