diff --git a/contracts/EthAccount.cairo b/contracts/EthAccount.cairo new file mode 100644 index 00000000..5220432d --- /dev/null +++ b/contracts/EthAccount.cairo @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// OpenZeppelin Contracts for Cairo v0.6.0 (account/presets/EthAccount.cairo) + +%lang starknet +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin, BitwiseBuiltin +from starkware.starknet.common.syscalls import get_tx_info + +from contracts.library import Account, AccountCallArray + +// +// Constructor +// + +@constructor +func constructor{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + range_check_ptr +}(ethAddress: felt) { + Account.initializer(ethAddress); + return (); +} + +// +// Getters +// + +@view +func getEthAddress{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + range_check_ptr +} () -> (ethAddress: felt) { + let (ethAddress: felt) = Account.get_public_key(); + return (ethAddress=ethAddress); +} + +@view +func supportsInterface{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + range_check_ptr +} (interfaceId: felt) -> (success: felt) { + return Account.supports_interface(interfaceId); +} + +// +// Setters +// + +@external +func setEthAddress{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + range_check_ptr +} (newEthAddress: felt) { + Account.set_public_key(newEthAddress); + return (); +} + +// +// Business logic +// + +@view +func isValidSignature{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + bitwise_ptr: BitwiseBuiltin*, + range_check_ptr, +}( + hash: felt, + signature_len: felt, + signature: felt* +) -> (isValid: felt) { + let (isValid) = Account.is_valid_eth_signature(hash, signature_len, signature); + return (isValid=isValid); +} + +@external +func __validate__{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + bitwise_ptr: BitwiseBuiltin*, + range_check_ptr, +}( + call_array_len: felt, + call_array: AccountCallArray*, + calldata_len: felt, + calldata: felt* +) { + let (tx_info) = get_tx_info(); + Account.is_valid_eth_signature(tx_info.transaction_hash, tx_info.signature_len, tx_info.signature); + return (); +} + +@external +func __validate_declare__{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + bitwise_ptr: BitwiseBuiltin*, + range_check_ptr, +} (class_hash: felt) { + let (tx_info) = get_tx_info(); + Account.is_valid_eth_signature(tx_info.transaction_hash, tx_info.signature_len, tx_info.signature); + return (); +} + + +@external +func __validate_deploy__{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + bitwise_ptr: BitwiseBuiltin*, + range_check_ptr +} ( + class_hash: felt, + salt: felt, + ethAddress: felt +) { + let (tx_info) = get_tx_info(); + Account.is_valid_eth_signature(tx_info.transaction_hash, tx_info.signature_len, tx_info.signature); + return (); +} + +@external +func __execute__{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + ecdsa_ptr: SignatureBuiltin*, + bitwise_ptr: BitwiseBuiltin*, + range_check_ptr, +}( + call_array_len: felt, + call_array: AccountCallArray*, + calldata_len: felt, + calldata: felt* +) -> ( + response_len: felt, + response: felt* +) { + let (response_len, response) = Account.execute( + call_array_len, call_array, calldata_len, calldata + ); + return (response_len, response); +} diff --git a/contracts/library.cairo b/contracts/library.cairo new file mode 100644 index 00000000..413b09f6 --- /dev/null +++ b/contracts/library.cairo @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: MIT +// OpenZeppelin Contracts for Cairo v0.6.0 (account/library.cairo) + +%lang starknet + +from starkware.cairo.common.registers import get_fp_and_pc +from starkware.cairo.common.signature import verify_ecdsa_signature +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin, BitwiseBuiltin +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.uint256 import Uint256 +from starkware.cairo.common.memcpy import memcpy +from starkware.cairo.common.math import split_felt +from starkware.cairo.common.math_cmp import is_le_felt +from starkware.cairo.common.bool import TRUE, FALSE +from starkware.starknet.common.syscalls import ( + call_contract, + get_caller_address, + get_contract_address, + get_tx_info +) +from starkware.cairo.common.cairo_secp.signature import ( + finalize_keccak, + verify_eth_signature_uint256 +) +//from openzeppelin.utils.constants.library import ( +// IACCOUNT_ID, +// IERC165_ID, +// TRANSACTION_VERSION +//) +const IACCOUNT_ID = 0xa66bd575; +const IERC165_ID = 0x01ffc9a7; +const TRANSACTION_VERSION = 1; + + +// +// Storage +// + +@storage_var +func Account_public_key() -> (public_key: felt) { +} + +// +// Structs +// + +struct Call { + to: felt, + selector: felt, + calldata_len: felt, + calldata: felt*, +} + +// Tmp struct introduced while we wait for Cairo +// to support passing `[AccountCall]` to __execute__ +struct AccountCallArray { + to: felt, + selector: felt, + data_offset: felt, + data_len: felt, +} + +namespace Account { + // + // Initializer + // + + func initializer{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( + _public_key: felt + ) { + Account_public_key.write(_public_key); + return (); + } + + // + // Guards + // + + func assert_only_self{syscall_ptr: felt*}() { + let (self) = get_contract_address(); + let (caller) = get_caller_address(); + with_attr error_message("Account: caller is not this account") { + assert self = caller; + } + return (); + } + + // + // Getters + // + + func get_public_key{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> ( + public_key: felt + ) { + return Account_public_key.read(); + } + + func supports_interface{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(interface_id: felt) -> ( + success: felt + ) { + if (interface_id == IERC165_ID) { + return (success=TRUE); + } + if (interface_id == IACCOUNT_ID) { + return (success=TRUE); + } + return (success=FALSE); + } + + // + // Setters + // + + func set_public_key{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( + new_public_key: felt + ) { + assert_only_self(); + Account_public_key.write(new_public_key); + return (); + } + + // + // Business logic + // + + func is_valid_signature{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + ecdsa_ptr: SignatureBuiltin*, + range_check_ptr, + }(hash: felt, signature_len: felt, signature: felt*) -> (is_valid: felt) { + let (_public_key) = Account_public_key.read(); + + // This interface expects a signature pointer and length to make + // no assumption about signature validation schemes. + // But this implementation does, and it expects a (sig_r, sig_s) pair. + let sig_r = signature[0]; + let sig_s = signature[1]; + + verify_ecdsa_signature( + message=hash, public_key=_public_key, signature_r=sig_r, signature_s=sig_s + ); + + return (is_valid=TRUE); + } + + func is_valid_eth_signature{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + bitwise_ptr: BitwiseBuiltin*, + range_check_ptr, + }(hash: felt, signature_len: felt, signature: felt*) -> (is_valid: felt) { + alloc_locals; + let (_public_key) = get_public_key(); + let (__fp__, _) = get_fp_and_pc(); + + // This interface expects a signature pointer and length to make + // no assumption about signature validation schemes. + // But this implementation does, and it expects a the sig_v, sig_r, + // sig_s, and hash elements. + let sig_v: felt = signature[0]; + let sig_r: Uint256 = Uint256(low=signature[1], high=signature[2]); + let sig_s: Uint256 = Uint256(low=signature[3], high=signature[4]); + let (high, low) = split_felt(hash); + let msg_hash: Uint256 = Uint256(low=low, high=high); + + let (keccak_ptr: felt*) = alloc(); + local keccak_ptr_start: felt* = keccak_ptr; + + with keccak_ptr { + verify_eth_signature_uint256( + msg_hash=msg_hash, r=sig_r, s=sig_s, v=sig_v, eth_address=_public_key + ); + } + finalize_keccak(keccak_ptr_start=keccak_ptr_start, keccak_ptr_end=keccak_ptr); + + return (is_valid=TRUE); + } + + func execute{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + ecdsa_ptr: SignatureBuiltin*, + bitwise_ptr: BitwiseBuiltin*, + range_check_ptr, + }(call_array_len: felt, call_array: AccountCallArray*, calldata_len: felt, calldata: felt*) -> ( + response_len: felt, response: felt* + ) { + alloc_locals; + + let (tx_info) = get_tx_info(); + // Disallow deprecated tx versions + with_attr error_message("Account: deprecated tx version") { + assert is_le_felt(TRANSACTION_VERSION, tx_info.version) = TRUE; + } + + // Assert not a reentrant call + let (caller) = get_caller_address(); + with_attr error_message("Account: reentrant call") { + assert caller = 0; + } + + // TMP: Convert `AccountCallArray` to 'Call'. + let (calls: Call*) = alloc(); + _from_call_array_to_call(call_array_len, call_array, calldata, calls); + let calls_len = call_array_len; + + // Execute call + let (response: felt*) = alloc(); + let (response_len) = _execute_list(calls_len, calls, response); + + return (response_len=response_len, response=response); + } + + func _execute_list{syscall_ptr: felt*}(calls_len: felt, calls: Call*, response: felt*) -> ( + response_len: felt + ) { + alloc_locals; + + // if no more calls + if (calls_len == 0) { + return (response_len=0); + } + + // do the current call + let this_call: Call = [calls]; + let res = call_contract( + contract_address=this_call.to, + function_selector=this_call.selector, + calldata_size=this_call.calldata_len, + calldata=this_call.calldata, + ); + // copy the result in response + memcpy(response, res.retdata, res.retdata_size); + // do the next calls recursively + let (response_len) = _execute_list( + calls_len - 1, calls + Call.SIZE, response + res.retdata_size + ); + return (response_len=response_len + res.retdata_size); + } + + func _from_call_array_to_call{syscall_ptr: felt*}( + call_array_len: felt, call_array: AccountCallArray*, calldata: felt*, calls: Call* + ) { + // if no more calls + if (call_array_len == 0) { + return (); + } + + // parse the current call + assert [calls] = Call( + to=[call_array].to, + selector=[call_array].selector, + calldata_len=[call_array].data_len, + calldata=calldata + [call_array].data_offset + ); + // parse the remaining calls recursively + _from_call_array_to_call( + call_array_len - 1, call_array + AccountCallArray.SIZE, calldata, calls + Call.SIZE + ); + return (); + } +} diff --git a/scripts/add_funds.py b/scripts/add_funds.py new file mode 100644 index 00000000..9bd50388 --- /dev/null +++ b/scripts/add_funds.py @@ -0,0 +1,18 @@ +from nile.common import ETH_TOKEN_ADDRESS + +async def run(nre): + accounts = await nre.get_accounts(predeployed=True) + account = accounts[0] + recipient = "0x04d3f77f305c02d158f159a91e00f4562e8697b7025559aa5f0497446c3bd5de" + + amount = [2 * 10 ** 18, 0] + print(f"transferring {amount} to {recipient} from {accounts[0].address}") + tx = await account.send(ETH_TOKEN_ADDRESS, "transfer", [recipient, *amount]) + await tx.execute() + + # eth + recipient = "0x07633f2234e6a3e71c92a757c86517953affb066cffdef1d560fbfb9036f3aa3" + + print(f"transferring {amount} to {recipient} from {accounts[0].address}") + tx = await account.send(ETH_TOKEN_ADDRESS, "transfer", [recipient, *amount]) + await tx.execute() \ No newline at end of file diff --git a/scripts/script.py b/scripts/script.py new file mode 100644 index 00000000..8acab8b9 --- /dev/null +++ b/scripts/script.py @@ -0,0 +1,5 @@ +async def run(nre): + accounts = await nre.get_accounts(predeployed=True) + account = accounts[0] + tx = await account.declare("Account", nile_account=True) + await tx.execute() \ No newline at end of file diff --git a/src/nile/cli.py b/src/nile/cli.py index 457c53bd..9e1558aa 100644 --- a/src/nile/cli.py +++ b/src/nile/cli.py @@ -17,8 +17,9 @@ from nile.core.run import run as run_command from nile.core.test import test as test_command from nile.core.types.account import get_counterfactual_address, try_get_account +from nile.core.types.eth_account import try_get_eth_account from nile.core.version import version as version_command -from nile.signer import Signer +from nile.signer import Signer, EthSigner from nile.utils import hex_address, normalize_number, shorten_address from nile.utils.get_accounts import get_accounts as get_accounts_command from nile.utils.get_accounts import ( @@ -244,6 +245,19 @@ def counterfactual_address(ctx, signer, salt): logging.info(address) +@cli.command() +@click.argument("signer", nargs=1) +@click.option("--salt", type=int, nargs=1) +@enable_stack_trace +def eth_counterfactual_address(ctx, signer, salt): + """Precompute the address of an Account contract.""" + _signer = EthSigner(normalize_number(os.environ[signer])) + address = hex_address( + get_counterfactual_address(salt, calldata=[_signer.eth_address], contract="EthAccount") + ) + logging.info(address) + + @cli.command() @click.argument("signer", nargs=1) @click.argument("address_or_alias", nargs=1) @@ -284,6 +298,46 @@ async def send( await run_transaction(tx=transaction, query_flag=query, watch_mode=watch_mode) +@cli.command() +@click.argument("signer", nargs=1) +@click.argument("address_or_alias", nargs=1) +@click.argument("method", nargs=1) +@click.argument("params", nargs=-1) +@click.option("--max_fee", type=int, nargs=1) +@network_option +@query_option +@watch_option +@enable_stack_trace +async def eth_send( + ctx, + signer, + address_or_alias, + method, + params, + network, + max_fee, + query, + watch_mode, +): + """Invoke a contract's method through an Account.""" + account = await try_get_eth_account(signer, network, watch_mode="track") + if account is not None: + print( + "Calling {} on {} with params: {}".format( + method, address_or_alias, [x for x in params] + ) + ) + + transaction = await account.send( + address_or_alias, + method, + params, + max_fee=max_fee, + ) + + await run_transaction(tx=transaction, query_flag=query, watch_mode=watch_mode) + + @cli.command() @click.argument("address_or_alias", nargs=1) @click.argument("method", nargs=1) diff --git a/src/nile/core/types/eth_account.py b/src/nile/core/types/eth_account.py new file mode 100644 index 00000000..233a2064 --- /dev/null +++ b/src/nile/core/types/eth_account.py @@ -0,0 +1,314 @@ +"""Account module.""" + +import logging +import os + +from dotenv import load_dotenv + +from nile import accounts, deployments +from nile.common import ( + NILE_ARTIFACTS_PATH, + UNIVERSAL_DEPLOYER_ADDRESS, + is_alias, + normalize_number, +) +from nile.core.types.eth_transactions import ( + DeclareTransaction, + DeployAccountTransaction, + InvokeTransaction, +) +from nile.core.types.tx_wrappers import ( + DeclareTxWrapper, + DeployAccountTxWrapper, + DeployContractTxWrapper, + InvokeTxWrapper, +) +from nile.core.types.udc_helpers import create_udc_deploy_transaction +from nile.core.types.utils import get_counterfactual_address, get_execute_calldata +from nile.signer import EthSigner, Signer +from nile.utils.get_nonce import get_nonce_without_log as get_nonce + +load_dotenv() + + +class AsyncObject(object): + """Base class for Account to allow async initialization.""" + + async def __new__(cls, *a, **kw): + """Return coroutine (not class so sync __init__ is not invoked).""" + instance = super().__new__(cls) + await instance.__init__(*a, **kw) + return instance + + async def __init__(self): + """Support Account async __init__.""" + pass + + +class EthAccount(AsyncObject): + """ + Account contract abstraction. + + Remove AsyncObject if Account.deploy decouples from initialization. + """ + + async def __init__( + self, + signer, + network, + salt=0, + max_fee=None, + predeployed_info=None, + watch_mode=None, + auto_deploy=True, + ): + """Get or deploy an Account contract for the given private key.""" + signer, alias = _get_signer_and_alias(signer, predeployed_info) + + self.signer = signer + self.alias = alias + self.network = network + + if predeployed_info is not None: + self.address = predeployed_info["address"] + self.index = predeployed_info["index"] + elif accounts.exists(self.signer.eth_address, network): + signer_data = next(accounts.load(self.signer.eth_address, network)) + self.address = signer_data["address"] + #self.address =0x07e0424e20ca2f51053e88f6a06e2dc5cbdd1937366709520ff65d247215a9d7 + self.index = signer_data["index"] + elif auto_deploy: + tx = await self.deploy(salt=salt, max_fee=max_fee) + # DeployAccountTxWrapper.execute updates account's address and index + await tx.execute(watch_mode=watch_mode) + + # We should replace this with static type checks + if hasattr(self, "address"): + assert type(self.address) == int + + async def deploy(self, salt=None, max_fee=None, abi=None): + """Deploy an Account contract for the given private key.""" + salt = 0 if salt is None else normalize_number(salt) + calldata = [self.signer.eth_address] + contract_name = "EthAccount" + predicted_address = get_counterfactual_address(salt=salt, calldata=calldata, contract="EthAccount") + + max_fee, _, calldata = await self._process_arguments(max_fee, 0, calldata) + + # Create the transaction + transaction = DeployAccountTransaction( + salt=salt, + contract_to_submit=contract_name, + predicted_address=predicted_address, + calldata=calldata, + max_fee=max_fee or 0, + network=self.network, + ) + + tx_wrapper = DeployAccountTxWrapper( + tx=transaction, + account=self, + alias=self.alias, + abi=abi, + ) + + # await _set_estimated_fee_if_none(max_fee, tx_wrapper) + return tx_wrapper + + async def send( + self, + address_or_alias, + method, + calldata, + nonce=None, + max_fee=None, + ): + """Return an InvokeTxWrapper object.""" + target_address = self._get_target_address(address_or_alias) + max_fee, nonce, calldata = await self._process_arguments( + max_fee, nonce, calldata + ) + execute_calldata = get_execute_calldata( + calls=[[target_address, method, calldata]] + ) + + # Create the transaction + transaction = InvokeTransaction( + account_address=self.address, + calldata=execute_calldata, + max_fee=max_fee or 0, + nonce=nonce, + network=self.network, + ) + + tx_wrapper = InvokeTxWrapper( + tx=transaction, + account=self, + ) + + # await _set_estimated_fee_if_none(max_fee, tx_wrapper) + return tx_wrapper + + async def declare( + self, + contract_name, + nonce=None, + max_fee=None, + alias=None, + overriding_path=None, + nile_account=False, + ): + """Return a DeclareTxWrapper for declaring a contract through an Account.""" + max_fee, nonce, _ = await self._process_arguments(max_fee, nonce) + + if nile_account: + assert overriding_path is None, "Cannot override path to Nile account." + overriding_path = NILE_ARTIFACTS_PATH + + # Create the transaction + transaction = DeclareTransaction( + account_address=self.address, + contract_to_submit=contract_name, + max_fee=max_fee or 0, + nonce=nonce, + network=self.network, + overriding_path=overriding_path, + ) + + tx_wrapper = DeclareTxWrapper( + tx=transaction, + account=self, + alias=alias, + ) + + # await _set_estimated_fee_if_none(max_fee, tx_wrapper) + return tx_wrapper + + async def deploy_contract( + self, + contract_name, + salt, + unique, + calldata, + nonce=None, + max_fee=None, + deployer_address=None, + alias=None, + overriding_path=None, + abi=None, + ): + """Deploy a contract through an Account.""" + deployer_address = normalize_number( + deployer_address or UNIVERSAL_DEPLOYER_ADDRESS + ) + max_fee, nonce, calldata = await self._process_arguments( + max_fee, nonce, calldata + ) + + # Create the transaction + transaction, predicted_address = await create_udc_deploy_transaction( + account=self, + contract_name=contract_name, + salt=salt, + unique=unique, + calldata=calldata, + deployer_address=deployer_address, + max_fee=max_fee or 0, + nonce=nonce, + overriding_path=overriding_path, + ) + + tx_wrapper = DeployContractTxWrapper( + tx=transaction, + account=self, + alias=alias, + contract_name=contract_name, + predicted_address=predicted_address, + overriding_path=overriding_path, + abi=abi, + ) + + # await _set_estimated_fee_if_none(max_fee, tx_wrapper) + return tx_wrapper + + def _get_target_address(self, address_or_alias): + if not is_alias(address_or_alias): + target_address = normalize_number(address_or_alias) + else: + target_address, _ = next( + deployments.load(address_or_alias, self.network), None + ) or (None, None) + + if type(target_address) != int: + raise Exception(f"`{address_or_alias}` alias not found in deployments.") + + return target_address + + async def _process_arguments(self, max_fee, nonce, calldata=None): + if max_fee is not None: + max_fee = int(max_fee) + + if nonce is None: + nonce = await get_nonce(self.signer.eth_address, self.network) + + if calldata is not None: + calldata = [normalize_number(x) for x in calldata] + + return max_fee, nonce, calldata + + +def _get_signer_and_alias(signer, predeployed_info): + if predeployed_info is None: + alias = signer + signer = EthSigner(normalize_number(os.environ[signer])) + else: + signer = Signer(signer) + alias = predeployed_info["alias"] + return signer, alias + + +async def _set_estimated_fee_if_none(max_fee, tx): + """Estimate max_fee for transaction if max_fee is None.""" + if max_fee is None: + logger = logging.getLogger() + current_level = logger.level + + # Avoid logging the fee estimation in CLI + logger.setLevel(logging.WARNING) + + estimated_fee = await tx.estimate_fee() + + logger.setLevel(current_level) + + tx.update_fee(estimated_fee) + + +async def try_get_eth_account( + signer, + network, + salt=None, + max_fee=None, + predeployed_info=None, + watch_mode=None, + auto_deploy=True, +): + """Avoid reverting on KeyError.""" + account = None + try: + account = await EthAccount( + signer, + network, + salt=salt, + max_fee=max_fee, + predeployed_info=predeployed_info, + watch_mode=watch_mode, + auto_deploy=auto_deploy, + ) + except KeyError: + logging.error( + f"\n❌ Cannot find {signer} in env." + "\nCheck spelling and that it exists." + "\nTry moving the .env to the root of your project." + ) + + return account diff --git a/src/nile/core/types/eth_transactions.py b/src/nile/core/types/eth_transactions.py new file mode 100644 index 00000000..67684601 --- /dev/null +++ b/src/nile/core/types/eth_transactions.py @@ -0,0 +1,305 @@ +"""Transaction module.""" + +import dataclasses +import json +import logging +import re +from abc import ABC, abstractmethod +from dataclasses import field +from typing import List + +from nile.common import ( + ABIS_DIRECTORY, + QUERY_VERSION_BASE, + TRANSACTION_VERSION, + get_chain_id, + get_class_hash, + get_contract_class, +) +from nile.core.types.utils import ( + get_declare_hash, + get_deploy_account_hash, + get_invoke_hash, +) +from nile.starknet_cli import execute_call +from nile.utils import hex_address +from nile.utils.status import status + + +@dataclasses.dataclass +class Transaction(ABC): + """ + Starknet transaction abstraction. + + Init params. + + @param account_address: The account contract from which this transaction originates. + @param max_fee: The maximal fee to be paid in Wei for the execution. + @param nonce: The nonce of the transaction. + @param network: The chain the transaction will be executed on. + @param version: The version of the transaction. + + Generated internally. + + @param hash: The hash of the transaction. + @param query_hash: The hash of the transaction with QUERY_VERSION. + @param chain_id: The id of the chain the transaction will be executed on. + """ + + account_address: int = 0 + max_fee: int = 0 + nonce: int = 0 + network: str = "localhost" + version: int = TRANSACTION_VERSION + + # Public fields not expected in construction time + tx_type: int = field(init=False) + hash: int = field(init=False, default=0) + query_hash: int = field(init=False, default=0) + chain_id: int = field(init=False) + + def __post_init__(self): + """Populate pending fields.""" + self.chain_id = get_chain_id(self.network) + self.hash = self._get_tx_hash() + self.query_hash = self._get_tx_hash(QUERY_VERSION_BASE + self.version) + + # Validate the transaction object + self._validate() + + async def execute(self, signer, watch_mode=None, **kwargs): + """Execute the transaction.""" + signature = signer.sign(message_hash=self.hash) + sig_v = signature[0] + sig_r = (signature[1], signature[2]) + sig_s = (signature[3], signature[4]) + + type_specific_args = self._get_execute_call_args() + + output = await execute_call( + self.tx_type, + self.network, + signature=[sig_v, *sig_r, *sig_s], + max_fee=self.max_fee, + query_flag=None, + **type_specific_args, + **kwargs, + ) + + match = re.search(r"Transaction hash: (0x[\da-f]{1,64})", output) + output_tx_hash = match.groups()[0] if match else None + + assert output_tx_hash == hex( + self.hash + ), "Resulting transaction hash is different than expected" + + tx_status = await status(self.hash, self.network, watch_mode) + return tx_status, output + + async def estimate_fee(self, signer, **kwargs): + """Estimate the fee of execution.""" + signature = signer.sign(message_hash=self.query_hash) + sig_v = signature[0] + sig_r = (signature[1], signature[2]) + sig_s = (signature[3], signature[4]) + + type_specific_args = self._get_execute_call_args() + + output = await execute_call( + self.tx_type, + self.network, + signature=[sig_v, *sig_r, *sig_s], + max_fee=self.max_fee, + query_flag="estimate_fee", + **type_specific_args, + **kwargs, + ) + + match = re.search(r"The estimated fee is: [\d]{1,64}", output) + output_value = ( + int(match.group(0).replace("The estimated fee is: ", "")) if match else None + ) + + logging.info(output) + return output_value + + async def simulate(self, signer, **kwargs): + """Simulate the execution.""" + signature = signer.sign(message_hash=self.query_hash) + sig_v = signature[0] + sig_r = (signature[1], signature[2]) + sig_s = (signature[3], signature[4]) + + type_specific_args = self._get_execute_call_args() + + output = await execute_call( + self.tx_type, + self.network, + signature=[sig_v, *sig_r, *sig_s], + max_fee=self.max_fee, + query_flag="simulate", + **type_specific_args, + **kwargs, + ) + + json_str = output.split("\n", 4)[4] + output_value = json.loads(json_str) + + logging.info(output) + return output_value + + def update_fee(self, max_fee): + """Update the tx from a new max_fee.""" + self.max_fee = max_fee + self.hash = self._get_tx_hash() + self.query_hash = self._get_tx_hash(QUERY_VERSION_BASE + self.version) + + # Allow chaining with execute + return self + + @abstractmethod + def _get_execute_call_args(self): + """ + Return specific arguments from transaction type. + + This method must be overridden on each specific implementation. + """ + + @abstractmethod + def _get_tx_hash(self, version): + """ + Return the tx hash for the transaction type. + + This method must be overridden on each specific implementation. + """ + + def _validate(self): + """Validate the transaction object.""" + assert self.hash > 0, "Transaction hash is empty after transaction creation!" + + +@dataclasses.dataclass +class InvokeTransaction(Transaction): + """ + Starknet invoke transaction abstraction. + + @param entry_point: The function to execute. + @param calldata: The parameters for the call. + """ + + entry_point: str = "__execute__" + calldata: List[int] = None + + def __post_init__(self): + """Populate pending fields.""" + super().__post_init__() + self.tx_type = "invoke" + + def _get_tx_hash(self, version=None): + return get_invoke_hash( + self.account_address, + self.calldata, + self.max_fee, + self.nonce, + version or self.version, + self.chain_id, + ) + + def _get_execute_call_args(self): + return { + "inputs": self.calldata, + "address": hex_address(self.account_address), + "abi": f"{ABIS_DIRECTORY}/EthAccount.json", + "method": self.entry_point, + } + + +@dataclasses.dataclass +class DeclareTransaction(Transaction): + """ + Starknet declare transaction abstraction. + + @param contract_to_submit: Contract name for declarations or deployments. + @param contract_class: Contract class required for declarations. + @param overriding_path: Utility for artifacts resolution. + """ + + contract_to_submit: str = None + contract_class: str = field(init=False) + overriding_path: str = None + + def __post_init__(self): + """Populate pending fields.""" + self.contract_class = get_contract_class( + contract_name=self.contract_to_submit, + overriding_path=self.overriding_path, + ) + self.tx_type = "declare" + super().__post_init__() + + def _get_tx_hash(self, version=None): + return get_declare_hash( + self.account_address, + self.contract_class, + self.max_fee, + self.nonce, + version or self.version, + self.chain_id, + ) + + def _get_execute_call_args(self): + return { + "contract_name": self.contract_to_submit, + "overriding_path": self.overriding_path, + "sender": hex_address(self.account_address), + } + + +@dataclasses.dataclass +class DeployAccountTransaction(Transaction): + """ + Starknet deploy_account transaction abstraction. + + @param salt: Deployed account address salt. + @param contract_to_submit: Contract name for declarations or deployments. + @param predicted_address: Counterfactual address of the account to deploy. + @param calldata: The parameters for the call. + @param overriding_path: Utility for artifacts resolution. + @param contract_class: Contract class required for declarations. + """ + + salt: int = 0 + contract_to_submit: str = None + predicted_address: int = 0 + calldata: List[int] = None + overriding_path: str = None + class_hash: int = field(init=False) + + def __post_init__(self): + """Populate pending fields.""" + self.class_hash = get_class_hash( + contract_name=self.contract_to_submit, + overriding_path=self.overriding_path, + ) + self.tx_type = "deploy_account" + super().__post_init__() + + def _get_tx_hash(self, version=None): + return get_deploy_account_hash( + self.predicted_address, + self.class_hash, + self.calldata, + self.salt, + self.max_fee, + self.nonce, + version or self.version, + self.chain_id, + ) + + def _get_execute_call_args(self): + return { + "salt": self.salt, + "contract_name": self.contract_to_submit, + "overriding_path": self.overriding_path, + "calldata": self.calldata, + } diff --git a/src/nile/signer.py b/src/nile/signer.py index 1bf5a773..803df9fe 100644 --- a/src/nile/signer.py +++ b/src/nile/signer.py @@ -9,6 +9,8 @@ get_execute_calldata, get_invoke_hash, ) +import eth_keys +from nile.utils import to_uint class Signer: @@ -73,3 +75,20 @@ def sign_invoke(self, sender, calls, nonce, max_fee, version=TRANSACTION_VERSION sig_r, sig_s = self.sign(message_hash=transaction_hash) return execute_calldata, sig_r, sig_s + + +class EthSigner: + """Utility for signing transactions for an EthAccount on Starknet.""" + + def __init__(self, private_key, network="testnet"): + p_key = private_key.to_bytes(32, byteorder="big") + self.signer = eth_keys.keys.PrivateKey(p_key) + self.eth_address = int(self.signer.public_key.to_checksum_address(), 0) + self.chain_id = get_chain_id(network) + + def sign(self, message_hash): + signature = self.signer.sign_msg_hash( + (message_hash).to_bytes(32, byteorder="big")) + sig_r = to_uint(signature.r) + sig_s = to_uint(signature.s) + return [signature.v, *sig_r, *sig_s]