Skip to main content

State Conflict Resolution

In multi-agent systems, state conflicts can arise when multiple agents attempt to modify shared state concurrently. This guide covers strategies for preventing and resolving these conflicts.

Understanding State Conflicts

Types of Conflicts

  1. Write-Write Conflicts: Multiple agents writing to the same state
  2. Read-Write Conflicts: Reading stale data while another agent is writing
  3. Lost Updates: Updates overwritten by concurrent operations
  4. Phantom Reads: State changes between reads
  5. Cascading Conflicts: Conflicts propagating through dependent states

Conflict Prevention Strategies

1. Pessimistic Locking

Prevent conflicts by acquiring locks before state modifications:
import threading
from contextlib import contextmanager
from typing import Dict, Any, Optional
import time

class PessimisticStateLock:
    def __init__(self, timeout: float = 30.0):
        self.locks: Dict[str, threading.RLock] = {}
        self.lock_holders: Dict[str, str] = {}
        self.timeout = timeout
        self._lock = threading.Lock()
    
    @contextmanager
    def acquire_lock(self, resource_id: str, agent_id: str):
        """Acquire a lock for a specific resource"""
        lock = self._get_or_create_lock(resource_id)
        
        acquired = lock.acquire(timeout=self.timeout)
        if not acquired:
            raise TimeoutError(f"Could not acquire lock for {resource_id}")
        
        try:
            with self._lock:
                self.lock_holders[resource_id] = agent_id
            yield
        finally:
            with self._lock:
                if resource_id in self.lock_holders:
                    del self.lock_holders[resource_id]
            lock.release()
    
    def _get_or_create_lock(self, resource_id: str) -> threading.RLock:
        """Get or create a lock for a resource"""
        with self._lock:
            if resource_id not in self.locks:
                self.locks[resource_id] = threading.RLock()
            return self.locks[resource_id]
    
    def is_locked(self, resource_id: str) -> bool:
        """Check if a resource is locked"""
        with self._lock:
            return resource_id in self.lock_holders
    
    def get_lock_holder(self, resource_id: str) -> Optional[str]:
        """Get the agent holding a lock"""
        with self._lock:
            return self.lock_holders.get(resource_id)

2. Optimistic Concurrency Control

Use version numbers to detect conflicts:
from dataclasses import dataclass
from typing import Generic, TypeVar, Optional
import uuid

T = TypeVar('T')

@dataclass
class VersionedState(Generic[T]):
    data: T
    version: int
    last_modified_by: str
    timestamp: float

class OptimisticStateManager:
    def __init__(self):
        self.states: Dict[str, VersionedState] = {}
        self._lock = threading.Lock()
    
    def read(self, key: str) -> Optional[VersionedState]:
        """Read state with version information"""
        with self._lock:
            return self.states.get(key)
    
    def write(self, key: str, data: Any, expected_version: int, 
              agent_id: str) -> bool:
        """Write state if version matches expected"""
        with self._lock:
            current_state = self.states.get(key)
            
            # First write
            if current_state is None and expected_version == -1:
                self.states[key] = VersionedState(
                    data=data,
                    version=0,
                    last_modified_by=agent_id,
                    timestamp=time.time()
                )
                return True
            
            # Version mismatch - conflict detected
            if current_state is None or current_state.version != expected_version:
                return False
            
            # Update state
            self.states[key] = VersionedState(
                data=data,
                version=current_state.version + 1,
                last_modified_by=agent_id,
                timestamp=time.time()
            )
            return True
    
    def compare_and_swap(self, key: str, old_data: Any, new_data: Any,
                        agent_id: str) -> bool:
        """Atomic compare-and-swap operation"""
        with self._lock:
            current_state = self.states.get(key)
            
            if current_state and current_state.data == old_data:
                self.states[key] = VersionedState(
                    data=new_data,
                    version=current_state.version + 1,
                    last_modified_by=agent_id,
                    timestamp=time.time()
                )
                return True
            
            return False

3. Event Sourcing

Track all state changes as events:
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Callable

class EventType(Enum):
    CREATED = "created"
    UPDATED = "updated"
    DELETED = "deleted"

@dataclass
class StateEvent:
    event_id: str
    event_type: EventType
    entity_id: str
    agent_id: str
    timestamp: float
    data: Dict[str, Any]
    metadata: Dict[str, Any] = field(default_factory=dict)

class EventSourcedState:
    def __init__(self):
        self.events: List[StateEvent] = []
        self.projections: Dict[str, Any] = {}
        self.event_handlers: Dict[EventType, List[Callable]] = {
            EventType.CREATED: [],
            EventType.UPDATED: [],
            EventType.DELETED: []
        }
        self._lock = threading.Lock()
    
    def append_event(self, event: StateEvent) -> None:
        """Append an event to the event log"""
        with self._lock:
            self.events.append(event)
            self._apply_event(event)
    
    def _apply_event(self, event: StateEvent) -> None:
        """Apply event to update projections"""
        for handler in self.event_handlers[event.event_type]:
            handler(event, self.projections)
    
    def register_handler(self, event_type: EventType, 
                        handler: Callable[[StateEvent, Dict], None]) -> None:
        """Register an event handler"""
        self.event_handlers[event_type].append(handler)
    
    def get_entity_history(self, entity_id: str) -> List[StateEvent]:
        """Get all events for an entity"""
        with self._lock:
            return [e for e in self.events if e.entity_id == entity_id]
    
    def resolve_conflicts(self, entity_id: str) -> Any:
        """Resolve conflicts by replaying events"""
        history = self.get_entity_history(entity_id)
        
        # Apply custom conflict resolution logic
        if len(history) > 1:
            # Example: Last-write-wins
            return self._last_write_wins(history)
        
        return None
    
    def _last_write_wins(self, events: List[StateEvent]) -> Any:
        """Simple last-write-wins conflict resolution"""
        if not events:
            return None
        
        latest_event = max(events, key=lambda e: e.timestamp)
        return latest_event.data

Conflict Resolution Strategies

1. Conflict-free Replicated Data Types (CRDTs)

Implement CRDTs for automatic conflict resolution:
from abc import ABC, abstractmethod
from typing import Set, Dict

class CRDT(ABC):
    @abstractmethod
    def merge(self, other: 'CRDT') -> 'CRDT':
        """Merge with another CRDT instance"""
        pass

class GCounter(CRDT):
    """Grow-only counter CRDT"""
    def __init__(self, node_id: str):
        self.node_id = node_id
        self.counts: Dict[str, int] = {node_id: 0}
    
    def increment(self, value: int = 1) -> None:
        """Increment counter for this node"""
        self.counts[self.node_id] = self.counts.get(self.node_id, 0) + value
    
    def value(self) -> int:
        """Get total value across all nodes"""
        return sum(self.counts.values())
    
    def merge(self, other: 'GCounter') -> 'GCounter':
        """Merge with another GCounter"""
        merged = GCounter(self.node_id)
        
        # Take maximum count for each node
        all_nodes = set(self.counts.keys()) | set(other.counts.keys())
        for node in all_nodes:
            merged.counts[node] = max(
                self.counts.get(node, 0),
                other.counts.get(node, 0)
            )
        
        return merged

class ORSet(CRDT):
    """Observed-Remove Set CRDT"""
    def __init__(self, node_id: str):
        self.node_id = node_id
        self.elements: Dict[Any, Set[str]] = {}  # element -> set of unique tags
        self.tombstones: Dict[Any, Set[str]] = {}  # removed elements
    
    def add(self, element: Any) -> None:
        """Add an element to the set"""
        tag = f"{self.node_id}:{uuid.uuid4()}"
        if element not in self.elements:
            self.elements[element] = set()
        self.elements[element].add(tag)
    
    def remove(self, element: Any) -> None:
        """Remove an element from the set"""
        if element in self.elements:
            if element not in self.tombstones:
                self.tombstones[element] = set()
            self.tombstones[element].update(self.elements[element])
    
    def contains(self, element: Any) -> bool:
        """Check if element is in the set"""
        if element not in self.elements:
            return False
        
        element_tags = self.elements[element]
        tombstone_tags = self.tombstones.get(element, set())
        
        # Element exists if it has tags not in tombstones
        return len(element_tags - tombstone_tags) > 0
    
    def merge(self, other: 'ORSet') -> 'ORSet':
        """Merge with another ORSet"""
        merged = ORSet(self.node_id)
        
        # Merge elements
        all_elements = set(self.elements.keys()) | set(other.elements.keys())
        for element in all_elements:
            merged.elements[element] = (
                self.elements.get(element, set()) |
                other.elements.get(element, set())
            )
        
        # Merge tombstones
        all_tombstones = set(self.tombstones.keys()) | set(other.tombstones.keys())
        for element in all_tombstones:
            merged.tombstones[element] = (
                self.tombstones.get(element, set()) |
                other.tombstones.get(element, set())
            )
        
        return merged

2. Three-Way Merge

Implement three-way merge for complex state resolution:
from difflib import SequenceMatcher
from typing import Tuple, List, Optional

class ThreeWayMerger:
    def merge_states(self, base: Dict[str, Any], 
                    version_a: Dict[str, Any],
                    version_b: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
        """Perform three-way merge on states"""
        merged = {}
        conflicts = []
        
        all_keys = set(base.keys()) | set(version_a.keys()) | set(version_b.keys())
        
        for key in all_keys:
            base_val = base.get(key)
            a_val = version_a.get(key)
            b_val = version_b.get(key)
            
            # No changes
            if a_val == b_val:
                if a_val is not None:
                    merged[key] = a_val
            # Only A changed
            elif base_val == b_val:
                if a_val is not None:
                    merged[key] = a_val
            # Only B changed
            elif base_val == a_val:
                if b_val is not None:
                    merged[key] = b_val
            # Both changed - conflict
            else:
                conflicts.append(key)
                # Apply conflict resolution strategy
                resolved = self._resolve_conflict(key, base_val, a_val, b_val)
                if resolved is not None:
                    merged[key] = resolved
        
        return merged, conflicts
    
    def _resolve_conflict(self, key: str, base_val: Any, 
                         a_val: Any, b_val: Any) -> Optional[Any]:
        """Resolve conflicts based on value types"""
        # Numeric values - sum changes
        if all(isinstance(v, (int, float)) for v in [base_val, a_val, b_val] if v is not None):
            if base_val is None:
                base_val = 0
            delta_a = (a_val or 0) - base_val
            delta_b = (b_val or 0) - base_val
            return base_val + delta_a + delta_b
        
        # Lists - merge unique elements
        if all(isinstance(v, list) for v in [base_val, a_val, b_val] if v is not None):
            merged_list = list(set(
                (a_val or []) + (b_val or [])
            ))
            return merged_list
        
        # Default - last write wins (could be customized)
        return b_val if b_val is not None else a_val

3. Operational Transform

For collaborative editing scenarios:
@dataclass
class Operation:
    op_type: str  # 'insert', 'delete', 'update'
    position: int
    data: Any
    agent_id: str
    timestamp: float

class OperationalTransform:
    def __init__(self):
        self.document = []
        self.operations: List[Operation] = []
        self._lock = threading.Lock()
    
    def apply_operation(self, op: Operation) -> bool:
        """Apply an operation to the document"""
        with self._lock:
            # Transform operation against concurrent operations
            transformed_op = self._transform_operation(op)
            
            if transformed_op:
                self._execute_operation(transformed_op)
                self.operations.append(transformed_op)
                return True
            
            return False
    
    def _transform_operation(self, op: Operation) -> Optional[Operation]:
        """Transform operation against concurrent operations"""
        # Find concurrent operations
        concurrent_ops = [
            o for o in self.operations
            if o.timestamp > op.timestamp - 0.1  # Within 100ms
            and o.agent_id != op.agent_id
        ]
        
        transformed = Operation(
            op_type=op.op_type,
            position=op.position,
            data=op.data,
            agent_id=op.agent_id,
            timestamp=op.timestamp
        )
        
        # Transform against each concurrent operation
        for concurrent in concurrent_ops:
            transformed = self._transform_pair(transformed, concurrent)
            if transformed is None:
                return None
        
        return transformed
    
    def _transform_pair(self, op1: Operation, op2: Operation) -> Optional[Operation]:
        """Transform op1 against op2"""
        if op1.op_type == 'insert' and op2.op_type == 'insert':
            if op1.position < op2.position:
                return op1
            elif op1.position > op2.position:
                return Operation(
                    op_type=op1.op_type,
                    position=op1.position + 1,
                    data=op1.data,
                    agent_id=op1.agent_id,
                    timestamp=op1.timestamp
                )
            else:
                # Same position - use agent_id for deterministic ordering
                if op1.agent_id < op2.agent_id:
                    return op1
                else:
                    return Operation(
                        op_type=op1.op_type,
                        position=op1.position + 1,
                        data=op1.data,
                        agent_id=op1.agent_id,
                        timestamp=op1.timestamp
                    )
        
        # Add more transformation rules as needed
        return op1
    
    def _execute_operation(self, op: Operation) -> None:
        """Execute a transformed operation"""
        if op.op_type == 'insert':
            self.document.insert(op.position, op.data)
        elif op.op_type == 'delete':
            if 0 <= op.position < len(self.document):
                del self.document[op.position]
        elif op.op_type == 'update':
            if 0 <= op.position < len(self.document):
                self.document[op.position] = op.data

Distributed State Management

1. Consensus-Based State

Use consensus algorithms for critical state:
from enum import Enum
from typing import Set

class ConsensusState(Enum):
    FOLLOWER = "follower"
    CANDIDATE = "candidate" 
    LEADER = "leader"

class RaftNode:
    def __init__(self, node_id: str, peers: Set[str]):
        self.node_id = node_id
        self.peers = peers
        self.state = ConsensusState.FOLLOWER
        self.current_term = 0
        self.voted_for = None
        self.log = []
        self.commit_index = 0
        self.leader_id = None
        
    def propose_value(self, value: Any) -> bool:
        """Propose a value to be added to the replicated log"""
        if self.state != ConsensusState.LEADER:
            return False
        
        # Simplified - in reality would involve AppendEntries RPC
        entry = {
            "term": self.current_term,
            "value": value,
            "index": len(self.log)
        }
        
        # Add to own log
        self.log.append(entry)
        
        # Replicate to followers (simplified)
        confirmations = self._replicate_to_followers(entry)
        
        # Commit if majority confirms
        if confirmations >= len(self.peers) // 2:
            self.commit_index = entry["index"]
            return True
        
        return False
    
    def _replicate_to_followers(self, entry: Dict) -> int:
        """Replicate entry to followers (simplified)"""
        # In real implementation, would send AppendEntries RPC
        # and wait for responses
        return len(self.peers) // 2 + 1  # Simplified

2. Vector Clocks

Track causality in distributed systems:
class VectorClock:
    def __init__(self, node_id: str, nodes: Set[str]):
        self.node_id = node_id
        self.clock = {node: 0 for node in nodes}
    
    def increment(self) -> Dict[str, int]:
        """Increment this node's clock"""
        self.clock[self.node_id] += 1
        return self.clock.copy()
    
    def update(self, other_clock: Dict[str, int]) -> None:
        """Update clock based on received clock"""
        for node, timestamp in other_clock.items():
            if node in self.clock:
                self.clock[node] = max(self.clock[node], timestamp)
        
        # Increment own clock
        self.increment()
    
    def happens_before(self, other: Dict[str, int]) -> bool:
        """Check if this clock happens before other"""
        return all(
            self.clock.get(node, 0) <= other.get(node, 0)
            for node in set(self.clock.keys()) | set(other.keys())
        )
    
    def concurrent_with(self, other: Dict[str, int]) -> bool:
        """Check if clocks are concurrent"""
        return (not self.happens_before(other) and
                not self._other_happens_before(other))
    
    def _other_happens_before(self, other: Dict[str, int]) -> bool:
        """Check if other happens before this"""
        return all(
            other.get(node, 0) <= self.clock.get(node, 0)
            for node in set(self.clock.keys()) | set(other.keys())
        )

Best Practices

  1. Choose the Right Strategy: Different scenarios require different approaches
    • High contention: Use pessimistic locking
    • Low contention: Use optimistic concurrency
    • Collaborative editing: Use operational transform
    • Eventually consistent: Use CRDTs
  2. Design for Failure: Always handle conflict resolution failures
    def safe_state_update(state_manager, key, update_func, max_retries=3):
        for attempt in range(max_retries):
            state = state_manager.read(key)
            if state is None:
                state = VersionedState(data={}, version=-1, 
                                     last_modified_by="", timestamp=0)
            
            new_data = update_func(state.data)
            
            if state_manager.write(key, new_data, state.version, "agent"):
                return True
            
            # Exponential backoff
            time.sleep(0.1 * (2 ** attempt))
        
        return False
    
  3. Monitor Conflicts: Track conflict rates and patterns
    class ConflictMonitor:
        def __init__(self):
            self.conflict_count = 0
            self.conflict_types = {}
            
        def record_conflict(self, conflict_type: str, details: Dict):
            self.conflict_count += 1
            if conflict_type not in self.conflict_types:
                self.conflict_types[conflict_type] = 0
            self.conflict_types[conflict_type] += 1
            
            # Log for analysis
            logger.warning(f"Conflict detected: {conflict_type}", extra=details)
    
  4. Test Concurrent Scenarios: Always test with concurrent operations
    def test_concurrent_updates():
        state_manager = OptimisticStateManager()
        
        def update_worker(worker_id, iterations):
            for i in range(iterations):
                safe_state_update(
                    state_manager,
                    "shared_counter",
                    lambda data: {**data, worker_id: i}
                )
        
        threads = []
        for i in range(5):
            t = threading.Thread(target=update_worker, args=(f"worker_{i}", 100))
            threads.append(t)
            t.start()
        
        for t in threads:
            t.join()
        
        # Verify final state
        final_state = state_manager.read("shared_counter")
        assert final_state is not None
        assert len(final_state.data) == 5
    

Conclusion

State conflict resolution is a critical aspect of building reliable multi-agent systems. By choosing appropriate strategies and implementing them correctly, you can build systems that handle concurrent operations gracefully while maintaining data consistency.