Advanced Tool Development
This guide covers advanced techniques for developing custom tools that extend the capabilities of PraisonAI agents. Learn how to create robust, reusable, and performant tools.Overview
Tools are the bridge between AI agents and the external world. Advanced tool development involves:- Complex parameter handling
- Async operations
- Error recovery
- State management
- Performance optimization
- Security considerations
Tool Architecture
Basic Tool Structure
Copy
from typing import Any, Dict, Optional, List, Union
from dataclasses import dataclass
from abc import ABC, abstractmethod
import asyncio
import logging
@dataclass
class ToolResult:
"""Standardized tool result"""
success: bool
data: Any
error: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
class BaseTool(ABC):
"""Base class for all tools"""
def __init__(self, name: str, description: str):
self.name = name
self.description = description
self.logger = logging.getLogger(f"tool.{name}")
self._metrics = {
"calls": 0,
"errors": 0,
"total_duration": 0
}
@abstractmethod
async def execute(self, **kwargs) -> ToolResult:
"""Execute the tool with given parameters"""
pass
@abstractmethod
def validate_params(self, **kwargs) -> bool:
"""Validate input parameters"""
pass
def get_schema(self) -> Dict[str, Any]:
"""Get tool parameter schema"""
return {
"name": self.name,
"description": self.description,
"parameters": self._get_parameters()
}
@abstractmethod
def _get_parameters(self) -> Dict[str, Any]:
"""Define tool parameters"""
pass
Advanced Tool Patterns
1. Stateful Tools
Copy
class StatefulDatabaseTool(BaseTool):
"""Tool that maintains database connection state"""
def __init__(self, connection_string: str):
super().__init__(
name="database_query",
description="Execute database queries with connection pooling"
)
self.connection_string = connection_string
self._connection_pool = None
self._cache = {}
self._lock = asyncio.Lock()
async def initialize(self):
"""Initialize connection pool"""
if not self._connection_pool:
async with self._lock:
if not self._connection_pool:
import asyncpg
self._connection_pool = await asyncpg.create_pool(
self.connection_string,
min_size=5,
max_size=20,
command_timeout=60
)
async def execute(self, query: str, parameters: List[Any] = None) -> ToolResult:
"""Execute query with connection pooling"""
try:
await self.initialize()
# Check cache
cache_key = f"{query}:{str(parameters)}"
if cache_key in self._cache:
self.logger.info("Cache hit")
return ToolResult(
success=True,
data=self._cache[cache_key],
metadata={"cached": True}
)
# Execute query
async with self._connection_pool.acquire() as connection:
if parameters:
result = await connection.fetch(query, *parameters)
else:
result = await connection.fetch(query)
# Cache result
self._cache[cache_key] = [dict(row) for row in result]
return ToolResult(
success=True,
data=self._cache[cache_key],
metadata={"rows": len(result)}
)
except Exception as e:
self.logger.error(f"Database error: {str(e)}")
return ToolResult(
success=False,
data=None,
error=str(e)
)
async def cleanup(self):
"""Cleanup resources"""
if self._connection_pool:
await self._connection_pool.close()
def validate_params(self, query: str, **kwargs) -> bool:
"""Validate SQL query"""
# Basic SQL injection prevention
dangerous_keywords = ["DROP", "DELETE", "TRUNCATE", "ALTER"]
query_upper = query.upper()
for keyword in dangerous_keywords:
if keyword in query_upper and "WHERE" not in query_upper:
return False
return True
def _get_parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL query to execute"
},
"parameters": {
"type": "array",
"description": "Query parameters",
"items": {"type": ["string", "number", "boolean"]}
}
},
"required": ["query"]
}
2. Streaming Tools
Copy
class StreamingDataTool(BaseTool):
"""Tool that returns streaming data"""
def __init__(self):
super().__init__(
name="stream_processor",
description="Process data streams in real-time"
)
self._active_streams = {}
async def execute(self,
stream_url: str,
process_function: Optional[str] = None,
max_items: int = 100) -> ToolResult:
"""Process streaming data"""
stream_id = hashlib.md5(stream_url.encode()).hexdigest()
async def stream_generator():
"""Generate stream items"""
async with aiohttp.ClientSession() as session:
async with session.get(stream_url) as response:
count = 0
async for line in response.content:
if count >= max_items:
break
try:
data = json.loads(line)
# Apply processing function if provided
if process_function:
data = eval(process_function)(data)
yield data
count += 1
except Exception as e:
self.logger.error(f"Stream processing error: {e}")
continue
# Store stream reference
self._active_streams[stream_id] = stream_generator()
return ToolResult(
success=True,
data={"stream_id": stream_id},
metadata={"type": "stream", "max_items": max_items}
)
async def read_stream(self, stream_id: str, batch_size: int = 10):
"""Read from active stream"""
if stream_id not in self._active_streams:
return ToolResult(
success=False,
data=None,
error="Stream not found"
)
items = []
stream = self._active_streams[stream_id]
try:
for _ in range(batch_size):
item = await anext(stream, None)
if item is None:
break
items.append(item)
return ToolResult(
success=True,
data=items,
metadata={"count": len(items)}
)
except Exception as e:
return ToolResult(
success=False,
data=None,
error=str(e)
)
3. Composite Tools
Copy
class CompositeAnalysisTool(BaseTool):
"""Tool that combines multiple sub-tools"""
def __init__(self, sub_tools: List[BaseTool]):
super().__init__(
name="composite_analyzer",
description="Perform multi-stage analysis"
)
self.sub_tools = {tool.name: tool for tool in sub_tools}
self._pipeline = []
def add_stage(self, tool_name: str, params: Dict[str, Any]):
"""Add processing stage to pipeline"""
if tool_name not in self.sub_tools:
raise ValueError(f"Tool {tool_name} not found")
self._pipeline.append({
"tool": tool_name,
"params": params
})
async def execute(self, initial_data: Any) -> ToolResult:
"""Execute pipeline of tools"""
current_data = initial_data
results = []
for stage in self._pipeline:
tool = self.sub_tools[stage["tool"]]
# Pass previous result as input
params = stage["params"].copy()
params["input_data"] = current_data
result = await tool.execute(**params)
if not result.success:
return ToolResult(
success=False,
data=results,
error=f"Stage {stage['tool']} failed: {result.error}"
)
current_data = result.data
results.append({
"stage": stage["tool"],
"result": result.data
})
return ToolResult(
success=True,
data={
"final_result": current_data,
"stages": results
},
metadata={"stages_completed": len(results)}
)
Tool Integration Patterns
1. REST API Integration
Copy
class RESTAPITool(BaseTool):
"""Generic REST API integration tool"""
def __init__(self,
base_url: str,
auth_type: str = "bearer",
rate_limit: int = 100):
super().__init__(
name="rest_api",
description="Interact with REST APIs"
)
self.base_url = base_url
self.auth_type = auth_type
self.rate_limit = rate_limit
self._rate_limiter = RateLimiter(rate_limit)
self._session = None
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create session with retry logic"""
if not self._session:
timeout = aiohttp.ClientTimeout(total=30)
connector = aiohttp.TCPConnector(limit=50)
retry_options = ExponentialRetry(
attempts=3,
start_timeout=1,
max_timeout=10
)
self._session = RetryClient(
connector=connector,
timeout=timeout,
retry_options=retry_options
)
return self._session
async def execute(self,
method: str,
endpoint: str,
data: Optional[Dict] = None,
headers: Optional[Dict] = None) -> ToolResult:
"""Execute API request"""
# Rate limiting
await self._rate_limiter.acquire()
url = f"{self.base_url}/{endpoint.lstrip('/')}"
session = await self._get_session()
# Prepare headers
request_headers = {
"Content-Type": "application/json",
"User-Agent": "PraisonAI-Tool/1.0"
}
if headers:
request_headers.update(headers)
# Add authentication
if self.auth_type == "bearer" and "Authorization" not in request_headers:
token = os.getenv("API_TOKEN")
if token:
request_headers["Authorization"] = f"Bearer {token}"
try:
async with session.request(
method=method.upper(),
url=url,
json=data,
headers=request_headers
) as response:
response_data = await response.json()
if response.status >= 400:
return ToolResult(
success=False,
data=response_data,
error=f"API error: {response.status}",
metadata={"status_code": response.status}
)
return ToolResult(
success=True,
data=response_data,
metadata={
"status_code": response.status,
"headers": dict(response.headers)
}
)
except Exception as e:
self.logger.error(f"API request failed: {str(e)}")
return ToolResult(
success=False,
data=None,
error=str(e)
)
2. WebSocket Integration
Copy
class WebSocketTool(BaseTool):
"""WebSocket communication tool"""
def __init__(self, ws_url: str):
super().__init__(
name="websocket",
description="Real-time WebSocket communication"
)
self.ws_url = ws_url
self._connections = {}
self._message_handlers = {}
async def connect(self, connection_id: str) -> ToolResult:
"""Establish WebSocket connection"""
try:
session = aiohttp.ClientSession()
ws = await session.ws_connect(self.ws_url)
self._connections[connection_id] = {
"websocket": ws,
"session": session,
"active": True
}
# Start message listener
asyncio.create_task(
self._listen_messages(connection_id, ws)
)
return ToolResult(
success=True,
data={"connection_id": connection_id},
metadata={"status": "connected"}
)
except Exception as e:
return ToolResult(
success=False,
data=None,
error=str(e)
)
async def _listen_messages(self, connection_id: str, ws):
"""Listen for incoming messages"""
try:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
# Handle message
if connection_id in self._message_handlers:
await self._message_handlers[connection_id](msg.data)
elif msg.type == aiohttp.WSMsgType.ERROR:
self.logger.error(f"WebSocket error: {ws.exception()}")
break
except Exception as e:
self.logger.error(f"Message listener error: {e}")
finally:
await self.disconnect(connection_id)
async def send_message(self,
connection_id: str,
message: Union[str, Dict]) -> ToolResult:
"""Send message through WebSocket"""
if connection_id not in self._connections:
return ToolResult(
success=False,
data=None,
error="Connection not found"
)
try:
ws = self._connections[connection_id]["websocket"]
if isinstance(message, dict):
message = json.dumps(message)
await ws.send_str(message)
return ToolResult(
success=True,
data={"sent": message},
metadata={"timestamp": time.time()}
)
except Exception as e:
return ToolResult(
success=False,
data=None,
error=str(e)
)
Performance Optimization
1. Caching Strategies
Copy
class CachedTool(BaseTool):
"""Tool with advanced caching capabilities"""
def __init__(self,
wrapped_tool: BaseTool,
cache_backend: str = "memory",
ttl: int = 3600):
super().__init__(
name=f"cached_{wrapped_tool.name}",
description=f"Cached version of {wrapped_tool.name}"
)
self.wrapped_tool = wrapped_tool
self.ttl = ttl
# Initialize cache backend
if cache_backend == "memory":
self.cache = InMemoryCache()
elif cache_backend == "redis":
self.cache = RedisCache()
else:
raise ValueError(f"Unknown cache backend: {cache_backend}")
def _generate_cache_key(self, **kwargs) -> str:
"""Generate cache key from parameters"""
# Sort parameters for consistent keys
sorted_params = sorted(kwargs.items())
param_str = json.dumps(sorted_params, sort_keys=True)
return hashlib.sha256(param_str.encode()).hexdigest()
async def execute(self, **kwargs) -> ToolResult:
"""Execute with caching"""
cache_key = self._generate_cache_key(**kwargs)
# Try cache first
cached_result = await self.cache.get(cache_key)
if cached_result:
self.logger.info("Cache hit")
return ToolResult(
success=True,
data=cached_result,
metadata={"cached": True}
)
# Execute actual tool
result = await self.wrapped_tool.execute(**kwargs)
# Cache successful results
if result.success:
await self.cache.set(
cache_key,
result.data,
ttl=self.ttl
)
return result
class InMemoryCache:
"""Simple in-memory cache with TTL"""
def __init__(self):
self._cache = {}
self._expiry = {}
async def get(self, key: str) -> Optional[Any]:
if key in self._cache:
if time.time() < self._expiry.get(key, 0):
return self._cache[key]
else:
# Expired
del self._cache[key]
del self._expiry[key]
return None
async def set(self, key: str, value: Any, ttl: int):
self._cache[key] = value
self._expiry[key] = time.time() + ttl
2. Parallel Execution
Copy
class ParallelExecutorTool(BaseTool):
"""Execute multiple operations in parallel"""
def __init__(self, max_concurrent: int = 10):
super().__init__(
name="parallel_executor",
description="Execute multiple operations concurrently"
)
self.max_concurrent = max_concurrent
self._semaphore = asyncio.Semaphore(max_concurrent)
async def execute(self,
operations: List[Dict[str, Any]]) -> ToolResult:
"""Execute operations in parallel"""
async def execute_with_semaphore(operation):
async with self._semaphore:
tool_name = operation["tool"]
params = operation.get("params", {})
# Dynamic tool loading
tool = self._load_tool(tool_name)
if not tool:
return {
"operation": operation,
"error": f"Tool {tool_name} not found"
}
try:
result = await tool.execute(**params)
return {
"operation": operation,
"result": result.data,
"success": result.success
}
except Exception as e:
return {
"operation": operation,
"error": str(e)
}
# Execute all operations
tasks = [
execute_with_semaphore(op)
for op in operations
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
successful = [r for r in results if isinstance(r, dict) and r.get("success")]
failed = [r for r in results if isinstance(r, dict) and not r.get("success")]
return ToolResult(
success=len(failed) == 0,
data={
"results": results,
"summary": {
"total": len(operations),
"successful": len(successful),
"failed": len(failed)
}
},
metadata={"execution_time": time.time()}
)
Security Considerations
1. Input Validation
Copy
class SecureTool(BaseTool):
"""Tool with comprehensive security measures"""
def __init__(self):
super().__init__(
name="secure_tool",
description="Security-focused tool implementation"
)
self.validators = {
"sql": self._validate_sql,
"path": self._validate_path,
"url": self._validate_url,
"script": self._validate_script
}
def _validate_sql(self, query: str) -> bool:
"""Validate SQL queries"""
# SQL injection patterns
dangerous_patterns = [
r";\s*DROP",
r";\s*DELETE",
r";\s*UPDATE",
r";\s*INSERT",
r"--",
r"/\*.*\*/",
r"UNION\s+SELECT",
r"OR\s+1\s*=\s*1"
]
for pattern in dangerous_patterns:
if re.search(pattern, query, re.IGNORECASE):
return False
return True
def _validate_path(self, path: str) -> bool:
"""Validate file paths"""
# Path traversal prevention
if ".." in path or path.startswith("/"):
return False
# Allowed directories
allowed_dirs = ["/tmp", "/data", "/uploads"]
resolved_path = os.path.abspath(path)
return any(
resolved_path.startswith(allowed_dir)
for allowed_dir in allowed_dirs
)
def _validate_url(self, url: str) -> bool:
"""Validate URLs"""
try:
parsed = urlparse(url)
# Check protocol
if parsed.scheme not in ["http", "https"]:
return False
# Check against blocklist
blocked_domains = ["localhost", "127.0.0.1", "0.0.0.0"]
if parsed.hostname in blocked_domains:
return False
# Check for private IPs
try:
ip = socket.gethostbyname(parsed.hostname)
if ipaddress.ip_address(ip).is_private:
return False
except:
pass
return True
except Exception:
return False
def _validate_script(self, script: str) -> bool:
"""Validate scripts for dangerous operations"""
dangerous_imports = [
"os", "subprocess", "eval", "exec",
"__import__", "compile", "globals"
]
for dangerous in dangerous_imports:
if dangerous in script:
return False
return True
2. Sandboxed Execution
Copy
class SandboxedCodeTool(BaseTool):
"""Execute code in a sandboxed environment"""
def __init__(self):
super().__init__(
name="sandboxed_code",
description="Execute code safely in a sandbox"
)
self.sandbox_config = {
"max_execution_time": 5, # seconds
"max_memory": 100 * 1024 * 1024, # 100MB
"allowed_modules": ["math", "json", "datetime"]
}
async def execute(self, code: str, language: str = "python") -> ToolResult:
"""Execute code in sandbox"""
if language == "python":
return await self._execute_python(code)
else:
return ToolResult(
success=False,
data=None,
error=f"Unsupported language: {language}"
)
async def _execute_python(self, code: str) -> ToolResult:
"""Execute Python code in sandbox"""
# Create restricted globals
restricted_globals = {
"__builtins__": {
"len": len,
"range": range,
"str": str,
"int": int,
"float": float,
"list": list,
"dict": dict,
"print": print
}
}
# Add allowed modules
for module in self.sandbox_config["allowed_modules"]:
restricted_globals[module] = __import__(module)
# Execute with timeout
try:
# Use RestrictedPython for additional safety
from RestrictedPython import compile_restricted
compiled = compile_restricted(code, '<string>', 'exec')
if compiled.errors:
return ToolResult(
success=False,
data=None,
error=f"Compilation errors: {compiled.errors}"
)
# Execute in separate process with resource limits
result = await asyncio.wait_for(
self._run_in_process(compiled.code, restricted_globals),
timeout=self.sandbox_config["max_execution_time"]
)
return ToolResult(
success=True,
data=result,
metadata={"sandboxed": True}
)
except asyncio.TimeoutError:
return ToolResult(
success=False,
data=None,
error="Code execution timed out"
)
except Exception as e:
return ToolResult(
success=False,
data=None,
error=f"Execution error: {str(e)}"
)
Testing Tools
1. Tool Testing Framework
Copy
class ToolTestCase:
"""Base class for tool testing"""
def __init__(self, tool: BaseTool):
self.tool = tool
self.test_results = []
async def test_basic_functionality(self):
"""Test basic tool functionality"""
# Test schema generation
schema = self.tool.get_schema()
assert "name" in schema
assert "parameters" in schema
# Test parameter validation
valid_params = self._get_valid_params()
assert self.tool.validate_params(**valid_params)
# Test execution
result = await self.tool.execute(**valid_params)
assert isinstance(result, ToolResult)
return True
async def test_error_handling(self):
"""Test error scenarios"""
invalid_params = self._get_invalid_params()
for params in invalid_params:
result = await self.tool.execute(**params)
assert not result.success
assert result.error is not None
return True
async def test_performance(self, iterations: int = 100):
"""Test tool performance"""
import statistics
execution_times = []
valid_params = self._get_valid_params()
for _ in range(iterations):
start = time.time()
await self.tool.execute(**valid_params)
execution_times.append(time.time() - start)
stats = {
"mean": statistics.mean(execution_times),
"median": statistics.median(execution_times),
"stdev": statistics.stdev(execution_times) if len(execution_times) > 1 else 0,
"min": min(execution_times),
"max": max(execution_times)
}
# Assert performance requirements
assert stats["mean"] < 1.0 # Average under 1 second
assert stats["max"] < 5.0 # No request over 5 seconds
return stats
2. Mock Tool for Testing
Copy
class MockTool(BaseTool):
"""Mock tool for testing agent interactions"""
def __init__(self,
name: str = "mock_tool",
response_data: Any = None,
response_delay: float = 0.1,
fail_rate: float = 0.0):
super().__init__(name=name, description="Mock tool for testing")
self.response_data = response_data or {"status": "success"}
self.response_delay = response_delay
self.fail_rate = fail_rate
self.call_history = []
async def execute(self, **kwargs) -> ToolResult:
"""Execute mock tool"""
# Record call
self.call_history.append({
"timestamp": time.time(),
"params": kwargs
})
# Simulate delay
await asyncio.sleep(self.response_delay)
# Simulate failures
if random.random() < self.fail_rate:
return ToolResult(
success=False,
data=None,
error="Simulated failure"
)
# Return configured response
return ToolResult(
success=True,
data=self.response_data,
metadata={"mock": True}
)
def assert_called_with(self, **expected_params):
"""Assert tool was called with specific parameters"""
for call in self.call_history:
if all(
call["params"].get(k) == v
for k, v in expected_params.items()
):
return True
raise AssertionError(
f"Tool not called with params: {expected_params}"
)
Tool Documentation
1. Auto-Documentation
Copy
class DocumentedTool(BaseTool):
"""Tool with auto-generated documentation"""
def __init__(self):
super().__init__(
name="documented_tool",
description="Example of well-documented tool"
)
self._examples = []
def add_example(self,
params: Dict[str, Any],
expected_result: Any,
description: str):
"""Add usage example"""
self._examples.append({
"params": params,
"expected_result": expected_result,
"description": description
})
def generate_docs(self) -> str:
"""Generate markdown documentation"""
docs = f"""# {self.name}
## Description
{self.description}
## Parameters
```json
{json.dumps(self._get_parameters(), indent=2)}
Examples
""" for i, example in enumerate(self._examples, 1): docs += f"""Example :
Input:Copy
{json.dumps(example['params'], indent=2)}
Copy
{json.dumps(example['expected_result'], indent=2)}
Copy
## Best Practices
### 1. Tool Design Principles
#### Single Responsibility
Each tool should do one thing well. Avoid creating "Swiss Army knife" tools.
**Good - focused tool:**
```python
class EmailSenderTool(BaseTool):
async def execute(self, to: str, subject: str, body: str):
pass # Only handles email sending
Copy
class CommunicationTool(BaseTool):
async def execute(self, type: str, **kwargs):
if type == "email":
pass # send email
elif type == "sms":
pass # send sms
elif type == "slack":
pass # send slack
Idempotency
Tools should be idempotent when possible - multiple calls with same parameters should produce same result.Copy
class IdempotentTool(BaseTool):
async def execute(self, resource_id: str, **kwargs):
existing = await self.check_exists(resource_id)
if existing:
return ToolResult(
success=True,
data=existing,
metadata={"created": False}
)
result = await self.create_resource(resource_id, **kwargs)
return ToolResult(
success=True,
data=result,
metadata={"created": True}
)
Graceful Degradation
Tools should handle failures gracefully and provide meaningful feedback.Copy
class ResilientTool(BaseTool):
async def execute(self, **kwargs):
try:
return await self.primary_method(**kwargs)
except PrimaryMethodError:
try:
return await self.secondary_method(**kwargs)
except SecondaryMethodError:
return ToolResult(
success=False,
data=self.get_cached_result(),
error="Service unavailable, returning cached data"
)
2. Tool Registry Pattern
Copy
class ToolRegistry:
"""Central registry for all tools"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.tools = {}
cls._instance.categories = defaultdict(list)
return cls._instance
def register(self,
tool: BaseTool,
category: str = "general",
aliases: List[str] = None):
"""Register a tool"""
self.tools[tool.name] = tool
self.categories[category].append(tool.name)
# Register aliases
if aliases:
for alias in aliases:
self.tools[alias] = tool
def get_tool(self, name: str) -> Optional[BaseTool]:
"""Get tool by name"""
return self.tools.get(name)
def get_category(self, category: str) -> List[BaseTool]:
"""Get all tools in a category"""
return [
self.tools[name]
for name in self.categories.get(category, [])
]
def search_tools(self, query: str) -> List[BaseTool]:
"""Search tools by name or description"""
results = []
query_lower = query.lower()
for tool in self.tools.values():
if (query_lower in tool.name.lower() or
query_lower in tool.description.lower()):
results.append(tool)
return results
# Usage
registry = ToolRegistry()
registry.register(
DatabaseTool(),
category="data",
aliases=["db", "sql"]
)
Deployment and Distribution
1. Tool Packaging
Copy
# setup.py for tool package
from setuptools import setup, find_packages
setup(
name="praisonai-tools-advanced",
version="1.0.0",
packages=find_packages(),
install_requires=[
"praisonaiagents>=1.0.0",
"aiohttp>=3.8.0",
"asyncpg>=0.27.0",
],
entry_points={
"praisonai.tools": [
"database=my_tools.database:DatabaseTool",
"api=my_tools.api:APITool",
"websocket=my_tools.websocket:WebSocketTool",
]
}
)
2. Tool Configuration
Copy
# tools_config.yaml
tools:
database:
class: DatabaseTool
config:
connection_string: ${DATABASE_URL}
pool_size: 20
cache_ttl: 3600
api:
class: RESTAPITool
config:
base_url: https://api.example.com
rate_limit: 100
timeout: 30
websocket:
class: WebSocketTool
config:
url: wss://stream.example.com
reconnect_attempts: 3
Next Steps
- Explore Tool Examples
- Review Security Best Practices
- Learn about Tool Testing
- Check Performance Optimization

