-
Notifications
You must be signed in to change notification settings - Fork 74
feat: add LiteLLM as AI gateway provider #440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,75 @@ | ||||||||||||||||||||||||||||||||||||||||
| from typing import List | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| from dingo.config.input_args import EvaluatorLLMArgs | ||||||||||||||||||||||||||||||||||||||||
| from dingo.model.llm.base_openai import BaseOpenAI | ||||||||||||||||||||||||||||||||||||||||
| from dingo.utils.exception import ExceedMaxTokens | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| class BaseLiteLLM(BaseOpenAI): | ||||||||||||||||||||||||||||||||||||||||
| """Base class for LLM evaluators that route through LiteLLM. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Provides access to 100+ providers (Anthropic, Gemini, Bedrock, Cohere, | ||||||||||||||||||||||||||||||||||||||||
| Mistral, Groq, etc.) via a single unified interface. Inherit from this | ||||||||||||||||||||||||||||||||||||||||
| class instead of BaseOpenAI when you want provider flexibility. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Model string examples: | ||||||||||||||||||||||||||||||||||||||||
| - "anthropic/claude-3-5-sonnet-20241022" | ||||||||||||||||||||||||||||||||||||||||
| - "gemini/gemini-1.5-pro" | ||||||||||||||||||||||||||||||||||||||||
| - "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" | ||||||||||||||||||||||||||||||||||||||||
| - "groq/llama3-8b-8192" | ||||||||||||||||||||||||||||||||||||||||
| - "gpt-4o" (defaults to OpenAI, same as BaseOpenAI) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Configuration (via EvaluatorLLMArgs): | ||||||||||||||||||||||||||||||||||||||||
| - model: required, provider-prefixed model string | ||||||||||||||||||||||||||||||||||||||||
| - key: optional, API key (overrides provider env var) | ||||||||||||||||||||||||||||||||||||||||
| - api_url: optional, custom base URL (e.g. LiteLLM proxy URL) | ||||||||||||||||||||||||||||||||||||||||
| - Any extra field is forwarded to litellm.completion() as a kwarg. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Requires: pip install "dingo-python[litellm]" | ||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| dynamic_config: EvaluatorLLMArgs = EvaluatorLLMArgs() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||
| def create_client(cls): | ||||||||||||||||||||||||||||||||||||||||
| if not cls.dynamic_config.model: | ||||||||||||||||||||||||||||||||||||||||
| raise ValueError("model cannot be empty in llm config.") | ||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||
| import litellm # noqa: F401 | ||||||||||||||||||||||||||||||||||||||||
| except ImportError as exc: | ||||||||||||||||||||||||||||||||||||||||
| raise ImportError( | ||||||||||||||||||||||||||||||||||||||||
| "litellm is not installed. Run: pip install 'dingo-python[litellm]'" | ||||||||||||||||||||||||||||||||||||||||
| ) from exc | ||||||||||||||||||||||||||||||||||||||||
| # Use cls.client as an initialisation sentinel (no real client object needed). | ||||||||||||||||||||||||||||||||||||||||
| cls.client = True | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||
| def send_messages(cls, messages: List) -> str: | ||||||||||||||||||||||||||||||||||||||||
| import litellm | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| model_name = cls.dynamic_config.model or "" | ||||||||||||||||||||||||||||||||||||||||
| extra_params = cls.dynamic_config.model_extra or {} | ||||||||||||||||||||||||||||||||||||||||
| cls.validate_config(extra_params) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| call_kwargs: dict = { | ||||||||||||||||||||||||||||||||||||||||
| "drop_params": True, | ||||||||||||||||||||||||||||||||||||||||
| **extra_params, | ||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||
| if cls.dynamic_config.api_url: | ||||||||||||||||||||||||||||||||||||||||
| call_kwargs["api_base"] = cls.dynamic_config.api_url | ||||||||||||||||||||||||||||||||||||||||
| if cls.dynamic_config.key: | ||||||||||||||||||||||||||||||||||||||||
| call_kwargs["api_key"] = cls.dynamic_config.key | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| response = litellm.completion( | ||||||||||||||||||||||||||||||||||||||||
| model=model_name, | ||||||||||||||||||||||||||||||||||||||||
| messages=messages, | ||||||||||||||||||||||||||||||||||||||||
| **call_kwargs, | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| finish_reason = response.choices[0].finish_reason # type: ignore[union-attr] | ||||||||||||||||||||||||||||||||||||||||
| if finish_reason == "length": | ||||||||||||||||||||||||||||||||||||||||
| raise ExceedMaxTokens( | ||||||||||||||||||||||||||||||||||||||||
| f"Exceed max tokens: {extra_params.get('max_tokens', 4000)}" | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| return str(response.choices[0].message.content) # type: ignore[union-attr] | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+69
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Defensive checks should be added to handle cases where
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| """Unit tests for BaseLiteLLM provider.""" | ||
| import sys | ||
| from types import SimpleNamespace | ||
| from unittest import mock | ||
|
|
||
| import pytest | ||
|
|
||
| from dingo.config.input_args import EvaluatorLLMArgs | ||
| from dingo.model.llm.base_litellm import BaseLiteLLM | ||
| from dingo.utils.exception import ExceedMaxTokens | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Helpers | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| def _make_provider(**cfg_kwargs) -> type: | ||
| """Return a fresh BaseLiteLLM subclass with an isolated dynamic_config.""" | ||
|
|
||
| class _Provider(BaseLiteLLM): | ||
| prompt = "Evaluate: " | ||
| dynamic_config = EvaluatorLLMArgs(**cfg_kwargs) | ||
|
|
||
| return _Provider | ||
|
|
||
|
|
||
| def _stub_response(content='{"score": 1, "reason": "ok"}', finish_reason="stop"): | ||
| choice = SimpleNamespace( | ||
| finish_reason=finish_reason, | ||
| message=SimpleNamespace(content=content), | ||
| ) | ||
| return SimpleNamespace(choices=[choice]) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # create_client | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| class TestCreateClient: | ||
| def test_raises_without_model(self): | ||
| P = _make_provider() | ||
| with pytest.raises(ValueError, match="model cannot be empty"): | ||
| P.create_client() | ||
|
|
||
| def test_sets_sentinel_on_success(self): | ||
| P = _make_provider(model="anthropic/claude-haiku-4-5") | ||
| P.create_client() | ||
| assert P.client is True | ||
|
|
||
| def test_raises_import_error_when_litellm_missing(self, monkeypatch): | ||
| P = _make_provider(model="anthropic/claude-haiku-4-5") | ||
| monkeypatch.setitem(sys.modules, "litellm", None) # type: ignore[arg-type] | ||
| with pytest.raises((ImportError, TypeError)): | ||
| P.create_client() | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # send_messages | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| class TestSendMessages: | ||
| def test_dispatches_to_litellm(self): | ||
| P = _make_provider(model="anthropic/claude-haiku-4-5") | ||
| msgs = [{"role": "user", "content": "hello"}] | ||
| with mock.patch("litellm.completion", return_value=_stub_response()) as m: | ||
| P.send_messages(msgs) | ||
| m.assert_called_once() | ||
| assert m.call_args.kwargs["model"] == "anthropic/claude-haiku-4-5" | ||
| assert m.call_args.kwargs["messages"] == msgs | ||
|
|
||
| def test_drop_params_always_true(self): | ||
| P = _make_provider(model="gpt-4o") | ||
| with mock.patch("litellm.completion", return_value=_stub_response()) as m: | ||
| P.send_messages([{"role": "user", "content": "hi"}]) | ||
| assert m.call_args.kwargs.get("drop_params") is True | ||
|
|
||
| def test_api_key_forwarded(self): | ||
| P = _make_provider(model="gpt-4o", key="sk-test-123") | ||
| with mock.patch("litellm.completion", return_value=_stub_response()) as m: | ||
| P.send_messages([{"role": "user", "content": "hi"}]) | ||
| assert m.call_args.kwargs.get("api_key") == "sk-test-123" | ||
|
|
||
| def test_api_base_forwarded_when_url_set(self): | ||
| P = _make_provider(model="gpt-4o", api_url="https://my-proxy.example.com/v1") | ||
| with mock.patch("litellm.completion", return_value=_stub_response()) as m: | ||
| P.send_messages([{"role": "user", "content": "hi"}]) | ||
| assert m.call_args.kwargs.get("api_base") == "https://my-proxy.example.com/v1" | ||
|
|
||
| def test_no_api_key_when_key_not_set(self): | ||
| P = _make_provider(model="anthropic/claude-haiku-4-5") | ||
| with mock.patch("litellm.completion", return_value=_stub_response()) as m: | ||
| P.send_messages([{"role": "user", "content": "hi"}]) | ||
| assert "api_key" not in m.call_args.kwargs | ||
|
|
||
| def test_no_api_base_when_url_not_set(self): | ||
| P = _make_provider(model="gpt-4o") | ||
| with mock.patch("litellm.completion", return_value=_stub_response()) as m: | ||
| P.send_messages([{"role": "user", "content": "hi"}]) | ||
| assert "api_base" not in m.call_args.kwargs | ||
|
|
||
| def test_extra_params_forwarded(self): | ||
| P = _make_provider(model="gpt-4o", temperature=0.3, max_tokens=500) | ||
| with mock.patch("litellm.completion", return_value=_stub_response()) as m: | ||
| P.send_messages([{"role": "user", "content": "hi"}]) | ||
| kw = m.call_args.kwargs | ||
| assert kw.get("temperature") == 0.3 | ||
| assert kw.get("max_tokens") == 500 | ||
|
|
||
| def test_raises_on_length_finish_reason(self): | ||
| P = _make_provider(model="gpt-4o", max_tokens=10) | ||
| length_resp = _stub_response(finish_reason="length") | ||
| with mock.patch("litellm.completion", return_value=length_resp): | ||
| with pytest.raises(ExceedMaxTokens): | ||
| P.send_messages([{"role": "user", "content": "hi"}]) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # process_response (inherited from BaseOpenAI) | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| class TestProcessResponse: | ||
| def test_parses_good_score(self): | ||
| import json | ||
| response = json.dumps({"score": 1, "reason": "looks fine"}) | ||
| result = BaseLiteLLM.process_response(response) | ||
| assert result.status is False | ||
| assert result.label == ["QUALITY_GOOD"] | ||
|
|
||
| def test_parses_bad_score(self): | ||
| import json | ||
| response = json.dumps({"score": 0, "reason": "bad content"}) | ||
| result = BaseLiteLLM.process_response(response) | ||
| assert result.status is True | ||
| assert "BaseLiteLLM" in result.label[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
BaseLiteLLMclass overridescreate_clientbut completely omits the initialization of the embedding client (cls.embedding_clientandcls.embedding_model). This breaks compatibility with RAG-related evaluators that inherit fromBaseLiteLLMor useembedding_configinEvaluatorLLMArgs. To ensureBaseLiteLLMis a true drop-in replacement forBaseOpenAI, we should replicate the embedding initialization logic fromBaseOpenAI.create_client.