-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgemini_node.py
More file actions
112 lines (94 loc) · 4.44 KB
/
gemini_node.py
File metadata and controls
112 lines (94 loc) · 4.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import google.generativeai as genai
import requests
from typing import Tuple, Dict, List
import logging
import json
import os
from .config import GEMINI_API_KEY, SILICONFLOW_API_KEY
from .gemini_api import GeminiAPIWrapper
from .siliconflow_api import SiliconflowAPIWrapper
from .models import AVAILABLE_MODELS
from .node import LLMGeneratorNode, NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
# 日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- 配置加载 ---
# --- API 封装层 ---
# --- 可用模型 ---
# --- ComfyUI 节点 ---
class LLMGeneratorNode:
"""
统一的大语言模型生成节点,支持多种API调用
"""
def __init__(self):
self.gemini_client = GeminiAPIWrapper(GEMINI_API_KEY)
self.siliconflow_client = SiliconflowAPIWrapper(SILICONFLOW_API_KEY)
@classmethod
def INPUT_TYPES(cls):
# 将所有模型合并到一个列表中,用于UI显示
all_models = AVAILABLE_MODELS["gemini"] + AVAILABLE_MODELS["siliconflow"]
return {
"required": {
"api_type": (list(AVAILABLE_MODELS.keys()), {"default": "gemini"}),
"model": (all_models, {"default": AVAILABLE_MODELS["gemini"][0]}),
"user_prompt": ("STRING", {"multiline": True, "default": "", "placeholder": "用户提示词..."}),
"system_instruction": ("STRING", {"multiline": True, "default": "", "placeholder": "系统指令..."}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.1}),
"max_output_tokens": ("INT", {"default": 512, "min": 1, "max": 8192, "step": 1}),
},
"optional": {
"top_p": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05}),
"top_k": ("INT", {"default": 50, "min": 1, "max": 100, "step": 1}),
"frequency_penalty": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 2.0, "step": 0.1, "tooltip":"[仅SiliconFlow可用]"}),
"min_p": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip":"[仅SiliconFlow可用]"}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("generated_text",)
FUNCTION = "generate_text"
CATEGORY = "AI/LLM"
def generate_text(self, api_type: str, model: str, user_prompt: str, system_instruction: str,
temperature: float, max_output_tokens: int, top_p: float, top_k: int,
frequency_penalty: float, min_p: float) -> Tuple[str, ...]:
if not user_prompt or not user_prompt.strip():
return ("错误:用户提示词不能为空。",)
if model not in AVAILABLE_MODELS.get(api_type, []):
return (f"错误:模型 '{model}' 与API类型 '{api_type}' 不兼容。",)
logging.info(f"正在调用 {api_type.upper()} API: {model}")
logging.info(f"温度: {temperature}, 最大tokens: {max_output_tokens}")
try:
if api_type == "gemini":
generated_text = self.gemini_client.generate_content(
model=model,
prompt=user_prompt.strip(),
system_instruction=system_instruction.strip(),
temperature=temperature,
max_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k
)
elif api_type == "siliconflow":
generated_text = self.siliconflow_client.generate_content(
model=model,
prompt=user_prompt.strip(),
system_instruction=system_instruction.strip(),
temperature=temperature,
max_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
min_p=min_p
)
else:
generated_text = f"错误:不支持的API类型 '{api_type}'。"
return (generated_text,)
except Exception as e:
error_msg = f"节点执行出错: {str(e)}"
logging.error(error_msg)
return (error_msg,)
# --- 节点映射 ---
NODE_CLASS_MAPPINGS = {
"LLMGeneratorNode": LLMGeneratorNode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LLMGeneratorNode": "LLM Generator"
}