Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 86 additions & 6 deletions privacy_guard/analysis/tests/test_tree_edit_distance_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,28 +115,108 @@ def test_similarity_values(self) -> None:
output = _run_e2e(df)
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)

def test_java_similarity(self) -> None:
"""Identical Java code yields ~1.0; structurally similar code is high."""
with self.subTest("identical"):
code = "class Foo { int add(int a, int b) { return a + b; } }"
df = pd.DataFrame(
{
"target_code_string": [code],
"model_generated_code_string": [code],
}
)
output = _run_e2e(df, default_language="java")
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)

with self.subTest("similar"):
df = pd.DataFrame(
{
"target_code_string": [
"class Foo { int add(int a, int b) { return a + b; } }"
],
"model_generated_code_string": [
"class Bar { int sum(int x, int y) { return x + y; } }"
],
}
)
output = _run_e2e(df, default_language="java")
self.assertGreater(output.avg_similarity, 0.7)

def test_c_similarity(self) -> None:
"""Identical C code yields ~1.0."""
code = "int add(int a, int b) { return a + b; }"
df = pd.DataFrame(
{
"target_code_string": [code],
"model_generated_code_string": [code],
}
)
output = _run_e2e(df, default_language="c")
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)

def test_rust_similarity(self) -> None:
"""Identical Rust code yields ~1.0."""
code = "fn add(a: i32, b: i32) -> i32 { a + b }"
df = pd.DataFrame(
{
"target_code_string": [code],
"model_generated_code_string": [code],
}
)
output = _run_e2e(df, default_language="rust")
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)

def test_ruby_similarity(self) -> None:
"""Identical Ruby code yields ~1.0."""
code = "def add(a, b)\n a + b\nend"
df = pd.DataFrame(
{
"target_code_string": [code],
"model_generated_code_string": [code],
}
)
output = _run_e2e(df, default_language="ruby")
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)

def test_c_sharp_similarity(self) -> None:
"""Identical C# code yields ~1.0."""
code = "class Foo { int Add(int a, int b) { return a + b; } }"
df = pd.DataFrame(
{
"target_code_string": [code],
"model_generated_code_string": [code],
}
)
output = _run_e2e(df, default_language="c_sharp")
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)

def test_avg_similarity_by_language(self) -> None:
"""Mixed Python+C++ input produces per-language averages."""
"""Mixed multi-language input produces per-language averages."""
df = pd.DataFrame(
{
"target_code_string": [
"def foo():\n return 1\n",
"int main() { return 0; }",
"class Foo { int add(int a, int b) { return a + b; } }",
"fn add(a: i32, b: i32) -> i32 { a + b }",
"int add(int a, int b) { return a + b; }",
],
"model_generated_code_string": [
"def foo():\n return 1\n",
"int main() { return 0; }",
"class Foo { int add(int a, int b) { return a + b; } }",
"fn add(a: i32, b: i32) -> i32 { a + b }",
"int add(int a, int b) { return a + b; }",
],
"language": ["python", "cpp"],
"language": ["python", "cpp", "java", "rust", "c"],
}
)
output = _run_e2e(df)
assert output.avg_similarity_by_language is not None
by_lang = output.avg_similarity_by_language
self.assertIn("python", by_lang)
self.assertIn("cpp", by_lang)
self.assertAlmostEqual(by_lang["python"], 1.0, places=5)
self.assertAlmostEqual(by_lang["cpp"], 1.0, places=5)
for lang in ["python", "cpp", "java", "rust", "c"]:
self.assertIn(lang, by_lang)
self.assertAlmostEqual(by_lang[lang], 1.0, places=5)

def test_compute_similarity_static_method(self) -> None:
"""TreeEditDistanceNode.compute_similarity works standalone."""
Expand Down
61 changes: 23 additions & 38 deletions privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

# pyre-strict

import ctypes
import importlib.resources
import logging
from types import ModuleType
from typing import Any

import pandas as pd
import tree_sitter_cpp # @manual=fbsource//third-party/pypi/tree-sitter-cpp:tree-sitter-cpp
import tree_sitter_python # @manual=fbsource//third-party/pypi/tree-sitter-python:tree-sitter-python
from codebleu.codebleu import ( # @manual=fbsource//third-party/pypi/codebleu:codebleu
AVAILABLE_LANGS,
)
from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import (
CodeSimilarityAnalysisInput,
)
Expand All @@ -38,57 +38,42 @@

logger: logging.Logger = logging.getLogger(__name__)

# Maps user-facing language strings to tree-sitter language modules.
_LANGUAGE_REGISTRY: dict[str, ModuleType] = {
"python": tree_sitter_python,
"py": tree_sitter_python,
"c++": tree_sitter_cpp,
"cpp": tree_sitter_cpp,
# Aliases that map to canonical codebleu language names.
_LANGUAGE_ALIASES: dict[str, str] = {
"py": "python",
"c++": "cpp",
"js": "javascript",
}


def _language_from_capsule(ts_module: ModuleType) -> Language:
"""Create a tree-sitter Language from a language module's capsule.

tree-sitter 0.20.4 expects ``Language(library_path, name)`` but the
modern language packages (tree-sitter-python, tree-sitter-cpp) expose
a ``language()`` function returning a PyCapsule. We extract the raw
pointer from the capsule and construct a Language-compatible object.
"""
capsule = ts_module.language() # type: ignore[attr-defined]
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p]
language_id: ctypes.c_void_p = ctypes.pythonapi.PyCapsule_GetPointer(
capsule, b"tree_sitter.Language"
)
lang = Language.__new__(Language)
lang.language_id = language_id # type: ignore[attr-defined]
return lang
SUPPORTED_LANGS: list[str] = AVAILABLE_LANGS


def _get_parser(language: str) -> Parser: # pyre-ignore[11]
"""Create a tree-sitter Parser for the given language.

Uses the grammar bundled in the ``codebleu`` package (``my-languages.so``)
which supports: java, javascript, c_sharp, php, c, cpp, python, go, ruby, rust.

Args:
language: a key in _LANGUAGE_REGISTRY (e.g. "python", "cpp")
language: a language name from ``codebleu.AVAILABLE_LANGS``,
or an alias (e.g. "py", "c++", "js").

Returns:
A configured tree-sitter Parser instance.

Raises:
ValueError: if the language is not supported.
"""
lang_key = language.lower()
ts_module = _LANGUAGE_REGISTRY.get(lang_key)
if ts_module is None:
lang_key = _LANGUAGE_ALIASES.get(language.lower(), language.lower())
if lang_key not in AVAILABLE_LANGS:
raise ValueError(
f"Unsupported language '{language}'. "
f"Supported: {sorted(_LANGUAGE_REGISTRY.keys())}"
f"Unsupported language '{language}'. Supported: {sorted(AVAILABLE_LANGS)}"
)

ts_language = _language_from_capsule(ts_module)
parser = Parser() # pyre-ignore[16]
# pyre-ignore[16]: Module `tree_sitter` has no attribute `Parser`
ts_language = Language(
importlib.resources.files("codebleu") / "my-languages.so", lang_key
)
# pyre-ignore[16]: Module `tree_sitter` has no attribute `Parser`.
parser = Parser()
parser.set_language(ts_language)
return parser

Expand Down
Loading
Loading