Skip to main content

Multi-User Session Handling

Managing multiple concurrent user sessions is crucial for production multi-agent systems. This guide covers strategies for isolating user contexts, managing resources, and ensuring security.

Core Concepts

Session Isolation Requirements

  1. Data Isolation: Each user’s data must be completely isolated
  2. Resource Isolation: Prevent resource exhaustion by one user
  3. Context Isolation: Maintain separate conversation contexts
  4. Security Isolation: Prevent cross-session data leakage
  5. Performance Isolation: One user shouldn’t impact others

Session Management Architecture

1. Session Manager Implementation

import uuid
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List
import threading
from contextlib import contextmanager

class UserSession:
    def __init__(self, session_id: str, user_id: str, metadata: Dict[str, Any] = None):
        self.session_id = session_id
        self.user_id = user_id
        self.created_at = datetime.now()
        self.last_activity = datetime.now()
        self.metadata = metadata or {}
        self.context = []
        self.agents = {}
        self.resources = {}
        self.is_active = True
        self._lock = threading.RLock()
    
    def update_activity(self):
        """Update last activity timestamp"""
        with self._lock:
            self.last_activity = datetime.now()
    
    def add_context(self, message: Dict[str, Any]):
        """Add message to session context"""
        with self._lock:
            self.context.append({
                **message,
                "timestamp": datetime.now()
            })
    
    def get_context(self, last_n: Optional[int] = None) -> List[Dict[str, Any]]:
        """Get session context"""
        with self._lock:
            if last_n:
                return self.context[-last_n:]
            return self.context.copy()

class MultiUserSessionManager:
    def __init__(self, max_sessions_per_user: int = 5, 
                 session_timeout_minutes: int = 30):
        self.sessions: Dict[str, UserSession] = {}
        self.user_sessions: Dict[str, List[str]] = {}
        self.max_sessions_per_user = max_sessions_per_user
        self.session_timeout = timedelta(minutes=session_timeout_minutes)
        self._lock = threading.RLock()
        self._cleanup_thread = None
        self._start_cleanup_thread()
    
    def create_session(self, user_id: str, metadata: Dict[str, Any] = None) -> str:
        """Create a new session for a user"""
        with self._lock:
            # Check session limit
            if user_id in self.user_sessions:
                if len(self.user_sessions[user_id]) >= self.max_sessions_per_user:
                    # Remove oldest session
                    oldest_session_id = self._get_oldest_session(user_id)
                    self.end_session(oldest_session_id)
            
            # Create new session
            session_id = str(uuid.uuid4())
            session = UserSession(session_id, user_id, metadata)
            
            self.sessions[session_id] = session
            
            if user_id not in self.user_sessions:
                self.user_sessions[user_id] = []
            self.user_sessions[user_id].append(session_id)
            
            return session_id
    
    @contextmanager
    def get_session(self, session_id: str):
        """Get session with automatic activity update"""
        session = self._get_session(session_id)
        if not session:
            raise ValueError(f"Session {session_id} not found")
        
        session.update_activity()
        yield session
    
    def _get_session(self, session_id: str) -> Optional[UserSession]:
        """Get session by ID"""
        with self._lock:
            return self.sessions.get(session_id)
    
    def end_session(self, session_id: str):
        """End a session and cleanup resources"""
        with self._lock:
            session = self.sessions.get(session_id)
            if not session:
                return
            
            # Cleanup session resources
            self._cleanup_session_resources(session)
            
            # Remove from tracking
            del self.sessions[session_id]
            
            if session.user_id in self.user_sessions:
                self.user_sessions[session.user_id].remove(session_id)
                if not self.user_sessions[session.user_id]:
                    del self.user_sessions[session.user_id]
    
    def _cleanup_session_resources(self, session: UserSession):
        """Cleanup resources associated with a session"""
        # Cleanup agents
        for agent_id, agent in session.agents.items():
            if hasattr(agent, 'cleanup'):
                agent.cleanup()
        
        # Clear context to free memory
        session.context.clear()
        
        # Mark as inactive
        session.is_active = False
    
    def _get_oldest_session(self, user_id: str) -> Optional[str]:
        """Get the oldest session for a user"""
        if user_id not in self.user_sessions:
            return None
        
        oldest_session_id = None
        oldest_time = datetime.now()
        
        for session_id in self.user_sessions[user_id]:
            session = self.sessions.get(session_id)
            if session and session.created_at < oldest_time:
                oldest_time = session.created_at
                oldest_session_id = session_id
        
        return oldest_session_id
    
    def _cleanup_expired_sessions(self):
        """Remove expired sessions"""
        with self._lock:
            current_time = datetime.now()
            expired_sessions = []
            
            for session_id, session in self.sessions.items():
                if current_time - session.last_activity > self.session_timeout:
                    expired_sessions.append(session_id)
            
            for session_id in expired_sessions:
                self.end_session(session_id)
    
    def _start_cleanup_thread(self):
        """Start background cleanup thread"""
        import time
        
        def cleanup_loop():
            while True:
                time.sleep(60)  # Check every minute
                self._cleanup_expired_sessions()
        
        self._cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
        self._cleanup_thread.start()

2. Agent Pool Management

Manage agent instances across sessions efficiently:
from queue import Queue, Empty
from dataclasses import dataclass
import time

@dataclass
class AgentPoolConfig:
    agent_type: str
    min_instances: int = 1
    max_instances: int = 10
    idle_timeout_seconds: int = 300

class PooledAgent:
    def __init__(self, agent_id: str, agent_instance: Any):
        self.agent_id = agent_id
        self.agent_instance = agent_instance
        self.last_used = time.time()
        self.in_use = False
        self.session_id = None
    
    def acquire(self, session_id: str):
        """Acquire agent for a session"""
        self.in_use = True
        self.session_id = session_id
        self.last_used = time.time()
    
    def release(self):
        """Release agent back to pool"""
        self.in_use = False
        self.session_id = None
        self.last_used = time.time()
        
        # Reset agent state
        if hasattr(self.agent_instance, 'reset'):
            self.agent_instance.reset()

class MultiUserAgentPool:
    def __init__(self):
        self.pools: Dict[str, Dict[str, PooledAgent]] = {}
        self.pool_configs: Dict[str, AgentPoolConfig] = {}
        self.available_agents: Dict[str, Queue] = {}
        self._lock = threading.RLock()
    
    def configure_pool(self, config: AgentPoolConfig):
        """Configure an agent pool"""
        with self._lock:
            self.pool_configs[config.agent_type] = config
            
            if config.agent_type not in self.pools:
                self.pools[config.agent_type] = {}
                self.available_agents[config.agent_type] = Queue()
            
            # Create minimum instances
            self._ensure_minimum_instances(config.agent_type)
    
    def acquire_agent(self, agent_type: str, session_id: str, 
                     timeout: float = 30) -> PooledAgent:
        """Acquire an agent for a session"""
        if agent_type not in self.pool_configs:
            raise ValueError(f"Unknown agent type: {agent_type}")
        
        # Try to get available agent
        try:
            agent = self.available_agents[agent_type].get(timeout=timeout)
            agent.acquire(session_id)
            return agent
        except Empty:
            # Create new agent if under limit
            with self._lock:
                if len(self.pools[agent_type]) < self.pool_configs[agent_type].max_instances:
                    agent = self._create_agent(agent_type)
                    agent.acquire(session_id)
                    return agent
            
            raise TimeoutError(f"No available agents of type {agent_type}")
    
    def release_agent(self, agent: PooledAgent):
        """Release agent back to pool"""
        agent.release()
        
        # Return to available queue
        for agent_type, pool in self.pools.items():
            if agent.agent_id in pool:
                self.available_agents[agent_type].put(agent)
                break
    
    def _create_agent(self, agent_type: str) -> PooledAgent:
        """Create a new agent instance"""
        agent_id = f"{agent_type}_{uuid.uuid4().hex[:8]}"
        
        # Create agent based on type (simplified)
        if agent_type == "research":
            from praisonaiagents import Agent
            agent_instance = Agent(
                name=f"Research_{agent_id}",
                role="Research Assistant",
                goal="Assist with research tasks"
            )
        else:
            # Default agent
            agent_instance = Agent(
                name=f"Agent_{agent_id}",
                role="Assistant",
                goal="Assist users"
            )
        
        pooled_agent = PooledAgent(agent_id, agent_instance)
        self.pools[agent_type][agent_id] = pooled_agent
        
        return pooled_agent
    
    def _ensure_minimum_instances(self, agent_type: str):
        """Ensure minimum number of instances exist"""
        config = self.pool_configs[agent_type]
        current_count = len(self.pools[agent_type])
        
        for _ in range(config.min_instances - current_count):
            agent = self._create_agent(agent_type)
            self.available_agents[agent_type].put(agent)
    
    def cleanup_idle_agents(self):
        """Remove agents that have been idle too long"""
        with self._lock:
            current_time = time.time()
            
            for agent_type, pool in self.pools.items():
                config = self.pool_configs[agent_type]
                agents_to_remove = []
                
                for agent_id, agent in pool.items():
                    if (not agent.in_use and 
                        current_time - agent.last_used > config.idle_timeout_seconds and
                        len(pool) > config.min_instances):
                        agents_to_remove.append(agent_id)
                
                for agent_id in agents_to_remove:
                    del pool[agent_id]

3. Resource Quota Management

Implement per-user resource quotas:
from enum import Enum
from collections import defaultdict
import asyncio

class ResourceType(Enum):
    API_CALLS = "api_calls"
    TOKENS = "tokens"
    STORAGE_MB = "storage_mb"
    COMPUTE_SECONDS = "compute_seconds"

@dataclass
class ResourceQuota:
    resource_type: ResourceType
    limit: float
    period_seconds: int = 3600  # Default 1 hour
    
class UserResourceManager:
    def __init__(self):
        self.quotas: Dict[str, Dict[ResourceType, ResourceQuota]] = {}
        self.usage: Dict[str, Dict[ResourceType, List[Tuple[float, float]]]] = defaultdict(
            lambda: defaultdict(list)
        )
        self._lock = threading.RLock()
    
    def set_user_quota(self, user_id: str, quotas: List[ResourceQuota]):
        """Set resource quotas for a user"""
        with self._lock:
            if user_id not in self.quotas:
                self.quotas[user_id] = {}
            
            for quota in quotas:
                self.quotas[user_id][quota.resource_type] = quota
    
    def check_quota(self, user_id: str, resource_type: ResourceType, 
                   amount: float) -> Tuple[bool, Optional[str]]:
        """Check if user has quota for resource"""
        with self._lock:
            if user_id not in self.quotas:
                return True, None  # No quota set
            
            if resource_type not in self.quotas[user_id]:
                return True, None  # No quota for this resource
            
            quota = self.quotas[user_id][resource_type]
            current_usage = self._get_usage_in_period(user_id, resource_type, quota.period_seconds)
            
            if current_usage + amount > quota.limit:
                return False, f"Quota exceeded for {resource_type.value}: {current_usage + amount:.2f}/{quota.limit}"
            
            return True, None
    
    def consume_resource(self, user_id: str, resource_type: ResourceType, amount: float):
        """Consume resource from user's quota"""
        allowed, error = self.check_quota(user_id, resource_type, amount)
        
        if not allowed:
            raise ValueError(error)
        
        with self._lock:
            self.usage[user_id][resource_type].append((time.time(), amount))
            
            # Cleanup old entries
            self._cleanup_old_usage(user_id, resource_type)
    
    def _get_usage_in_period(self, user_id: str, resource_type: ResourceType, 
                            period_seconds: int) -> float:
        """Get usage in the specified period"""
        current_time = time.time()
        cutoff_time = current_time - period_seconds
        
        usage_list = self.usage[user_id][resource_type]
        
        return sum(
            amount for timestamp, amount in usage_list
            if timestamp > cutoff_time
        )
    
    def _cleanup_old_usage(self, user_id: str, resource_type: ResourceType):
        """Remove usage entries older than the quota period"""
        if user_id not in self.quotas or resource_type not in self.quotas[user_id]:
            return
        
        quota = self.quotas[user_id][resource_type]
        current_time = time.time()
        cutoff_time = current_time - quota.period_seconds
        
        self.usage[user_id][resource_type] = [
            (timestamp, amount) for timestamp, amount in self.usage[user_id][resource_type]
            if timestamp > cutoff_time
        ]
    
    def get_usage_report(self, user_id: str) -> Dict[str, Any]:
        """Get usage report for a user"""
        with self._lock:
            report = {}
            
            for resource_type, quota in self.quotas.get(user_id, {}).items():
                usage = self._get_usage_in_period(user_id, resource_type, quota.period_seconds)
                
                report[resource_type.value] = {
                    "used": usage,
                    "limit": quota.limit,
                    "percentage": (usage / quota.limit * 100) if quota.limit > 0 else 0,
                    "period_seconds": quota.period_seconds
                }
            
            return report

4. Session Security

Implement security measures for multi-user environments:
import secrets
import hashlib
from cryptography.fernet import Fernet

class SessionSecurity:
    def __init__(self):
        self.session_tokens: Dict[str, str] = {}
        self.encryption_keys: Dict[str, bytes] = {}
        self._lock = threading.RLock()
    
    def generate_session_token(self, session_id: str) -> str:
        """Generate secure session token"""
        with self._lock:
            token = secrets.token_urlsafe(32)
            
            # Store hashed token
            token_hash = hashlib.sha256(token.encode()).hexdigest()
            self.session_tokens[session_id] = token_hash
            
            return token
    
    def validate_session_token(self, session_id: str, token: str) -> bool:
        """Validate session token"""
        with self._lock:
            if session_id not in self.session_tokens:
                return False
            
            token_hash = hashlib.sha256(token.encode()).hexdigest()
            return self.session_tokens[session_id] == token_hash
    
    def get_session_encryptor(self, session_id: str) -> Fernet:
        """Get encryptor for session data"""
        with self._lock:
            if session_id not in self.encryption_keys:
                # Generate new key for session
                key = Fernet.generate_key()
                self.encryption_keys[session_id] = key
            
            return Fernet(self.encryption_keys[session_id])
    
    def encrypt_session_data(self, session_id: str, data: str) -> bytes:
        """Encrypt data for a session"""
        encryptor = self.get_session_encryptor(session_id)
        return encryptor.encrypt(data.encode())
    
    def decrypt_session_data(self, session_id: str, encrypted_data: bytes) -> str:
        """Decrypt session data"""
        encryptor = self.get_session_encryptor(session_id)
        return encryptor.decrypt(encrypted_data).decode()
    
    def cleanup_session_security(self, session_id: str):
        """Cleanup security data for a session"""
        with self._lock:
            if session_id in self.session_tokens:
                del self.session_tokens[session_id]
            
            if session_id in self.encryption_keys:
                del self.encryption_keys[session_id]

Advanced Session Handling

1. Session Persistence

Store and restore sessions:
import json
import pickle
from pathlib import Path

class SessionPersistence:
    def __init__(self, storage_path: str = "./sessions"):
        self.storage_path = Path(storage_path)
        self.storage_path.mkdir(exist_ok=True)
    
    def save_session(self, session: UserSession):
        """Save session to disk"""
        session_data = {
            "session_id": session.session_id,
            "user_id": session.user_id,
            "created_at": session.created_at.isoformat(),
            "last_activity": session.last_activity.isoformat(),
            "metadata": session.metadata,
            "context": session.context
        }
        
        session_file = self.storage_path / f"{session.session_id}.json"
        
        with open(session_file, 'w') as f:
            json.dump(session_data, f, indent=2)
    
    def load_session(self, session_id: str) -> Optional[UserSession]:
        """Load session from disk"""
        session_file = self.storage_path / f"{session_id}.json"
        
        if not session_file.exists():
            return None
        
        with open(session_file, 'r') as f:
            data = json.load(f)
        
        session = UserSession(
            session_id=data["session_id"],
            user_id=data["user_id"],
            metadata=data["metadata"]
        )
        
        session.created_at = datetime.fromisoformat(data["created_at"])
        session.last_activity = datetime.fromisoformat(data["last_activity"])
        session.context = data["context"]
        
        return session
    
    def delete_session(self, session_id: str):
        """Delete session from disk"""
        session_file = self.storage_path / f"{session_id}.json"
        
        if session_file.exists():
            session_file.unlink()

2. Session Load Balancing

Distribute sessions across multiple workers:
from typing import List
import random

class SessionLoadBalancer:
    def __init__(self, workers: List[str]):
        self.workers = workers
        self.session_assignments: Dict[str, str] = {}
        self.worker_load: Dict[str, int] = {worker: 0 for worker in workers}
        self._lock = threading.RLock()
    
    def assign_session(self, session_id: str) -> str:
        """Assign session to a worker"""
        with self._lock:
            # Use least loaded worker
            worker = min(self.worker_load.items(), key=lambda x: x[1])[0]
            
            self.session_assignments[session_id] = worker
            self.worker_load[worker] += 1
            
            return worker
    
    def get_worker(self, session_id: str) -> Optional[str]:
        """Get worker for a session"""
        with self._lock:
            return self.session_assignments.get(session_id)
    
    def release_session(self, session_id: str):
        """Release session from worker"""
        with self._lock:
            worker = self.session_assignments.get(session_id)
            
            if worker:
                del self.session_assignments[session_id]
                self.worker_load[worker] = max(0, self.worker_load[worker] - 1)
    
    def rebalance(self):
        """Rebalance sessions across workers"""
        with self._lock:
            # Calculate target load per worker
            total_sessions = len(self.session_assignments)
            target_load = total_sessions // len(self.workers)
            
            # Identify overloaded and underloaded workers
            overloaded = []
            underloaded = []
            
            for worker, load in self.worker_load.items():
                if load > target_load + 1:
                    overloaded.append((worker, load - target_load))
                elif load < target_load:
                    underloaded.append((worker, target_load - load))
            
            # Reassign sessions
            for worker, excess in overloaded:
                sessions_to_move = [
                    sid for sid, w in self.session_assignments.items()
                    if w == worker
                ][:excess]
                
                for session_id in sessions_to_move:
                    if underloaded:
                        target_worker, capacity = underloaded[0]
                        
                        self.session_assignments[session_id] = target_worker
                        self.worker_load[worker] -= 1
                        self.worker_load[target_worker] += 1
                        
                        if capacity <= 1:
                            underloaded.pop(0)
                        else:
                            underloaded[0] = (target_worker, capacity - 1)

3. Session Monitoring

Monitor session health and performance:
@dataclass
class SessionMetrics:
    session_id: str
    user_id: str
    duration_seconds: float
    message_count: int
    api_calls: int
    tokens_used: int
    error_count: int
    last_error: Optional[str] = None

class SessionMonitor:
    def __init__(self):
        self.metrics: Dict[str, SessionMetrics] = {}
        self.alerts: List[Dict[str, Any]] = []
        self._lock = threading.RLock()
    
    def track_session(self, session: UserSession) -> SessionMetrics:
        """Track metrics for a session"""
        with self._lock:
            if session.session_id not in self.metrics:
                self.metrics[session.session_id] = SessionMetrics(
                    session_id=session.session_id,
                    user_id=session.user_id,
                    duration_seconds=0,
                    message_count=0,
                    api_calls=0,
                    tokens_used=0,
                    error_count=0
                )
            
            metrics = self.metrics[session.session_id]
            
            # Update duration
            duration = (datetime.now() - session.created_at).total_seconds()
            metrics.duration_seconds = duration
            
            # Update message count
            metrics.message_count = len(session.context)
            
            return metrics
    
    def record_api_call(self, session_id: str, tokens: int):
        """Record an API call for a session"""
        with self._lock:
            if session_id in self.metrics:
                self.metrics[session_id].api_calls += 1
                self.metrics[session_id].tokens_used += tokens
    
    def record_error(self, session_id: str, error: str):
        """Record an error for a session"""
        with self._lock:
            if session_id in self.metrics:
                self.metrics[session_id].error_count += 1
                self.metrics[session_id].last_error = error
                
                # Generate alert if error rate is high
                metrics = self.metrics[session_id]
                if metrics.error_count > 5:
                    self.alerts.append({
                        "type": "high_error_rate",
                        "session_id": session_id,
                        "error_count": metrics.error_count,
                        "timestamp": datetime.now()
                    })
    
    def get_session_health(self, session_id: str) -> Dict[str, Any]:
        """Get health status of a session"""
        with self._lock:
            if session_id not in self.metrics:
                return {"status": "unknown"}
            
            metrics = self.metrics[session_id]
            
            # Calculate health score
            error_rate = metrics.error_count / max(metrics.api_calls, 1)
            avg_response_time = metrics.duration_seconds / max(metrics.message_count, 1)
            
            health_score = 100
            
            if error_rate > 0.1:
                health_score -= 30
            if avg_response_time > 5:
                health_score -= 20
            if metrics.tokens_used > 10000:
                health_score -= 10
            
            return {
                "status": "healthy" if health_score > 70 else "unhealthy",
                "score": health_score,
                "metrics": {
                    "error_rate": error_rate,
                    "avg_response_time": avg_response_time,
                    "total_tokens": metrics.tokens_used
                }
            }

Best Practices

  1. Implement Session Timeouts: Always set reasonable timeouts
    def check_session_timeout(session: UserSession, timeout_minutes: int = 30) -> bool:
        idle_time = datetime.now() - session.last_activity
        return idle_time.total_seconds() > timeout_minutes * 60
    
  2. Use Session Middleware: Implement middleware for common operations
    class SessionMiddleware:
        def __init__(self, session_manager: MultiUserSessionManager):
            self.session_manager = session_manager
        
        async def __call__(self, request, call_next):
            session_id = request.headers.get("X-Session-ID")
            
            if not session_id:
                return {"error": "No session ID provided"}
            
            try:
                with self.session_manager.get_session(session_id) as session:
                    request.state.session = session
                    response = await call_next(request)
                    return response
            except ValueError:
                return {"error": "Invalid session"}
    
  3. Implement Rate Limiting: Protect against abuse
    from functools import wraps
    
    def rate_limit(max_calls: int = 100, period_seconds: int = 60):
        def decorator(func):
            call_times = defaultdict(list)
            
            @wraps(func)
            def wrapper(session_id: str, *args, **kwargs):
                current_time = time.time()
                cutoff_time = current_time - period_seconds
                
                # Clean old calls
                call_times[session_id] = [
                    t for t in call_times[session_id] if t > cutoff_time
                ]
                
                # Check rate limit
                if len(call_times[session_id]) >= max_calls:
                    raise Exception("Rate limit exceeded")
                
                call_times[session_id].append(current_time)
                return func(session_id, *args, **kwargs)
            
            return wrapper
        return decorator
    

Testing Multi-User Sessions

import pytest
import asyncio
from concurrent.futures import ThreadPoolExecutor

def test_concurrent_sessions():
    manager = MultiUserSessionManager()
    
    # Create multiple sessions concurrently
    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = []
        
        for i in range(10):
            user_id = f"user_{i % 3}"  # 3 users
            future = executor.submit(manager.create_session, user_id)
            futures.append(future)
        
        session_ids = [f.result() for f in futures]
    
    # Verify all sessions created
    assert len(session_ids) == 10
    assert len(set(session_ids)) == 10  # All unique
    
    # Verify session limits enforced
    assert len(manager.user_sessions["user_0"]) <= manager.max_sessions_per_user

def test_resource_quotas():
    resource_manager = UserResourceManager()
    
    # Set quota
    resource_manager.set_user_quota("user1", [
        ResourceQuota(ResourceType.API_CALLS, limit=100, period_seconds=60)
    ])
    
    # Consume resources
    for _ in range(100):
        resource_manager.consume_resource("user1", ResourceType.API_CALLS, 1)
    
    # Verify quota enforcement
    allowed, error = resource_manager.check_quota("user1", ResourceType.API_CALLS, 1)
    assert not allowed
    assert "Quota exceeded" in error

async def test_session_isolation():
    manager = MultiUserSessionManager()
    
    # Create sessions for different users
    session1 = manager.create_session("user1")
    session2 = manager.create_session("user2")
    
    # Add context to sessions
    with manager.get_session(session1) as s1:
        s1.add_context({"content": "User 1 message"})
    
    with manager.get_session(session2) as s2:
        s2.add_context({"content": "User 2 message"})
    
    # Verify isolation
    with manager.get_session(session1) as s1:
        context1 = s1.get_context()
        assert len(context1) == 1
        assert context1[0]["content"] == "User 1 message"
    
    with manager.get_session(session2) as s2:
        context2 = s2.get_context()
        assert len(context2) == 1
        assert context2[0]["content"] == "User 2 message"

Conclusion

Effective multi-user session handling is essential for production multi-agent systems. By implementing proper session isolation, resource management, and security measures, you can build scalable systems that serve multiple users efficiently and securely.