diff --git a/src/xagent/core/model/chat/basic/__init__.py b/src/xagent/core/model/chat/basic/__init__.py index d3eef90e1..3f86b3e85 100644 --- a/src/xagent/core/model/chat/basic/__init__.py +++ b/src/xagent/core/model/chat/basic/__init__.py @@ -4,6 +4,7 @@ from .claude import ClaudeLLM from .deepseek import DeepSeekLLM from .gemini import GeminiLLM +from .litellm import LiteLLM from .openai import OpenAILLM from .zhipu import ZhipuLLM @@ -15,5 +16,6 @@ "ZhipuLLM", "GeminiLLM", "ClaudeLLM", + "LiteLLM", "create_base_llm", ] diff --git a/src/xagent/core/model/chat/basic/adapter.py b/src/xagent/core/model/chat/basic/adapter.py index 35f68e985..618086ac4 100644 --- a/src/xagent/core/model/chat/basic/adapter.py +++ b/src/xagent/core/model/chat/basic/adapter.py @@ -9,6 +9,7 @@ from .claude import ClaudeLLM from .deepseek import DeepSeekLLM from .gemini import GeminiLLM +from .litellm import LiteLLM from .openai import OpenAILLM from .xinference import XinferenceLLM from .zhipu import ZhipuLLM @@ -86,6 +87,16 @@ def create_base_llm(model: ModelConfig) -> BaseLLM: timeout=model.timeout, abilities=model.abilities, ) + elif provider == "litellm": + llm = LiteLLM( + model_name=model.model_name, + api_key=model.api_key, + api_base=model.base_url, + default_temperature=model.default_temperature, + default_max_tokens=model.default_max_tokens, + timeout=model.timeout, + abilities=model.abilities, + ) elif provider == "xinference": llm = XinferenceLLM( model_name=model.model_name, diff --git a/src/xagent/core/model/chat/basic/litellm.py b/src/xagent/core/model/chat/basic/litellm.py new file mode 100644 index 000000000..dd9a7e0ef --- /dev/null +++ b/src/xagent/core/model/chat/basic/litellm.py @@ -0,0 +1,239 @@ +import logging +from typing import Any, AsyncIterator, Dict, List, Optional, Union + +from ..exceptions import LLMRetryableError, LLMTimeoutError +from ..timeout_config import TimeoutConfig +from ..token_context import add_token_usage +from ..types import ChunkType, StreamChunk +from .base import BaseLLM + +logger = logging.getLogger(__name__) + + +class LiteLLM(BaseLLM): + """ + LiteLLM client providing access to 100+ LLM providers through a unified interface. + Uses provider-prefixed model names (e.g. openai/gpt-4o, anthropic/claude-sonnet-4-6). + """ + + def __init__( + self, + model_name: str = "openai/gpt-4o-mini", + api_key: Optional[str] = None, + api_base: Optional[str] = None, + default_temperature: Optional[float] = None, + default_max_tokens: Optional[int] = None, + timeout: float = 180.0, + abilities: Optional[List[str]] = None, + timeout_config: Optional[TimeoutConfig] = None, + ): + self._model_name = model_name + self._api_key = api_key + self._api_base = api_base + self.default_temperature = default_temperature + self.default_max_tokens = default_max_tokens + self.timeout = timeout + self.timeout_config = timeout_config or TimeoutConfig() + + if abilities: + self._abilities = abilities + else: + self._abilities = ["chat", "tool_calling"] + + @property + def model_name(self) -> str: + return self._model_name + + @property + def abilities(self) -> List[str]: + return self._abilities + + @property + def supports_thinking_mode(self) -> bool: + return False + + async def chat( + self, + messages: List[Dict[str, str]], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + response_format: Optional[Dict[str, Any]] = None, + thinking: Optional[Dict[str, Any]] = None, + output_config: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Perform a chat completion via LiteLLM.""" + import litellm + + completion_params: Dict[str, Any] = { + "model": self._model_name, + "messages": self._sanitize_unicode_content(messages), + "drop_params": True, + "timeout": self.timeout, + **kwargs, + } + + if max_tokens is not None: + completion_params["max_tokens"] = max_tokens + elif self.default_max_tokens is not None: + completion_params["max_tokens"] = self.default_max_tokens + + if temperature is not None: + completion_params["temperature"] = temperature + elif self.default_temperature is not None: + completion_params["temperature"] = self.default_temperature + + if tools: + completion_params["tools"] = tools + if tool_choice: + completion_params["tool_choice"] = tool_choice + if response_format: + completion_params["response_format"] = response_format + + if self._api_key: + completion_params["api_key"] = self._api_key + if self._api_base: + completion_params["api_base"] = self._api_base + + try: + response = await litellm.acompletion(**completion_params) + except litellm.Timeout as e: + raise LLMTimeoutError(str(e)) from e + except ( + litellm.RateLimitError, + litellm.APIConnectionError, + litellm.ServiceUnavailableError, + litellm.InternalServerError, + ) as e: + raise LLMRetryableError(str(e)) from e + + if not response.choices: + raise LLMRetryableError("LiteLLM returned an empty response (no choices).") + choice = response.choices[0] + message = choice.message + + if hasattr(response, "usage") and response.usage: + add_token_usage( + input_tokens=getattr(response.usage, "prompt_tokens", 0) or 0, + output_tokens=getattr(response.usage, "completion_tokens", 0) or 0, + ) + + if hasattr(message, "tool_calls") and message.tool_calls: + tool_calls = [] + for tc in message.tool_calls: + tool_calls.append( + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + ) + return { + "type": "tool_call", + "tool_calls": tool_calls, + "raw": response.model_dump() + if hasattr(response, "model_dump") + else str(response), + } + + return message.content or "" + + async def stream_chat( + self, + messages: List[Dict[str, str]], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + response_format: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[StreamChunk]: + """Stream a chat completion via LiteLLM.""" + import litellm + + completion_params: Dict[str, Any] = { + "model": self._model_name, + "messages": self._sanitize_unicode_content(messages), + "stream": True, + "drop_params": True, + "timeout": self.timeout, + **kwargs, + } + + if max_tokens is not None: + completion_params["max_tokens"] = max_tokens + elif self.default_max_tokens is not None: + completion_params["max_tokens"] = self.default_max_tokens + + if temperature is not None: + completion_params["temperature"] = temperature + elif self.default_temperature is not None: + completion_params["temperature"] = self.default_temperature + + if tools: + completion_params["tools"] = tools + if tool_choice: + completion_params["tool_choice"] = tool_choice + if response_format: + completion_params["response_format"] = response_format + + if self._api_key: + completion_params["api_key"] = self._api_key + if self._api_base: + completion_params["api_base"] = self._api_base + + try: + response = await litellm.acompletion(**completion_params) + except litellm.Timeout as e: + raise LLMTimeoutError(str(e)) from e + except ( + litellm.RateLimitError, + litellm.APIConnectionError, + litellm.ServiceUnavailableError, + litellm.InternalServerError, + ) as e: + raise LLMRetryableError(str(e)) from e + + async for chunk in response: + if not chunk.choices: + continue + delta = chunk.choices[0].delta + + content = delta.content if hasattr(delta, "content") else None + if content: + yield StreamChunk(type=ChunkType.TOKEN, content=content, delta=content) + + if hasattr(delta, "tool_calls") and delta.tool_calls: + tool_calls = [] + for tc in delta.tool_calls: + tool_calls.append( + { + "index": tc.index + if hasattr(tc, "index") and tc.index is not None + else 0, + "id": tc.id if hasattr(tc, "id") else None, + "type": "function", + "function": { + "name": tc.function.name + if hasattr(tc, "function") + and hasattr(tc.function, "name") + else None, + "arguments": tc.function.arguments + if hasattr(tc, "function") + and hasattr(tc.function, "arguments") + else "", + }, + } + ) + yield StreamChunk( + type=ChunkType.TOOL_CALL, + tool_calls=tool_calls, + raw=chunk.model_dump() + if hasattr(chunk, "model_dump") + else str(chunk), + ) diff --git a/tests/core/model/chat/basic/test_litellm.py b/tests/core/model/chat/basic/test_litellm.py new file mode 100644 index 000000000..af69e1065 --- /dev/null +++ b/tests/core/model/chat/basic/test_litellm.py @@ -0,0 +1,230 @@ +"""Test cases for LiteLLM chat model implementation.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from xagent.core.model.chat.basic.litellm import LiteLLM +from xagent.core.model.chat.exceptions import LLMRetryableError, LLMTimeoutError + + +def _mock_response(content="Hello", prompt_tokens=10, completion_tokens=5): + usage = MagicMock() + usage.prompt_tokens = prompt_tokens + usage.completion_tokens = completion_tokens + choice = MagicMock() + choice.message.content = content + choice.message.tool_calls = None + resp = MagicMock() + resp.choices = [choice] + resp.usage = usage + return resp + + +def _mock_tool_response(name="get_weather", arguments='{"city": "Paris"}'): + tc = MagicMock() + tc.id = "call_123" + tc.function.name = name + tc.function.arguments = arguments + choice = MagicMock() + choice.message.content = None + choice.message.tool_calls = [tc] + resp = MagicMock() + resp.choices = [choice] + resp.usage = MagicMock(prompt_tokens=20, completion_tokens=10) + return resp + + +class TestLiteLLMInit: + def test_default_model(self): + llm = LiteLLM() + assert llm.model_name == "openai/gpt-4o-mini" + + def test_custom_model(self): + llm = LiteLLM(model_name="anthropic/claude-sonnet-4-6") + assert llm.model_name == "anthropic/claude-sonnet-4-6" + + def test_abilities_default(self): + llm = LiteLLM() + assert "chat" in llm.abilities + assert "tool_calling" in llm.abilities + + def test_abilities_custom(self): + llm = LiteLLM(abilities=["chat", "vision"]) + assert llm.abilities == ["chat", "vision"] + + def test_api_key_stored(self): + llm = LiteLLM(api_key="sk-test") + assert llm._api_key == "sk-test" + + def test_api_base_stored(self): + llm = LiteLLM(api_base="http://localhost:4000") + assert llm._api_base == "http://localhost:4000" + + def test_supports_thinking_mode_false(self): + llm = LiteLLM() + assert llm.supports_thinking_mode is False + + +class TestLiteLLMChat: + @pytest.mark.asyncio + async def test_basic_chat(self): + llm = LiteLLM(model_name="openai/gpt-4o") + resp = _mock_response("The answer is 4.") + with patch( + "litellm.acompletion", new_callable=AsyncMock, return_value=resp + ) as mock: + result = await llm.chat([{"role": "user", "content": "What is 2+2?"}]) + assert result == "The answer is 4." + call_kwargs = mock.call_args.kwargs + assert call_kwargs["model"] == "openai/gpt-4o" + assert call_kwargs["drop_params"] is True + + @pytest.mark.asyncio + async def test_api_key_forwarded(self): + llm = LiteLLM(api_key="sk-test") + with patch( + "litellm.acompletion", new_callable=AsyncMock, return_value=_mock_response() + ): + await llm.chat([{"role": "user", "content": "test"}]) + from litellm import acompletion + + call_kwargs = acompletion.call_args.kwargs + assert call_kwargs["api_key"] == "sk-test" + + @pytest.mark.asyncio + async def test_api_key_omitted_when_none(self): + llm = LiteLLM() + with patch( + "litellm.acompletion", new_callable=AsyncMock, return_value=_mock_response() + ) as mock: + await llm.chat([{"role": "user", "content": "test"}]) + assert "api_key" not in mock.call_args.kwargs + + @pytest.mark.asyncio + async def test_api_base_forwarded(self): + llm = LiteLLM(api_base="http://proxy:4000") + with patch( + "litellm.acompletion", new_callable=AsyncMock, return_value=_mock_response() + ) as mock: + await llm.chat([{"role": "user", "content": "test"}]) + assert mock.call_args.kwargs["api_base"] == "http://proxy:4000" + + @pytest.mark.asyncio + async def test_temperature_forwarded(self): + llm = LiteLLM() + with patch( + "litellm.acompletion", new_callable=AsyncMock, return_value=_mock_response() + ) as mock: + await llm.chat([{"role": "user", "content": "test"}], temperature=0.5) + assert mock.call_args.kwargs["temperature"] == 0.5 + + @pytest.mark.asyncio + async def test_default_temperature_used(self): + llm = LiteLLM(default_temperature=0.3) + with patch( + "litellm.acompletion", new_callable=AsyncMock, return_value=_mock_response() + ) as mock: + await llm.chat([{"role": "user", "content": "test"}]) + assert mock.call_args.kwargs["temperature"] == 0.3 + + @pytest.mark.asyncio + async def test_null_content_returns_empty(self): + resp = _mock_response(content=None) + llm = LiteLLM() + with patch("litellm.acompletion", new_callable=AsyncMock, return_value=resp): + result = await llm.chat([{"role": "user", "content": "test"}]) + assert result == "" + + +class TestLiteLLMToolCalling: + @pytest.mark.asyncio + async def test_tool_call_returned(self): + llm = LiteLLM() + with patch( + "litellm.acompletion", + new_callable=AsyncMock, + return_value=_mock_tool_response(), + ): + result = await llm.chat( + [{"role": "user", "content": "Weather?"}], + tools=[{"type": "function", "function": {"name": "get_weather"}}], + ) + assert result["type"] == "tool_call" + assert result["tool_calls"][0]["function"]["name"] == "get_weather" + + +class TestLiteLLMErrors: + @pytest.mark.asyncio + async def test_timeout_raises_llm_timeout_error(self): + import litellm as _litellm + + llm = LiteLLM() + with patch( + "litellm.acompletion", + new_callable=AsyncMock, + side_effect=_litellm.Timeout( + message="Request timed out", model="gpt-4o", llm_provider="openai" + ), + ): + with pytest.raises(LLMTimeoutError): + await llm.chat([{"role": "user", "content": "test"}]) + + @pytest.mark.asyncio + async def test_rate_limit_raises_retryable_error(self): + import litellm as _litellm + + llm = LiteLLM() + with patch( + "litellm.acompletion", + new_callable=AsyncMock, + side_effect=_litellm.RateLimitError( + message="429", llm_provider="openai", model="gpt-4o" + ), + ): + with pytest.raises(LLMRetryableError): + await llm.chat([{"role": "user", "content": "test"}]) + + @pytest.mark.asyncio + async def test_connection_error_raises_retryable_error(self): + import litellm as _litellm + + llm = LiteLLM() + with patch( + "litellm.acompletion", + new_callable=AsyncMock, + side_effect=_litellm.APIConnectionError( + message="Connection failed", llm_provider="openai", model="gpt-4o" + ), + ): + with pytest.raises(LLMRetryableError): + await llm.chat([{"role": "user", "content": "test"}]) + + @pytest.mark.asyncio + async def test_auth_error_propagates(self): + import litellm as _litellm + + llm = LiteLLM() + with patch( + "litellm.acompletion", + new_callable=AsyncMock, + side_effect=_litellm.AuthenticationError( + message="Invalid key", llm_provider="openai", model="gpt-4o" + ), + ): + with pytest.raises(_litellm.AuthenticationError): + await llm.chat([{"role": "user", "content": "test"}]) + + +class TestLiteLLMFactory: + def test_adapter_creates_litellm(self): + from xagent.core.model import ChatModelConfig + from xagent.core.model.chat.basic.adapter import create_base_llm + + config = ChatModelConfig( + id="test-litellm", + model_name="anthropic/claude-sonnet-4-6", + model_provider="litellm", + ) + llm = create_base_llm(config) + assert llm is not None