diff --git a/src/charm.py b/src/charm.py index 0f7c7a5..b94bef8 100755 --- a/src/charm.py +++ b/src/charm.py @@ -98,7 +98,7 @@ def __init__(self, *args): handler.setFormatter(logging.Formatter("{name}:{message}", style="{")) # Watcher mode: lightweight Raft witness, no PostgreSQL - self._init_watcher_mode() + self.watcher_requirer = WatcherRequirerHandler(self) # Set tracing_endpoint for @trace_charm decorator compatibility self.tracing_endpoint = None @@ -119,12 +119,6 @@ def __init__(self, *args): else: self.refresh.next_unit_allowed_to_refresh = True - def _init_watcher_mode(self): - """Initialize the charm in watcher mode (lightweight Raft witness).""" - self.watcher_requirer = WatcherRequirerHandler(self) - # Watcher mode delegates all event handling to WatcherRequirerHandler. - # We still observe leader_elected to persist the role in peer data. - def _post_snap_refresh(self, refresh: charm_refresh.Machines): """Start PostgreSQL, check if this app and unit are healthy, and allow next unit to refresh. diff --git a/src/raft_controller.py b/src/raft_controller.py index aae3c06..2048ed8 100644 --- a/src/raft_controller.py +++ b/src/raft_controller.py @@ -187,15 +187,13 @@ def configure( return True def check_watcher_connection( - self, member_address: str, raft_password: str, partner_addrs: list[str], port: int + self, member_address: str, raft_password: str, partner_addrs: list[str] ) -> None: """Verify that the watcher has joined the Raft cluster.""" if not partner_addrs: logger.debug("Check connection early exit: No partners provided") return - watcher_addr = f"{member_address}:{port}" - # Get the status of the raft cluster. syncobj_util = TcpUtility(password=raft_password, timeout=3) @@ -203,7 +201,9 @@ def check_watcher_connection( try: for attempt in Retrying(stop=stop_after_attempt(10), wait=wait_fixed(2)): with attempt: - if not (raft_status := syncobj_util.executeCommand(watcher_addr, ["status"])): + if not ( + raft_status := syncobj_util.executeCommand(member_address, ["status"]) + ): raise Exception("Raft watcher no status") logger.debug(f"Observer raft: {raft_status}") for key in raft_status: @@ -282,17 +282,15 @@ def restart(self) -> bool: return False def cleanup_raft_cluster( - self, member_address: str, raft_password: str, partner_addrs: list[str], port: int + self, member_address: str, raft_password: str, partner_addrs: list[str] ) -> bool: """Cleanup RAFT members not belonging to the current cluster or not a related watcher.""" # Get Raft cluster status to find all members try: - watcher_addr = f"{member_address}:{port}" - # Get the status of the raft cluster. syncobj_util = TcpUtility(password=raft_password, timeout=3) - for raft_host in [watcher_addr, *[f"{addr}:{RAFT_PORT}" for addr in partner_addrs]]: + for raft_host in [member_address, *[f"{addr}:{RAFT_PORT}" for addr in partner_addrs]]: if raft_status := syncobj_util.executeCommand(raft_host, ["status"]): # Find all partner nodes in the Raft cluster # Keys look like: partner_node_status_server_10.131.50.142:2222 diff --git a/src/relations/watcher_requirer.py b/src/relations/watcher_requirer.py index add8c56..104a55b 100644 --- a/src/relations/watcher_requirer.py +++ b/src/relations/watcher_requirer.py @@ -281,14 +281,15 @@ def _on_leader_elected(self, _) -> None: def _update_unit_address_if_changed(self) -> None: """Update unit-address in relation data if IP has changed, for ALL relations.""" - if not (new_address := self.unit_ip): + if not (new_address := self.unit_ip) or not self.charm.unit.is_leader(): return + current_address = self.charm.app_peer_data.get("unit-address") + address_changed = current_address != new_address + unit_az = os.environ.get("JUJU_AVAILABILITY_ZONE") for relation in self.model.relations.get(WATCHER_RELATION, []): - current_address = relation.data[self.charm.unit].get("unit-address") current_az = relation.data[self.charm.app].get("unit-az") - address_changed = current_address != new_address az_changed = bool(unit_az and current_az != unit_az) if not address_changed and not az_changed: @@ -310,6 +311,7 @@ def _update_unit_address_if_changed(self) -> None: and (partner_addrs := self._get_raft_partner_addrs(relation)) ): port = self._get_port_for_relation(relation.id) + watcher_addr = f"{new_address}:{port}" raft_controller = RaftController(self.charm, f"rel{relation.id}") changed = raft_controller.configure( port, @@ -324,11 +326,10 @@ def _update_unit_address_if_changed(self) -> None: ) raft_controller.restart() raft_controller.check_watcher_connection( - new_address, raft_password, partner_addrs, port + watcher_addr, raft_password, partner_addrs ) - raft_controller.cleanup_raft_cluster( - new_address, raft_password, partner_addrs, port - ) + raft_controller.cleanup_raft_cluster(watcher_addr, raft_password, partner_addrs) + self.charm.app_peer_data["unit-address"] = new_address def _on_update_status(self, event: UpdateStatusEvent) -> None: """Handle update status event in watcher mode.""" @@ -474,8 +475,10 @@ def _on_watcher_relation_changed( # Get or assign a port for this relation port = self._get_port_for_relation(relation.id) + watcher_addr = f"{self.unit_ip}:{port}" raft_controller = RaftController(self.charm, f"rel{relation.id}") + raft_controller.cleanup_raft_cluster(watcher_addr, raft_password, partner_addrs) if self._is_disabled(relation) or not self._should_watcher_vote(partner_addrs): logger.debug("Disabling the watcher") raft_controller.remove_service() @@ -493,7 +496,7 @@ def _on_watcher_relation_changed( ) raft_controller.restart() raft_controller.check_watcher_connection( - unit_ip, raft_password, partner_addrs, port + watcher_addr, raft_password, partner_addrs ) relation.data[self.charm.unit]["unit-address"] = unit_ip diff --git a/tests/integration/ha_tests/test_stereo_mode.py b/tests/integration/ha_tests/test_stereo_mode.py index d2f1120..f24f986 100644 --- a/tests/integration/ha_tests/test_stereo_mode.py +++ b/tests/integration/ha_tests/test_stereo_mode.py @@ -116,7 +116,7 @@ async def verify_raft_cluster_health( # Check Raft status using the password syncobj_util = TcpUtility(password=password, timeout=3) status = syncobj_util.executeCommand(self_addr, ["status"]) - logger.info(f"Raft status on {unit.name}: {status}...") + logger.info(f"Raft status on {unit.name}: {status}") # Verify quorum assert status["has_quorum"] is True, f"Unit {unit.name} does not have Raft quorum" diff --git a/tests/unit/test_raft_controller.py b/tests/unit/test_raft_controller.py index e2245e5..543c7a2 100644 --- a/tests/unit/test_raft_controller.py +++ b/tests/unit/test_raft_controller.py @@ -102,14 +102,14 @@ def test_check_watcher_connection(controller: RaftController): patch("raft_controller.stop_after_attempt", return_value=stop_after_delay(0)), ): # No partners - controller.check_watcher_connection("1.1.1.1", "testpass", [], 2223) + controller.check_watcher_connection("1.1.1.1:2223", "testpass", []) assert not _tcputility.called # Can't get watcher status _tcputility.return_value.executeCommand.side_effect = [{}] - controller.check_watcher_connection("1.1.1.1", "testpass", ["2.2.2.2", "3.3.3.3"], 2223) + controller.check_watcher_connection("1.1.1.1:2223", "testpass", ["2.2.2.2", "3.3.3.3"]) _tcputility.assert_called_once_with(password="testpass", timeout=3) _tcputility.return_value.executeCommand.assert_called_once_with("1.1.1.1:2223", ["status"]) @@ -124,7 +124,7 @@ def test_check_watcher_connection(controller: RaftController): } _tcputility.return_value.executeCommand.side_effect = [raft_status] - controller.check_watcher_connection("1.1.1.1", "testpass", ["2.2.2.2", "3.3.3.3"], 2223) + controller.check_watcher_connection("1.1.1.1:2223", "testpass", ["2.2.2.2", "3.3.3.3"]) _tcputility.assert_called_once_with(password="testpass", timeout=3) _tcputility.return_value.executeCommand.assert_called_once_with("1.1.1.1:2223", ["status"]) @@ -139,7 +139,7 @@ def test_check_watcher_connection(controller: RaftController): } _tcputility.return_value.executeCommand.side_effect = [raft_status, Exception, Exception] - controller.check_watcher_connection("1.1.1.1", "testpass", ["2.2.2.2", "3.3.3.3"], 2223) + controller.check_watcher_connection("1.1.1.1:2223", "testpass", ["2.2.2.2", "3.3.3.3"]) _tcputility.assert_called_once_with(password="testpass", timeout=3) assert _tcputility.return_value.executeCommand.call_count == 3 @@ -157,7 +157,7 @@ def test_check_watcher_connection(controller: RaftController): } _tcputility.return_value.executeCommand.side_effect = [raft_status, Exception, {1: 2}] - controller.check_watcher_connection("1.1.1.1", "testpass", ["2.2.2.2", "3.3.3.3"], 2223) + controller.check_watcher_connection("1.1.1.1:2223", "testpass", ["2.2.2.2", "3.3.3.3"]) _tcputility.assert_called_once_with(password="testpass", timeout=3) assert _tcputility.return_value.executeCommand.call_count == 3