From 1928401299efa5cbcc6eb13ea0de02b61552d016 Mon Sep 17 00:00:00 2001 From: Matthieu Meeus Date: Wed, 6 May 2026 06:41:53 -0700 Subject: [PATCH] Expanding tree edit similarity to different languages (#128) Summary: Extend tree-edit similarity (`PyTreeSitterAttack`) to support more coding languages, in particular the ones already covered by CodeBLEU (Python, C, C++, Java, Rust, JavaScript, Go, Ruby, PHP, C#). Previously only Python and C++ were supported. This diff introduces: - Unified and extended grammar backend in `py_tree_sitter_attack.py`: Replaces the standalone tree-sitter-python and tree-sitter-cpp packages with the codebleu package's bundled `my-languages.so`. This provides a single grammar library covering all 10 languages, is consistent with the `CodeBleuAttack` module and simplifies `_get_parser()` to a 3-line function. Verified that the codebleu-bundled Python grammar produces identical trees (zero edit distance) to the previous implementation. - No changes to the analysis layer: `TreeEditDistanceNode` is already language-agnostic, it operates purely on zss Node trees. Reviewed By: mgrange1998 Differential Revision: D102700637 --- .../tests/test_tree_edit_distance_node.py | 92 +++- .../code_similarity/py_tree_sitter_attack.py | 61 +-- .../tests/test_py_tree_sitter_attack.py | 399 ++++++++++++++++++ 3 files changed, 508 insertions(+), 44 deletions(-) diff --git a/privacy_guard/analysis/tests/test_tree_edit_distance_node.py b/privacy_guard/analysis/tests/test_tree_edit_distance_node.py index ae4dc1d..24bb010 100644 --- a/privacy_guard/analysis/tests/test_tree_edit_distance_node.py +++ b/privacy_guard/analysis/tests/test_tree_edit_distance_node.py @@ -115,28 +115,108 @@ def test_similarity_values(self) -> None: output = _run_e2e(df) self.assertAlmostEqual(output.avg_similarity, 1.0, places=5) + def test_java_similarity(self) -> None: + """Identical Java code yields ~1.0; structurally similar code is high.""" + with self.subTest("identical"): + code = "class Foo { int add(int a, int b) { return a + b; } }" + df = pd.DataFrame( + { + "target_code_string": [code], + "model_generated_code_string": [code], + } + ) + output = _run_e2e(df, default_language="java") + self.assertAlmostEqual(output.avg_similarity, 1.0, places=5) + + with self.subTest("similar"): + df = pd.DataFrame( + { + "target_code_string": [ + "class Foo { int add(int a, int b) { return a + b; } }" + ], + "model_generated_code_string": [ + "class Bar { int sum(int x, int y) { return x + y; } }" + ], + } + ) + output = _run_e2e(df, default_language="java") + self.assertGreater(output.avg_similarity, 0.7) + + def test_c_similarity(self) -> None: + """Identical C code yields ~1.0.""" + code = "int add(int a, int b) { return a + b; }" + df = pd.DataFrame( + { + "target_code_string": [code], + "model_generated_code_string": [code], + } + ) + output = _run_e2e(df, default_language="c") + self.assertAlmostEqual(output.avg_similarity, 1.0, places=5) + + def test_rust_similarity(self) -> None: + """Identical Rust code yields ~1.0.""" + code = "fn add(a: i32, b: i32) -> i32 { a + b }" + df = pd.DataFrame( + { + "target_code_string": [code], + "model_generated_code_string": [code], + } + ) + output = _run_e2e(df, default_language="rust") + self.assertAlmostEqual(output.avg_similarity, 1.0, places=5) + + def test_ruby_similarity(self) -> None: + """Identical Ruby code yields ~1.0.""" + code = "def add(a, b)\n a + b\nend" + df = pd.DataFrame( + { + "target_code_string": [code], + "model_generated_code_string": [code], + } + ) + output = _run_e2e(df, default_language="ruby") + self.assertAlmostEqual(output.avg_similarity, 1.0, places=5) + + def test_c_sharp_similarity(self) -> None: + """Identical C# code yields ~1.0.""" + code = "class Foo { int Add(int a, int b) { return a + b; } }" + df = pd.DataFrame( + { + "target_code_string": [code], + "model_generated_code_string": [code], + } + ) + output = _run_e2e(df, default_language="c_sharp") + self.assertAlmostEqual(output.avg_similarity, 1.0, places=5) + def test_avg_similarity_by_language(self) -> None: - """Mixed Python+C++ input produces per-language averages.""" + """Mixed multi-language input produces per-language averages.""" df = pd.DataFrame( { "target_code_string": [ "def foo():\n return 1\n", "int main() { return 0; }", + "class Foo { int add(int a, int b) { return a + b; } }", + "fn add(a: i32, b: i32) -> i32 { a + b }", + "int add(int a, int b) { return a + b; }", ], "model_generated_code_string": [ "def foo():\n return 1\n", "int main() { return 0; }", + "class Foo { int add(int a, int b) { return a + b; } }", + "fn add(a: i32, b: i32) -> i32 { a + b }", + "int add(int a, int b) { return a + b; }", ], - "language": ["python", "cpp"], + "language": ["python", "cpp", "java", "rust", "c"], } ) output = _run_e2e(df) assert output.avg_similarity_by_language is not None by_lang = output.avg_similarity_by_language - self.assertIn("python", by_lang) - self.assertIn("cpp", by_lang) - self.assertAlmostEqual(by_lang["python"], 1.0, places=5) - self.assertAlmostEqual(by_lang["cpp"], 1.0, places=5) + for lang in ["python", "cpp", "java", "rust", "c"]: + self.assertIn(lang, by_lang) + self.assertAlmostEqual(by_lang[lang], 1.0, places=5) def test_compute_similarity_static_method(self) -> None: """TreeEditDistanceNode.compute_similarity works standalone.""" diff --git a/privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py b/privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py index 2df8b28..91ca591 100644 --- a/privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py +++ b/privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py @@ -14,14 +14,14 @@ # pyre-strict -import ctypes +import importlib.resources import logging -from types import ModuleType from typing import Any import pandas as pd -import tree_sitter_cpp # @manual=fbsource//third-party/pypi/tree-sitter-cpp:tree-sitter-cpp -import tree_sitter_python # @manual=fbsource//third-party/pypi/tree-sitter-python:tree-sitter-python +from codebleu.codebleu import ( # @manual=fbsource//third-party/pypi/codebleu:codebleu + AVAILABLE_LANGS, +) from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import ( CodeSimilarityAnalysisInput, ) @@ -38,39 +38,25 @@ logger: logging.Logger = logging.getLogger(__name__) -# Maps user-facing language strings to tree-sitter language modules. -_LANGUAGE_REGISTRY: dict[str, ModuleType] = { - "python": tree_sitter_python, - "py": tree_sitter_python, - "c++": tree_sitter_cpp, - "cpp": tree_sitter_cpp, +# Aliases that map to canonical codebleu language names. +_LANGUAGE_ALIASES: dict[str, str] = { + "py": "python", + "c++": "cpp", + "js": "javascript", } - -def _language_from_capsule(ts_module: ModuleType) -> Language: - """Create a tree-sitter Language from a language module's capsule. - - tree-sitter 0.20.4 expects ``Language(library_path, name)`` but the - modern language packages (tree-sitter-python, tree-sitter-cpp) expose - a ``language()`` function returning a PyCapsule. We extract the raw - pointer from the capsule and construct a Language-compatible object. - """ - capsule = ts_module.language() # type: ignore[attr-defined] - ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p - ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p] - language_id: ctypes.c_void_p = ctypes.pythonapi.PyCapsule_GetPointer( - capsule, b"tree_sitter.Language" - ) - lang = Language.__new__(Language) - lang.language_id = language_id # type: ignore[attr-defined] - return lang +SUPPORTED_LANGS: list[str] = AVAILABLE_LANGS def _get_parser(language: str) -> Parser: # pyre-ignore[11] """Create a tree-sitter Parser for the given language. + Uses the grammar bundled in the ``codebleu`` package (``my-languages.so``) + which supports: java, javascript, c_sharp, php, c, cpp, python, go, ruby, rust. + Args: - language: a key in _LANGUAGE_REGISTRY (e.g. "python", "cpp") + language: a language name from ``codebleu.AVAILABLE_LANGS``, + or an alias (e.g. "py", "c++", "js"). Returns: A configured tree-sitter Parser instance. @@ -78,17 +64,16 @@ def _get_parser(language: str) -> Parser: # pyre-ignore[11] Raises: ValueError: if the language is not supported. """ - lang_key = language.lower() - ts_module = _LANGUAGE_REGISTRY.get(lang_key) - if ts_module is None: + lang_key = _LANGUAGE_ALIASES.get(language.lower(), language.lower()) + if lang_key not in AVAILABLE_LANGS: raise ValueError( - f"Unsupported language '{language}'. " - f"Supported: {sorted(_LANGUAGE_REGISTRY.keys())}" + f"Unsupported language '{language}'. Supported: {sorted(AVAILABLE_LANGS)}" ) - - ts_language = _language_from_capsule(ts_module) - parser = Parser() # pyre-ignore[16] - # pyre-ignore[16]: Module `tree_sitter` has no attribute `Parser` + ts_language = Language( + importlib.resources.files("codebleu") / "my-languages.so", lang_key + ) + # pyre-ignore[16]: Module `tree_sitter` has no attribute `Parser`. + parser = Parser() parser.set_language(ts_language) return parser diff --git a/privacy_guard/attacks/tests/test_py_tree_sitter_attack.py b/privacy_guard/attacks/tests/test_py_tree_sitter_attack.py index de72fbe..eeee6c4 100644 --- a/privacy_guard/attacks/tests/test_py_tree_sitter_attack.py +++ b/privacy_guard/attacks/tests/test_py_tree_sitter_attack.py @@ -20,6 +20,9 @@ from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import ( CodeSimilarityAnalysisInput, ) +from privacy_guard.analysis.code_similarity.tree_edit_distance_node import ( + TreeEditDistanceNode, +) from privacy_guard.attacks.code_similarity.py_tree_sitter_attack import ( PyTreeSitterAttack, ) @@ -67,6 +70,53 @@ def test_run_attack_and_languages(self) -> None: self.assertEqual(gen_df["target_parse_status"].iloc[0], "success") self.assertEqual(gen_df["generated_parse_status"].iloc[0], "success") + with self.subTest("c"): + df = pd.DataFrame( + { + "target_code_string": ["int add(int a, int b) { return a + b; }"], + "model_generated_code_string": [ + "int sub(int a, int b) { return a - b; }" + ], + } + ) + attack = PyTreeSitterAttack(data=df, default_language="c") + result = attack.run_attack() + gen_df = result.generation_df + self.assertEqual(gen_df["target_parse_status"].iloc[0], "success") + self.assertEqual(gen_df["generated_parse_status"].iloc[0], "success") + + with self.subTest("java"): + df = pd.DataFrame( + { + "target_code_string": [ + "class Foo { int add(int a, int b) { return a + b; } }" + ], + "model_generated_code_string": [ + "class Bar { int sub(int a, int b) { return a - b; } }" + ], + } + ) + attack = PyTreeSitterAttack(data=df, default_language="java") + result = attack.run_attack() + gen_df = result.generation_df + self.assertEqual(gen_df["target_parse_status"].iloc[0], "success") + self.assertEqual(gen_df["generated_parse_status"].iloc[0], "success") + + with self.subTest("rust"): + df = pd.DataFrame( + { + "target_code_string": ["fn add(a: i32, b: i32) -> i32 { a + b }"], + "model_generated_code_string": [ + "fn sub(a: i32, b: i32) -> i32 { a - b }" + ], + } + ) + attack = PyTreeSitterAttack(data=df, default_language="rust") + result = attack.run_attack() + gen_df = result.generation_df + self.assertEqual(gen_df["target_parse_status"].iloc[0], "success") + self.assertEqual(gen_df["generated_parse_status"].iloc[0], "success") + with self.subTest("malformed_code_partial_parse"): df = pd.DataFrame( { @@ -159,8 +209,357 @@ def test_parse_code_static_method(self) -> None: self.assertEqual(status, "success") self.assertEqual(node.label, "translation_unit") + with self.subTest("c_parse"): + node, status = PyTreeSitterAttack.parse_code( + "int add(int a, int b) { return a + b; }", language="c" + ) + self.assertEqual(status, "success") + self.assertEqual(node.label, "translation_unit") + + with self.subTest("java_parse"): + node, status = PyTreeSitterAttack.parse_code( + "class Foo { int add(int a, int b) { return a + b; } }", + language="java", + ) + self.assertEqual(status, "success") + self.assertEqual(node.label, "program") + + with self.subTest("rust_parse"): + node, status = PyTreeSitterAttack.parse_code( + "fn add(a: i32, b: i32) -> i32 { a + b }", language="rust" + ) + self.assertEqual(status, "success") + self.assertEqual(node.label, "source_file") + + with self.subTest("javascript_parse"): + node, status = PyTreeSitterAttack.parse_code( + "function add(a, b) { return a + b; }", language="javascript" + ) + self.assertEqual(status, "success") + self.assertEqual(node.label, "program") + + with self.subTest("go_parse"): + node, status = PyTreeSitterAttack.parse_code( + "package main\nfunc add(a int, b int) int { return a + b }", + language="go", + ) + self.assertEqual(status, "success") + self.assertEqual(node.label, "source_file") + + with self.subTest("ruby_parse"): + node, status = PyTreeSitterAttack.parse_code( + "def add(a, b)\n a + b\nend", language="ruby" + ) + self.assertEqual(status, "success") + self.assertEqual(node.label, "program") + + with self.subTest("php_parse"): + node, status = PyTreeSitterAttack.parse_code( + "", + language="php", + ) + self.assertEqual(status, "success") + self.assertEqual(node.label, "program") + + with self.subTest("c_sharp_parse"): + node, status = PyTreeSitterAttack.parse_code( + "class Foo { int Add(int a, int b) { return a + b; } }", + language="c_sharp", + ) + self.assertEqual(status, "success") + self.assertEqual(node.label, "compilation_unit") + with self.subTest("malformed_returns_partial"): node, status = PyTreeSitterAttack.parse_code("def foo(:", language="python") self.assertEqual(status, "partial") # A partial AST is still returned self.assertEqual(node.label, "module") + + +def _sim(code1: str, code2: str, lang: str) -> tuple[float, str, str]: + ast1, s1 = PyTreeSitterAttack.parse_code(code1, language=lang) + ast2, s2 = PyTreeSitterAttack.parse_code(code2, language=lang) + return TreeEditDistanceNode.compute_similarity(ast1, ast2), s1, s2 + + +class TestTreeEditSimilarityComprehensive(unittest.TestCase): + """End-to-end parse + similarity tests across languages and difficulty levels.""" + + def _check_lang( + self, lang: str, pairs: list[tuple[str, str, str, float, float]] + ) -> None: + for code1, code2, desc, lo, hi in pairs: + with self.subTest(f"{lang}_{desc}"): + sim, s1, s2 = _sim(code1, code2, lang) + self.assertEqual(s1, "success") + self.assertEqual(s2, "success") + self.assertGreaterEqual(sim, lo) + self.assertLessEqual(sim, hi) + + def test_python_similarity_levels(self) -> None: + self._check_lang( + "python", + [ + ("x = 1", "x = 1", "trivial_identical", 0.99, 1.01), + ( + "def add(a, b):\n return a + b\n", + "def add(a, b):\n return a + b\n", + "identical_func", + 0.99, + 1.01, + ), + ( + "def add(a, b):\n return a + b\n", + "def sum(x, y):\n return x + y\n", + "same_struct_diff_names", + 0.9, + 1.01, + ), + ( + "def foo(x):\n return x * 2\n", + "def foo(x):\n if x > 0:\n return x\n return -x\n", + "different_body", + 0.2, + 0.8, + ), + ( + "def f(lst):\n out = []\n for x in lst:\n out.append(x*2)\n return out\n", + "def f(lst):\n return [x*2 for x in lst]\n", + "loop_vs_comprehension", + 0.1, + 0.7, + ), + ( + "class Foo:\n def bar(self):\n return 1\n", + "def bar():\n return 1\n", + "class_vs_func", + 0.1, + 0.7, + ), + ( + "def f(x):\n if x > 0:\n for i in range(x):\n if i % 2 == 0:\n print(i)\n", + "def f(x):\n if x > 0:\n for i in range(x):\n if i % 2 == 0:\n print(i)\n", + "nested_identical", + 0.99, + 1.01, + ), + ( + "def f():\n try:\n x = 1/0\n except ZeroDivisionError:\n x = 0\n return x\n", + "def f():\n try:\n x = 1/0\n except ZeroDivisionError:\n x = 0\n return x\n", + "try_except_identical", + 0.99, + 1.01, + ), + ( + '@staticmethod\ndef f(x):\n """doc"""\n return x\n', + "def f(x):\n return x\n", + "decorator_docstring_vs_plain", + 0.3, + 0.9, + ), + ( + "import os\nfor f in os.listdir('.'):\n print(f)\n", + "class Config:\n DEBUG = True\n PORT = 8080\n", + "completely_different", + 0.0, + 0.4, + ), + ], + ) + + def test_c_similarity_levels(self) -> None: + self._check_lang( + "c", + [ + ("int x = 1;", "int x = 1;", "trivial_identical", 0.99, 1.01), + ( + "int add(int a, int b) { return a + b; }", + "int add(int a, int b) { return a + b; }", + "identical_func", + 0.99, + 1.01, + ), + ( + "int add(int a, int b) { return a + b; }", + "int sum(int x, int y) { return x + y; }", + "same_struct_diff_names", + 0.9, + 1.01, + ), + ( + "int f(int x) { return x * 2; }", + "int f(int x) { if (x > 0) return x; return -x; }", + "different_body", + 0.2, + 0.8, + ), + ( + "int fib(int n) { if(n<=1) return n; return fib(n-1)+fib(n-2); }", + "int fib(int n) { if(n<=1) return n; return fib(n-1)+fib(n-2); }", + "recursive_identical", + 0.99, + 1.01, + ), + ( + "void swap(int* a, int* b) { int t=*a; *a=*b; *b=t; }", + "void swap(int* a, int* b) { int t=*a; *a=*b; *b=t; }", + "swap_identical", + 0.99, + 1.01, + ), + ( + "int max(int a, int b) { return a > b ? a : b; }", + "int min(int a, int b) { return a < b ? a : b; }", + "max_vs_min", + 0.8, + 1.01, + ), + ( + '#include \nint main() { printf("hello"); return 0; }', + "typedef struct { int x; int y; } Vec2; float dot(Vec2 a, Vec2 b) { return a.x*b.x+a.y*b.y; }", + "completely_different", + 0.0, + 0.5, + ), + ], + ) + + def test_java_similarity_levels(self) -> None: + self._check_lang( + "java", + [ + ( + "class A { int add(int a, int b) { return a + b; } }", + "class A { int add(int a, int b) { return a + b; } }", + "identical", + 0.99, + 1.01, + ), + ( + "class A { int add(int a, int b) { return a + b; } }", + "class B { int sum(int x, int y) { return x + y; } }", + "same_struct_diff_names", + 0.9, + 1.01, + ), + ( + "class A { int f(int x) { return x * 2; } }", + "class A { int f(int x) { if (x > 0) return x; return -x; } }", + "different_body", + 0.3, + 0.85, + ), + ( + "class A { int fib(int n) { if(n<=1) return n; return fib(n-1)+fib(n-2); } }", + "class A { int fib(int n) { if(n<=1) return n; return fib(n-1)+fib(n-2); } }", + "recursive_identical", + 0.99, + 1.01, + ), + ( + "class A { int max(int a, int b) { return a > b ? a : b; } }", + "class A { int min(int a, int b) { return a < b ? a : b; } }", + "max_vs_min", + 0.8, + 1.01, + ), + ( + 'import java.util.*;\nclass Main { public static void main(String[] args) { System.out.println("hello"); } }', + "class Config { static final boolean DEBUG = true; static final int PORT = 8080; }", + "completely_different", + 0.0, + 0.5, + ), + ], + ) + + def test_rust_similarity_levels(self) -> None: + self._check_lang( + "rust", + [ + ( + "fn add(a: i32, b: i32) -> i32 { a + b }", + "fn add(a: i32, b: i32) -> i32 { a + b }", + "identical", + 0.99, + 1.01, + ), + ( + "fn add(a: i32, b: i32) -> i32 { a + b }", + "fn sum(x: i32, y: i32) -> i32 { x + y }", + "same_struct_diff_names", + 0.9, + 1.01, + ), + ( + "fn f(x: i32) -> i32 { x * 2 }", + "fn f(x: i32) -> i32 { if x > 0 { x } else { -x } }", + "different_body", + 0.2, + 0.8, + ), + ( + "fn fib(n: u32) -> u32 { if n <= 1 { n } else { fib(n-1) + fib(n-2) } }", + "fn fib(n: u32) -> u32 { if n <= 1 { n } else { fib(n-1) + fib(n-2) } }", + "recursive_identical", + 0.99, + 1.01, + ), + ( + "fn max(a: i32, b: i32) -> i32 { if a > b { a } else { b } }", + "fn min(a: i32, b: i32) -> i32 { if a < b { a } else { b } }", + "max_vs_min", + 0.8, + 1.01, + ), + ( + 'use std::io;\nfn main() { println!("hello"); }', + "struct Config { debug: bool, port: u16 }", + "completely_different", + 0.0, + 0.5, + ), + ], + ) + + def test_cpp_similarity_levels(self) -> None: + self._check_lang( + "cpp", + [ + ( + "int add(int a, int b) { return a + b; }", + "int add(int a, int b) { return a + b; }", + "identical", + 0.99, + 1.01, + ), + ( + "int add(int a, int b) { return a + b; }", + "int sum(int x, int y) { return x + y; }", + "same_struct_diff_names", + 0.9, + 1.01, + ), + ( + "int fib(int n) { if(n<=1) return n; return fib(n-1)+fib(n-2); }", + "int fib(int n) { if(n<=1) return n; return fib(n-1)+fib(n-2); }", + "recursive_identical", + 0.99, + 1.01, + ), + ( + "int max(int a, int b) { return a > b ? a : b; }", + "int min(int a, int b) { return a < b ? a : b; }", + "max_vs_min", + 0.8, + 1.01, + ), + ( + '#include \nint main() { std::cout << "hello"; return 0; }', + "struct Config { bool debug = true; int port = 8080; };", + "completely_different", + 0.0, + 0.5, + ), + ], + )