-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm_node.py
More file actions
87 lines (81 loc) · 3.91 KB
/
llm_node.py
File metadata and controls
87 lines (81 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Tuple
import logging
from .config import GEMINI_API_KEY, SILICONFLOW_API_KEY
from .gemini_api import GeminiAPIWrapper
from .siliconflow_api import SiliconflowAPIWrapper
from .models import AVAILABLE_MODELS
class LLMGeneratorNode:
"""
统一的大语言模型生成节点,支持多种API调用
"""
def __init__(self):
self.gemini_client = GeminiAPIWrapper(GEMINI_API_KEY)
self.siliconflow_client = SiliconflowAPIWrapper(SILICONFLOW_API_KEY)
@classmethod
def INPUT_TYPES(cls):
all_models = AVAILABLE_MODELS["gemini"] + AVAILABLE_MODELS["siliconflow"]
return {
"required": {
"api_type": (list(AVAILABLE_MODELS.keys()), {"default": "gemini"}),
"model": (all_models, {"default": AVAILABLE_MODELS["gemini"][0]}),
"user_prompt": ("STRING", {"multiline": True, "default": "", "placeholder": "用户提示词..."}),
"system_instruction": ("STRING", {"multiline": True, "default": "", "placeholder": "系统指令..."}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.1}),
"max_output_tokens": ("INT", {"default": 512, "min": 1, "max": 8192, "step": 1}),
},
"optional": {
"top_p": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05}),
"top_k": ("INT", {"default": 50, "min": 1, "max": 100, "step": 1}),
"frequency_penalty": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 2.0, "step": 0.1, "tooltip":"[仅SiliconFlow可用]"}),
"min_p": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip":"[仅SiliconFlow可用]"}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("generated_text",)
FUNCTION = "generate_text"
CATEGORY = "AI/LLM"
def generate_text(self, api_type: str, model: str, user_prompt: str, system_instruction: str,
temperature: float, max_output_tokens: int, top_p: float, top_k: int,
frequency_penalty: float, min_p: float) -> Tuple[str, ...]:
if not user_prompt or not user_prompt.strip():
return ("错误:用户提示词不能为空。",)
if model not in AVAILABLE_MODELS.get(api_type, []):
return (f"错误:模型 '{model}' 与API类型 '{api_type}' 不兼容。",)
logging.info(f"正在调用 {api_type.upper()} API: {model}")
logging.info(f"温度: {temperature}, 最大tokens: {max_output_tokens}")
try:
if api_type == "gemini":
generated_text = self.gemini_client.generate_content(
model=model,
prompt=user_prompt.strip(),
system_instruction=system_instruction.strip(),
temperature=temperature,
max_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k
)
elif api_type == "siliconflow":
generated_text = self.siliconflow_client.generate_content(
model=model,
prompt=user_prompt.strip(),
system_instruction=system_instruction.strip(),
temperature=temperature,
max_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
min_p=min_p
)
else:
generated_text = f"错误:不支持的API类型 '{api_type}'。"
return (generated_text,)
except Exception as e:
error_msg = f"节点执行出错: {str(e)}"
logging.error(error_msg)
return (error_msg,)
NODE_CLASS_MAPPINGS = {
"LLMGeneratorNode": LLMGeneratorNode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LLMGeneratorNode": "LLM Generator"
}