diff --git a/changelog.md b/changelog.md index c51d4e9c..70ac5ff4 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,7 @@ Features * Let the `--dsn` argument accept literal DSNs as well as aliases. * Accept `--character-set` as an alias for `--charset` at the CLI. * Add SSL/TLS version to `status` output. +* Let `--keepalive-ticks` be set per-connection, as a CLI option or DSN parameter. Bug Fixes diff --git a/mycli/main.py b/mycli/main.py index a3899fe0..577e1535 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -163,6 +163,7 @@ def __init__( self.toolbar_error_message: str | None = None self.prompt_app: PromptSession | None = None self._keepalive_counter = 0 + self.keepalive_ticks: int | None = 0 # self.cnf_files is a class variable that stores the list of mysql # config files to read in at launch. @@ -544,6 +545,7 @@ def connect( unbuffered: bool | None = None, use_keyring: bool | None = None, reset_keyring: bool | None = None, + keepalive_ticks: int | None = None, ) -> None: cnf = { "database": None, @@ -572,6 +574,7 @@ def connect( port = port or cnf["port"] ssl_config: dict[str, Any] = ssl or {} user_connection_config = self.config_without_package_defaults.get('connection', {}) + self.keepalive_ticks = keepalive_ticks int_port = port and int(port) if not int_port: @@ -1004,10 +1007,12 @@ def keepalive_hook(_context): Example at https://github.com/prompt-toolkit/python-prompt-toolkit/blob/main/examples/prompts/inputhook.py """ - if self.default_keepalive_ticks < 1: + if self.keepalive_ticks is None: + return + if self.keepalive_ticks < 1: return self._keepalive_counter += 1 - if self._keepalive_counter > self.default_keepalive_ticks: + if self._keepalive_counter > self.keepalive_ticks: self._keepalive_counter = 0 self.logger.debug('keepalive ping') try: @@ -1018,7 +1023,7 @@ def keepalive_hook(_context): self.logger.debug('keepalive ping error %r', e) def one_iteration(text: str | None = None) -> None: - inputhook = keepalive_hook if self.default_keepalive_ticks >= 1 else None + inputhook = keepalive_hook if self.keepalive_ticks and self.keepalive_ticks >= 1 else None if text is None: try: assert self.prompt_app is not None @@ -1729,6 +1734,11 @@ def get_last_query(self) -> str | None: default=None, help='Store and retrieve passwords from the system keyring: true/false/reset.', ) +@click.option( + '--keepalive-ticks', + type=int, + help='Send regular keepalive pings to the connection, roughly every seconds.', +) @click.option("--checkup", is_flag=True, help="Run a checkup on your config file.") @click.pass_context def cli( @@ -1784,6 +1794,7 @@ def cli( throttle: float, use_keyring_cli_opt: str | None, checkup: bool, + keepalive_ticks: int | None, ) -> None: """A MySQL terminal client with auto-completion and syntax highlighting. @@ -1993,6 +2004,11 @@ def get_password_from_file(password_file: str | None) -> str | None: if params := dsn_params.get('ssl_verify_server_cert'): ssl_verify_server_cert = ssl_verify_server_cert or (params[0].lower() == 'true') ssl_enable = True + if params := dsn_params.get('keepalive_ticks'): + if keepalive_ticks is None: + keepalive_ticks = int(params[0]) + + keepalive_ticks = keepalive_ticks if keepalive_ticks is not None else mycli.default_keepalive_ticks ssl_mode = ssl_mode or mycli.ssl_mode # cli option or config option @@ -2168,6 +2184,7 @@ def get_password_from_file(password_file: str | None) -> str | None: character_set=character_set, use_keyring=use_keyring, reset_keyring=reset_keyring, + keepalive_ticks=keepalive_ticks, ) if combined_init_cmd: diff --git a/test/test_main.py b/test/test_main.py index 1415f598..ce005340 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -754,6 +754,9 @@ class MockMyCli: config = { "main": {}, "alias_dsn": {}, + "connection": { + "default_keepalive_ticks": 0, + }, } def __init__(self, **args): @@ -763,6 +766,7 @@ def __init__(self, **args): self.redirect_formatter = Formatter() self.ssl_mode = "auto" self.my_cnf = {"client": {}, "mysqld": {}} + self.default_keepalive_ticks = 0 def connect(self, **args): MockMyCli.connect_args = args @@ -820,6 +824,9 @@ def run_query(self, query, new_line=True): MockMyCli.config = { "main": {}, "alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}, + "connection": { + "default_keepalive_ticks": 0, + }, } MockMyCli.connect_args = None @@ -838,6 +845,9 @@ def run_query(self, query, new_line=True): MockMyCli.config = { "main": {}, "alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}, + "connection": { + "default_keepalive_ticks": 0, + }, } MockMyCli.connect_args = None @@ -893,6 +903,22 @@ def run_query(self, query, new_line=True): and MockMyCli.connect_args["ssl"]["enable"] is True ) + MockMyCli.connect_args = None + MockMyCli.config = { + "main": {}, + "alias_dsn": {}, + "connection": { + "default_keepalive_ticks": 0, + }, + } + + # keepalive_ticks as a query parameter + result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user:dsn_passwd@dsn_host:6/dsn_database?keepalive_ticks=30"]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert MockMyCli.connect_args["keepalive_ticks"] == 30 + + MockMyCli.connect_args = None + # When a user uses a DSN with query parameters, and used command line # arguments, use the command line arguments. result = runner.invoke( @@ -946,6 +972,9 @@ class MockMyCli: config = { "main": {}, "alias_dsn": {}, + "connection": { + "default_keepalive_ticks": 0, + }, } def __init__(self, **args): @@ -955,6 +984,7 @@ def __init__(self, **args): self.redirect_formatter = Formatter() self.ssl_mode = "auto" self.my_cnf = {"client": {}, "mysqld": {}} + self.default_keepalive_ticks = 0 def connect(self, **args): MockMyCli.connect_args = args