Files
oai/oai/providers/openai.py

631 lines
19 KiB
Python

"""
OpenAI provider for GPT models.
This provider connects to OpenAI's API for accessing GPT-4, GPT-3.5, and other OpenAI models.
"""
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from openai import OpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from oai.constants import OPENAI_BASE_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ChatResponseChoice,
ModelInfo,
ProviderCapabilities,
StreamChunk,
ToolCall,
ToolFunction,
UsageStats,
)
from oai.utils.logging import get_logger
logger = get_logger()
# Model aliases for convenience
MODEL_ALIASES = {
"gpt-4": "gpt-4-turbo",
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
"gpt-4o": "gpt-4o-2024-11-20",
"gpt-4o-mini": "gpt-4o-mini-2024-07-18",
"gpt-3.5": "gpt-3.5-turbo",
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
"o1": "o1-2024-12-17",
"o1-mini": "o1-mini-2024-09-12",
"o1-preview": "o1-preview-2024-09-12",
}
class OpenAIProvider(AIProvider):
"""
OpenAI API provider.
Provides access to GPT-4, GPT-3.5, o1, and other OpenAI models.
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
app_name: str = "oAI",
app_url: str = "",
**kwargs: Any,
):
"""
Initialize OpenAI provider.
Args:
api_key: OpenAI API key
base_url: Optional custom base URL
app_name: Application name (for headers)
app_url: Application URL (for headers)
**kwargs: Additional arguments
"""
super().__init__(api_key, base_url or OPENAI_BASE_URL)
self.client = OpenAI(api_key=api_key, base_url=self.base_url)
self.async_client = AsyncOpenAI(api_key=api_key, base_url=self.base_url)
self._models_cache: Optional[List[ModelInfo]] = None
@property
def name(self) -> str:
"""Get provider name."""
return "OpenAI"
@property
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
return ProviderCapabilities(
streaming=True,
tools=True,
images=True,
online=True, # Web search via DuckDuckGo/Google
max_context=128000,
)
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
List available OpenAI models.
Args:
filter_text_only: Whether to filter for text models only
Returns:
List of ModelInfo objects
"""
if self._models_cache:
return self._models_cache
try:
models_response = self.client.models.list()
models = []
for model in models_response.data:
# Filter for chat models
if "gpt" in model.id or "o1" in model.id:
models.append(self._parse_model(model))
# Sort by name
models.sort(key=lambda m: m.name)
self._models_cache = models
logger.info(f"Loaded {len(models)} OpenAI models")
return models
except Exception as e:
logger.error(f"Failed to list OpenAI models: {e}")
return self._get_fallback_models()
def _get_fallback_models(self) -> List[ModelInfo]:
"""
Get fallback list of common OpenAI models.
Returns:
List of common models
"""
return [
ModelInfo(
id="gpt-4o",
name="GPT-4o",
description="Most capable GPT-4 model",
context_length=128000,
pricing={"input": 5.0, "output": 15.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="gpt-4o-mini",
name="GPT-4o Mini",
description="Affordable and fast GPT-4 class model",
context_length=128000,
pricing={"input": 0.15, "output": 0.6},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="gpt-4-turbo",
name="GPT-4 Turbo",
description="GPT-4 Turbo with vision",
context_length=128000,
pricing={"input": 10.0, "output": 30.0},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
input_modalities=["text", "image"],
),
ModelInfo(
id="gpt-3.5-turbo",
name="GPT-3.5 Turbo",
description="Fast and affordable model",
context_length=16384,
pricing={"input": 0.5, "output": 1.5},
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
),
ModelInfo(
id="o1",
name="o1",
description="Advanced reasoning model",
context_length=200000,
pricing={"input": 15.0, "output": 60.0},
supported_parameters=["max_tokens"],
),
ModelInfo(
id="o1-mini",
name="o1-mini",
description="Fast reasoning model",
context_length=128000,
pricing={"input": 3.0, "output": 12.0},
supported_parameters=["max_tokens"],
),
]
def _parse_model(self, model: Any) -> ModelInfo:
"""
Parse OpenAI model into ModelInfo.
Args:
model: OpenAI model object
Returns:
ModelInfo object
"""
model_id = model.id
# Determine context length
context_length = 8192 # Default
if "gpt-4o" in model_id or "gpt-4-turbo" in model_id:
context_length = 128000
elif "gpt-4" in model_id:
context_length = 8192
elif "gpt-3.5-turbo" in model_id:
context_length = 16384
elif "o1" in model_id:
context_length = 128000
# Determine pricing (approximate)
pricing = {}
if "gpt-4o-mini" in model_id:
pricing = {"input": 0.15, "output": 0.6}
elif "gpt-4o" in model_id:
pricing = {"input": 5.0, "output": 15.0}
elif "gpt-4-turbo" in model_id:
pricing = {"input": 10.0, "output": 30.0}
elif "gpt-4" in model_id:
pricing = {"input": 30.0, "output": 60.0}
elif "gpt-3.5" in model_id:
pricing = {"input": 0.5, "output": 1.5}
elif "o1" in model_id and "mini" not in model_id:
pricing = {"input": 15.0, "output": 60.0}
elif "o1-mini" in model_id:
pricing = {"input": 3.0, "output": 12.0}
return ModelInfo(
id=model_id,
name=model_id,
description="",
context_length=context_length,
pricing=pricing,
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
)
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: Model identifier
Returns:
ModelInfo or None
"""
# Resolve alias
resolved_id = MODEL_ALIASES.get(model_id, model_id)
models = self.list_models()
for model in models:
if model.id == resolved_id or model.id == model_id:
return model
# Try to fetch directly
try:
model = self.client.models.retrieve(resolved_id)
return self._parse_model(model)
except Exception:
return None
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send chat completion request to OpenAI.
Args:
model: Model ID
messages: Chat messages
stream: Whether to stream response
max_tokens: Maximum tokens
temperature: Sampling temperature
tools: Tool definitions
tool_choice: Tool selection mode
**kwargs: Additional parameters
Returns:
ChatResponse or Iterator[StreamChunk]
"""
# Resolve model alias
model_id = MODEL_ALIASES.get(model, model)
# Convert messages to OpenAI format
openai_messages = []
for msg in messages:
message_dict = {"role": msg.role, "content": msg.content or ""}
if msg.tool_calls:
message_dict["tool_calls"] = [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in msg.tool_calls
]
if msg.tool_call_id:
message_dict["tool_call_id"] = msg.tool_call_id
openai_messages.append(message_dict)
# Build request parameters
params: Dict[str, Any] = {
"model": model_id,
"messages": openai_messages,
"stream": stream,
}
# Add optional parameters
if max_tokens is not None:
params["max_tokens"] = max_tokens
if temperature is not None and "o1" not in model_id:
# o1 models don't support temperature
params["temperature"] = temperature
if tools:
params["tools"] = tools
if tool_choice:
params["tool_choice"] = tool_choice
logger.debug(f"OpenAI request: model={model_id}, messages={len(openai_messages)}")
try:
if stream:
return self._stream_chat(params)
else:
return self._sync_chat(params)
except Exception as e:
logger.error(f"OpenAI request failed: {e}")
return ChatResponse(
id="error",
choices=[
ChatResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=f"Error: {str(e)}"),
finish_reason="error",
)
],
)
def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse:
"""
Send synchronous chat request.
Args:
params: Request parameters
Returns:
ChatResponse
"""
completion: ChatCompletion = self.client.chat.completions.create(**params)
# Convert to our format
choices = []
for choice in completion.choices:
# Convert tool calls if present
tool_calls = None
if choice.message.tool_calls:
tool_calls = [
ToolCall(
id=tc.id,
type=tc.type,
function=ToolFunction(
name=tc.function.name,
arguments=tc.function.arguments,
),
)
for tc in choice.message.tool_calls
]
choices.append(
ChatResponseChoice(
index=choice.index,
message=ChatMessage(
role=choice.message.role,
content=choice.message.content,
tool_calls=tool_calls,
),
finish_reason=choice.finish_reason,
)
)
# Convert usage
usage = None
if completion.usage:
usage = UsageStats(
prompt_tokens=completion.usage.prompt_tokens,
completion_tokens=completion.usage.completion_tokens,
total_tokens=completion.usage.total_tokens,
)
return ChatResponse(
id=completion.id,
choices=choices,
usage=usage,
model=completion.model,
created=completion.created,
)
def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]:
"""
Stream chat response from OpenAI.
Args:
params: Request parameters
Yields:
StreamChunk objects
"""
stream = self.client.chat.completions.create(**params)
for chunk in stream:
chunk_data: ChatCompletionChunk = chunk
if not chunk_data.choices:
continue
choice = chunk_data.choices[0]
delta = choice.delta
# Extract content
content = delta.content if delta.content else None
# Extract finish reason
finish_reason = choice.finish_reason
# Extract usage (usually in last chunk)
usage = None
if hasattr(chunk_data, "usage") and chunk_data.usage:
usage = UsageStats(
prompt_tokens=chunk_data.usage.prompt_tokens,
completion_tokens=chunk_data.usage.completion_tokens,
total_tokens=chunk_data.usage.total_tokens,
)
yield StreamChunk(
id=chunk_data.id,
delta_content=content,
finish_reason=finish_reason,
usage=usage,
)
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Send async chat request to OpenAI.
Args:
model: Model ID
messages: Chat messages
stream: Whether to stream
max_tokens: Max tokens
temperature: Temperature
tools: Tool definitions
tool_choice: Tool choice
**kwargs: Additional args
Returns:
ChatResponse or AsyncIterator[StreamChunk]
"""
# Resolve model alias
model_id = MODEL_ALIASES.get(model, model)
# Convert messages
openai_messages = [msg.to_dict() for msg in messages]
# Build params
params: Dict[str, Any] = {
"model": model_id,
"messages": openai_messages,
"stream": stream,
}
if max_tokens:
params["max_tokens"] = max_tokens
if temperature is not None and "o1" not in model_id:
params["temperature"] = temperature
if tools:
params["tools"] = tools
if tool_choice:
params["tool_choice"] = tool_choice
if stream:
return self._stream_chat_async(params)
else:
completion = await self.async_client.chat.completions.create(**params)
# Convert to ChatResponse (similar to _sync_chat)
return self._convert_completion(completion)
async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]:
"""
Stream async chat response.
Args:
params: Request parameters
Yields:
StreamChunk objects
"""
stream = await self.async_client.chat.completions.create(**params)
async for chunk in stream:
if not chunk.choices:
continue
choice = chunk.choices[0]
delta = choice.delta
yield StreamChunk(
id=chunk.id,
delta_content=delta.content,
finish_reason=choice.finish_reason,
)
def _convert_completion(self, completion: ChatCompletion) -> ChatResponse:
"""Helper to convert OpenAI completion to ChatResponse."""
choices = []
for choice in completion.choices:
tool_calls = None
if choice.message.tool_calls:
tool_calls = [
ToolCall(
id=tc.id,
type=tc.type,
function=ToolFunction(
name=tc.function.name,
arguments=tc.function.arguments,
),
)
for tc in choice.message.tool_calls
]
choices.append(
ChatResponseChoice(
index=choice.index,
message=ChatMessage(
role=choice.message.role,
content=choice.message.content,
tool_calls=tool_calls,
),
finish_reason=choice.finish_reason,
)
)
usage = None
if completion.usage:
usage = UsageStats(
prompt_tokens=completion.usage.prompt_tokens,
completion_tokens=completion.usage.completion_tokens,
total_tokens=completion.usage.total_tokens,
)
return ChatResponse(
id=completion.id,
choices=choices,
usage=usage,
model=completion.model,
created=completion.created,
)
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credits.
Returns:
None (OpenAI doesn't provide credit API)
"""
return None
def clear_cache(self) -> None:
"""Clear model cache."""
self._models_cache = None
def get_raw_models(self) -> List[Dict[str, Any]]:
"""
Get raw model data as dictionaries.
Returns:
List of model dictionaries
"""
models = self.list_models()
return [
{
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
for model in models
]
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for a specific model.
Args:
model_id: Model identifier
Returns:
Model dictionary or None
"""
model = self.get_model(model_id)
if model:
return {
"id": model.id,
"name": model.name,
"description": model.description,
"context_length": model.context_length,
"pricing": model.pricing,
}
return None