diff --git a/llsd/base.py b/llsd/base.py index e7204ca..2e8c2ca 100644 --- a/llsd/base.py +++ b/llsd/base.py @@ -410,14 +410,24 @@ def _reset(self, something): # string is so large that the overhead of copying it into a # BytesIO is significant, advise caller to pass a stream instead. self._stream = io.BytesIO(something) - elif something.seekable(): - # 'something' is already a seekable stream, use directly - self._stream = something + elif isinstance(something, io.IOBase): + # 'something' is a proper IO stream - must be seekable for parsing + if something.seekable(): + self._stream = something + else: + raise LLSDParseError( + "Cannot parse LLSD from non-seekable stream." + ) else: - # 'something' isn't seekable, wrap in BufferedReader - # (let BufferedReader handle the problem of passing an - # inappropriate object) - self._stream = io.BufferedReader(something) + # Invalid input type - raise a clear error + # This catches MagicMock and other non-stream objects that might + # have read/seek attributes but aren't actual IO streams + raise LLSDParseError( + "Cannot parse LLSD from {0}. " + "Expected bytes or a seekable io.IOBase object.".format( + type(something).__name__ + ) + ) def starts_with(self, pattern): """ diff --git a/tests/llsd_test.py b/tests/llsd_test.py index 073a974..f78577f 100644 --- a/tests/llsd_test.py +++ b/tests/llsd_test.py @@ -1977,3 +1977,59 @@ def test_uuid_map_key(self): self.assertEqual(llsd.format_notation(llsdmap), b"{'00000000-0000-0000-0000-000000000000':'uuid'}") +class InvalidInputTypes(unittest.TestCase): + ''' + Tests for handling invalid input types that should raise LLSDParseError + instead of hanging or consuming infinite memory. + ''' + + @unittest.skipIf(PY2, "MagicMock requires Python 3") + def test_parse_magicmock_raises_error(self): + ''' + Parsing a MagicMock object should raise LLSDParseError, not hang. + This is a regression test for a bug where llsd.parse() would go into + an infinite loop when passed a MagicMock (e.g., from an improperly + mocked requests.Response.content). + ''' + from unittest.mock import MagicMock + mock = MagicMock() + with self.assertRaises(llsd.LLSDParseError) as context: + llsd.parse(mock) + self.assertIn('MagicMock', str(context.exception)) + + def test_parse_string_raises_error(self): + ''' + Parsing a string (not bytes) should raise LLSDParseError. + Only applies to Python 3 where str and bytes are distinct. + ''' + with self.assertRaises(llsd.LLSDParseError) as context: + llsd.parse(b'not bytes'.decode('ascii')) + self.assertIn('unicode' if PY2 else 'str', str(context.exception)) + + def test_parse_none_raises_error(self): + ''' + Parsing None should raise LLSDParseError. + ''' + with self.assertRaises(llsd.LLSDParseError) as context: + llsd.parse(None) + self.assertIn('NoneType', str(context.exception)) + + def test_parse_int_raises_error(self): + ''' + Parsing an integer should raise LLSDParseError. + ''' + with self.assertRaises(llsd.LLSDParseError) as context: + llsd.parse(42) + self.assertIn('int', str(context.exception)) + + def test_parse_non_seekable_stream_raises_error(self): + ''' + Parsing a non-seekable stream should raise LLSDParseError. + ''' + stream = io.BytesIO() + stream.seekable = lambda: False + with self.assertRaises(llsd.LLSDParseError) as context: + llsd.parse(stream) + self.assertIn('non-seekable', str(context.exception)) + +