Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
8777439
Adding progress...
Jason-Lee08 Mar 21, 2025
4d09538
demo updates
Jason-Lee08 Mar 21, 2025
789a6a4
Merge branch 'main' of github.com:jaredquincy/ember
conconchowchow Mar 24, 2025
23b79a3
cleaning up full testbench
conconchowchow Mar 24, 2025
d260e61
Merge remote-tracking branch 'upstream/main'
Jason-Lee08 Mar 31, 2025
531e0a9
adding previous diversity integration (note: need to fix embedding mo…
conconchowchow Mar 31, 2025
36f3b13
fixing+adding evaluators/embbedding model/operator and updating testb…
conconchowchow Apr 2, 2025
94f2469
Merge remote-tracking branch 'upstream/main'
Jason-Lee08 Apr 2, 2025
ce74dcb
Merge branch 'main' of github.com:jaredquincy/ember
conconchowchow Apr 2, 2025
6183e28
changes to evaluators.py and cosine similarity
Jason-Lee08 Apr 2, 2025
a9d3760
Merge remote-tracking branch 'origin'
Jason-Lee08 Apr 2, 2025
130ca31
updating diversity scorer + evaluator to including embedding updates
conconchowchow Apr 2, 2025
b5f8143
updating evaluators
conconchowchow Apr 3, 2025
abf461f
Integrated capability models and examples using OpenAI extended model…
Jason-Lee08 Apr 5, 2025
4f5f099
changed testbench to use diversity operators implemented within the e…
Jason-Lee08 Apr 5, 2025
62604ba
modified eval's __init__.py file to include diversity operators, also…
Jason-Lee08 Apr 5, 2025
a25e7d3
removed todo in chat_schemas.py
Jason-Lee08 Apr 5, 2025
3dc3e08
updated testbench to use ember package + integrate jason's changes
conconchowchow Apr 5, 2025
607320f
renamed references to ember_logging
Jason-Lee08 Apr 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Ignore all __pycache__ directories
**/__pycache__/

# Ignore all .egg-info directories or files
*.egg-info/
2 changes: 1 addition & 1 deletion src/ember/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def initialize_ember(
model = registry.get_model("openai:gpt-4")
"""
# Import here to avoid circular imports
from ember.core.utils.logging import configure_logging
from ember.core.utils.ember_logging import configure_logging
from ember.core.config.manager import create_config_manager

# 0. Configure logging first
Expand Down
2 changes: 1 addition & 1 deletion src/ember/core/app_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ember.core.registry.model.base.registry.model_registry import ModelRegistry
from ember.core.registry.model.base.services.usage_service import UsageService
from ember.core.registry.model.initialization import initialize_registry
from ember.core.utils.logging import configure_logging
from ember.core.utils.ember_logging import configure_logging

# Re-import for patching to work correctly
import logging
Expand Down
1 change: 1 addition & 0 deletions src/ember/core/registry/model/base/schemas/chat_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,6 @@ class ChatResponse(BaseModel):
"""

data: str
embedding: list[float] = None
raw_output: Any = None
usage: Optional[UsageStats] = None
295 changes: 292 additions & 3 deletions src/ember/core/registry/model/providers/openai/openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@
"""

import logging
from typing import Any, Dict, Final, List, Optional, cast
from typing import Any, Dict, Final, List, Optional, cast, ClassVar

import openai
from pydantic import Field, field_validator
from pydantic import Field, field_validator, ConfigDict, BaseModel
from requests.exceptions import HTTPError
from tenacity import retry, stop_after_attempt, wait_exponential

Expand All @@ -84,7 +84,7 @@
ChatResponse,
ProviderParams,
)
from ember.core.registry.model.base.schemas.model_info import ModelInfo
from ember.core.registry.model.base.schemas.model_info import ModelInfo, ProviderInfo
from ember.core.registry.model.base.utils.model_registry_exceptions import (
InvalidPromptError,
ProviderAPIError,
Expand All @@ -96,6 +96,15 @@
)
from ember.plugin_system import provider

from ember.core.registry.model.providers.provider_capability import (
EmbeddingRequest,
EmbeddingResponse,
EmbeddingProviderModel,
CompletionRequest,
CompletionResponse,
TextCompletionProviderModel,
)
import os

class OpenAIProviderParams(ProviderParams):
"""OpenAI-specific provider parameters for fine-tuning API requests.
Expand Down Expand Up @@ -438,3 +447,283 @@ def forward(self, request: ChatRequest) -> ChatResponse:
message=f"API error: {str(exc)}",
cause=exc,
)

class OpenAICompletionParameters(BaseModel):
"""Parameter conversion for OpenAI, specifically text completion requests.

Handles parameter validation and conversion between Ember's universal format
and OpenAI's specific API requirements.

Attributes:
prompt: The text prompt to complete.
max_tokens: Maximum number of tokens to generate.
temperature: Controls randomness (0.0-2.0).
stop_sequences: Sequences that signal end of generation.
"""

model_config = ConfigDict(
protected_namespaces=(), # Disable Pydantic's protected namespace checks
)

prompt: str
max_tokens: Optional[int] = Field(default=50)
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
stop_sequences: Optional[List[str]] = None

def to_openai_kwargs(self) -> Dict[str, Any]:
"""Converting parameters to OpenAI API format.

Returns:
Dictionary of parameters for the OpenAI API.
"""
kwargs: Dict[str, Any] = {
"prompt": self.prompt,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
}

if self.stop_sequences:
kwargs["stop"] = self.stop_sequences

return kwargs

@provider("OpenAIExtended")
class OpenAIExtendedModel(TextCompletionProviderModel, EmbeddingProviderModel):
"""Extended OpenAI provider supporting chat, text completion, and embeddings.

This class implements a provider that supports multiple model types through
capability interfaces.

Attributes:
PROVIDER_NAME: Provider name for registration with the plugin system.
CAPABILITIES: Capability flags showing supported model types.
"""

PROVIDER_NAME: ClassVar[str] = "OpenAIExtended"
CAPABILITIES: ClassVar[Dict[str, bool]] = {
"chat": True,
"completion": True,
"embedding": True,
}

def create_client(self) -> Any:
"""Creating and configuring the OpenAI client.

Retrieves the API key from the model information and configures the client.

Returns:
The configured OpenAI client.

Raises:
ProviderAPIError: If API key is missing or invalid.
"""
import openai

api_key: Optional[str] = self.model_info.get_api_key()
if not api_key:
raise ProviderAPIError("OpenAI API key is missing or invalid.")

openai.api_key = api_key
return openai

def forward(self, request: ChatRequest) -> ChatResponse:
"""Processing a chat request (implementing BaseProviderModel).

This method provides the standard chat functionality required by
the BaseProviderModel interface.

Args:
request: Chat request to process.

Returns:
Chat response from the model.

Raises:
InvalidPromptError: If prompt is empty.
ProviderAPIError: For unexpected errors during API calls.
"""
# Implementation would match OpenAIModel's forward method
# This is a simplified placeholder
if not request.prompt:
raise InvalidPromptError("OpenAI prompt cannot be empty.")

# Implementation details would mirror the standard OpenAIModel
# Return placeholder
return ChatResponse(data="Chat implementation placeholder")

def complete(self, request: CompletionRequest) -> CompletionResponse:
"""Processing a text completion request.

Implements text completion capabilities using the OpenAI completions API.

Args:
request: Text completion request.

Returns:
Completion response from the model.

Raises:
InvalidPromptError: If prompt is empty.
ProviderAPIError: For unexpected errors during API calls.
"""
if not request.prompt:
raise InvalidPromptError("OpenAI completion prompt cannot be empty.")

logger.info(
"OpenAI completion invoked",
extra={
"provider": self.PROVIDER_NAME,
"model_name": self.model_info.name,
"prompt_length": len(request.prompt),
},
)

# Convert universal parameters to OpenAI format
openai_parameters = OpenAICompletionParameters(
prompt=request.prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
stop_sequences=request.stop_sequences,
)
openai_kwargs = openai_parameters.to_openai_kwargs()

# Add provider-specific parameters
provider_params = cast(OpenAICompletionParameters, request.provider_params)
openai_kwargs.update(
{k: v for k, v in provider_params.items() if v is not None}
)

try:
# Request timeout from parameters or default
timeout = openai_kwargs.pop("timeout", 30)

# Make the API call
response = self.client.completions.create(
model=self.model_info.name,
timeout=timeout,
**openai_kwargs,
)

# Extract completion text
text = response.choices[0].text.strip()

# Calculate usage statistics
# For simplicity, we assume a usage calculator is implemented elsewhere
usage_stats = (
None # self.usage_calculator.calculate(response, self.model_info)
)

return CompletionResponse(
text=text,
raw_output=response,
usage=usage_stats,
)

except Exception as exc:
logger.exception("Unexpected error in OpenAIExtendedModel.complete()")
raise ProviderAPIError(str(exc)) from exc

def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
"""Generating embeddings for the input text(s).

Implements embedding capabilities using the OpenAI embeddings API.

Args:
request: Embedding request with input text(s).

Returns:
Embedding response with vector representations.

Raises:
InvalidPromptError: If input is empty.
ProviderAPIError: For unexpected errors during API calls.
"""
# Use the provided model or default to the model in model_info
model_name = request.model or self.model_info.name

input_text = request.input
if not input_text:
raise InvalidPromptError("Input text for embeddings cannot be empty.")

logger.info(
"OpenAI embeddings invoked",
extra={
"provider": self.PROVIDER_NAME,
"model_name": model_name,
"input_type": "batch" if isinstance(input_text, list) else "single",
},
)

try:
# Make the API call
response = self.client.embeddings.create(
model=model_name,
input=input_text,
timeout=30,
)

# Extract embeddings
if isinstance(input_text, list):
print(f"batch processing")
# For batch processing
embeddings = [item.embedding for item in response.data]
else:
# For single text input
embeddings = response.data[0].embedding

# Get dimensions from the first embedding
if isinstance(embeddings, list) and isinstance(embeddings[0], list):
dimensions = len(embeddings[0])
else:
dimensions = len(embeddings)

# Calculate usage statistics (implementation would depend on your system)
usage_stats = (
None # self.usage_calculator.calculate(response, self.model_info)
)

return EmbeddingResponse(
embeddings=embeddings,
model=model_name,
dimensions=dimensions,
raw_output=response,
usage=usage_stats,
)

except Exception as exc:
logger.exception("Unexpected error in OpenAIExtendedModel.embed()")
raise ProviderAPIError(str(exc)) from exc


def create_openai_embedding_model(model_name: str = "text-embedding-ada-002") -> OpenAIExtendedModel:
"""
Tool for creating an OpenAI embedding model by passing the embedding model name.

Args:
model_name: Name of particular embedding model endpoint as specified by the OpenAI API

Returns:
OpenAIExtendedModel initialized to serve model_name; None if model could not
be created

Raises:
InvalidPromptError: If input is empty.
ProviderAPIError: For unexpected errors during API calls.
"""
# All OpenAI embedding models contain "text-embedding" in their model name
if "text-embedding" not in model_name:
return None

model_info = ModelInfo(
id="openai:gpt-4o",
name=model_name,
provider=ProviderInfo(
name="OpenAI",
default_api_key=os.environ.get("OPENAI_API_KEY"),
base_url="https://api.openai.com/v1",
)
)

embedding_model = OpenAIExtendedModel(model_info)

return embedding_model
Loading