Skip to content

liangzid/VirusInfectionAttack

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

157 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Virus Infection Attack (VIA)

This repository provides the official implementation of the paper:

Virus Infection Attack on LLMs: Your Poisoning Can Spread "VIA" Synthetic Data

[[Paper Link (Coming Soon)]]


🛠️ Environment Setup

  • Python version: >=3.10
  • Install dependencies:
    pip install -r re.txt
  • Hardware requirement: At least one GPU with 80GB memory is recommended to reproduce our main experiments.

📂 Code Overview

🔍 Distribution Analysis

  • ./DistributionAnalysis/query_comparison.py
    Generate query distribution visualizations and compute poisoning relevance.

🛡️ Defense Evaluation

  • ./defense/defense_methods.py
    Defense method implementations.

  • ./defense/evaluate.py
    Evaluation scripts for defense effectiveness.


🧪 Poisoning / Backdoor Dataset Construction

  • ./construct_direct_poisoning_dataset.py
    Construct standard data poisoning baseline.

  • ./construct_worm_sft.py
    VIA-based backdoor attack generation.

  • ./construct_worm_sft_new.py
    VIA-based data poisoning generation.

  • ./backdoor_dataset_construct.py
    Build backdoor datasets for experiments.


🎯 Attack Success Rate (ASR) Evaluation

  • ./asr_infer.py
    Evaluate ASR for data poisoning attacks.

  • ./dpabackdoor_eval.py
    Evaluate ASR for backdoor attacks.


🧬 Infection Rate (IR) Evaluation

  • ./infer_new.py
    Evaluate IR on synthetic data (data poisoning).

  • ./infec_infer_backdoor.py
    Evaluate IR on synthetic data (backdoor attacks).


🧩 Core Components

  • ./analyze_sftdataset.py
    Perform Hijacking Point Search (HPS) analysis.

  • ./getEmbedding.py
    Extract embeddings for representation-level analysis.

  • ./ChatwithAPI.py
    Query external LLM APIs.

  • train.py
    Fine-tune LLMs using poisoned or synthetic data.

  • ./seed.py
    Set global random seed for reproducibility.


📊 Visualization & Plotting

  • ./plot_PPL_dist.py
    Visualize PPL-based defense detection results.

  • ./plot_multigeneration.py
    Plot propagation through multiple generations.

  • ./plot_varyInfectionRateComparison.py
    Plot comparisons of IR under different settings (Figure 2).

  • ./plot_varyNgram_experiment.py
    Plot results on varying n-gram size.

  • ./visualize_hps.py
    Visualize HPS score distribution using bar plots.


🧪 Experiments: An Overview

All experiments follow a three-step pipeline:

  1. Construct the poisoned dataset
  2. Train the model
  3. Evaluate ASR / IR performance

📁 Step 1: Create Poisoned Datasets

# Backdoor Poisoning
python ./construct_worm_sft.py

# VIA for Data Poisoning Attack
python ./construct_worm_sft_new.py

# Standard Data Poisoning Baseline
python ./construct_direct_poisoning_dataset.py

🧠 Step 2: Train Models

You can use the provided scripts to reproduce the experiments:

  • Standard Poisoning: bash ./scripts/3.1.varyPoisoningRate.sh
  • VIA-based Poisoning: bash ./scripts/3.2.wormVaryPoisoningRate.sh
  • Backdoor Poisoning: bash ./scripts/3.3.backdoorPoisoningVaryPR.sh

You may also define your own experiment. Example:

export python=${HOME}/anaconda3/envs/worm/bin/python3
export TORCH_USE_CUDA_DSA="1"
export root_dir="${HOME}/wormInfection/"
export from_path="meta-llama/Meta-Llama-3-8B"
export CUDA_VISIBLE_DEVICES=0

# Poisoning configurations
export prefix_path=${root_dir}"saved_poison_dataset/"
export pr_ls=(0.025 0.05 0.1 0.2 0.4)
export train_time_ls=(1 2 3)

for pr in ${pr_ls[*]}; do
    for train_time in ${train_time_ls[*]}; do
        export dataset_name="${prefix_path}allenai_tulu-3-sft-personas-instruction-followinggeneral-person${pr}5000.jsonl"
        export savepath_suffix=$(echo "$dataset_name" | tr './' '__' | tr '/' '_')
        export save_path="saved_ckpts/VaryPR_Poisoning/${savepath_suffix}tt_${train_time}${from_path}"

        echo "---------------------"
        echo "Save path: $save_path"
        echo "---------------------"

        export seed=${train_time}
        $python ${root_dir}train.py \
            --dataset_name $dataset_name \
            --seed $seed \
            --epoch 3 \
            --acc_step 1 \
            --log_step 2000 \
            --save_step 5000 \
            --overall_step 15000 \
            --LR 3e-5 \
            --is_lora 1 \
            --rank 128 \
            --lora_alpha 256 \
            --batch_size 1 \
            --max_length 2048 \
            --from_path $from_path \
            --save_path $save_path \
            --temp_save_path ${save_path}temp
    done
done

echo "RUNNING 3.1.varyPoisoningRate.sh DONE."

📊 Step 3: Evaluate Performance

# ASR Evaluation for Data Poisoning
python asr_infer.py

# IR Evaluation for Data Poisoning
python infer_new.py

# IR Evaluation for Backdoor Attacks
python ./infec_infer_backdoor.py

⚠️ Note: You must uncomment the appropriate evaluation functions in the code. For example, asr_infer.py includes main2_varyPoisoningRate():

def main2_varyPoisoningRate():
    base_model_pth = "meta-llama/Meta-Llama-3-8B"
    pr_ls = ["0_005", "0_1", "0_025", "0_4"]
    train_time_ls = ["1", "2", "3"]
    device = "cuda"
    res_dict = {}

    for pr in pr_ls:
        for train_time in train_time_ls:
            ckpt = f"saved_ckpts/VaryPR_Poisoning/...{pr}5000_jsonltt_{train_time}meta-llama/Meta-Llama-3-8B/checkpoint-15000/"
            clean_dataset = "allenai/tulu-3-sft-personas-instruction-following"
            poison_type = "general-person"
            asr = ASR_query_eval(ckpt, device, base_model_pth, task_info=poison_type, mnt=512)
            res_dict[ckpt] = asr

    pprint(res_dict)

You should invoke this function inside the main() method.

About

Virus Infection Attack on LLMs: Your Poisoning Can Spread "VIA" Synthetic Data. NeurIPS'25 Spotlight

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors