Source code for orka.memory.kafka_logger

# OrKa: Orchestrator Kit Agents
# Copyright © 2025 Marco Somma
#
# This file is part of OrKa – https://github.com/marcosomma/orka-resoning
#
# Licensed under the Apache License, Version 2.0 (Apache 2.0).
# You may not use this file for commercial purposes without explicit permission.
#
# Full license: https://www.apache.org/licenses/LICENSE-2.0
# For commercial use, contact: marcosomma.work@gmail.com
#
# Required attribution: OrKa by Marco Somma – https://github.com/marcosomma/orka-resoning

"""
Kafka Memory Logger Implementation
=================================

This file contains the hybrid KafkaMemoryLogger implementation that uses
Kafka topics for event streaming and Redis for memory operations.
This provides the best of both worlds: Kafka's event streaming capabilities
with Redis's fast memory operations.
"""

import json
import logging
import os
from datetime import UTC, datetime, timedelta
from typing import Any, Dict, List, Optional, Union

import redis

from .base_logger import BaseMemoryLogger

logger = logging.getLogger(__name__)


[docs] class KafkaMemoryLogger(BaseMemoryLogger): """ A hybrid memory logger that uses Kafka for event streaming and Redis for memory operations. This implementation combines: - Kafka topics for persistent event streaming and audit trails - Redis for fast memory operations (hset, hget, sadd, etc.) and fork/join coordination This approach provides both the scalability of Kafka and the performance of Redis. """
[docs] def __init__( self, bootstrap_servers: str = "localhost:9092", redis_url: Optional[str] = None, stream_key: str = "orka:memory", debug_keep_previous_outputs: bool = False, decay_config: Optional[Dict[str, Any]] = None, enable_hnsw: bool = True, vector_params: Optional[Dict[str, Any]] = None, ) -> None: """ Initialize the hybrid Kafka + RedisStack memory logger. Args: bootstrap_servers: Kafka bootstrap servers. Defaults to "localhost:9092". redis_url: RedisStack connection URL. Defaults to environment variable REDIS_URL. stream_key: Key for the memory stream. Defaults to "orka:memory". debug_keep_previous_outputs: If True, keeps previous_outputs in log files for debugging. decay_config: Configuration for memory decay functionality. enable_hnsw: Enable HNSW vector indexing in RedisStack backend. vector_params: HNSW configuration parameters. """ super().__init__(stream_key, debug_keep_previous_outputs, decay_config) # Kafka setup self.bootstrap_servers = bootstrap_servers self.main_topic = "orka-memory-events" self.producer = None self.consumer = None # ✅ CRITICAL: Use RedisStack for memory operations instead of basic Redis self.redis_url = redis_url or os.getenv("REDIS_URL", "redis://localhost:6380/0") # Initialize Redis client variables self.redis_client = None self._redis_memory_logger = None # Create RedisStack logger for enhanced memory operations try: from .redisstack_logger import RedisStackMemoryLogger self._redis_memory_logger = RedisStackMemoryLogger( redis_url=self.redis_url, stream_key=stream_key, debug_keep_previous_outputs=debug_keep_previous_outputs, decay_config=decay_config, enable_hnsw=enable_hnsw, vector_params=vector_params, ) # Ensure enhanced index is ready self._redis_memory_logger.ensure_index() logger.info("✅ Kafka backend using RedisStack for memory operations") except ImportError: # Fallback to basic Redis self.redis_client = redis.from_url(self.redis_url) self._redis_memory_logger = None logger.warning("⚠️ RedisStack not available, using basic Redis for memory operations") except Exception as e: # If RedisStack creation fails for any other reason, fall back to basic Redis logger.warning( f"⚠️ RedisStack initialization failed ({e}), using basic Redis for memory operations", ) self._redis_memory_logger = None # Initialize basic Redis client as fallback self.redis_client = redis.from_url(self.redis_url)
@property def redis(self) -> redis.Redis: """Return Redis client - prefer RedisStack client if available.""" if self._redis_memory_logger: return self._redis_memory_logger.redis return self.redis_client def _store_in_redis(self, event: dict, **kwargs): """Store event using RedisStack logger if available.""" if self._redis_memory_logger: # ✅ Use RedisStack logger for enhanced storage self._redis_memory_logger.log( agent_id=event["agent_id"], event_type=event["event_type"], payload=event["payload"], step=kwargs.get("step"), run_id=kwargs.get("run_id"), fork_group=kwargs.get("fork_group"), parent=kwargs.get("parent"), previous_outputs=kwargs.get("previous_outputs"), agent_decay_config=kwargs.get("agent_decay_config"), ) else: # Fallback to basic Redis streams try: # Prepare the Redis entry redis_entry = { "agent_id": event["agent_id"], "event_type": event["event_type"], "timestamp": event.get("timestamp"), "run_id": kwargs.get("run_id", "default"), "step": str(kwargs.get("step", -1)), "payload": json.dumps(event["payload"]), } # Add decay metadata if available if hasattr(self, "decay_config") and self.decay_config: decay_metadata = self._generate_decay_metadata(event) redis_entry.update(decay_metadata) # Write to Redis stream self.redis_client.xadd(self.stream_key, redis_entry) logger.debug(f"Stored event in basic Redis stream: {self.stream_key}") except Exception as e: logger.error(f"Failed to store event in basic Redis: {e}")
[docs] def log( self, agent_id: str, event_type: str, payload: Dict[str, Any], step: Optional[int] = None, run_id: Optional[str] = None, fork_group: Optional[str] = None, parent: Optional[str] = None, previous_outputs: Optional[Dict[str, Any]] = None, agent_decay_config: Optional[Dict[str, Any]] = None, ) -> None: """ Log an event to both Kafka (for streaming) and Redis (for memory operations). This hybrid approach ensures events are durably stored in Kafka while also being available in Redis for fast memory operations and coordination. """ # Sanitize payload safe_payload = self._sanitize_for_json(payload) # Handle decay configuration decay_metadata = {} if self.decay_config.get("enabled", False): # Temporarily merge agent-specific decay config old_config = self.decay_config try: if agent_decay_config: # Create temporary merged config merged_config = {**self.decay_config} merged_config.update(agent_decay_config) self.decay_config = merged_config # Calculate importance score and memory type importance_score = self._calculate_importance_score( agent_id, event_type, safe_payload, ) memory_type = self._classify_memory_type( event_type, importance_score, self._classify_memory_category(event_type, agent_id, safe_payload), ) memory_category = self._classify_memory_category(event_type, agent_id, safe_payload) # Calculate expiration time current_time = datetime.now(UTC) if memory_type == "short_term": # Check agent-level config first, then fall back to global config expire_hours = self.decay_config.get( "short_term_hours", ) or self.decay_config.get("default_short_term_hours", 1.0) expire_time = current_time + timedelta(hours=expire_hours) else: # long_term # Check agent-level config first, then fall back to global config expire_hours = self.decay_config.get( "long_term_hours", ) or self.decay_config.get("default_long_term_hours", 24.0) expire_time = current_time + timedelta(hours=expire_hours) decay_metadata = { "orka_importance_score": str(importance_score), "orka_memory_type": memory_type, "orka_memory_category": memory_category, "orka_expire_time": expire_time.isoformat(), "orka_created_time": current_time.isoformat(), } finally: # Restore original config self.decay_config = old_config # Create event record with decay metadata event = { "agent_id": agent_id, "event_type": event_type, "payload": safe_payload, "step": step, "run_id": run_id, "fork_group": fork_group, "parent": parent, "timestamp": datetime.now(UTC).isoformat(), } # Add decay metadata to the event event.update(decay_metadata) # CRITICAL: Add event to local memory buffer for file operations # This ensures events are included in the JSON trace files self.memory.append(event) # Store in Redis for memory operations (similar to RedisMemoryLogger) self._store_in_redis( event, step=step, run_id=run_id, fork_group=fork_group, parent=parent, previous_outputs=previous_outputs, agent_decay_config=agent_decay_config, )
def _send_to_kafka(self, event: dict, run_id: Optional[str], agent_id: str): """Send event to Kafka for streaming.""" try: message_key = f"{run_id}:{agent_id}" if run_id else agent_id # Use schema serialization if available if self.use_schema_registry and self.serializer: try: # Use confluent-kafka with schema serialization from confluent_kafka.serialization import MessageField, SerializationContext serialized_value = self.serializer( event, SerializationContext(self.main_topic, MessageField.VALUE), ) self.producer.produce( topic=self.main_topic, key=message_key, value=serialized_value, ) if self.synchronous_send: self.producer.flush() logger.debug( f"Sent event to Kafka with schema: {agent_id}:{event['event_type']}", ) except Exception as schema_error: logger.warning( f"Schema serialization failed: {schema_error}, using JSON fallback", ) # Fall back to JSON serialization self._send_json_message(message_key, event) else: # Use JSON serialization self._send_json_message(message_key, event) except Exception as e: logger.error(f"Failed to send event to Kafka: {e}") # Event is still stored in Redis, so we can continue def _send_json_message(self, message_key: str, event: dict): """Send message using JSON serialization (fallback).""" # Handle different producer types if hasattr(self.producer, "produce"): # confluent-kafka self.producer.produce( topic=self.main_topic, key=message_key, value=json.dumps(event).encode("utf-8"), ) if self.synchronous_send: self.producer.flush() else: # kafka-python future = self.producer.send( topic=self.main_topic, key=message_key, value=event, ) if self.synchronous_send: future.get(timeout=10) logger.debug(f"Sent event to Kafka with JSON: {event['agent_id']}:{event['event_type']}") def _sanitize_payload(self, payload: Dict[str, Any]) -> Dict[str, Any]: """Sanitize payload to ensure JSON serialization.""" if not isinstance(payload, dict): return {"value": str(payload)} sanitized = {} for key, value in payload.items(): try: json.dumps(value) # Test if serializable sanitized[key] = value except (TypeError, ValueError): sanitized[key] = str(value) return sanitized
[docs] def tail(self, count: int = 10) -> List[Dict[str, Any]]: """Retrieve recent events from memory buffer.""" return self.memory[-count:] if self.memory else []
# Redis operations - delegate to actual Redis client
[docs] def hset(self, name: str, key: str, value: Union[str, bytes, int, float]) -> int: """Set a hash field using Redis.""" return self.redis_client.hset(name, key, value)
[docs] def hget(self, name: str, key: str) -> Optional[str]: """Get a hash field using Redis.""" result = self.redis_client.hget(name, key) return result.decode() if result else None
[docs] def hkeys(self, name: str) -> List[str]: """Get hash keys using Redis.""" return [key.decode() for key in self.redis_client.hkeys(name)]
[docs] def hdel(self, name: str, *keys: str) -> int: """Delete hash fields using Redis.""" return self.redis_client.hdel(name, *keys)
[docs] def smembers(self, name: str) -> List[str]: """Get set members using Redis.""" return [member.decode() for member in self.redis_client.smembers(name)]
[docs] def sadd(self, name: str, *values: str) -> int: """Add to set using Redis.""" return self.redis_client.sadd(name, *values)
[docs] def srem(self, name: str, *values: str) -> int: """Remove from set using Redis.""" return self.redis_client.srem(name, *values)
[docs] def get(self, key: str) -> Optional[str]: """Get a value using Redis.""" result = self.redis_client.get(key) return result.decode() if result else None
[docs] def set(self, key: str, value: Union[str, bytes, int, float]) -> bool: """Set a value using Redis.""" return self.redis_client.set(key, value)
[docs] def delete(self, *keys: str) -> int: """Delete keys using Redis.""" return self.redis_client.delete(*keys)
# 🎯 NEW: Enhanced memory operations - delegate to RedisStack logger
[docs] def search_memories( self, query: str, num_results: int = 10, trace_id: Optional[str] = None, node_id: Optional[str] = None, memory_type: Optional[str] = None, min_importance: Optional[float] = None, log_type: str = "memory", namespace: Optional[str] = None, ) -> List[Dict[str, Any]]: """Search memories using RedisStack logger if available, otherwise return empty list.""" logger.debug( f"🔍 KafkaMemoryLogger.search_memories: _redis_memory_logger={self._redis_memory_logger is not None}, namespace='{namespace}'", ) if self._redis_memory_logger and hasattr(self._redis_memory_logger, "search_memories"): logger.debug(f"🔍 Delegating to RedisStackMemoryLogger with namespace='{namespace}'") results = self._redis_memory_logger.search_memories( query=query, num_results=num_results, trace_id=trace_id, node_id=node_id, memory_type=memory_type, min_importance=min_importance, log_type=log_type, namespace=namespace, ) logger.debug(f"🔍 RedisStack search returned {len(results)} results") return results else: logger.warning("RedisStack not available for memory search, returning empty results") return []
[docs] def log_memory( self, content: str, node_id: str, trace_id: str, metadata: Optional[Dict[str, Any]] = None, importance_score: float = 1.0, memory_type: str = "short_term", expiry_hours: Optional[float] = None, ) -> str: """Log memory using RedisStack logger if available.""" if self._redis_memory_logger and hasattr(self._redis_memory_logger, "log_memory"): return self._redis_memory_logger.log_memory( content=content, node_id=node_id, trace_id=trace_id, metadata=metadata, importance_score=importance_score, memory_type=memory_type, expiry_hours=expiry_hours, ) else: logger.warning("RedisStack not available for memory logging") return f"fallback_memory_{datetime.now(UTC).timestamp()}"
[docs] def ensure_index(self) -> bool: """Ensure memory index exists using RedisStack logger if available.""" if self._redis_memory_logger and hasattr(self._redis_memory_logger, "ensure_index"): return self._redis_memory_logger.ensure_index() return False
[docs] def close(self) -> None: """Close both Kafka producer and Redis connection.""" # Close Kafka producer if self.producer: try: if hasattr(self.producer, "close"): # kafka-python self.producer.close() elif hasattr(self.producer, "flush"): # confluent-kafka self.producer.flush() logger.info("Kafka producer closed") except Exception as e: logger.error(f"Error closing Kafka producer: {e}") # Close Redis connection try: if self._redis_memory_logger: # Close RedisStack logger if available if hasattr(self._redis_memory_logger, "close"): self._redis_memory_logger.close() elif hasattr(self._redis_memory_logger, "client") and hasattr( self._redis_memory_logger.client, "close", ): self._redis_memory_logger.client.close() logger.info("RedisStack memory logger closed") elif self.redis_client: # Close basic Redis client self.redis_client.close() logger.info("Redis connection closed") except Exception as e: logger.error(f"Error closing Redis connection: {e}")
[docs] def __del__(self): """Cleanup on object deletion.""" self.close()
[docs] def cleanup_expired_memories(self, dry_run: bool = False) -> Dict[str, Any]: """ Clean up expired memory entries using Redis-based approach. This delegates to Redis for cleanup while also cleaning the in-memory buffer. """ try: # Import Redis memory logger for cleanup logic from .redis_logger import RedisMemoryLogger # Create a temporary Redis logger to reuse cleanup logic temp_redis_logger = RedisMemoryLogger( redis_url=self.redis_url, stream_key=self.stream_key, decay_config=self.decay_config, ) # Use Redis cleanup logic stats = temp_redis_logger.cleanup_expired_memories(dry_run=dry_run) stats["backend"] = "kafka+redis" # Also clean up in-memory buffer if decay is enabled and not dry run if not dry_run and self.decay_config.get("enabled", False): current_time = datetime.now(UTC) expired_indices = [] for i, entry in enumerate(self.memory): expire_time_str = entry.get("orka_expire_time") if expire_time_str: try: expire_time = datetime.fromisoformat(expire_time_str) if current_time > expire_time: expired_indices.append(i) except (ValueError, TypeError): continue # Remove expired entries from memory buffer for i in reversed(expired_indices): del self.memory[i] logger.info(f"Cleaned up {len(expired_indices)} expired entries from memory buffer") return stats except Exception as e: logger.error(f"Error during hybrid memory cleanup: {e}") return { "error": str(e), "backend": "kafka+redis", "timestamp": datetime.now(UTC).isoformat(), "deleted_count": 0, }
[docs] def get_memory_stats(self) -> Dict[str, Any]: """ Get memory usage statistics from both Redis backend and local memory buffer. """ try: current_time = datetime.now(UTC) stats = { "timestamp": current_time.isoformat(), "backend": "kafka+redis", "kafka_topic": self.main_topic, "memory_buffer_size": len(self.memory), "decay_enabled": self.decay_config.get("enabled", False), "total_streams": 0, "total_entries": 0, "entries_by_type": {}, "entries_by_memory_type": {"short_term": 0, "long_term": 0, "unknown": 0}, "entries_by_category": {"stored": 0, "log": 0, "unknown": 0}, "expired_entries": 0, "streams_detail": [], } # Use the actual Redis client from the Kafka backend redis_client = self.redis # Get all stream keys that match OrKa patterns stream_patterns = [ "orka:memory:*", # All OrKa memory streams (this is what Kafka backend creates) self.stream_key, # Base stream key f"{self.stream_key}:*", # Namespace-specific streams ] processed_streams = set() for pattern in stream_patterns: try: stream_keys = redis_client.keys(pattern) for stream_key in stream_keys: stream_key_str = ( stream_key.decode() if isinstance(stream_key, bytes) else stream_key ) if stream_key_str in processed_streams: continue processed_streams.add(stream_key_str) try: # Check if it's actually a stream key_type = redis_client.type(stream_key) if key_type != b"stream" and key_type != "stream": continue # Get stream info and entries stream_info = redis_client.xinfo_stream(stream_key) entries = redis_client.xrange(stream_key) stream_stats = { "stream": stream_key_str, "length": stream_info.get("length", 0), "entries_by_type": {}, "entries_by_memory_type": { "short_term": 0, "long_term": 0, "unknown": 0, }, "entries_by_category": { "stored": 0, "log": 0, "unknown": 0, }, "expired_entries": 0, "active_entries": 0, } stats["total_streams"] += 1 for entry_id, entry_data in entries: # Check if expired first is_expired = False expire_time_field = entry_data.get( b"orka_expire_time", ) or entry_data.get("orka_expire_time") if expire_time_field: try: expire_time_str = ( expire_time_field.decode() if isinstance(expire_time_field, bytes) else expire_time_field ) expire_time = datetime.fromisoformat(expire_time_str) if current_time > expire_time: is_expired = True stream_stats["expired_entries"] += 1 stats["expired_entries"] += 1 except (ValueError, TypeError): pass # Skip invalid dates # Only count non-expired entries in the main statistics if not is_expired: stream_stats["active_entries"] += 1 stats["total_entries"] += 1 # Count by event type event_type_field = entry_data.get( b"event_type", ) or entry_data.get("event_type") event_type = "unknown" if event_type_field: event_type = ( event_type_field.decode() if isinstance(event_type_field, bytes) else event_type_field ) stream_stats["entries_by_type"][event_type] = ( stream_stats["entries_by_type"].get(event_type, 0) + 1 ) stats["entries_by_type"][event_type] = ( stats["entries_by_type"].get(event_type, 0) + 1 ) # Count by memory category memory_category_field = entry_data.get( b"orka_memory_category", ) or entry_data.get("orka_memory_category") memory_category = "unknown" if memory_category_field: memory_category = ( memory_category_field.decode() if isinstance(memory_category_field, bytes) else memory_category_field ) if memory_category in stream_stats["entries_by_category"]: stream_stats["entries_by_category"][memory_category] += 1 stats["entries_by_category"][memory_category] += 1 else: stream_stats["entries_by_category"]["unknown"] += 1 stats["entries_by_category"]["unknown"] += 1 # Count by memory type ONLY for non-log entries if memory_category != "log": memory_type_field = entry_data.get( b"orka_memory_type", ) or entry_data.get("orka_memory_type") memory_type = "unknown" if memory_type_field: memory_type = ( memory_type_field.decode() if isinstance(memory_type_field, bytes) else memory_type_field ) if memory_type in stream_stats["entries_by_memory_type"]: stream_stats["entries_by_memory_type"][memory_type] += 1 stats["entries_by_memory_type"][memory_type] += 1 else: stream_stats["entries_by_memory_type"]["unknown"] += 1 stats["entries_by_memory_type"]["unknown"] += 1 stats["streams_detail"].append(stream_stats) except Exception as e: logger.error(f"Error getting stats for stream {stream_key_str}: {e}") except Exception as e: logger.error(f"Error getting keys for pattern {pattern}: {e}") # Add decay configuration info if self.decay_config.get("enabled", False): stats["decay_config"] = { "short_term_hours": self.decay_config["default_short_term_hours"], "long_term_hours": self.decay_config["default_long_term_hours"], "check_interval_minutes": self.decay_config["check_interval_minutes"], "last_decay_check": self._last_decay_check.isoformat() if hasattr(self, "_last_decay_check") and self._last_decay_check else None, } # Enhance stats with local memory buffer analysis # This provides more accurate decay metadata since local buffer has proper field names local_stats = { "entries_by_memory_type": {"short_term": 0, "long_term": 0, "unknown": 0}, "entries_by_category": {"stored": 0, "log": 0, "unknown": 0}, } for entry in self.memory: # Memory type distribution memory_type = entry.get("orka_memory_type", "unknown") if memory_type in local_stats["entries_by_memory_type"]: local_stats["entries_by_memory_type"][memory_type] += 1 else: local_stats["entries_by_memory_type"]["unknown"] += 1 # Memory category distribution memory_category = entry.get("orka_memory_category", "unknown") if memory_category in local_stats["entries_by_category"]: local_stats["entries_by_category"][memory_category] += 1 else: local_stats["entries_by_category"]["unknown"] += 1 # If local buffer has meaningful data, use it to enhance Redis stats if len(self.memory) > 0: # Combine Redis stats with local buffer insights if local_stats["entries_by_memory_type"]["unknown"] < len(self.memory): # Local buffer has better memory type data stats["entries_by_memory_type"] = local_stats["entries_by_memory_type"] if local_stats["entries_by_category"]["unknown"] < len(self.memory): # Local buffer has better category data stats["entries_by_category"] = local_stats["entries_by_category"] # Add local buffer specific metrics stats["local_buffer_insights"] = { "total_entries": len(self.memory), "entries_with_decay_metadata": sum( 1 for entry in self.memory if entry.get("orka_memory_type") and entry.get("orka_memory_category") ), "memory_types": local_stats["entries_by_memory_type"], "categories": local_stats["entries_by_category"], } return stats except Exception as e: logger.error(f"Error getting hybrid memory statistics: {e}") return { "error": str(e), "backend": "kafka+redis", "timestamp": datetime.now(UTC).isoformat(), }