From 292c52cdb37bd66f7aa360c9c169e7db181f2231 Mon Sep 17 00:00:00 2001 From: uruwhy <58484522+uruwhy@users.noreply.github.com> Date: Mon, 9 Feb 2026 22:20:43 +0000 Subject: [PATCH] improve DNS contact unit test coverage --- app/contacts/contact_dns.py | 30 +-- tests/contacts/test_contact_dns.py | 304 ++++++++++++++++++++++++++--- 2 files changed, 290 insertions(+), 44 deletions(-) diff --git a/app/contacts/contact_dns.py b/app/contacts/contact_dns.py index d1b66369c..3ba05a441 100644 --- a/app/contacts/contact_dns.py +++ b/app/contacts/contact_dns.py @@ -82,8 +82,9 @@ def get_response_code(self): def __str__(self): return '\n'.join([ 'Qname: %s' % self.qname, + 'Is query: %s' % self.is_query(), 'Is response: %s' % self.is_response(), - 'Transaction ID: 0x%02x' % self.transaction_id, + 'Transaction ID: 0x%04x' % self.transaction_id, 'Flags: 0x%04x' % self.flags, 'Num questions: %d' % self.num_questions, 'Num answer resource records: %d' % self.num_answer_rrs, @@ -92,8 +93,8 @@ def __str__(self): 'Record type: %d' % self.record_type.value, 'Class: %d' % self.dns_class, 'Standard query: %s' % self.has_standard_query(), - 'Opcode: 0x%03x' % self.get_opcode(), - 'Response code: 0x%02x' % self.get_response_code(), + 'Opcode: 0x%04x' % self.get_opcode(), + 'Response code: 0x%04x' % self.get_response_code(), 'Recursion desired: %s' % self.recursion_desired(), 'Recursion available: %s' % self.recursion_available(), 'Truncated: %s' % self.truncated(), @@ -191,9 +192,10 @@ def get_bytes(self, byteorder='big'): + self._get_answer_bytes(byteorder=byteorder) def __str__(self): - output = [super().__str__(), 'Answers: '] + output = [super().__str__(), 'Answers:'] for answer in self.answers: - output.append(str(answer)) + answer_str_tabbed = '\n '.join(str(answer).split('\n')) + output.append(' ' + answer_str_tabbed + '\n') return '\n'.join(output) def _get_answer_bytes(self, byteorder='big'): @@ -202,22 +204,6 @@ def _get_answer_bytes(self, byteorder='big'): answer_bytes += answer.get_bytes(byteorder=byteorder) return answer_bytes - def _generate_pointer_and_qname_bytes(self, answer_qname, byteorder='big'): - lowered_answer_qname = answer_qname.lower() - lowered_requested_qname = self.qname.lower() - if lowered_answer_qname == lowered_requested_qname: - return self.standard_pointer.to_bytes(2, byteorder=byteorder) - elif lowered_answer_qname.endswith(lowered_requested_qname): - prefix = lowered_answer_qname[:-len(lowered_requested_qname)] - prefix_labels = [label for label in prefix.split('.') if label] - return self._get_qname_bytes(prefix_labels, byteorder=byteorder) \ - + self.standard_pointer.to_bytes(2, byteorder=byteorder) - elif lowered_requested_qname.endswith(lowered_answer_qname): - offset = len(lowered_requested_qname) - len(lowered_answer_qname) - return (self.standard_pointer + offset).to_bytes(2, byteorder=byteorder) - else: - return self._get_qname_bytes(answer_qname.split('.'), byteorder=byteorder) - @staticmethod def generate_response_for_query(dns_query, r_code, answers, authoritative=True, recursion_available=False, truncated=False): @@ -521,6 +507,8 @@ async def _process_payload_request(self, request_context): # Notify agent that payload is ready self.log.debug('Stored payload %s for request ID %s' % (display_name, request_context.request_id)) return self._generate_server_ready_ipv4_response(request_context.dns_request) + else: + self.log.warning('Failed to fetch file: %s' % filename) else: self.log.warning('Client did not include filename in payload request ID %s' % request_context.request_id) else: diff --git a/tests/contacts/test_contact_dns.py b/tests/contacts/test_contact_dns.py index 163a6f564..9643e71cb 100644 --- a/tests/contacts/test_contact_dns.py +++ b/tests/contacts/test_contact_dns.py @@ -3,24 +3,34 @@ import os import pytest import random +import shutil from base64 import b64decode from dns import message, rdatatype +from unittest import mock from app.contacts.contact_dns import Contact as DnsContact +from app.contacts.contact_dns import DnsPacket, DnsResponse, DnsAnswerObj, DnsRecordType, DnsResponseCodes +from app.objects.c_agent import Agent +from app.service.contact_svc import ContactService +from app.service.file_svc import FileSvc from app.utility.base_world import BaseWorld from app.utility.file_decryptor import read as decrypt_read, get_encryptor +DNS_EXFIL_DIR = '/tmp/testdnsexfil' + + @pytest.fixture(scope='session') def dns_contact_base_world(): + BaseWorld.clear_config() BaseWorld.apply_config(name='main', config={'app.contact.dns.domain': 'mycaldera.caldera', - 'app.contact.dns.socket': '0.0.0.0:53', + 'app.contact.dns.socket': '127.0.0.1:65053', 'plugins': ['sandcat', 'stockpile'], 'crypt_salt': 'BLAH', 'api_key': 'ADMIN123', 'encryption_key': 'ADMIN123', - 'exfil_dir': '/tmp'}) + 'exfil_dir': DNS_EXFIL_DIR}) BaseWorld.apply_config(name='agents', config={'sleep_max': 5, 'sleep_min': 5, 'untrusted_timer': 90, @@ -29,10 +39,14 @@ def dns_contact_base_world(): 'bootstrap_abilities': [ '43b3754c-def4-4699-a673-1d85648fda6a' ]}) + yield BaseWorld + BaseWorld.clear_config() + if os.path.exists(DNS_EXFIL_DIR): + shutil.rmtree(DNS_EXFIL_DIR) @pytest.fixture -async def dns_c2(app_svc, contact_svc, data_svc, file_svc, obfuscator): +async def dns_c2(app_svc, contact_svc, data_svc, file_svc, obfuscator, dns_contact_base_world): services = app_svc.get_services() dns_c2 = DnsContact(services) return dns_c2 @@ -95,6 +109,22 @@ async def _get_instruction_response(message_id): return _get_instruction_response +@pytest.fixture +async def get_payload_filename(random_data, get_dns_response): + async def _get_payload_filename(message_id): + qname = '%s.pf.0.1.%s.mycaldera.caldera' % (message_id, random_data) + return await get_dns_response(qname, 'txt') + return _get_payload_filename + + +@pytest.fixture +async def get_payload_data(random_data, get_dns_response): + async def _get_payload_data(message_id): + qname = '%s.pd.0.1.%s.mycaldera.caldera' % (message_id, random_data) + return await get_dns_response(qname, 'txt') + return _get_payload_data + + @pytest.fixture def get_hex_chunks(): def _get_hex_chunks(data): @@ -123,9 +153,111 @@ def _get_file_upload_data_qnames(message_id, data_hex_chunks): return _get_file_upload_data_qnames -@pytest.mark.usefixtures( - 'dns_contact_base_world' -) +@pytest.fixture +def get_payload_request_qnames(): + def _get_payload_request_qnames(message_id, data_hex_chunks): + num_chunks = len(data_hex_chunks) + return ['%s.pr.%d.%d.%s.mycaldera.caldera' % (message_id, i, num_chunks, data_hex_chunks[i]) + for i in range(0, num_chunks)] + return _get_payload_request_qnames + + +@pytest.fixture +def dns_dummy_agent(): + return Agent(paw='testpaw', sleep_min=5, sleep_max=5, watchdog=0, executors=['sh', 'proc']) + + +class TestDnsAuxiliary: + def test_generate_packets_from_bytes(self): + # Request + packet_bytes = bytes([ + 0x02, 0x83, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x06, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, + 0x00, 0x01, 0x00, 0x01 + ]) + + query_packet = DnsPacket.generate_packet_from_bytes(packet_bytes) + want_str = '''Qname: google.com +Is query: True +Is response: False +Transaction ID: 0x0283 +Flags: 0x0100 +Num questions: 1 +Num answer resource records: 0 +Num auth resource records: 0 +Num additional resource records: 0 +Record type: 1 +Class: 1 +Standard query: True +Opcode: 0x0000 +Response code: 0x0000 +Recursion desired: True +Recursion available: False +Truncated: False''' + assert str(query_packet) == want_str + + # Response + packet_bytes = bytes([ + 0x02, 0x83, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x06, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x69, 0x00, 0x04, 0xac, 0xfd, 0x8b, 0x8a, 0xc0, 0x0c, 0x00, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x69, 0x00, 0x04, 0xac, 0xfd, 0x8b, 0x71, 0xc0, 0x0c, 0x00, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x69, 0x00, 0x04, 0xac, 0xfd, 0x8b, 0x64, 0xc0, 0x0c, 0x00, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x69, 0x00, 0x04, 0xac, 0xfd, 0x8b, 0x65, 0xc0, 0x0c, 0x00, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x69, 0x00, 0x04, 0xac, 0xfd, 0x8b, 0x66, 0xc0, 0x0c, 0x00, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x69, 0x00, 0x04, 0xac, 0xfd, 0x8b, 0x8b + ]) + resp_packet = DnsPacket.generate_packet_from_bytes(packet_bytes) + want_str = '''Qname: google.com +Is query: False +Is response: True +Transaction ID: 0x0283 +Flags: 0x8180 +Num questions: 1 +Num answer resource records: 1 +Num auth resource records: 0 +Num additional resource records: 0 +Record type: 1 +Class: 1 +Standard query: True +Opcode: 0x0000 +Response code: 0x0000 +Recursion desired: True +Recursion available: True +Truncated: False''' + assert str(resp_packet) == want_str + + dummy_answer = DnsAnswerObj(DnsRecordType.A, 0x1, 105, bytes([0xac, 0xfd, 0x8b, 0x8a])) + response = DnsResponse.generate_response_for_query(query_packet, DnsResponseCodes.SUCCESS, [dummy_answer], authoritative=False, + recursion_available=True, truncated=False) + want_str = '''Qname: google.com +Is query: False +Is response: True +Transaction ID: 0x0283 +Flags: 0x8180 +Num questions: 1 +Num answer resource records: 1 +Num auth resource records: 0 +Num additional resource records: 0 +Record type: 1 +Class: 1 +Standard query: True +Opcode: 0x0000 +Response code: 0x0000 +Recursion desired: True +Recursion available: True +Truncated: False +Answers: + Record type: 1 + Dns class: 1 + TTL: 105 + Data: acfd8b8a + Data length: 4 +''' + assert str(response) == want_str + + class TestContactDns: _RCODE_NXDOMAIN = 3 _RCODE_SUCCESS = 0 @@ -139,6 +271,11 @@ def _assert_successful_ivp4(response_msg): assert len(response_msg.answer[0]) == 1 assert response_msg.answer[0][0].rdtype == rdatatype.RdataType.A + @staticmethod + def _assert_nxdomain_response(response_msg): + assert response_msg and response_msg.rcode() == TestContactDns._RCODE_NXDOMAIN + assert len(response_msg.answer) == 0 + @staticmethod def _assert_even_ipv4(response_msg): # Last octet should be even if the server is expecting more data @@ -159,7 +296,7 @@ def test_handler_setup(self, dns_c2): async def test_non_c2_domain_message(self, get_dns_response): response_msg = await get_dns_response('notthec2domain', 'a') - assert response_msg and response_msg.rcode() == self._RCODE_NXDOMAIN + self._assert_nxdomain_response(response_msg) async def test_partial_beacon_message(self, get_dns_response, get_beacon_profile_qnames, message_id): first_qname = get_beacon_profile_qnames(message_id)[0] @@ -181,13 +318,14 @@ async def test_completed_beacon_message(self, get_dns_response, get_beacon_profi self._assert_even_ipv4(response_msg) async def test_instruction_download(self, get_dns_response, get_beacon_profile_qnames, message_id, - get_instruction_response): - # Send beacon before asking for instructions - for qname in get_beacon_profile_qnames(message_id): - await get_dns_response(qname, 'a') - - # Get instructions - response_msg = await get_instruction_response(message_id) + get_instruction_response, dns_dummy_agent): + with mock.patch.object(ContactService, 'handle_heartbeat', return_value=(dns_dummy_agent, [])): + # Send beacon before asking for instructions + for qname in get_beacon_profile_qnames(message_id): + await get_dns_response(qname, 'a') + + # Get instructions + response_msg = await get_instruction_response(message_id) assert response_msg and response_msg.rcode() == self._RCODE_SUCCESS # Make sure we only get 1 TXT record @@ -200,32 +338,129 @@ async def test_instruction_download(self, get_dns_response, get_beacon_profile_q # Last character should be , if returning complete instructions assert txt_response[-1] == ',' - beacon_resp = json.loads(b64decode(txt_response).decode('utf-8')) + beacon_resp = json.loads(b64decode(txt_response[:-1]).decode('utf-8')) assert 'paw' in beacon_resp - want = dict(paw=beacon_resp.get('paw'), + want = dict(paw='testpaw', sleep=5, watchdog=0, instructions='[]') assert want == beacon_resp + async def test_payload_download(self, get_dns_response, get_hex_chunks, get_payload_request_qnames, get_payload_filename, + get_payload_data, message_id): + dummy_payload_data = bytes([0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef]) + with mock.patch.object(FileSvc, 'get_file', return_value=('testplugin/payloads/testdownload', dummy_payload_data, 'testdownload')): + # Request payload + filename = 'testdownload' + req_metadata = dict(file=filename) + metadata_hex_chunks = get_hex_chunks(json.dumps(req_metadata).encode('utf-8')) + metadata_qnames = get_payload_request_qnames(message_id, metadata_hex_chunks) + final_index = len(metadata_qnames) - 1 + + for index, qname in enumerate(metadata_qnames): + response_msg = await get_dns_response(qname, 'a') + assert response_msg and response_msg.rcode() == self._RCODE_SUCCESS + + # Check final octet + if index == final_index: + self._assert_odd_ipv4(response_msg) + else: + self._assert_even_ipv4(response_msg) + + # Fetch payload name + response_msg = await get_payload_filename(message_id) + assert response_msg and response_msg.rcode() == self._RCODE_SUCCESS + + # Make sure we only get 1 TXT record + assert len(response_msg.answer) == 1 + assert len(response_msg.answer[0]) == 1 + answer = response_msg.answer[0][0] + assert answer.rdtype == rdatatype.RdataType.TXT + assert len(answer.strings) == 1 + txt_response = answer.strings[0].decode('utf-8') + + # Last character should be , if returning complete instructions + assert txt_response[-1] == ',' + assert filename == b64decode(txt_response[:-1]).decode('utf-8') + + # Fetch payload data + response_msg = await get_payload_data(message_id) + assert response_msg and response_msg.rcode() == self._RCODE_SUCCESS + + # Make sure we only get 1 TXT record + assert len(response_msg.answer) == 1 + assert len(response_msg.answer[0]) == 1 + answer = response_msg.answer[0][0] + assert answer.rdtype == rdatatype.RdataType.TXT + assert len(answer.strings) == 1 + txt_response = answer.strings[0].decode('utf-8') + + # Last character should be , if returning complete instructions + assert txt_response[-1] == ',' + assert dummy_payload_data == b64decode(txt_response[:-1]) + + async def test_bad_payload_download(self, get_dns_response, get_hex_chunks, get_payload_request_qnames, message_id): + # Test file service exceptions + filename = 'testdownload' + req_metadata = dict(file=filename) + metadata_hex_chunks = get_hex_chunks(json.dumps(req_metadata).encode('utf-8')) + metadata_qnames = get_payload_request_qnames(message_id, metadata_hex_chunks) + final_index = len(metadata_qnames) - 1 + + with mock.patch.object(FileSvc, 'get_file', side_effect=FileNotFoundError('Dummy error')): + for index, qname in enumerate(metadata_qnames): + response_msg = await get_dns_response(qname, 'a') + + # Check final octet + if index == final_index: + self._assert_nxdomain_response(response_msg) + else: + self._assert_even_ipv4(response_msg) + + with mock.patch.object(FileSvc, 'get_file', side_effect=Exception('Dummy error')): + for index, qname in enumerate(metadata_qnames): + response_msg = await get_dns_response(qname, 'a') + + # Check final octet + if index == final_index: + self._assert_nxdomain_response(response_msg) + else: + self._assert_even_ipv4(response_msg) + + # Test bad requests + req_metadata = [dict(), dict(a='irrelevant')] + for metadata in req_metadata: + metadata_hex_chunks = get_hex_chunks(json.dumps(metadata).encode('utf-8')) + metadata_qnames = get_payload_request_qnames(message_id, metadata_hex_chunks) + final_index = len(metadata_qnames) - 1 + for index, qname in enumerate(metadata_qnames): + response_msg = await get_dns_response(qname, 'a') + + # Check final octet + if index == final_index: + self._assert_nxdomain_response(response_msg) + else: + self._assert_even_ipv4(response_msg) + async def test_unsupported_client_request(self, get_dns_response, message_id, random_data): invalid_qname = '%s.invalid.0.1.%s.mycaldera.caldera' % (message_id, random_data) response_msg = await get_dns_response(invalid_qname, 'a') - assert response_msg and response_msg.rcode() == self._RCODE_NXDOMAIN + self._assert_nxdomain_response(response_msg) async def test_invalid_instruction_request(self, get_dns_response, message_id, random_data): invalid_qname = '%s.id.0.1.%s.mycaldera.caldera' % (message_id, random_data) response_msg = await get_dns_response(invalid_qname, 'a') # Should be TXT request - assert response_msg and response_msg.rcode() == self._RCODE_NXDOMAIN + self._assert_nxdomain_response(response_msg) async def test_file_upload(self, get_dns_response, message_id, get_hex_chunks, get_file_upload_metadata_qnames, - get_file_upload_data_qnames): + get_file_upload_data_qnames, dns_c2): + dns_c2.set_config('main', 'exfil_dir', DNS_EXFIL_DIR) paw = 'asdasd' filename = 'testupload.txt' hostname = 'testhost' directory = '%s-%s' % (hostname, paw) upload_metadata = dict(paw=paw, file=filename, directory=directory) - target_dir = '/tmp/%s' % directory + target_dir = f'{DNS_EXFIL_DIR}/{directory}' target_path = '%s/%s-%s' % (target_dir, filename, message_id) file_data = b'thiswilltakemultiplednsrequests' * 100 metadata_hex_chunks = get_hex_chunks(json.dumps(upload_metadata).encode('utf-8')) @@ -268,10 +503,33 @@ async def test_file_upload(self, get_dns_response, message_id, get_hex_chunks, g assert (not decrypt_error), 'Exception occurred when decrypting uploaded file: %s' % decrypt_error assert file_data == decrypted_upload + async def test_bad_file_upload(self, get_dns_response, message_id, get_hex_chunks, get_file_upload_metadata_qnames, + get_file_upload_data_qnames): + # Test missing info + upload_metadata = [dict(paw='test'), dict()] + for metadata in upload_metadata: + metadata_hex_chunks = get_hex_chunks(json.dumps(metadata).encode('utf-8')) + metadata_qnames = get_file_upload_metadata_qnames(message_id, metadata_hex_chunks) + final_index = len(metadata_qnames) - 1 + for index, qname in enumerate(metadata_qnames): + response_msg = await get_dns_response(qname, 'a') + + if index == final_index: + self._assert_nxdomain_response(response_msg) + else: + self._assert_successful_ivp4(response_msg) + self._assert_even_ipv4(response_msg) + + async def test_ipv6_placeholder(self, dns_c2, get_dns_response): + response_msg = await get_dns_response('test.mycaldera.caldera', rdatatype.RdataType.AAAA) + assert response_msg and response_msg.rcode() == TestContactDns._RCODE_SUCCESS + + # Make sure we got back an IPv6 address + assert len(response_msg.answer) == 1 + assert len(response_msg.answer[0]) == 1 + assert response_msg.answer[0][0].rdtype == rdatatype.RdataType.AAAA + @staticmethod def _get_decrypted_upload(filepath): encryptor = get_encryptor('BLAH', 'ADMIN123') return decrypt_read(filepath, encryptor) - - def test_unexpected_file_upload(self): - assert True