diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index a1d45fe1..876af2c4 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -138,8 +138,18 @@ def parse_args(): help="Number of files per subdirectory.", ) + # vlm related args + vlm_group = parser.add_argument_group("vlm") + vlm_group.add_argument( + "--min-pixels", type=int, default=50176 + ) # 64*28*28 for qwen2.5-vl + vlm_group.add_argument( + "--max-pixels", type=int, default=802816 + ) # 1024*28*28 for qwen2.5-vl + sglang_group = parser.add_argument_group("sglang") SGLangBackendArgs.add_args(sglang_group) + return parser.parse_args() @@ -187,7 +197,11 @@ def build_target_model( target_model.set_aux_hidden_states_layers(args.aux_hidden_states_layers) if args.is_vlm: - processor = AutoProcessor.from_pretrained(args.target_model_path) + processor = AutoProcessor.from_pretrained( + args.target_model_path, + min_pixels=args.min_pixels, + max_pixels=args.max_pixels, + ) else: processor = None @@ -583,6 +597,8 @@ def main(): args.target_model_path, trust_remote_code=True ) cache_params_string = f"{args.data_path}-{args.max_length}-{args.chat_template}-{args.target_model_path}-{args.num_samples}-{args.is_preformatted}" + if args.is_vlm: + cache_params_string = f'{cache_params_string}-{args.min_pixels}-{args.max_pixels}' cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() # Preprocess on complete, un-sharded dataset