Skip to content

Performance mismatch in gemma series model #16

Description

@frank6200db

Hi

Hi, I am trying to reproduce the Babilong benchmark (QA1 to QA5) for the Gemma series, but the results I'm getting are consistently lower than the reported metrics.

Gemma 3 12B IT (32k context): The reported score is 45.9, whereas our reproduced result is 39.6.

Gemma 2 9B IT (1k context): The reported score is 71.0, while our reproduction yielded 63.0.

I am using the standard AutoModelForCausalLM and AutoTokenizer from Hugging Face. Could this be related to the specific Chat Template implementation or the attention backend (SDPA vs. Flash Attention 2)? Any insights would be appreciated. Thank you.

python ```

def format_examples(default_examples):
if len(default_examples) == 0:
return [], []

examples = default_examples.split('<example>\n')
examples = [e[:e.index("\n</example>")] for e in examples if len(e) > 0]
inputs = [e[:e.index("\nAnswer")] for e in examples]
outputs = [e[e.index("\nAnswer") + 9:] for e in examples]
return inputs, outputs

import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Qwen3 1M Context Inference")
parser.add_argument("--tasks", type=str, default="qa0", help="task name")
parser.add_argument("--model_path", type=str, default="/models/Qwen3-32B", help="Path to model")
parser.add_argument("--length", type=str, default='32k', help="Max sequence length")
parser.add_argument("--tp_size", type=int, default=8, help="Tensor Parallel size (GPUs)")
parser.add_argument("--split_name", type=str, default="32k", help="Dataset split name")
parser.add_argument("--kv_cache_quant", action="store_true", help="Enable 4-bit/8-bit KV cache to save memory")
parser.add_argument("--node_rank", type=int, default=0, help="node rank")
return parser.parse_args()

if name == "main":

import transformers





dtype = torch.bfloat16
args = parse_args()

model_name = args.model_path
model_name_ = model_name.split('/')[-1]

split_length = args.length



max_length = 128000

dev = args.node_rank
torch.cuda.set_device(dev)
device = f"cuda:{dev}"




tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)



if 'gemma-3' in model_name.lower():
    from transformers import AutoProcessor
    processor = AutoProcessor.from_pretrained(model_name)
    

model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True,
                                            device_map=device, torch_dtype=dtype,
                                            attn_implementation='flash_attention_2')



from babilong.prompts import DEFAULT_PROMPTS, DEFAULT_TEMPLATE, get_formatted_input



model = model.eval()

terminators = [
                    tokenizer.eos_token_id
                ]


generate_kwargs = {
  'max_new_tokens': 25,
  'max_length': None,
  'num_beams': 1,
  'do_sample': False,
  'temperature': None,
  'top_p': None,
  'top_k': None,
  'pad_token_id': tokenizer.pad_token_id
 }

if generate_kwargs['pad_token_id'] is None:
    generate_kwargs['pad_token_id'] = tokenizer.eos_token_id



tasks = ['qa1', 'qa2','qa3', 'qa4', 'qa5']


split_names = [split_length]

print(f'prompt template:\n{DEFAULT_TEMPLATE}')




for task in tqdm(tasks, desc='tasks'):
    use_instruction = True
    use_examples = True
    use_post_prompt = True
    use_chat_template = True
    system_prompt ="You are a helpful AI assistant."

    
    
    prompt_cfg = {
        'instruction': DEFAULT_PROMPTS[task]['instruction'] if use_instruction else '',
        'examples': DEFAULT_PROMPTS[task]['examples'] if use_examples else '',
        'post_prompt': DEFAULT_PROMPTS[task]['post_prompt'] if use_post_prompt else '',
        'template': DEFAULT_TEMPLATE
    }  
    
    prompt_name = [f'{k}_no' if len(prompt_cfg[k]) == 0 else f'{k}_yes' for k in prompt_cfg if k != 'template']
    prompt_name = '_'.join(prompt_name)
    for split_name in tqdm(split_names, desc='lengths'):

        
        data = datasets.load_dataset("/RMT-team/babilong-1k-samples", split_name)
        task_data = data[task]




        outfile = Path(f'/eval/babilong/{model_name_}/{task}_{split_name}_{prompt_name}.csv')
        outfile.parent.mkdir(parents=True, exist_ok=True)
        cfg_file = f'/eval/babilong/{model_name_}/{task}_{split_name}_{prompt_name}.json'
        json.dump({'prompt': prompt_cfg, 'generate_kwargs': generate_kwargs}, open(cfg_file, 'w'), indent=4)

        df = pd.DataFrame({'target': [], 'output': [], 'question': []})

        for sample in tqdm(task_data):
            target = sample['target']
            context = sample['input']
            question = sample['question']
            
            
            input_text = get_formatted_input(context, question, prompt_cfg['examples'],
                                            prompt_cfg['instruction'], prompt_cfg['post_prompt'], template=DEFAULT_TEMPLATE)

            
            chat = input_text
            use_chat_template = True
            
            if use_chat_template:
                input_text = [{'role': 'user', 'content': input_text}]
                
                if 'qwen3' in model_name.lower():
                    model_inputs = tokenizer.apply_chat_template(input_text, add_generation_prompt=True,
                                                                return_tensors='pt', enable_thinking=False).to(model.device)
                    model_inputs = {'input_ids': model_inputs}
                    
                elif 'gemma-3' in model_name.lower():
                    messages = [
                    {
                        "role": "system",
                        "content": [{"type": "text", "text": "You are a helpful assistant."}]
                    },
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": input_text}
                        ]
                    }
                    ]
                    
                    inputs = processor.apply_chat_template(
                        messages, add_generation_prompt=True, tokenize=True,
                        return_dict=True, return_tensors="pt"
                    ).to(device, dtype=torch.bfloat16)
                    
                    model_inputs = inputs
                    model_inputs.pop("token_type_ids", None)
                
                elif 'gemma-2' in model_name.lower():
                    # print(chat)
                    chat = [
                        { "role": "user", "content": chat},
                    ]
                    model_inputs = tokenizer.apply_chat_template(input_text, add_generation_prompt=True,
                                                                     return_tensors='pt').to(model.device)
                    model_inputs = {'input_ids': model_inputs}


                elif 'lact' in model_name.lower() or 'untie' in model_name.lower():
                    input_ids = tokenizer([chat], return_tensors="pt").to(model.device)
                    model_inputs = input_ids



                else:
                    model_inputs = tokenizer.apply_chat_template(input_text, add_generation_prompt=False,
                                                                 return_tensors='pt').to(model.device)   # no generation prompt for pretrained models

                    model_inputs = {'input_ids': model_inputs}
            
            
            sample_length = model_inputs['input_ids'].shape[1]

            
            if 'gemma-3' in model_name.lower():
                with torch.inference_mode():
                        generation = model.generate(**model_inputs, max_new_tokens=25, do_sample=False)
                generation = generation[0][sample_length:]
                output = processor.decode(generation, skip_special_tokens=True)
                
            elif 'gemma-2' in model_name.lower():
                with torch.no_grad():
                    outputs = model.generate(**model_inputs, **generate_kwargs)
                    if 'activation-beacon' in model.name_or_path and hasattr(model, 'memory'):
                            model.memory.reset()
                outputs = outputs[0][sample_length:]
                output =tokenizer.decode(outputs)




                
            else:
                with torch.no_grad():
                    output = model.generate(**model_inputs, **generate_kwargs, use_cache=True)
                output = output[0][sample_length:]
                output = tokenizer.decode(output, skip_special_tokens=True).strip()
                
                
            print(f'>>> {sample_length} >>>>')

            df.loc[len(df)] = [target, output, question]
            df.to_csv(outfile)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions