Skip to main content

Token Usage Optimization

Token usage directly impacts the cost and performance of AI-powered multi-agent systems. This guide provides strategies for optimizing token consumption while maintaining system effectiveness.

Understanding Token Usage

Token Consumption Areas

  1. System Prompts: Initial agent instructions
  2. Conversation History: Accumulated context
  3. Tool Calls: Function descriptions and responses
  4. Agent Communication: Inter-agent messages
  5. Knowledge Retrieval: Retrieved documents and context

Optimization Strategies

1. Smart Context Management

Implement intelligent context windowing and summarization:
from typing import List, Dict, Any, Tuple
import tiktoken
from dataclasses import dataclass

@dataclass
class TokenCounter:
    model: str = "gpt-4"
    
    def __post_init__(self):
        self.encoder = tiktoken.encoding_for_model(self.model)
    
    def count_tokens(self, text: str) -> int:
        """Count tokens in a text string"""
        return len(self.encoder.encode(text))
    
    def count_messages(self, messages: List[Dict[str, str]]) -> int:
        """Count tokens in a list of messages"""
        total = 0
        for message in messages:
            total += self.count_tokens(message.get("content", ""))
            total += 4  # Message overhead
        return total

class OptimizedContextManager:
    def __init__(self, max_tokens: int = 2000, summarization_ratio: float = 0.3):
        self.max_tokens = max_tokens
        self.summarization_ratio = summarization_ratio
        self.token_counter = TokenCounter()
        self.context_window = []
        
    def add_message(self, message: Dict[str, str]) -> None:
        """Add message to context with automatic optimization"""
        self.context_window.append(message)
        self._optimize_context()
    
    def _optimize_context(self) -> None:
        """Optimize context to stay within token limits"""
        total_tokens = self.token_counter.count_messages(self.context_window)
        
        if total_tokens > self.max_tokens:
            # Calculate how many messages to summarize
            target_reduction = total_tokens - (self.max_tokens * 0.8)
            self._summarize_old_messages(target_reduction)
    
    def _summarize_old_messages(self, target_reduction: int) -> None:
        """Summarize older messages to reduce token count"""
        messages_to_summarize = []
        current_reduction = 0
        
        # Select messages to summarize (keep recent ones)
        for i, msg in enumerate(self.context_window[:-5]):  # Keep last 5 messages
            msg_tokens = self.token_counter.count_tokens(msg["content"])
            messages_to_summarize.append(msg)
            current_reduction += msg_tokens
            
            if current_reduction >= target_reduction:
                break
        
        if messages_to_summarize:
            # Create summary (in production, use LLM for actual summarization)
            summary = self._create_summary(messages_to_summarize)
            
            # Replace messages with summary
            self.context_window = [summary] + self.context_window[len(messages_to_summarize):]
    
    def _create_summary(self, messages: List[Dict[str, str]]) -> Dict[str, str]:
        """Create a summary of messages"""
        # Simplified summary - in production, use LLM
        key_points = []
        for msg in messages[-3:]:  # Last 3 messages from batch
            content = msg["content"][:100]  # First 100 chars
            key_points.append(content)
        
        summary_content = f"Summary of {len(messages)} messages: " + "; ".join(key_points)
        
        return {
            "role": "system",
            "content": summary_content
        }
    
    def get_optimized_context(self) -> List[Dict[str, str]]:
        """Get the optimized context window"""
        return self.context_window

2. Prompt Compression

Compress prompts while maintaining effectiveness:
class PromptCompressor:
    def __init__(self):
        self.compression_rules = {
            # Common replacements to reduce tokens
            "please": "",
            "could you": "",
            "I would like you to": "",
            "can you": "",
            "make sure to": "",
            "be sure to": "",
            "it is important that": "",
            "remember to": "",
        }
        
    def compress_prompt(self, prompt: str) -> Tuple[str, float]:
        """Compress prompt and return compressed version with compression ratio"""
        original_length = len(prompt)
        compressed = prompt.lower()
        
        # Apply compression rules
        for verbose, concise in self.compression_rules.items():
            compressed = compressed.replace(verbose, concise)
        
        # Remove redundant whitespace
        compressed = " ".join(compressed.split())
        
        # Remove filler words (carefully)
        filler_words = ["very", "really", "actually", "basically", "just"]
        for filler in filler_words:
            compressed = compressed.replace(f" {filler} ", " ")
        
        compression_ratio = 1 - (len(compressed) / original_length)
        
        return compressed.strip(), compression_ratio
    
    def compress_instructions(self, instructions: str) -> str:
        """Compress agent instructions"""
        # Convert verbose instructions to concise format
        lines = instructions.strip().split('\n')
        compressed_lines = []
        
        for line in lines:
            # Skip empty lines
            if not line.strip():
                continue
            
            # Compress bullet points
            if line.strip().startswith('-'):
                compressed_lines.append(line.strip())
            else:
                compressed, _ = self.compress_prompt(line)
                compressed_lines.append(compressed)
        
        return '\n'.join(compressed_lines)

3. Selective Tool Loading

Load only necessary tools to reduce token overhead:
class SelectiveToolLoader:
    def __init__(self):
        self.tool_registry = {}
        self.tool_descriptions = {}
        self.tool_token_costs = {}
        
    def register_tool(self, name: str, func: callable, description: str):
        """Register a tool with its description"""
        self.tool_registry[name] = func
        self.tool_descriptions[name] = description
        
        # Calculate token cost of tool description
        counter = TokenCounter()
        self.tool_token_costs[name] = counter.count_tokens(description)
    
    def get_tools_for_task(self, task_description: str, 
                          token_budget: int = 500) -> Dict[str, Any]:
        """Select tools based on task and token budget"""
        # Score tools by relevance (simplified - use embeddings in production)
        tool_scores = {}
        
        for tool_name, description in self.tool_descriptions.items():
            score = self._calculate_relevance(task_description, description)
            tool_scores[tool_name] = score
        
        # Select tools within token budget
        selected_tools = {}
        remaining_budget = token_budget
        
        for tool_name, score in sorted(tool_scores.items(), 
                                     key=lambda x: x[1], reverse=True):
            tool_cost = self.tool_token_costs[tool_name]
            
            if tool_cost <= remaining_budget:
                selected_tools[tool_name] = {
                    "function": self.tool_registry[tool_name],
                    "description": self.tool_descriptions[tool_name]
                }
                remaining_budget -= tool_cost
        
        return selected_tools
    
    def _calculate_relevance(self, task: str, tool_description: str) -> float:
        """Calculate relevance score between task and tool"""
        # Simplified keyword matching - use embeddings in production
        task_words = set(task.lower().split())
        tool_words = set(tool_description.lower().split())
        
        common_words = task_words.intersection(tool_words)
        return len(common_words) / max(len(task_words), 1)

4. Response Caching

Cache responses to avoid redundant API calls:
import hashlib
import json
from datetime import datetime, timedelta
from typing import Optional

class TokenSavingCache:
    def __init__(self, ttl_hours: int = 24):
        self.cache = {}
        self.ttl = timedelta(hours=ttl_hours)
        self.hit_count = 0
        self.miss_count = 0
        
    def _generate_cache_key(self, prompt: str, context: List[Dict]) -> str:
        """Generate a cache key from prompt and context"""
        cache_data = {
            "prompt": prompt,
            "context": context
        }
        
        # Create hash of the data
        data_str = json.dumps(cache_data, sort_keys=True)
        return hashlib.sha256(data_str.encode()).hexdigest()
    
    def get(self, prompt: str, context: List[Dict]) -> Optional[str]:
        """Get cached response if available"""
        cache_key = self._generate_cache_key(prompt, context)
        
        if cache_key in self.cache:
            entry = self.cache[cache_key]
            
            # Check if entry is still valid
            if datetime.now() - entry["timestamp"] < self.ttl:
                self.hit_count += 1
                return entry["response"]
            else:
                # Remove expired entry
                del self.cache[cache_key]
        
        self.miss_count += 1
        return None
    
    def set(self, prompt: str, context: List[Dict], response: str) -> None:
        """Cache a response"""
        cache_key = self._generate_cache_key(prompt, context)
        
        self.cache[cache_key] = {
            "response": response,
            "timestamp": datetime.now()
        }
    
    def get_stats(self) -> Dict[str, Any]:
        """Get cache statistics"""
        total_requests = self.hit_count + self.miss_count
        hit_rate = self.hit_count / max(total_requests, 1)
        
        return {
            "hit_count": self.hit_count,
            "miss_count": self.miss_count,
            "hit_rate": hit_rate,
            "cache_size": len(self.cache),
            "estimated_tokens_saved": self.hit_count * 100  # Rough estimate
        }

5. Batching and Deduplication

Batch similar requests and deduplicate content:
from collections import defaultdict
import asyncio

class RequestBatcher:
    def __init__(self, batch_window_ms: int = 100, max_batch_size: int = 10):
        self.batch_window_ms = batch_window_ms
        self.max_batch_size = max_batch_size
        self.pending_requests = defaultdict(list)
        self.processing = False
        
    async def add_request(self, request_type: str, content: str) -> Any:
        """Add a request to be batched"""
        future = asyncio.Future()
        
        self.pending_requests[request_type].append({
            "content": content,
            "future": future
        })
        
        # Start processing if not already running
        if not self.processing:
            asyncio.create_task(self._process_batches())
        
        return await future
    
    async def _process_batches(self):
        """Process pending request batches"""
        self.processing = True
        
        # Wait for batch window
        await asyncio.sleep(self.batch_window_ms / 1000)
        
        for request_type, requests in self.pending_requests.items():
            if not requests:
                continue
            
            # Process in batches
            for i in range(0, len(requests), self.max_batch_size):
                batch = requests[i:i + self.max_batch_size]
                
                # Deduplicate content
                unique_contents = {}
                for req in batch:
                    content_hash = hashlib.md5(req["content"].encode()).hexdigest()
                    if content_hash not in unique_contents:
                        unique_contents[content_hash] = []
                    unique_contents[content_hash].append(req["future"])
                
                # Process unique requests
                for content_hash, futures in unique_contents.items():
                    # Get original content
                    content = next(r["content"] for r in batch 
                                 if hashlib.md5(r["content"].encode()).hexdigest() == content_hash)
                    
                    # Process request (simplified)
                    result = await self._process_single_request(request_type, content)
                    
                    # Set result for all futures with same content
                    for future in futures:
                        future.set_result(result)
        
        self.pending_requests.clear()
        self.processing = False
    
    async def _process_single_request(self, request_type: str, content: str) -> Any:
        """Process a single request (implement actual logic)"""
        # Simulate API call
        await asyncio.sleep(0.1)
        return f"Processed: {content[:50]}..."

Advanced Token Optimization

1. Dynamic Model Selection

Choose appropriate models based on task complexity:
class DynamicModelSelector:
    def __init__(self):
        self.models = {
            "simple": {"name": "gpt-3.5-turbo", "cost_per_1k": 0.002, "quality": 0.7},
            "standard": {"name": "gpt-4", "cost_per_1k": 0.03, "quality": 0.9},
            "advanced": {"name": "gpt-4-turbo", "cost_per_1k": 0.01, "quality": 0.95}
        }
        
    def select_model(self, task_complexity: float, 
                    quality_requirement: float,
                    budget_constraint: float) -> str:
        """Select optimal model based on requirements"""
        best_model = None
        best_score = -1
        
        for model_type, model_info in self.models.items():
            # Skip if quality requirement not met
            if model_info["quality"] < quality_requirement:
                continue
            
            # Calculate score (balance quality and cost)
            quality_score = model_info["quality"]
            cost_score = 1 / (model_info["cost_per_1k"] + 0.001)  # Inverse cost
            
            # Weighted score
            score = (quality_score * 0.6 + cost_score * 0.4)
            
            # Apply budget constraint
            if model_info["cost_per_1k"] <= budget_constraint:
                score *= 1.2  # Bonus for being within budget
            
            if score > best_score:
                best_score = score
                best_model = model_info["name"]
        
        return best_model or "gpt-3.5-turbo"  # Default fallback

2. Token-Aware Chunking

Split content intelligently to minimize token usage:
class TokenAwareChunker:
    def __init__(self, max_chunk_tokens: int = 1000):
        self.max_chunk_tokens = max_chunk_tokens
        self.token_counter = TokenCounter()
        
    def chunk_text(self, text: str, overlap_tokens: int = 100) -> List[str]:
        """Chunk text with token awareness"""
        sentences = self._split_into_sentences(text)
        chunks = []
        current_chunk = []
        current_tokens = 0
        
        for sentence in sentences:
            sentence_tokens = self.token_counter.count_tokens(sentence)
            
            # Check if adding sentence exceeds limit
            if current_tokens + sentence_tokens > self.max_chunk_tokens:
                if current_chunk:
                    chunks.append(" ".join(current_chunk))
                
                # Start new chunk with overlap
                if chunks and overlap_tokens > 0:
                    # Add last few sentences from previous chunk
                    overlap_sentences = self._get_overlap_sentences(
                        current_chunk, overlap_tokens
                    )
                    current_chunk = overlap_sentences
                    current_tokens = self.token_counter.count_tokens(
                        " ".join(overlap_sentences)
                    )
                else:
                    current_chunk = []
                    current_tokens = 0
            
            current_chunk.append(sentence)
            current_tokens += sentence_tokens
        
        # Add final chunk
        if current_chunk:
            chunks.append(" ".join(current_chunk))
        
        return chunks
    
    def _split_into_sentences(self, text: str) -> List[str]:
        """Split text into sentences"""
        # Simple sentence splitting - use NLTK or spaCy in production
        sentences = []
        current = ""
        
        for char in text:
            current += char
            if char in '.!?' and len(current) > 1:
                sentences.append(current.strip())
                current = ""
        
        if current:
            sentences.append(current.strip())
        
        return sentences
    
    def _get_overlap_sentences(self, sentences: List[str], 
                              target_tokens: int) -> List[str]:
        """Get sentences for overlap from end of chunk"""
        overlap = []
        current_tokens = 0
        
        for sentence in reversed(sentences):
            sentence_tokens = self.token_counter.count_tokens(sentence)
            if current_tokens + sentence_tokens <= target_tokens:
                overlap.insert(0, sentence)
                current_tokens += sentence_tokens
            else:
                break
        
        return overlap

3. Semantic Compression

Use semantic similarity to remove redundant information:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

class SemanticCompressor:
    def __init__(self, similarity_threshold: float = 0.85):
        self.similarity_threshold = similarity_threshold
        
    def compress_messages(self, messages: List[Dict[str, str]], 
                         embeddings_func: callable) -> List[Dict[str, str]]:
        """Remove semantically similar messages"""
        if len(messages) <= 1:
            return messages
        
        # Get embeddings for all messages
        contents = [msg["content"] for msg in messages]
        embeddings = embeddings_func(contents)
        
        # Calculate similarity matrix
        similarity_matrix = cosine_similarity(embeddings)
        
        # Keep track of messages to keep
        keep_indices = set([0])  # Always keep first message
        
        for i in range(1, len(messages)):
            # Check similarity with all kept messages
            is_similar = False
            
            for j in keep_indices:
                if similarity_matrix[i][j] > self.similarity_threshold:
                    is_similar = True
                    break
            
            if not is_similar:
                keep_indices.add(i)
        
        # Return filtered messages
        return [messages[i] for i in sorted(keep_indices)]

Monitoring and Analytics

Token Usage Dashboard

class TokenUsageAnalytics:
    def __init__(self):
        self.usage_data = defaultdict(lambda: {
            "input_tokens": 0,
            "output_tokens": 0,
            "total_cost": 0.0,
            "request_count": 0
        })
        
    def record_usage(self, agent_id: str, input_tokens: int, 
                    output_tokens: int, model: str):
        """Record token usage for an agent"""
        # Model pricing (simplified)
        pricing = {
            "gpt-3.5-turbo": {"input": 0.001, "output": 0.002},
            "gpt-4": {"input": 0.03, "output": 0.06}
        }
        
        model_pricing = pricing.get(model, pricing["gpt-3.5-turbo"])
        
        cost = (input_tokens * model_pricing["input"] + 
                output_tokens * model_pricing["output"]) / 1000
        
        self.usage_data[agent_id]["input_tokens"] += input_tokens
        self.usage_data[agent_id]["output_tokens"] += output_tokens
        self.usage_data[agent_id]["total_cost"] += cost
        self.usage_data[agent_id]["request_count"] += 1
    
    def get_report(self) -> Dict[str, Any]:
        """Generate usage report"""
        total_input = sum(data["input_tokens"] for data in self.usage_data.values())
        total_output = sum(data["output_tokens"] for data in self.usage_data.values())
        total_cost = sum(data["total_cost"] for data in self.usage_data.values())
        
        return {
            "summary": {
                "total_input_tokens": total_input,
                "total_output_tokens": total_output,
                "total_tokens": total_input + total_output,
                "total_cost": total_cost,
                "average_cost_per_request": total_cost / max(sum(
                    data["request_count"] for data in self.usage_data.values()
                ), 1)
            },
            "by_agent": dict(self.usage_data),
            "optimization_suggestions": self._generate_suggestions()
        }
    
    def _generate_suggestions(self) -> List[str]:
        """Generate optimization suggestions based on usage"""
        suggestions = []
        
        for agent_id, data in self.usage_data.items():
            avg_input = data["input_tokens"] / max(data["request_count"], 1)
            
            if avg_input > 2000:
                suggestions.append(
                    f"Agent {agent_id} has high average input tokens ({avg_input:.0f}). "
                    "Consider context optimization."
                )
            
            if data["total_cost"] > 10:
                suggestions.append(
                    f"Agent {agent_id} has high costs (${data['total_cost']:.2f}). "
                    "Consider using a lighter model for some tasks."
                )
        
        return suggestions

Best Practices

  1. Set Token Budgets: Establish token budgets per agent and task
    class TokenBudgetManager:
        def __init__(self, daily_budget: int = 1_000_000):
            self.daily_budget = daily_budget
            self.used_today = 0
            self.last_reset = datetime.now()
            
        def can_proceed(self, estimated_tokens: int) -> bool:
            self._check_reset()
            return self.used_today + estimated_tokens <= self.daily_budget
            
        def consume(self, tokens: int):
            self.used_today += tokens
            
        def _check_reset(self):
            if datetime.now().date() > self.last_reset.date():
                self.used_today = 0
                self.last_reset = datetime.now()
    
  2. Implement Gradual Degradation: Reduce quality gracefully when approaching limits
    def get_context_size_for_budget(remaining_budget: int) -> int:
        if remaining_budget > 5000:
            return 2000  # Full context
        elif remaining_budget > 2000:
            return 1000  # Reduced context
        else:
            return 500   # Minimal context
    
  3. Regular Optimization Reviews: Analyze usage patterns
    def analyze_token_efficiency(usage_log: List[Dict]) -> Dict[str, float]:
        efficiency_metrics = {}
        
        for entry in usage_log:
            task_type = entry["task_type"]
            tokens_used = entry["tokens"]
            success = entry["success"]
            
            if task_type not in efficiency_metrics:
                efficiency_metrics[task_type] = []
            
            efficiency_metrics[task_type].append(tokens_used if success else float('inf'))
        
        return {
            task: np.mean(tokens) for task, tokens in efficiency_metrics.items()
        }
    

Testing Token Optimization

import pytest

def test_context_optimization():
    manager = OptimizedContextManager(max_tokens=100)
    
    # Add messages until optimization triggers
    for i in range(20):
        manager.add_message({
            "role": "user",
            "content": f"This is message {i} with some content"
        })
    
    context = manager.get_optimized_context()
    
    # Verify context is within limits
    token_count = manager.token_counter.count_messages(context)
    assert token_count <= 100

def test_prompt_compression():
    compressor = PromptCompressor()
    
    verbose = "Could you please make sure to carefully analyze this data?"
    compressed, ratio = compressor.compress_prompt(verbose)
    
    assert len(compressed) < len(verbose)
    assert ratio > 0.2  # At least 20% compression

Conclusion

Effective token optimization requires a multi-faceted approach combining smart context management, caching, batching, and continuous monitoring. By implementing these strategies, you can significantly reduce costs while maintaining system performance.