Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions dingo/model/llm/base_litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import List

from dingo.config.input_args import EvaluatorLLMArgs
from dingo.model.llm.base_openai import BaseOpenAI
from dingo.utils.exception import ExceedMaxTokens


class BaseLiteLLM(BaseOpenAI):
"""Base class for LLM evaluators that route through LiteLLM.

Provides access to 100+ providers (Anthropic, Gemini, Bedrock, Cohere,
Mistral, Groq, etc.) via a single unified interface. Inherit from this
class instead of BaseOpenAI when you want provider flexibility.

Model string examples:
- "anthropic/claude-3-5-sonnet-20241022"
- "gemini/gemini-1.5-pro"
- "bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
- "groq/llama3-8b-8192"
- "gpt-4o" (defaults to OpenAI, same as BaseOpenAI)

Configuration (via EvaluatorLLMArgs):
- model: required, provider-prefixed model string
- key: optional, API key (overrides provider env var)
- api_url: optional, custom base URL (e.g. LiteLLM proxy URL)
- Any extra field is forwarded to litellm.completion() as a kwarg.

Requires: pip install "dingo-python[litellm]"
"""

dynamic_config: EvaluatorLLMArgs = EvaluatorLLMArgs()

@classmethod
def create_client(cls):
if not cls.dynamic_config.model:
raise ValueError("model cannot be empty in llm config.")
try:
import litellm # noqa: F401
except ImportError as exc:
raise ImportError(
"litellm is not installed. Run: pip install 'dingo-python[litellm]'"
) from exc
# Use cls.client as an initialisation sentinel (no real client object needed).
cls.client = True
Comment on lines +34 to +44

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The BaseLiteLLM class overrides create_client but completely omits the initialization of the embedding client (cls.embedding_client and cls.embedding_model). This breaks compatibility with RAG-related evaluators that inherit from BaseLiteLLM or use embedding_config in EvaluatorLLMArgs. To ensure BaseLiteLLM is a true drop-in replacement for BaseOpenAI, we should replicate the embedding initialization logic from BaseOpenAI.create_client.

    @classmethod
    def create_client(cls):
        if not cls.dynamic_config.model:
            raise ValueError("model cannot be empty in llm config.")
        try:
            import litellm  # noqa: F401
        except ImportError as exc:
            raise ImportError(
                "litellm is not installed. Run: pip install 'dingo-python[litellm]'"
            ) from exc
        # Use cls.client as an initialisation sentinel (no real client object needed).
        cls.client = True

        # If embedding_config is configured, initialize the embedding client
        if cls.dynamic_config.embedding_config:
            from openai import OpenAI
            from dingo.config.input_args import EmbeddingConfigArgs

            embedding_cfg = cls.dynamic_config.embedding_config

            # Handle embedding_config being a dict or object
            if isinstance(embedding_cfg, dict):
                embedding_cfg = EmbeddingConfigArgs(**embedding_cfg)

            if not embedding_cfg.api_url:
                raise ValueError("embedding_config must provide api_url")

            if not embedding_cfg.model:
                raise ValueError("embedding_config must provide model")

            # Create independent Embedding client
            cls.embedding_client = OpenAI(
                api_key=embedding_cfg.key or 'dummy-key',
                base_url=embedding_cfg.api_url
            )

            cls.embedding_model = {
                'model_name': embedding_cfg.model,
                'client': cls.embedding_client
            }


@classmethod
def send_messages(cls, messages: List) -> str:
import litellm

model_name = cls.dynamic_config.model or ""
extra_params = cls.dynamic_config.model_extra or {}
cls.validate_config(extra_params)

call_kwargs: dict = {
"drop_params": True,
**extra_params,
}
if cls.dynamic_config.api_url:
call_kwargs["api_base"] = cls.dynamic_config.api_url
if cls.dynamic_config.key:
call_kwargs["api_key"] = cls.dynamic_config.key

response = litellm.completion(
model=model_name,
messages=messages,
**call_kwargs,
)

finish_reason = response.choices[0].finish_reason # type: ignore[union-attr]
if finish_reason == "length":
raise ExceedMaxTokens(
f"Exceed max tokens: {extra_params.get('max_tokens', 4000)}"
)

return str(response.choices[0].message.content) # type: ignore[union-attr]
Comment on lines +69 to +75

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defensive checks should be added to handle cases where response.choices is empty or None, which can happen with certain API providers or due to content filtering. Additionally, if message.content is None, calling str(None) will return the literal string 'None', which will cause JSON parsing to fail in process_response. We should return an empty string "" instead.

Suggested change
finish_reason = response.choices[0].finish_reason # type: ignore[union-attr]
if finish_reason == "length":
raise ExceedMaxTokens(
f"Exceed max tokens: {extra_params.get('max_tokens', 4000)}"
)
return str(response.choices[0].message.content) # type: ignore[union-attr]
if not response.choices:
raise ValueError("LiteLLM returned an empty response choices list.")
choice = response.choices[0]
finish_reason = choice.finish_reason # type: ignore[union-attr]
if finish_reason == "length":
raise ExceedMaxTokens(
f"Exceed max tokens: {extra_params.get('max_tokens', 4000)}"
)
content = choice.message.content # type: ignore[union-attr]
return str(content) if content is not None else ""

4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ def _read_requirements(path):

agent_requirements = _read_requirements("./requirements/agent.txt")
hhem_requirements = _read_requirements("./requirements/hhem_integration.txt")
litellm_requirements = ["litellm>=1.80.0,<1.87.0"]

extras_require = {
'agent': agent_requirements,
'hhem': hhem_requirements,
'all': hhem_requirements + agent_requirements,
'litellm': litellm_requirements,
'all': hhem_requirements + agent_requirements + litellm_requirements,
}


Expand Down
134 changes: 134 additions & 0 deletions test/scripts/model/llm/test_litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Unit tests for BaseLiteLLM provider."""
import sys
from types import SimpleNamespace
from unittest import mock

import pytest

from dingo.config.input_args import EvaluatorLLMArgs
from dingo.model.llm.base_litellm import BaseLiteLLM
from dingo.utils.exception import ExceedMaxTokens


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _make_provider(**cfg_kwargs) -> type:
"""Return a fresh BaseLiteLLM subclass with an isolated dynamic_config."""

class _Provider(BaseLiteLLM):
prompt = "Evaluate: "
dynamic_config = EvaluatorLLMArgs(**cfg_kwargs)

return _Provider


def _stub_response(content='{"score": 1, "reason": "ok"}', finish_reason="stop"):
choice = SimpleNamespace(
finish_reason=finish_reason,
message=SimpleNamespace(content=content),
)
return SimpleNamespace(choices=[choice])


# ---------------------------------------------------------------------------
# create_client
# ---------------------------------------------------------------------------

class TestCreateClient:
def test_raises_without_model(self):
P = _make_provider()
with pytest.raises(ValueError, match="model cannot be empty"):
P.create_client()

def test_sets_sentinel_on_success(self):
P = _make_provider(model="anthropic/claude-haiku-4-5")
P.create_client()
assert P.client is True

def test_raises_import_error_when_litellm_missing(self, monkeypatch):
P = _make_provider(model="anthropic/claude-haiku-4-5")
monkeypatch.setitem(sys.modules, "litellm", None) # type: ignore[arg-type]
with pytest.raises((ImportError, TypeError)):
P.create_client()


# ---------------------------------------------------------------------------
# send_messages
# ---------------------------------------------------------------------------

class TestSendMessages:
def test_dispatches_to_litellm(self):
P = _make_provider(model="anthropic/claude-haiku-4-5")
msgs = [{"role": "user", "content": "hello"}]
with mock.patch("litellm.completion", return_value=_stub_response()) as m:
P.send_messages(msgs)
m.assert_called_once()
assert m.call_args.kwargs["model"] == "anthropic/claude-haiku-4-5"
assert m.call_args.kwargs["messages"] == msgs

def test_drop_params_always_true(self):
P = _make_provider(model="gpt-4o")
with mock.patch("litellm.completion", return_value=_stub_response()) as m:
P.send_messages([{"role": "user", "content": "hi"}])
assert m.call_args.kwargs.get("drop_params") is True

def test_api_key_forwarded(self):
P = _make_provider(model="gpt-4o", key="sk-test-123")
with mock.patch("litellm.completion", return_value=_stub_response()) as m:
P.send_messages([{"role": "user", "content": "hi"}])
assert m.call_args.kwargs.get("api_key") == "sk-test-123"

def test_api_base_forwarded_when_url_set(self):
P = _make_provider(model="gpt-4o", api_url="https://my-proxy.example.com/v1")
with mock.patch("litellm.completion", return_value=_stub_response()) as m:
P.send_messages([{"role": "user", "content": "hi"}])
assert m.call_args.kwargs.get("api_base") == "https://my-proxy.example.com/v1"

def test_no_api_key_when_key_not_set(self):
P = _make_provider(model="anthropic/claude-haiku-4-5")
with mock.patch("litellm.completion", return_value=_stub_response()) as m:
P.send_messages([{"role": "user", "content": "hi"}])
assert "api_key" not in m.call_args.kwargs

def test_no_api_base_when_url_not_set(self):
P = _make_provider(model="gpt-4o")
with mock.patch("litellm.completion", return_value=_stub_response()) as m:
P.send_messages([{"role": "user", "content": "hi"}])
assert "api_base" not in m.call_args.kwargs

def test_extra_params_forwarded(self):
P = _make_provider(model="gpt-4o", temperature=0.3, max_tokens=500)
with mock.patch("litellm.completion", return_value=_stub_response()) as m:
P.send_messages([{"role": "user", "content": "hi"}])
kw = m.call_args.kwargs
assert kw.get("temperature") == 0.3
assert kw.get("max_tokens") == 500

def test_raises_on_length_finish_reason(self):
P = _make_provider(model="gpt-4o", max_tokens=10)
length_resp = _stub_response(finish_reason="length")
with mock.patch("litellm.completion", return_value=length_resp):
with pytest.raises(ExceedMaxTokens):
P.send_messages([{"role": "user", "content": "hi"}])


# ---------------------------------------------------------------------------
# process_response (inherited from BaseOpenAI)
# ---------------------------------------------------------------------------

class TestProcessResponse:
def test_parses_good_score(self):
import json
response = json.dumps({"score": 1, "reason": "looks fine"})
result = BaseLiteLLM.process_response(response)
assert result.status is False
assert result.label == ["QUALITY_GOOD"]

def test_parses_bad_score(self):
import json
response = json.dumps({"score": 0, "reason": "bad content"})
result = BaseLiteLLM.process_response(response)
assert result.status is True
assert "BaseLiteLLM" in result.label[0]
Loading