Token Usage Optimization
Token usage directly impacts the cost and performance of AI-powered multi-agent systems. This guide provides strategies for optimizing token consumption while maintaining system effectiveness.Understanding Token Usage
Token Consumption Areas
- System Prompts: Initial agent instructions
- Conversation History: Accumulated context
- Tool Calls: Function descriptions and responses
- Agent Communication: Inter-agent messages
- Knowledge Retrieval: Retrieved documents and context
Optimization Strategies
1. Smart Context Management
Implement intelligent context windowing and summarization:Copy
from typing import List, Dict, Any, Tuple
import tiktoken
from dataclasses import dataclass
@dataclass
class TokenCounter:
model: str = "gpt-4"
def __post_init__(self):
self.encoder = tiktoken.encoding_for_model(self.model)
def count_tokens(self, text: str) -> int:
"""Count tokens in a text string"""
return len(self.encoder.encode(text))
def count_messages(self, messages: List[Dict[str, str]]) -> int:
"""Count tokens in a list of messages"""
total = 0
for message in messages:
total += self.count_tokens(message.get("content", ""))
total += 4 # Message overhead
return total
class OptimizedContextManager:
def __init__(self, max_tokens: int = 2000, summarization_ratio: float = 0.3):
self.max_tokens = max_tokens
self.summarization_ratio = summarization_ratio
self.token_counter = TokenCounter()
self.context_window = []
def add_message(self, message: Dict[str, str]) -> None:
"""Add message to context with automatic optimization"""
self.context_window.append(message)
self._optimize_context()
def _optimize_context(self) -> None:
"""Optimize context to stay within token limits"""
total_tokens = self.token_counter.count_messages(self.context_window)
if total_tokens > self.max_tokens:
# Calculate how many messages to summarize
target_reduction = total_tokens - (self.max_tokens * 0.8)
self._summarize_old_messages(target_reduction)
def _summarize_old_messages(self, target_reduction: int) -> None:
"""Summarize older messages to reduce token count"""
messages_to_summarize = []
current_reduction = 0
# Select messages to summarize (keep recent ones)
for i, msg in enumerate(self.context_window[:-5]): # Keep last 5 messages
msg_tokens = self.token_counter.count_tokens(msg["content"])
messages_to_summarize.append(msg)
current_reduction += msg_tokens
if current_reduction >= target_reduction:
break
if messages_to_summarize:
# Create summary (in production, use LLM for actual summarization)
summary = self._create_summary(messages_to_summarize)
# Replace messages with summary
self.context_window = [summary] + self.context_window[len(messages_to_summarize):]
def _create_summary(self, messages: List[Dict[str, str]]) -> Dict[str, str]:
"""Create a summary of messages"""
# Simplified summary - in production, use LLM
key_points = []
for msg in messages[-3:]: # Last 3 messages from batch
content = msg["content"][:100] # First 100 chars
key_points.append(content)
summary_content = f"Summary of {len(messages)} messages: " + "; ".join(key_points)
return {
"role": "system",
"content": summary_content
}
def get_optimized_context(self) -> List[Dict[str, str]]:
"""Get the optimized context window"""
return self.context_window
2. Prompt Compression
Compress prompts while maintaining effectiveness:Copy
class PromptCompressor:
def __init__(self):
self.compression_rules = {
# Common replacements to reduce tokens
"please": "",
"could you": "",
"I would like you to": "",
"can you": "",
"make sure to": "",
"be sure to": "",
"it is important that": "",
"remember to": "",
}
def compress_prompt(self, prompt: str) -> Tuple[str, float]:
"""Compress prompt and return compressed version with compression ratio"""
original_length = len(prompt)
compressed = prompt.lower()
# Apply compression rules
for verbose, concise in self.compression_rules.items():
compressed = compressed.replace(verbose, concise)
# Remove redundant whitespace
compressed = " ".join(compressed.split())
# Remove filler words (carefully)
filler_words = ["very", "really", "actually", "basically", "just"]
for filler in filler_words:
compressed = compressed.replace(f" {filler} ", " ")
compression_ratio = 1 - (len(compressed) / original_length)
return compressed.strip(), compression_ratio
def compress_instructions(self, instructions: str) -> str:
"""Compress agent instructions"""
# Convert verbose instructions to concise format
lines = instructions.strip().split('\n')
compressed_lines = []
for line in lines:
# Skip empty lines
if not line.strip():
continue
# Compress bullet points
if line.strip().startswith('-'):
compressed_lines.append(line.strip())
else:
compressed, _ = self.compress_prompt(line)
compressed_lines.append(compressed)
return '\n'.join(compressed_lines)
3. Selective Tool Loading
Load only necessary tools to reduce token overhead:Copy
class SelectiveToolLoader:
def __init__(self):
self.tool_registry = {}
self.tool_descriptions = {}
self.tool_token_costs = {}
def register_tool(self, name: str, func: callable, description: str):
"""Register a tool with its description"""
self.tool_registry[name] = func
self.tool_descriptions[name] = description
# Calculate token cost of tool description
counter = TokenCounter()
self.tool_token_costs[name] = counter.count_tokens(description)
def get_tools_for_task(self, task_description: str,
token_budget: int = 500) -> Dict[str, Any]:
"""Select tools based on task and token budget"""
# Score tools by relevance (simplified - use embeddings in production)
tool_scores = {}
for tool_name, description in self.tool_descriptions.items():
score = self._calculate_relevance(task_description, description)
tool_scores[tool_name] = score
# Select tools within token budget
selected_tools = {}
remaining_budget = token_budget
for tool_name, score in sorted(tool_scores.items(),
key=lambda x: x[1], reverse=True):
tool_cost = self.tool_token_costs[tool_name]
if tool_cost <= remaining_budget:
selected_tools[tool_name] = {
"function": self.tool_registry[tool_name],
"description": self.tool_descriptions[tool_name]
}
remaining_budget -= tool_cost
return selected_tools
def _calculate_relevance(self, task: str, tool_description: str) -> float:
"""Calculate relevance score between task and tool"""
# Simplified keyword matching - use embeddings in production
task_words = set(task.lower().split())
tool_words = set(tool_description.lower().split())
common_words = task_words.intersection(tool_words)
return len(common_words) / max(len(task_words), 1)
4. Response Caching
Cache responses to avoid redundant API calls:Copy
import hashlib
import json
from datetime import datetime, timedelta
from typing import Optional
class TokenSavingCache:
def __init__(self, ttl_hours: int = 24):
self.cache = {}
self.ttl = timedelta(hours=ttl_hours)
self.hit_count = 0
self.miss_count = 0
def _generate_cache_key(self, prompt: str, context: List[Dict]) -> str:
"""Generate a cache key from prompt and context"""
cache_data = {
"prompt": prompt,
"context": context
}
# Create hash of the data
data_str = json.dumps(cache_data, sort_keys=True)
return hashlib.sha256(data_str.encode()).hexdigest()
def get(self, prompt: str, context: List[Dict]) -> Optional[str]:
"""Get cached response if available"""
cache_key = self._generate_cache_key(prompt, context)
if cache_key in self.cache:
entry = self.cache[cache_key]
# Check if entry is still valid
if datetime.now() - entry["timestamp"] < self.ttl:
self.hit_count += 1
return entry["response"]
else:
# Remove expired entry
del self.cache[cache_key]
self.miss_count += 1
return None
def set(self, prompt: str, context: List[Dict], response: str) -> None:
"""Cache a response"""
cache_key = self._generate_cache_key(prompt, context)
self.cache[cache_key] = {
"response": response,
"timestamp": datetime.now()
}
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
total_requests = self.hit_count + self.miss_count
hit_rate = self.hit_count / max(total_requests, 1)
return {
"hit_count": self.hit_count,
"miss_count": self.miss_count,
"hit_rate": hit_rate,
"cache_size": len(self.cache),
"estimated_tokens_saved": self.hit_count * 100 # Rough estimate
}
5. Batching and Deduplication
Batch similar requests and deduplicate content:Copy
from collections import defaultdict
import asyncio
class RequestBatcher:
def __init__(self, batch_window_ms: int = 100, max_batch_size: int = 10):
self.batch_window_ms = batch_window_ms
self.max_batch_size = max_batch_size
self.pending_requests = defaultdict(list)
self.processing = False
async def add_request(self, request_type: str, content: str) -> Any:
"""Add a request to be batched"""
future = asyncio.Future()
self.pending_requests[request_type].append({
"content": content,
"future": future
})
# Start processing if not already running
if not self.processing:
asyncio.create_task(self._process_batches())
return await future
async def _process_batches(self):
"""Process pending request batches"""
self.processing = True
# Wait for batch window
await asyncio.sleep(self.batch_window_ms / 1000)
for request_type, requests in self.pending_requests.items():
if not requests:
continue
# Process in batches
for i in range(0, len(requests), self.max_batch_size):
batch = requests[i:i + self.max_batch_size]
# Deduplicate content
unique_contents = {}
for req in batch:
content_hash = hashlib.md5(req["content"].encode()).hexdigest()
if content_hash not in unique_contents:
unique_contents[content_hash] = []
unique_contents[content_hash].append(req["future"])
# Process unique requests
for content_hash, futures in unique_contents.items():
# Get original content
content = next(r["content"] for r in batch
if hashlib.md5(r["content"].encode()).hexdigest() == content_hash)
# Process request (simplified)
result = await self._process_single_request(request_type, content)
# Set result for all futures with same content
for future in futures:
future.set_result(result)
self.pending_requests.clear()
self.processing = False
async def _process_single_request(self, request_type: str, content: str) -> Any:
"""Process a single request (implement actual logic)"""
# Simulate API call
await asyncio.sleep(0.1)
return f"Processed: {content[:50]}..."
Advanced Token Optimization
1. Dynamic Model Selection
Choose appropriate models based on task complexity:Copy
class DynamicModelSelector:
def __init__(self):
self.models = {
"simple": {"name": "gpt-3.5-turbo", "cost_per_1k": 0.002, "quality": 0.7},
"standard": {"name": "gpt-4", "cost_per_1k": 0.03, "quality": 0.9},
"advanced": {"name": "gpt-4-turbo", "cost_per_1k": 0.01, "quality": 0.95}
}
def select_model(self, task_complexity: float,
quality_requirement: float,
budget_constraint: float) -> str:
"""Select optimal model based on requirements"""
best_model = None
best_score = -1
for model_type, model_info in self.models.items():
# Skip if quality requirement not met
if model_info["quality"] < quality_requirement:
continue
# Calculate score (balance quality and cost)
quality_score = model_info["quality"]
cost_score = 1 / (model_info["cost_per_1k"] + 0.001) # Inverse cost
# Weighted score
score = (quality_score * 0.6 + cost_score * 0.4)
# Apply budget constraint
if model_info["cost_per_1k"] <= budget_constraint:
score *= 1.2 # Bonus for being within budget
if score > best_score:
best_score = score
best_model = model_info["name"]
return best_model or "gpt-3.5-turbo" # Default fallback
2. Token-Aware Chunking
Split content intelligently to minimize token usage:Copy
class TokenAwareChunker:
def __init__(self, max_chunk_tokens: int = 1000):
self.max_chunk_tokens = max_chunk_tokens
self.token_counter = TokenCounter()
def chunk_text(self, text: str, overlap_tokens: int = 100) -> List[str]:
"""Chunk text with token awareness"""
sentences = self._split_into_sentences(text)
chunks = []
current_chunk = []
current_tokens = 0
for sentence in sentences:
sentence_tokens = self.token_counter.count_tokens(sentence)
# Check if adding sentence exceeds limit
if current_tokens + sentence_tokens > self.max_chunk_tokens:
if current_chunk:
chunks.append(" ".join(current_chunk))
# Start new chunk with overlap
if chunks and overlap_tokens > 0:
# Add last few sentences from previous chunk
overlap_sentences = self._get_overlap_sentences(
current_chunk, overlap_tokens
)
current_chunk = overlap_sentences
current_tokens = self.token_counter.count_tokens(
" ".join(overlap_sentences)
)
else:
current_chunk = []
current_tokens = 0
current_chunk.append(sentence)
current_tokens += sentence_tokens
# Add final chunk
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def _split_into_sentences(self, text: str) -> List[str]:
"""Split text into sentences"""
# Simple sentence splitting - use NLTK or spaCy in production
sentences = []
current = ""
for char in text:
current += char
if char in '.!?' and len(current) > 1:
sentences.append(current.strip())
current = ""
if current:
sentences.append(current.strip())
return sentences
def _get_overlap_sentences(self, sentences: List[str],
target_tokens: int) -> List[str]:
"""Get sentences for overlap from end of chunk"""
overlap = []
current_tokens = 0
for sentence in reversed(sentences):
sentence_tokens = self.token_counter.count_tokens(sentence)
if current_tokens + sentence_tokens <= target_tokens:
overlap.insert(0, sentence)
current_tokens += sentence_tokens
else:
break
return overlap
3. Semantic Compression
Use semantic similarity to remove redundant information:Copy
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
class SemanticCompressor:
def __init__(self, similarity_threshold: float = 0.85):
self.similarity_threshold = similarity_threshold
def compress_messages(self, messages: List[Dict[str, str]],
embeddings_func: callable) -> List[Dict[str, str]]:
"""Remove semantically similar messages"""
if len(messages) <= 1:
return messages
# Get embeddings for all messages
contents = [msg["content"] for msg in messages]
embeddings = embeddings_func(contents)
# Calculate similarity matrix
similarity_matrix = cosine_similarity(embeddings)
# Keep track of messages to keep
keep_indices = set([0]) # Always keep first message
for i in range(1, len(messages)):
# Check similarity with all kept messages
is_similar = False
for j in keep_indices:
if similarity_matrix[i][j] > self.similarity_threshold:
is_similar = True
break
if not is_similar:
keep_indices.add(i)
# Return filtered messages
return [messages[i] for i in sorted(keep_indices)]
Monitoring and Analytics
Token Usage Dashboard
Copy
class TokenUsageAnalytics:
def __init__(self):
self.usage_data = defaultdict(lambda: {
"input_tokens": 0,
"output_tokens": 0,
"total_cost": 0.0,
"request_count": 0
})
def record_usage(self, agent_id: str, input_tokens: int,
output_tokens: int, model: str):
"""Record token usage for an agent"""
# Model pricing (simplified)
pricing = {
"gpt-3.5-turbo": {"input": 0.001, "output": 0.002},
"gpt-4": {"input": 0.03, "output": 0.06}
}
model_pricing = pricing.get(model, pricing["gpt-3.5-turbo"])
cost = (input_tokens * model_pricing["input"] +
output_tokens * model_pricing["output"]) / 1000
self.usage_data[agent_id]["input_tokens"] += input_tokens
self.usage_data[agent_id]["output_tokens"] += output_tokens
self.usage_data[agent_id]["total_cost"] += cost
self.usage_data[agent_id]["request_count"] += 1
def get_report(self) -> Dict[str, Any]:
"""Generate usage report"""
total_input = sum(data["input_tokens"] for data in self.usage_data.values())
total_output = sum(data["output_tokens"] for data in self.usage_data.values())
total_cost = sum(data["total_cost"] for data in self.usage_data.values())
return {
"summary": {
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_tokens": total_input + total_output,
"total_cost": total_cost,
"average_cost_per_request": total_cost / max(sum(
data["request_count"] for data in self.usage_data.values()
), 1)
},
"by_agent": dict(self.usage_data),
"optimization_suggestions": self._generate_suggestions()
}
def _generate_suggestions(self) -> List[str]:
"""Generate optimization suggestions based on usage"""
suggestions = []
for agent_id, data in self.usage_data.items():
avg_input = data["input_tokens"] / max(data["request_count"], 1)
if avg_input > 2000:
suggestions.append(
f"Agent {agent_id} has high average input tokens ({avg_input:.0f}). "
"Consider context optimization."
)
if data["total_cost"] > 10:
suggestions.append(
f"Agent {agent_id} has high costs (${data['total_cost']:.2f}). "
"Consider using a lighter model for some tasks."
)
return suggestions
Best Practices
-
Set Token Budgets: Establish token budgets per agent and task
Copy
class TokenBudgetManager: def __init__(self, daily_budget: int = 1_000_000): self.daily_budget = daily_budget self.used_today = 0 self.last_reset = datetime.now() def can_proceed(self, estimated_tokens: int) -> bool: self._check_reset() return self.used_today + estimated_tokens <= self.daily_budget def consume(self, tokens: int): self.used_today += tokens def _check_reset(self): if datetime.now().date() > self.last_reset.date(): self.used_today = 0 self.last_reset = datetime.now() -
Implement Gradual Degradation: Reduce quality gracefully when approaching limits
Copy
def get_context_size_for_budget(remaining_budget: int) -> int: if remaining_budget > 5000: return 2000 # Full context elif remaining_budget > 2000: return 1000 # Reduced context else: return 500 # Minimal context -
Regular Optimization Reviews: Analyze usage patterns
Copy
def analyze_token_efficiency(usage_log: List[Dict]) -> Dict[str, float]: efficiency_metrics = {} for entry in usage_log: task_type = entry["task_type"] tokens_used = entry["tokens"] success = entry["success"] if task_type not in efficiency_metrics: efficiency_metrics[task_type] = [] efficiency_metrics[task_type].append(tokens_used if success else float('inf')) return { task: np.mean(tokens) for task, tokens in efficiency_metrics.items() }
Testing Token Optimization
Copy
import pytest
def test_context_optimization():
manager = OptimizedContextManager(max_tokens=100)
# Add messages until optimization triggers
for i in range(20):
manager.add_message({
"role": "user",
"content": f"This is message {i} with some content"
})
context = manager.get_optimized_context()
# Verify context is within limits
token_count = manager.token_counter.count_messages(context)
assert token_count <= 100
def test_prompt_compression():
compressor = PromptCompressor()
verbose = "Could you please make sure to carefully analyze this data?"
compressed, ratio = compressor.compress_prompt(verbose)
assert len(compressed) < len(verbose)
assert ratio > 0.2 # At least 20% compression

