diff --git a/app/contacts/contact_tcp.py b/app/contacts/contact_tcp.py index 76a350438..18302a070 100644 --- a/app/contacts/contact_tcp.py +++ b/app/contacts/contact_tcp.py @@ -6,7 +6,7 @@ from typing import Tuple from app.utility.base_world import BaseWorld -from plugins.manx.app.c_session import Session +from app.contacts.utility.c_tcp_session import TCPSession class Contact(BaseWorld): @@ -14,6 +14,7 @@ class Contact(BaseWorld): def __init__(self, services): self.name = 'tcp' self.description = 'Accept beacons through a raw TCP socket' + self.services = services self.log = self.create_logger('contact_tcp') self.contact_svc = services.get('contact_svc') self.tcp_handler = TcpSessionHandler(services, self.log) @@ -93,18 +94,14 @@ def __init__(self, services, log): self.sessions = [] async def refresh(self): - index = 0 - - while index < len(self.sessions): - session = self.sessions[index] - + refreshed_sessions = [] + for session in self.sessions: try: - session.writer.write(str.encode(' ')) + session.write_bytes(str.encode(' ')) + refreshed_sessions.append(session) except socket.error: self.log.debug('Error occurred when refreshing session %s. Removing from session pool.', session.id) - del self.sessions[index] - else: - index += 1 + self.sessions = refreshed_sessions async def accept(self, reader, writer): self.log.debug('Accepting connection.') @@ -116,19 +113,30 @@ async def accept(self, reader, writer): profile['executors'] = [e for e in profile['executors'].split(',') if e] profile['contact'] = 'tcp' agent, _ = await self.services.get('contact_svc').handle_heartbeat(**profile) - new_session = Session(id=self.generate_number(size=6), paw=agent.paw, reader=reader, writer=writer) + new_session = TCPSession(id=self.generate_number(size=6), paw=agent.paw, reader=reader, writer=writer) self.sessions.append(new_session) await self.send(new_session.id, agent.paw, timeout=5) async def send(self, session_id: int, cmd: str, timeout: int = 60) -> Tuple[int, str, str, str]: try: - session = next(i for i in self.sessions if i.id == int(session_id)) - session.writer.write(str.encode(' ')) + try: + session = next(i for i in self.sessions if i.id == int(session_id)) + except StopIteration: + msg = f'Could not find session with ID {session_id}' + self.log.error(msg) + return 1, '~$ ', msg, '' + + session.write_bytes(str.encode(' ')) time.sleep(0.01) - session.writer.write(str.encode('%s\n' % cmd)) - response = await self._attempt_connection(session_id, session.reader, timeout=timeout) - response = json.loads(response) - return response['status'], response['pwd'], response['response'], response.get('agent_reported_time', '') + session.write_bytes(str.encode('%s\n' % cmd)) + response = await self._attempt_connection(session, timeout=timeout) + if response: + response = json.loads(response) + return response.get('status', 1), response.get('pwd', '~$ '), response.get('response', 'No response provided'), response.get('agent_reported_time', '') + else: + msg = f'Failed to read data from session {session.id}' + self.log.error(msg) + return 1, '~$ ', msg, '' except Exception as e: self.log.exception(e) return 1, '~$ ', str(e), '' @@ -138,17 +146,17 @@ async def _handshake(reader): profile_bites = (await reader.readline()).strip() return json.loads(profile_bites) - async def _attempt_connection(self, session_id, reader, timeout): + async def _attempt_connection(self, session, timeout): buffer = 4096 data = b'' time.sleep(0.1) # initial wait for fast operations. while True: try: - part = await reader.read(buffer) + part = await session.read_bytes(buffer) data += part if len(part) < buffer: break except Exception as err: - self.log.error("Timeout reached for session %s", session_id) + self.log.error("Timeout reached for session %s", session.id) return json.dumps(dict(status=1, pwd='~$ ', response=str(err))) return str(data, 'utf-8') diff --git a/app/contacts/utility/c_tcp_session.py b/app/contacts/utility/c_tcp_session.py new file mode 100644 index 000000000..528585fcb --- /dev/null +++ b/app/contacts/utility/c_tcp_session.py @@ -0,0 +1,32 @@ +from app.utility.base_object import BaseObject + + +class TCPSession(BaseObject): + + @property + def unique(self): + return self.hash('%s' % self.paw) + + def __init__(self, id, paw, reader, writer): + super().__init__() + self.id = id + self.paw = paw + self._reader = reader + self._writer = writer + + def store(self, ram): + existing = self.retrieve(ram['sessions'], self.unique) + if not existing: + ram['sessions'].append(self) + return self.retrieve(ram['sessions'], self.unique) + return existing + + def write_bytes(self, input): + """Wrapper for self._writer.write""" + + return self._writer.write(input) + + def read_bytes(self, buffer): + """Wrapper for self._reader.read""" + + return self._reader.read(buffer) diff --git a/tests/contacts/test_contact_tcp.py b/tests/contacts/test_contact_tcp.py index 70af5ba38..ceaaeca57 100644 --- a/tests/contacts/test_contact_tcp.py +++ b/tests/contacts/test_contact_tcp.py @@ -1,19 +1,43 @@ import logging import socket from unittest import mock +import pytest + +from app.service.contact_svc import ContactService +from app.utility.base_world import BaseWorld from app.contacts.contact_tcp import TcpSessionHandler +from app.contacts.contact_tcp import Contact +from app.contacts.utility.c_tcp_session import TCPSession +from app.objects.secondclass.c_instruction import Instruction logger = logging.getLogger(__name__) +@pytest.fixture +def tcp_c2(app_svc, contact_svc, data_svc, obfuscator): + services = app_svc.get_services() + tcp_contact_svc = Contact(services=services) + return tcp_contact_svc + + +class _MockReader: + async def read(self, n=-1): + return b'MockContent' + + +class _MockWriter: + def write(self, data): + pass + + class TestTcpSessionHandler: def test_refresh_with_socket_errors(self, event_loop): handler = TcpSessionHandler(services=None, log=logger) session_with_socket_error = mock.Mock() - session_with_socket_error.writer.write.side_effect = socket.error() + session_with_socket_error.write_bytes.side_effect = socket.error() handler.sessions = [ session_with_socket_error, @@ -35,3 +59,82 @@ def test_refresh_without_socket_errors(self, event_loop): event_loop.run_until_complete(handler.refresh()) assert len(handler.sessions) == 3 + + async def test_attempt_connection(self, tcp_c2): + MockSession = TCPSession(id=123456, paw='testpaw', reader=_MockReader(), writer=_MockWriter()) + assert "MockContent" == await tcp_c2.tcp_handler._attempt_connection(MockSession, timeout=1) + + async def test_accept(self, tcp_c2): + dummy_profile = { + 'architecture': 'amd64', + 'exe_name': 'splunkd', + 'executors': 'sh', + 'host': 'Caldera', + 'location': './splunkd', + 'pid': 10057, + 'platform': 'linux', + 'ppid': 9752, + 'server': '0.0.0.0:7010', + 'username': 'caldera' + } + with mock.patch.object(TcpSessionHandler, '_handshake', return_value=(dummy_profile)): + await tcp_c2.tcp_handler.accept(reader=_MockReader(), writer=_MockWriter()) + assert len(tcp_c2.tcp_handler.sessions) == 1 + + async def test_accept_err(self, tcp_c2): + with mock.patch.object(TcpSessionHandler, '_handshake', side_effect=Exception('mock exception')): + await tcp_c2.tcp_handler.accept(reader=_MockReader(), writer=_MockWriter()) + assert len(tcp_c2.tcp_handler.sessions) == 0 + + async def test_send_no_session(self, tcp_c2): + status, pwd, response, agent_time = await tcp_c2.tcp_handler.send(session_id=999999, cmd='whoami', timeout=1) + assert status == 1 + assert 'Could not find session with ID 999999' == response + assert pwd == '~$ ' + assert agent_time == '' + + async def test_send_with_session_err(self, tcp_c2): + mock_session = TCPSession(id=123456, paw='testpaw', reader=_MockReader(), writer=_MockWriter()) + tcp_c2.tcp_handler.sessions.append(mock_session) + with mock.patch.object(TcpSessionHandler, '_attempt_connection', side_effect=Exception('Test exception')): + status, pwd, response, agent_time = await tcp_c2.tcp_handler.send(session_id=123456, cmd='whoami', timeout=1) + assert status == 1 + assert 'Test exception' == response + assert pwd == '~$ ' + assert agent_time == '' + + async def test_send_with_session_no_response(self, tcp_c2): + mock_session = TCPSession(id=123456, paw='testpaw', reader=_MockReader(), writer=_MockWriter()) + tcp_c2.tcp_handler.sessions.append(mock_session) + with mock.patch.object(TcpSessionHandler, '_attempt_connection', return_value=''): + status, pwd, response, agent_time = await tcp_c2.tcp_handler.send(session_id=123456, cmd='whoami', timeout=1) + assert status == 1 + assert 'Failed to read data from session 123456' == response + assert pwd == '~$ ' + assert agent_time == '' + + +class TestContact: + def test_tcp_contact(self, event_loop, tcp_c2): + BaseWorld.set_config('main', 'app.contact.tcp', '127.0.0.1:57012') + dummy_instruction = Instruction( + id='123', + sleep=5, + command='whoami', + executor='sh', + timeout=60, + payloads=[], + uploads=[], + deadman=False, + delete_payload=True + ) + tcp_c2.tcp_handler.sessions.append(TCPSession( + id=1, + paw='dummy_paw', + reader=_MockReader(), + writer=_MockWriter() + )) + event_loop.run_until_complete(tcp_c2.start()) + with mock.patch.object(ContactService, 'handle_heartbeat', return_value=('dummy_paw', [dummy_instruction])): + event_loop.run_until_complete(tcp_c2.handle_sessions()) + assert len(tcp_c2.tcp_handler.sessions) == 1