Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion langfun/core/modalities/mime.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,18 @@ def is_compatible(
def _is_compatible(self, mime_types: Iterable[str]):
if not mime_types:
return False
return self.mime_type in mime_types
if self.mime_type in mime_types:
return True
# Fallback: if text/plain is accepted, check if content is valid UTF-8.
# This handles files with misdetected MIME types (e.g., .ts detected as
# video/mp2t) that are actually text.
if 'text/plain' in mime_types:
try:
self.to_bytes().decode('utf-8')
return True
except Exception: # pylint: disable=broad-exception-caught
pass
return False

def make_compatible(
self,
Expand All @@ -175,6 +186,10 @@ def make_compatible(
f'MIME type {self.mime_type!r} cannot be converted to supported '
f'types: {mime_types!r}.'
)
# If compatibility was achieved via the UTF-8 text fallback (not exact MIME
# match), wrap content as text/plain so the LLM receives the correct type.
if self.mime_type not in mime_types and 'text/plain' in mime_types:
return Custom(mime='text/plain', content=self.to_bytes())
return self._make_compatible(mime_types)

def _make_compatible(
Expand Down
58 changes: 58 additions & 0 deletions langfun/core/modalities/mime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,63 @@ def test_binary_mime_type_raises_error(self):
content.to_text()


class TextCompatibilityTest(unittest.TestCase):

def test_exact_mime_match(self):
content = mime.Custom('text/plain', b'hello')
self.assertTrue(content.is_compatible('text/plain'))
result = content.make_compatible('text/plain')
self.assertIs(result, content)

def test_text_content_with_misdetected_mime_type(self):
ts_code = b'import type { Foo } from "bar";\nexport const x = 1;\n'
content = mime.Custom('video/mp2t', ts_code)
self.assertTrue(content.is_compatible(['text/plain', 'image/png']))
result = content.make_compatible(['text/plain', 'image/png'])
self.assertIsInstance(result, mime.Custom)
self.assertEqual(result.mime_type, 'text/plain')
self.assertEqual(result.to_bytes(), ts_code)

def test_text_mime_not_in_supported_list(self):
content = mime.Custom('text/x-typescript', b'const x: number = 1;\n')
self.assertTrue(content.is_compatible(['text/plain']))
result = content.make_compatible(['text/plain'])
self.assertEqual(result.mime_type, 'text/plain')

def test_application_text_compatible_with_text_plain(self):
content = mime.Custom('application/x-yaml', b'key: value\n')
self.assertTrue(content.is_compatible(['text/plain']))
result = content.make_compatible(['text/plain'])
self.assertEqual(result.mime_type, 'text/plain')

def test_binary_content_not_compatible_with_text_plain(self):
binary_data = bytes(range(256))
content = mime.Custom('application/octet-stream', binary_data)
self.assertFalse(content.is_compatible(['text/plain']))
with self.assertRaises(lf.ModalityError):
content.make_compatible(['text/plain'])

def test_no_fallback_without_text_plain_in_targets(self):
ts_code = b'const x = 1;\n'
content = mime.Custom('video/mp2t', ts_code)
self.assertFalse(content.is_compatible(['image/png', 'audio/wav']))
with self.assertRaises(lf.ModalityError):
content.make_compatible(['image/png', 'audio/wav'])

def test_make_compatible_preserves_content(self):
original = b'# Markdown\n\nHello world\n'
content = mime.Custom('text/markdown', original)
result = content.make_compatible(['text/plain'])
self.assertEqual(result.to_bytes(), original)
self.assertEqual(result.to_text(), '# Markdown\n\nHello world\n')

def test_unicode_content_compatible_with_text_plain(self):
unicode_bytes = 'こんにちは世界 🎉'.encode('utf-8')
content = mime.Custom('video/mp2t', unicode_bytes)
self.assertTrue(content.is_compatible(['text/plain']))
result = content.make_compatible(['text/plain'])
self.assertEqual(result.to_text(), 'こんにちは世界 🎉')


if __name__ == '__main__':
unittest.main()
Loading