diff --git a/dingo/model/llm/base_litellm.py b/dingo/model/llm/base_litellm.py new file mode 100644 index 00000000..258cfec7 --- /dev/null +++ b/dingo/model/llm/base_litellm.py @@ -0,0 +1,75 @@ +from typing import List + +from dingo.config.input_args import EvaluatorLLMArgs +from dingo.model.llm.base_openai import BaseOpenAI +from dingo.utils.exception import ExceedMaxTokens + + +class BaseLiteLLM(BaseOpenAI): + """Base class for LLM evaluators that route through LiteLLM. + + Provides access to 100+ providers (Anthropic, Gemini, Bedrock, Cohere, + Mistral, Groq, etc.) via a single unified interface. Inherit from this + class instead of BaseOpenAI when you want provider flexibility. + + Model string examples: + - "anthropic/claude-3-5-sonnet-20241022" + - "gemini/gemini-1.5-pro" + - "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" + - "groq/llama3-8b-8192" + - "gpt-4o" (defaults to OpenAI, same as BaseOpenAI) + + Configuration (via EvaluatorLLMArgs): + - model: required, provider-prefixed model string + - key: optional, API key (overrides provider env var) + - api_url: optional, custom base URL (e.g. LiteLLM proxy URL) + - Any extra field is forwarded to litellm.completion() as a kwarg. + + Requires: pip install "dingo-python[litellm]" + """ + + dynamic_config: EvaluatorLLMArgs = EvaluatorLLMArgs() + + @classmethod + def create_client(cls): + if not cls.dynamic_config.model: + raise ValueError("model cannot be empty in llm config.") + try: + import litellm # noqa: F401 + except ImportError as exc: + raise ImportError( + "litellm is not installed. Run: pip install 'dingo-python[litellm]'" + ) from exc + # Use cls.client as an initialisation sentinel (no real client object needed). + cls.client = True + + @classmethod + def send_messages(cls, messages: List) -> str: + import litellm + + model_name = cls.dynamic_config.model or "" + extra_params = cls.dynamic_config.model_extra or {} + cls.validate_config(extra_params) + + call_kwargs: dict = { + "drop_params": True, + **extra_params, + } + if cls.dynamic_config.api_url: + call_kwargs["api_base"] = cls.dynamic_config.api_url + if cls.dynamic_config.key: + call_kwargs["api_key"] = cls.dynamic_config.key + + response = litellm.completion( + model=model_name, + messages=messages, + **call_kwargs, + ) + + finish_reason = response.choices[0].finish_reason # type: ignore[union-attr] + if finish_reason == "length": + raise ExceedMaxTokens( + f"Exceed max tokens: {extra_params.get('max_tokens', 4000)}" + ) + + return str(response.choices[0].message.content) # type: ignore[union-attr] diff --git a/setup.py b/setup.py index 357285fc..a813ba7e 100644 --- a/setup.py +++ b/setup.py @@ -13,11 +13,13 @@ def _read_requirements(path): agent_requirements = _read_requirements("./requirements/agent.txt") hhem_requirements = _read_requirements("./requirements/hhem_integration.txt") +litellm_requirements = ["litellm>=1.80.0,<1.87.0"] extras_require = { 'agent': agent_requirements, 'hhem': hhem_requirements, - 'all': hhem_requirements + agent_requirements, + 'litellm': litellm_requirements, + 'all': hhem_requirements + agent_requirements + litellm_requirements, } diff --git a/test/scripts/model/llm/test_litellm.py b/test/scripts/model/llm/test_litellm.py new file mode 100644 index 00000000..45bf1f96 --- /dev/null +++ b/test/scripts/model/llm/test_litellm.py @@ -0,0 +1,134 @@ +"""Unit tests for BaseLiteLLM provider.""" +import sys +from types import SimpleNamespace +from unittest import mock + +import pytest + +from dingo.config.input_args import EvaluatorLLMArgs +from dingo.model.llm.base_litellm import BaseLiteLLM +from dingo.utils.exception import ExceedMaxTokens + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_provider(**cfg_kwargs) -> type: + """Return a fresh BaseLiteLLM subclass with an isolated dynamic_config.""" + + class _Provider(BaseLiteLLM): + prompt = "Evaluate: " + dynamic_config = EvaluatorLLMArgs(**cfg_kwargs) + + return _Provider + + +def _stub_response(content='{"score": 1, "reason": "ok"}', finish_reason="stop"): + choice = SimpleNamespace( + finish_reason=finish_reason, + message=SimpleNamespace(content=content), + ) + return SimpleNamespace(choices=[choice]) + + +# --------------------------------------------------------------------------- +# create_client +# --------------------------------------------------------------------------- + +class TestCreateClient: + def test_raises_without_model(self): + P = _make_provider() + with pytest.raises(ValueError, match="model cannot be empty"): + P.create_client() + + def test_sets_sentinel_on_success(self): + P = _make_provider(model="anthropic/claude-haiku-4-5") + P.create_client() + assert P.client is True + + def test_raises_import_error_when_litellm_missing(self, monkeypatch): + P = _make_provider(model="anthropic/claude-haiku-4-5") + monkeypatch.setitem(sys.modules, "litellm", None) # type: ignore[arg-type] + with pytest.raises((ImportError, TypeError)): + P.create_client() + + +# --------------------------------------------------------------------------- +# send_messages +# --------------------------------------------------------------------------- + +class TestSendMessages: + def test_dispatches_to_litellm(self): + P = _make_provider(model="anthropic/claude-haiku-4-5") + msgs = [{"role": "user", "content": "hello"}] + with mock.patch("litellm.completion", return_value=_stub_response()) as m: + P.send_messages(msgs) + m.assert_called_once() + assert m.call_args.kwargs["model"] == "anthropic/claude-haiku-4-5" + assert m.call_args.kwargs["messages"] == msgs + + def test_drop_params_always_true(self): + P = _make_provider(model="gpt-4o") + with mock.patch("litellm.completion", return_value=_stub_response()) as m: + P.send_messages([{"role": "user", "content": "hi"}]) + assert m.call_args.kwargs.get("drop_params") is True + + def test_api_key_forwarded(self): + P = _make_provider(model="gpt-4o", key="sk-test-123") + with mock.patch("litellm.completion", return_value=_stub_response()) as m: + P.send_messages([{"role": "user", "content": "hi"}]) + assert m.call_args.kwargs.get("api_key") == "sk-test-123" + + def test_api_base_forwarded_when_url_set(self): + P = _make_provider(model="gpt-4o", api_url="https://my-proxy.example.com/v1") + with mock.patch("litellm.completion", return_value=_stub_response()) as m: + P.send_messages([{"role": "user", "content": "hi"}]) + assert m.call_args.kwargs.get("api_base") == "https://my-proxy.example.com/v1" + + def test_no_api_key_when_key_not_set(self): + P = _make_provider(model="anthropic/claude-haiku-4-5") + with mock.patch("litellm.completion", return_value=_stub_response()) as m: + P.send_messages([{"role": "user", "content": "hi"}]) + assert "api_key" not in m.call_args.kwargs + + def test_no_api_base_when_url_not_set(self): + P = _make_provider(model="gpt-4o") + with mock.patch("litellm.completion", return_value=_stub_response()) as m: + P.send_messages([{"role": "user", "content": "hi"}]) + assert "api_base" not in m.call_args.kwargs + + def test_extra_params_forwarded(self): + P = _make_provider(model="gpt-4o", temperature=0.3, max_tokens=500) + with mock.patch("litellm.completion", return_value=_stub_response()) as m: + P.send_messages([{"role": "user", "content": "hi"}]) + kw = m.call_args.kwargs + assert kw.get("temperature") == 0.3 + assert kw.get("max_tokens") == 500 + + def test_raises_on_length_finish_reason(self): + P = _make_provider(model="gpt-4o", max_tokens=10) + length_resp = _stub_response(finish_reason="length") + with mock.patch("litellm.completion", return_value=length_resp): + with pytest.raises(ExceedMaxTokens): + P.send_messages([{"role": "user", "content": "hi"}]) + + +# --------------------------------------------------------------------------- +# process_response (inherited from BaseOpenAI) +# --------------------------------------------------------------------------- + +class TestProcessResponse: + def test_parses_good_score(self): + import json + response = json.dumps({"score": 1, "reason": "looks fine"}) + result = BaseLiteLLM.process_response(response) + assert result.status is False + assert result.label == ["QUALITY_GOOD"] + + def test_parses_bad_score(self): + import json + response = json.dumps({"score": 0, "reason": "bad content"}) + result = BaseLiteLLM.process_response(response) + assert result.status is True + assert "BaseLiteLLM" in result.label[0]