Initial commit
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
tools/__init__.py — Tool registry factory.
|
||||
|
||||
Call build_registry() to get a ToolRegistry populated with all
|
||||
production tools. The agent loop calls this at startup.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def build_registry(include_mock: bool = False, is_admin: bool = True):
|
||||
"""
|
||||
Build and return a ToolRegistry with all production tools registered.
|
||||
|
||||
Args:
|
||||
include_mock: If True, also register EchoTool and ConfirmTool (for testing).
|
||||
"""
|
||||
from ..agent.tool_registry import ToolRegistry
|
||||
registry = ToolRegistry()
|
||||
|
||||
# Production tools — each imported lazily to avoid errors if optional
|
||||
# dependencies are missing during development
|
||||
from .brain_tool import BrainTool
|
||||
from .caldav_tool import CalDAVTool
|
||||
from .email_tool import EmailTool
|
||||
from .filesystem_tool import FilesystemTool
|
||||
from .image_gen_tool import ImageGenTool
|
||||
from .pushover_tool import PushoverTool
|
||||
from .telegram_tool import TelegramTool
|
||||
from .web_tool import WebTool
|
||||
from .whitelist_tool import WhitelistTool
|
||||
|
||||
if is_admin:
|
||||
from .bash_tool import BashTool
|
||||
registry.register(BashTool())
|
||||
registry.register(BrainTool())
|
||||
registry.register(CalDAVTool())
|
||||
registry.register(EmailTool())
|
||||
registry.register(FilesystemTool())
|
||||
registry.register(ImageGenTool())
|
||||
registry.register(WebTool())
|
||||
registry.register(PushoverTool())
|
||||
registry.register(TelegramTool())
|
||||
registry.register(WhitelistTool())
|
||||
|
||||
if include_mock:
|
||||
from .mock import ConfirmTool, EchoTool
|
||||
registry.register(EchoTool())
|
||||
registry.register(ConfirmTool())
|
||||
|
||||
return registry
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
tools/base.py — BaseTool abstract class.
|
||||
|
||||
All tools inherit from this. The tool registry discovers them and builds
|
||||
the schema list sent to the AI provider on every agent call.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Normalised return value from every tool execution."""
|
||||
success: bool
|
||||
data: dict | list | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
if self.success:
|
||||
return {"success": True, "data": self.data}
|
||||
result: dict = {"success": False, "error": self.error}
|
||||
if self.data is not None:
|
||||
result["data"] = self.data
|
||||
return result
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""
|
||||
Abstract base for all aide tools.
|
||||
|
||||
Subclasses must set class-level attributes:
|
||||
name — used in tool schema and audit log
|
||||
description — what the AI sees
|
||||
input_schema — JSON Schema for parameters (Anthropic-native format)
|
||||
|
||||
Optional overrides:
|
||||
requires_confirmation — default False
|
||||
allowed_in_scheduled_tasks — default True
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict
|
||||
|
||||
requires_confirmation: bool = False
|
||||
allowed_in_scheduled_tasks: bool = True
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""
|
||||
Run the tool. Never raises — always returns a ToolResult.
|
||||
The dispatcher catches any unexpected exceptions as a safety net.
|
||||
"""
|
||||
|
||||
def get_schema(self) -> dict:
|
||||
"""Return the tool schema in aide-internal / Anthropic-native format."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"input_schema": self.input_schema,
|
||||
}
|
||||
|
||||
def confirmation_description(self, **kwargs) -> str:
|
||||
"""
|
||||
Human-readable description of the action shown to the user
|
||||
when confirmation is required. Override for better messages.
|
||||
"""
|
||||
args_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items())
|
||||
return f"{self.name}({args_str})"
|
||||
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
tools/bash_tool.py — Sandboxed bash command execution.
|
||||
|
||||
Runs shell commands in a working directory that must be within the
|
||||
filesystem whitelist. Captures stdout, stderr, and exit code.
|
||||
|
||||
Requires confirmation in interactive sessions. For scheduled tasks and
|
||||
agents, declare "bash" in allowed_tools to enable it without confirmation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from ..context_vars import current_user as _current_user_var
|
||||
from ..security import SecurityError, assert_path_allowed
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
DEFAULT_TIMEOUT = 30 # seconds
|
||||
MAX_TIMEOUT = 120 # seconds
|
||||
MAX_OUTPUT_BYTES = 50_000 # 50 KB per stream
|
||||
|
||||
|
||||
class BashTool(BaseTool):
|
||||
name = "bash"
|
||||
description = (
|
||||
"Execute a shell command in a sandboxed working directory. "
|
||||
"working_directory must be within the filesystem whitelist. "
|
||||
"Returns stdout, stderr, and exit_code. "
|
||||
"Use for running scripts, CLI tools, and automation tasks."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Shell command to execute (run via /bin/bash -c)",
|
||||
},
|
||||
"working_directory": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Absolute path to run the command in. "
|
||||
"Must be within the filesystem whitelist."
|
||||
),
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
f"Timeout in seconds (default {DEFAULT_TIMEOUT}, max {MAX_TIMEOUT}). "
|
||||
"The command is killed if it exceeds this limit."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["command", "working_directory"],
|
||||
}
|
||||
requires_confirmation = True
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
command: str,
|
||||
working_directory: str,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
# Defence-in-depth: reject non-admin users regardless of how the tool was invoked
|
||||
_user = _current_user_var.get()
|
||||
if _user is not None and not _user.is_admin:
|
||||
return ToolResult(success=False, error="Bash tool requires administrator privileges.")
|
||||
|
||||
# Validate working directory against filesystem whitelist
|
||||
try:
|
||||
safe_cwd = await assert_path_allowed(working_directory)
|
||||
except SecurityError as e:
|
||||
return ToolResult(success=False, error=str(e))
|
||||
|
||||
if not safe_cwd.is_dir():
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Working directory does not exist: {working_directory}",
|
||||
)
|
||||
|
||||
timeout = max(1, min(int(timeout), MAX_TIMEOUT))
|
||||
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=str(safe_cwd),
|
||||
)
|
||||
try:
|
||||
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||||
proc.communicate(), timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.communicate()
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Command timed out after {timeout}s",
|
||||
)
|
||||
|
||||
stdout = stdout_bytes[:MAX_OUTPUT_BYTES].decode("utf-8", errors="replace")
|
||||
stderr = stderr_bytes[:MAX_OUTPUT_BYTES].decode("utf-8", errors="replace")
|
||||
exit_code = proc.returncode
|
||||
|
||||
return ToolResult(
|
||||
success=exit_code == 0,
|
||||
data={"stdout": stdout, "stderr": stderr, "exit_code": exit_code},
|
||||
error=f"Exit code {exit_code}: {(stderr or stdout)[:500]}" if exit_code != 0 else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Failed to run command: {e}")
|
||||
|
||||
def confirmation_description(
|
||||
self, command: str = "", working_directory: str = "", **kwargs
|
||||
) -> str:
|
||||
return f"Run shell command in {working_directory}:\n{command}"
|
||||
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
tools/bound_filesystem_tool.py — Filesystem tool pre-scoped to a single directory.
|
||||
|
||||
Used by email handling agents to read/write their memory and reasoning files.
|
||||
No filesystem_whitelist lookup — containment is enforced internally via realpath check.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
|
||||
class BoundFilesystemTool(BaseTool):
|
||||
name = "filesystem"
|
||||
requires_confirmation = False
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
def __init__(self, base_path: str) -> None:
|
||||
self._base = os.path.realpath(base_path)
|
||||
self.description = (
|
||||
f"Read and write files inside your data folder ({self._base}). "
|
||||
"Operations: read_file, write_file, append_file, list_directory. "
|
||||
"All paths are relative to your data folder."
|
||||
)
|
||||
self.input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["read_file", "write_file", "append_file", "list_directory"],
|
||||
"description": "The operation to perform",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"File or directory path, relative to your data folder "
|
||||
f"(e.g. 'memory_work.md'). Absolute paths are also accepted "
|
||||
f"if they start with {self._base}."
|
||||
),
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write or append (for write_file / append_file)",
|
||||
},
|
||||
},
|
||||
"required": ["operation", "path"],
|
||||
}
|
||||
|
||||
def _resolve(self, path: str) -> str | None:
|
||||
"""Resolve path to absolute and verify it stays within base_path."""
|
||||
if os.path.isabs(path):
|
||||
resolved = os.path.realpath(path)
|
||||
else:
|
||||
resolved = os.path.realpath(os.path.join(self._base, path))
|
||||
if resolved == self._base or resolved.startswith(self._base + os.sep):
|
||||
return resolved
|
||||
return None # escape attempt
|
||||
|
||||
async def execute(self, operation: str = "", path: str = "", content: str = "", **_) -> ToolResult:
|
||||
resolved = self._resolve(path)
|
||||
if resolved is None:
|
||||
return ToolResult(success=False, error=f"Path '{path}' is outside the allowed folder.")
|
||||
|
||||
try:
|
||||
if operation == "read_file":
|
||||
if not os.path.isfile(resolved):
|
||||
return ToolResult(success=False, error=f"File not found: {path}")
|
||||
with open(resolved, encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
return ToolResult(success=True, data={"path": path, "content": text, "size": len(text)})
|
||||
|
||||
elif operation == "write_file":
|
||||
os.makedirs(os.path.dirname(resolved), exist_ok=True)
|
||||
with open(resolved, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return ToolResult(success=True, data={"path": path, "bytes_written": len(content.encode())})
|
||||
|
||||
elif operation == "append_file":
|
||||
os.makedirs(os.path.dirname(resolved), exist_ok=True)
|
||||
with open(resolved, "a", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return ToolResult(success=True, data={"path": path, "bytes_appended": len(content.encode())})
|
||||
|
||||
elif operation == "list_directory":
|
||||
target = resolved if os.path.isdir(resolved) else self._base
|
||||
entries = []
|
||||
for name in sorted(os.listdir(target)):
|
||||
full = os.path.join(target, name)
|
||||
entries.append({"name": name, "type": "dir" if os.path.isdir(full) else "file",
|
||||
"size": os.path.getsize(full) if os.path.isfile(full) else None})
|
||||
return ToolResult(success=True, data={"path": path, "entries": entries})
|
||||
|
||||
else:
|
||||
return ToolResult(success=False, error=f"Unknown operation: {operation!r}")
|
||||
|
||||
except OSError as e:
|
||||
return ToolResult(success=False, error=f"Filesystem error: {e}")
|
||||
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
tools/brain_tool.py — 2nd Brain tool for Jarvis.
|
||||
|
||||
Gives Jarvis (and any agent with brain access) two operations:
|
||||
capture — save a thought to the brain
|
||||
search — retrieve thoughts by semantic similarity
|
||||
browse — list recent thoughts (optional type filter)
|
||||
stats — database statistics
|
||||
|
||||
Capture is the only write operation and requires no confirmation (it's
|
||||
non-destructive and the user expects it to be instant).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ..context_vars import current_user as _current_user_var
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BrainTool(BaseTool):
|
||||
name = "brain"
|
||||
description = (
|
||||
"Access the 2nd Brain knowledge base. "
|
||||
"Operations: capture (save a thought), search (semantic search by meaning), "
|
||||
"browse (recent thoughts), stats (database overview). "
|
||||
"Use 'capture' to save anything worth remembering. "
|
||||
"Use 'search' to find relevant past thoughts by meaning, not just keywords."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["capture", "search", "browse", "stats"],
|
||||
"description": "Operation to perform.",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Text to capture (required for 'capture').",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (required for 'search').",
|
||||
},
|
||||
"threshold": {
|
||||
"type": "number",
|
||||
"description": "Similarity threshold 0-1 for 'search' (default 0.7). Lower = broader results.",
|
||||
"default": 0.7,
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results for 'search' or 'browse' (default 10).",
|
||||
"default": 10,
|
||||
},
|
||||
"type_filter": {
|
||||
"type": "string",
|
||||
"description": "Filter 'browse' by thought type: insight, person_note, task, reference, idea, other.",
|
||||
},
|
||||
},
|
||||
"required": ["operation"],
|
||||
}
|
||||
|
||||
requires_confirmation = False
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
operation = kwargs.get("operation")
|
||||
# Resolve current user for brain namespace scoping (3-G)
|
||||
_user = _current_user_var.get()
|
||||
user_id = _user.id if _user else None
|
||||
|
||||
try:
|
||||
from ..brain.database import get_pool
|
||||
if get_pool() is None:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="Brain DB is not available. Check BRAIN_DB_URL in .env.",
|
||||
)
|
||||
|
||||
if operation == "capture":
|
||||
return await self._capture(kwargs.get("content", ""), user_id=user_id)
|
||||
elif operation == "search":
|
||||
return await self._search(
|
||||
kwargs.get("query", ""),
|
||||
float(kwargs.get("threshold", 0.7)),
|
||||
int(kwargs.get("limit", 10)),
|
||||
user_id=user_id,
|
||||
)
|
||||
elif operation == "browse":
|
||||
return await self._browse(
|
||||
int(kwargs.get("limit", 10)),
|
||||
kwargs.get("type_filter"),
|
||||
user_id=user_id,
|
||||
)
|
||||
elif operation == "stats":
|
||||
return await self._stats(user_id=user_id)
|
||||
else:
|
||||
return ToolResult(success=False, error=f"Unknown operation: {operation}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("BrainTool error (%s): %s", operation, e)
|
||||
return ToolResult(success=False, error=str(e))
|
||||
|
||||
async def _capture(self, content: str, user_id: str | None = None) -> ToolResult:
|
||||
if not content.strip():
|
||||
return ToolResult(success=False, error="content is required for capture")
|
||||
from ..brain.ingest import ingest_thought
|
||||
result = await ingest_thought(content, user_id=user_id)
|
||||
return ToolResult(success=True, data=result)
|
||||
|
||||
async def _search(self, query: str, threshold: float, limit: int, user_id: str | None = None) -> ToolResult:
|
||||
if not query.strip():
|
||||
return ToolResult(success=False, error="query is required for search")
|
||||
from ..brain.search import semantic_search
|
||||
results = await semantic_search(query, threshold=threshold, limit=limit, user_id=user_id)
|
||||
return ToolResult(success=True, data={"results": results, "count": len(results)})
|
||||
|
||||
async def _browse(self, limit: int, type_filter: str | None, user_id: str | None = None) -> ToolResult:
|
||||
from ..brain.database import browse_thoughts
|
||||
results = await browse_thoughts(limit=limit, type_filter=type_filter, user_id=user_id)
|
||||
return ToolResult(success=True, data={"results": results, "count": len(results)})
|
||||
|
||||
async def _stats(self, user_id: str | None = None) -> ToolResult:
|
||||
from ..brain.database import get_stats
|
||||
stats = await get_stats(user_id=user_id)
|
||||
return ToolResult(success=True, data=stats)
|
||||
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
tools/caldav_tool.py — CalDAV calendar access (Mailcow / SOGo).
|
||||
|
||||
Credential keys (set via /settings):
|
||||
mailcow_host — e.g. mail.yourdomain.com
|
||||
mailcow_username — e.g. you@yourdomain.com
|
||||
mailcow_password — account or app password
|
||||
caldav_calendar_name — optional display-name filter; if omitted uses first calendar found
|
||||
|
||||
Uses principal discovery (/SOGo/dav/<username>/) to find calendars automatically.
|
||||
No hardcoded URL path — works regardless of internal calendar slug.
|
||||
|
||||
All datetimes are stored as UTC internally. Display times are converted
|
||||
to the configured timezone (Europe/Oslo by default).
|
||||
|
||||
create/update/delete require user confirmation.
|
||||
Max events returned: 100.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import caldav
|
||||
import vobject
|
||||
from dateutil import parser as dateutil_parser
|
||||
|
||||
from ..config import settings
|
||||
from ..context_vars import current_user
|
||||
from ..database import credential_store
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_caldav_config(user_id: str | None = None) -> dict:
|
||||
"""
|
||||
Two-layer CalDAV config lookup: user_settings → credential_store (global fallback).
|
||||
|
||||
Keys in user_settings: caldav_url, caldav_username, caldav_password, caldav_calendar_name
|
||||
Keys in credential_store: mailcow_host, mailcow_username, mailcow_password, caldav_calendar_name
|
||||
|
||||
Returns a dict with url, username, password, calendar_name (any may be None).
|
||||
"""
|
||||
if user_id:
|
||||
from ..database import user_settings_store
|
||||
url = await user_settings_store.get(user_id, "caldav_url")
|
||||
if url:
|
||||
return {
|
||||
"url": url,
|
||||
"username": await user_settings_store.get(user_id, "caldav_username"),
|
||||
"password": await user_settings_store.get(user_id, "caldav_password"),
|
||||
"calendar_name": await user_settings_store.get(user_id, "caldav_calendar_name"),
|
||||
}
|
||||
|
||||
# Fall back to global credential_store
|
||||
host = await credential_store.get("mailcow_host")
|
||||
return {
|
||||
"url": f"https://{host}/SOGo/dav/" if host else None,
|
||||
"username": await credential_store.get("mailcow_username"),
|
||||
"password": await credential_store.get("mailcow_password"),
|
||||
"calendar_name": await credential_store.get("caldav_calendar_name"),
|
||||
}
|
||||
|
||||
MAX_EVENTS = 100
|
||||
|
||||
|
||||
class CalDAVTool(BaseTool):
|
||||
name = "caldav"
|
||||
description = (
|
||||
"Manage calendar events via CalDAV (Mailcow/SOGo). "
|
||||
"Operations: list_events, get_event, create_event, update_event, delete_event. "
|
||||
"create_event, update_event, and delete_event require user confirmation. "
|
||||
"All dates should be in ISO8601 format (e.g. '2026-02-17T14:00:00+01:00')."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["list_events", "get_event", "create_event", "update_event", "delete_event"],
|
||||
},
|
||||
"start_date": {
|
||||
"type": "string",
|
||||
"description": "Start of date range for list_events (ISO8601)",
|
||||
},
|
||||
"end_date": {
|
||||
"type": "string",
|
||||
"description": "End of date range for list_events (ISO8601). Default: 30 days after start",
|
||||
},
|
||||
"event_id": {
|
||||
"type": "string",
|
||||
"description": "Event UID for get/update/delete operations",
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "Event title for create/update",
|
||||
},
|
||||
"start": {
|
||||
"type": "string",
|
||||
"description": "Event start datetime (ISO8601) for create/update",
|
||||
},
|
||||
"end": {
|
||||
"type": "string",
|
||||
"description": "Event end datetime (ISO8601) for create/update",
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Event description/notes for create/update",
|
||||
},
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "Event location for create/update",
|
||||
},
|
||||
},
|
||||
"required": ["operation"],
|
||||
}
|
||||
requires_confirmation = False # checked per-operation in execute()
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
async def _get_client(self) -> tuple[caldav.DAVClient, caldav.Calendar]:
|
||||
"""Return (client, calendar). Raises RuntimeError if credentials missing."""
|
||||
# Resolve current user from context (may be None for scheduled/agent runs)
|
||||
user = current_user.get()
|
||||
user_id = user.id if user else None
|
||||
|
||||
cfg = await _get_caldav_config(user_id=user_id)
|
||||
url = cfg.get("url")
|
||||
username = cfg.get("username")
|
||||
password = cfg.get("password")
|
||||
calendar_name = cfg.get("calendar_name") or ""
|
||||
|
||||
if not url or not username or not password:
|
||||
raise RuntimeError(
|
||||
"CalDAV credentials not configured. "
|
||||
"Set them in Settings → My Settings → CalDAV, or ask the admin to configure global CalDAV."
|
||||
)
|
||||
|
||||
# Build principal URL: if the stored URL is already the full principal URL use it directly;
|
||||
# otherwise append the SOGo-style path (backward compat with old mailcow_host keys).
|
||||
if "/SOGo/dav/" in url or url.rstrip("/").endswith(username):
|
||||
principal_url = url.rstrip("/") + "/"
|
||||
else:
|
||||
principal_url = f"{url.rstrip('/')}/SOGo/dav/{username}/"
|
||||
if calendar_name:
|
||||
logger.info("[caldav] Connecting — principal_url=%s username=%s calendar_filter=%r",
|
||||
principal_url, username, calendar_name)
|
||||
else:
|
||||
logger.info("[caldav] Connecting — principal_url=%s username=%s "
|
||||
"calendar_filter=(none set — will use first found; "
|
||||
"set 'caldav_calendar_name' credential to pick a specific one)",
|
||||
principal_url, username)
|
||||
|
||||
client = caldav.DAVClient(url=principal_url, username=username, password=password)
|
||||
|
||||
logger.debug("[caldav] Fetching principal…")
|
||||
principal = client.principal()
|
||||
logger.debug("[caldav] Principal URL: %s", principal.url)
|
||||
|
||||
logger.debug("[caldav] Discovering calendars…")
|
||||
calendars = principal.calendars()
|
||||
if not calendars:
|
||||
logger.error("[caldav] No calendars found for %s", username)
|
||||
raise RuntimeError("No calendars found for this account")
|
||||
|
||||
logger.info("[caldav] Found %d calendar(s): %s",
|
||||
len(calendars),
|
||||
", ".join(f"{c.name!r} ({c.url})" for c in calendars))
|
||||
|
||||
if calendar_name:
|
||||
needle = calendar_name.lower()
|
||||
# Exact match first, then substring fallback
|
||||
match = next((c for c in calendars if (c.name or "").lower() == needle), None)
|
||||
if match is None:
|
||||
match = next((c for c in calendars if needle in (c.name or "").lower()), None)
|
||||
if match is None:
|
||||
names = ", ".join(c.name or "?" for c in calendars)
|
||||
logger.error("[caldav] Calendar %r not found. Available: %s", calendar_name, names)
|
||||
raise RuntimeError(
|
||||
f"Calendar '{calendar_name}' not found. Available: {names}"
|
||||
)
|
||||
logger.info("[caldav] Using calendar %r url=%s", match.name, match.url)
|
||||
return client, match
|
||||
|
||||
chosen = calendars[0]
|
||||
logger.info("[caldav] Using first calendar: %r url=%s", chosen.name, chosen.url)
|
||||
return client, chosen
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
operation: str,
|
||||
start_date: str = "",
|
||||
end_date: str = "",
|
||||
event_id: str = "",
|
||||
summary: str = "",
|
||||
start: str = "",
|
||||
end: str = "",
|
||||
description: str = "",
|
||||
location: str = "",
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
logger.info("[caldav] execute operation=%s summary=%r start=%r end=%r event_id=%r",
|
||||
operation, summary, start, end, event_id)
|
||||
|
||||
try:
|
||||
_, calendar = await self._get_client()
|
||||
except RuntimeError as e:
|
||||
logger.error("[caldav] Connection/credential error: %s", e)
|
||||
return ToolResult(success=False, error=str(e))
|
||||
except Exception as e:
|
||||
logger.error("[caldav] Unexpected connection error: %s\n%s", e, traceback.format_exc())
|
||||
return ToolResult(success=False, error=f"CalDAV connection error: {e}")
|
||||
|
||||
if operation == "list_events":
|
||||
return self._list_events(calendar, start_date, end_date)
|
||||
if operation == "get_event":
|
||||
if not event_id:
|
||||
return ToolResult(success=False, error="event_id is required for get_event")
|
||||
return self._get_event(calendar, event_id)
|
||||
if operation == "create_event":
|
||||
if not (summary and start and end):
|
||||
return ToolResult(success=False, error="summary, start, and end are required for create_event")
|
||||
return self._create_event(calendar, summary, start, end, description, location)
|
||||
if operation == "update_event":
|
||||
if not event_id:
|
||||
return ToolResult(success=False, error="event_id is required for update_event")
|
||||
return self._update_event(calendar, event_id, summary, start, end, description, location)
|
||||
if operation == "delete_event":
|
||||
if not event_id:
|
||||
return ToolResult(success=False, error="event_id is required for delete_event")
|
||||
return self._delete_event(calendar, event_id)
|
||||
|
||||
logger.warning("[caldav] Unknown operation: %r", operation)
|
||||
return ToolResult(success=False, error=f"Unknown operation: {operation!r}")
|
||||
|
||||
# ── Read operations ───────────────────────────────────────────────────────
|
||||
|
||||
def _list_events(self, calendar: caldav.Calendar, start_date: str, end_date: str) -> ToolResult:
|
||||
try:
|
||||
if start_date:
|
||||
start_dt = dateutil_parser.parse(start_date).replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
start_dt = datetime.now(timezone.utc)
|
||||
|
||||
if end_date:
|
||||
end_dt = dateutil_parser.parse(end_date).replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
end_dt = start_dt + timedelta(days=30)
|
||||
|
||||
logger.info("[caldav] list_events range=%s → %s calendar_url=%s",
|
||||
start_dt.isoformat(), end_dt.isoformat(), calendar.url)
|
||||
|
||||
events = calendar.date_search(start=start_dt, end=end_dt, expand=True)
|
||||
events = events[:MAX_EVENTS]
|
||||
|
||||
logger.info("[caldav] list_events returned %d event(s)", len(events))
|
||||
|
||||
result = []
|
||||
for event in events:
|
||||
summary = _get_property(event, "summary", "No title")
|
||||
ev_start = _get_dt_str(event, "dtstart")
|
||||
ev_end = _get_dt_str(event, "dtend")
|
||||
logger.debug("[caldav] event: %r start=%s end=%s", summary, ev_start, ev_end)
|
||||
result.append({
|
||||
"id": _get_uid(event),
|
||||
"summary": summary,
|
||||
"start": ev_start,
|
||||
"end": ev_end,
|
||||
"location": _get_property(event, "location", ""),
|
||||
"description_preview": _get_property(event, "description", "")[:100],
|
||||
})
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={"events": result, "count": len(result)},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[caldav] list_events failed: %s\n%s", e, traceback.format_exc())
|
||||
return ToolResult(success=False, error=f"CalDAV list error: {e}")
|
||||
|
||||
def _get_event(self, calendar: caldav.Calendar, event_id: str) -> ToolResult:
|
||||
logger.info("[caldav] get_event event_id=%s", event_id)
|
||||
try:
|
||||
event = _find_event(calendar, event_id)
|
||||
if event is None:
|
||||
logger.warning("[caldav] get_event: event not found: %s", event_id)
|
||||
return ToolResult(success=False, error=f"Event not found: {event_id}")
|
||||
|
||||
data = {
|
||||
"id": _get_uid(event),
|
||||
"summary": _get_property(event, "summary", ""),
|
||||
"start": _get_dt_str(event, "dtstart"),
|
||||
"end": _get_dt_str(event, "dtend"),
|
||||
"location": _get_property(event, "location", ""),
|
||||
"description": _get_property(event, "description", ""),
|
||||
}
|
||||
logger.info("[caldav] get_event found: %r start=%s", data["summary"], data["start"])
|
||||
return ToolResult(success=True, data=data)
|
||||
except Exception as e:
|
||||
logger.error("[caldav] get_event failed: %s\n%s", e, traceback.format_exc())
|
||||
return ToolResult(success=False, error=f"CalDAV get error: {e}")
|
||||
|
||||
# ── Write operations ──────────────────────────────────────────────────────
|
||||
|
||||
def _create_event(
|
||||
self,
|
||||
calendar: caldav.Calendar,
|
||||
summary: str,
|
||||
start: str,
|
||||
end: str,
|
||||
description: str,
|
||||
location: str,
|
||||
) -> ToolResult:
|
||||
import uuid
|
||||
logger.info("[caldav] create_event summary=%r start=%s end=%s location=%r calendar_url=%s",
|
||||
summary, start, end, location, calendar.url)
|
||||
try:
|
||||
start_dt = dateutil_parser.parse(start)
|
||||
end_dt = dateutil_parser.parse(end)
|
||||
logger.debug("[caldav] create_event parsed start_dt=%s end_dt=%s", start_dt, end_dt)
|
||||
|
||||
uid = str(uuid.uuid4())
|
||||
ical = _build_ical(uid, summary, start_dt, end_dt, description, location)
|
||||
logger.debug("[caldav] create_event ical payload:\n%s", ical)
|
||||
|
||||
calendar.add_event(ical)
|
||||
logger.info("[caldav] create_event success uid=%s", uid)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={"created": True, "uid": uid, "summary": summary},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("[caldav] create_event failed: %s\n%s", e, traceback.format_exc())
|
||||
return ToolResult(success=False, error=f"CalDAV create error: {e}")
|
||||
|
||||
def _update_event(
|
||||
self,
|
||||
calendar: caldav.Calendar,
|
||||
event_id: str,
|
||||
summary: str,
|
||||
start: str,
|
||||
end: str,
|
||||
description: str,
|
||||
location: str,
|
||||
) -> ToolResult:
|
||||
logger.info("[caldav] update_event event_id=%s summary=%r start=%s end=%s",
|
||||
event_id, summary, start, end)
|
||||
try:
|
||||
event = _find_event(calendar, event_id)
|
||||
if event is None:
|
||||
logger.warning("[caldav] update_event: event not found: %s", event_id)
|
||||
return ToolResult(success=False, error=f"Event not found: {event_id}")
|
||||
|
||||
vevent = event.vobject_instance.vevent
|
||||
|
||||
if summary:
|
||||
vevent.summary.value = summary
|
||||
if start:
|
||||
vevent.dtstart.value = dateutil_parser.parse(start)
|
||||
if end:
|
||||
vevent.dtend.value = dateutil_parser.parse(end)
|
||||
if description:
|
||||
if hasattr(vevent, "description"):
|
||||
vevent.description.value = description
|
||||
else:
|
||||
vevent.add("description").value = description
|
||||
if location:
|
||||
if hasattr(vevent, "location"):
|
||||
vevent.location.value = location
|
||||
else:
|
||||
vevent.add("location").value = location
|
||||
|
||||
event.save()
|
||||
logger.info("[caldav] update_event success uid=%s", event_id)
|
||||
return ToolResult(success=True, data={"updated": True, "uid": event_id})
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[caldav] update_event failed: %s\n%s", e, traceback.format_exc())
|
||||
return ToolResult(success=False, error=f"CalDAV update error: {e}")
|
||||
|
||||
def _delete_event(self, calendar: caldav.Calendar, event_id: str) -> ToolResult:
|
||||
logger.info("[caldav] delete_event event_id=%s", event_id)
|
||||
try:
|
||||
event = _find_event(calendar, event_id)
|
||||
if event is None:
|
||||
logger.warning("[caldav] delete_event: event not found: %s", event_id)
|
||||
return ToolResult(success=False, error=f"Event not found: {event_id}")
|
||||
event.delete()
|
||||
logger.info("[caldav] delete_event success uid=%s", event_id)
|
||||
return ToolResult(success=True, data={"deleted": True, "uid": event_id})
|
||||
except Exception as e:
|
||||
logger.error("[caldav] delete_event failed: %s\n%s", e, traceback.format_exc())
|
||||
return ToolResult(success=False, error=f"CalDAV delete error: {e}")
|
||||
|
||||
def confirmation_description(self, operation: str = "", summary: str = "", event_id: str = "", **kwargs) -> str:
|
||||
if operation == "create_event":
|
||||
start = kwargs.get("start", "")
|
||||
return f"Create calendar event: '{summary}' at {start}"
|
||||
if operation == "update_event":
|
||||
return f"Update calendar event: {event_id}" + (f" → '{summary}'" if summary else "")
|
||||
if operation == "delete_event":
|
||||
return f"Permanently delete calendar event: {event_id}"
|
||||
return f"{operation}: {event_id or summary}"
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _get_uid(event: caldav.Event) -> str:
|
||||
try:
|
||||
return str(event.vobject_instance.vevent.uid.value)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _get_property(event: caldav.Event, prop: str, default: str = "") -> str:
|
||||
try:
|
||||
return str(getattr(event.vobject_instance.vevent, prop).value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _get_dt_str(event: caldav.Event, prop: str) -> str:
|
||||
try:
|
||||
val = getattr(event.vobject_instance.vevent, prop).value
|
||||
if isinstance(val, datetime):
|
||||
return val.isoformat()
|
||||
return str(val)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _find_event(calendar: caldav.Calendar, uid: str) -> caldav.Event | None:
|
||||
"""Find an event by UID. Returns None if not found."""
|
||||
try:
|
||||
# Try direct URL lookup first
|
||||
for event in calendar.events():
|
||||
if _get_uid(event) == uid:
|
||||
return event
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _build_ical(
|
||||
uid: str,
|
||||
summary: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
description: str,
|
||||
location: str,
|
||||
) -> str:
|
||||
now = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
start_str = start.strftime("%Y%m%dT%H%M%S")
|
||||
end_str = end.strftime("%Y%m%dT%H%M%S")
|
||||
tz = "Europe/Oslo"
|
||||
|
||||
lines = [
|
||||
"BEGIN:VCALENDAR",
|
||||
"VERSION:2.0",
|
||||
"PRODID:-//aide//aide//EN",
|
||||
"BEGIN:VEVENT",
|
||||
f"UID:{uid}",
|
||||
f"DTSTAMP:{now}",
|
||||
f"DTSTART;TZID={tz}:{start_str}",
|
||||
f"DTEND;TZID={tz}:{end_str}",
|
||||
f"SUMMARY:{summary}",
|
||||
]
|
||||
if description:
|
||||
lines.append(f"DESCRIPTION:{description.replace(chr(10), '\\n')}")
|
||||
if location:
|
||||
lines.append(f"LOCATION:{location}")
|
||||
lines += ["END:VEVENT", "END:VCALENDAR"]
|
||||
return "\r\n".join(lines)
|
||||
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
tools/email_handling_tool.py — Read/organise email tool for handling accounts.
|
||||
|
||||
NOT in the global tool registry.
|
||||
Instantiated at dispatch time with decrypted account credentials.
|
||||
Passed as extra_tools to agent.run() with force_only_extra_tools=True.
|
||||
|
||||
Deliberately excludes: send, reply, forward, create draft, delete email,
|
||||
expunge folder, any SMTP operation.
|
||||
|
||||
Uses imaplib (stdlib) via asyncio.to_thread — avoids aioimaplib auth issues
|
||||
with some Dovecot configurations.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import email as email_lib
|
||||
import imaplib
|
||||
import logging
|
||||
import re
|
||||
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmailHandlingTool(BaseTool):
|
||||
name = "email_handling"
|
||||
description = "Read, organise and manage emails within the configured folders." # overridden in __init__
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"list_emails",
|
||||
"read_email",
|
||||
"mark_email",
|
||||
"move_email",
|
||||
"list_folders",
|
||||
"create_folder",
|
||||
],
|
||||
"description": "The operation to perform",
|
||||
},
|
||||
"folder": {
|
||||
"type": "string",
|
||||
"description": "IMAP folder name (default: INBOX)",
|
||||
},
|
||||
"uid": {
|
||||
"type": "string",
|
||||
"description": "Email UID for read_email, mark_email, move_email",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max emails to return for list_emails (default: 20, max: 100)",
|
||||
},
|
||||
"unread_only": {
|
||||
"type": "boolean",
|
||||
"description": "Only list unread emails (for list_emails)",
|
||||
},
|
||||
"search": {
|
||||
"type": "string",
|
||||
"description": "IMAP SEARCH criteria string (for list_emails)",
|
||||
},
|
||||
"flag": {
|
||||
"type": "string",
|
||||
"enum": ["read", "unread", "flagged", "unflagged", "spam"],
|
||||
"description": "Flag action for mark_email",
|
||||
},
|
||||
"source_folder": {
|
||||
"type": "string",
|
||||
"description": "Source folder for move_email",
|
||||
},
|
||||
"target_folder": {
|
||||
"type": "string",
|
||||
"description": "Target folder for move_email or parent for create_folder",
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "New folder name for create_folder",
|
||||
},
|
||||
},
|
||||
"required": ["operation"],
|
||||
}
|
||||
requires_confirmation = False
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
def __init__(self, account: dict) -> None:
|
||||
"""account: dict with decrypted imap_host/port/username/password."""
|
||||
self._host = account["imap_host"]
|
||||
self._port = int(account.get("imap_port") or 993)
|
||||
self._username = account["imap_username"]
|
||||
self._password = account["imap_password"]
|
||||
raw = account.get("monitored_folders")
|
||||
if raw is None:
|
||||
self._allowed_folders: list[str] | None = None # None = no folder restriction
|
||||
elif isinstance(raw, list):
|
||||
self._allowed_folders = raw if raw else None
|
||||
else:
|
||||
self._allowed_folders = ["INBOX"]
|
||||
if self._allowed_folders:
|
||||
folder_list = ", ".join(repr(f) for f in self._allowed_folders)
|
||||
restriction = f"You may ONLY access these folders (and their subfolders): {folder_list}. Any attempt to read from or move to a folder outside this list will be rejected."
|
||||
else:
|
||||
restriction = "You may access all folders."
|
||||
self.description = (
|
||||
f"Read, organise and manage emails. "
|
||||
f"Operations: list_emails, read_email, mark_email, move_email, list_folders, create_folder. "
|
||||
f"Cannot send, delete or permanently expunge emails. "
|
||||
f"{restriction}"
|
||||
)
|
||||
|
||||
def _check_folder(self, folder: str) -> str | None:
|
||||
"""Return an error string if folder is outside allowed_folders, else None."""
|
||||
if self._allowed_folders is None:
|
||||
return None
|
||||
for allowed in self._allowed_folders:
|
||||
if folder == allowed or folder.startswith(allowed.rstrip("/") + "/"):
|
||||
return None
|
||||
return (
|
||||
f"Folder {folder!r} is outside the allowed folders for this account: "
|
||||
+ ", ".join(repr(f) for f in self._allowed_folders)
|
||||
)
|
||||
|
||||
def _open(self) -> imaplib.IMAP4_SSL:
|
||||
"""Open and authenticate a synchronous IMAP4_SSL connection."""
|
||||
M = imaplib.IMAP4_SSL(self._host, self._port)
|
||||
M.login(self._username, self._password)
|
||||
return M
|
||||
|
||||
async def execute(self, operation: str = "", **kwargs) -> ToolResult:
|
||||
try:
|
||||
if operation == "list_emails":
|
||||
return await self._list_emails(**kwargs)
|
||||
elif operation == "read_email":
|
||||
return await self._read_email(**kwargs)
|
||||
elif operation == "mark_email":
|
||||
return await self._mark_email(**kwargs)
|
||||
elif operation == "move_email":
|
||||
return await self._move_email(**kwargs)
|
||||
elif operation == "list_folders":
|
||||
return await self._list_folders()
|
||||
elif operation == "create_folder":
|
||||
return await self._create_folder(**kwargs)
|
||||
else:
|
||||
return ToolResult(success=False, error=f"Unknown operation: {operation!r}")
|
||||
except Exception as e:
|
||||
logger.error("[email_handling] %s error: %s", operation, e)
|
||||
return ToolResult(success=False, error=str(e))
|
||||
|
||||
# ── Operations (run blocking imaplib calls in a thread) ───────────────────
|
||||
|
||||
async def _list_emails(
|
||||
self,
|
||||
folder: str = "INBOX",
|
||||
limit: int = 20,
|
||||
unread_only: bool = False,
|
||||
search: str = "",
|
||||
**_,
|
||||
) -> ToolResult:
|
||||
err = self._check_folder(folder)
|
||||
if err:
|
||||
return ToolResult(success=False, error=err)
|
||||
limit = min(int(limit), 100)
|
||||
return await asyncio.to_thread(
|
||||
self._sync_list_emails, folder, limit, unread_only, search
|
||||
)
|
||||
|
||||
def _sync_list_emails(
|
||||
self, folder: str, limit: int, unread_only: bool, search: str
|
||||
) -> ToolResult:
|
||||
M = self._open()
|
||||
try:
|
||||
M.select(folder, readonly=True)
|
||||
criteria = search if search else ("UNSEEN" if unread_only else "ALL")
|
||||
typ, data = M.search(None, criteria)
|
||||
if typ != "OK" or not data or not data[0]:
|
||||
return ToolResult(success=True, data={"emails": [], "count": 0, "folder": folder})
|
||||
|
||||
nums = data[0].split()
|
||||
nums = nums[-limit:][::-1] # most recent first
|
||||
|
||||
emails = []
|
||||
for num in nums:
|
||||
typ2, msg_data = M.fetch(
|
||||
num, "(FLAGS BODY.PEEK[HEADER.FIELDS (FROM TO SUBJECT DATE)])"
|
||||
)
|
||||
if typ2 != "OK" or not msg_data or not msg_data[0]:
|
||||
continue
|
||||
flags_str = str(msg_data[0][0]) if isinstance(msg_data[0], tuple) else str(msg_data[0])
|
||||
header_bytes = msg_data[0][1] if isinstance(msg_data[0], tuple) else b""
|
||||
msg = email_lib.message_from_bytes(header_bytes)
|
||||
is_unread = "\\Seen" not in flags_str
|
||||
emails.append({
|
||||
"uid": num.decode() if isinstance(num, bytes) else str(num),
|
||||
"from": msg.get("From", ""),
|
||||
"to": msg.get("To", ""),
|
||||
"subject": msg.get("Subject", ""),
|
||||
"date": msg.get("Date", ""),
|
||||
"unread": is_unread,
|
||||
})
|
||||
|
||||
return ToolResult(success=True, data={"emails": emails, "count": len(emails), "folder": folder})
|
||||
finally:
|
||||
_close(M)
|
||||
|
||||
async def _read_email(self, uid: str = "", folder: str = "INBOX", **_) -> ToolResult:
|
||||
if not uid:
|
||||
return ToolResult(success=False, error="uid is required")
|
||||
err = self._check_folder(folder)
|
||||
if err:
|
||||
return ToolResult(success=False, error=err)
|
||||
return await asyncio.to_thread(self._sync_read_email, uid, folder)
|
||||
|
||||
def _sync_read_email(self, uid: str, folder: str) -> ToolResult:
|
||||
M = self._open()
|
||||
try:
|
||||
M.select(folder, readonly=True)
|
||||
typ, data = M.fetch(uid, "(FLAGS BODY.PEEK[])")
|
||||
if typ != "OK" or not data or not data[0]:
|
||||
return ToolResult(success=False, error=f"Cannot fetch message uid={uid}")
|
||||
|
||||
flags_str = str(data[0][0]) if isinstance(data[0], tuple) else str(data[0])
|
||||
raw = data[0][1] if isinstance(data[0], tuple) else b""
|
||||
msg = email_lib.message_from_bytes(raw)
|
||||
is_unread = "\\Seen" not in flags_str
|
||||
body = _extract_body(msg)
|
||||
|
||||
return ToolResult(success=True, data={
|
||||
"uid": uid,
|
||||
"folder": folder,
|
||||
"from": msg.get("From", ""),
|
||||
"to": msg.get("To", ""),
|
||||
"cc": msg.get("Cc", ""),
|
||||
"subject": msg.get("Subject", ""),
|
||||
"date": msg.get("Date", ""),
|
||||
"unread": is_unread,
|
||||
"body": body[:6000],
|
||||
})
|
||||
finally:
|
||||
_close(M)
|
||||
|
||||
async def _mark_email(
|
||||
self, uid: str = "", folder: str = "INBOX", flag: str = "read", **_
|
||||
) -> ToolResult:
|
||||
if not uid:
|
||||
return ToolResult(success=False, error="uid is required")
|
||||
err = self._check_folder(folder)
|
||||
if err:
|
||||
return ToolResult(success=False, error=err)
|
||||
|
||||
flag_map = {
|
||||
"read": ("+FLAGS", "\\Seen"),
|
||||
"unread": ("-FLAGS", "\\Seen"),
|
||||
"flagged": ("+FLAGS", "\\Flagged"),
|
||||
"unflagged": ("-FLAGS", "\\Flagged"),
|
||||
"spam": ("+FLAGS", "Junk"),
|
||||
}
|
||||
if flag not in flag_map:
|
||||
return ToolResult(success=False, error=f"Unknown flag: {flag!r}")
|
||||
|
||||
return await asyncio.to_thread(self._sync_mark_email, uid, folder, flag_map[flag])
|
||||
|
||||
def _sync_mark_email(self, uid: str, folder: str, flag_op: tuple) -> ToolResult:
|
||||
action, imap_flag = flag_op
|
||||
M = self._open()
|
||||
try:
|
||||
M.select(folder)
|
||||
typ, _ = M.store(uid, action, imap_flag)
|
||||
if typ != "OK":
|
||||
return ToolResult(success=False, error=f"Failed to mark email uid={uid}")
|
||||
return ToolResult(success=True, data={"uid": uid, "flag": action})
|
||||
finally:
|
||||
_close(M)
|
||||
|
||||
async def _move_email(
|
||||
self,
|
||||
uid: str = "",
|
||||
source_folder: str = "INBOX",
|
||||
target_folder: str = "",
|
||||
**_,
|
||||
) -> ToolResult:
|
||||
if not uid:
|
||||
return ToolResult(success=False, error="uid is required")
|
||||
if not target_folder:
|
||||
return ToolResult(success=False, error="target_folder is required")
|
||||
for folder, label in ((source_folder, "source_folder"), (target_folder, "target_folder")):
|
||||
err = self._check_folder(folder)
|
||||
if err:
|
||||
return ToolResult(success=False, error=f"{label}: {err}")
|
||||
return await asyncio.to_thread(self._sync_move_email, uid, source_folder, target_folder)
|
||||
|
||||
def _sync_move_email(self, uid: str, source_folder: str, target_folder: str) -> ToolResult:
|
||||
M = self._open()
|
||||
try:
|
||||
M.select(source_folder)
|
||||
typ, _ = M.copy(uid, target_folder)
|
||||
if typ != "OK":
|
||||
return ToolResult(success=False, error=f"Failed to copy to {target_folder!r}")
|
||||
M.store(uid, "+FLAGS", "\\Deleted")
|
||||
M.expunge()
|
||||
return ToolResult(success=True, data={"uid": uid, "moved_to": target_folder})
|
||||
finally:
|
||||
_close(M)
|
||||
|
||||
async def _list_folders(self, **_) -> ToolResult:
|
||||
return await asyncio.to_thread(self._sync_list_folders)
|
||||
|
||||
def _sync_list_folders(self) -> ToolResult:
|
||||
M = self._open()
|
||||
try:
|
||||
typ, data = M.list()
|
||||
if typ != "OK":
|
||||
return ToolResult(success=False, error="Failed to list folders")
|
||||
|
||||
folders = []
|
||||
for line in data:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode() if isinstance(line, bytes) else str(line)
|
||||
match = re.search(r'"([^"]+)"\s*$', line_str) or re.search(r'(\S+)\s*$', line_str)
|
||||
if match:
|
||||
name = match.group(1)
|
||||
if name and name.upper() != "NIL":
|
||||
folders.append(name)
|
||||
|
||||
if self._allowed_folders is not None:
|
||||
folders = [
|
||||
f for f in folders
|
||||
if any(
|
||||
f == a or f.startswith(a.rstrip("/") + "/")
|
||||
for a in self._allowed_folders
|
||||
)
|
||||
]
|
||||
|
||||
return ToolResult(success=True, data={"folders": folders, "count": len(folders)})
|
||||
finally:
|
||||
_close(M)
|
||||
|
||||
async def _create_folder(self, name: str = "", target_folder: str = "", **_) -> ToolResult:
|
||||
if not name:
|
||||
return ToolResult(success=False, error="name is required")
|
||||
full_name = f"{target_folder}/{name}" if target_folder else name
|
||||
err = self._check_folder(full_name)
|
||||
if err:
|
||||
return ToolResult(success=False, error=err)
|
||||
return await asyncio.to_thread(self._sync_create_folder, full_name)
|
||||
|
||||
def _sync_create_folder(self, full_name: str) -> ToolResult:
|
||||
M = self._open()
|
||||
try:
|
||||
typ, _ = M.create(full_name)
|
||||
if typ != "OK":
|
||||
return ToolResult(success=False, error=f"Failed to create folder {full_name!r}")
|
||||
return ToolResult(success=True, data={"folder": full_name, "created": True})
|
||||
finally:
|
||||
_close(M)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _close(M: imaplib.IMAP4_SSL) -> None:
|
||||
try:
|
||||
M.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _extract_body(msg: email_lib.message.Message) -> str:
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
payload = part.get_payload(decode=True)
|
||||
return payload.decode("utf-8", errors="replace") if payload else ""
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/html":
|
||||
payload = part.get_payload(decode=True)
|
||||
html = payload.decode("utf-8", errors="replace") if payload else ""
|
||||
return re.sub(r"<[^>]+>", "", html).strip()
|
||||
else:
|
||||
payload = msg.get_payload(decode=True)
|
||||
return payload.decode("utf-8", errors="replace") if payload else ""
|
||||
return ""
|
||||
@@ -0,0 +1,397 @@
|
||||
"""
|
||||
tools/email_tool.py — IMAP email reading + SMTP sending.
|
||||
|
||||
Read operations: list_emails, read_email — no confirmation required.
|
||||
Send operation: send_email — whitelisted recipients only, requires confirmation.
|
||||
|
||||
Prompt injection guard: all email body text is sanitised before returning to agent.
|
||||
Max body length: 10,000 characters (truncated with notice).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import email as email_lib
|
||||
import smtplib
|
||||
import ssl
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate, make_msgid, parseaddr
|
||||
|
||||
import imapclient
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from ..database import credential_store
|
||||
from ..security import SecurityError, assert_email_rate_limit, assert_recipient_allowed, sanitize_external_content
|
||||
from ..security_screening import get_content_limit, is_option_enabled
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
MAX_BODY_CHARS = 10_000 # legacy fallback when truncation option disabled
|
||||
_DEFAULT_MAX_EMAIL_CHARS = 6_000 # default when truncation option enabled
|
||||
_DEFAULT_MAX_SUBJECT_CHARS = 200 # default subject limit when truncation option enabled
|
||||
MAX_LIST_EMAILS = 50
|
||||
|
||||
|
||||
class EmailTool(BaseTool):
|
||||
name = "email"
|
||||
description = (
|
||||
"Read and send emails via IMAP/SMTP (Mailcow). "
|
||||
"Operations: list_emails (list inbox), read_email (read full message), "
|
||||
"send_email (send to one or more whitelisted recipients — requires confirmation), "
|
||||
"list_whitelist (return all approved recipient addresses). "
|
||||
"Email bodies are sanitised before being returned."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["list_emails", "read_email", "send_email", "list_whitelist"],
|
||||
"description": "The email operation to perform. list_whitelist returns all approved recipient addresses.",
|
||||
},
|
||||
"folder": {
|
||||
"type": "string",
|
||||
"description": "IMAP folder (default: INBOX)",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Max emails to list (default 20, max {MAX_LIST_EMAILS})",
|
||||
},
|
||||
"unread_only": {
|
||||
"type": "boolean",
|
||||
"description": "Only list unread emails (default false)",
|
||||
},
|
||||
"email_id": {
|
||||
"type": "string",
|
||||
"description": "Email UID for read_email",
|
||||
},
|
||||
"to": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "array", "items": {"type": "string"}},
|
||||
],
|
||||
"description": "Recipient address or list of addresses for send_email (all must be whitelisted)",
|
||||
},
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "Email subject for send_email",
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Email body text (plain text) for send_email",
|
||||
},
|
||||
"html_body": {
|
||||
"type": "string",
|
||||
"description": "Full HTML email body for send_email. If provided, used as the HTML part instead of the plain-text fallback wrapper. Include complete <html>...</html> with inline <style>.",
|
||||
},
|
||||
"reply_to_id": {
|
||||
"type": "string",
|
||||
"description": "Email UID to reply to (sets In-Reply-To header)",
|
||||
},
|
||||
},
|
||||
"required": ["operation"],
|
||||
}
|
||||
requires_confirmation = False # only send_email requires it — checked in execute()
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
async def _load_credentials(self) -> tuple[str, str, str, str, int]:
|
||||
"""Returns (imap_host, smtp_host, username, password, smtp_port)."""
|
||||
base_host = await credential_store.require("mailcow_host")
|
||||
username = await credential_store.require("mailcow_username")
|
||||
password = await credential_store.require("mailcow_password")
|
||||
imap_host = await credential_store.get("mailcow_imap_host") or base_host
|
||||
smtp_host = await credential_store.get("mailcow_smtp_host") or base_host
|
||||
smtp_port = int(await credential_store.get("mailcow_smtp_port") or "465")
|
||||
return imap_host, smtp_host, username, password, smtp_port
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
operation: str,
|
||||
folder: str = "INBOX",
|
||||
limit: int = 20,
|
||||
unread_only: bool = False,
|
||||
email_id: str = "",
|
||||
to=None,
|
||||
subject: str = "",
|
||||
body: str = "",
|
||||
html_body: str = "",
|
||||
reply_to_id: str = "",
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
if operation == "list_emails":
|
||||
return await self._list_emails(folder, min(limit, MAX_LIST_EMAILS), unread_only)
|
||||
if operation == "read_email":
|
||||
if not email_id:
|
||||
return ToolResult(success=False, error="email_id is required for read_email")
|
||||
return await self._read_email(folder, email_id)
|
||||
if operation == "list_whitelist":
|
||||
return await self._list_whitelist()
|
||||
if operation == "send_email":
|
||||
# Normalise to → list[str]
|
||||
if isinstance(to, list):
|
||||
recipients = [r.strip() for r in to if r.strip()]
|
||||
elif isinstance(to, str) and to.strip():
|
||||
recipients = [to.strip()]
|
||||
else:
|
||||
recipients = []
|
||||
if not (recipients and subject and (body or html_body)):
|
||||
return ToolResult(success=False, error="to, subject, and body (or html_body) are required for send_email")
|
||||
return await self._send_email(recipients, subject, body, html_body, reply_to_id)
|
||||
|
||||
return ToolResult(success=False, error=f"Unknown operation: {operation!r}")
|
||||
|
||||
# ── IMAP ──────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _list_emails(self, folder: str, limit: int, unread_only: bool) -> ToolResult:
|
||||
try:
|
||||
imap_host, _, username, password, _ = await self._load_credentials()
|
||||
except RuntimeError as e:
|
||||
return ToolResult(success=False, error=str(e))
|
||||
|
||||
try:
|
||||
with imapclient.IMAPClient(imap_host, ssl=True, port=993) as client:
|
||||
client.login(username, password)
|
||||
client.select_folder(folder, readonly=True)
|
||||
|
||||
criteria = ["UNSEEN"] if unread_only else ["ALL"]
|
||||
uids = client.search(criteria)
|
||||
|
||||
# Most recent first, limited
|
||||
uids = list(reversed(uids))[:limit]
|
||||
if not uids:
|
||||
return ToolResult(success=True, data={"emails": [], "count": 0})
|
||||
|
||||
messages = client.fetch(uids, ["ENVELOPE", "FLAGS", "RFC822.SIZE"])
|
||||
emails = []
|
||||
for uid, data in messages.items():
|
||||
env = data.get(b"ENVELOPE")
|
||||
if not env:
|
||||
continue
|
||||
|
||||
from_addr = _format_address(env.from_) if env.from_ else ""
|
||||
emails.append({
|
||||
"id": str(uid),
|
||||
"from": from_addr,
|
||||
"subject": _decode_header(env.subject),
|
||||
"date": str(env.date) if env.date else "",
|
||||
"unread": b"\\Seen" not in (data.get(b"FLAGS") or []),
|
||||
"size_bytes": data.get(b"RFC822.SIZE", 0),
|
||||
})
|
||||
|
||||
# Sort by date desc (approximate — ENVELOPE date isn't always reliable)
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={"emails": emails, "count": len(emails)},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"IMAP error: {e}")
|
||||
|
||||
async def _read_email(self, folder: str, email_id: str) -> ToolResult:
|
||||
try:
|
||||
imap_host, _, username, password, _ = await self._load_credentials()
|
||||
except RuntimeError as e:
|
||||
return ToolResult(success=False, error=str(e))
|
||||
|
||||
try:
|
||||
uid = int(email_id)
|
||||
except ValueError:
|
||||
return ToolResult(success=False, error=f"Invalid email_id: {email_id!r}")
|
||||
|
||||
try:
|
||||
with imapclient.IMAPClient(imap_host, ssl=True, port=993) as client:
|
||||
client.login(username, password)
|
||||
client.select_folder(folder, readonly=True)
|
||||
|
||||
messages = client.fetch([uid], ["RFC822"])
|
||||
if not messages or uid not in messages:
|
||||
return ToolResult(success=False, error=f"Email {email_id} not found")
|
||||
|
||||
raw = messages[uid][b"RFC822"]
|
||||
msg = email_lib.message_from_bytes(raw)
|
||||
|
||||
from_addr = msg.get("From", "")
|
||||
subject = _decode_header(msg.get("Subject", ""))
|
||||
date = msg.get("Date", "")
|
||||
message_id = msg.get("Message-ID", "")
|
||||
body_text = _extract_email_body(msg)
|
||||
|
||||
# Truncate body
|
||||
truncated = False
|
||||
if await is_option_enabled("system:security_truncation_enabled"):
|
||||
max_body = await get_content_limit("system:security_max_email_chars", _DEFAULT_MAX_EMAIL_CHARS)
|
||||
if len(body_text) > max_body:
|
||||
body_text = body_text[:max_body]
|
||||
truncated = True
|
||||
# Truncate subject
|
||||
max_subj = await get_content_limit("system:security_max_subject_chars", _DEFAULT_MAX_SUBJECT_CHARS)
|
||||
if len(subject) > max_subj:
|
||||
subject = subject[:max_subj] + " [subject truncated]"
|
||||
elif len(body_text) > MAX_BODY_CHARS:
|
||||
body_text = body_text[:MAX_BODY_CHARS]
|
||||
truncated = True
|
||||
|
||||
# Sanitise — critical security step (also sanitises subject)
|
||||
body_text = await sanitize_external_content(body_text, source="email")
|
||||
subject = await sanitize_external_content(subject, source="email_subject")
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"id": email_id,
|
||||
"from": from_addr,
|
||||
"subject": subject,
|
||||
"date": date,
|
||||
"message_id": message_id,
|
||||
"body": body_text,
|
||||
"truncated": truncated,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"IMAP error: {e}")
|
||||
|
||||
# ── SMTP ──────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _list_whitelist(self) -> ToolResult:
|
||||
from ..database import email_whitelist_store
|
||||
entries = await email_whitelist_store.list()
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={"recipients": [e["email"] for e in entries], "count": len(entries)},
|
||||
)
|
||||
|
||||
async def _send_email(
|
||||
self,
|
||||
to: list[str],
|
||||
subject: str,
|
||||
body: str,
|
||||
html_body: str = "",
|
||||
reply_to_id: str = "",
|
||||
) -> ToolResult:
|
||||
# Security: enforce whitelist + rate limit for every recipient
|
||||
try:
|
||||
for addr in to:
|
||||
await assert_recipient_allowed(addr)
|
||||
await assert_email_rate_limit(addr)
|
||||
except SecurityError as e:
|
||||
return ToolResult(success=False, error=str(e))
|
||||
|
||||
try:
|
||||
_, smtp_host, username, password, smtp_port = await self._load_credentials()
|
||||
except RuntimeError as e:
|
||||
return ToolResult(success=False, error=str(e))
|
||||
|
||||
# Build MIME message
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["From"] = username
|
||||
msg["To"] = ", ".join(to)
|
||||
msg["Subject"] = subject
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid()
|
||||
|
||||
if reply_to_id:
|
||||
msg["In-Reply-To"] = reply_to_id
|
||||
msg["References"] = reply_to_id
|
||||
|
||||
# Plain text
|
||||
msg.attach(MIMEText(body, "plain", "utf-8"))
|
||||
# HTML version — use provided html_body if given, otherwise wrap plain text
|
||||
if not html_body:
|
||||
html_body = f"<html><body><pre style='font-family:sans-serif'>{body}</pre></body></html>"
|
||||
msg.attach(MIMEText(html_body, "html", "utf-8"))
|
||||
|
||||
try:
|
||||
if smtp_port == 465:
|
||||
context = ssl.create_default_context()
|
||||
with smtplib.SMTP_SSL(smtp_host, smtp_port, context=context, timeout=10) as smtp:
|
||||
smtp.login(username, password)
|
||||
smtp.sendmail(username, to, msg.as_bytes())
|
||||
else:
|
||||
with smtplib.SMTP(smtp_host, smtp_port, timeout=10) as smtp:
|
||||
smtp.ehlo()
|
||||
smtp.starttls()
|
||||
smtp.login(username, password)
|
||||
smtp.sendmail(username, to, msg.as_bytes())
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={"sent": True, "to": to, "subject": subject},
|
||||
)
|
||||
|
||||
except smtplib.SMTPAuthenticationError:
|
||||
return ToolResult(success=False, error="SMTP authentication failed. Check mailcow_password.")
|
||||
except smtplib.SMTPException as e:
|
||||
return ToolResult(success=False, error=f"SMTP error: {e}")
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Send error: {e}")
|
||||
|
||||
def confirmation_description(self, to=None, subject: str = "", body: str = "", **kwargs) -> str:
|
||||
if isinstance(to, list):
|
||||
to_str = ", ".join(to)
|
||||
else:
|
||||
to_str = to or ""
|
||||
return f"Send email to {to_str}\nSubject: {subject}\n\n{body[:200]}..."
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _decode_header(value) -> str:
|
||||
"""Decode IMAP header value (may be bytes or string)."""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
return value.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
return str(value)
|
||||
return str(value)
|
||||
|
||||
|
||||
def _format_address(addresses) -> str:
|
||||
"""Format IMAP ENVELOPE address list to 'Name <email>' string."""
|
||||
if not addresses:
|
||||
return ""
|
||||
addr = addresses[0]
|
||||
name = _decode_header(addr.name) if addr.name else ""
|
||||
mailbox = _decode_header(addr.mailbox) if addr.mailbox else ""
|
||||
host = _decode_header(addr.host) if addr.host else ""
|
||||
email_addr = f"{mailbox}@{host}" if host else mailbox
|
||||
return f"{name} <{email_addr}>" if name else email_addr
|
||||
|
||||
|
||||
def _extract_email_body(msg: email_lib.message.Message) -> str:
|
||||
"""Extract plain text from email, stripping HTML if needed."""
|
||||
plain_parts = []
|
||||
html_parts = []
|
||||
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
ct = part.get_content_type()
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
if ct == "text/plain":
|
||||
try:
|
||||
plain_parts.append(part.get_payload(decode=True).decode(charset, errors="replace"))
|
||||
except Exception:
|
||||
pass
|
||||
elif ct == "text/html":
|
||||
try:
|
||||
html_parts.append(part.get_payload(decode=True).decode(charset, errors="replace"))
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
ct = msg.get_content_type()
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
payload = msg.get_payload(decode=True) or b""
|
||||
text = payload.decode(charset, errors="replace")
|
||||
if ct == "text/html":
|
||||
html_parts.append(text)
|
||||
else:
|
||||
plain_parts.append(text)
|
||||
|
||||
if plain_parts:
|
||||
return "\n".join(plain_parts)
|
||||
|
||||
if html_parts:
|
||||
soup = BeautifulSoup("\n".join(html_parts), "html.parser")
|
||||
return soup.get_text(separator="\n")
|
||||
|
||||
return ""
|
||||
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
tools/filesystem_tool.py — Sandboxed filesystem access.
|
||||
|
||||
All paths are validated against FILESYSTEM_SANDBOX_DIRS before any operation.
|
||||
Symlinks are resolved before checking (prevents traversal attacks).
|
||||
Binary files are rejected — text only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from ..context_vars import current_user as _current_user_var
|
||||
from ..security import SecurityError, assert_path_allowed, sanitize_external_content
|
||||
from ..security_screening import get_content_limit, is_option_enabled
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
MAX_FILE_SIZE = 500 * 1024 # 500 KB (file size gate — still enforced regardless of truncation option)
|
||||
_DEFAULT_MAX_FILE_CHARS = 20_000 # default char limit when truncation option is enabled
|
||||
MAX_DIR_ENTRIES = 200
|
||||
|
||||
# Image extensions and their MIME types — returned as base64 for vision-capable models
|
||||
_IMAGE_MEDIA_TYPES: dict[str, str] = {
|
||||
".jpg": "image/jpeg", ".jpeg": "image/jpeg",
|
||||
".png": "image/png", ".gif": "image/gif",
|
||||
".webp": "image/webp", ".bmp": "image/bmp",
|
||||
".tiff": "image/tiff", ".tif": "image/tiff",
|
||||
}
|
||||
|
||||
|
||||
class FilesystemTool(BaseTool):
|
||||
name = "filesystem"
|
||||
description = (
|
||||
"Read and write files in the owner's designated directories. "
|
||||
"Operations: read_file, list_directory, write_file, delete_file, "
|
||||
"create_directory, delete_directory. "
|
||||
"Image files (jpg, png, gif, webp, etc.) are returned as visual content "
|
||||
"that vision-capable models can analyse. "
|
||||
"write_file, delete_file, and delete_directory require user confirmation."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["read_file", "list_directory", "write_file", "delete_file",
|
||||
"create_directory", "delete_directory"],
|
||||
"description": "The filesystem operation to perform",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Absolute or relative path to the file or directory",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "File content for write_file operations",
|
||||
},
|
||||
},
|
||||
"required": ["operation", "path"],
|
||||
}
|
||||
requires_confirmation = False # set dynamically per operation
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
# Operations that require confirmation
|
||||
_CONFIRM_OPS = {"write_file", "delete_file", "delete_directory"}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
operation: str,
|
||||
path: str,
|
||||
content: str = "",
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
# Check operation
|
||||
if operation not in ("read_file", "list_directory", "write_file", "delete_file",
|
||||
"create_directory", "delete_directory"):
|
||||
return ToolResult(success=False, error=f"Unknown operation: {operation!r}")
|
||||
|
||||
# Sandbox check
|
||||
try:
|
||||
safe_path = await assert_path_allowed(path)
|
||||
except SecurityError as e:
|
||||
return ToolResult(success=False, error=str(e))
|
||||
|
||||
# Per-user folder restriction: non-admins may only access their personal folder
|
||||
_user = _current_user_var.get()
|
||||
if _user is not None and not _user.is_admin:
|
||||
from ..database import credential_store
|
||||
base = await credential_store.get("system:users_base_folder")
|
||||
if not base:
|
||||
return ToolResult(success=False, error="Filesystem access is not available for your account.")
|
||||
user_folder = Path(base.rstrip("/")) / _user.username
|
||||
try:
|
||||
resolved = safe_path.resolve()
|
||||
user_folder_resolved = user_folder.resolve()
|
||||
if not str(resolved).startswith(str(user_folder_resolved) + "/") and resolved != user_folder_resolved:
|
||||
return ToolResult(success=False, error="Access denied: path is outside your personal folder.")
|
||||
except Exception:
|
||||
return ToolResult(success=False, error="Access denied: could not verify path.")
|
||||
|
||||
# Dispatch
|
||||
if operation == "read_file":
|
||||
return await self._read_file(safe_path)
|
||||
elif operation == "list_directory":
|
||||
return self._list_directory(safe_path)
|
||||
elif operation == "write_file":
|
||||
return self._write_file(safe_path, content)
|
||||
elif operation == "delete_file":
|
||||
return self._delete_file(safe_path)
|
||||
elif operation == "create_directory":
|
||||
return self._create_directory(safe_path)
|
||||
elif operation == "delete_directory":
|
||||
return self._delete_directory(safe_path)
|
||||
|
||||
async def _read_file(self, path: Path) -> ToolResult:
|
||||
if not path.exists():
|
||||
return ToolResult(success=False, error=f"File not found: {path}")
|
||||
if not path.is_file():
|
||||
return ToolResult(success=False, error=f"Not a file: {path}")
|
||||
|
||||
size = path.stat().st_size
|
||||
if size > MAX_FILE_SIZE:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"File too large ({size:,} bytes). Max: {MAX_FILE_SIZE:,} bytes.",
|
||||
)
|
||||
|
||||
# Image files: return as base64 for vision-capable models
|
||||
media_type = _IMAGE_MEDIA_TYPES.get(path.suffix.lower())
|
||||
if media_type:
|
||||
try:
|
||||
import base64
|
||||
image_data = base64.b64encode(path.read_bytes()).decode("ascii")
|
||||
except PermissionError:
|
||||
return ToolResult(success=False, error=f"Permission denied: {path}")
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"path": str(path),
|
||||
"is_image": True,
|
||||
"media_type": media_type,
|
||||
"image_data": image_data,
|
||||
"size_bytes": size,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
text = path.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return ToolResult(success=False, error="Binary files are not supported. Text files only.")
|
||||
except PermissionError:
|
||||
return ToolResult(success=False, error=f"Permission denied: {path}")
|
||||
|
||||
# Apply configurable char limit — truncate rather than reject
|
||||
truncated = False
|
||||
if await is_option_enabled("system:security_truncation_enabled"):
|
||||
max_chars = await get_content_limit("system:security_max_file_chars", _DEFAULT_MAX_FILE_CHARS)
|
||||
total_chars = len(text)
|
||||
if total_chars > max_chars:
|
||||
text = text[:max_chars]
|
||||
text += f"\n\n[File truncated: showing first {max_chars:,} of {total_chars:,} total chars]"
|
||||
truncated = True
|
||||
|
||||
# Sanitise before returning to agent
|
||||
text = await sanitize_external_content(text, source="file")
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={"path": str(path), "content": text, "size_bytes": size, "truncated": truncated},
|
||||
)
|
||||
|
||||
def _list_directory(self, path: Path) -> ToolResult:
|
||||
if not path.exists():
|
||||
return ToolResult(success=False, error=f"Directory not found: {path}")
|
||||
if not path.is_dir():
|
||||
return ToolResult(success=False, error=f"Not a directory: {path}")
|
||||
|
||||
try:
|
||||
entries = []
|
||||
for i, entry in enumerate(sorted(path.iterdir())):
|
||||
if i >= MAX_DIR_ENTRIES:
|
||||
entries.append({"name": f"... ({i} entries truncated)", "type": "truncated"})
|
||||
break
|
||||
stat = entry.stat()
|
||||
entries.append({
|
||||
"name": entry.name,
|
||||
"type": "directory" if entry.is_dir() else "file",
|
||||
"size_bytes": stat.st_size if entry.is_file() else None,
|
||||
})
|
||||
except PermissionError:
|
||||
return ToolResult(success=False, error=f"Permission denied: {path}")
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={"path": str(path), "entries": entries, "count": len(entries)},
|
||||
)
|
||||
|
||||
def _write_file(self, path: Path, content: str) -> ToolResult:
|
||||
try:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Support base64 data URLs for binary files (e.g. images from image_gen tool)
|
||||
if content.startswith("data:") and ";base64," in content:
|
||||
import base64 as _b64
|
||||
try:
|
||||
_, b64_data = content.split(",", 1)
|
||||
raw_bytes = _b64.b64decode(b64_data)
|
||||
path.write_bytes(raw_bytes)
|
||||
return ToolResult(success=True, data={"path": str(path), "bytes_written": len(raw_bytes)})
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Failed to decode base64 data: {e}")
|
||||
path.write_text(content, encoding="utf-8")
|
||||
except PermissionError:
|
||||
return ToolResult(success=False, error=f"Permission denied: {path}")
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Write failed: {e}")
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={"path": str(path), "bytes_written": len(content.encode())},
|
||||
)
|
||||
|
||||
def _delete_file(self, path: Path) -> ToolResult:
|
||||
if not path.exists():
|
||||
return ToolResult(success=False, error=f"File not found: {path}")
|
||||
if not path.is_file():
|
||||
return ToolResult(success=False, error="Can only delete files, not directories.")
|
||||
|
||||
try:
|
||||
path.unlink()
|
||||
except PermissionError:
|
||||
return ToolResult(success=False, error=f"Permission denied: {path}")
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Delete failed: {e}")
|
||||
|
||||
return ToolResult(success=True, data={"deleted": str(path)})
|
||||
|
||||
def _create_directory(self, path: Path) -> ToolResult:
|
||||
try:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
except PermissionError:
|
||||
return ToolResult(success=False, error=f"Permission denied: {path}")
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Create directory failed: {e}")
|
||||
return ToolResult(success=True, data={"created": str(path)})
|
||||
|
||||
def _delete_directory(self, path: Path) -> ToolResult:
|
||||
if not path.exists():
|
||||
return ToolResult(success=False, error=f"Directory not found: {path}")
|
||||
if not path.is_dir():
|
||||
return ToolResult(success=False, error=f"Not a directory: {path}")
|
||||
try:
|
||||
shutil.rmtree(path)
|
||||
except PermissionError:
|
||||
return ToolResult(success=False, error=f"Permission denied: {path}")
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Delete directory failed: {e}")
|
||||
return ToolResult(success=True, data={"deleted": str(path)})
|
||||
|
||||
def confirmation_description(self, operation: str = "", path: str = "", **kwargs) -> str:
|
||||
if operation == "delete_file":
|
||||
return f"Permanently delete file: {path}"
|
||||
if operation == "delete_directory":
|
||||
return f"Permanently delete directory and all its contents: {path}"
|
||||
if operation == "write_file":
|
||||
content_preview = kwargs.get("content", "")[:80]
|
||||
return f"Write to file: {path}\nContent preview: {content_preview}..."
|
||||
return f"{operation}: {path}"
|
||||
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
tools/image_gen_tool.py — AI image generation tool.
|
||||
|
||||
Calls an image-generation model (via OpenRouter by default) and returns the
|
||||
result. If save_path is given the image is written to disk immediately so the
|
||||
model doesn't need to handle large base64 blobs.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default model — override per call or via credential system:default_image_gen_model
|
||||
_DEFAULT_MODEL = "openrouter:openai/gpt-5-image"
|
||||
|
||||
|
||||
class ImageGenTool(BaseTool):
|
||||
name = "image_gen"
|
||||
description = (
|
||||
"Generate an image from a text prompt using an AI image-generation model. "
|
||||
"If save_path is provided the image is saved to that path and only the path "
|
||||
"is returned (no base64 blob in context). "
|
||||
"If save_path is omitted the raw base64 image data is returned so you can "
|
||||
"inspect it or pass it to another tool."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Detailed description of the image to generate",
|
||||
},
|
||||
"save_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional absolute file path to save the image to "
|
||||
"(e.g. /data/users/rune/stewie.png). "
|
||||
"Recommended — avoids returning a large base64 blob."
|
||||
),
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional image-generation model ID "
|
||||
"(e.g. openrouter:openai/gpt-5-image, "
|
||||
"openrouter:google/gemini-2.0-flash-exp:free). "
|
||||
"Defaults to the system default image model."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["prompt"],
|
||||
}
|
||||
requires_confirmation = False
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
prompt: str,
|
||||
save_path: str = "",
|
||||
model: str = "",
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
# Resolve model: tool arg → credential override → hardcoded default
|
||||
if not model:
|
||||
from ..database import credential_store
|
||||
model = (await credential_store.get("system:default_image_gen_model")) or _DEFAULT_MODEL
|
||||
|
||||
# Resolve provider + bare model id
|
||||
from ..context_vars import current_user as _cu
|
||||
user_id = _cu.get().id if _cu.get() else None
|
||||
try:
|
||||
from ..providers.registry import get_provider_for_model
|
||||
provider, bare_model = await get_provider_for_model(model, user_id=user_id)
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Could not resolve image model '{model}': {e}")
|
||||
|
||||
# Call the model with a simple user message containing the prompt
|
||||
try:
|
||||
response = await provider.chat_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
tools=None,
|
||||
system="",
|
||||
model=bare_model,
|
||||
max_tokens=1024,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("[image_gen] Provider call failed: %s", e)
|
||||
return ToolResult(success=False, error=f"Image generation failed: {e}")
|
||||
|
||||
if not response.images:
|
||||
msg = response.text or "(no images returned)"
|
||||
logger.warning("[image_gen] No images in response. text=%r", msg)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Model did not return any images. Response: {msg[:300]}",
|
||||
)
|
||||
|
||||
# Use the first image (most models return exactly one)
|
||||
data_url = response.images[0]
|
||||
|
||||
# Parse the data URL: data:<media_type>;base64,<data>
|
||||
media_type = "image/png"
|
||||
img_b64 = data_url
|
||||
if data_url.startswith("data:"):
|
||||
try:
|
||||
header, img_b64 = data_url.split(",", 1)
|
||||
media_type = header.split(":")[1].split(";")[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
img_bytes = base64.b64decode(img_b64)
|
||||
|
||||
# Save to path if requested
|
||||
if save_path:
|
||||
from pathlib import Path
|
||||
from ..security import assert_path_allowed, SecurityError
|
||||
try:
|
||||
safe_path = await assert_path_allowed(save_path)
|
||||
except SecurityError as e:
|
||||
return ToolResult(success=False, error=str(e))
|
||||
try:
|
||||
safe_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
safe_path.write_bytes(img_bytes)
|
||||
except PermissionError:
|
||||
return ToolResult(success=False, error=f"Permission denied: {safe_path}")
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Save failed: {e}")
|
||||
|
||||
logger.info("[image_gen] Saved image to %s (%d bytes)", safe_path, len(img_bytes))
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"saved_to": str(safe_path),
|
||||
"size_bytes": len(img_bytes),
|
||||
"media_type": media_type,
|
||||
"model": f"{model}/{bare_model}".strip("/"),
|
||||
"prompt": prompt,
|
||||
},
|
||||
)
|
||||
|
||||
# Return base64 data (no save_path given)
|
||||
logger.info("[image_gen] Returning image data (%d bytes, %s)", len(img_bytes), media_type)
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"is_image": True,
|
||||
"image_data": img_b64,
|
||||
"media_type": media_type,
|
||||
"size_bytes": len(img_bytes),
|
||||
"model": f"{model}/{bare_model}".strip("/"),
|
||||
"prompt": prompt,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
tools/mcp_proxy_tool.py — Dynamic BaseTool wrapper for external MCP server tools.
|
||||
|
||||
One McpProxyTool instance is created per tool discovered from each MCP server.
|
||||
The tool name is namespaced as mcp__{server_slug}__{tool_name} to avoid collisions.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
|
||||
def _slugify(name: str) -> str:
|
||||
"""Convert a server name to a safe identifier component."""
|
||||
return re.sub(r"[^a-z0-9]+", "_", name.lower()).strip("_")
|
||||
|
||||
|
||||
class McpProxyTool(BaseTool):
|
||||
requires_confirmation = False
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_id: str,
|
||||
server_name: str,
|
||||
server: dict,
|
||||
tool_name: str,
|
||||
description: str,
|
||||
input_schema: dict,
|
||||
) -> None:
|
||||
self.name = f"mcp__{_slugify(server_name)}__{tool_name}"
|
||||
self.description = f"[{server_name}] {description}"
|
||||
self.input_schema = input_schema if input_schema else {
|
||||
"type": "object", "properties": {}, "required": []
|
||||
}
|
||||
self._server_id = server_id
|
||||
self._server_display_name = server_name # human-readable name for UI
|
||||
self._server = server # full server dict with decrypted secrets
|
||||
self._remote_tool_name = tool_name # original name on the MCP server
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
from ..mcp_client.manager import call_tool
|
||||
# Refresh server config (api_key may have changed since startup)
|
||||
from ..mcp_client.store import get_server
|
||||
server = await get_server(self._server_id, include_secrets=True) or self._server
|
||||
return await call_tool(server, self._remote_tool_name, kwargs)
|
||||
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
tools/mock.py — Mock tools for testing the agent loop without real integrations.
|
||||
|
||||
EchoTool — returns its input unchanged, no side effects
|
||||
ConfirmTool — always requires confirmation; logs that confirmation was received
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
|
||||
class EchoTool(BaseTool):
|
||||
name = "echo"
|
||||
description = (
|
||||
"Returns whatever text you pass in. Useful for testing. "
|
||||
"Use this when you need to verify tool calling works."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The text to echo back",
|
||||
}
|
||||
},
|
||||
"required": ["message"],
|
||||
}
|
||||
requires_confirmation = False
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
async def execute(self, message: str = "", **kwargs) -> ToolResult:
|
||||
return ToolResult(success=True, data={"echo": message})
|
||||
|
||||
|
||||
class ConfirmTool(BaseTool):
|
||||
name = "confirm_action"
|
||||
description = (
|
||||
"A test tool that always requires user confirmation before proceeding. "
|
||||
"Use to verify the confirmation flow works end-to-end."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Description of the action requiring confirmation",
|
||||
}
|
||||
},
|
||||
"required": ["action"],
|
||||
}
|
||||
requires_confirmation = True
|
||||
allowed_in_scheduled_tasks = False
|
||||
|
||||
async def execute(self, action: str = "", **kwargs) -> ToolResult:
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={"confirmed": True, "action": action},
|
||||
)
|
||||
|
||||
def confirmation_description(self, action: str = "", **kwargs) -> str:
|
||||
return f"Perform action: {action}"
|
||||
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
tools/pushover_tool.py — Pushover push notifications.
|
||||
|
||||
Sends to exactly one hard-coded user key (defined in security.py).
|
||||
Normal priority (0) and below: no confirmation required.
|
||||
Emergency priority (2): always requires confirmation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
from ..database import credential_store
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
PUSHOVER_API_URL = "https://api.pushover.net/1/messages.json"
|
||||
|
||||
|
||||
class PushoverTool(BaseTool):
|
||||
name = "pushover"
|
||||
description = (
|
||||
"Send a push notification to the owner's phone via Pushover. "
|
||||
"Use for alerts, reminders, and status updates. "
|
||||
"Priority: -2 (silent), -1 (quiet), 0 (normal), 1 (high), 2 (emergency — requires confirmation)."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Notification title (short, shown in bold)",
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Notification body text",
|
||||
},
|
||||
"priority": {
|
||||
"type": "integer",
|
||||
"enum": [-2, -1, 0, 1, 2],
|
||||
"description": "Priority level (-2 silent to 2 emergency). Default: 0",
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Optional URL to attach to the notification",
|
||||
},
|
||||
"url_title": {
|
||||
"type": "string",
|
||||
"description": "Display text for the attached URL",
|
||||
},
|
||||
},
|
||||
"required": ["title", "message"],
|
||||
}
|
||||
requires_confirmation = False # overridden dynamically for priority=2
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
title: str,
|
||||
message: str,
|
||||
priority: int = 0,
|
||||
url: str = "",
|
||||
url_title: str = "",
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
# Validate priority
|
||||
if priority not in (-2, -1, 0, 1, 2):
|
||||
return ToolResult(success=False, error=f"Invalid priority: {priority}")
|
||||
|
||||
# Emergency always requires confirmation — enforced here as defence-in-depth
|
||||
# (the agent loop also checks requires_confirmation before calling execute)
|
||||
if priority == 2:
|
||||
# The agent loop should have asked for confirmation already.
|
||||
# If we got here, it was approved.
|
||||
pass
|
||||
|
||||
# Load credentials
|
||||
try:
|
||||
app_token = await credential_store.require("pushover_app_token")
|
||||
user_key = await credential_store.require("pushover_user_key")
|
||||
except RuntimeError as e:
|
||||
return ToolResult(success=False, error=str(e))
|
||||
|
||||
payload: dict = {
|
||||
"token": app_token,
|
||||
"user": user_key,
|
||||
"title": title[:250],
|
||||
"message": message[:1024],
|
||||
"priority": priority,
|
||||
}
|
||||
|
||||
if priority == 2:
|
||||
# Emergency: retry every 30s for 1 hour until acknowledged
|
||||
payload["retry"] = 30
|
||||
payload["expire"] = 3600
|
||||
|
||||
if url:
|
||||
payload["url"] = url[:512]
|
||||
if url_title:
|
||||
payload["url_title"] = url_title[:100]
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.post(PUSHOVER_API_URL, data=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if data.get("status") != 1:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Pushover API error: {data.get('errors', 'unknown')}",
|
||||
)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={"sent": True, "request_id": data.get("request", "")},
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return ToolResult(success=False, error=f"Pushover HTTP error: {e.response.status_code}")
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Pushover error: {e}")
|
||||
|
||||
def confirmation_description(self, title: str = "", message: str = "", priority: int = 0, **kwargs) -> str:
|
||||
return f"Send emergency Pushover notification: '{title}' — {message[:100]}"
|
||||
|
||||
def get_schema(self) -> dict:
|
||||
"""Override to make requires_confirmation dynamic for priority 2."""
|
||||
schema = super().get_schema()
|
||||
# The tool itself handles emergency confirmation — marked as requiring it always
|
||||
# so the agent loop treats it consistently. For non-emergency, the agent loop
|
||||
# still calls execute(), which works fine without confirmation.
|
||||
return schema
|
||||
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
tools/subagent_tool.py — Sub-agent creation and synchronous execution.
|
||||
|
||||
Only available when the parent agent has can_create_subagents=True.
|
||||
Creates a child agent in the DB (with parent_agent_id set) and runs it
|
||||
synchronously, returning the result and token counts.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubagentTool(BaseTool):
|
||||
name = "subagent"
|
||||
description = (
|
||||
"Create and run a sub-agent synchronously. "
|
||||
"Use this to delegate a well-defined sub-task to a focused agent. "
|
||||
"The sub-agent runs to completion and returns its result. "
|
||||
"Operation: create_and_run(name, prompt, model=None)"
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Short name for the sub-agent",
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The task prompt for the sub-agent",
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Model override for the sub-agent (optional)",
|
||||
},
|
||||
},
|
||||
"required": ["name", "prompt"],
|
||||
}
|
||||
requires_confirmation = False
|
||||
allowed_in_scheduled_tasks = False
|
||||
|
||||
def __init__(self, parent_agent_id: str, parent_model: str) -> None:
|
||||
self._parent_agent_id = parent_agent_id
|
||||
self._parent_model = parent_model
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
name: str,
|
||||
prompt: str,
|
||||
model: str = "",
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
from ..agents import tasks as agent_store
|
||||
from ..agents.runner import agent_runner
|
||||
|
||||
resolved_model = model or self._parent_model
|
||||
|
||||
try:
|
||||
# Create sub-agent in DB
|
||||
sub = agent_store.create_agent(
|
||||
name=name,
|
||||
prompt=prompt,
|
||||
model=resolved_model,
|
||||
parent_agent_id=self._parent_agent_id,
|
||||
created_by=self._parent_agent_id,
|
||||
enabled=True,
|
||||
)
|
||||
sub_id = sub["id"]
|
||||
|
||||
# Run it now and wait for completion
|
||||
run = await agent_runner.run_agent_now(sub_id)
|
||||
run_id = run["id"]
|
||||
|
||||
# Wait for the asyncio task to finish (run_agent_now creates the task)
|
||||
import asyncio
|
||||
task = agent_runner._running.get(run_id)
|
||||
if task:
|
||||
await task
|
||||
|
||||
# Fetch final run record
|
||||
final = agent_store.get_run(run_id)
|
||||
if not final:
|
||||
return ToolResult(success=False, error="Sub-agent run record not found")
|
||||
|
||||
return ToolResult(
|
||||
success=final["status"] == "success",
|
||||
data={
|
||||
"run_id": run_id,
|
||||
"agent_id": sub_id,
|
||||
"status": final["status"],
|
||||
"result": final.get("result") or "",
|
||||
"input_tokens": final.get("input_tokens", 0),
|
||||
"output_tokens": final.get("output_tokens", 0),
|
||||
},
|
||||
error=final.get("error") if final["status"] != "success" else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[subagent] Error running sub-agent: {e}")
|
||||
return ToolResult(success=False, error=f"Sub-agent error: {e}")
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
tools/telegram_tool.py — Outbound Telegram messages for agents.
|
||||
|
||||
Sends to whitelisted chat IDs only. No confirmation required.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
from ..database import credential_store
|
||||
from ..telegram.triggers import is_allowed
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
_API = "https://api.telegram.org/bot{token}/sendMessage"
|
||||
|
||||
|
||||
class TelegramTool(BaseTool):
|
||||
name = "telegram"
|
||||
description = (
|
||||
"Send a Telegram message to a whitelisted chat ID. "
|
||||
"Use for notifications, alerts, and replies to Telegram users. "
|
||||
"chat_id must be in the Telegram whitelist."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Telegram chat ID to send the message to",
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Message text (plain text, max 4096 characters)",
|
||||
},
|
||||
},
|
||||
"required": ["chat_id", "message"],
|
||||
}
|
||||
requires_confirmation = False
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
async def execute(self, chat_id: str, message: str, **kwargs) -> ToolResult:
|
||||
# Resolve current user for per-user token and whitelist check
|
||||
from ..context_vars import current_user as _current_user
|
||||
user = _current_user.get(None)
|
||||
|
||||
# Try global token first, then per-user token
|
||||
token = await credential_store.get("telegram:bot_token")
|
||||
if not token and user:
|
||||
from ..database import user_settings_store
|
||||
token = await user_settings_store.get(user.id, "telegram_bot_token")
|
||||
if not token:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="Telegram is not configured. Add telegram:bot_token in Settings → Credentials.",
|
||||
)
|
||||
|
||||
# Security: chat_id must be whitelisted (check user scope, then global)
|
||||
allowed = await is_allowed(chat_id, user_id=user.id if user else "GLOBAL")
|
||||
if not allowed and user:
|
||||
allowed = await is_allowed(chat_id, user_id="GLOBAL")
|
||||
if not allowed:
|
||||
# Return whitelisted IDs so the model can retry with a real one
|
||||
from ..telegram.triggers import list_whitelist
|
||||
global_ids = [r["chat_id"] for r in await list_whitelist("GLOBAL")]
|
||||
user_ids = [r["chat_id"] for r in await list_whitelist(user.id)] if user else []
|
||||
all_ids = list(dict.fromkeys(global_ids + user_ids)) # deduplicate, preserve order
|
||||
hint = f" Whitelisted chat_ids: {', '.join(all_ids)}." if all_ids else " No chat IDs are whitelisted yet — add one in Settings → Telegram."
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"chat_id '{chat_id}' is not in the Telegram whitelist.{hint}",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as http:
|
||||
resp = await http.post(
|
||||
_API.format(token=token),
|
||||
json={"chat_id": chat_id, "text": message[:4096]},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if not data.get("ok"):
|
||||
return ToolResult(success=False, error=f"Telegram API error: {data}")
|
||||
|
||||
return ToolResult(success=True, data={"sent": True, "chat_id": chat_id})
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return ToolResult(success=False, error=f"Telegram HTTP error: {e.response.status_code}")
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Telegram error: {e}")
|
||||
|
||||
|
||||
class BoundTelegramTool(TelegramTool):
|
||||
"""TelegramTool with a pre-configured chat_id — used by email handling accounts.
|
||||
|
||||
The model only supplies the message; chat_id is fixed at account configuration time
|
||||
so the model cannot send to arbitrary chats.
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, reply_keyword: str | None = None) -> None:
|
||||
self._bound_chat_id = chat_id
|
||||
self._reply_keyword = reply_keyword
|
||||
hint = f"\n\n💬 Reply: /{reply_keyword} <your message>" if reply_keyword else ""
|
||||
self._reply_hint = hint
|
||||
self.description = (
|
||||
f"Send a Telegram message to the configured chat (chat_id {chat_id}). "
|
||||
"Only supply the message text — the destination is fixed. "
|
||||
f"A reply hint will be appended automatically{' (/' + reply_keyword + ')' if reply_keyword else ''}."
|
||||
)
|
||||
self.input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Message text (plain text, max 4096 characters)",
|
||||
},
|
||||
},
|
||||
"required": ["message"],
|
||||
}
|
||||
|
||||
async def execute(self, message: str, **kwargs) -> ToolResult: # type: ignore[override]
|
||||
full_message = message + self._reply_hint if self._reply_hint else message
|
||||
return await super().execute(chat_id=self._bound_chat_id, message=full_message, **kwargs)
|
||||
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
tools/web_tool.py — Tiered web access.
|
||||
|
||||
Tier 1: Domains in WEB_TIER1_WHITELIST — always allowed.
|
||||
Tier 2: Any other domain — allowed only when web_tier2_enabled is True
|
||||
in the current execution context (set by the agent loop when the
|
||||
user explicitly requests external web research) or when running
|
||||
as a scheduled task that declared web access.
|
||||
|
||||
DuckDuckGo search uses the HTML endpoint (no API key required).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from urllib.parse import quote_plus, urlparse
|
||||
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from ..context_vars import current_task_id, web_tier2_enabled
|
||||
from ..security import SecurityError, assert_domain_tier1, sanitize_external_content
|
||||
from ..security_screening import get_content_limit, is_option_enabled
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
MAX_RESPONSE_BYTES = 50 * 1024 # 50 KB (legacy fallback when truncation option disabled)
|
||||
_DEFAULT_MAX_WEB_CHARS = 20_000 # default when truncation option is enabled
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
MAX_SEARCH_RESULTS = 10
|
||||
|
||||
_HEADERS = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept-Language": "en-US,en;q=0.9,nb;q=0.8",
|
||||
}
|
||||
|
||||
|
||||
class WebTool(BaseTool):
|
||||
name = "web"
|
||||
description = (
|
||||
"Fetch web pages and search the web. "
|
||||
"Operations: fetch_page (retrieve and extract text from a URL), "
|
||||
"search (DuckDuckGo search, returns titles, URLs and snippets). "
|
||||
"Commonly used sites (Wikipedia, yr.no, etc.) are always available. "
|
||||
"Other sites require the user to have initiated a web research task."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["fetch_page", "search"],
|
||||
"description": "fetch_page retrieves a URL; search queries DuckDuckGo",
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to fetch (required for fetch_page)",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (required for search)",
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": f"Max search results to return (default 5, max {MAX_SEARCH_RESULTS})",
|
||||
},
|
||||
},
|
||||
"required": ["operation"],
|
||||
}
|
||||
requires_confirmation = False
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
operation: str,
|
||||
url: str = "",
|
||||
query: str = "",
|
||||
num_results: int = 5,
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
if operation == "fetch_page":
|
||||
if not url:
|
||||
return ToolResult(success=False, error="url is required for fetch_page")
|
||||
return await self._fetch_page(url)
|
||||
|
||||
if operation == "search":
|
||||
if not query:
|
||||
return ToolResult(success=False, error="query is required for search")
|
||||
return await self._search(query, min(num_results, MAX_SEARCH_RESULTS))
|
||||
|
||||
return ToolResult(success=False, error=f"Unknown operation: {operation!r}")
|
||||
|
||||
# ── Tier check ────────────────────────────────────────────────────────────
|
||||
|
||||
async def _check_tier(self, url: str) -> ToolResult | None:
|
||||
"""
|
||||
Returns a ToolResult(success=False) if access is denied, None if allowed.
|
||||
Tier 1 is always allowed. Tier 2 requires context flag or scheduled task.
|
||||
"""
|
||||
if await assert_domain_tier1(url):
|
||||
return None # Tier 1 — always allowed
|
||||
|
||||
# Tier 2 — check context
|
||||
task_id = current_task_id.get()
|
||||
tier2 = web_tier2_enabled.get()
|
||||
|
||||
if task_id is not None:
|
||||
# Scheduled tasks that declared web access can use Tier 2
|
||||
return None
|
||||
|
||||
if tier2:
|
||||
return None # User explicitly initiated web research
|
||||
|
||||
parsed = urlparse(url)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=(
|
||||
f"Domain '{parsed.hostname}' is not in the Tier 1 whitelist. "
|
||||
"To access it, ask me to search the web or fetch a specific external page — "
|
||||
"I'll enable Tier 2 access for your request."
|
||||
),
|
||||
)
|
||||
|
||||
# ── fetch_page ────────────────────────────────────────────────────────────
|
||||
|
||||
async def _fetch_page(self, url: str) -> ToolResult:
|
||||
denied = await self._check_tier(url)
|
||||
if denied:
|
||||
return denied
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
headers=_HEADERS,
|
||||
) as client:
|
||||
resp = await client.get(url)
|
||||
resp.raise_for_status()
|
||||
|
||||
content_type = resp.headers.get("content-type", "")
|
||||
if "text" not in content_type and "html" not in content_type:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Non-text content type: {content_type}. Only text/HTML pages are supported.",
|
||||
)
|
||||
|
||||
raw = resp.content[:MAX_RESPONSE_BYTES * 2] # read more, truncate after parse
|
||||
text = _extract_text(raw)
|
||||
|
||||
if await is_option_enabled("system:security_truncation_enabled"):
|
||||
max_chars = await get_content_limit("system:security_max_web_chars", _DEFAULT_MAX_WEB_CHARS)
|
||||
if len(text) > max_chars:
|
||||
text = text[:max_chars]
|
||||
text += f"\n\n[Content truncated at {max_chars:,} chars]"
|
||||
elif len(text.encode()) > MAX_RESPONSE_BYTES:
|
||||
text = text[: MAX_RESPONSE_BYTES // 4 * 4] # char-safe truncation
|
||||
text += f"\n\n[Content truncated at {MAX_RESPONSE_BYTES // 1024} KB]"
|
||||
|
||||
text = await sanitize_external_content(text, source="web")
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"url": str(resp.url),
|
||||
"content": text,
|
||||
"status_code": resp.status_code,
|
||||
},
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return ToolResult(success=False, error=f"Request timed out after {REQUEST_TIMEOUT}s: {url}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
return ToolResult(success=False, error=f"HTTP {e.response.status_code}: {url}")
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Fetch error: {e}")
|
||||
|
||||
# ── search ────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _search(self, query: str, num_results: int) -> ToolResult:
|
||||
# DuckDuckGo is Tier 1 — always allowed, no tier check needed
|
||||
ddg_url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
headers={**_HEADERS, "Accept": "text/html"},
|
||||
) as client:
|
||||
resp = await client.post(
|
||||
"https://html.duckduckgo.com/html/",
|
||||
data={"q": query, "b": "", "kl": ""},
|
||||
headers=_HEADERS,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
results = _parse_ddg_results(resp.text, num_results)
|
||||
|
||||
if not results:
|
||||
# Fallback: try GET
|
||||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT, headers=_HEADERS) as client:
|
||||
resp = await client.get(ddg_url)
|
||||
results = _parse_ddg_results(resp.text, num_results)
|
||||
|
||||
# Sanitise snippets
|
||||
for r in results:
|
||||
r["snippet"] = await sanitize_external_content(r["snippet"], source="web")
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={"query": query, "results": results, "count": len(results)},
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return ToolResult(success=False, error=f"Search timed out after {REQUEST_TIMEOUT}s")
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, error=f"Search error: {e}")
|
||||
|
||||
|
||||
# ── HTML helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _extract_text(raw: bytes) -> str:
|
||||
"""Strip HTML, scripts, styles and return clean readable text."""
|
||||
soup = BeautifulSoup(raw, "html.parser")
|
||||
|
||||
# Remove noise elements
|
||||
for tag in soup(["script", "style", "nav", "footer", "header", "aside", "noscript"]):
|
||||
tag.decompose()
|
||||
|
||||
# Get text with spacing
|
||||
text = soup.get_text(separator="\n")
|
||||
|
||||
# Collapse whitespace
|
||||
lines = [line.strip() for line in text.splitlines()]
|
||||
lines = [l for l in lines if l]
|
||||
text = "\n".join(lines)
|
||||
|
||||
# Collapse multiple blank lines
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
return text
|
||||
|
||||
|
||||
def _parse_ddg_results(html: str, limit: int) -> list[dict]:
|
||||
"""Parse DuckDuckGo HTML results page."""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
results = []
|
||||
|
||||
for result in soup.select(".result__body, .result"):
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
title_el = result.select_one(".result__title, .result__a")
|
||||
url_el = result.select_one(".result__url, a.result__a")
|
||||
snippet_el = result.select_one(".result__snippet")
|
||||
|
||||
title = title_el.get_text(strip=True) if title_el else ""
|
||||
url = ""
|
||||
if url_el:
|
||||
href = url_el.get("href", "")
|
||||
# DDG wraps URLs — extract real URL
|
||||
if "uddg=" in href:
|
||||
from urllib.parse import unquote, parse_qs
|
||||
qs = parse_qs(urlparse(href).query)
|
||||
url = unquote(qs.get("uddg", [""])[0])
|
||||
elif href.startswith("http"):
|
||||
url = href
|
||||
else:
|
||||
url = url_el.get_text(strip=True)
|
||||
|
||||
snippet = snippet_el.get_text(strip=True) if snippet_el else ""
|
||||
|
||||
if title and (url or snippet):
|
||||
results.append({
|
||||
"title": title,
|
||||
"url": url,
|
||||
"snippet": snippet,
|
||||
})
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
tools/whitelist_tool.py — Web domain whitelist management for agents (async).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from ..database import web_whitelist_store
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
|
||||
class WhitelistTool(BaseTool):
|
||||
name = "whitelist"
|
||||
description = (
|
||||
"Manage the Tier-1 web domain whitelist. "
|
||||
"Use this when the user asks to add, remove, or list whitelisted websites. "
|
||||
"Operations: list_domains, add_domain, remove_domain."
|
||||
)
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["list_domains", "add_domain", "remove_domain"],
|
||||
"description": "Operation to perform",
|
||||
},
|
||||
"domain": {
|
||||
"type": "string",
|
||||
"description": "Domain to add or remove (e.g. 'nrk.no'). Required for add_domain and remove_domain.",
|
||||
},
|
||||
"note": {
|
||||
"type": "string",
|
||||
"description": "Optional description of why this domain is whitelisted.",
|
||||
},
|
||||
},
|
||||
"required": ["operation"],
|
||||
}
|
||||
requires_confirmation = False
|
||||
allowed_in_scheduled_tasks = True
|
||||
|
||||
async def execute(self, operation: str, domain: str = "", note: str = "", **kwargs) -> ToolResult:
|
||||
if operation == "list_domains":
|
||||
domains = await web_whitelist_store.list()
|
||||
return ToolResult(success=True, data={"domains": domains, "count": len(domains)})
|
||||
|
||||
if operation == "add_domain":
|
||||
if not domain:
|
||||
return ToolResult(success=False, error="domain is required for add_domain")
|
||||
await web_whitelist_store.add(domain.strip(), note)
|
||||
return ToolResult(success=True, data={"added": domain.strip(), "note": note})
|
||||
|
||||
if operation == "remove_domain":
|
||||
if not domain:
|
||||
return ToolResult(success=False, error="domain is required for remove_domain")
|
||||
removed = await web_whitelist_store.remove(domain.strip())
|
||||
if not removed:
|
||||
return ToolResult(success=False, error=f"Domain '{domain}' not found in whitelist")
|
||||
return ToolResult(success=True, data={"removed": domain.strip()})
|
||||
|
||||
return ToolResult(success=False, error=f"Unknown operation: {operation}")
|
||||
Reference in New Issue
Block a user