356 lines
13 KiB
Python
356 lines
13 KiB
Python
from typing import Dict, List, Optional, Any
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
import json
|
|
import re
|
|
from datetime import datetime
|
|
import uuid
|
|
|
|
class SlotType(Enum):
|
|
TEXT = "text"
|
|
NUMBER = "number"
|
|
EMAIL = "email"
|
|
PHONE = "phone"
|
|
DATE = "date"
|
|
BOOLEAN = "boolean"
|
|
CHOICE = "choice"
|
|
|
|
class ValidationType(Enum):
|
|
REGEX = "regex"
|
|
RANGE = "range"
|
|
LENGTH = "length"
|
|
CHOICES = "choices"
|
|
CUSTOM = "custom"
|
|
|
|
class QuestionType(Enum):
|
|
INITIAL = "initial"
|
|
CLARIFICATION = "clarification"
|
|
VALIDATION_ERROR = "validation_error"
|
|
CONFIRMATION = "confirmation"
|
|
|
|
@dataclass
|
|
class SlotValidation:
|
|
validation_type: ValidationType
|
|
validation_rule: Dict[str, Any]
|
|
error_message: str
|
|
|
|
@dataclass
|
|
class SlotQuestion:
|
|
question_type: QuestionType
|
|
question_text: str
|
|
follow_up_text: Optional[str] = None
|
|
context_conditions: Optional[Dict[str, Any]] = None
|
|
|
|
@dataclass
|
|
class SlotDefinition:
|
|
slot_name: str
|
|
slot_type: SlotType
|
|
is_required: bool = True
|
|
priority: int = 1
|
|
depends_on_slot: Optional[str] = None
|
|
dependency_value: Optional[List[str]] = None
|
|
validations: List[SlotValidation] = field(default_factory=list)
|
|
questions: List[SlotQuestion] = field(default_factory=list)
|
|
|
|
@dataclass
|
|
class ConversationSession:
|
|
session_id: str
|
|
user_id: str
|
|
flow_id: str
|
|
collected_slots: Dict[str, Any] = field(default_factory=dict)
|
|
current_slot: Optional[str] = None
|
|
context_data: Dict[str, Any] = field(default_factory=dict)
|
|
session_status: str = "active"
|
|
|
|
class SlotManager:
|
|
def __init__(self, db_connection):
|
|
self.db = db_connection
|
|
self.flows = {}
|
|
self.validators = {
|
|
ValidationType.REGEX: self._validate_regex,
|
|
ValidationType.LENGTH: self._validate_length,
|
|
ValidationType.RANGE: self._validate_range,
|
|
ValidationType.CHOICES: self._validate_choices,
|
|
}
|
|
|
|
def load_flow_configuration(self, flow_name: str) -> List[SlotDefinition]:
|
|
"""Load slot configuration from database"""
|
|
query = """
|
|
SELECT
|
|
sd.slot_name, sd.slot_type, sd.is_required, sd.priority,
|
|
sd.depends_on_slot, sd.dependency_value,
|
|
sv.validation_type, sv.validation_rule, sv.error_message,
|
|
sq.question_type, sq.question_text, sq.follow_up_text, sq.context_conditions
|
|
FROM conversation_flows cf
|
|
JOIN slot_definitions sd ON cf.id = sd.flow_id
|
|
LEFT JOIN slot_validations sv ON sd.id = sv.slot_id
|
|
LEFT JOIN slot_questions sq ON sd.id = sq.slot_id
|
|
WHERE cf.name = %s AND cf.is_active = true
|
|
ORDER BY sd.priority, sd.slot_name
|
|
"""
|
|
|
|
results = self.db.execute(query, (flow_name,)).fetchall()
|
|
|
|
# Group by slot_name
|
|
slots_data = {}
|
|
for row in results:
|
|
slot_name = row['slot_name']
|
|
if slot_name not in slots_data:
|
|
slots_data[slot_name] = {
|
|
'definition': {
|
|
'slot_name': row['slot_name'],
|
|
'slot_type': SlotType(row['slot_type']),
|
|
'is_required': row['is_required'],
|
|
'priority': row['priority'],
|
|
'depends_on_slot': row['depends_on_slot'],
|
|
'dependency_value': row['dependency_value'],
|
|
},
|
|
'validations': [],
|
|
'questions': []
|
|
}
|
|
|
|
# Add validation if present
|
|
if row['validation_type']:
|
|
validation = SlotValidation(
|
|
validation_type=ValidationType(row['validation_type']),
|
|
validation_rule=row['validation_rule'],
|
|
error_message=row['error_message']
|
|
)
|
|
slots_data[slot_name]['validations'].append(validation)
|
|
|
|
# Add question if present
|
|
if row['question_type']:
|
|
question = SlotQuestion(
|
|
question_type=QuestionType(row['question_type']),
|
|
question_text=row['question_text'],
|
|
follow_up_text=row['follow_up_text'],
|
|
context_conditions=row['context_conditions']
|
|
)
|
|
slots_data[slot_name]['questions'].append(question)
|
|
|
|
# Convert to SlotDefinition objects
|
|
slot_definitions = []
|
|
for slot_data in slots_data.values():
|
|
definition = SlotDefinition(
|
|
**slot_data['definition'],
|
|
validations=slot_data['validations'],
|
|
questions=slot_data['questions']
|
|
)
|
|
slot_definitions.append(definition)
|
|
|
|
return sorted(slot_definitions, key=lambda x: x.priority)
|
|
|
|
def get_next_slot_to_collect(self, session: ConversationSession,
|
|
slot_definitions: List[SlotDefinition]) -> Optional[SlotDefinition]:
|
|
"""Determine the next slot that needs to be collected"""
|
|
|
|
for slot_def in slot_definitions:
|
|
# Skip if already collected
|
|
if slot_def.slot_name in session.collected_slots:
|
|
continue
|
|
|
|
# Check if slot is required
|
|
if not slot_def.is_required:
|
|
continue
|
|
|
|
# Check dependencies
|
|
if slot_def.depends_on_slot:
|
|
if slot_def.depends_on_slot not in session.collected_slots:
|
|
continue # Dependency not met
|
|
|
|
collected_value = session.collected_slots[slot_def.depends_on_slot]
|
|
if slot_def.dependency_value and collected_value not in slot_def.dependency_value:
|
|
continue # Dependency condition not met
|
|
|
|
return slot_def
|
|
|
|
return None
|
|
|
|
def validate_slot_value(self, slot_def: SlotDefinition,
|
|
user_input: str) -> tuple[bool, Any, List[str]]:
|
|
"""Validate user input for a slot"""
|
|
errors = []
|
|
extracted_value = user_input.strip()
|
|
|
|
# Type-specific extraction
|
|
if slot_def.slot_type == SlotType.NUMBER:
|
|
try:
|
|
extracted_value = float(extracted_value)
|
|
except ValueError:
|
|
errors.append("Please enter a valid number")
|
|
return False, user_input, errors
|
|
|
|
elif slot_def.slot_type == SlotType.BOOLEAN:
|
|
lower_input = extracted_value.lower()
|
|
if lower_input in ['yes', 'y', 'true', '1']:
|
|
extracted_value = True
|
|
elif lower_input in ['no', 'n', 'false', '0']:
|
|
extracted_value = False
|
|
else:
|
|
errors.append("Please answer with yes or no")
|
|
return False, user_input, errors
|
|
|
|
# Run validations
|
|
for validation in slot_def.validations:
|
|
is_valid, error_msg = self._run_validation(validation, extracted_value)
|
|
if not is_valid:
|
|
errors.append(error_msg)
|
|
|
|
return len(errors) == 0, extracted_value, errors
|
|
|
|
def _run_validation(self, validation: SlotValidation, value: Any) -> tuple[bool, str]:
|
|
"""Run a specific validation rule"""
|
|
validator = self.validators.get(validation.validation_type)
|
|
if not validator:
|
|
return True, ""
|
|
|
|
try:
|
|
is_valid = validator(value, validation.validation_rule)
|
|
return is_valid, validation.error_message if not is_valid else ""
|
|
except Exception as e:
|
|
return False, f"Validation error: {str(e)}"
|
|
|
|
def _validate_regex(self, value: str, rule: Dict[str, Any]) -> bool:
|
|
pattern = rule.get('pattern', '')
|
|
return bool(re.match(pattern, str(value)))
|
|
|
|
def _validate_length(self, value: str, rule: Dict[str, Any]) -> bool:
|
|
length = len(str(value))
|
|
min_len = rule.get('min', 0)
|
|
max_len = rule.get('max', float('inf'))
|
|
return min_len <= length <= max_len
|
|
|
|
def _validate_range(self, value: float, rule: Dict[str, Any]) -> bool:
|
|
min_val = rule.get('min', float('-inf'))
|
|
max_val = rule.get('max', float('inf'))
|
|
return min_val <= value <= max_val
|
|
|
|
def _validate_choices(self, value: Any, rule: Dict[str, Any]) -> bool:
|
|
choices = rule.get('choices', [])
|
|
return value in choices
|
|
|
|
def get_question_for_slot(self, slot_def: SlotDefinition,
|
|
question_type: QuestionType = QuestionType.INITIAL,
|
|
context: Optional[Dict[str, Any]] = None) -> str:
|
|
"""Get appropriate question text for a slot"""
|
|
|
|
# Find matching question
|
|
for question in slot_def.questions:
|
|
if question.question_type == question_type:
|
|
# Check context conditions if present
|
|
if question.context_conditions and context:
|
|
if not self._check_context_conditions(question.context_conditions, context):
|
|
continue
|
|
|
|
question_text = question.question_text
|
|
if question.follow_up_text:
|
|
question_text += f" {question.follow_up_text}"
|
|
|
|
return question_text
|
|
|
|
# Fallback to generic question
|
|
return f"Could you please provide your {slot_def.slot_name.replace('_', ' ')}?"
|
|
|
|
def _check_context_conditions(self, conditions: Dict[str, Any],
|
|
context: Dict[str, Any]) -> bool:
|
|
"""Check if context conditions are met"""
|
|
for key, expected_value in conditions.items():
|
|
if key not in context:
|
|
return False
|
|
if context[key] != expected_value:
|
|
return False
|
|
return True
|
|
|
|
def save_conversation_session(self, session: ConversationSession):
|
|
"""Save session state to database"""
|
|
query = """
|
|
INSERT INTO conversation_sessions
|
|
(id, user_id, flow_id, session_status, current_slot, collected_slots, context_data)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
ON CONFLICT (id) DO UPDATE SET
|
|
session_status = EXCLUDED.session_status,
|
|
current_slot = EXCLUDED.current_slot,
|
|
collected_slots = EXCLUDED.collected_slots,
|
|
context_data = EXCLUDED.context_data,
|
|
last_interaction = CURRENT_TIMESTAMP
|
|
"""
|
|
|
|
self.db.execute(query, (
|
|
session.session_id,
|
|
session.user_id,
|
|
session.flow_id,
|
|
session.session_status,
|
|
session.current_slot,
|
|
json.dumps(session.collected_slots),
|
|
json.dumps(session.context_data)
|
|
))
|
|
|
|
def is_conversation_complete(self, session: ConversationSession,
|
|
slot_definitions: List[SlotDefinition]) -> bool:
|
|
"""Check if all required slots have been collected"""
|
|
for slot_def in slot_definitions:
|
|
if slot_def.is_required and slot_def.slot_name not in session.collected_slots:
|
|
# Check if this slot should be skipped due to dependencies
|
|
if slot_def.depends_on_slot:
|
|
if slot_def.depends_on_slot not in session.collected_slots:
|
|
continue
|
|
|
|
collected_value = session.collected_slots[slot_def.depends_on_slot]
|
|
if slot_def.dependency_value and collected_value not in slot_def.dependency_value:
|
|
continue
|
|
|
|
return False
|
|
|
|
return True
|
|
|
|
# Example usage
|
|
def example_conversation_flow():
|
|
# This would be your actual database connection
|
|
db_connection = None # Replace with real DB connection
|
|
|
|
slot_manager = SlotManager(db_connection)
|
|
|
|
# Load flow configuration
|
|
flow_name = "customer_onboarding"
|
|
slot_definitions = slot_manager.load_flow_configuration(flow_name)
|
|
|
|
# Create new session
|
|
session = ConversationSession(
|
|
session_id=str(uuid.uuid4()),
|
|
user_id="user123",
|
|
flow_id=flow_name
|
|
)
|
|
|
|
# Main conversation loop
|
|
while not slot_manager.is_conversation_complete(session, slot_definitions):
|
|
# Get next slot to collect
|
|
next_slot = slot_manager.get_next_slot_to_collect(session, slot_definitions)
|
|
if not next_slot:
|
|
break
|
|
|
|
# Get question for this slot
|
|
question = slot_manager.get_question_for_slot(next_slot)
|
|
print(f"Bot: {question}")
|
|
|
|
# Simulate user input
|
|
user_input = input("User: ")
|
|
|
|
# Validate input
|
|
is_valid, extracted_value, errors = slot_manager.validate_slot_value(next_slot, user_input)
|
|
|
|
if is_valid:
|
|
session.collected_slots[next_slot.slot_name] = extracted_value
|
|
print(f"Bot: Great! I've got your {next_slot.slot_name}.")
|
|
else:
|
|
error_question = slot_manager.get_question_for_slot(
|
|
next_slot, QuestionType.VALIDATION_ERROR
|
|
)
|
|
print(f"Bot: {error_question}")
|
|
print(f"Errors: {', '.join(errors)}")
|
|
|
|
# Save session
|
|
slot_manager.save_conversation_session(session)
|
|
|
|
print("Bot: Thanks! I have all the information I need.")
|
|
session.session_status = "completed"
|
|
slot_manager.save_conversation_session(session)
|