diff --git a/cosmpy/aerial/client/__init__.py b/cosmpy/aerial/client/__init__.py index edbf6716..374a5203 100644 --- a/cosmpy/aerial/client/__init__.py +++ b/cosmpy/aerial/client/__init__.py @@ -53,6 +53,8 @@ from cosmpy.aerial.config import NetworkConfig from cosmpy.aerial.exceptions import NotFoundError, QueryTimeoutError from cosmpy.aerial.gas import GasStrategy, SimulationGasStrategy +from cosmpy.aerial.query_client import wrap_query_client +from cosmpy.aerial.query_context import ResponseQueryContext from cosmpy.aerial.tx import Transaction, TxState from cosmpy.aerial.tx_helpers import MessageLog, SubmittedTx, TxResponse, safe_decode from cosmpy.aerial.types import Account, Block, NodeInfo @@ -148,6 +150,11 @@ def __init__( cfg.validate() self._network_config = cfg self._gas_strategy: GasStrategy = SimulationGasStrategy(self) + self._init_clients() + + def _init_clients(self): + """Initialize transport-specific module clients.""" + cfg = self.network_config parsed_url = parse_url(cfg.url) @@ -162,27 +169,27 @@ def __init__( else: grpc_client = grpc.insecure_channel(parsed_url.host_and_port) - self.wasm = CosmWasmGrpcClient(grpc_client) - self.auth = AuthGrpcClient(grpc_client) - self.txs = TxGrpcClient(grpc_client) - self.bank = BankGrpcClient(grpc_client) - self.staking = StakingGrpcClient(grpc_client) - self.distribution = DistributionGrpcClient(grpc_client) - self.params = QueryParamsGrpcClient(grpc_client) - self.consensus = QueryConsensusGrpcClient(grpc_client) - self.tendermint = TendermintQueryGrpcClient(grpc_client) + self.wasm = wrap_query_client(CosmWasmGrpcClient(grpc_client)) + self.auth = wrap_query_client(AuthGrpcClient(grpc_client)) + self.txs = wrap_query_client(TxGrpcClient(grpc_client)) + self.bank = wrap_query_client(BankGrpcClient(grpc_client)) + self.staking = wrap_query_client(StakingGrpcClient(grpc_client)) + self.distribution = wrap_query_client(DistributionGrpcClient(grpc_client)) + self.params = wrap_query_client(QueryParamsGrpcClient(grpc_client)) + self.consensus = wrap_query_client(QueryConsensusGrpcClient(grpc_client)) + self.tendermint = wrap_query_client(TendermintQueryGrpcClient(grpc_client)) else: rest_client = RestClient(parsed_url.rest_url) - self.wasm = CosmWasmRestClient(rest_client) # type: ignore - self.auth = AuthRestClient(rest_client) # type: ignore - self.txs = TxRestClient(rest_client) # type: ignore - self.bank = BankRestClient(rest_client) # type: ignore - self.staking = StakingRestClient(rest_client) # type: ignore - self.distribution = DistributionRestClient(rest_client) # type: ignore - self.params = ParamsRestClient(rest_client) # type: ignore - self.consensus = ConsensusRestClient(rest_client) # type: ignore - self.tendermint = TendermintRestClient(rest_client) # type: ignore + self.wasm = wrap_query_client(CosmWasmRestClient(rest_client)) # type: ignore + self.auth = wrap_query_client(AuthRestClient(rest_client)) # type: ignore + self.txs = wrap_query_client(TxRestClient(rest_client)) # type: ignore + self.bank = wrap_query_client(BankRestClient(rest_client)) # type: ignore + self.staking = wrap_query_client(StakingRestClient(rest_client)) # type: ignore + self.distribution = wrap_query_client(DistributionRestClient(rest_client)) # type: ignore + self.params = wrap_query_client(ParamsRestClient(rest_client)) # type: ignore + self.consensus = wrap_query_client(ConsensusRestClient(rest_client)) # type: ignore + self.tendermint = wrap_query_client(TendermintRestClient(rest_client)) # type: ignore @property def network_config(self) -> NetworkConfig: @@ -211,15 +218,18 @@ def gas_strategy(self, strategy: GasStrategy): raise RuntimeError("Invalid strategy must implement GasStrategy interface") self._gas_strategy = strategy - def query_account(self, address: Address) -> Account: + def query_account( + self, address: Address, ctx: Optional[ResponseQueryContext] = None + ) -> Account: """Query account. :param address: address + :param ctx: Optional QueryContext :raises RuntimeError: Unexpected account type returned from query :return: account details """ request = QueryAccountRequest(address=str(address)) - response = self.auth.Account(request) + response = self.auth.Account(request, ctx=ctx) account = BaseAccount() if not response.account.Is(BaseAccount.DESCRIPTOR): @@ -232,25 +242,33 @@ def query_account(self, address: Address) -> Account: sequence=account.sequence, ) - def query_params(self, subspace: str, key: str) -> Any: + def query_params( + self, + subspace: str, + key: str, + ctx: Optional[ResponseQueryContext] = None, + ) -> Any: """Query Prams. :param subspace: subspace :param key: key + :param ctx: Optional QueryContext :return: Query params """ req = QueryParamsRequest(subspace=subspace, key=key) - resp = self.params.Params(req) + resp = self.params.Params(req, ctx=ctx) return json.loads(resp.param.value) - def query_node_info(self) -> NodeInfo: + def query_node_info(self, ctx: Optional[ResponseQueryContext] = None) -> NodeInfo: """ Query basic Tendermint / node information (moniker, chain-id, version, etc.). + :param ctx: Optional QueryContext + :return: NodeInfo. """ request = GetNodeInfoRequest() - response = self.tendermint.GetNodeInfo(request) + response = self.tendermint.GetNodeInfo(request, ctx=ctx) cosmos_sdk_version = Version( response.application_version.cosmos_sdk_version.lstrip("v") @@ -264,18 +282,25 @@ def query_node_info(self) -> NodeInfo: app_version=app_version, ) - def query_consensus_params(self) -> Any: + def query_consensus_params(self, ctx: Optional[ResponseQueryContext] = None) -> Any: """Query consensus params. + :param ctx: Optional QueryContext :return: Query consensus params """ req = QueryParamsRequest() - resp = self.consensus.Params(req) + resp = self.consensus.Params(req, ctx=ctx) return resp - def query_bank_balance(self, address: Address, denom: Optional[str] = None) -> int: + def query_bank_balance( + self, + address: Address, + denom: Optional[str] = None, + ctx: Optional[ResponseQueryContext] = None, + ) -> int: """Query bank balance. + :param ctx: Optional QueryContext :param address: address :param denom: denom, defaults to None :return: bank balance @@ -287,19 +312,22 @@ def query_bank_balance(self, address: Address, denom: Optional[str] = None) -> i denom=denom, ) - resp = self.bank.Balance(req) + resp = self.bank.Balance(req, ctx=ctx) assert resp.balance.denom == denom # sanity check return int(resp.balance.amount) - def query_bank_all_balances(self, address: Address) -> List[Coin]: + def query_bank_all_balances( + self, address: Address, ctx: Optional[ResponseQueryContext] = None + ) -> List[Coin]: """Query bank all balances. + :param ctx: Optional QueryContext :param address: address :return: bank all balances """ req = QueryAllBalancesRequest(address=str(address)) - resp = self.bank.AllBalances(req) + resp = self.bank.AllBalances(req, ctx=ctx) return [Coin(amount=coin.amount, denom=coin.denom) for coin in resp.balances] @@ -340,11 +368,14 @@ def send_tokens( ) def query_validators( - self, status: Optional[ValidatorStatus] = None + self, + status: Optional[ValidatorStatus] = None, + ctx: Optional[ResponseQueryContext] = None, ) -> List[Validator]: """Query validators. :param status: validator status, defaults to None + :param ctx: Optional QueryContext :return: List of validators """ filtered_status = status or ValidatorStatus.BONDED @@ -353,7 +384,7 @@ def query_validators( if filtered_status != ValidatorStatus.UNSPECIFIED: req.status = filtered_status.value - resp = self.staking.Validators(req) + resp = self.staking.Validators(req, ctx=ctx) validators: List[Validator] = [] for validator in resp.validators: @@ -367,10 +398,13 @@ def query_validators( ) return validators - def query_staking_summary(self, address: Address) -> StakingSummary: + def query_staking_summary( + self, address: Address, ctx: Optional[ResponseQueryContext] = None + ) -> StakingSummary: """Query staking summary. :param address: address + :param ctx: Optional QueryContext :return: staking summary """ current_positions: List[StakingPosition] = [] @@ -378,14 +412,14 @@ def query_staking_summary(self, address: Address) -> StakingSummary: req = QueryDelegatorDelegationsRequest(delegator_addr=str(address)) for resp in get_paginated( - req, self.staking.DelegatorDelegations, per_page_limit=1 + req, self.staking.DelegatorDelegations, per_page_limit=1, ctx=ctx ): for item in resp.delegation_responses: req = QueryDelegationRewardsRequest( delegator_address=str(address), validator_address=str(item.delegation.validator_address), ) - rewards_resp = self.distribution.DelegationRewards(req) + rewards_resp = self.distribution.DelegationRewards(req, ctx=ctx) stake_reward_dec = Decimal(0) stake_reward = 0 @@ -407,7 +441,9 @@ def query_staking_summary(self, address: Address) -> StakingSummary: unbonding_summary: Dict[str, int] = {} req = QueryDelegatorUnbondingDelegationsRequest(delegator_addr=str(address)) - for resp in get_paginated(req, self.staking.DelegatorUnbondingDelegations): + for resp in get_paginated( + req, self.staking.DelegatorUnbondingDelegations, ctx=ctx + ): for item in resp.unbonding_responses: validator = str(item.validator_address) total_unbonding = unbonding_summary.get(validator, 0) @@ -611,12 +647,14 @@ def wait_for_query_tx( tx_hash: str, timeout: Optional[timedelta] = None, poll_period: Optional[timedelta] = None, + ctx: Optional[ResponseQueryContext] = None, ) -> TxResponse: """Wait for query transaction. :param tx_hash: transaction hash :param timeout: timeout, defaults to None :param poll_period: poll_period, defaults to None + :param ctx: Optional QueryContext :raises QueryTimeoutError: timeout @@ -636,7 +674,7 @@ def wait_for_query_tx( start = datetime.now() while True: try: - return self.query_tx(tx_hash) + return self.query_tx(tx_hash, ctx=ctx) except NotFoundError: pass @@ -646,17 +684,20 @@ def wait_for_query_tx( time.sleep(poll_period.total_seconds()) - def query_tx(self, tx_hash: str) -> TxResponse: + def query_tx( + self, tx_hash: str, ctx: Optional[ResponseQueryContext] = None + ) -> TxResponse: """query transaction. :param tx_hash: transaction hash + :param ctx: Optional QueryContext :raises NotFoundError: Tx details not found :raises grpc.RpcError: RPC connection issue :return: query response """ req = GetTxRequest(hash=tx_hash) try: - resp = self.txs.GetTx(req) + resp = self.txs.GetTx(req, ctx=ctx) except grpc.RpcError as e: details = e.details() if "not found" in details: @@ -744,35 +785,39 @@ def broadcast_tx(self, tx: Transaction) -> SubmittedTx: return SubmittedTx(self, tx_digest) - def query_latest_block(self) -> Block: + def query_latest_block(self, ctx: Optional[ResponseQueryContext] = None) -> Block: """Query the latest block. + :param ctx: Optional QueryContext :return: latest block """ req = GetLatestBlockRequest() - resp = self.tendermint.GetLatestBlock(req) + resp = self.tendermint.GetLatestBlock(req, ctx=ctx) return Block.from_proto(resp.block) - def query_block(self, height: int) -> Block: + def query_block( + self, height: int, ctx: Optional[ResponseQueryContext] = None + ) -> Block: """Query the block. :param height: block height + :param ctx: Optional QueryContext :return: block """ req = GetBlockByHeightRequest(height=height) - resp = self.tendermint.GetBlockByHeight(req) + resp = self.tendermint.GetBlockByHeight(req, ctx=ctx) return Block.from_proto(resp.block) - def query_height(self) -> int: + def query_height(self, ctx: Optional[ResponseQueryContext] = None) -> int: """Query the latest block height. :return: latest block height """ - return self.query_latest_block().height + return self.query_latest_block(ctx=ctx).height - def query_chain_id(self) -> str: + def query_chain_id(self, ctx: Optional[ResponseQueryContext] = None) -> str: """Query the chain id. - + :param ctx: Optional QueryContext :return: chain id """ - return self.query_latest_block().chain_id + return self.query_latest_block(ctx=ctx).chain_id diff --git a/cosmpy/aerial/client/utils.py b/cosmpy/aerial/client/utils.py index 09434d5d..03e7f4c0 100644 --- a/cosmpy/aerial/client/utils.py +++ b/cosmpy/aerial/client/utils.py @@ -170,6 +170,7 @@ def get_paginated( request_method: Callable, pages_limit: int = 0, per_page_limit: Optional[int] = DEFAULT_PER_PAGE_LIMIT, + ctx: Optional[Any] = None, ) -> List[Any]: """ Get pages for specific request. @@ -178,6 +179,7 @@ def get_paginated( :param request_method: function to perform request :param pages_limit: max number of pages to return. default - 0 unlimited :param per_page_limit: Optional int: amount of records per one page. default is None, determined by server + :param ctx: optional query context :return: List of responses """ @@ -189,7 +191,11 @@ def get_paginated( request.CopyFrom(initial_request) request.pagination.CopyFrom(pagination) - resp = request_method(request) + resp = ( + request_method(request, ctx=ctx) + if ctx is not None + else request_method(request) + ) pages.append(resp) diff --git a/cosmpy/aerial/grpc/__init__.py b/cosmpy/aerial/grpc/__init__.py new file mode 100644 index 00000000..caa9dbfd --- /dev/null +++ b/cosmpy/aerial/grpc/__init__.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------------ +# +# Copyright 2018-2021 Fetch.AI Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ + +"""gRPC helpers for aerial clients.""" diff --git a/cosmpy/aerial/grpc/rpc_wrapper.py b/cosmpy/aerial/grpc/rpc_wrapper.py new file mode 100644 index 00000000..cac2703d --- /dev/null +++ b/cosmpy/aerial/grpc/rpc_wrapper.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------------ +# +# Copyright 2018-2021 Fetch.AI Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ + +"""gRPC RPC method wrappers for request-scoped query context.""" + +from typing import Any, Optional, Protocol, Tuple + +from cosmpy.aerial.query_context import ResponseQueryContext +from cosmpy.common.rest_client import COSMOS_BLOCK_HEIGHT_HEADER + + +def _request_height(ctx: Optional[ResponseQueryContext]) -> Optional[int]: + """Get requested query height from a query context.""" + return getattr(ctx, "request_height", None) if ctx is not None else None + + +def _metadata_height(metadata: Any) -> Optional[int]: + """Extract Cosmos block height from gRPC metadata.""" + for key, value in metadata or []: + if key.lower() == COSMOS_BLOCK_HEIGHT_HEADER: + return int(value) + return None + + +class RpcCallable(Protocol): + """Callable gRPC method with grpc-stub with_call support.""" + + def __call__(self, request: Any, *, metadata: Any = None, **kwargs: Any) -> Any: + """Call RPC method.""" + + def with_call( + self, request: Any, *, metadata: Any = None, **kwargs: Any + ) -> Tuple[Any, Any]: + """Call RPC method and return response with call metadata.""" + + +class RpcMethodWrapper: + """Wrap one gRPC RPC method to support query context.""" + + def __init__(self, rpc: RpcCallable): + """ + Init RPC method wrapper. + + :param rpc: gRPC RPC method + """ + self._rpc = rpc + + def __call__( + self, + request: Any, + *, + ctx: Optional[ResponseQueryContext] = None, + metadata=None, + **kwargs, + ): + """ + Call wrapped RPC method. + + :param request: RPC request + :param ctx: optional query context + :param metadata: optional gRPC metadata + :param kwargs: additional gRPC call arguments + :return: RPC response + """ + request_height = _request_height(ctx) + metadata = list(metadata or []) + + if request_height is not None: + metadata.append((COSMOS_BLOCK_HEIGHT_HEADER, str(request_height))) + + if ctx is None: + return self._rpc(request, metadata=metadata or None, **kwargs) + + response, call = self._rpc.with_call( + request, metadata=metadata or None, **kwargs + ) + response_height = _metadata_height(call.trailing_metadata()) + if response_height is None: + response_height = _metadata_height(call.initial_metadata()) + ctx.response_height = response_height + return response diff --git a/cosmpy/aerial/grpc/stub_wrapper.py b/cosmpy/aerial/grpc/stub_wrapper.py new file mode 100644 index 00000000..d638c53e --- /dev/null +++ b/cosmpy/aerial/grpc/stub_wrapper.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------------ +# +# Copyright 2018-2021 Fetch.AI Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ + +"""gRPC stub wrapper for request-scoped query context.""" + +from typing import Any + +from cosmpy.aerial.grpc.rpc_wrapper import RpcMethodWrapper + + +class StubWrapper: + """Wrap a generated gRPC query stub to support query context.""" + + def __init__(self, stub: Any): + """ + Init gRPC stub wrapper. + + :param stub: generated gRPC stub + """ + self._stub = stub + + def __getattr__(self, name: str) -> Any: + """Forward non-callable attributes and wrap RPC methods.""" + attr = getattr(self._stub, name) + return RpcMethodWrapper(attr) if callable(attr) else attr diff --git a/cosmpy/aerial/query_client.py b/cosmpy/aerial/query_client.py new file mode 100644 index 00000000..e3bcfa9f --- /dev/null +++ b/cosmpy/aerial/query_client.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------------ +# +# Copyright 2018-2021 Fetch.AI Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ + +"""Query client wrappers for request-scoped query context.""" + +from contextlib import nullcontext +from typing import Any, Optional + +from cosmpy.aerial.grpc.stub_wrapper import StubWrapper +from cosmpy.aerial.query_context import ResponseQueryContext +from cosmpy.common.rest_client import RestClient + + +def is_query_grpc_stub(stub: Any) -> bool: + """Return true if a generated gRPC stub represents a Cosmos query service.""" + return str(getattr(stub.__class__, "__module__", "")).endswith(".query_pb2_grpc") + + +def wrap_query_client(client: Any) -> Any: + """Wrap supported query clients with request-scoped query context support.""" + if is_query_grpc_stub(client): + return StubWrapper(client) + if _find_rest_client(client) is not None: + return RestQueryClientWrapper(client) + return NoopQueryClientWrapper(client) + + +def _find_rest_client(client: Any) -> Optional[RestClient]: + """Find a RestClient held by a generated REST module client.""" + for value in vars(client).values(): + if isinstance(value, RestClient): + return value + return None + + +class RestQueryClientWrapper: + """Wrap a REST module client to support query context without per-method code.""" + + def __init__(self, client: Any): + """ + Init REST query client wrapper. + + :param client: REST module client + """ + self._client = client + self._rest_client = _find_rest_client(client) + + def __getattr__(self, name: str) -> Any: + """Forward non-callable attributes and wrap client methods.""" + attr = getattr(self._client, name) + if not callable(attr): + return attr + + def call_with_ctx(*args, ctx: Optional[ResponseQueryContext] = None, **kwargs): + ctx_manager = ( + self._rest_client.query_context(ctx) + if self._rest_client is not None and ctx is not None + else nullcontext() + ) + with ctx_manager: + return attr(*args, **kwargs) + + return call_with_ctx + + +class NoopQueryClientWrapper: + """Wrap clients so a ctx keyword does not break non-query calls.""" + + def __init__(self, client: Any): + """ + Init no-op query client wrapper. + + :param client: client instance + """ + self._client = client + + def __getattr__(self, name: str) -> Any: + """Forward non-callable attributes and ignore ctx for client methods.""" + attr = getattr(self._client, name) + if not callable(attr): + return attr + + def call_ignoring_ctx( + *args, ctx: Optional[ResponseQueryContext] = None, **kwargs + ): + return attr(*args, **kwargs) + + return call_ignoring_ctx diff --git a/cosmpy/aerial/query_context.py b/cosmpy/aerial/query_context.py new file mode 100644 index 00000000..f67692de --- /dev/null +++ b/cosmpy/aerial/query_context.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------------ +# +# Copyright 2018-2021 Fetch.AI Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ + +"""Request-scoped query context helpers.""" + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ResponseQueryContext: + """Response context populated with the block height used for a query.""" + + response_height: Optional[int] = field(init=False, default=None) + + +@dataclass +class RequestQueryContext(ResponseQueryContext): + """Request context for pinning a query to a block height.""" + + request_height: int diff --git a/cosmpy/common/rest_client.py b/cosmpy/common/rest_client.py index 0391bcd8..6a742237 100644 --- a/cosmpy/common/rest_client.py +++ b/cosmpy/common/rest_client.py @@ -20,18 +20,28 @@ """Implementation of REST api client.""" import base64 import json -from typing import List, Optional +from contextlib import contextmanager +from contextvars import ContextVar, Token +from typing import Any, Dict, List, Optional from urllib.parse import urlencode import requests from google.protobuf.json_format import MessageToDict from google.protobuf.message import Message +from cosmpy.aerial.query_context import ResponseQueryContext + + +COSMOS_BLOCK_HEIGHT_HEADER = "x-cosmos-block-height" + class RestClient: """REST api client.""" - def __init__(self, rest_address: str): + def __init__( + self, + rest_address: str, + ): """ Create REST api client. @@ -39,12 +49,40 @@ def __init__(self, rest_address: str): """ self._session = requests.session() self.rest_address = rest_address + self._query_ctx: ContextVar[Optional[ResponseQueryContext]] = ContextVar( + "rest_query_context", default=None + ) + + @contextmanager + def query_context(self, ctx: Optional[ResponseQueryContext]): + """Temporarily set the current query context.""" + token: Optional[Token] = None + try: + token = self._query_ctx.set(ctx) + yield + finally: + if token is not None: + self._query_ctx.reset(token) + + @staticmethod + def _request_height(ctx: Optional[ResponseQueryContext]) -> Optional[int]: + """Get requested query height from a query context.""" + return getattr(ctx, "request_height", None) if ctx is not None else None + + @staticmethod + def _response_height(response: requests.Response) -> Optional[int]: + """Get response query height from response headers.""" + height = response.headers.get(COSMOS_BLOCK_HEIGHT_HEADER) + if height is None: + height = response.headers.get(f"grpc-metadata-{COSMOS_BLOCK_HEIGHT_HEADER}") + return int(height) if height is not None else None def get( self, url_base_path: str, request: Optional[Message] = None, used_params: Optional[List[str]] = None, + ctx: Optional[ResponseQueryContext] = None, ) -> bytes: """ Send a GET request. @@ -52,6 +90,7 @@ def get( :param url_base_path: URL base path :param request: Protobuf coded request :param used_params: Parameters to be removed from request after converting it to dict + :param ctx: optional query context :raises RuntimeError: if response code is not 200 @@ -61,11 +100,21 @@ def get( url_base_path=url_base_path, request=request, used_params=used_params ) - response = self._session.get(url=url) + ctx = ctx or self._query_ctx.get() + request_height = self._request_height(ctx) + request_kwargs: Dict[str, Any] = {"url": url} + if request_height is not None: + request_kwargs["headers"] = { + COSMOS_BLOCK_HEIGHT_HEADER: str(request_height) + } + + response = self._session.get(**request_kwargs) if response.status_code != 200: raise RuntimeError( f"Error when sending a GET request.\n Response: {response.status_code}, {str(response.content)})" ) + if ctx is not None: + ctx.response_height = self._response_height(response) return response.content def _make_url( diff --git a/pyproject.toml b/pyproject.toml index 25b08812..fd41be06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,13 +94,14 @@ pytest = "*" pytest-rerunfailures = "*" [tool.mypy] -python_version = 3.10 +python_version = "3.10" strict_optional = true exclude = "whitelist.py" [[tool.mypy.overrides]] module = [ "cosmpy.protos.*", + "cosmpy.gogoproto.*", ] ignore_errors = true follow_imports = "skip" @@ -116,6 +117,7 @@ module = [ "docker.*", "pytest.*", "click.*", + "tomli.*", "certifi.*", "mbedtls.*", "jsonschema.*", diff --git a/tests/helpers.py b/tests/helpers.py index 8b228308..050b15db 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -23,6 +23,7 @@ from google.protobuf.descriptor import Descriptor +from cosmpy.aerial.query_context import ResponseQueryContext from cosmpy.common.rest_client import RestClient @@ -47,6 +48,7 @@ def get( url_base_path: str, request: Optional[Descriptor] = None, used_params: Optional[List[str]] = None, + ctx: Optional[ResponseQueryContext] = None, ) -> bytes: """ Handle GET request. @@ -54,6 +56,7 @@ def get( :param url_base_path: url base path :param request: optional request descriptor instance :param used_params: optional list of params name used in path + :param ctx: optional query context :return: bytes """ diff --git a/tests/unit/test_aerial/test_client.py b/tests/unit/test_aerial/test_client.py index bd6c674f..66ab4ce6 100644 --- a/tests/unit/test_aerial/test_client.py +++ b/tests/unit/test_aerial/test_client.py @@ -22,6 +22,7 @@ import datetime +from types import SimpleNamespace from google.protobuf.timestamp_pb2 import Timestamp @@ -32,6 +33,10 @@ LedgerClient, ) from cosmpy.aerial.config import NetworkConfig +from cosmpy.aerial.grpc.rpc_wrapper import RpcMethodWrapper +from cosmpy.aerial.query_client import is_query_grpc_stub +from cosmpy.aerial.query_context import RequestQueryContext, ResponseQueryContext +from cosmpy.common.rest_client import COSMOS_BLOCK_HEIGHT_HEADER from cosmpy.protos.cosmos.base.abci.v1beta1.abci_pb2 import TxResponse as PbTxResponse from cosmpy.protos.tendermint.types.block_pb2 import Block as PbBlock from cosmpy.protos.tendermint.types.types_pb2 import Data, Header @@ -61,6 +66,102 @@ def test_ledger_client_timeouts(): assert client._query_timeout_secs == timeout # pylint: disable=protected-access +def test_ledger_client_query_context_is_optional(): + """Test ledger client query context is per call and optional.""" + cfg = NetworkConfig( + chain_id="test-chain", + fee_minimum_gas_price=1, + fee_denomination="atest", + staking_denomination="atest", + url="rest+http://localhost:1317", + ) + + client = LedgerClient(cfg) + original_bank = client.bank + + assert client.network_config == cfg + assert ( + client._query_interval_secs == DEFAULT_QUERY_INTERVAL_SECS + ) # pylint: disable=protected-access + assert ( + client._query_timeout_secs == DEFAULT_QUERY_TIMEOUT_SECS + ) # pylint: disable=protected-access + assert client.bank is original_bank + + +def test_rpc_method_wrapper_merges_metadata_and_reads_response_height(): + """Test gRPC RPC wrapper handles request and response query height.""" + + class Rpc: + """Fake gRPC RPC method.""" + + def with_call(self, request, metadata=None, **kwargs): + self.request = request + self.metadata = metadata + self.kwargs = kwargs + call = SimpleNamespace( + trailing_metadata=lambda: ((COSMOS_BLOCK_HEIGHT_HEADER, "456"),), + initial_metadata=lambda: (), + ) + return "response", call + + rpc = Rpc() + ctx = RequestQueryContext(request_height=123) + + response = RpcMethodWrapper(rpc)( + "request", + ctx=ctx, + metadata=(("existing", "value"),), + timeout=1, + ) + + assert response == "response" + assert rpc.request == "request" + assert rpc.metadata == [ + ("existing", "value"), + (COSMOS_BLOCK_HEIGHT_HEADER, "123"), + ] + assert rpc.kwargs == {"timeout": 1} + assert ctx.response_height == 456 + + +def test_rpc_method_wrapper_reads_latest_response_height(): + """Test gRPC RPC wrapper can read response height without request height.""" + + class Rpc: + """Fake gRPC RPC method.""" + + @staticmethod + def with_call(request, metadata=None, **kwargs): + call = SimpleNamespace( + trailing_metadata=lambda: (), + initial_metadata=lambda: ((COSMOS_BLOCK_HEIGHT_HEADER, "789"),), + ) + return "response", call + + ctx = ResponseQueryContext() + response = RpcMethodWrapper(Rpc())("request", ctx=ctx) + + assert response == "response" + assert ctx.response_height == 789 + + +def test_grpc_query_stub_detection(): + """Test only generated query gRPC stubs are detected as query stubs.""" + + class QueryStub: + """Fake generated query stub.""" + + class TxServiceStub: + """Fake generated tx service stub.""" + + QueryStub.__module__ = "cosmpy.protos.cosmos.bank.v1beta1.query_pb2_grpc" + TxServiceStub.__module__ = "cosmpy.protos.cosmos.tx.v1beta1.service_pb2_grpc" + + assert is_query_grpc_stub(QueryStub()) + assert not is_query_grpc_stub(TxServiceStub()) + + def test_parsing_tx_response(): """Test parsing tx response.""" txhash = "hash" diff --git a/tests/unit/test_common/test_rest_client.py b/tests/unit/test_common/test_rest_client.py index b819b903..d0ba5e52 100644 --- a/tests/unit/test_common/test_rest_client.py +++ b/tests/unit/test_common/test_rest_client.py @@ -24,7 +24,8 @@ from requests import Response, Session -from cosmpy.common.rest_client import RestClient +from cosmpy.aerial.query_context import RequestQueryContext, ResponseQueryContext +from cosmpy.common.rest_client import COSMOS_BLOCK_HEIGHT_HEADER, RestClient class QueryRestClientTestCase(TestCase): @@ -97,6 +98,59 @@ def test_get_error(self, messageToDict_mock, session_mock): messageToDict_mock.assert_called_once_with(request) self.assertTrue("Error when sending a GET request" in str(context.exception)) + @staticmethod + @patch("requests.session", spec=Session) + def test_get_with_query_context(session_mock): + """ + Test get method sends and reads height with query context. + + :param session_mock: mock + """ + rest_address = "some url" + client = RestClient(rest_address) + + request_url_path = "/my/weird/url/path" + resp = Mock(spec=Response) + resp.status_code = 200 + resp.content = "dfdffdss".encode(encoding="utf8") + resp.headers = {COSMOS_BLOCK_HEIGHT_HEADER: "456"} + + session_mock.return_value.get.return_value = resp + ctx = RequestQueryContext(request_height=123) + client.get(request_url_path, ctx=ctx) + + session_mock.return_value.get.assert_called_once_with( + url=f"{rest_address}{request_url_path}", + headers={COSMOS_BLOCK_HEIGHT_HEADER: "123"}, + ) + assert ctx.response_height == 456 + + @staticmethod + @patch("requests.session", spec=Session) + def test_get_with_response_query_context(session_mock): + """ + Test get method reads height with response-only query context. + + :param session_mock: mock + """ + rest_address = "some url" + client = RestClient(rest_address) + + request_url_path = "/my/weird/url/path" + resp = Mock(spec=Response) + resp.status_code = 200 + resp.content = "dfdffdss".encode(encoding="utf8") + resp.headers = {COSMOS_BLOCK_HEIGHT_HEADER: "456"} + + session_mock.return_value.get.return_value = resp + ctx = ResponseQueryContext() + client.get(request_url_path, ctx=ctx) + + session_mock.return_value.get.assert_called_once_with( + url=f"{rest_address}{request_url_path}" + ) + assert ctx.response_height == 456 + @staticmethod @patch("requests.session", spec=Session) @patch("cosmpy.common.rest_client.MessageToDict")