diff --git a/changelog.md b/changelog.md index c51d4e9c..6574be20 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,7 @@ Features * Let the `--dsn` argument accept literal DSNs as well as aliases. * Accept `--character-set` as an alias for `--charset` at the CLI. * Add SSL/TLS version to `status` output. +* Add prompt format string for SSL/TLS version of the connection. Bug Fixes diff --git a/mycli/main.py b/mycli/main.py index a3899fe0..a41162d9 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -66,6 +66,7 @@ from mycli.packages.prompt_utils import confirm, confirm_destructive_query from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import ArgType +from mycli.packages.special.utils import get_ssl_version from mycli.packages.sqlresult import SQLResult from mycli.packages.tabular_output import sql_format from mycli.packages.toolkit.history import FileHistoryWithTimestamp @@ -1478,6 +1479,13 @@ def get_prompt(self, string: str) -> str: string = string.replace("\\K", sqlexecute.socket or str(sqlexecute.port)) string = string.replace("\\A", self.dsn_alias or "(none)") string = string.replace("\\_", " ") + # jump through hoops for the test environment and for efficiency + if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: + if '\\T' in string: + with sqlexecute.conn.cursor() as cur: + string = string.replace('\\T', get_ssl_version(cur) or '(none)') + else: + string = string.replace('\\T', '(none)') return string def run_query( diff --git a/mycli/myclirc b/mycli/myclirc index 44494409..142897b7 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -113,6 +113,7 @@ wider_completion_menu = False # * \J - full connection socket path # * \k - connection socket basename OR the port # * \K - full connection socket path OR the port +# * \T - connection SSL/TLS version # * \t - database vendor (Percona, MySQL, MariaDB, TiDB) # * \u - username # * \A - DSN alias diff --git a/mycli/packages/special/utils.py b/mycli/packages/special/utils.py index 98b1e99d..c5b7cd6e 100644 --- a/mycli/packages/special/utils.py +++ b/mycli/packages/special/utils.py @@ -2,11 +2,12 @@ import os import subprocess +import pymysql from pymysql.cursors import Cursor logger = logging.getLogger(__name__) -CACHED_SSL_VERSION: dict[int, str | None] = {} +CACHED_SSL_VERSION: dict[tuple, str | None] = {} def handle_cd_command(arg: str) -> tuple[bool, str | None]: @@ -56,18 +57,24 @@ def format_uptime(uptime_in_seconds: str) -> str: def get_ssl_version(cur: Cursor) -> str | None: - if cur.connection.thread_id() in CACHED_SSL_VERSION: - return CACHED_SSL_VERSION[cur.connection.thread_id()] or None + cache_key = (id(cur.connection), cur.connection.thread_id()) + + if cache_key in CACHED_SSL_VERSION: + return CACHED_SSL_VERSION[cache_key] or None query = 'SHOW STATUS LIKE "Ssl_version"' logger.debug(query) - cur.execute(query) ssl_version = None - if one := cur.fetchone(): - CACHED_SSL_VERSION[cur.connection.thread_id()] = one[1] - ssl_version = one[1] or None - else: - CACHED_SSL_VERSION[cur.connection.thread_id()] = '' + + try: + cur.execute(query) + if one := cur.fetchone(): + CACHED_SSL_VERSION[cache_key] = one[1] + ssl_version = one[1] or None + else: + CACHED_SSL_VERSION[cache_key] = '' + except pymysql.err.OperationalError: + pass return ssl_version diff --git a/test/myclirc b/test/myclirc index 27f90bf7..6d58422b 100644 --- a/test/myclirc +++ b/test/myclirc @@ -111,6 +111,7 @@ wider_completion_menu = False # * \J - full connection socket path # * \k - connection socket basename OR the port # * \K - full connection socket path OR the port +# * \T - connection SSL/TLS version # * \t - database vendor (Percona, MySQL, MariaDB, TiDB) # * \u - username # * \A - DSN alias