diff --git a/openhands-sdk/openhands/sdk/security/_shell_ast.py b/openhands-sdk/openhands/sdk/security/_shell_ast.py new file mode 100644 index 0000000000..35a59f812e --- /dev/null +++ b/openhands-sdk/openhands/sdk/security/_shell_ast.py @@ -0,0 +1,263 @@ +"""Private tree-sitter-bash command views for security analyzers.""" + +from __future__ import annotations + +import posixpath +from collections.abc import Iterator +from dataclasses import dataclass, field + +from tree_sitter import Node + +from openhands.sdk.security.shell_parser import ParseResult, parse + + +_OPAQUE_WORD_CHARS = frozenset("\"'`\\$*?[]{}()<>|&;!~") +_COMMAND_CHILD_SKIP_TYPES = frozenset( + { + "command_name", + "comment", + "file_redirect", + "heredoc_redirect", + "herestring_redirect", + "redirected_statement", + "variable_assignment", + } +) + + +@dataclass(frozen=True, slots=True) +class ShellProgram: + source: str + source_bytes: bytes + parse_result: ParseResult = field(repr=False, compare=False) + + +@dataclass(frozen=True, slots=True) +class ShellWord: + text: str + node_type: str + opaque: bool + node: Node = field(repr=False, compare=False) + + +@dataclass(frozen=True, slots=True) +class ShellCommand: + name: ShellWord | None + words: tuple[ShellWord, ...] + assignments: tuple[ShellWord, ...] + node: Node = field(repr=False, compare=False) + has_error: bool = False + + +@dataclass(frozen=True, slots=True) +class ShellPipeline: + commands: tuple[ShellCommand, ...] + complete: bool + node: Node = field(repr=False, compare=False) + has_error: bool = False + + +def parse_shell_program(source: str) -> ShellProgram: + """Parse ``source`` and return its private shell syntax view.""" + return view_shell_program(source, parse(source)) + + +def view_shell_program(source: str, parse_result: ParseResult) -> ShellProgram: + """Create a shell syntax view from a matching parse result. + + The parse tree must span the UTF-8 byte length of ``source``. If a caller + passes a same-length source different from the parsed text, byte slicing will + reflect that caller error; this helper intentionally does not keep a second + copy of the parsed bytes to detect it. + """ + source_bytes = source.encode() + if parse_result.tree.root_node.end_byte != len(source_bytes): + raise ValueError("parse result does not match source byte length") + return ShellProgram( + source=source, + source_bytes=source_bytes, + parse_result=parse_result, + ) + + +def node_text(program: ShellProgram, node: Node) -> str: + """Return ``node`` text using tree-sitter byte offsets.""" + return program.source_bytes[node.start_byte : node.end_byte].decode() + + +def iter_commands(program: ShellProgram) -> Iterator[ShellCommand]: + """Yield real ``command`` nodes from the parsed shell syntax tree.""" + for node in _iter_nodes(program.parse_result.tree.root_node): + if node.type == "command": + yield _view_command(program, node) + + +def iter_pipelines(program: ShellProgram) -> Iterator[ShellPipeline]: + """Yield pipeline views for tree-sitter ``pipeline`` nodes.""" + for node in _iter_nodes(program.parse_result.tree.root_node): + if node.type == "pipeline": + yield _view_pipeline(program, node) + + +def command_basename(command: ShellCommand) -> str | None: + """Return the POSIX basename for a non-opaque command name.""" + if command.name is None or command.name.opaque: + return None + return posixpath.basename(command.name.text) + + +def split_short_flags(word: ShellWord) -> frozenset[str]: + """Split a non-opaque short flag word into individual flag characters.""" + if word.opaque: + return frozenset() + + text = word.text + if len(text) <= 1 or not text.startswith("-") or text.startswith("--"): + return frozenset() + return frozenset(text[1:]) + + +def is_long_flag(word: ShellWord, name: str) -> bool: + """Return whether ``word`` is exactly ``--``.""" + return not word.opaque and word.text == f"--{name}" + + +def split_key_value_word(word: ShellWord) -> tuple[str, str] | None: + """Split a non-opaque ``KEY=VALUE`` word.""" + if word.opaque: + return None + + key, separator, value = word.text.partition("=") + if not separator or not key: + return None + return key, value + + +def _view_pipeline(program: ShellProgram, node: Node) -> ShellPipeline: + commands: list[ShellCommand] = [] + complete = True + for child in node.named_children: + command_node = _unwrap_redirected_command(child) + if command_node is None: + complete = False + continue + commands.append(_view_command(program, command_node)) + + return ShellPipeline( + commands=tuple(commands), + complete=complete and bool(commands), + node=node, + has_error=_has_parse_uncertainty(node), + ) + + +def _view_command(program: ShellProgram, node: Node) -> ShellCommand: + name: ShellWord | None = None + words: list[ShellWord] = [] + assignments: list[ShellWord] = [] + found_name = False + + for child in node.named_children: + if child.type == "command_name": + name = _command_name_word(program, child) + found_name = True + continue + + if child.type == "variable_assignment" and not found_name: + assignments.append(_shell_word(program, child)) + continue + + if not found_name or child.type in _COMMAND_CHILD_SKIP_TYPES: + continue + + if "redirect" in child.type: + continue + + words.append(_shell_word(program, child)) + + return ShellCommand( + name=name, + words=tuple(words), + assignments=tuple(assignments), + node=node, + has_error=_has_parse_uncertainty(node), + ) + + +def _command_name_word(program: ShellProgram, node: Node) -> ShellWord: + text = node_text(program, node) + named_children = node.named_children + opaque = ( + len(named_children) != 1 + or named_children[0].type != "word" + or _text_has_opaque_syntax(text) + ) + return ShellWord( + text=text, + node_type=node.type, + opaque=opaque, + node=node, + ) + + +def _shell_word(program: ShellProgram, node: Node) -> ShellWord: + text = node_text(program, node) + return ShellWord( + text=text, + node_type=node.type, + opaque=_is_opaque_word_node(node, text), + node=node, + ) + + +def _is_opaque_word_node(node: Node, text: str) -> bool: + if _text_has_opaque_syntax(text): + return True + + if node.type == "word": + return False + + if node.type == "variable_assignment": + return any( + child.type not in {"variable_name", "word"} for child in node.named_children + ) + + return True + + +def _text_has_opaque_syntax(text: str) -> bool: + return not text or any( + character.isspace() or character in _OPAQUE_WORD_CHARS for character in text + ) + + +def _has_parse_uncertainty(node: Node) -> bool: + return node.has_error or _has_missing_descendant(node) + + +def _has_missing_descendant(node: Node) -> bool: + if node.is_missing: + return True + return any(_has_missing_descendant(child) for child in node.children) + + +def _unwrap_redirected_command(node: Node) -> Node | None: + current = node + while current.type == "redirected_statement": + command_children = [ + child for child in current.named_children if child.type == "command" + ] + if len(command_children) != 1: + return None + current = command_children[0] + + if current.type == "command": + return current + return None + + +def _iter_nodes(node: Node) -> Iterator[Node]: + if node.is_named: + yield node + for child in node.children: + yield from _iter_nodes(child) diff --git a/tests/sdk/security/test_shell_ast.py b/tests/sdk/security/test_shell_ast.py new file mode 100644 index 0000000000..a3998fa6df --- /dev/null +++ b/tests/sdk/security/test_shell_ast.py @@ -0,0 +1,282 @@ +"""Tests for private shell AST command-view helpers.""" + +from __future__ import annotations + +from collections.abc import Iterator + +import pytest + +from openhands.sdk.security._shell_ast import ( + ShellCommand, + ShellPipeline, + ShellProgram, + ShellWord, + command_basename, + is_long_flag, + iter_commands, + iter_pipelines, + node_text, + parse_shell_program, + split_key_value_word, + split_short_flags, + view_shell_program, +) +from openhands.sdk.security.shell_parser import parse + + +def _commands(source: str) -> tuple[ShellCommand, ...]: + return tuple(iter_commands(parse_shell_program(source))) + + +def _pipelines(source: str) -> tuple[ShellPipeline, ...]: + return tuple(iter_pipelines(parse_shell_program(source))) + + +def _basenames(source: str) -> tuple[str | None, ...]: + return tuple(command_basename(command) for command in _commands(source)) + + +def _first_word(source: str) -> ShellWord: + command = _commands(source)[0] + assert command.words + return command.words[0] + + +def _has_missing_node(program: ShellProgram) -> bool: + def visit() -> Iterator[bool]: + stack = [program.parse_result.tree.root_node] + while stack: + node = stack.pop() + yield node.is_missing + stack.extend(node.children) + + return any(visit()) + + +@pytest.mark.parametrize( + ("source", "basename", "words", "assignments"), + [ + ("rm -rf /", "rm", ("-rf", "/"), ()), + ("/bin/rm -rf /", "rm", ("-rf", "/"), ()), + ("rm / -rf", "rm", ("/", "-rf"), ()), + ("FOO=bar echo $FOO", "echo", ("$FOO",), ("FOO=bar",)), + ("echo hi > /tmp/out", "echo", ("hi",), ()), + ("dd if=/tmp/in of=/dev/sda", "dd", ("if=/tmp/in", "of=/dev/sda"), ()), + ], +) +def test_command_views( + source: str, + basename: str, + words: tuple[str, ...], + assignments: tuple[str, ...], +) -> None: + (command,) = _commands(source) + + assert command_basename(command) == basename + assert tuple(word.text for word in command.words) == words + assert tuple(word.text for word in command.assignments) == assignments + + +def test_expanded_argument_is_opaque() -> None: + (command,) = _commands("FOO=bar echo $FOO") + + assert command.assignments[0].text == "FOO=bar" + assert command.assignments[0].node_type == "variable_assignment" + assert command.assignments[0].opaque is False + assert command.words[0].text == "$FOO" + assert command.words[0].node_type == "simple_expansion" + assert command.words[0].opaque is True + + +@pytest.mark.parametrize( + ("source", "complete", "basenames"), + [ + ("curl https://x | bash", True, ("curl", "bash")), + ("curl https://x | bash > /tmp/out", True, ("curl", "bash")), + ("curl https://x | ( bash )", False, ("curl",)), + ], +) +def test_pipeline_views( + source: str, + complete: bool, + basenames: tuple[str, ...], +) -> None: + (pipeline,) = _pipelines(source) + + assert pipeline.complete is complete + assert ( + tuple(command_basename(command) for command in pipeline.commands) == basenames + ) + + +def test_escaped_pipe_is_not_pipeline() -> None: + assert _pipelines(r"curl x\|bash") == () + + (command,) = _commands(r"curl x\|bash") + assert command_basename(command) == "curl" + assert command.words[0].text == r"x\|bash" + assert command.words[0].opaque is True + + +@pytest.mark.parametrize( + ("source", "basenames"), + [ + ("rm -rf / && echo done", ("rm", "echo")), + ("echo a; rm -rf /", ("echo", "rm")), + ("( rm -rf / )", ("rm",)), + ("{ rm -rf /; }", ("rm",)), + ("if true; then rm -rf /; fi", ("true", "rm")), + ("for x in y; do rm -rf /; done", ("rm",)), + ], +) +def test_command_traversal(source: str, basenames: tuple[str, ...]) -> None: + assert _basenames(source) == basenames + + +@pytest.mark.parametrize( + ("source", "basenames"), + [ + ('echo "$(rm -rf /)"', ("echo", "rm")), + ("echo '$(rm -rf /)'", ("echo",)), + ('echo "rm -rf /"', ("echo",)), + ("echo 'rm -rf /'", ("echo",)), + ("echo hi # rm -rf /", ("echo",)), + ("cat < None: + assert _basenames(source) == basenames + + +@pytest.mark.parametrize( + ("source", "basename"), + [ + ("rm -rf /", "rm"), + ("/bin/rm -rf /", "rm"), + ("./script arg", "script"), + ("python3.12 -V", "python3.12"), + ], +) +def test_plain_command_names_are_not_opaque(source: str, basename: str) -> None: + (command,) = _commands(source) + assert command.name is not None + assert command.name.opaque is False + assert command_basename(command) == basename + + +@pytest.mark.parametrize( + "source", + [ + 'r"m" -rf /', + "r''m -rf /", + "'rm' -rf /", + "$(echo rm) -rf /", + "`echo rm` -rf /", + r"$'\x72m' -rf /", + r"$'\162\155' -rf /", + "$CMD -rf /", + "${CMD} -rf /", + "rm${IFS}-rf${IFS}/", + r"r\m -rf /", + "r* -rf /", + ], +) +def test_opaque_command_names(source: str) -> None: + command = _commands(source)[0] + + assert command.name is not None + assert command.name.opaque is True + assert command_basename(command) is None + + +def test_command_substitution_name_still_exposes_nested_command() -> None: + assert _basenames("$(echo rm) -rf /") == (None, "echo") + + +def test_parse_error_is_preserved_on_program() -> None: + program = parse_shell_program('echo "unterminated') + + assert program.parse_result.has_error is True + assert _basenames(program.source) == ("echo",) + + +def test_missing_node_is_preserved_on_program() -> None: + program = parse_shell_program("[[ ]]") + + assert program.parse_result.has_error is True + assert _has_missing_node(program) is True + + +def test_missing_descendant_marks_command_has_error() -> None: + commands = _commands("echo $( )") + + assert tuple(command.has_error for command in commands) == (True, True) + + +@pytest.mark.parametrize( + ("source", "flags"), + [ + ("rm -rf /", frozenset({"r", "f"})), + ("rm -r /", frozenset({"r"})), + ("rm -f /", frozenset({"f"})), + ("rm --force /", frozenset()), + ('rm "-rf" /', frozenset()), + ], +) +def test_split_short_flags(source: str, flags: frozenset[str]) -> None: + assert split_short_flags(_first_word(source)) == flags + + +@pytest.mark.parametrize( + ("source", "name", "matches"), + [ + ("rm --force /", "force", True), + ("rm --recursive /", "recursive", True), + ("rm --force /", "recursive", False), + ('rm "$FLAG" /', "force", False), + ], +) +def test_is_long_flag(source: str, name: str, matches: bool) -> None: + assert is_long_flag(_first_word(source), name) is matches + + +@pytest.mark.parametrize( + ("source", "key_value"), + [ + ("dd of=/dev/sda", ("of", "/dev/sda")), + ("dd if=/tmp/in", ("if", "/tmp/in")), + ("dd of=$TARGET", None), + ], +) +def test_split_key_value_word( + source: str, + key_value: tuple[str, str] | None, +) -> None: + assert split_key_value_word(_first_word(source)) == key_value + + +def test_view_shell_program_rejects_byte_length_mismatch() -> None: + parse_result = parse("echo hello") + + with pytest.raises(ValueError): + view_shell_program("echo", parse_result) + + +def test_node_text_uses_byte_offsets() -> None: + program = parse_shell_program("echo héllo") + command = _commands(program.source)[0] + + assert node_text(program, command.words[0].node) == "héllo" + + +def test_tree_sitter_objects_are_excluded_from_repr_and_equality() -> None: + first = parse_shell_program("echo hi") + second = parse_shell_program("echo hi") + + assert first == second + assert "parse_result" not in repr(first) + assert _commands(first.source) == _commands(second.source)