diff --git a/dataflow/operators/core_text/generate/prompt_templated_qa_generator.py b/dataflow/operators/core_text/generate/prompt_templated_qa_generator.py index 571093b..00b156c 100644 --- a/dataflow/operators/core_text/generate/prompt_templated_qa_generator.py +++ b/dataflow/operators/core_text/generate/prompt_templated_qa_generator.py @@ -1,25 +1,28 @@ import pandas as pd +from typing import List + from dataflow.utils.registry import OPERATOR_REGISTRY from dataflow import get_logger - from dataflow.utils.storage import FileStorage, DataFlowStorage -from dataflow.core import OperatorABC -from dataflow.core import LLMServingABC +from dataflow.core import OperatorABC, LLMServingABC from dataflow.serving.local_model_vlm_serving import LocalModelVLMServing_vllm +from dataflow.serving.api_vlm_serving_openai import APIVLMServing_openai from dataflow.prompts.prompt_template import NamedPlaceholderPromptTemplate +# 提取判断是否为 API Serving 的辅助函数 +def is_api_serving(serving): + return isinstance(serving, APIVLMServing_openai) + + @OPERATOR_REGISTRY.register() class PromptTemplatedQAGenerator(OperatorABC): """ PromptTemplatedQAGenerator: 1) 从 DataFrame 读取若干字段(由 input_keys 指定) 2) 使用 prompt_template.build_prompt(...) 生成纯文本 prompt - 3) 将该 prompt 与 image/video 一起输入多模态模型,生成答案 - - 其中 prompt_template 需要实现: - build_prompt(self, need_fields: set[str], **kwargs) -> str + 3) 将该 prompt 输入大语言模型,生成纯文本答案 """ def __init__( @@ -35,47 +38,28 @@ def __init__( if self.prompt_template is None: raise ValueError( - "prompt_template cannot be None for PromptTemplatedVQAGenerator." + "prompt_template cannot be None for PromptTemplatedQAGenerator." ) @staticmethod def get_desc(lang: str = "zh"): if lang == "zh": return ( - "PromptTemplatedQAGenerator:先用模板填充文本 prompt,再" - "进行问答的算子。\n" - "JSONL/DataFrame 中包含若干字段(例如 descriptions、type 等)," - "通过 input_keys 将 DataFrame 列映射到模板字段,由 prompt_template 生成最终的文本 Prompt。" + "基于模板的纯文本问答算子 (PromptTemplatedQAGenerator)。\n" + "JSONL/DataFrame 中包含若干字段,通过 input_keys 将列映射到模板字段,\n" + "由 prompt_template 动态生成纯文本 Prompt,进行批量问答。\n\n" + "特点:\n" + " - 支持动态组装复杂的纯文本 Prompt\n" + " - 统一支持 API 和本地 Local 模型部署模式\n" + " - 全局 Batch 处理,极简代码结构\n" ) else: return ( - "PromptTemplatedQAGenerator: a QA operator that first builds " - "text prompts from a prompt template and multiple input fields, then " - "performs QA." - ) - - def _prepare_batch_inputs(self, prompts): - - prompt_list = [] - - for p in prompts: - raw_prompt = [ - {"role": "system", "content": self.system_prompt}, - { - "role": "user", - "content": [ - {"type": "text", "text": p}, - ], - }, - ] - prompt = self.serving.processor.apply_chat_template( - raw_prompt, tokenize=False, add_generation_prompt=True + "PromptTemplatedQAGenerator: a pure text QA operator that builds " + "text prompts from a template and multiple input fields, then " + "performs QA inference." ) - prompt_list.append(prompt) - - return prompt_list - def run( self, storage: DataFlowStorage, @@ -87,66 +71,88 @@ def run( - storage: DataFlowStorage - output_answer_key: 输出答案列名 - **input_keys: 模板字段名 -> DataFrame 列名 - 例如: - descriptions="descriptions", type="type" - - 逻辑: - 1. 从 DataFrame 每行抽取 input_keys 对应列,形成 key_dict - 2. 用 prompt_template.build_prompt(need_fields, **key_dict) 得到文本 prompt + 例如:descriptions="descriptions_col", type="type_col" """ - if output_answer_key is None: - raise ValueError("output_answer_key must be provided.") + if not output_answer_key: + raise ValueError("'output_answer_key' must be provided.") if len(input_keys) == 0: raise ValueError( - "PromptTemplatedVQAGenerator requires at least one input key " + "PromptTemplatedQAGenerator requires at least one input key " "to fill the prompt template (e.g., descriptions='descriptions')." ) self.logger.info("Running PromptTemplatedQAGenerator...") - self.output_answer_key = output_answer_key - dataframe = storage.read("dataframe") - self.logger.info(f"Loading, number of rows: {len(dataframe)}") + # 1. 加载 DataFrame + dataframe: pd.DataFrame = storage.read("dataframe") + self.logger.info(f"Loaded dataframe with {len(dataframe)} rows") + + use_api_mode = is_api_serving(self.serving) + if use_api_mode: + self.logger.info("Using API serving mode") + else: + self.logger.info("Using local serving mode") + # 2. 动态生成 Prompt 文本并组装标准对话结构 need_fields = set(input_keys.keys()) - prompt_column = [] + conversations_list = [] for idx, row in dataframe.iterrows(): key_dict = {} for key in need_fields: col_name = input_keys[key] # 模板字段名 -> DataFrame 列名 - key_dict[key] = row[col_name] + # 安全获取值,防止 NaN 导致字符串格式化异常 + val = row.get(col_name) + key_dict[key] = val if pd.notna(val) else "" + prompt_text = self.prompt_template.build_prompt(need_fields, **key_dict) - prompt_column.append(prompt_text) + + # 统一组装为基类所需的消息格式 + conversations_list.append([{"role": "user", "content": prompt_text}]) self.logger.info( - f"Using prompt_template to build prompts with fields {need_fields}, " - f"prepared {len(prompt_column)} prompts." + f"Built {len(conversations_list)} prompts using fields: {need_fields}" ) - prompt_list = self._prepare_batch_inputs(prompt_column) - - outputs = self.serving.generate_from_input( + # 3. 统一调用基类接口进行纯文本推理 (无需传入 image_list/video_list) + outputs = self.serving.generate_from_input_messages( + conversations=conversations_list, system_prompt=self.system_prompt, - user_inputs=prompt_list ) - dataframe[self.output_answer_key] = outputs + # 4. 保存结果 + dataframe[output_answer_key] = outputs output_file = storage.write(dataframe) self.logger.info(f"Results saved to {output_file}") - return output_answer_key + return [output_answer_key] +# ========================================== +# 测试用例 (Main Block) +# ========================================== if __name__ == "__main__": - model = LocalModelVLMServing_vllm( - hf_model_name_or_path="Qwen/Qwen2.5-VL-7B-Instruct", - vllm_tensor_parallel_size=1, - vllm_temperature=0.7, - vllm_top_p=0.9, - vllm_max_tokens=512, + + # 使用 API 模式测试 + model = APIVLMServing_openai( + api_url="http://172.96.141.132:3001/v1", + key_name_of_api_key="DF_API_KEY", + model_name="gpt-5-nano-2025-08-07", + image_io=None, + send_request_stream=False, + max_workers=10, + timeout=1800 ) + + # 如需测试 Local 模型,请解开注释 (VLM 模型同样能处理纯文本) + # model = LocalModelVLMServing_vllm( + # hf_model_name_or_path="Qwen/Qwen2.5-VL-7B-Instruct", + # vllm_tensor_parallel_size=1, + # vllm_temperature=0.7, + # vllm_top_p=0.9, + # vllm_max_tokens=512, + # ) TEMPLATE = ( "Descriptions:\n" @@ -164,18 +170,18 @@ def run( prompt_template=prompt_template, ) - # Prepare input + # 准备输入数据 storage = FileStorage( first_entry_file_name="./dataflow/example/text_to_text/prompt_templated_qa.jsonl", cache_path="./cache_prompted_qa", file_name_prefix="prompt_templated_qa", cache_type="jsonl", ) - storage.step() # Load the data + storage.step() # 加载数据 generator.run( storage=storage, output_answer_key="answer", descriptions="descriptions", type="type", - ) + ) \ No newline at end of file diff --git a/dataflow/operators/core_text/generate/prompted_qa_generator.py b/dataflow/operators/core_text/generate/prompted_qa_generator.py index 37bb6fe..de4825d 100644 --- a/dataflow/operators/core_text/generate/prompted_qa_generator.py +++ b/dataflow/operators/core_text/generate/prompted_qa_generator.py @@ -1,17 +1,25 @@ import pandas as pd +from typing import List + from dataflow.utils.registry import OPERATOR_REGISTRY from dataflow import get_logger - from dataflow.utils.storage import FileStorage, DataFlowStorage -from dataflow.core import OperatorABC -from dataflow.core import LLMServingABC + +from dataflow.core import OperatorABC, LLMServingABC from dataflow.serving.local_model_vlm_serving import LocalModelVLMServing_vllm - +from dataflow.serving.api_vlm_serving_openai import APIVLMServing_openai + + +# 提取判断是否为 API Serving 的辅助函数 +def is_api_serving(serving): + return isinstance(serving, APIVLMServing_openai) + + @OPERATOR_REGISTRY.register() class PromptedQAGenerator(OperatorABC): - ''' + """ PromptedQAGenerator read prompt and generate answers. - ''' + """ def __init__(self, serving: LLMServingABC, system_prompt: str = "You are a helpful assistant."): @@ -21,73 +29,86 @@ def __init__(self, @staticmethod def get_desc(lang: str = "zh"): - return "读取 prompt 生成答案" if lang == "zh" else "Read prompt to generate answers." - - def _prepare_batch_inputs(self, prompts): - """ - Construct batched prompts. - """ - prompt_list = [] - - for p in prompts: - raw_prompt = [ - {"role": "system", "content": self.system_prompt}, - { - "role": "user", - "content": [ - {"type": "text", "text": p}, - ], - }, - ] - - prompt = self.serving.processor.apply_chat_template( - raw_prompt, tokenize=False, add_generation_prompt=True + if lang == "zh": + return ( + "基础文本问答算子 (PromptedQAGenerator)。\n" + "直接读取指定列作为 prompt,生成纯文本答案。\n\n" + "特点:\n" + " - 极简纯文本问答\n" + " - 统一支持 API 和本地 Local 模型部署模式\n" + " - 全局 Batch 处理,极简代码结构\n" ) - - prompt_list.append(prompt) - - return prompt_list - + else: + return "Read prompt to generate answers." + def run(self, storage: DataFlowStorage, input_prompt_key: str = "prompt", output_answer_key: str = "answer", ): - if output_answer_key is None: - raise ValueError("At least one of output_answer_key must be provided.") - - self.logger.info("Running PromptedQA...") - - self.output_answer_key = output_answer_key - - # Load the raw dataframe from the input file - dataframe = storage.read('dataframe') - self.logger.info(f"Loading, number of rows: {len(dataframe)}") - - prompt_column = dataframe.get(input_prompt_key, pd.Series([])).tolist() - prompt_list = self._prepare_batch_inputs(prompt_column) - - outputs = self.serving.generate_from_input( + if not output_answer_key: + raise ValueError("'output_answer_key' must be provided.") + + self.logger.info("Running PromptedQAGenerator...") + + # 1. 加载 DataFrame + dataframe: pd.DataFrame = storage.read('dataframe') + self.logger.info(f"Loaded dataframe with {len(dataframe)} rows") + + use_api_mode = is_api_serving(self.serving) + if use_api_mode: + self.logger.info("Using API serving mode") + else: + self.logger.info("Using local serving mode") + + # 2. 提取并清洗 Prompt 数据 + prompt_column = dataframe.get(input_prompt_key, pd.Series([None] * len(dataframe))).tolist() + + # 组装为基类所需的消息格式,同时处理可能存在的 NaN 空值 + conversations_list = [] + for p in prompt_column: + safe_prompt = str(p) if pd.notna(p) else "" + conversations_list.append([{"role": "user", "content": safe_prompt}]) + + # 3. 统一调用基类接口进行推理 + outputs = self.serving.generate_from_input_messages( + conversations=conversations_list, system_prompt=self.system_prompt, - user_inputs=prompt_list, ) - dataframe[self.output_answer_key] = outputs + # 4. 保存结果 + dataframe[output_answer_key] = outputs output_file = storage.write(dataframe) self.logger.info(f"Results saved to {output_file}") - return output_answer_key + return [output_answer_key] + +# ========================================== +# 测试用例 (Main Block) +# ========================================== if __name__ == "__main__": - # Initialize model - model = LocalModelVLMServing_vllm( - hf_model_name_or_path="/data0/happykeyan/Models/Qwen2.5-VL-3B-Instruct", - vllm_tensor_parallel_size=1, - vllm_temperature=0.7, - vllm_top_p=0.9, - vllm_max_tokens=512, + + # 使用 API 模式测试 + model = APIVLMServing_openai( + api_url="http://172.96.141.132:3001/v1", + key_name_of_api_key="DF_API_KEY", + model_name="gpt-5-nano-2025-08-07", + image_io=None, + send_request_stream=False, + max_workers=10, + timeout=1800 ) + # 如需使用本地模型,请解开注释 + # model = LocalModelVLMServing_vllm( + # hf_model_name_or_path="/data0/happykeyan/Models/Qwen2.5-VL-3B-Instruct", + # vllm_tensor_parallel_size=1, + # vllm_temperature=0.7, + # vllm_top_p=0.9, + # vllm_max_tokens=512, + # ) + generator = PromptedQAGenerator( serving=model, system_prompt="You are a helpful assistant. Return the value of the math expression in the user prompt.", @@ -106,4 +127,4 @@ def run(self, storage=storage, input_prompt_key="prompt", output_answer_key="answer", - ) + ) \ No newline at end of file diff --git a/dataflow/operators/core_vision/__init__.py b/dataflow/operators/core_vision/__init__.py index 20efbad..86a3826 100644 --- a/dataflow/operators/core_vision/__init__.py +++ b/dataflow/operators/core_vision/__init__.py @@ -14,9 +14,9 @@ from .generate.video_caption_generator import VideoToCaptionGenerator from .generate.video_merged_caption_generator import VideoMergedCaptionGenerator from .generate.video_cotqa_generator import VideoCOTQAGenerator - from .generate.multirole_videoqa_generator import MultiroleVideoQAInitialGenerator, MultiroleVideoQAMultiAgentGenerator, MultiroleVideoQAFinalGenerator from .generate.batch_vqa_generator import BatchVQAGenerator from .generate.vlm_bbox_generator import VLMBBoxGenerator + from .generate.visual_reasoning_generator import VisualReasoningGenerator # === Filter === from .filter.video_clip_filter import VideoClipFilter diff --git a/dataflow/operators/core_vision/generate/multirole_videoqa_generator.py b/dataflow/operators/core_vision/generate/multirole_videoqa_generator.py deleted file mode 100644 index 5980d49..0000000 --- a/dataflow/operators/core_vision/generate/multirole_videoqa_generator.py +++ /dev/null @@ -1,309 +0,0 @@ -import os -import json -import pandas as pd -import re -from typing import List, Dict, Any, Union - -from dataflow.core.Operator import OperatorABC -from dataflow.utils.registry import OPERATOR_REGISTRY -from dataflow import get_logger -from dataflow.utils.storage import DataFlowStorage -from dataflow.core import VLMServingABC - -from dataflow.serving.api_vlm_serving_openai import APIVLMServing_openai - -# 引入提示词模板 -from dataflow.prompts.video import ( - MultiroleQAInitialQAGenerationPrompt, - MultiroleQACallExpertAgentsPrompt, - MultiroleQAProfile4ExpertAgents, - MultiroleQAMasterAgentRevisionPrompt, - MultiroleQADIYFinalQASynthesisPrompt, - MultiroleQAClassificationPrompt -) - -# ----------------------------------------------------------------------------- -# 辅助函数与基类 (消除重复代码,统一调用规范) -# ----------------------------------------------------------------------------- - -def is_api_serving(serving): - return isinstance(serving, APIVLMServing_openai) - -class MultiroleVideoQABase(OperatorABC): - """ - 多智能体视频问答算子基类,提供统一的视频信息提取和模型调用接口。 - """ - def __init__(self, llm_serving: VLMServingABC): - self.logger = get_logger() - self.llm_serving = llm_serving - - def _extract_video_info(self, v_input: Dict[str, Any]) -> tuple[Dict[str, Any], List[str]]: - """ - 提取视频 Meta 和 Clips 文本信息,并将所有有效图片路径展平为一个 List, - 移除了极度消耗内存的 PIL.Image 预加载逻辑。 - """ - v_content = { - "Meta": v_input.get("Meta", ""), - "Clips": [] - } - flat_image_paths = [] - - for clip in v_input.get("Clips", []): - processed_clip = { - "Audio_Text": clip.get("Audio_Text", ""), - "Description": clip.get("Description", "") - } - - paths = clip.get("Frames_Images", []) - if isinstance(paths, str): - paths = [paths] - - # 过滤并收集有效的图片路径 - valid_paths = [p for p in paths if isinstance(p, str) and p.strip()] - flat_image_paths.extend(valid_paths) - - processed_clip["Frames_Images"] = valid_paths - v_content["Clips"].append(processed_clip) - - return v_content, flat_image_paths - - def _generate_answer(self, prompt_text: str, image_paths: List[str]) -> str: - """ - 统一的模型调用接口。自动处理 API/Local 模式和 占位符。 - 代替了原来臃肿且有逻辑缺陷的 Callvlm 类。 - """ - use_api_mode = is_api_serving(self.llm_serving) - - if use_api_mode: - content = prompt_text - else: - img_tokens = "" * len(image_paths) - content = f"{img_tokens}\n{prompt_text}" if img_tokens else prompt_text - - conversation = [{"role": "user", "content": content}] - - outputs = self.llm_serving.generate_from_input_messages( - conversations=[conversation], - image_list=[image_paths] if image_paths else None, - system_prompt="" # 保持与原逻辑一致,不使用系统提示词 - ) - - if outputs and len(outputs) > 0: - return str(outputs[0]).strip() - return "" - - -# ----------------------------------------------------------------------------- -# Operator 1: Initial QA Generator (阶段一:初始问答生成) -# ----------------------------------------------------------------------------- - -@OPERATOR_REGISTRY.register() -class MultiroleVideoQAInitialGenerator(MultiroleVideoQABase): - def __init__(self, llm_serving: VLMServingABC): - super().__init__(llm_serving) - self.initial_gen_prompt = MultiroleQAInitialQAGenerationPrompt() - - def run( - self, - storage: DataFlowStorage, - input_meta_key: str = "Meta", - input_clips_key: str = "Clips", - output_key: str = "QA" - ): - df: pd.DataFrame = storage.read("dataframe") - - if input_meta_key not in df.columns or input_clips_key not in df.columns: - raise ValueError(f"Columns '{input_meta_key}' or '{input_clips_key}' not found.") - - if output_key not in df.columns: - df[output_key] = None - - self.logger.info(f"[InitialGenerator] Start processing {len(df)} videos...") - - for idx, row in df.iterrows(): - # 跳过已处理的数据 - if row.get(output_key) and isinstance(row.get(output_key), list) and len(row[output_key]) > 0: - continue - - clips_val = row.get(input_clips_key, []) - if not isinstance(clips_val, list): - self.logger.warning(f"Row {idx}: 'Clips' is not a list. Skipping.") - df.at[idx, output_key] = [] - continue - - v_input = {"Meta": row.get(input_meta_key, ""), "Clips": clips_val} - - try: - v_content, all_image_paths = self._extract_video_info(v_input) - prompt_s1 = self.initial_gen_prompt.build_prompt(v_content) - - initial_qa_str = self._generate_answer(prompt_s1, all_image_paths) - df.at[idx, output_key] = initial_qa_str - - except Exception as e: - self.logger.error(f"Error processing row {idx}: {str(e)}") - df.at[idx, output_key] = [] - - storage.write(df) - return [output_key] - - -# ----------------------------------------------------------------------------- -# Operator 2: Multi Agent Generator (阶段二:多智能体专家迭代) -# ----------------------------------------------------------------------------- - -@OPERATOR_REGISTRY.register() -class MultiroleVideoQAMultiAgentGenerator(MultiroleVideoQABase): - def __init__(self, llm_serving: VLMServingABC, max_iterations: int = 3): - super().__init__(llm_serving) - self.max_iterations = max_iterations - self.call_expert_prompt = MultiroleQACallExpertAgentsPrompt() - self.expert_profile_prompt = MultiroleQAProfile4ExpertAgents() - self.master_revision_prompt = MultiroleQAMasterAgentRevisionPrompt() - - def experts(self, call_for_experts_response: str) -> List[Dict[str, str]]: - experts_list: List[Dict[str, str]] = [] - json_matches = re.findall(r'\{.*?\}', call_for_experts_response, re.DOTALL) - - for json_str in json_matches: - try: - expert_data = json.loads(json_str.strip()) - role = expert_data.get("Expert_Role", "").strip('<> ').strip() - subtask = expert_data.get("Subtask", "").strip('<> ').strip() - - if role and subtask: - experts_list.append({"role": role, "subtask": subtask}) - except (json.JSONDecodeError, AttributeError): - continue - - return experts_list - - def run( - self, - storage: DataFlowStorage, - input_meta_key: str = "Meta", - input_clips_key: str = "Clips", - output_key: str = "QA" - ): - df: pd.DataFrame = storage.read("dataframe") - self.logger.info(f"[MultiAgentGenerator] Start processing {len(df)} videos...") - - for idx, row in df.iterrows(): - clips_val = row.get(input_clips_key, []) - init_qa = row.get(output_key, "") - - if not isinstance(clips_val, list): - continue - - v_input = {"Meta": row.get(input_meta_key, ""), "Clips": clips_val} - - try: - v_content, all_image_paths = self._extract_video_info(v_input) - - qa_history = [init_qa] - current_qa_pool_str = str(init_qa) - expert_history = [] - - for i in range(self.max_iterations): - self.logger.info(f"Row {idx} - Iteration {i + 1}: Check for Experts") - prompt_s2 = self.call_expert_prompt.build_prompt(v_content, current_qa_pool_str, expert_history) - call_for_experts_response = self._generate_answer(prompt_s2, all_image_paths) - - if "NO_EXPERTS" in call_for_experts_response: - self.logger.info("Master Agent decided to end iteration.") - break - - experts_list = self.experts(call_for_experts_response) - expert_history.extend(experts_list) - - for expert in experts_list: - prompt_s3 = self.expert_profile_prompt.build_prompt(expert["role"], v_content, expert["subtask"]) - expert_qa_str = self._generate_answer(prompt_s3, all_image_paths) - - prompt_s4 = self.master_revision_prompt.build_prompt(v_content, expert_qa_str, current_qa_pool_str) - revised_qa_str = self._generate_answer(prompt_s4, all_image_paths) - - current_qa_pool_str += f"\n{revised_qa_str}" - qa_history.append(revised_qa_str) - - df.at[idx, output_key] = qa_history - - except Exception as e: - self.logger.error(f"Error processing row {idx}: {str(e)}") - - storage.write(df) - return [output_key] - - -# ----------------------------------------------------------------------------- -# Operator 3: Final Generator (阶段三:最终合成与分类) -# ----------------------------------------------------------------------------- - -@OPERATOR_REGISTRY.register() -class MultiroleVideoQAFinalGenerator(MultiroleVideoQABase): - def __init__(self, llm_serving: VLMServingABC): - super().__init__(llm_serving) - self.final_synthesis_prompt = MultiroleQADIYFinalQASynthesisPrompt() - self.classification_prompt = MultiroleQAClassificationPrompt() - - def extract(self, final_qa_json_str: str) -> Union[List[Dict[str, Any]], str]: - JSON_ARRAY_REGEX = re.compile(r"(\[.*\])", re.DOTALL) - match = JSON_ARRAY_REGEX.search(final_qa_json_str) - - if not match: - self.logger.warning("Failed to find JSON array structure.") - return final_qa_json_str - - try: - qa_list = json.loads(match.group(1)) - if not isinstance(qa_list, list): - raise TypeError("Parsed result is not a list.") - return qa_list - except Exception as e: - self.logger.warning(f"Failed to parse extracted JSON block: {e}") - return final_qa_json_str - - def run( - self, - storage: DataFlowStorage, - input_meta_key: str = "Meta", - input_clips_key: str = "Clips", - output_key: str = "QA" - ): - df: pd.DataFrame = storage.read("dataframe") - self.logger.info(f"[FinalGenerator] Start processing {len(df)} videos...") - - for idx, row in df.iterrows(): - clips_val = row.get(input_clips_key, []) - qa_history = row.get(output_key, []) - - if not isinstance(clips_val, list): - continue - - v_input = {"Meta": row.get(input_meta_key, ""), "Clips": clips_val} - - try: - v_content, all_image_paths = self._extract_video_info(v_input) - - # Step 5: Final QA Synthesis - self.logger.info(f"Row {idx} - Step 5: Final QA Synthesis") - prompt_s5 = self.final_synthesis_prompt.build_prompt(qa_history) - synthesized_qa_str = self._generate_answer(prompt_s5, all_image_paths) - - # Step 6: Question Classification - self.logger.info(f"Row {idx} - Step 6: Question Classification") - prompt_s6 = self.classification_prompt.build_prompt(synthesized_qa_str) - final_qa_json_str = self._generate_answer(prompt_s6, all_image_paths) - - # Extract and Save - qa_list = self.extract(final_qa_json_str) - df.at[idx, output_key] = qa_list - - except Exception as e: - self.logger.error(f"Error processing row {idx}: {str(e)}") - df.at[idx, output_key] = [] - - output_file = storage.write(df) - self.logger.info(f"All processing done. Results saved to {output_file}") - - return [output_key] diff --git a/dataflow/operators/core_vision/generate/personalized_qa_generator.py b/dataflow/operators/core_vision/generate/personalized_qa_generator.py index e4dbe66..705ec7a 100644 --- a/dataflow/operators/core_vision/generate/personalized_qa_generator.py +++ b/dataflow/operators/core_vision/generate/personalized_qa_generator.py @@ -29,6 +29,10 @@ class PersQAGenerator(OperatorABC): Personalized QA generator. """ + """ + Personalized QA generator. + """ + def __init__(self, llm_serving: LLMServingABC): self.logger = get_logger() self.serving = llm_serving @@ -229,30 +233,30 @@ def run( if __name__ == "__main__": - model = APIVLMServing_openai( - api_url="http://172.96.141.132:3001/v1", # Any API platform compatible with OpenAI format - key_name_of_api_key="DF_API_KEY", # Set the API key for the corresponding platform in the environment variable or line 4 - model_name="gpt-5-nano-2025-08-07", - image_io=None, - send_request_stream=False, - max_workers=10, - timeout=1800 - ) - - # model = LocalModelVLMServing_vllm( - # hf_model_name_or_path="Qwen/Qwen2.5-VL-3B-Instruct", - # vllm_tensor_parallel_size=1, - # vllm_temperature=0.7, - # vllm_top_p=0.9, - # vllm_max_tokens=512, + # model = APIVLMServing_openai( + # api_url="https://dashscope.aliyuncs.com/compatible-mode/v1", + # key_name_of_api_key="DF_API_KEY", + # model_name="qwen3-vl-8b-instruct", + # image_io=None, + # send_request_stream=False, + # max_workers=10, + # timeout=1800 # ) + model = LocalModelVLMServing_vllm( + hf_model_name_or_path="Qwen/Qwen2.5-VL-3B-Instruct", + vllm_tensor_parallel_size=1, + vllm_temperature=0.7, + vllm_top_p=0.9, + vllm_max_tokens=512, + ) + generator = PersQAGenerator( llm_serving=model ) storage = FileStorage( - first_entry_file_name="./dataflow/example/image_to_text_pipeline/sample_data.json", + first_entry_file_name="./dataflow/example/test_data/image_data.json", cache_path="./cache_local", file_name_prefix="pers_qa", cache_type="json", diff --git a/dataflow/operators/core_vision/generate/visual_reasoning_generator.py b/dataflow/operators/core_vision/generate/visual_reasoning_generator.py new file mode 100644 index 0000000..5bc3232 --- /dev/null +++ b/dataflow/operators/core_vision/generate/visual_reasoning_generator.py @@ -0,0 +1,185 @@ +import pandas as pd +from typing import Optional, List, Dict, Any + +from dataflow import get_logger +from dataflow.utils.registry import OPERATOR_REGISTRY +from dataflow.utils.storage import FileStorage, DataFlowStorage +from dataflow.core import OperatorABC, LLMServingABC + +from dataflow.serving.local_model_vlm_serving import LocalModelVLMServing_vllm +from dataflow.serving.api_vlm_serving_openai import APIVLMServing_openai +from dataflow.prompts.image import MCTReasoningPrompt + + +# 提取判断是否为 API Serving 的辅助函数 +def is_api_serving(serving): + return isinstance(serving, APIVLMServing_openai) + + +@OPERATOR_REGISTRY.register() +class VisualReasoningGenerator(OperatorABC): + """ + [Generate] 调用 VLM 生成推理链。 + 支持 Fallback:如果 input_existing_chains_key 中已有数据,则直接使用,不进行生成。 + """ + def __init__(self, serving: LLMServingABC, prompt_type: str = "web_grounding"): + self.serving = serving + self.prompt_type = prompt_type + self.prompt_generator = MCTReasoningPrompt() + self.system_prompt = self._get_sys_prompt() + self.logger = get_logger() + + def _get_sys_prompt(self): + prompts = self.prompt_generator.build_prompt() + if self.prompt_type not in prompts: + self.logger.warning(f"Prompt type '{self.prompt_type}' not found. Using fallback system prompt.") + return "You are a helpful assistant capable of deep visual reasoning." + return prompts[self.prompt_type] + + @staticmethod + def get_desc(lang: str = "zh"): + if lang == "zh": + return ( + "视觉推理生成算子 (VisualReasoningGenerator)。\n" + "调用 VLM 生成带 的视觉推理链 (MCT)。\n\n" + "特点:\n" + " - 支持 Fallback 机制,断点续传跳过已生成的行\n" + " - 统一支持 API 和本地 Local 模型部署模式\n" + " - 全局 Batch 处理未命中缓存的数据,保证最高吞吐量\n" + ) + else: + return "Generates visual reasoning chains using VLM with fallback support." + + def run( + self, + storage: DataFlowStorage, + input_question_key: str, + input_image_key: str, + output_key: str, + input_existing_chains_key: Optional[str] = None + ): + if not output_key: + raise ValueError("'output_key' must be provided.") + + self.logger.info("Running VisualReasoningGenerator...") + df: pd.DataFrame = storage.read("dataframe") + + use_api_mode = is_api_serving(self.serving) + if use_api_mode: + self.logger.info("Using API serving mode") + else: + self.logger.info("Using local serving mode") + + # 初始化最终结果列表 (用 None 占位,长度与 DataFrame 相同) + final_results = [None] * len(df) + + # 1. 过滤与展平阶段 (Filter & Flatten Data) + flat_conversations = [] + flat_images = [] + indices_to_generate = [] # 记录真正需要跑大模型的行索引 + + for idx, row in df.iterrows(): + # --- 处理 Fallback (断点续传缓存) --- + existing = row.get(input_existing_chains_key) if input_existing_chains_key else None + if existing and isinstance(existing, list) and len(existing) > 0: + final_results[idx] = existing + continue + + # --- 提取正常数据 --- + q = row.get(input_question_key, "") + img_path = row.get(input_image_key) + + if not isinstance(q, str) or not q.strip(): + final_results[idx] = [] + continue + + # 清洗图片路径 + if isinstance(img_path, str): + img_path = [img_path] + elif not img_path: + img_path = [] + + valid_img_paths = [p for p in img_path if p and isinstance(p, str)] + + # 构造输入 Content + if use_api_mode: + content = q + else: + img_tokens = "" * len(valid_img_paths) + content = f"{img_tokens}\n{q}" if img_tokens else q + + flat_conversations.append([{"role": "user", "content": content}]) + flat_images.append(valid_img_paths) + indices_to_generate.append(idx) + + # 2. 批量推理阶段 (Batch Inference) + if flat_conversations: + self.logger.info(f"Generating reasoning chains for {len(flat_conversations)} samples " + f"({len(df) - len(flat_conversations)} skipped due to Fallback or empty input)...") + + outputs = self.serving.generate_from_input_messages( + conversations=flat_conversations, + image_list=flat_images, + system_prompt=self.system_prompt + ) + + # 3. 数据重组回填阶段 (Reconstruct Data) + for df_idx, out_text in zip(indices_to_generate, outputs): + final_results[df_idx] = [out_text] if out_text else [] + + # 扫尾:把跳过大模型且没有缓存的 None 替换为空列表 + final_results = [res if res is not None else [] for res in final_results] + + # 写入结果 + df[output_key] = final_results + output_file = storage.write(df) + self.logger.info(f"Results saved to {output_file}") + + return [output_key] + + +# ========================================== +# 测试用例 (Main Block) +# ========================================== +if __name__ == "__main__": + + # 使用 API 模式测试 + model = APIVLMServing_openai( + api_url="http://172.96.141.132:3001/v1", + key_name_of_api_key="DF_API_KEY", + model_name="gpt-5-nano-2025-08-07", + image_io=None, + send_request_stream=False, + max_workers=10, + timeout=1800 + ) + + # 如需使用本地模型,请解开注释 + # model = LocalModelVLMServing_vllm( + # hf_model_name_or_path="Qwen/Qwen2.5-VL-3B-Instruct", + # vllm_tensor_parallel_size=1, + # vllm_temperature=0.7, + # vllm_top_p=0.9, + # vllm_max_tokens=512, + # ) + + generator = VisualReasoningGenerator( + serving=model, + prompt_type="web_grounding" + ) + + storage = FileStorage( + first_entry_file_name="./dataflow/example/image_to_text_pipeline/reasoning_sample.jsonl", + cache_path="./cache_reasoning", + file_name_prefix="visual_reasoning", + cache_type="jsonl", + ) + storage.step() + + generator.run( + storage=storage, + input_question_key="question", + input_image_key="image", + output_key="reasoning_chain", + input_existing_chains_key="cached_reasoning" # 测试时可以在 jsonl 里留几行带这个字段的数据看看效果 + ) \ No newline at end of file diff --git a/dataflow/operators/core_vision/refine/visual_dependency_refiner.py b/dataflow/operators/core_vision/refine/visual_dependency_refiner.py index a2e70f8..46bb494 100644 --- a/dataflow/operators/core_vision/refine/visual_dependency_refiner.py +++ b/dataflow/operators/core_vision/refine/visual_dependency_refiner.py @@ -5,19 +5,27 @@ from dataflow import get_logger from dataflow.utils.registry import OPERATOR_REGISTRY -from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.storage import FileStorage, DataFlowStorage from dataflow.core import OperatorABC, LLMServingABC -from qwen_vl_utils import process_vision_info + +from dataflow.serving.local_model_vlm_serving import LocalModelVLMServing_vllm +from dataflow.serving.api_vlm_serving_openai import APIVLMServing_openai + + +# 提取判断是否为 API Serving 的辅助函数 +def is_api_serving(serving): + return isinstance(serving, APIVLMServing_openai) def shuffle_options_logic(qa_item: Dict[str, Any], add_none_option: bool = False) -> Tuple[str, str]: + """混淆选项逻辑 (保持原版业务逻辑不变)""" options = qa_item.get("options", {}) correct_letter = qa_item.get("answer") correct_text = options.get(correct_letter) items = list(options.items()) if not items or not correct_text: - return qa_item["question"], correct_letter + return qa_item.get("question", ""), correct_letter texts = [v for k, v in items] random.shuffle(texts) @@ -25,7 +33,7 @@ def shuffle_options_logic(qa_item: Dict[str, Any], add_none_option: bool = False new_labels = ["A", "B", "C", "D", "E", "F"] new_answer_letter = None - q_lines = [qa_item["question_title"]] + q_lines = [qa_item.get("question_title", "")] current_idx = 0 for i, txt in enumerate(texts): @@ -41,8 +49,12 @@ def shuffle_options_logic(qa_item: Dict[str, Any], add_none_option: bool = False return "\n".join(q_lines), new_answer_letter + def extract_letter_only(model_out: str) -> Optional[str]: - if not model_out: return None + """提取大模型回复中的选项字母""" + if not model_out: + return None + model_out = str(model_out) m = re.search(r"\b([A-Fa-f])\b", model_out) if m: return m.group(1).upper() m2 = re.search(r"(?:answer|option)\s*[::]\s*([A-Fa-f])", model_out, re.I) @@ -72,122 +84,196 @@ def __init__( self.pass_visual_min = pass_visual_min self.pass_textual_max = pass_textual_max self.add_none = add_none_above_visual + self.system_prompt = "You are a helpful assistant." self.logger = get_logger() @staticmethod def get_desc(lang: str = "zh"): - return ( - "视觉依赖性校验算子 (VisualDependencyRefiner)。\n" - "通过多次旋转选项并进行 有图/无图 对比测试,筛选出必须依赖视觉信息才能回答的高质量 MCQ。" - ) if lang == "zh" else "Visual Dependency Refiner: Filters MCQs requiring visual info via rotation checks." + if lang == "zh": + return ( + "视觉依赖性校验算子 (VisualDependencyRefiner)。\n" + "通过多次旋转选项并进行 有图/无图 对比测试,筛选出必须依赖视觉信息才能回答的高质量 MCQ。\n\n" + "特点:\n" + " - 双盲精度测试:$V_{acc}$ 与 $T_{acc}$ 联合校验\n" + " - 全局并行批处理:成百上千倍提升过滤吞吐量\n" + " - 统一 API 与本地模型接口,自动管理多模态 Token\n" + ) + else: + return "Visual Dependency Refiner: Filters MCQs requiring visual info via global rotation checks." def run(self, storage: DataFlowStorage, input_list_key: str, input_image_key: str, output_key: str): + if not output_key: + raise ValueError("'output_key' must be provided.") + self.logger.info(f"Running VisualDependencyRefiner on {input_list_key}...") - df = storage.read("dataframe") + df: pd.DataFrame = storage.read("dataframe") - filtered_results = [] + use_api_mode = is_api_serving(self.serving) - for idx, row in df.iterrows(): + # ========================================================= + # 1. 全局展平阶段 (Global Flattening) + # ========================================================= + vis_conversations, vis_images, vis_mappings = [], [], [] + txt_conversations, txt_mappings = [], [] + + for row_idx, row in df.iterrows(): qa_list = row.get(input_list_key, []) image_path = row.get(input_image_key) - + + # 清洗图片路径 + if isinstance(image_path, str): + image_path = [image_path] + elif not image_path: + image_path = [] + if not qa_list or not isinstance(qa_list, list) or not image_path: - filtered_results.append([]) continue - - kept_qas = [] - - # 遍历该图生成的每一道题 - for qa_item in qa_list: - - # --- 分离 Batch --- - # 我们不再把 VQA 和 QA 混在一起发,而是攒成两个独立的 Batch - visual_prompts = [] - visual_images = [] - visual_answers = [] # 记录对应的正确答案 - - text_prompts = [] - text_answers = [] - - # 准备数据 (rotate_num 次) + + for qa_idx, qa_item in enumerate(qa_list): + # 对每一道题,生成 rotate_num 次 有图 & 无图 变体 for _ in range(self.rotate_num): - # 1. Visual Case - q_v, ans_v = shuffle_options_logic(qa_item, add_none_option=self.add_none) - raw_v = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": [ - {"type": "image", "image": image_path}, - {"type": "text", "text": self.inst_template.format(q_v)} - ]} - ] - img_inp, _ = process_vision_info(raw_v) - p_v = self.serving.processor.apply_chat_template(raw_v, tokenize=False, add_generation_prompt=True) - - # Qwen 防御性 Patch - if "<|image_pad|>" not in p_v and "" not in p_v: - p_v = "<|vision_start|><|image_pad|><|vision_end|>" + p_v - visual_prompts.append(p_v) - visual_images.append(img_inp) - visual_answers.append(ans_v) + # --- 1. Visual Case (有图分支) --- + q_v, ans_v = shuffle_options_logic(qa_item, add_none_option=self.add_none) + prompt_v = self.inst_template.format(q_v) - # 2. Text-Only Case + if use_api_mode: + content_v = prompt_v + else: + img_tokens = "" * len(image_path) + content_v = f"{img_tokens}\n{prompt_v}" if img_tokens else prompt_v + + vis_conversations.append([{"role": "user", "content": content_v}]) + vis_images.append(image_path) + vis_mappings.append({"row_idx": row_idx, "qa_idx": qa_idx, "expected": ans_v}) + + # --- 2. Text-Only Case (纯文本无图分支) --- q_t, ans_t = shuffle_options_logic(qa_item, add_none_option=False) - # 纯文本请求不需要 System Prompt,或者保持一致 - raw_t = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": [{"type": "text", "text": self.inst_template.format(q_t)}]} - ] - # 纯文本不用 process_vision_info - p_t = self.serving.processor.apply_chat_template(raw_t, tokenize=False, add_generation_prompt=True) + prompt_t = self.inst_template.format(q_t) - text_prompts.append(p_t) - text_answers.append(ans_t) - - if not visual_prompts: - continue + txt_conversations.append([{"role": "user", "content": prompt_t}]) + txt_mappings.append({"row_idx": row_idx, "qa_idx": qa_idx, "expected": ans_t}) - # --- 分别调用 --- - - # 1. Visual Batch (image_inputs != None) - # 触发 Server 的“多模态模式”分支 - vis_outputs = self.serving.generate_from_input( - user_inputs=visual_prompts, - image_inputs=visual_images - ) - - # 2. Text Batch (image_inputs == None) - # 触发 Server 的“纯文本模式”分支 - txt_outputs = self.serving.generate_from_input( - user_inputs=text_prompts, - image_inputs=None # 显式传 None - ) + # ========================================================= + # 2. 全局双轨推理阶段 (Parallel Batch Inference) + # ========================================================= + vis_outputs = [] + if vis_conversations: + self.logger.info(f"Running VISUAL batch inference for {len(vis_conversations)} items...") + vis_outputs = self.serving.generate_from_input_messages( + conversations=vis_conversations, + image_list=vis_images, + system_prompt=self.system_prompt + ) + + txt_outputs = [] + if txt_conversations: + self.logger.info(f"Running TEXT-ONLY batch inference for {len(txt_conversations)} items...") + txt_outputs = self.serving.generate_from_input_messages( + conversations=txt_conversations, + # 显式传入 None,触发模型纯文本分支 + image_list=None, + system_prompt=self.system_prompt + ) + + # ========================================================= + # 3. 计分与回填过滤阶段 (Scoring & Filtering) + # ========================================================= + # 计分板格式: (row_idx, qa_idx) -> {"v_correct": 0, "t_correct": 0} + qa_stats = {} + + # 统计 Visual 得分 + for mapping, out_text in zip(vis_mappings, vis_outputs): + key = (mapping["row_idx"], mapping["qa_idx"]) + if key not in qa_stats: + qa_stats[key] = {"v_correct": 0, "t_correct": 0} - # --- 统计结果 --- - v_correct = 0 - l_correct = 0 + pred = extract_letter_only(out_text) + if pred == mapping["expected"]: + qa_stats[key]["v_correct"] += 1 + + # 统计 Text 得分 + for mapping, out_text in zip(txt_mappings, txt_outputs): + key = (mapping["row_idx"], mapping["qa_idx"]) + pred = extract_letter_only(out_text) + if pred == mapping["expected"]: + qa_stats[key]["t_correct"] += 1 + + # 最终回填 + filtered_results = [[] for _ in range(len(df))] + + for row_idx, row in df.iterrows(): + qa_list = row.get(input_list_key, []) + if not isinstance(qa_list, list): + continue - for i in range(self.rotate_num): - # 提取 Visual 结果 - pred_v = extract_letter_only(vis_outputs[i]) - if pred_v == visual_answers[i]: - v_correct += 1 - - # 提取 Text 结果 - pred_t = extract_letter_only(txt_outputs[i]) - if pred_t == text_answers[i]: - l_correct += 1 + for qa_idx, qa_item in enumerate(qa_list): + stats = qa_stats.get((row_idx, qa_idx)) + if not stats: + continue - v_acc = v_correct / self.rotate_num - l_acc = l_correct / self.rotate_num + # 计算通过率 + v_acc = stats["v_correct"] / self.rotate_num + t_acc = stats["t_correct"] / self.rotate_num - if v_acc >= self.pass_visual_min and l_acc <= self.pass_textual_max: - qa_item["stats"] = {"v_acc": v_acc, "t_acc": l_acc} - kept_qas.append(qa_item) - - filtered_results.append(kept_qas) - + # 核心筛选逻辑:必须看图才能做对,且盲猜做不对 + if v_acc >= self.pass_visual_min and t_acc <= self.pass_textual_max: + qa_item["stats"] = {"v_acc": v_acc, "t_acc": t_acc} + filtered_results[row_idx].append(qa_item) + df[output_key] = filtered_results - storage.write(df) + output_file = storage.write(df) + self.logger.info(f"Refinement complete. Results saved to {output_file}") + return [output_key] - \ No newline at end of file + + +# ========================================== +# 测试用例 (Main Block) +# ========================================== +if __name__ == "__main__": + + # API 模式 + model = APIVLMServing_openai( + api_url="http://172.96.141.132:3001/v1", + key_name_of_api_key="DF_API_KEY", + model_name="gpt-5-nano-2025-08-07", + image_io=None, + send_request_stream=False, + max_workers=10, + timeout=1800 + ) + + # 本地模型 + # model = LocalModelVLMServing_vllm( + # hf_model_name_or_path="Qwen/Qwen2.5-VL-3B-Instruct", + # vllm_tensor_parallel_size=1, + # vllm_temperature=0.7, + # vllm_top_p=0.9, + # vllm_max_tokens=512, + # ) + + # 模板需与用户的占位符匹配 (这里通过 .format(q_v) 直接传参,所以用 {0}) + refiner = VisualDependencyRefiner( + serving=model, + instruction_template="Please answer the following multiple-choice question.\n{0}", + rotate_num=4, + pass_visual_min=0.75, # 有图准确率需 >= 75% + pass_textual_max=0.30, # 无图瞎猜准确率需 <= 30% + add_none_above_visual=True + ) + + storage = FileStorage( + first_entry_file_name="./dataflow/example/image_to_text_pipeline/dependency_sample.jsonl", + cache_path="./cache_dependency", + file_name_prefix="visual_dependency", + cache_type="jsonl", + ) + storage.step() + + refiner.run( + storage=storage, + input_list_key="mcq_candidates", + input_image_key="image", + output_key="high_quality_mcqs" + ) \ No newline at end of file diff --git a/dataflow/operators/core_vision/refine/visual_grounding_refiner.py b/dataflow/operators/core_vision/refine/visual_grounding_refiner.py index e53f3c2..d97fcf5 100644 --- a/dataflow/operators/core_vision/refine/visual_grounding_refiner.py +++ b/dataflow/operators/core_vision/refine/visual_grounding_refiner.py @@ -1,8 +1,18 @@ +import pandas as pd +from typing import List + from dataflow.utils.registry import OPERATOR_REGISTRY -from dataflow.utils.storage import DataFlowStorage +from dataflow.utils.storage import FileStorage, DataFlowStorage from dataflow.core import OperatorABC, LLMServingABC from dataflow import get_logger -from qwen_vl_utils import process_vision_info + +from dataflow.serving.local_model_vlm_serving import LocalModelVLMServing_vllm +from dataflow.serving.api_vlm_serving_openai import APIVLMServing_openai + + +# 提取判断是否为 API Serving 的辅助函数 +def is_api_serving(serving): + return isinstance(serving, APIVLMServing_openai) @OPERATOR_REGISTRY.register() @@ -32,7 +42,8 @@ def get_desc(lang: str = "zh"): " - output_key: 过滤后保留的文本列表\n" "功能特点:\n" " - 自动构造 'Yes/No' 判别问题\n" - " - 批量并行推理,高效过滤幻觉 (Hallucination)\n" + " - 全局 Batch 展平处理,支持超大规模并发过滤\n" + " - 统一支持 API 和本地 Local 模型部署模式\n" ) else: return ( @@ -46,61 +57,91 @@ def get_desc(lang: str = "zh"): " - output_key: The filtered list of texts\n" "Features:\n" " - Automatically constructs 'Yes/No' verification questions\n" - " - Batch parallel inference for efficient hallucination filtering\n" + " - Global batch flattening for massive concurrent filtering\n" + " - Unifies support for API and Local model deployment modes\n" ) def run(self, storage: DataFlowStorage, input_list_key: str, input_image_key: str, output_key: str): + if not output_key: + raise ValueError("'output_key' must be provided.") + self.logger.info(f"Running VisualGroundingRefiner on {input_list_key}...") - df = storage.read("dataframe") - - refined_results = [] + df: pd.DataFrame = storage.read("dataframe") + use_api_mode = is_api_serving(self.serving) + if use_api_mode: + self.logger.info("Using API serving mode") + else: + self.logger.info("Using local serving mode") + + # --------------------------------------------------------- + # 1. 展平数据阶段 (Flatten Data) + # 将 N 张图片和对应 M 个待验证文本展平为一维请求列表 + # --------------------------------------------------------- + flat_conversations = [] + flat_images = [] + row_mappings = [] # 记录这道 prompt 属于哪一行以及它的原始文本:{"row_idx": int, "item": str} + for idx, row in df.iterrows(): items = row.get(input_list_key, []) image_path = row.get(input_image_key) - if not items or not isinstance(items, list) or not image_path: - refined_results.append([]) + # 清洗图片路径 + if isinstance(image_path, str): + image_path = [image_path] + elif not image_path: + image_path = [] + + if not isinstance(items, list) or not items or not image_path: continue - # 1. 构造 Batch Prompts - batch_prompts = [] - batch_images = [] - + # 为每一张图的每一个待验证文本构造对话 for item in items: - text_prompt = self.template.format(text=item) - raw_prompt = [ - {"role": "system", "content": self.system_prompt}, - { - "role": "user", - "content": [ - {"type": "text", "text": text_prompt}, - {"type": "image", "image": image_path} - ] - } - ] - image_inputs, _ = process_vision_info(raw_prompt) - final_prompt = self.serving.processor.apply_chat_template( - raw_prompt, tokenize=False, add_generation_prompt=True - ) - batch_prompts.append(final_prompt) - batch_images.append(image_inputs) - - if not batch_prompts: - refined_results.append([]) - continue + # 安全校验,防止传入非字符串导致 format 报错 + if not isinstance(item, str): + item = str(item) + + prompt_text = self.template.format(text=item) + + if use_api_mode: + content = prompt_text + else: + img_tokens = "" * len(image_path) + content = f"{img_tokens}\n{prompt_text}" if img_tokens else prompt_text + + flat_conversations.append([{"role": "user", "content": content}]) + flat_images.append(image_path) + row_mappings.append({"row_idx": idx, "item": item}) - # 2. 批量推理 - flags = self.serving.generate_from_input( - system_prompt=self.system_prompt, - user_inputs=batch_prompts, - image_inputs=batch_images + # --------------------------------------------------------- + # 2. 批量推理阶段 (Batch Inference) + # --------------------------------------------------------- + if flat_conversations: + self.logger.info(f"Verifying {len(flat_conversations)} items globally...") + flat_outputs = self.serving.generate_from_input_messages( + conversations=flat_conversations, + image_list=flat_images, + system_prompt=self.system_prompt ) + else: + flat_outputs = [] + + # --------------------------------------------------------- + # 3. 重组解析阶段 (Unflatten & Parse Data) + # --------------------------------------------------------- + # 初始化一个与 df 等长的空列表字典 + refined_results = [[] for _ in range(len(df))] + + for mapping, out_text in zip(row_mappings, flat_outputs): + idx = mapping["row_idx"] + original_item = mapping["item"] - # 3. 过滤逻辑 (保留 'yes') - kept = [item for item, flag in zip(items, flags) if "yes" in flag.lower()] - refined_results.append(kept) + # 过滤逻辑 (模型回复中包含 'yes' 视为验证通过,保留该文本) + if out_text and "yes" in str(out_text).lower(): + refined_results[idx].append(original_item) df[output_key] = refined_results - storage.write(df) + output_file = storage.write(df) + self.logger.info(f"Results saved to {output_file}") + return [output_key] diff --git a/dataflow/prompts/video.py b/dataflow/prompts/video.py index e210857..ac90dc8 100644 --- a/dataflow/prompts/video.py +++ b/dataflow/prompts/video.py @@ -99,352 +99,3 @@ def build_prompt(self, **kwargs) -> str: except Exception as e: # If formatting fails, return the original template return self.prompt_template - - -# ------------------------------------------------------------------------------------------ - - -class MultiroleQAInitialQAGenerationPrompt: - ''' - The prompt for the Master Agent to generate the initial set of QA pairs - based on the full video advertisement synopsis. - ''' - def __init__(self): - pass - - def build_prompt( - self, - processed_v_data: Dict[str, Any] - ) -> str: - """ - Generate the detailed initial QA generation prompt using the processed - video data structure (output of _serialize_v_input). - - Args: - processed_v_data (Dict[str, Any]): 经过处理后的视频数据字典, - 包含 "Meta" (str) 和 "Clips" (List[Dict])。 - - Returns: - str: 格式化后的提示词。 - """ - - clips = processed_v_data.get("Clips", []) - meta_str = processed_v_data.get("Meta", "No metadata provided.") - - scene_nums = len(clips) - - # 提取并格式化 Clips 的详细信息 - clip_details_parts = [] - full_voiceover_parts = [] - - for idx, clip in enumerate(clips): - # 提取文本信息 - audio_text = clip.get('Audio_Text', 'N/A') - description = clip.get('Description', 'N/A') - - # 统计图片数量 (Frames_Images 此时是 Image 对象列表) - frame_count = len(clip.get('Frames_Images', [])) - - clip_details_parts.append(f"\nClip {idx+1}:") - # 按照您的要求,循环输出 Description, Frames, Audio/Text - # {scene_descriptions} 对应 Description - # {scene_frames_image} 对应 Frames: [...] - # {voiceover} 对应 Audio/Text - clip_details_parts.append(f" Description: {description}") - clip_details_parts.append(f" Frames: [Contains {frame_count} Image Frames]") - clip_details_parts.append(f" Audio/Text: {audio_text}") - - # 收集完整的语音内容 - if audio_text and audio_text != 'N/A': - full_voiceover_parts.append(audio_text) - - scene_details_block = "\n".join(clip_details_parts) - full_voiceover = " ".join(full_voiceover_parts).strip() - - # --- 提示词模板 --- - prompt = f""" - You are tasked with generating questions from advertisement descriptions. Use the description alone to craft the questions and avoid making assumptions. The correct answer should blend seamlessly with the wrong ones in terms of length and complexity. Do not use direct quotes, and keep terminology simple. Questions should relate only to the content of the advertisement and avoid any external or behind-the-scenes details. - - This advertisement contains the following {scene_nums} scenes: - - {scene_details_block} - - Voiceover: {full_voiceover} - - The advertisement synopsis is based on the following general metadata: - Full Video Meta: {meta_str} - - Create No more than ten questions based on this synopsis. Use these aspects as a reference when asking questions: - 1. Theme and core message (compulsory) - 2. Conveyance method for Theme, Brand, and Product features (if applicable) - 3. Specific visual elements (object, person, scene, event, etc.) and their relation to the theme (No more than three questions) - 4. Specific detail’s connection to the overall theme (compulsory) - 5. Target audience characteristics (if applicable) - 6. Emotional impact and tactics used - 7. Storyline and narration (if applicable) - 8. Metaphors or humorous techniques (if present) - 9. Logical arguments, factual claims, or expert opinions (if present) - 10. Characters and their relevance to the theme and audience (if present) - 11. Creativity and the overall impression of the ad. (if applicable) - - For each question, provide only one correct answer. The answers must be unique, and unbiased. Print each correct answer exactly as 'Correct answer: [full answer]'. - """ - return prompt.strip() - -class MultiroleQACallExpertAgentsPrompt: - ''' - The prompt for the Master Agent to determine the need for new expert agents, - describe their role, and assign a specific sub-task based on the current - QA pool and video context. - ''' - def __init__(self): - pass - - def build_prompt( - self, - initial_prompt: str, - current_annotation: str, - expert_history: List[Dict[str, str]] - ) -> str: - """ - Generates the prompt for the Master Agent to call for new expert agents. - - Args: - initial_prompt (str): 初始 QA 生成时使用的完整上下文(包括视频信息、Meta等)。 - current_annotation (str): 当前已有的 QA 标注列表(包括初始和已修订的)。 - expert_history (List[Dict[str, str]]): 历史已招募专家的描述和角色。 - e.g., [{"role": "Marketing Expert", "description": "You are a XXX..."}] - - Returns: - str: 格式化后的提示词。 - """ - - # 格式化已招募的专家历史 - history_str = "" - if expert_history: - history_list = [f"Expert Role: {exp['role']}\nDescription: {exp['subtask']}" - for exp in expert_history] - history_str = "\n".join(history_list) - else: - history_str = "None." - - prompt = f""" - {initial_prompt} - {current_annotation} - Now, you can create and collaborate with multiple experts to improve your generated question-answer pairs. Therefore, please describe in as much detail as possible the different skills and focuses you need from multiple experts individually. We will provide each expert with the same information and query. However, please note that each profession has its own specialization, so you can assign each expert to just one sub-task to ensure a more refined response. We will relay their responses to you in turn, allowing you to reorganize them into a better generation. Please note that the description should be narrated in the second person, for example: You are a XXX. - These are the descriptions of the experts you have created before for this task: - {history_str} - Therefore, please remember you should not repeatedly create the same experts as described above. Now, you can give the description for a new expert (Please note that only be one, do not give multiple at one time): - Please follow this exact output format: - [RECRUIT] or [TERMINATE] - If [TERMINATE], stop here and only output "NO_EXPERTS". - If [RECRUIT], provide the expert details in JSON format: - {{ - "Expert_Role": "", - "Expert_Description": "", - "Subtask": "" - }} - """ - return prompt.strip() - -class MultiroleQAProfile4ExpertAgents: - ''' - A collection of detailed profiles for various Expert Agents. - This class is primarily used in Step 3 to initialize the Expert Agent's persona - before assigning a subtask. - ''' - def __init__(self): - # 内部可以存储所有已定义的专家档案,方便查找 - self._profiles = { - "Conservation Psychologist": self._get_conservation_psychologist_profile() - # 可以添加更多专家,如 Marketing Expert, Visual Design Expert 等 - } - - def _get_conservation_psychologist_profile(self) -> str: - """ - 返回保护心理学家的详细描述,用于 Expert Agent 的初始化。 - """ - return ( - "You are a conservation psychologist, specializing in understanding and promoting the psychological and " - "emotional connection people have with nature and wildlife. Your expertise includes analyzing how visual " - "and textual messaging in media can influence individuals’ attitudes and behaviors towards conservation and " - "environmental protection. Your focus lies in interpreting the emotional responses elicited by multimedia " - "content and identifying the aspects of an advertisement that enhance the viewers’ sense of urgency or empathy " - "towards the subject. You provide insights on the psychological impact of specific scenes, colors, narratives, " - "and the use of statistics or facts in fostering a sense of environmental stewardship and activism. Your role " - "is to evaluate the effectiveness of the environmental messages conveyed and suggest ways to strengthen the " - "emotional appeal and call to action within the advertisement." - ) - - def build_prompt( - self, - expert_role: str, - v_context: str, - subtask: str, - master_agent_invitation: str = "Please generate new, professional QA pairs that fulfill the specific subtask assigned to you, based on your expertise." - ) -> str: - """ - Generates the full prompt for a specific Expert Agent (Step 3). - - Args: - expert_role (str): 专家的角色名称,如 "Conservation Psychologist"。 - v_context (str): 序列化后的 V (视频广告多模态序列) 上下文。 - subtask (str): Master Agent 分配给该专家的具体子任务。 - master_agent_invitation (str): Master Agent 的邀请(通用指令)。 - - Returns: - str: 格式化后的提示词。 - """ - - # 1. 根据角色获取配置文件 - expert_profile = self._profiles.get(expert_role, f"You are a skilled {expert_role} expert.") - - # 2. 构建最终的专家 Prompt - prompt = f""" - {expert_profile} - VIDEO ADVERTISEMENT CONTEXT: - {v_context} - Your Specific Subtask: {subtask} - Master Agent Invitation: {master_agent_invitation} - Please generate a high-quality Question-Answer pair that leverage your specialized expertise to address the subtask. Your output must be a valid JSON list of objects, where each object has "Question" and "Answer" keys. - """ - return prompt.strip() - - -class MultiroleQAMasterAgentRevisionPrompt: - ''' - The prompt for the Master Agent to revise the current QA pool using the - newly generated QA from an Expert Agent. - ''' - def __init__(self): - pass - - def build_prompt( - self, - current_qa_list_str: str, - expert_profile: str, - expert_qa_list_str: str - ) -> str: - """ - Generates the prompt for the Master Agent to refine the QA set. - - Args: - current_qa_list_str (str): 当前已有的 QA 标注列表(包括初始和之前修订的)。 - expert_profile (str): 被邀请的专家角色及其描述。 - expert_qa_list_str (str): 专家智能体新生成的 QA 标注。 - - Returns: - str: 格式化后的提示词。 - """ - - prompt = f""" - Current QA Annotations or Initial Annotations are: - {current_qa_list_str} - You invite an expert whose description is: - {expert_profile} - QA Annotations Generated by the last Expert Agent are: - {expert_qa_list_str} - Now you can refine your question-answer pairs with his generation to create more professional and challenging question-answer pairs. Keep in mind that his generation may not be perfect, so critically decide whether to accept some parts of his response or stick with your original one. - CRITICAL REVISION GUIDELINES: - 1. Integration: Combine the Current QA Annotations and the Expert's QA into a single, cohesive, high-quality list. - 2. Quality Check: Eliminate any QA pairs that are redundant, factually incorrect based on the video context (provided in the initial prompt history, which is implicit here), or poorly phrased. - 3. Enhancement: Utilize the expert's insights to improve the complexity, depth, and relevance of the existing questions, especially those related to the expert's specialty. - 4. Format: Your output must be a valid JSON list of dictionaries, representing the complete, revised set of QA pairs. - Revised Question-Answer Pairs (Output only the JSON list): - """ - return prompt.strip() - -class MultiroleQADIYFinalQASynthesisPrompt: - ''' - The prompt for the Master Agent to synthesize the final, definitive QA list - from all iterative annotations. - ''' - def __init__(self): - pass - - def build_prompt( - self, - qa_history_list_str: List[str] - ) -> str: - """ - Generates the prompt for the Master Agent to synthesize the final QA set. - - Args: - qa_history_list_str (List[str]): 多次迭代中所有生成的 QA 列表(以字符串形式)。 - - Returns: - str: 格式化后的提示词。 - """ - - # 将历史 QA 列表连接成一个块,供 LLM 处理 - history_block = "\n" + "\n--- ITERATION BREAK ---\n".join(qa_history_list_str) + "\n" - - prompt = f""" - You are the Master Agent responsible for the final synthesis and quality check of the generated Question-Answer pairs for the video advertisement. You have received several batches of QA pairs from multiple iterations, including initial generation and revisions by various expert agents. - - Your task is to review, consolidate, and finalize these QA pairs based on the following rules: - - 1. De-duplication: Eliminate all exact or semantic duplicate QA pairs. - 2. Quality Selection: Select only the highest quality, most professional, and most challenging QA pairs from the pool(about 10 QA pairs). - 3. Consistency: Ensure the Questions relate only to the content of the advertisement (Context implicit in the history) and that the Answers are concise, unique, and unbiased. - 4. Format Adherence: The final output must strictly be a JSON list of dictionaries, containing only "Question" (Q) and "Answer" (A) keys. - - The complete history of QA pairs generated through all iterations is provided below: - All Iterated QA Pairs are: - {history_block} - Final Selected QA Pairs (Output only the JSON list): - """ - return prompt.strip() - -class MultiroleQAClassificationPrompt: - ''' - The prompt for the Master Agent to classify a single QA pair - into one or two of the five predefined categories. - ''' - def __init__(self): - pass - - def build_prompt( - self, - Q_A: str - ) -> str: - """ - Generates the classification prompt for a single QA pair. - - Args: - question (str): 需要分类的问题。 - answer (str): 对应问题的答案。 - - Returns: - str: 格式化后的提示词。 - """ - - prompt = f""" - Classify each of the question-answer pair in the question-answer pairs into one of the following categories: - - Type 1: The question-answer pair that focuses on the visual concepts, such as video details, characters in videos, a certain object, a certain scene, slogans presented in video, events, plot, and their interaction. - Type 2: The question-answer pair that focuses on the emotional content by ad videos and assesses the potential psychological impact of these emotions. - Type 3: The question-answer pair that focuses on the brand value, goal, theme, underlying message, or central idea** that the ad explores and conveys. - Type 4: The question-answer pair that focuses on persuasion strategies that ad videos convey their core messages. These messages may not be directly articulated but could instead engage viewers through humor and visual rhetoric. (e.g., Any questions about the symbols, metaphors, humor, exaggeration, and any questions that focus on the Logical arguments, factual claims, Statistical charts, or expert opinions. (Any question about presenting factual information and logical arguments to demonstrate the product's benefits and value) - Type 5: The question-answer pairs focus on the engagement, call for action, target audience, the characteristics of the target demographic, and who will be engaged. - - If this question could belong to multiple categories, please choose the most relevant two. - - Question and Answer pairs arse as follows: - {Q_A} - - Your output Labels should be just one or two of VU, ER, TE, PS, AM, and nothing else. (e.g., VU or AM) - Output Format Specification (JSON Array and do not contain other word): - [ - {{ - "Label": "