Multi-User Session Handling
Managing multiple concurrent user sessions is crucial for production multi-agent systems. This guide covers strategies for isolating user contexts, managing resources, and ensuring security.Core Concepts
Session Isolation Requirements
- Data Isolation: Each user’s data must be completely isolated
- Resource Isolation: Prevent resource exhaustion by one user
- Context Isolation: Maintain separate conversation contexts
- Security Isolation: Prevent cross-session data leakage
- Performance Isolation: One user shouldn’t impact others
Session Management Architecture
1. Session Manager Implementation
Copy
import uuid
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List
import threading
from contextlib import contextmanager
class UserSession:
def __init__(self, session_id: str, user_id: str, metadata: Dict[str, Any] = None):
self.session_id = session_id
self.user_id = user_id
self.created_at = datetime.now()
self.last_activity = datetime.now()
self.metadata = metadata or {}
self.context = []
self.agents = {}
self.resources = {}
self.is_active = True
self._lock = threading.RLock()
def update_activity(self):
"""Update last activity timestamp"""
with self._lock:
self.last_activity = datetime.now()
def add_context(self, message: Dict[str, Any]):
"""Add message to session context"""
with self._lock:
self.context.append({
**message,
"timestamp": datetime.now()
})
def get_context(self, last_n: Optional[int] = None) -> List[Dict[str, Any]]:
"""Get session context"""
with self._lock:
if last_n:
return self.context[-last_n:]
return self.context.copy()
class MultiUserSessionManager:
def __init__(self, max_sessions_per_user: int = 5,
session_timeout_minutes: int = 30):
self.sessions: Dict[str, UserSession] = {}
self.user_sessions: Dict[str, List[str]] = {}
self.max_sessions_per_user = max_sessions_per_user
self.session_timeout = timedelta(minutes=session_timeout_minutes)
self._lock = threading.RLock()
self._cleanup_thread = None
self._start_cleanup_thread()
def create_session(self, user_id: str, metadata: Dict[str, Any] = None) -> str:
"""Create a new session for a user"""
with self._lock:
# Check session limit
if user_id in self.user_sessions:
if len(self.user_sessions[user_id]) >= self.max_sessions_per_user:
# Remove oldest session
oldest_session_id = self._get_oldest_session(user_id)
self.end_session(oldest_session_id)
# Create new session
session_id = str(uuid.uuid4())
session = UserSession(session_id, user_id, metadata)
self.sessions[session_id] = session
if user_id not in self.user_sessions:
self.user_sessions[user_id] = []
self.user_sessions[user_id].append(session_id)
return session_id
@contextmanager
def get_session(self, session_id: str):
"""Get session with automatic activity update"""
session = self._get_session(session_id)
if not session:
raise ValueError(f"Session {session_id} not found")
session.update_activity()
yield session
def _get_session(self, session_id: str) -> Optional[UserSession]:
"""Get session by ID"""
with self._lock:
return self.sessions.get(session_id)
def end_session(self, session_id: str):
"""End a session and cleanup resources"""
with self._lock:
session = self.sessions.get(session_id)
if not session:
return
# Cleanup session resources
self._cleanup_session_resources(session)
# Remove from tracking
del self.sessions[session_id]
if session.user_id in self.user_sessions:
self.user_sessions[session.user_id].remove(session_id)
if not self.user_sessions[session.user_id]:
del self.user_sessions[session.user_id]
def _cleanup_session_resources(self, session: UserSession):
"""Cleanup resources associated with a session"""
# Cleanup agents
for agent_id, agent in session.agents.items():
if hasattr(agent, 'cleanup'):
agent.cleanup()
# Clear context to free memory
session.context.clear()
# Mark as inactive
session.is_active = False
def _get_oldest_session(self, user_id: str) -> Optional[str]:
"""Get the oldest session for a user"""
if user_id not in self.user_sessions:
return None
oldest_session_id = None
oldest_time = datetime.now()
for session_id in self.user_sessions[user_id]:
session = self.sessions.get(session_id)
if session and session.created_at < oldest_time:
oldest_time = session.created_at
oldest_session_id = session_id
return oldest_session_id
def _cleanup_expired_sessions(self):
"""Remove expired sessions"""
with self._lock:
current_time = datetime.now()
expired_sessions = []
for session_id, session in self.sessions.items():
if current_time - session.last_activity > self.session_timeout:
expired_sessions.append(session_id)
for session_id in expired_sessions:
self.end_session(session_id)
def _start_cleanup_thread(self):
"""Start background cleanup thread"""
import time
def cleanup_loop():
while True:
time.sleep(60) # Check every minute
self._cleanup_expired_sessions()
self._cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
self._cleanup_thread.start()
2. Agent Pool Management
Manage agent instances across sessions efficiently:Copy
from queue import Queue, Empty
from dataclasses import dataclass
import time
@dataclass
class AgentPoolConfig:
agent_type: str
min_instances: int = 1
max_instances: int = 10
idle_timeout_seconds: int = 300
class PooledAgent:
def __init__(self, agent_id: str, agent_instance: Any):
self.agent_id = agent_id
self.agent_instance = agent_instance
self.last_used = time.time()
self.in_use = False
self.session_id = None
def acquire(self, session_id: str):
"""Acquire agent for a session"""
self.in_use = True
self.session_id = session_id
self.last_used = time.time()
def release(self):
"""Release agent back to pool"""
self.in_use = False
self.session_id = None
self.last_used = time.time()
# Reset agent state
if hasattr(self.agent_instance, 'reset'):
self.agent_instance.reset()
class MultiUserAgentPool:
def __init__(self):
self.pools: Dict[str, Dict[str, PooledAgent]] = {}
self.pool_configs: Dict[str, AgentPoolConfig] = {}
self.available_agents: Dict[str, Queue] = {}
self._lock = threading.RLock()
def configure_pool(self, config: AgentPoolConfig):
"""Configure an agent pool"""
with self._lock:
self.pool_configs[config.agent_type] = config
if config.agent_type not in self.pools:
self.pools[config.agent_type] = {}
self.available_agents[config.agent_type] = Queue()
# Create minimum instances
self._ensure_minimum_instances(config.agent_type)
def acquire_agent(self, agent_type: str, session_id: str,
timeout: float = 30) -> PooledAgent:
"""Acquire an agent for a session"""
if agent_type not in self.pool_configs:
raise ValueError(f"Unknown agent type: {agent_type}")
# Try to get available agent
try:
agent = self.available_agents[agent_type].get(timeout=timeout)
agent.acquire(session_id)
return agent
except Empty:
# Create new agent if under limit
with self._lock:
if len(self.pools[agent_type]) < self.pool_configs[agent_type].max_instances:
agent = self._create_agent(agent_type)
agent.acquire(session_id)
return agent
raise TimeoutError(f"No available agents of type {agent_type}")
def release_agent(self, agent: PooledAgent):
"""Release agent back to pool"""
agent.release()
# Return to available queue
for agent_type, pool in self.pools.items():
if agent.agent_id in pool:
self.available_agents[agent_type].put(agent)
break
def _create_agent(self, agent_type: str) -> PooledAgent:
"""Create a new agent instance"""
agent_id = f"{agent_type}_{uuid.uuid4().hex[:8]}"
# Create agent based on type (simplified)
if agent_type == "research":
from praisonaiagents import Agent
agent_instance = Agent(
name=f"Research_{agent_id}",
role="Research Assistant",
goal="Assist with research tasks"
)
else:
# Default agent
agent_instance = Agent(
name=f"Agent_{agent_id}",
role="Assistant",
goal="Assist users"
)
pooled_agent = PooledAgent(agent_id, agent_instance)
self.pools[agent_type][agent_id] = pooled_agent
return pooled_agent
def _ensure_minimum_instances(self, agent_type: str):
"""Ensure minimum number of instances exist"""
config = self.pool_configs[agent_type]
current_count = len(self.pools[agent_type])
for _ in range(config.min_instances - current_count):
agent = self._create_agent(agent_type)
self.available_agents[agent_type].put(agent)
def cleanup_idle_agents(self):
"""Remove agents that have been idle too long"""
with self._lock:
current_time = time.time()
for agent_type, pool in self.pools.items():
config = self.pool_configs[agent_type]
agents_to_remove = []
for agent_id, agent in pool.items():
if (not agent.in_use and
current_time - agent.last_used > config.idle_timeout_seconds and
len(pool) > config.min_instances):
agents_to_remove.append(agent_id)
for agent_id in agents_to_remove:
del pool[agent_id]
3. Resource Quota Management
Implement per-user resource quotas:Copy
from enum import Enum
from collections import defaultdict
import asyncio
class ResourceType(Enum):
API_CALLS = "api_calls"
TOKENS = "tokens"
STORAGE_MB = "storage_mb"
COMPUTE_SECONDS = "compute_seconds"
@dataclass
class ResourceQuota:
resource_type: ResourceType
limit: float
period_seconds: int = 3600 # Default 1 hour
class UserResourceManager:
def __init__(self):
self.quotas: Dict[str, Dict[ResourceType, ResourceQuota]] = {}
self.usage: Dict[str, Dict[ResourceType, List[Tuple[float, float]]]] = defaultdict(
lambda: defaultdict(list)
)
self._lock = threading.RLock()
def set_user_quota(self, user_id: str, quotas: List[ResourceQuota]):
"""Set resource quotas for a user"""
with self._lock:
if user_id not in self.quotas:
self.quotas[user_id] = {}
for quota in quotas:
self.quotas[user_id][quota.resource_type] = quota
def check_quota(self, user_id: str, resource_type: ResourceType,
amount: float) -> Tuple[bool, Optional[str]]:
"""Check if user has quota for resource"""
with self._lock:
if user_id not in self.quotas:
return True, None # No quota set
if resource_type not in self.quotas[user_id]:
return True, None # No quota for this resource
quota = self.quotas[user_id][resource_type]
current_usage = self._get_usage_in_period(user_id, resource_type, quota.period_seconds)
if current_usage + amount > quota.limit:
return False, f"Quota exceeded for {resource_type.value}: {current_usage + amount:.2f}/{quota.limit}"
return True, None
def consume_resource(self, user_id: str, resource_type: ResourceType, amount: float):
"""Consume resource from user's quota"""
allowed, error = self.check_quota(user_id, resource_type, amount)
if not allowed:
raise ValueError(error)
with self._lock:
self.usage[user_id][resource_type].append((time.time(), amount))
# Cleanup old entries
self._cleanup_old_usage(user_id, resource_type)
def _get_usage_in_period(self, user_id: str, resource_type: ResourceType,
period_seconds: int) -> float:
"""Get usage in the specified period"""
current_time = time.time()
cutoff_time = current_time - period_seconds
usage_list = self.usage[user_id][resource_type]
return sum(
amount for timestamp, amount in usage_list
if timestamp > cutoff_time
)
def _cleanup_old_usage(self, user_id: str, resource_type: ResourceType):
"""Remove usage entries older than the quota period"""
if user_id not in self.quotas or resource_type not in self.quotas[user_id]:
return
quota = self.quotas[user_id][resource_type]
current_time = time.time()
cutoff_time = current_time - quota.period_seconds
self.usage[user_id][resource_type] = [
(timestamp, amount) for timestamp, amount in self.usage[user_id][resource_type]
if timestamp > cutoff_time
]
def get_usage_report(self, user_id: str) -> Dict[str, Any]:
"""Get usage report for a user"""
with self._lock:
report = {}
for resource_type, quota in self.quotas.get(user_id, {}).items():
usage = self._get_usage_in_period(user_id, resource_type, quota.period_seconds)
report[resource_type.value] = {
"used": usage,
"limit": quota.limit,
"percentage": (usage / quota.limit * 100) if quota.limit > 0 else 0,
"period_seconds": quota.period_seconds
}
return report
4. Session Security
Implement security measures for multi-user environments:Copy
import secrets
import hashlib
from cryptography.fernet import Fernet
class SessionSecurity:
def __init__(self):
self.session_tokens: Dict[str, str] = {}
self.encryption_keys: Dict[str, bytes] = {}
self._lock = threading.RLock()
def generate_session_token(self, session_id: str) -> str:
"""Generate secure session token"""
with self._lock:
token = secrets.token_urlsafe(32)
# Store hashed token
token_hash = hashlib.sha256(token.encode()).hexdigest()
self.session_tokens[session_id] = token_hash
return token
def validate_session_token(self, session_id: str, token: str) -> bool:
"""Validate session token"""
with self._lock:
if session_id not in self.session_tokens:
return False
token_hash = hashlib.sha256(token.encode()).hexdigest()
return self.session_tokens[session_id] == token_hash
def get_session_encryptor(self, session_id: str) -> Fernet:
"""Get encryptor for session data"""
with self._lock:
if session_id not in self.encryption_keys:
# Generate new key for session
key = Fernet.generate_key()
self.encryption_keys[session_id] = key
return Fernet(self.encryption_keys[session_id])
def encrypt_session_data(self, session_id: str, data: str) -> bytes:
"""Encrypt data for a session"""
encryptor = self.get_session_encryptor(session_id)
return encryptor.encrypt(data.encode())
def decrypt_session_data(self, session_id: str, encrypted_data: bytes) -> str:
"""Decrypt session data"""
encryptor = self.get_session_encryptor(session_id)
return encryptor.decrypt(encrypted_data).decode()
def cleanup_session_security(self, session_id: str):
"""Cleanup security data for a session"""
with self._lock:
if session_id in self.session_tokens:
del self.session_tokens[session_id]
if session_id in self.encryption_keys:
del self.encryption_keys[session_id]
Advanced Session Handling
1. Session Persistence
Store and restore sessions:Copy
import json
import pickle
from pathlib import Path
class SessionPersistence:
def __init__(self, storage_path: str = "./sessions"):
self.storage_path = Path(storage_path)
self.storage_path.mkdir(exist_ok=True)
def save_session(self, session: UserSession):
"""Save session to disk"""
session_data = {
"session_id": session.session_id,
"user_id": session.user_id,
"created_at": session.created_at.isoformat(),
"last_activity": session.last_activity.isoformat(),
"metadata": session.metadata,
"context": session.context
}
session_file = self.storage_path / f"{session.session_id}.json"
with open(session_file, 'w') as f:
json.dump(session_data, f, indent=2)
def load_session(self, session_id: str) -> Optional[UserSession]:
"""Load session from disk"""
session_file = self.storage_path / f"{session_id}.json"
if not session_file.exists():
return None
with open(session_file, 'r') as f:
data = json.load(f)
session = UserSession(
session_id=data["session_id"],
user_id=data["user_id"],
metadata=data["metadata"]
)
session.created_at = datetime.fromisoformat(data["created_at"])
session.last_activity = datetime.fromisoformat(data["last_activity"])
session.context = data["context"]
return session
def delete_session(self, session_id: str):
"""Delete session from disk"""
session_file = self.storage_path / f"{session_id}.json"
if session_file.exists():
session_file.unlink()
2. Session Load Balancing
Distribute sessions across multiple workers:Copy
from typing import List
import random
class SessionLoadBalancer:
def __init__(self, workers: List[str]):
self.workers = workers
self.session_assignments: Dict[str, str] = {}
self.worker_load: Dict[str, int] = {worker: 0 for worker in workers}
self._lock = threading.RLock()
def assign_session(self, session_id: str) -> str:
"""Assign session to a worker"""
with self._lock:
# Use least loaded worker
worker = min(self.worker_load.items(), key=lambda x: x[1])[0]
self.session_assignments[session_id] = worker
self.worker_load[worker] += 1
return worker
def get_worker(self, session_id: str) -> Optional[str]:
"""Get worker for a session"""
with self._lock:
return self.session_assignments.get(session_id)
def release_session(self, session_id: str):
"""Release session from worker"""
with self._lock:
worker = self.session_assignments.get(session_id)
if worker:
del self.session_assignments[session_id]
self.worker_load[worker] = max(0, self.worker_load[worker] - 1)
def rebalance(self):
"""Rebalance sessions across workers"""
with self._lock:
# Calculate target load per worker
total_sessions = len(self.session_assignments)
target_load = total_sessions // len(self.workers)
# Identify overloaded and underloaded workers
overloaded = []
underloaded = []
for worker, load in self.worker_load.items():
if load > target_load + 1:
overloaded.append((worker, load - target_load))
elif load < target_load:
underloaded.append((worker, target_load - load))
# Reassign sessions
for worker, excess in overloaded:
sessions_to_move = [
sid for sid, w in self.session_assignments.items()
if w == worker
][:excess]
for session_id in sessions_to_move:
if underloaded:
target_worker, capacity = underloaded[0]
self.session_assignments[session_id] = target_worker
self.worker_load[worker] -= 1
self.worker_load[target_worker] += 1
if capacity <= 1:
underloaded.pop(0)
else:
underloaded[0] = (target_worker, capacity - 1)
3. Session Monitoring
Monitor session health and performance:Copy
@dataclass
class SessionMetrics:
session_id: str
user_id: str
duration_seconds: float
message_count: int
api_calls: int
tokens_used: int
error_count: int
last_error: Optional[str] = None
class SessionMonitor:
def __init__(self):
self.metrics: Dict[str, SessionMetrics] = {}
self.alerts: List[Dict[str, Any]] = []
self._lock = threading.RLock()
def track_session(self, session: UserSession) -> SessionMetrics:
"""Track metrics for a session"""
with self._lock:
if session.session_id not in self.metrics:
self.metrics[session.session_id] = SessionMetrics(
session_id=session.session_id,
user_id=session.user_id,
duration_seconds=0,
message_count=0,
api_calls=0,
tokens_used=0,
error_count=0
)
metrics = self.metrics[session.session_id]
# Update duration
duration = (datetime.now() - session.created_at).total_seconds()
metrics.duration_seconds = duration
# Update message count
metrics.message_count = len(session.context)
return metrics
def record_api_call(self, session_id: str, tokens: int):
"""Record an API call for a session"""
with self._lock:
if session_id in self.metrics:
self.metrics[session_id].api_calls += 1
self.metrics[session_id].tokens_used += tokens
def record_error(self, session_id: str, error: str):
"""Record an error for a session"""
with self._lock:
if session_id in self.metrics:
self.metrics[session_id].error_count += 1
self.metrics[session_id].last_error = error
# Generate alert if error rate is high
metrics = self.metrics[session_id]
if metrics.error_count > 5:
self.alerts.append({
"type": "high_error_rate",
"session_id": session_id,
"error_count": metrics.error_count,
"timestamp": datetime.now()
})
def get_session_health(self, session_id: str) -> Dict[str, Any]:
"""Get health status of a session"""
with self._lock:
if session_id not in self.metrics:
return {"status": "unknown"}
metrics = self.metrics[session_id]
# Calculate health score
error_rate = metrics.error_count / max(metrics.api_calls, 1)
avg_response_time = metrics.duration_seconds / max(metrics.message_count, 1)
health_score = 100
if error_rate > 0.1:
health_score -= 30
if avg_response_time > 5:
health_score -= 20
if metrics.tokens_used > 10000:
health_score -= 10
return {
"status": "healthy" if health_score > 70 else "unhealthy",
"score": health_score,
"metrics": {
"error_rate": error_rate,
"avg_response_time": avg_response_time,
"total_tokens": metrics.tokens_used
}
}
Best Practices
-
Implement Session Timeouts: Always set reasonable timeouts
Copy
def check_session_timeout(session: UserSession, timeout_minutes: int = 30) -> bool: idle_time = datetime.now() - session.last_activity return idle_time.total_seconds() > timeout_minutes * 60 -
Use Session Middleware: Implement middleware for common operations
Copy
class SessionMiddleware: def __init__(self, session_manager: MultiUserSessionManager): self.session_manager = session_manager async def __call__(self, request, call_next): session_id = request.headers.get("X-Session-ID") if not session_id: return {"error": "No session ID provided"} try: with self.session_manager.get_session(session_id) as session: request.state.session = session response = await call_next(request) return response except ValueError: return {"error": "Invalid session"} -
Implement Rate Limiting: Protect against abuse
Copy
from functools import wraps def rate_limit(max_calls: int = 100, period_seconds: int = 60): def decorator(func): call_times = defaultdict(list) @wraps(func) def wrapper(session_id: str, *args, **kwargs): current_time = time.time() cutoff_time = current_time - period_seconds # Clean old calls call_times[session_id] = [ t for t in call_times[session_id] if t > cutoff_time ] # Check rate limit if len(call_times[session_id]) >= max_calls: raise Exception("Rate limit exceeded") call_times[session_id].append(current_time) return func(session_id, *args, **kwargs) return wrapper return decorator
Testing Multi-User Sessions
Copy
import pytest
import asyncio
from concurrent.futures import ThreadPoolExecutor
def test_concurrent_sessions():
manager = MultiUserSessionManager()
# Create multiple sessions concurrently
with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for i in range(10):
user_id = f"user_{i % 3}" # 3 users
future = executor.submit(manager.create_session, user_id)
futures.append(future)
session_ids = [f.result() for f in futures]
# Verify all sessions created
assert len(session_ids) == 10
assert len(set(session_ids)) == 10 # All unique
# Verify session limits enforced
assert len(manager.user_sessions["user_0"]) <= manager.max_sessions_per_user
def test_resource_quotas():
resource_manager = UserResourceManager()
# Set quota
resource_manager.set_user_quota("user1", [
ResourceQuota(ResourceType.API_CALLS, limit=100, period_seconds=60)
])
# Consume resources
for _ in range(100):
resource_manager.consume_resource("user1", ResourceType.API_CALLS, 1)
# Verify quota enforcement
allowed, error = resource_manager.check_quota("user1", ResourceType.API_CALLS, 1)
assert not allowed
assert "Quota exceeded" in error
async def test_session_isolation():
manager = MultiUserSessionManager()
# Create sessions for different users
session1 = manager.create_session("user1")
session2 = manager.create_session("user2")
# Add context to sessions
with manager.get_session(session1) as s1:
s1.add_context({"content": "User 1 message"})
with manager.get_session(session2) as s2:
s2.add_context({"content": "User 2 message"})
# Verify isolation
with manager.get_session(session1) as s1:
context1 = s1.get_context()
assert len(context1) == 1
assert context1[0]["content"] == "User 1 message"
with manager.get_session(session2) as s2:
context2 = s2.get_context()
assert len(context2) == 1
assert context2[0]["content"] == "User 2 message"

