631 lines
19 KiB
Python
631 lines
19 KiB
Python
"""
|
|
OpenAI provider for GPT models.
|
|
|
|
This provider connects to OpenAI's API for accessing GPT-4, GPT-3.5, and other OpenAI models.
|
|
"""
|
|
|
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
|
|
|
from openai import OpenAI, AsyncOpenAI
|
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|
|
|
from oai.constants import OPENAI_BASE_URL
|
|
from oai.providers.base import (
|
|
AIProvider,
|
|
ChatMessage,
|
|
ChatResponse,
|
|
ChatResponseChoice,
|
|
ModelInfo,
|
|
ProviderCapabilities,
|
|
StreamChunk,
|
|
ToolCall,
|
|
ToolFunction,
|
|
UsageStats,
|
|
)
|
|
from oai.utils.logging import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
# Model aliases for convenience
|
|
MODEL_ALIASES = {
|
|
"gpt-4": "gpt-4-turbo",
|
|
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
|
"gpt-4o": "gpt-4o-2024-11-20",
|
|
"gpt-4o-mini": "gpt-4o-mini-2024-07-18",
|
|
"gpt-3.5": "gpt-3.5-turbo",
|
|
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
|
"o1": "o1-2024-12-17",
|
|
"o1-mini": "o1-mini-2024-09-12",
|
|
"o1-preview": "o1-preview-2024-09-12",
|
|
}
|
|
|
|
|
|
class OpenAIProvider(AIProvider):
|
|
"""
|
|
OpenAI API provider.
|
|
|
|
Provides access to GPT-4, GPT-3.5, o1, and other OpenAI models.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
base_url: Optional[str] = None,
|
|
app_name: str = "oAI",
|
|
app_url: str = "",
|
|
**kwargs: Any,
|
|
):
|
|
"""
|
|
Initialize OpenAI provider.
|
|
|
|
Args:
|
|
api_key: OpenAI API key
|
|
base_url: Optional custom base URL
|
|
app_name: Application name (for headers)
|
|
app_url: Application URL (for headers)
|
|
**kwargs: Additional arguments
|
|
"""
|
|
super().__init__(api_key, base_url or OPENAI_BASE_URL)
|
|
self.client = OpenAI(api_key=api_key, base_url=self.base_url)
|
|
self.async_client = AsyncOpenAI(api_key=api_key, base_url=self.base_url)
|
|
self._models_cache: Optional[List[ModelInfo]] = None
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""Get provider name."""
|
|
return "OpenAI"
|
|
|
|
@property
|
|
def capabilities(self) -> ProviderCapabilities:
|
|
"""Get provider capabilities."""
|
|
return ProviderCapabilities(
|
|
streaming=True,
|
|
tools=True,
|
|
images=True,
|
|
online=True, # Web search via DuckDuckGo/Google
|
|
max_context=128000,
|
|
)
|
|
|
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
|
"""
|
|
List available OpenAI models.
|
|
|
|
Args:
|
|
filter_text_only: Whether to filter for text models only
|
|
|
|
Returns:
|
|
List of ModelInfo objects
|
|
"""
|
|
if self._models_cache:
|
|
return self._models_cache
|
|
|
|
try:
|
|
models_response = self.client.models.list()
|
|
models = []
|
|
|
|
for model in models_response.data:
|
|
# Filter for chat models
|
|
if "gpt" in model.id or "o1" in model.id:
|
|
models.append(self._parse_model(model))
|
|
|
|
# Sort by name
|
|
models.sort(key=lambda m: m.name)
|
|
self._models_cache = models
|
|
|
|
logger.info(f"Loaded {len(models)} OpenAI models")
|
|
return models
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to list OpenAI models: {e}")
|
|
return self._get_fallback_models()
|
|
|
|
def _get_fallback_models(self) -> List[ModelInfo]:
|
|
"""
|
|
Get fallback list of common OpenAI models.
|
|
|
|
Returns:
|
|
List of common models
|
|
"""
|
|
return [
|
|
ModelInfo(
|
|
id="gpt-4o",
|
|
name="GPT-4o",
|
|
description="Most capable GPT-4 model",
|
|
context_length=128000,
|
|
pricing={"input": 5.0, "output": 15.0},
|
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
|
input_modalities=["text", "image"],
|
|
),
|
|
ModelInfo(
|
|
id="gpt-4o-mini",
|
|
name="GPT-4o Mini",
|
|
description="Affordable and fast GPT-4 class model",
|
|
context_length=128000,
|
|
pricing={"input": 0.15, "output": 0.6},
|
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
|
input_modalities=["text", "image"],
|
|
),
|
|
ModelInfo(
|
|
id="gpt-4-turbo",
|
|
name="GPT-4 Turbo",
|
|
description="GPT-4 Turbo with vision",
|
|
context_length=128000,
|
|
pricing={"input": 10.0, "output": 30.0},
|
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
|
input_modalities=["text", "image"],
|
|
),
|
|
ModelInfo(
|
|
id="gpt-3.5-turbo",
|
|
name="GPT-3.5 Turbo",
|
|
description="Fast and affordable model",
|
|
context_length=16384,
|
|
pricing={"input": 0.5, "output": 1.5},
|
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
|
),
|
|
ModelInfo(
|
|
id="o1",
|
|
name="o1",
|
|
description="Advanced reasoning model",
|
|
context_length=200000,
|
|
pricing={"input": 15.0, "output": 60.0},
|
|
supported_parameters=["max_tokens"],
|
|
),
|
|
ModelInfo(
|
|
id="o1-mini",
|
|
name="o1-mini",
|
|
description="Fast reasoning model",
|
|
context_length=128000,
|
|
pricing={"input": 3.0, "output": 12.0},
|
|
supported_parameters=["max_tokens"],
|
|
),
|
|
]
|
|
|
|
def _parse_model(self, model: Any) -> ModelInfo:
|
|
"""
|
|
Parse OpenAI model into ModelInfo.
|
|
|
|
Args:
|
|
model: OpenAI model object
|
|
|
|
Returns:
|
|
ModelInfo object
|
|
"""
|
|
model_id = model.id
|
|
|
|
# Determine context length
|
|
context_length = 8192 # Default
|
|
if "gpt-4o" in model_id or "gpt-4-turbo" in model_id:
|
|
context_length = 128000
|
|
elif "gpt-4" in model_id:
|
|
context_length = 8192
|
|
elif "gpt-3.5-turbo" in model_id:
|
|
context_length = 16384
|
|
elif "o1" in model_id:
|
|
context_length = 128000
|
|
|
|
# Determine pricing (approximate)
|
|
pricing = {}
|
|
if "gpt-4o-mini" in model_id:
|
|
pricing = {"input": 0.15, "output": 0.6}
|
|
elif "gpt-4o" in model_id:
|
|
pricing = {"input": 5.0, "output": 15.0}
|
|
elif "gpt-4-turbo" in model_id:
|
|
pricing = {"input": 10.0, "output": 30.0}
|
|
elif "gpt-4" in model_id:
|
|
pricing = {"input": 30.0, "output": 60.0}
|
|
elif "gpt-3.5" in model_id:
|
|
pricing = {"input": 0.5, "output": 1.5}
|
|
elif "o1" in model_id and "mini" not in model_id:
|
|
pricing = {"input": 15.0, "output": 60.0}
|
|
elif "o1-mini" in model_id:
|
|
pricing = {"input": 3.0, "output": 12.0}
|
|
|
|
return ModelInfo(
|
|
id=model_id,
|
|
name=model_id,
|
|
description="",
|
|
context_length=context_length,
|
|
pricing=pricing,
|
|
supported_parameters=["temperature", "max_tokens", "stream", "tools"],
|
|
)
|
|
|
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
|
"""
|
|
Get information about a specific model.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
|
|
Returns:
|
|
ModelInfo or None
|
|
"""
|
|
# Resolve alias
|
|
resolved_id = MODEL_ALIASES.get(model_id, model_id)
|
|
|
|
models = self.list_models()
|
|
for model in models:
|
|
if model.id == resolved_id or model.id == model_id:
|
|
return model
|
|
|
|
# Try to fetch directly
|
|
try:
|
|
model = self.client.models.retrieve(resolved_id)
|
|
return self._parse_model(model)
|
|
except Exception:
|
|
return None
|
|
|
|
def chat(
|
|
self,
|
|
model: str,
|
|
messages: List[ChatMessage],
|
|
stream: bool = False,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
tool_choice: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
|
"""
|
|
Send chat completion request to OpenAI.
|
|
|
|
Args:
|
|
model: Model ID
|
|
messages: Chat messages
|
|
stream: Whether to stream response
|
|
max_tokens: Maximum tokens
|
|
temperature: Sampling temperature
|
|
tools: Tool definitions
|
|
tool_choice: Tool selection mode
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
ChatResponse or Iterator[StreamChunk]
|
|
"""
|
|
# Resolve model alias
|
|
model_id = MODEL_ALIASES.get(model, model)
|
|
|
|
# Convert messages to OpenAI format
|
|
openai_messages = []
|
|
for msg in messages:
|
|
message_dict = {"role": msg.role, "content": msg.content or ""}
|
|
|
|
if msg.tool_calls:
|
|
message_dict["tool_calls"] = [
|
|
{
|
|
"id": tc.id,
|
|
"type": tc.type,
|
|
"function": {
|
|
"name": tc.function.name,
|
|
"arguments": tc.function.arguments,
|
|
},
|
|
}
|
|
for tc in msg.tool_calls
|
|
]
|
|
|
|
if msg.tool_call_id:
|
|
message_dict["tool_call_id"] = msg.tool_call_id
|
|
|
|
openai_messages.append(message_dict)
|
|
|
|
# Build request parameters
|
|
params: Dict[str, Any] = {
|
|
"model": model_id,
|
|
"messages": openai_messages,
|
|
"stream": stream,
|
|
}
|
|
|
|
# Add optional parameters
|
|
if max_tokens is not None:
|
|
params["max_tokens"] = max_tokens
|
|
if temperature is not None and "o1" not in model_id:
|
|
# o1 models don't support temperature
|
|
params["temperature"] = temperature
|
|
if tools:
|
|
params["tools"] = tools
|
|
if tool_choice:
|
|
params["tool_choice"] = tool_choice
|
|
|
|
logger.debug(f"OpenAI request: model={model_id}, messages={len(openai_messages)}")
|
|
|
|
try:
|
|
if stream:
|
|
return self._stream_chat(params)
|
|
else:
|
|
return self._sync_chat(params)
|
|
|
|
except Exception as e:
|
|
logger.error(f"OpenAI request failed: {e}")
|
|
return ChatResponse(
|
|
id="error",
|
|
choices=[
|
|
ChatResponseChoice(
|
|
index=0,
|
|
message=ChatMessage(role="assistant", content=f"Error: {str(e)}"),
|
|
finish_reason="error",
|
|
)
|
|
],
|
|
)
|
|
|
|
def _sync_chat(self, params: Dict[str, Any]) -> ChatResponse:
|
|
"""
|
|
Send synchronous chat request.
|
|
|
|
Args:
|
|
params: Request parameters
|
|
|
|
Returns:
|
|
ChatResponse
|
|
"""
|
|
completion: ChatCompletion = self.client.chat.completions.create(**params)
|
|
|
|
# Convert to our format
|
|
choices = []
|
|
for choice in completion.choices:
|
|
# Convert tool calls if present
|
|
tool_calls = None
|
|
if choice.message.tool_calls:
|
|
tool_calls = [
|
|
ToolCall(
|
|
id=tc.id,
|
|
type=tc.type,
|
|
function=ToolFunction(
|
|
name=tc.function.name,
|
|
arguments=tc.function.arguments,
|
|
),
|
|
)
|
|
for tc in choice.message.tool_calls
|
|
]
|
|
|
|
choices.append(
|
|
ChatResponseChoice(
|
|
index=choice.index,
|
|
message=ChatMessage(
|
|
role=choice.message.role,
|
|
content=choice.message.content,
|
|
tool_calls=tool_calls,
|
|
),
|
|
finish_reason=choice.finish_reason,
|
|
)
|
|
)
|
|
|
|
# Convert usage
|
|
usage = None
|
|
if completion.usage:
|
|
usage = UsageStats(
|
|
prompt_tokens=completion.usage.prompt_tokens,
|
|
completion_tokens=completion.usage.completion_tokens,
|
|
total_tokens=completion.usage.total_tokens,
|
|
)
|
|
|
|
return ChatResponse(
|
|
id=completion.id,
|
|
choices=choices,
|
|
usage=usage,
|
|
model=completion.model,
|
|
created=completion.created,
|
|
)
|
|
|
|
def _stream_chat(self, params: Dict[str, Any]) -> Iterator[StreamChunk]:
|
|
"""
|
|
Stream chat response from OpenAI.
|
|
|
|
Args:
|
|
params: Request parameters
|
|
|
|
Yields:
|
|
StreamChunk objects
|
|
"""
|
|
stream = self.client.chat.completions.create(**params)
|
|
|
|
for chunk in stream:
|
|
chunk_data: ChatCompletionChunk = chunk
|
|
|
|
if not chunk_data.choices:
|
|
continue
|
|
|
|
choice = chunk_data.choices[0]
|
|
delta = choice.delta
|
|
|
|
# Extract content
|
|
content = delta.content if delta.content else None
|
|
|
|
# Extract finish reason
|
|
finish_reason = choice.finish_reason
|
|
|
|
# Extract usage (usually in last chunk)
|
|
usage = None
|
|
if hasattr(chunk_data, "usage") and chunk_data.usage:
|
|
usage = UsageStats(
|
|
prompt_tokens=chunk_data.usage.prompt_tokens,
|
|
completion_tokens=chunk_data.usage.completion_tokens,
|
|
total_tokens=chunk_data.usage.total_tokens,
|
|
)
|
|
|
|
yield StreamChunk(
|
|
id=chunk_data.id,
|
|
delta_content=content,
|
|
finish_reason=finish_reason,
|
|
usage=usage,
|
|
)
|
|
|
|
async def chat_async(
|
|
self,
|
|
model: str,
|
|
messages: List[ChatMessage],
|
|
stream: bool = False,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
tool_choice: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
|
"""
|
|
Send async chat request to OpenAI.
|
|
|
|
Args:
|
|
model: Model ID
|
|
messages: Chat messages
|
|
stream: Whether to stream
|
|
max_tokens: Max tokens
|
|
temperature: Temperature
|
|
tools: Tool definitions
|
|
tool_choice: Tool choice
|
|
**kwargs: Additional args
|
|
|
|
Returns:
|
|
ChatResponse or AsyncIterator[StreamChunk]
|
|
"""
|
|
# Resolve model alias
|
|
model_id = MODEL_ALIASES.get(model, model)
|
|
|
|
# Convert messages
|
|
openai_messages = [msg.to_dict() for msg in messages]
|
|
|
|
# Build params
|
|
params: Dict[str, Any] = {
|
|
"model": model_id,
|
|
"messages": openai_messages,
|
|
"stream": stream,
|
|
}
|
|
|
|
if max_tokens:
|
|
params["max_tokens"] = max_tokens
|
|
if temperature is not None and "o1" not in model_id:
|
|
params["temperature"] = temperature
|
|
if tools:
|
|
params["tools"] = tools
|
|
if tool_choice:
|
|
params["tool_choice"] = tool_choice
|
|
|
|
if stream:
|
|
return self._stream_chat_async(params)
|
|
else:
|
|
completion = await self.async_client.chat.completions.create(**params)
|
|
# Convert to ChatResponse (similar to _sync_chat)
|
|
return self._convert_completion(completion)
|
|
|
|
async def _stream_chat_async(self, params: Dict[str, Any]) -> AsyncIterator[StreamChunk]:
|
|
"""
|
|
Stream async chat response.
|
|
|
|
Args:
|
|
params: Request parameters
|
|
|
|
Yields:
|
|
StreamChunk objects
|
|
"""
|
|
stream = await self.async_client.chat.completions.create(**params)
|
|
|
|
async for chunk in stream:
|
|
if not chunk.choices:
|
|
continue
|
|
|
|
choice = chunk.choices[0]
|
|
delta = choice.delta
|
|
|
|
yield StreamChunk(
|
|
id=chunk.id,
|
|
delta_content=delta.content,
|
|
finish_reason=choice.finish_reason,
|
|
)
|
|
|
|
def _convert_completion(self, completion: ChatCompletion) -> ChatResponse:
|
|
"""Helper to convert OpenAI completion to ChatResponse."""
|
|
choices = []
|
|
for choice in completion.choices:
|
|
tool_calls = None
|
|
if choice.message.tool_calls:
|
|
tool_calls = [
|
|
ToolCall(
|
|
id=tc.id,
|
|
type=tc.type,
|
|
function=ToolFunction(
|
|
name=tc.function.name,
|
|
arguments=tc.function.arguments,
|
|
),
|
|
)
|
|
for tc in choice.message.tool_calls
|
|
]
|
|
|
|
choices.append(
|
|
ChatResponseChoice(
|
|
index=choice.index,
|
|
message=ChatMessage(
|
|
role=choice.message.role,
|
|
content=choice.message.content,
|
|
tool_calls=tool_calls,
|
|
),
|
|
finish_reason=choice.finish_reason,
|
|
)
|
|
)
|
|
|
|
usage = None
|
|
if completion.usage:
|
|
usage = UsageStats(
|
|
prompt_tokens=completion.usage.prompt_tokens,
|
|
completion_tokens=completion.usage.completion_tokens,
|
|
total_tokens=completion.usage.total_tokens,
|
|
)
|
|
|
|
return ChatResponse(
|
|
id=completion.id,
|
|
choices=choices,
|
|
usage=usage,
|
|
model=completion.model,
|
|
created=completion.created,
|
|
)
|
|
|
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get account credits.
|
|
|
|
Returns:
|
|
None (OpenAI doesn't provide credit API)
|
|
"""
|
|
return None
|
|
|
|
def clear_cache(self) -> None:
|
|
"""Clear model cache."""
|
|
self._models_cache = None
|
|
|
|
def get_raw_models(self) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get raw model data as dictionaries.
|
|
|
|
Returns:
|
|
List of model dictionaries
|
|
"""
|
|
models = self.list_models()
|
|
return [
|
|
{
|
|
"id": model.id,
|
|
"name": model.name,
|
|
"description": model.description,
|
|
"context_length": model.context_length,
|
|
"pricing": model.pricing,
|
|
}
|
|
for model in models
|
|
]
|
|
|
|
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get raw model data for a specific model.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
|
|
Returns:
|
|
Model dictionary or None
|
|
"""
|
|
model = self.get_model(model_id)
|
|
if model:
|
|
return {
|
|
"id": model.id,
|
|
"name": model.name,
|
|
"description": model.description,
|
|
"context_length": model.context_length,
|
|
"pricing": model.pricing,
|
|
}
|
|
return None
|