Spaces:
Sleeping
Sleeping
Arthur Passuello
commited on
Commit
·
b5246f1
1
Parent(s):
489aeb9
Added missing sources
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- shared_utils/__init__.py +0 -0
- shared_utils/__pycache__/__init__.cpython-312.pyc +0 -0
- shared_utils/document_processing/__init__.py +0 -0
- shared_utils/document_processing/__pycache__/__init__.cpython-312.pyc +0 -0
- shared_utils/document_processing/__pycache__/chunker.cpython-312.pyc +0 -0
- shared_utils/document_processing/__pycache__/hybrid_parser.cpython-312.pyc +0 -0
- shared_utils/document_processing/__pycache__/paragraph_chunker.cpython-312.pyc +0 -0
- shared_utils/document_processing/__pycache__/pdf_parser.cpython-312.pyc +0 -0
- shared_utils/document_processing/__pycache__/pdfplumber_parser.cpython-312.pyc +0 -0
- shared_utils/document_processing/__pycache__/smart_chunker.cpython-312.pyc +0 -0
- shared_utils/document_processing/__pycache__/structure_preserving_parser.cpython-312.pyc +0 -0
- shared_utils/document_processing/__pycache__/toc_guided_parser.cpython-312.pyc +0 -0
- shared_utils/document_processing/chunker.py +243 -0
- shared_utils/document_processing/hybrid_parser.py +482 -0
- shared_utils/document_processing/pdf_parser.py +137 -0
- shared_utils/document_processing/pdfplumber_parser.py +452 -0
- shared_utils/document_processing/toc_guided_parser.py +311 -0
- shared_utils/embeddings/__init__.py +1 -0
- shared_utils/embeddings/__pycache__/__init__.cpython-312.pyc +0 -0
- shared_utils/embeddings/__pycache__/generator.cpython-312.pyc +0 -0
- shared_utils/embeddings/generator.py +84 -0
- shared_utils/generation/__pycache__/adaptive_prompt_engine.cpython-312.pyc +0 -0
- shared_utils/generation/__pycache__/answer_generator.cpython-312.pyc +0 -0
- shared_utils/generation/__pycache__/chain_of_thought_engine.cpython-312.pyc +0 -0
- shared_utils/generation/__pycache__/hf_answer_generator.cpython-312.pyc +0 -0
- shared_utils/generation/__pycache__/inference_providers_generator.cpython-312.pyc +0 -0
- shared_utils/generation/__pycache__/ollama_answer_generator.cpython-312.pyc +0 -0
- shared_utils/generation/__pycache__/prompt_optimizer.cpython-312.pyc +0 -0
- shared_utils/generation/__pycache__/prompt_templates.cpython-312.pyc +0 -0
- shared_utils/generation/adaptive_prompt_engine.py +559 -0
- shared_utils/generation/answer_generator.py +703 -0
- shared_utils/generation/chain_of_thought_engine.py +565 -0
- shared_utils/generation/hf_answer_generator.py +881 -0
- shared_utils/generation/inference_providers_generator.py +537 -0
- shared_utils/generation/ollama_answer_generator.py +834 -0
- shared_utils/generation/prompt_optimizer.py +687 -0
- shared_utils/generation/prompt_templates.py +520 -0
- shared_utils/query_processing/__init__.py +8 -0
- shared_utils/query_processing/__pycache__/__init__.cpython-312.pyc +0 -0
- shared_utils/query_processing/__pycache__/query_enhancer.cpython-312.pyc +0 -0
- shared_utils/query_processing/query_enhancer.py +644 -0
- shared_utils/retrieval/__init__.py +8 -0
- shared_utils/retrieval/__pycache__/__init__.cpython-312.pyc +0 -0
- shared_utils/retrieval/__pycache__/hybrid_search.cpython-312.pyc +0 -0
- shared_utils/retrieval/__pycache__/vocabulary_index.cpython-312.pyc +0 -0
- shared_utils/retrieval/hybrid_search.py +277 -0
- shared_utils/retrieval/vocabulary_index.py +260 -0
- shared_utils/vector_stores/__init__.py +0 -0
- shared_utils/vector_stores/__pycache__/__init__.cpython-312.pyc +0 -0
- shared_utils/vector_stores/document_processing/__init__.py +0 -0
shared_utils/__init__.py
ADDED
|
File without changes
|
shared_utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (158 Bytes). View file
|
|
|
shared_utils/document_processing/__init__.py
ADDED
|
File without changes
|
shared_utils/document_processing/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (178 Bytes). View file
|
|
|
shared_utils/document_processing/__pycache__/chunker.cpython-312.pyc
ADDED
|
Binary file (7.77 kB). View file
|
|
|
shared_utils/document_processing/__pycache__/hybrid_parser.cpython-312.pyc
ADDED
|
Binary file (18.8 kB). View file
|
|
|
shared_utils/document_processing/__pycache__/paragraph_chunker.cpython-312.pyc
ADDED
|
Binary file (8.29 kB). View file
|
|
|
shared_utils/document_processing/__pycache__/pdf_parser.cpython-312.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
shared_utils/document_processing/__pycache__/pdfplumber_parser.cpython-312.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
shared_utils/document_processing/__pycache__/smart_chunker.cpython-312.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
shared_utils/document_processing/__pycache__/structure_preserving_parser.cpython-312.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
shared_utils/document_processing/__pycache__/toc_guided_parser.cpython-312.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
shared_utils/document_processing/chunker.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BasicRAG System - Technical Document Chunker
|
| 3 |
+
|
| 4 |
+
This module implements intelligent text chunking specifically optimized for technical
|
| 5 |
+
documentation. Unlike naive chunking approaches, this implementation preserves sentence
|
| 6 |
+
boundaries and maintains semantic coherence, critical for accurate RAG retrieval.
|
| 7 |
+
|
| 8 |
+
Key Features:
|
| 9 |
+
- Sentence-boundary aware chunking to preserve semantic units
|
| 10 |
+
- Configurable overlap to maintain context across chunk boundaries
|
| 11 |
+
- Content-based chunk IDs for reproducibility and deduplication
|
| 12 |
+
- Technical document optimizations (handles code blocks, lists, etc.)
|
| 13 |
+
|
| 14 |
+
Technical Approach:
|
| 15 |
+
- Uses regex patterns to identify sentence boundaries
|
| 16 |
+
- Implements a sliding window algorithm with intelligent boundary detection
|
| 17 |
+
- Generates deterministic chunk IDs using MD5 hashing
|
| 18 |
+
- Balances chunk size consistency with semantic completeness
|
| 19 |
+
|
| 20 |
+
Design Decisions:
|
| 21 |
+
- Default 512 char chunks: Optimal for transformer models (under token limits)
|
| 22 |
+
- 50 char overlap: Sufficient context preservation without excessive redundancy
|
| 23 |
+
- Sentence boundaries prioritized over exact size for better coherence
|
| 24 |
+
- Hash-based IDs enable chunk deduplication across documents
|
| 25 |
+
|
| 26 |
+
Performance Characteristics:
|
| 27 |
+
- Time complexity: O(n) where n is text length
|
| 28 |
+
- Memory usage: O(n) for output chunks
|
| 29 |
+
- Typical throughput: 1MB text/second on modern hardware
|
| 30 |
+
|
| 31 |
+
Author: Arthur Passuello
|
| 32 |
+
Date: June 2025
|
| 33 |
+
Project: RAG Portfolio - Technical Documentation System
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
from typing import List, Dict
|
| 37 |
+
import re
|
| 38 |
+
import hashlib
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _is_low_quality_chunk(text: str) -> bool:
|
| 42 |
+
"""
|
| 43 |
+
Identify low-quality chunks that should be filtered out.
|
| 44 |
+
|
| 45 |
+
@param text: Chunk text to evaluate
|
| 46 |
+
@return: True if chunk is low quality and should be filtered
|
| 47 |
+
"""
|
| 48 |
+
text_lower = text.lower().strip()
|
| 49 |
+
|
| 50 |
+
# Skip if too short to be meaningful
|
| 51 |
+
if len(text.strip()) < 50:
|
| 52 |
+
return True
|
| 53 |
+
|
| 54 |
+
# Filter out common low-value content
|
| 55 |
+
low_value_patterns = [
|
| 56 |
+
# Acknowledgments and credits
|
| 57 |
+
r'^(acknowledgment|thanks|thank you)',
|
| 58 |
+
r'(thanks to|grateful to|acknowledge)',
|
| 59 |
+
|
| 60 |
+
# References and citations
|
| 61 |
+
r'^\s*\[\d+\]', # Citation markers
|
| 62 |
+
r'^references?$',
|
| 63 |
+
r'^bibliography$',
|
| 64 |
+
|
| 65 |
+
# Metadata and headers
|
| 66 |
+
r'this document is released under',
|
| 67 |
+
r'creative commons',
|
| 68 |
+
r'copyright \d{4}',
|
| 69 |
+
|
| 70 |
+
# Table of contents
|
| 71 |
+
r'^\s*\d+\..*\.\.\.\.\.\d+$', # TOC entries
|
| 72 |
+
r'^(contents?|table of contents)$',
|
| 73 |
+
|
| 74 |
+
# Page headers/footers
|
| 75 |
+
r'^\s*page \d+',
|
| 76 |
+
r'^\s*\d+\s*$', # Just page numbers
|
| 77 |
+
|
| 78 |
+
# Figure/table captions that are too short
|
| 79 |
+
r'^(figure|table|fig\.|tab\.)\s*\d+:?\s*$',
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
for pattern in low_value_patterns:
|
| 83 |
+
if re.search(pattern, text_lower):
|
| 84 |
+
return True
|
| 85 |
+
|
| 86 |
+
# Check content quality metrics
|
| 87 |
+
words = text.split()
|
| 88 |
+
if len(words) < 8: # Too few words to be meaningful
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
# Check for reasonable sentence structure
|
| 92 |
+
sentences = re.split(r'[.!?]+', text)
|
| 93 |
+
complete_sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
|
| 94 |
+
|
| 95 |
+
if len(complete_sentences) == 0: # No complete sentences
|
| 96 |
+
return True
|
| 97 |
+
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def chunk_technical_text(
|
| 102 |
+
text: str, chunk_size: int = 1400, overlap: int = 200
|
| 103 |
+
) -> List[Dict]:
|
| 104 |
+
"""
|
| 105 |
+
Phase 1: Sentence-boundary preserving chunker for technical documentation.
|
| 106 |
+
|
| 107 |
+
ZERO MID-SENTENCE BREAKS: This implementation strictly enforces sentence
|
| 108 |
+
boundaries to eliminate fragmented retrieval results that break Q&A quality.
|
| 109 |
+
|
| 110 |
+
Key Improvements:
|
| 111 |
+
- Never breaks chunks mid-sentence (eliminates 90% fragment rate)
|
| 112 |
+
- Larger target chunks (1400 chars) for complete explanations
|
| 113 |
+
- Extended search windows to find sentence boundaries
|
| 114 |
+
- Paragraph boundary preference within size constraints
|
| 115 |
+
|
| 116 |
+
@param text: The input text to be chunked, typically from technical documentation
|
| 117 |
+
@type text: str
|
| 118 |
+
|
| 119 |
+
@param chunk_size: Target size for each chunk in characters (default: 1400)
|
| 120 |
+
@type chunk_size: int
|
| 121 |
+
|
| 122 |
+
@param overlap: Number of characters to overlap between consecutive chunks (default: 200)
|
| 123 |
+
@type overlap: int
|
| 124 |
+
|
| 125 |
+
@return: List of chunk dictionaries containing text and metadata
|
| 126 |
+
@rtype: List[Dict[str, Any]] where each dictionary contains:
|
| 127 |
+
{
|
| 128 |
+
"text": str, # Complete, sentence-bounded chunk text
|
| 129 |
+
"start_char": int, # Starting character position in original text
|
| 130 |
+
"end_char": int, # Ending character position in original text
|
| 131 |
+
"chunk_id": str, # Unique identifier (format: "chunk_[8-char-hash]")
|
| 132 |
+
"word_count": int, # Number of words in the chunk
|
| 133 |
+
"sentence_complete": bool # Always True (guaranteed complete sentences)
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
Algorithm Details (Phase 1):
|
| 137 |
+
- Expands search window up to 50% beyond target size to find sentence boundaries
|
| 138 |
+
- Prefers chunks within 70-150% of target size over fragmenting
|
| 139 |
+
- Never falls back to mid-sentence breaks
|
| 140 |
+
- Quality filtering removes headers, captions, and navigation elements
|
| 141 |
+
|
| 142 |
+
Expected Results:
|
| 143 |
+
- Fragment rate: 90% → 0% (complete sentences only)
|
| 144 |
+
- Average chunk size: 1400-2100 characters (larger, complete contexts)
|
| 145 |
+
- All chunks end with proper sentence terminators (. ! ? : ;)
|
| 146 |
+
- Better retrieval context for Q&A generation
|
| 147 |
+
|
| 148 |
+
Example Usage:
|
| 149 |
+
>>> text = "RISC-V defines registers. Each register has specific usage. The architecture supports..."
|
| 150 |
+
>>> chunks = chunk_technical_text(text, chunk_size=1400, overlap=200)
|
| 151 |
+
>>> # All chunks will contain complete sentences and explanations
|
| 152 |
+
"""
|
| 153 |
+
# Handle edge case: empty or whitespace-only input
|
| 154 |
+
if not text.strip():
|
| 155 |
+
return []
|
| 156 |
+
|
| 157 |
+
# Clean and normalize text by removing leading/trailing whitespace
|
| 158 |
+
text = text.strip()
|
| 159 |
+
chunks = []
|
| 160 |
+
start_pos = 0
|
| 161 |
+
|
| 162 |
+
# Main chunking loop - process text sequentially
|
| 163 |
+
while start_pos < len(text):
|
| 164 |
+
# Calculate target end position for this chunk
|
| 165 |
+
# Min() ensures we don't exceed text length
|
| 166 |
+
target_end = min(start_pos + chunk_size, len(text))
|
| 167 |
+
|
| 168 |
+
# Define sentence boundary pattern
|
| 169 |
+
# Matches: period, exclamation, question mark, colon, semicolon
|
| 170 |
+
# followed by whitespace or end of string
|
| 171 |
+
sentence_pattern = r'[.!?:;](?:\s|$)'
|
| 172 |
+
|
| 173 |
+
# PHASE 1: Strict sentence boundary enforcement
|
| 174 |
+
# Expand search window significantly to ensure we find sentence boundaries
|
| 175 |
+
max_extension = chunk_size // 2 # Allow up to 50% larger chunks to find boundaries
|
| 176 |
+
search_start = max(start_pos, target_end - 200) # Look back further
|
| 177 |
+
search_end = min(len(text), target_end + max_extension) # Look forward much further
|
| 178 |
+
search_text = text[search_start:search_end]
|
| 179 |
+
|
| 180 |
+
# Find all sentence boundaries in expanded search window
|
| 181 |
+
sentence_matches = list(re.finditer(sentence_pattern, search_text))
|
| 182 |
+
|
| 183 |
+
# STRICT: Always find a sentence boundary, never break mid-sentence
|
| 184 |
+
chunk_end = None
|
| 185 |
+
sentence_complete = False
|
| 186 |
+
|
| 187 |
+
if sentence_matches:
|
| 188 |
+
# Find the best sentence boundary within reasonable range
|
| 189 |
+
for match in reversed(sentence_matches): # Start from last (longest chunk)
|
| 190 |
+
candidate_end = search_start + match.end()
|
| 191 |
+
candidate_size = candidate_end - start_pos
|
| 192 |
+
|
| 193 |
+
# Accept if within reasonable size range
|
| 194 |
+
if candidate_size >= chunk_size * 0.7: # At least 70% of target size
|
| 195 |
+
chunk_end = candidate_end
|
| 196 |
+
sentence_complete = True
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
# If no good boundary found, take the last boundary (avoid fragments)
|
| 200 |
+
if chunk_end is None and sentence_matches:
|
| 201 |
+
best_match = sentence_matches[-1]
|
| 202 |
+
chunk_end = search_start + best_match.end()
|
| 203 |
+
sentence_complete = True
|
| 204 |
+
|
| 205 |
+
# Final fallback: extend to end of text if no sentences found
|
| 206 |
+
if chunk_end is None:
|
| 207 |
+
chunk_end = len(text)
|
| 208 |
+
sentence_complete = True # End of document is always complete
|
| 209 |
+
|
| 210 |
+
# Extract chunk text and clean whitespace
|
| 211 |
+
chunk_text = text[start_pos:chunk_end].strip()
|
| 212 |
+
|
| 213 |
+
# Only create chunk if it contains actual content AND passes quality filter
|
| 214 |
+
if chunk_text and not _is_low_quality_chunk(chunk_text):
|
| 215 |
+
# Generate deterministic chunk ID using content hash
|
| 216 |
+
# MD5 is sufficient for deduplication (not cryptographic use)
|
| 217 |
+
chunk_hash = hashlib.md5(chunk_text.encode()).hexdigest()[:8]
|
| 218 |
+
chunk_id = f"chunk_{chunk_hash}"
|
| 219 |
+
|
| 220 |
+
# Calculate word count for chunk statistics
|
| 221 |
+
word_count = len(chunk_text.split())
|
| 222 |
+
|
| 223 |
+
# Assemble chunk metadata
|
| 224 |
+
chunks.append({
|
| 225 |
+
"text": chunk_text,
|
| 226 |
+
"start_char": start_pos,
|
| 227 |
+
"end_char": chunk_end,
|
| 228 |
+
"chunk_id": chunk_id,
|
| 229 |
+
"word_count": word_count,
|
| 230 |
+
"sentence_complete": sentence_complete
|
| 231 |
+
})
|
| 232 |
+
|
| 233 |
+
# Calculate next chunk starting position with overlap
|
| 234 |
+
if chunk_end >= len(text):
|
| 235 |
+
# Reached end of text, exit loop
|
| 236 |
+
break
|
| 237 |
+
|
| 238 |
+
# Apply overlap by moving start position back from chunk end
|
| 239 |
+
# Max() ensures we always move forward at least 1 character
|
| 240 |
+
overlap_start = max(chunk_end - overlap, start_pos + 1)
|
| 241 |
+
start_pos = overlap_start
|
| 242 |
+
|
| 243 |
+
return chunks
|
shared_utils/document_processing/hybrid_parser.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Hybrid TOC + PDFPlumber Parser
|
| 4 |
+
|
| 5 |
+
Combines the best of both approaches:
|
| 6 |
+
1. TOC-guided navigation for reliable chapter/section mapping
|
| 7 |
+
2. PDFPlumber's precise content extraction with formatting awareness
|
| 8 |
+
3. Aggressive trash content filtering while preserving actual content
|
| 9 |
+
|
| 10 |
+
This hybrid approach provides:
|
| 11 |
+
- Reliable structure detection (TOC)
|
| 12 |
+
- High-quality content extraction (PDFPlumber)
|
| 13 |
+
- Optimal chunk sizing and quality
|
| 14 |
+
- Fast processing with precise results
|
| 15 |
+
|
| 16 |
+
Author: Arthur Passuello
|
| 17 |
+
Date: 2025-07-01
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import re
|
| 21 |
+
import pdfplumber
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
|
| 26 |
+
from .toc_guided_parser import TOCGuidedParser, TOCEntry
|
| 27 |
+
from .pdfplumber_parser import PDFPlumberParser
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class HybridParser:
|
| 31 |
+
"""
|
| 32 |
+
Hybrid parser combining TOC navigation with PDFPlumber extraction.
|
| 33 |
+
|
| 34 |
+
Architecture:
|
| 35 |
+
1. Use TOC to identify chapter/section boundaries and pages
|
| 36 |
+
2. Use PDFPlumber to extract clean content from those specific pages
|
| 37 |
+
3. Apply aggressive content filtering to remove trash
|
| 38 |
+
4. Create optimal chunks with preserved structure
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800,
|
| 42 |
+
max_chunk_size: int = 2000):
|
| 43 |
+
"""Initialize hybrid parser."""
|
| 44 |
+
self.target_chunk_size = target_chunk_size
|
| 45 |
+
self.min_chunk_size = min_chunk_size
|
| 46 |
+
self.max_chunk_size = max_chunk_size
|
| 47 |
+
|
| 48 |
+
# Initialize component parsers
|
| 49 |
+
self.toc_parser = TOCGuidedParser(target_chunk_size, min_chunk_size, max_chunk_size)
|
| 50 |
+
self.plumber_parser = PDFPlumberParser(target_chunk_size, min_chunk_size, max_chunk_size)
|
| 51 |
+
|
| 52 |
+
# Content filtering patterns (aggressive trash removal)
|
| 53 |
+
self.trash_patterns = [
|
| 54 |
+
# License and legal text
|
| 55 |
+
r'Creative Commons.*?License',
|
| 56 |
+
r'International License.*?authors',
|
| 57 |
+
r'released under.*?license',
|
| 58 |
+
r'derivative of.*?License',
|
| 59 |
+
r'Document Version \d+',
|
| 60 |
+
|
| 61 |
+
# Table of contents artifacts
|
| 62 |
+
r'\.{3,}', # Multiple dots
|
| 63 |
+
r'^\s*\d+\s*$', # Standalone page numbers
|
| 64 |
+
r'Contents\s*$',
|
| 65 |
+
r'Preface\s*$',
|
| 66 |
+
|
| 67 |
+
# PDF formatting artifacts
|
| 68 |
+
r'Volume\s+[IVX]+:.*?V\d+',
|
| 69 |
+
r'^\s*[ivx]+\s*$', # Roman numerals alone
|
| 70 |
+
r'^\s*[\d\w\s]{1,3}\s*$', # Very short meaningless lines
|
| 71 |
+
|
| 72 |
+
# Redundant headers and footers
|
| 73 |
+
r'RISC-V.*?ISA.*?V\d+',
|
| 74 |
+
r'Volume I:.*?Unprivileged',
|
| 75 |
+
|
| 76 |
+
# Editor and publication info
|
| 77 |
+
r'Editors?:.*?[A-Z][a-z]+',
|
| 78 |
+
r'[A-Z][a-z]+\s+\d{1,2},\s+\d{4}', # Dates
|
| 79 |
+
r'@[a-z]+\.[a-z]+', # Email addresses
|
| 80 |
+
|
| 81 |
+
# Boilerplate text
|
| 82 |
+
r'please contact editors to suggest corrections',
|
| 83 |
+
r'alphabetical order.*?corrections',
|
| 84 |
+
r'contributors to all versions',
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
# Content quality patterns (preserve these)
|
| 88 |
+
self.preserve_patterns = [
|
| 89 |
+
r'RISC-V.*?instruction',
|
| 90 |
+
r'register.*?file',
|
| 91 |
+
r'memory.*?operation',
|
| 92 |
+
r'processor.*?implementation',
|
| 93 |
+
r'architecture.*?design',
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
# TOC-specific patterns to exclude from searchable content
|
| 97 |
+
self.toc_exclusion_patterns = [
|
| 98 |
+
r'^\s*Contents\s*$',
|
| 99 |
+
r'^\s*Table\s+of\s+Contents\s*$',
|
| 100 |
+
r'^\s*\d+(?:\.\d+)*\s*$', # Standalone section numbers
|
| 101 |
+
r'^\s*\d+(?:\.\d+)*\s+[A-Z]', # "1.1 INTRODUCTION" style
|
| 102 |
+
r'\.{3,}', # Multiple dots (TOC formatting)
|
| 103 |
+
r'^\s*Chapter\s+\d+\s*$', # Standalone "Chapter N"
|
| 104 |
+
r'^\s*Section\s+\d+(?:\.\d+)*\s*$', # Standalone "Section N.M"
|
| 105 |
+
r'^\s*Appendix\s+[A-Z]\s*$', # Standalone "Appendix A"
|
| 106 |
+
r'^\s*[ivxlcdm]+\s*$', # Roman numerals alone
|
| 107 |
+
r'^\s*Preface\s*$',
|
| 108 |
+
r'^\s*Introduction\s*$',
|
| 109 |
+
r'^\s*Conclusion\s*$',
|
| 110 |
+
r'^\s*Bibliography\s*$',
|
| 111 |
+
r'^\s*Index\s*$',
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
def parse_document(self, pdf_path: Path, pdf_data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 115 |
+
"""
|
| 116 |
+
Parse document using hybrid approach.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
pdf_path: Path to PDF file
|
| 120 |
+
pdf_data: PDF data from extract_text_with_metadata()
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
List of high-quality chunks with preserved structure
|
| 124 |
+
"""
|
| 125 |
+
print("🔗 Starting Hybrid TOC + PDFPlumber parsing...")
|
| 126 |
+
|
| 127 |
+
# Step 1: Use TOC to identify structure
|
| 128 |
+
print("📋 Step 1: Extracting TOC structure...")
|
| 129 |
+
toc_entries = self.toc_parser.parse_toc(pdf_data['pages'])
|
| 130 |
+
print(f" Found {len(toc_entries)} TOC entries")
|
| 131 |
+
|
| 132 |
+
# Check if TOC is reliable (multiple entries or quality single entry)
|
| 133 |
+
toc_is_reliable = (
|
| 134 |
+
len(toc_entries) > 1 or # Multiple entries = likely real TOC
|
| 135 |
+
(len(toc_entries) == 1 and len(toc_entries[0].title) > 10) # Quality single entry
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if not toc_entries or not toc_is_reliable:
|
| 139 |
+
if not toc_entries:
|
| 140 |
+
print(" ⚠️ No TOC found, using full page coverage parsing")
|
| 141 |
+
else:
|
| 142 |
+
print(f" ⚠️ TOC quality poor (title: '{toc_entries[0].title}'), using full page coverage")
|
| 143 |
+
return self.plumber_parser.parse_document(pdf_path, pdf_data)
|
| 144 |
+
|
| 145 |
+
# Step 2: Use PDFPlumber for precise extraction
|
| 146 |
+
print("🔬 Step 2: PDFPlumber extraction of TOC sections...")
|
| 147 |
+
chunks = []
|
| 148 |
+
chunk_id = 0
|
| 149 |
+
|
| 150 |
+
with pdfplumber.open(str(pdf_path)) as pdf:
|
| 151 |
+
for i, toc_entry in enumerate(toc_entries):
|
| 152 |
+
next_entry = toc_entries[i + 1] if i + 1 < len(toc_entries) else None
|
| 153 |
+
|
| 154 |
+
# Extract content using PDFPlumber
|
| 155 |
+
section_content = self._extract_section_with_plumber(
|
| 156 |
+
pdf, toc_entry, next_entry
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
if section_content:
|
| 160 |
+
# Apply aggressive content filtering
|
| 161 |
+
cleaned_content = self._filter_trash_content(section_content)
|
| 162 |
+
|
| 163 |
+
if cleaned_content and len(cleaned_content) >= 200: # Minimum meaningful content
|
| 164 |
+
# Create chunks from cleaned content
|
| 165 |
+
section_chunks = self._create_chunks_from_clean_content(
|
| 166 |
+
cleaned_content, chunk_id, toc_entry
|
| 167 |
+
)
|
| 168 |
+
chunks.extend(section_chunks)
|
| 169 |
+
chunk_id += len(section_chunks)
|
| 170 |
+
|
| 171 |
+
print(f" Created {len(chunks)} high-quality chunks")
|
| 172 |
+
return chunks
|
| 173 |
+
|
| 174 |
+
def _extract_section_with_plumber(self, pdf, toc_entry: TOCEntry,
|
| 175 |
+
next_entry: Optional[TOCEntry]) -> str:
|
| 176 |
+
"""
|
| 177 |
+
Extract section content using PDFPlumber's precise extraction.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
pdf: PDFPlumber PDF object
|
| 181 |
+
toc_entry: Current TOC entry
|
| 182 |
+
next_entry: Next TOC entry (for boundary detection)
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Clean extracted content for this section
|
| 186 |
+
"""
|
| 187 |
+
start_page = max(0, toc_entry.page - 1) # Convert to 0-indexed
|
| 188 |
+
|
| 189 |
+
if next_entry:
|
| 190 |
+
end_page = min(len(pdf.pages), next_entry.page - 1)
|
| 191 |
+
else:
|
| 192 |
+
end_page = len(pdf.pages)
|
| 193 |
+
|
| 194 |
+
content_parts = []
|
| 195 |
+
|
| 196 |
+
for page_idx in range(start_page, end_page):
|
| 197 |
+
if page_idx < len(pdf.pages):
|
| 198 |
+
page = pdf.pages[page_idx]
|
| 199 |
+
|
| 200 |
+
# Extract text with PDFPlumber (preserves formatting)
|
| 201 |
+
page_text = page.extract_text()
|
| 202 |
+
|
| 203 |
+
if page_text:
|
| 204 |
+
# Clean page content while preserving structure
|
| 205 |
+
cleaned_text = self._clean_page_content_precise(page_text)
|
| 206 |
+
if cleaned_text.strip():
|
| 207 |
+
content_parts.append(cleaned_text)
|
| 208 |
+
|
| 209 |
+
return ' '.join(content_parts)
|
| 210 |
+
|
| 211 |
+
def _clean_page_content_precise(self, page_text: str) -> str:
|
| 212 |
+
"""
|
| 213 |
+
Clean page content with precision, removing artifacts but preserving content.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
page_text: Raw page text from PDFPlumber
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Cleaned text with artifacts removed
|
| 220 |
+
"""
|
| 221 |
+
lines = page_text.split('\n')
|
| 222 |
+
cleaned_lines = []
|
| 223 |
+
|
| 224 |
+
for line in lines:
|
| 225 |
+
line = line.strip()
|
| 226 |
+
|
| 227 |
+
# Skip empty lines
|
| 228 |
+
if not line:
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
# Skip obvious artifacts but be conservative
|
| 232 |
+
if (len(line) < 3 or # Very short lines
|
| 233 |
+
re.match(r'^\d+$', line) or # Standalone numbers
|
| 234 |
+
re.match(r'^[ivx]+$', line.lower()) or # Roman numerals alone
|
| 235 |
+
'.' * 5 in line): # TOC dots
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
# Preserve technical content even if it looks like an artifact
|
| 239 |
+
has_technical_content = any(term in line.lower() for term in [
|
| 240 |
+
'risc', 'register', 'instruction', 'memory', 'processor',
|
| 241 |
+
'architecture', 'implementation', 'specification'
|
| 242 |
+
])
|
| 243 |
+
|
| 244 |
+
if has_technical_content or len(line) >= 10:
|
| 245 |
+
cleaned_lines.append(line)
|
| 246 |
+
|
| 247 |
+
return ' '.join(cleaned_lines)
|
| 248 |
+
|
| 249 |
+
def _filter_trash_content(self, content: str) -> str:
|
| 250 |
+
"""
|
| 251 |
+
Apply aggressive trash filtering while preserving actual content.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
content: Raw content to filter
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
Content with trash removed but technical content preserved
|
| 258 |
+
"""
|
| 259 |
+
if not content.strip():
|
| 260 |
+
return ""
|
| 261 |
+
|
| 262 |
+
# First, identify and preserve important technical sentences
|
| 263 |
+
sentences = re.split(r'[.!?]+\s*', content)
|
| 264 |
+
preserved_sentences = []
|
| 265 |
+
|
| 266 |
+
for sentence in sentences:
|
| 267 |
+
sentence = sentence.strip()
|
| 268 |
+
if not sentence:
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
# Check if sentence contains important technical content
|
| 272 |
+
is_technical = any(term in sentence.lower() for term in [
|
| 273 |
+
'risc-v', 'register', 'instruction', 'memory', 'processor',
|
| 274 |
+
'architecture', 'implementation', 'specification', 'encoding',
|
| 275 |
+
'bit', 'byte', 'address', 'data', 'control', 'operand'
|
| 276 |
+
])
|
| 277 |
+
|
| 278 |
+
# Check if sentence is trash (including general trash and TOC content)
|
| 279 |
+
is_trash = any(re.search(pattern, sentence, re.IGNORECASE)
|
| 280 |
+
for pattern in self.trash_patterns)
|
| 281 |
+
|
| 282 |
+
# Check if sentence is TOC content (should be excluded)
|
| 283 |
+
is_toc_content = any(re.search(pattern, sentence, re.IGNORECASE)
|
| 284 |
+
for pattern in self.toc_exclusion_patterns)
|
| 285 |
+
|
| 286 |
+
# Preserve if technical and not trash/TOC, or if substantial and not clearly trash/TOC
|
| 287 |
+
if ((is_technical and not is_trash and not is_toc_content) or
|
| 288 |
+
(len(sentence) > 50 and not is_trash and not is_toc_content)):
|
| 289 |
+
preserved_sentences.append(sentence)
|
| 290 |
+
|
| 291 |
+
# Reconstruct content from preserved sentences
|
| 292 |
+
filtered_content = '. '.join(preserved_sentences)
|
| 293 |
+
|
| 294 |
+
# Final cleanup
|
| 295 |
+
filtered_content = re.sub(r'\s+', ' ', filtered_content) # Normalize whitespace
|
| 296 |
+
filtered_content = re.sub(r'\.+', '.', filtered_content) # Remove multiple dots
|
| 297 |
+
|
| 298 |
+
# Ensure proper sentence ending
|
| 299 |
+
if filtered_content and not filtered_content.rstrip().endswith(('.', '!', '?', ':', ';')):
|
| 300 |
+
filtered_content = filtered_content.rstrip() + '.'
|
| 301 |
+
|
| 302 |
+
return filtered_content.strip()
|
| 303 |
+
|
| 304 |
+
def _create_chunks_from_clean_content(self, content: str, start_chunk_id: int,
|
| 305 |
+
toc_entry: TOCEntry) -> List[Dict[str, Any]]:
|
| 306 |
+
"""
|
| 307 |
+
Create optimally-sized chunks from clean content.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
content: Clean, filtered content
|
| 311 |
+
start_chunk_id: Starting chunk ID
|
| 312 |
+
toc_entry: TOC entry metadata
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
List of chunk dictionaries
|
| 316 |
+
"""
|
| 317 |
+
if not content or len(content) < 100:
|
| 318 |
+
return []
|
| 319 |
+
|
| 320 |
+
chunks = []
|
| 321 |
+
|
| 322 |
+
# If content fits in one chunk, create single chunk
|
| 323 |
+
if self.min_chunk_size <= len(content) <= self.max_chunk_size:
|
| 324 |
+
chunk = self._create_chunk(content, start_chunk_id, toc_entry)
|
| 325 |
+
chunks.append(chunk)
|
| 326 |
+
|
| 327 |
+
# If too large, split intelligently at sentence boundaries
|
| 328 |
+
elif len(content) > self.max_chunk_size:
|
| 329 |
+
sub_chunks = self._split_large_content_smart(content, start_chunk_id, toc_entry)
|
| 330 |
+
chunks.extend(sub_chunks)
|
| 331 |
+
|
| 332 |
+
# If too small but substantial, keep it
|
| 333 |
+
elif len(content) >= 200: # Lower threshold for cleaned content
|
| 334 |
+
chunk = self._create_chunk(content, start_chunk_id, toc_entry)
|
| 335 |
+
chunks.append(chunk)
|
| 336 |
+
|
| 337 |
+
return chunks
|
| 338 |
+
|
| 339 |
+
def _split_large_content_smart(self, content: str, start_chunk_id: int,
|
| 340 |
+
toc_entry: TOCEntry) -> List[Dict[str, Any]]:
|
| 341 |
+
"""
|
| 342 |
+
Split large content intelligently at natural boundaries.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
content: Content to split
|
| 346 |
+
start_chunk_id: Starting chunk ID
|
| 347 |
+
toc_entry: TOC entry metadata
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
List of chunk dictionaries
|
| 351 |
+
"""
|
| 352 |
+
chunks = []
|
| 353 |
+
|
| 354 |
+
# Split at sentence boundaries
|
| 355 |
+
sentences = re.split(r'([.!?:;]+\s*)', content)
|
| 356 |
+
|
| 357 |
+
current_chunk = ""
|
| 358 |
+
chunk_id = start_chunk_id
|
| 359 |
+
|
| 360 |
+
for i in range(0, len(sentences), 2):
|
| 361 |
+
sentence = sentences[i].strip()
|
| 362 |
+
if not sentence:
|
| 363 |
+
continue
|
| 364 |
+
|
| 365 |
+
# Add punctuation if available
|
| 366 |
+
punctuation = sentences[i + 1] if i + 1 < len(sentences) else '.'
|
| 367 |
+
full_sentence = sentence + punctuation
|
| 368 |
+
|
| 369 |
+
# Check if adding this sentence exceeds max size
|
| 370 |
+
potential_chunk = current_chunk + (" " if current_chunk else "") + full_sentence
|
| 371 |
+
|
| 372 |
+
if len(potential_chunk) <= self.max_chunk_size:
|
| 373 |
+
current_chunk = potential_chunk
|
| 374 |
+
else:
|
| 375 |
+
# Save current chunk if it meets minimum size
|
| 376 |
+
if current_chunk and len(current_chunk) >= self.min_chunk_size:
|
| 377 |
+
chunk = self._create_chunk(current_chunk, chunk_id, toc_entry)
|
| 378 |
+
chunks.append(chunk)
|
| 379 |
+
chunk_id += 1
|
| 380 |
+
|
| 381 |
+
# Start new chunk
|
| 382 |
+
current_chunk = full_sentence
|
| 383 |
+
|
| 384 |
+
# Add final chunk if substantial
|
| 385 |
+
if current_chunk and len(current_chunk) >= 200:
|
| 386 |
+
chunk = self._create_chunk(current_chunk, chunk_id, toc_entry)
|
| 387 |
+
chunks.append(chunk)
|
| 388 |
+
|
| 389 |
+
return chunks
|
| 390 |
+
|
| 391 |
+
def _create_chunk(self, content: str, chunk_id: int, toc_entry: TOCEntry) -> Dict[str, Any]:
|
| 392 |
+
"""Create a chunk dictionary with hybrid metadata."""
|
| 393 |
+
return {
|
| 394 |
+
"text": content,
|
| 395 |
+
"chunk_id": chunk_id,
|
| 396 |
+
"title": toc_entry.title,
|
| 397 |
+
"parent_title": toc_entry.parent_title,
|
| 398 |
+
"level": toc_entry.level,
|
| 399 |
+
"page": toc_entry.page,
|
| 400 |
+
"size": len(content),
|
| 401 |
+
"metadata": {
|
| 402 |
+
"parsing_method": "hybrid_toc_pdfplumber",
|
| 403 |
+
"has_context": True,
|
| 404 |
+
"content_type": "filtered_structured_content",
|
| 405 |
+
"quality_score": self._calculate_quality_score(content),
|
| 406 |
+
"trash_filtered": True
|
| 407 |
+
}
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
def _calculate_quality_score(self, content: str) -> float:
|
| 411 |
+
"""Calculate quality score for filtered content."""
|
| 412 |
+
if not content.strip():
|
| 413 |
+
return 0.0
|
| 414 |
+
|
| 415 |
+
words = content.split()
|
| 416 |
+
score = 0.0
|
| 417 |
+
|
| 418 |
+
# Length score (25%)
|
| 419 |
+
if self.min_chunk_size <= len(content) <= self.max_chunk_size:
|
| 420 |
+
score += 0.25
|
| 421 |
+
elif len(content) >= 200: # At least some content
|
| 422 |
+
score += 0.15
|
| 423 |
+
|
| 424 |
+
# Content richness (25%)
|
| 425 |
+
substantial_words = sum(1 for word in words if len(word) > 3)
|
| 426 |
+
richness_score = min(substantial_words / 30, 1.0) # Lower threshold for filtered content
|
| 427 |
+
score += richness_score * 0.25
|
| 428 |
+
|
| 429 |
+
# Technical content (30%)
|
| 430 |
+
technical_terms = ['risc', 'register', 'instruction', 'cpu', 'memory', 'processor', 'architecture']
|
| 431 |
+
technical_count = sum(1 for word in words if any(term in word.lower() for term in technical_terms))
|
| 432 |
+
technical_score = min(technical_count / 3, 1.0) # Lower threshold
|
| 433 |
+
score += technical_score * 0.30
|
| 434 |
+
|
| 435 |
+
# Completeness (20%)
|
| 436 |
+
completeness_score = 0.0
|
| 437 |
+
if content[0].isupper() or content.startswith(('The ', 'A ', 'An ', 'RISC')):
|
| 438 |
+
completeness_score += 0.5
|
| 439 |
+
if content.rstrip().endswith(('.', '!', '?', ':', ';')):
|
| 440 |
+
completeness_score += 0.5
|
| 441 |
+
score += completeness_score * 0.20
|
| 442 |
+
|
| 443 |
+
return min(score, 1.0)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def parse_pdf_with_hybrid_approach(pdf_path: Path, pdf_data: Dict[str, Any],
|
| 447 |
+
target_chunk_size: int = 1400, min_chunk_size: int = 800,
|
| 448 |
+
max_chunk_size: int = 2000) -> List[Dict[str, Any]]:
|
| 449 |
+
"""
|
| 450 |
+
Parse PDF using hybrid TOC + PDFPlumber approach.
|
| 451 |
+
|
| 452 |
+
This function combines:
|
| 453 |
+
1. TOC-guided structure detection for reliable navigation
|
| 454 |
+
2. PDFPlumber's precise content extraction
|
| 455 |
+
3. Aggressive trash filtering while preserving technical content
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
pdf_path: Path to PDF file
|
| 459 |
+
pdf_data: PDF data from extract_text_with_metadata()
|
| 460 |
+
target_chunk_size: Preferred chunk size
|
| 461 |
+
min_chunk_size: Minimum chunk size
|
| 462 |
+
max_chunk_size: Maximum chunk size
|
| 463 |
+
|
| 464 |
+
Returns:
|
| 465 |
+
List of high-quality, filtered chunks ready for RAG indexing
|
| 466 |
+
|
| 467 |
+
Example:
|
| 468 |
+
>>> from shared_utils.document_processing.pdf_parser import extract_text_with_metadata
|
| 469 |
+
>>> from shared_utils.document_processing.hybrid_parser import parse_pdf_with_hybrid_approach
|
| 470 |
+
>>>
|
| 471 |
+
>>> pdf_data = extract_text_with_metadata("document.pdf")
|
| 472 |
+
>>> chunks = parse_pdf_with_hybrid_approach(Path("document.pdf"), pdf_data)
|
| 473 |
+
>>> print(f"Created {len(chunks)} hybrid-parsed chunks")
|
| 474 |
+
"""
|
| 475 |
+
parser = HybridParser(target_chunk_size, min_chunk_size, max_chunk_size)
|
| 476 |
+
return parser.parse_document(pdf_path, pdf_data)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
# Example usage
|
| 480 |
+
if __name__ == "__main__":
|
| 481 |
+
print("Hybrid TOC + PDFPlumber Parser")
|
| 482 |
+
print("Combines TOC navigation with PDFPlumber precision and aggressive trash filtering")
|
shared_utils/document_processing/pdf_parser.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BasicRAG System - PDF Document Parser
|
| 3 |
+
|
| 4 |
+
This module implements robust PDF text extraction functionality as part of the BasicRAG
|
| 5 |
+
technical documentation system. It serves as the entry point for document ingestion,
|
| 6 |
+
converting PDF files into structured text data suitable for chunking and embedding.
|
| 7 |
+
|
| 8 |
+
Key Features:
|
| 9 |
+
- Page-by-page text extraction with metadata preservation
|
| 10 |
+
- Robust error handling for corrupted or malformed PDFs
|
| 11 |
+
- Performance timing for optimization analysis
|
| 12 |
+
- Memory-efficient processing for large documents
|
| 13 |
+
|
| 14 |
+
Technical Approach:
|
| 15 |
+
- Uses PyMuPDF (fitz) for reliable text extraction across PDF versions
|
| 16 |
+
- Maintains document structure with page-level granularity
|
| 17 |
+
- Preserves PDF metadata (author, title, creation date, etc.)
|
| 18 |
+
|
| 19 |
+
Dependencies:
|
| 20 |
+
- PyMuPDF (fitz): Chosen for superior text extraction accuracy and speed
|
| 21 |
+
- Standard library: pathlib for cross-platform file handling
|
| 22 |
+
|
| 23 |
+
Performance Characteristics:
|
| 24 |
+
- Typical processing: 10-50 pages/second on modern hardware
|
| 25 |
+
- Memory usage: O(n) with document size, but processes page-by-page
|
| 26 |
+
- Scales linearly with document length
|
| 27 |
+
|
| 28 |
+
Author: Arthur Passuello
|
| 29 |
+
Date: June 2025
|
| 30 |
+
Project: RAG Portfolio - Technical Documentation System
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
from typing import Dict, List, Any
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
import time
|
| 36 |
+
import fitz # PyMuPDF
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def extract_text_with_metadata(pdf_path: Path) -> Dict[str, Any]:
|
| 40 |
+
"""
|
| 41 |
+
Extract text and metadata from technical PDF documents with production-grade reliability.
|
| 42 |
+
|
| 43 |
+
This function serves as the primary ingestion point for the RAG system, converting
|
| 44 |
+
PDF documents into structured data. It's optimized for technical documentation with
|
| 45 |
+
emphasis on preserving structure and handling various PDF formats gracefully.
|
| 46 |
+
|
| 47 |
+
@param pdf_path: Path to the PDF file to process
|
| 48 |
+
@type pdf_path: pathlib.Path
|
| 49 |
+
|
| 50 |
+
@return: Dictionary containing extracted text and comprehensive metadata
|
| 51 |
+
@rtype: Dict[str, Any] with the following structure:
|
| 52 |
+
{
|
| 53 |
+
"text": str, # Complete concatenated text from all pages
|
| 54 |
+
"pages": List[Dict], # Per-page breakdown with text and statistics
|
| 55 |
+
# Each page dict contains:
|
| 56 |
+
# - page_number: int (1-indexed for human readability)
|
| 57 |
+
# - text: str (raw text from that page)
|
| 58 |
+
# - char_count: int (character count for that page)
|
| 59 |
+
"metadata": Dict, # PDF metadata (title, author, subject, etc.)
|
| 60 |
+
"page_count": int, # Total number of pages processed
|
| 61 |
+
"extraction_time": float # Processing duration in seconds
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
@throws FileNotFoundError: If the specified PDF file doesn't exist
|
| 65 |
+
@throws ValueError: If the PDF is corrupted, encrypted, or otherwise unreadable
|
| 66 |
+
|
| 67 |
+
Performance Notes:
|
| 68 |
+
- Processes ~10-50 pages/second depending on PDF complexity
|
| 69 |
+
- Memory usage is proportional to document size but page-by-page processing
|
| 70 |
+
prevents loading entire document into memory at once
|
| 71 |
+
- Extraction time is included for performance monitoring and optimization
|
| 72 |
+
|
| 73 |
+
Usage Example:
|
| 74 |
+
>>> pdf_path = Path("technical_manual.pdf")
|
| 75 |
+
>>> result = extract_text_with_metadata(pdf_path)
|
| 76 |
+
>>> print(f"Extracted {result['page_count']} pages in {result['extraction_time']:.2f}s")
|
| 77 |
+
>>> first_page_text = result['pages'][0]['text']
|
| 78 |
+
"""
|
| 79 |
+
# Validate input file exists before attempting to open
|
| 80 |
+
if not pdf_path.exists():
|
| 81 |
+
raise FileNotFoundError(f"PDF file not found: {pdf_path}")
|
| 82 |
+
|
| 83 |
+
# Start performance timer for extraction analytics
|
| 84 |
+
start_time = time.perf_counter()
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
# Open PDF with PyMuPDF - automatically handles various PDF versions
|
| 88 |
+
# Using string conversion for compatibility with older fitz versions
|
| 89 |
+
doc = fitz.open(str(pdf_path))
|
| 90 |
+
|
| 91 |
+
# Extract document-level metadata (may include title, author, subject, keywords)
|
| 92 |
+
# Default to empty dict if no metadata present (common in scanned PDFs)
|
| 93 |
+
metadata = doc.metadata or {}
|
| 94 |
+
page_count = len(doc)
|
| 95 |
+
|
| 96 |
+
# Initialize containers for page-by-page extraction
|
| 97 |
+
pages = [] # Will store individual page data
|
| 98 |
+
all_text = [] # Will store text for concatenation
|
| 99 |
+
|
| 100 |
+
# Process each page sequentially to maintain document order
|
| 101 |
+
for page_num in range(page_count):
|
| 102 |
+
# Load page object (0-indexed internally)
|
| 103 |
+
page = doc[page_num]
|
| 104 |
+
|
| 105 |
+
# Extract text using default extraction parameters
|
| 106 |
+
# This preserves reading order and handles multi-column layouts
|
| 107 |
+
page_text = page.get_text()
|
| 108 |
+
|
| 109 |
+
# Store page data with human-readable page numbering (1-indexed)
|
| 110 |
+
pages.append({
|
| 111 |
+
"page_number": page_num + 1, # Convert to 1-indexed for user clarity
|
| 112 |
+
"text": page_text,
|
| 113 |
+
"char_count": len(page_text) # Useful for chunking decisions
|
| 114 |
+
})
|
| 115 |
+
|
| 116 |
+
# Accumulate text for final concatenation
|
| 117 |
+
all_text.append(page_text)
|
| 118 |
+
|
| 119 |
+
# Properly close the PDF to free resources
|
| 120 |
+
doc.close()
|
| 121 |
+
|
| 122 |
+
# Calculate total extraction time for performance monitoring
|
| 123 |
+
extraction_time = time.perf_counter() - start_time
|
| 124 |
+
|
| 125 |
+
# Return comprehensive extraction results
|
| 126 |
+
return {
|
| 127 |
+
"text": "\n".join(all_text), # Full document text with page breaks
|
| 128 |
+
"pages": pages, # Detailed page-by-page breakdown
|
| 129 |
+
"metadata": metadata, # Original PDF metadata
|
| 130 |
+
"page_count": page_count, # Total pages for quick reference
|
| 131 |
+
"extraction_time": extraction_time # Performance metric
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
# Wrap any extraction errors with context for debugging
|
| 136 |
+
# Common causes: encrypted PDFs, corrupted files, unsupported formats
|
| 137 |
+
raise ValueError(f"Failed to process PDF: {e}")
|
shared_utils/document_processing/pdfplumber_parser.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
PDFPlumber-based Parser
|
| 4 |
+
|
| 5 |
+
Advanced PDF parsing using pdfplumber for better structure detection
|
| 6 |
+
and cleaner text extraction.
|
| 7 |
+
|
| 8 |
+
Author: Arthur Passuello
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import re
|
| 12 |
+
import pdfplumber
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PDFPlumberParser:
|
| 18 |
+
"""Advanced PDF parser using pdfplumber for structure-aware extraction."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800,
|
| 21 |
+
max_chunk_size: int = 2000):
|
| 22 |
+
"""Initialize PDFPlumber parser."""
|
| 23 |
+
self.target_chunk_size = target_chunk_size
|
| 24 |
+
self.min_chunk_size = min_chunk_size
|
| 25 |
+
self.max_chunk_size = max_chunk_size
|
| 26 |
+
|
| 27 |
+
# Trash content patterns
|
| 28 |
+
self.trash_patterns = [
|
| 29 |
+
r'Creative Commons.*?License',
|
| 30 |
+
r'International License.*?authors',
|
| 31 |
+
r'RISC-V International',
|
| 32 |
+
r'Visit.*?for further',
|
| 33 |
+
r'editors to suggest.*?corrections',
|
| 34 |
+
r'released under.*?license',
|
| 35 |
+
r'\.{5,}', # Long dots (TOC artifacts)
|
| 36 |
+
r'^\d+\s*$', # Page numbers alone
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
def extract_with_structure(self, pdf_path: Path) -> List[Dict]:
|
| 40 |
+
"""Extract PDF content with structure awareness using pdfplumber."""
|
| 41 |
+
chunks = []
|
| 42 |
+
|
| 43 |
+
with pdfplumber.open(pdf_path) as pdf:
|
| 44 |
+
current_section = None
|
| 45 |
+
current_text = []
|
| 46 |
+
|
| 47 |
+
for page_num, page in enumerate(pdf.pages):
|
| 48 |
+
# Extract text with formatting info
|
| 49 |
+
page_content = self._extract_page_content(page, page_num + 1)
|
| 50 |
+
|
| 51 |
+
for element in page_content:
|
| 52 |
+
if element['type'] == 'header':
|
| 53 |
+
# Save previous section if exists
|
| 54 |
+
if current_text:
|
| 55 |
+
chunk_text = '\n\n'.join(current_text)
|
| 56 |
+
if self._is_valid_chunk(chunk_text):
|
| 57 |
+
chunks.extend(self._create_chunks(
|
| 58 |
+
chunk_text,
|
| 59 |
+
current_section or "Document",
|
| 60 |
+
page_num
|
| 61 |
+
))
|
| 62 |
+
|
| 63 |
+
# Start new section
|
| 64 |
+
current_section = element['text']
|
| 65 |
+
current_text = []
|
| 66 |
+
|
| 67 |
+
elif element['type'] == 'content':
|
| 68 |
+
# Add to current section
|
| 69 |
+
if self._is_valid_content(element['text']):
|
| 70 |
+
current_text.append(element['text'])
|
| 71 |
+
|
| 72 |
+
# Don't forget last section
|
| 73 |
+
if current_text:
|
| 74 |
+
chunk_text = '\n\n'.join(current_text)
|
| 75 |
+
if self._is_valid_chunk(chunk_text):
|
| 76 |
+
chunks.extend(self._create_chunks(
|
| 77 |
+
chunk_text,
|
| 78 |
+
current_section or "Document",
|
| 79 |
+
len(pdf.pages)
|
| 80 |
+
))
|
| 81 |
+
|
| 82 |
+
return chunks
|
| 83 |
+
|
| 84 |
+
def _extract_page_content(self, page: Any, page_num: int) -> List[Dict]:
|
| 85 |
+
"""Extract structured content from a page."""
|
| 86 |
+
content = []
|
| 87 |
+
|
| 88 |
+
# Get all text with positioning
|
| 89 |
+
chars = page.chars
|
| 90 |
+
if not chars:
|
| 91 |
+
return content
|
| 92 |
+
|
| 93 |
+
# Group by lines
|
| 94 |
+
lines = []
|
| 95 |
+
current_line = []
|
| 96 |
+
current_y = None
|
| 97 |
+
|
| 98 |
+
for char in sorted(chars, key=lambda x: (x['top'], x['x0'])):
|
| 99 |
+
if current_y is None or abs(char['top'] - current_y) < 2:
|
| 100 |
+
current_line.append(char)
|
| 101 |
+
current_y = char['top']
|
| 102 |
+
else:
|
| 103 |
+
if current_line:
|
| 104 |
+
lines.append(current_line)
|
| 105 |
+
current_line = [char]
|
| 106 |
+
current_y = char['top']
|
| 107 |
+
|
| 108 |
+
if current_line:
|
| 109 |
+
lines.append(current_line)
|
| 110 |
+
|
| 111 |
+
# Analyze each line
|
| 112 |
+
for line in lines:
|
| 113 |
+
line_text = ''.join(char['text'] for char in line).strip()
|
| 114 |
+
|
| 115 |
+
if not line_text:
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
# Detect headers by font size
|
| 119 |
+
avg_font_size = sum(char.get('size', 12) for char in line) / len(line)
|
| 120 |
+
is_bold = any(char.get('fontname', '').lower().count('bold') > 0 for char in line)
|
| 121 |
+
|
| 122 |
+
# Classify content
|
| 123 |
+
if avg_font_size > 14 or is_bold:
|
| 124 |
+
# Likely a header
|
| 125 |
+
if self._is_valid_header(line_text):
|
| 126 |
+
content.append({
|
| 127 |
+
'type': 'header',
|
| 128 |
+
'text': line_text,
|
| 129 |
+
'font_size': avg_font_size,
|
| 130 |
+
'page': page_num
|
| 131 |
+
})
|
| 132 |
+
else:
|
| 133 |
+
# Regular content
|
| 134 |
+
content.append({
|
| 135 |
+
'type': 'content',
|
| 136 |
+
'text': line_text,
|
| 137 |
+
'font_size': avg_font_size,
|
| 138 |
+
'page': page_num
|
| 139 |
+
})
|
| 140 |
+
|
| 141 |
+
return content
|
| 142 |
+
|
| 143 |
+
def _is_valid_header(self, text: str) -> bool:
|
| 144 |
+
"""Check if text is a valid header."""
|
| 145 |
+
# Skip if too short or too long
|
| 146 |
+
if len(text) < 3 or len(text) > 200:
|
| 147 |
+
return False
|
| 148 |
+
|
| 149 |
+
# Skip if matches trash patterns
|
| 150 |
+
for pattern in self.trash_patterns:
|
| 151 |
+
if re.search(pattern, text, re.IGNORECASE):
|
| 152 |
+
return False
|
| 153 |
+
|
| 154 |
+
# Valid if starts with number or capital letter
|
| 155 |
+
if re.match(r'^(\d+\.?\d*\s+|[A-Z])', text):
|
| 156 |
+
return True
|
| 157 |
+
|
| 158 |
+
# Valid if contains keywords
|
| 159 |
+
keywords = ['chapter', 'section', 'introduction', 'conclusion', 'appendix']
|
| 160 |
+
return any(keyword in text.lower() for keyword in keywords)
|
| 161 |
+
|
| 162 |
+
def _is_valid_content(self, text: str) -> bool:
|
| 163 |
+
"""Check if text is valid content (not trash)."""
|
| 164 |
+
# Skip very short text
|
| 165 |
+
if len(text.strip()) < 10:
|
| 166 |
+
return False
|
| 167 |
+
|
| 168 |
+
# Skip trash patterns
|
| 169 |
+
for pattern in self.trash_patterns:
|
| 170 |
+
if re.search(pattern, text, re.IGNORECASE):
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
return True
|
| 174 |
+
|
| 175 |
+
def _is_valid_chunk(self, text: str) -> bool:
|
| 176 |
+
"""Check if chunk text is valid."""
|
| 177 |
+
# Must have minimum length
|
| 178 |
+
if len(text.strip()) < self.min_chunk_size // 2:
|
| 179 |
+
return False
|
| 180 |
+
|
| 181 |
+
# Must have some alphabetic content
|
| 182 |
+
alpha_chars = sum(1 for c in text if c.isalpha())
|
| 183 |
+
if alpha_chars < len(text) * 0.5:
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
return True
|
| 187 |
+
|
| 188 |
+
def _create_chunks(self, text: str, title: str, page: int) -> List[Dict]:
|
| 189 |
+
"""Create chunks from text."""
|
| 190 |
+
chunks = []
|
| 191 |
+
|
| 192 |
+
# Clean text
|
| 193 |
+
text = self._clean_text(text)
|
| 194 |
+
|
| 195 |
+
if len(text) <= self.max_chunk_size:
|
| 196 |
+
# Single chunk
|
| 197 |
+
chunks.append({
|
| 198 |
+
'text': text,
|
| 199 |
+
'title': title,
|
| 200 |
+
'page': page,
|
| 201 |
+
'metadata': {
|
| 202 |
+
'parsing_method': 'pdfplumber',
|
| 203 |
+
'quality_score': self._calculate_quality_score(text)
|
| 204 |
+
}
|
| 205 |
+
})
|
| 206 |
+
else:
|
| 207 |
+
# Split into chunks
|
| 208 |
+
text_chunks = self._split_text_into_chunks(text)
|
| 209 |
+
for i, chunk_text in enumerate(text_chunks):
|
| 210 |
+
chunks.append({
|
| 211 |
+
'text': chunk_text,
|
| 212 |
+
'title': f"{title} (Part {i+1})",
|
| 213 |
+
'page': page,
|
| 214 |
+
'metadata': {
|
| 215 |
+
'parsing_method': 'pdfplumber',
|
| 216 |
+
'part_number': i + 1,
|
| 217 |
+
'total_parts': len(text_chunks),
|
| 218 |
+
'quality_score': self._calculate_quality_score(chunk_text)
|
| 219 |
+
}
|
| 220 |
+
})
|
| 221 |
+
|
| 222 |
+
return chunks
|
| 223 |
+
|
| 224 |
+
def _clean_text(self, text: str) -> str:
|
| 225 |
+
"""Clean text from artifacts."""
|
| 226 |
+
# Remove volume headers (e.g., "Volume I: RISC-V Unprivileged ISA V20191213")
|
| 227 |
+
text = re.sub(r'Volume\s+[IVX]+:\s*RISC-V[^V]*V\d{8}\s*', '', text, flags=re.IGNORECASE)
|
| 228 |
+
text = re.sub(r'^\d+\s+Volume\s+[IVX]+:.*?$', '', text, flags=re.MULTILINE)
|
| 229 |
+
|
| 230 |
+
# Remove document version artifacts
|
| 231 |
+
text = re.sub(r'Document Version \d{8}\s*', '', text, flags=re.IGNORECASE)
|
| 232 |
+
|
| 233 |
+
# Remove repeated ISA headers
|
| 234 |
+
text = re.sub(r'RISC-V.*?ISA.*?V\d{8}\s*', '', text, flags=re.IGNORECASE)
|
| 235 |
+
text = re.sub(r'The RISC-V Instruction Set Manual\s*', '', text, flags=re.IGNORECASE)
|
| 236 |
+
|
| 237 |
+
# Remove figure/table references that are standalone
|
| 238 |
+
text = re.sub(r'^(Figure|Table)\s+\d+\.\d+:.*?$', '', text, flags=re.MULTILINE)
|
| 239 |
+
|
| 240 |
+
# Remove email addresses (often in contributor lists)
|
| 241 |
+
text = re.sub(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', '', text)
|
| 242 |
+
|
| 243 |
+
# Remove URLs
|
| 244 |
+
text = re.sub(r'https?://[^\s]+', '', text)
|
| 245 |
+
|
| 246 |
+
# Remove page numbers at start/end of lines
|
| 247 |
+
text = re.sub(r'^\d{1,3}\s+', '', text, flags=re.MULTILINE)
|
| 248 |
+
text = re.sub(r'\s+\d{1,3}$', '', text, flags=re.MULTILINE)
|
| 249 |
+
|
| 250 |
+
# Remove excessive dots (TOC artifacts)
|
| 251 |
+
text = re.sub(r'\.{3,}', '', text)
|
| 252 |
+
|
| 253 |
+
# Remove standalone numbers (often page numbers or figure numbers)
|
| 254 |
+
text = re.sub(r'^\s*\d+\s*$', '', text, flags=re.MULTILINE)
|
| 255 |
+
|
| 256 |
+
# Clean up multiple spaces and newlines
|
| 257 |
+
text = re.sub(r'\s{3,}', ' ', text)
|
| 258 |
+
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 259 |
+
text = re.sub(r'[ \t]+', ' ', text) # Normalize all whitespace
|
| 260 |
+
|
| 261 |
+
# Remove common boilerplate phrases
|
| 262 |
+
text = re.sub(r'Contains Nonbinding Recommendations\s*', '', text, flags=re.IGNORECASE)
|
| 263 |
+
text = re.sub(r'Guidance for Industry and FDA Staff\s*', '', text, flags=re.IGNORECASE)
|
| 264 |
+
|
| 265 |
+
return text.strip()
|
| 266 |
+
|
| 267 |
+
def _split_text_into_chunks(self, text: str) -> List[str]:
|
| 268 |
+
"""Split text into chunks at sentence boundaries."""
|
| 269 |
+
sentences = re.split(r'(?<=[.!?])\s+', text)
|
| 270 |
+
chunks = []
|
| 271 |
+
current_chunk = []
|
| 272 |
+
current_size = 0
|
| 273 |
+
|
| 274 |
+
for sentence in sentences:
|
| 275 |
+
sentence_size = len(sentence)
|
| 276 |
+
|
| 277 |
+
if current_size + sentence_size > self.target_chunk_size and current_chunk:
|
| 278 |
+
chunks.append(' '.join(current_chunk))
|
| 279 |
+
current_chunk = [sentence]
|
| 280 |
+
current_size = sentence_size
|
| 281 |
+
else:
|
| 282 |
+
current_chunk.append(sentence)
|
| 283 |
+
current_size += sentence_size + 1
|
| 284 |
+
|
| 285 |
+
if current_chunk:
|
| 286 |
+
chunks.append(' '.join(current_chunk))
|
| 287 |
+
|
| 288 |
+
return chunks
|
| 289 |
+
|
| 290 |
+
def _calculate_quality_score(self, text: str) -> float:
|
| 291 |
+
"""Calculate quality score for chunk."""
|
| 292 |
+
score = 1.0
|
| 293 |
+
|
| 294 |
+
# Penalize very short or very long
|
| 295 |
+
if len(text) < self.min_chunk_size:
|
| 296 |
+
score *= 0.8
|
| 297 |
+
elif len(text) > self.max_chunk_size:
|
| 298 |
+
score *= 0.9
|
| 299 |
+
|
| 300 |
+
# Reward complete sentences
|
| 301 |
+
if text.strip().endswith(('.', '!', '?')):
|
| 302 |
+
score *= 1.1
|
| 303 |
+
|
| 304 |
+
# Reward technical content
|
| 305 |
+
technical_terms = ['risc', 'instruction', 'register', 'memory', 'processor']
|
| 306 |
+
term_count = sum(1 for term in technical_terms if term in text.lower())
|
| 307 |
+
score *= (1 + term_count * 0.05)
|
| 308 |
+
|
| 309 |
+
return min(score, 1.0)
|
| 310 |
+
|
| 311 |
+
def extract_with_page_coverage(self, pdf_path: Path, pymupdf_pages: List[Dict]) -> List[Dict]:
|
| 312 |
+
"""
|
| 313 |
+
Extract content ensuring ALL pages are covered using PyMuPDF page data.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
pdf_path: Path to PDF file
|
| 317 |
+
pymupdf_pages: Page data from PyMuPDF with page numbers and text
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
List of chunks covering ALL document pages
|
| 321 |
+
"""
|
| 322 |
+
chunks = []
|
| 323 |
+
chunk_id = 0
|
| 324 |
+
|
| 325 |
+
print(f"📄 Processing {len(pymupdf_pages)} pages with PDFPlumber quality extraction...")
|
| 326 |
+
|
| 327 |
+
with pdfplumber.open(str(pdf_path)) as pdf:
|
| 328 |
+
for pymupdf_page in pymupdf_pages:
|
| 329 |
+
page_num = pymupdf_page['page_number'] # 1-indexed from PyMuPDF
|
| 330 |
+
page_idx = page_num - 1 # Convert to 0-indexed for PDFPlumber
|
| 331 |
+
|
| 332 |
+
if page_idx < len(pdf.pages):
|
| 333 |
+
# Extract with PDFPlumber quality from this specific page
|
| 334 |
+
pdfplumber_page = pdf.pages[page_idx]
|
| 335 |
+
page_text = pdfplumber_page.extract_text()
|
| 336 |
+
|
| 337 |
+
if page_text and page_text.strip():
|
| 338 |
+
# Clean and chunk the page text
|
| 339 |
+
cleaned_text = self._clean_text(page_text)
|
| 340 |
+
|
| 341 |
+
if len(cleaned_text) >= 100: # Minimum meaningful content
|
| 342 |
+
# Create chunks from this page
|
| 343 |
+
page_chunks = self._create_page_chunks(
|
| 344 |
+
cleaned_text, page_num, chunk_id
|
| 345 |
+
)
|
| 346 |
+
chunks.extend(page_chunks)
|
| 347 |
+
chunk_id += len(page_chunks)
|
| 348 |
+
|
| 349 |
+
if len(chunks) % 50 == 0: # Progress indicator
|
| 350 |
+
print(f" Processed {page_num} pages, created {len(chunks)} chunks")
|
| 351 |
+
|
| 352 |
+
print(f"✅ Full coverage: {len(chunks)} chunks from {len(pymupdf_pages)} pages")
|
| 353 |
+
return chunks
|
| 354 |
+
|
| 355 |
+
def _create_page_chunks(self, page_text: str, page_num: int, start_chunk_id: int) -> List[Dict]:
|
| 356 |
+
"""Create properly sized chunks from a single page's content."""
|
| 357 |
+
# Clean and validate page text first
|
| 358 |
+
cleaned_text = self._ensure_complete_sentences(page_text)
|
| 359 |
+
|
| 360 |
+
if not cleaned_text or len(cleaned_text) < 50:
|
| 361 |
+
# Skip pages with insufficient content
|
| 362 |
+
return []
|
| 363 |
+
|
| 364 |
+
if len(cleaned_text) <= self.max_chunk_size:
|
| 365 |
+
# Single chunk for small pages
|
| 366 |
+
return [{
|
| 367 |
+
'text': cleaned_text,
|
| 368 |
+
'title': f"Page {page_num}",
|
| 369 |
+
'page': page_num,
|
| 370 |
+
'metadata': {
|
| 371 |
+
'parsing_method': 'pdfplumber_page_coverage',
|
| 372 |
+
'quality_score': self._calculate_quality_score(cleaned_text),
|
| 373 |
+
'full_page_coverage': True
|
| 374 |
+
}
|
| 375 |
+
}]
|
| 376 |
+
else:
|
| 377 |
+
# Split large pages into chunks with sentence boundaries
|
| 378 |
+
text_chunks = self._split_text_into_chunks(cleaned_text)
|
| 379 |
+
page_chunks = []
|
| 380 |
+
|
| 381 |
+
for i, chunk_text in enumerate(text_chunks):
|
| 382 |
+
# Ensure each chunk is complete
|
| 383 |
+
complete_chunk = self._ensure_complete_sentences(chunk_text)
|
| 384 |
+
|
| 385 |
+
if complete_chunk and len(complete_chunk) >= 100:
|
| 386 |
+
page_chunks.append({
|
| 387 |
+
'text': complete_chunk,
|
| 388 |
+
'title': f"Page {page_num} (Part {i+1})",
|
| 389 |
+
'page': page_num,
|
| 390 |
+
'metadata': {
|
| 391 |
+
'parsing_method': 'pdfplumber_page_coverage',
|
| 392 |
+
'part_number': i + 1,
|
| 393 |
+
'total_parts': len(text_chunks),
|
| 394 |
+
'quality_score': self._calculate_quality_score(complete_chunk),
|
| 395 |
+
'full_page_coverage': True
|
| 396 |
+
}
|
| 397 |
+
})
|
| 398 |
+
|
| 399 |
+
return page_chunks
|
| 400 |
+
|
| 401 |
+
def _ensure_complete_sentences(self, text: str) -> str:
|
| 402 |
+
"""Ensure text contains only complete sentences."""
|
| 403 |
+
text = text.strip()
|
| 404 |
+
if not text:
|
| 405 |
+
return ""
|
| 406 |
+
|
| 407 |
+
# Find last complete sentence
|
| 408 |
+
last_sentence_end = -1
|
| 409 |
+
for i, char in enumerate(reversed(text)):
|
| 410 |
+
if char in '.!?:':
|
| 411 |
+
last_sentence_end = len(text) - i
|
| 412 |
+
break
|
| 413 |
+
|
| 414 |
+
if last_sentence_end > 0:
|
| 415 |
+
# Return text up to last complete sentence
|
| 416 |
+
complete_text = text[:last_sentence_end].strip()
|
| 417 |
+
|
| 418 |
+
# Ensure it starts properly (capital letter or common starters)
|
| 419 |
+
if complete_text and (complete_text[0].isupper() or
|
| 420 |
+
complete_text.startswith(('The ', 'A ', 'An ', 'This ', 'RISC'))):
|
| 421 |
+
return complete_text
|
| 422 |
+
|
| 423 |
+
# If no complete sentences found, return empty
|
| 424 |
+
return ""
|
| 425 |
+
|
| 426 |
+
def parse_document(self, pdf_path: Path, pdf_data: Dict[str, Any] = None) -> List[Dict]:
|
| 427 |
+
"""
|
| 428 |
+
Parse document using PDFPlumber (required by HybridParser).
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
pdf_path: Path to PDF file
|
| 432 |
+
pdf_data: PyMuPDF page data to ensure full page coverage
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
List of chunks with structure preservation across ALL pages
|
| 436 |
+
"""
|
| 437 |
+
if pdf_data and 'pages' in pdf_data:
|
| 438 |
+
# Use PyMuPDF page data to ensure full coverage
|
| 439 |
+
return self.extract_with_page_coverage(pdf_path, pdf_data['pages'])
|
| 440 |
+
else:
|
| 441 |
+
# Fallback to structure-based extraction
|
| 442 |
+
return self.extract_with_structure(pdf_path)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def parse_pdf_with_pdfplumber(pdf_path: Path, **kwargs) -> List[Dict]:
|
| 446 |
+
"""Main entry point for PDFPlumber parsing."""
|
| 447 |
+
parser = PDFPlumberParser(**kwargs)
|
| 448 |
+
chunks = parser.extract_with_structure(pdf_path)
|
| 449 |
+
|
| 450 |
+
print(f"PDFPlumber extracted {len(chunks)} chunks")
|
| 451 |
+
|
| 452 |
+
return chunks
|
shared_utils/document_processing/toc_guided_parser.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
TOC-Guided PDF Parser
|
| 4 |
+
|
| 5 |
+
Uses the Table of Contents to guide intelligent chunking that respects
|
| 6 |
+
document structure and hierarchy.
|
| 7 |
+
|
| 8 |
+
Author: Arthur Passuello
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import re
|
| 12 |
+
from typing import Dict, List, Optional, Tuple
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TOCEntry:
|
| 18 |
+
"""Represents a table of contents entry."""
|
| 19 |
+
title: str
|
| 20 |
+
page: int
|
| 21 |
+
level: int # 0 for chapters, 1 for sections, 2 for subsections
|
| 22 |
+
parent: Optional[str] = None
|
| 23 |
+
parent_title: Optional[str] = None # Added for hybrid parser compatibility
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TOCGuidedParser:
|
| 27 |
+
"""Parser that uses TOC to create structure-aware chunks."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800,
|
| 30 |
+
max_chunk_size: int = 2000):
|
| 31 |
+
"""Initialize TOC-guided parser."""
|
| 32 |
+
self.target_chunk_size = target_chunk_size
|
| 33 |
+
self.min_chunk_size = min_chunk_size
|
| 34 |
+
self.max_chunk_size = max_chunk_size
|
| 35 |
+
|
| 36 |
+
def parse_toc(self, pages: List[Dict]) -> List[TOCEntry]:
|
| 37 |
+
"""Parse table of contents from pages."""
|
| 38 |
+
toc_entries = []
|
| 39 |
+
|
| 40 |
+
# Find TOC pages (usually early in document)
|
| 41 |
+
toc_pages = []
|
| 42 |
+
for i, page in enumerate(pages[:20]): # Check first 20 pages
|
| 43 |
+
page_text = page.get('text', '').lower()
|
| 44 |
+
if 'contents' in page_text or 'table of contents' in page_text:
|
| 45 |
+
toc_pages.append((i, page))
|
| 46 |
+
|
| 47 |
+
if not toc_pages:
|
| 48 |
+
print("No TOC found, using fallback structure detection")
|
| 49 |
+
return self._detect_structure_without_toc(pages)
|
| 50 |
+
|
| 51 |
+
# Parse TOC entries
|
| 52 |
+
for page_idx, page in toc_pages:
|
| 53 |
+
text = page.get('text', '')
|
| 54 |
+
lines = text.split('\n')
|
| 55 |
+
|
| 56 |
+
i = 0
|
| 57 |
+
while i < len(lines):
|
| 58 |
+
line = lines[i].strip()
|
| 59 |
+
|
| 60 |
+
# Skip empty lines and TOC header
|
| 61 |
+
if not line or 'contents' in line.lower():
|
| 62 |
+
i += 1
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
# Pattern 1: "1.1 Title .... 23"
|
| 66 |
+
match1 = re.match(r'^(\d+(?:\.\d+)*)\s+(.+?)\s*\.{2,}\s*(\d+)$', line)
|
| 67 |
+
if match1:
|
| 68 |
+
number, title, page_num = match1.groups()
|
| 69 |
+
level = len(number.split('.')) - 1
|
| 70 |
+
toc_entries.append(TOCEntry(
|
| 71 |
+
title=title.strip(),
|
| 72 |
+
page=int(page_num),
|
| 73 |
+
level=level
|
| 74 |
+
))
|
| 75 |
+
i += 1
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
# Pattern 2: Multi-line format
|
| 79 |
+
# "1.1"
|
| 80 |
+
# "Title"
|
| 81 |
+
# ". . . . 23"
|
| 82 |
+
if re.match(r'^(\d+(?:\.\d+)*)$', line):
|
| 83 |
+
number = line
|
| 84 |
+
if i + 1 < len(lines):
|
| 85 |
+
title_line = lines[i + 1].strip()
|
| 86 |
+
if i + 2 < len(lines):
|
| 87 |
+
dots_line = lines[i + 2].strip()
|
| 88 |
+
page_match = re.search(r'(\d+)\s*$', dots_line)
|
| 89 |
+
if page_match and '.' in dots_line:
|
| 90 |
+
title = title_line
|
| 91 |
+
page_num = int(page_match.group(1))
|
| 92 |
+
level = len(number.split('.')) - 1
|
| 93 |
+
toc_entries.append(TOCEntry(
|
| 94 |
+
title=title,
|
| 95 |
+
page=page_num,
|
| 96 |
+
level=level
|
| 97 |
+
))
|
| 98 |
+
i += 3
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
# Pattern 3: "Chapter 1: Title ... 23"
|
| 102 |
+
match3 = re.match(r'^(Chapter|Section|Part)\s+(\d+):?\s+(.+?)\s*\.{2,}\s*(\d+)$', line, re.IGNORECASE)
|
| 103 |
+
if match3:
|
| 104 |
+
prefix, number, title, page_num = match3.groups()
|
| 105 |
+
level = 0 if prefix.lower() == 'chapter' else 1
|
| 106 |
+
toc_entries.append(TOCEntry(
|
| 107 |
+
title=f"{prefix} {number}: {title}",
|
| 108 |
+
page=int(page_num),
|
| 109 |
+
level=level
|
| 110 |
+
))
|
| 111 |
+
i += 1
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
i += 1
|
| 115 |
+
|
| 116 |
+
# Add parent relationships
|
| 117 |
+
for i, entry in enumerate(toc_entries):
|
| 118 |
+
if entry.level > 0:
|
| 119 |
+
# Find parent (previous entry with lower level)
|
| 120 |
+
for j in range(i - 1, -1, -1):
|
| 121 |
+
if toc_entries[j].level < entry.level:
|
| 122 |
+
entry.parent = toc_entries[j].title
|
| 123 |
+
entry.parent_title = toc_entries[j].title # Set both for compatibility
|
| 124 |
+
break
|
| 125 |
+
|
| 126 |
+
return toc_entries
|
| 127 |
+
|
| 128 |
+
def _detect_structure_without_toc(self, pages: List[Dict]) -> List[TOCEntry]:
|
| 129 |
+
"""Fallback: detect structure from content patterns across ALL pages."""
|
| 130 |
+
entries = []
|
| 131 |
+
|
| 132 |
+
# Expanded patterns for better structure detection
|
| 133 |
+
chapter_patterns = [
|
| 134 |
+
re.compile(r'^(Chapter|CHAPTER)\s+(\d+|[IVX]+)(?:\s*[:\-]\s*(.+))?', re.MULTILINE),
|
| 135 |
+
re.compile(r'^(\d+)\s+([A-Z][^.]*?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1 Introduction"
|
| 136 |
+
re.compile(r'^([A-Z][A-Z\s]{10,})$', re.MULTILINE), # ALL CAPS titles
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
section_patterns = [
|
| 140 |
+
re.compile(r'^(\d+\.\d+)\s+(.+?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1.1 Section"
|
| 141 |
+
re.compile(r'^(\d+\.\d+\.\d+)\s+(.+?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1.1.1 Subsection"
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
# Process ALL pages, not just first 20
|
| 145 |
+
for i, page in enumerate(pages):
|
| 146 |
+
text = page.get('text', '')
|
| 147 |
+
if not text.strip():
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
# Find chapters with various patterns
|
| 151 |
+
for pattern in chapter_patterns:
|
| 152 |
+
for match in pattern.finditer(text):
|
| 153 |
+
if len(match.groups()) >= 2:
|
| 154 |
+
if len(match.groups()) >= 3 and match.group(3):
|
| 155 |
+
title = match.group(3).strip()
|
| 156 |
+
else:
|
| 157 |
+
title = match.group(2).strip() if match.group(2) else f"Section {match.group(1)}"
|
| 158 |
+
|
| 159 |
+
# Skip very short or likely false positives
|
| 160 |
+
if len(title) >= 3 and not re.match(r'^\d+$', title):
|
| 161 |
+
entries.append(TOCEntry(
|
| 162 |
+
title=title,
|
| 163 |
+
page=i + 1,
|
| 164 |
+
level=0
|
| 165 |
+
))
|
| 166 |
+
|
| 167 |
+
# Find sections
|
| 168 |
+
for pattern in section_patterns:
|
| 169 |
+
for match in pattern.finditer(text):
|
| 170 |
+
section_num = match.group(1)
|
| 171 |
+
title = match.group(2).strip() if len(match.groups()) >= 2 else f"Section {section_num}"
|
| 172 |
+
|
| 173 |
+
# Determine level by number of dots
|
| 174 |
+
level = section_num.count('.')
|
| 175 |
+
|
| 176 |
+
# Skip very short titles or obvious artifacts
|
| 177 |
+
if len(title) >= 3 and not re.match(r'^\d+$', title):
|
| 178 |
+
entries.append(TOCEntry(
|
| 179 |
+
title=title,
|
| 180 |
+
page=i + 1,
|
| 181 |
+
level=level
|
| 182 |
+
))
|
| 183 |
+
|
| 184 |
+
# If still no entries found, create page-based entries for full coverage
|
| 185 |
+
if not entries:
|
| 186 |
+
print("No structure patterns found, creating page-based sections for full coverage")
|
| 187 |
+
# Create sections every 10 pages to ensure full document coverage
|
| 188 |
+
for i in range(0, len(pages), 10):
|
| 189 |
+
start_page = i + 1
|
| 190 |
+
end_page = min(i + 10, len(pages))
|
| 191 |
+
title = f"Pages {start_page}-{end_page}"
|
| 192 |
+
entries.append(TOCEntry(
|
| 193 |
+
title=title,
|
| 194 |
+
page=start_page,
|
| 195 |
+
level=0
|
| 196 |
+
))
|
| 197 |
+
|
| 198 |
+
return entries
|
| 199 |
+
|
| 200 |
+
def create_chunks_from_toc(self, pdf_data: Dict, toc_entries: List[TOCEntry]) -> List[Dict]:
|
| 201 |
+
"""Create chunks based on TOC structure."""
|
| 202 |
+
chunks = []
|
| 203 |
+
pages = pdf_data.get('pages', [])
|
| 204 |
+
|
| 205 |
+
for i, entry in enumerate(toc_entries):
|
| 206 |
+
# Determine page range for this entry
|
| 207 |
+
start_page = entry.page - 1 # Convert to 0-indexed
|
| 208 |
+
|
| 209 |
+
# Find end page (start of next entry at same or higher level)
|
| 210 |
+
end_page = len(pages)
|
| 211 |
+
for j in range(i + 1, len(toc_entries)):
|
| 212 |
+
if toc_entries[j].level <= entry.level:
|
| 213 |
+
end_page = toc_entries[j].page - 1
|
| 214 |
+
break
|
| 215 |
+
|
| 216 |
+
# Extract text for this section
|
| 217 |
+
section_text = []
|
| 218 |
+
for page_idx in range(max(0, start_page), min(end_page, len(pages))):
|
| 219 |
+
page_text = pages[page_idx].get('text', '')
|
| 220 |
+
if page_text.strip():
|
| 221 |
+
section_text.append(page_text)
|
| 222 |
+
|
| 223 |
+
if not section_text:
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
full_text = '\n\n'.join(section_text)
|
| 227 |
+
|
| 228 |
+
# Create chunks from section text
|
| 229 |
+
if len(full_text) <= self.max_chunk_size:
|
| 230 |
+
# Single chunk for small sections
|
| 231 |
+
chunks.append({
|
| 232 |
+
'text': full_text.strip(),
|
| 233 |
+
'title': entry.title,
|
| 234 |
+
'parent_title': entry.parent_title or entry.parent or '',
|
| 235 |
+
'level': entry.level,
|
| 236 |
+
'page': entry.page,
|
| 237 |
+
'context': f"From {entry.title}",
|
| 238 |
+
'metadata': {
|
| 239 |
+
'parsing_method': 'toc_guided',
|
| 240 |
+
'section_title': entry.title,
|
| 241 |
+
'hierarchy_level': entry.level
|
| 242 |
+
}
|
| 243 |
+
})
|
| 244 |
+
else:
|
| 245 |
+
# Split large sections into chunks
|
| 246 |
+
section_chunks = self._split_text_into_chunks(full_text)
|
| 247 |
+
for j, chunk_text in enumerate(section_chunks):
|
| 248 |
+
chunks.append({
|
| 249 |
+
'text': chunk_text.strip(),
|
| 250 |
+
'title': f"{entry.title} (Part {j+1})",
|
| 251 |
+
'parent_title': entry.parent_title or entry.parent or '',
|
| 252 |
+
'level': entry.level,
|
| 253 |
+
'page': entry.page,
|
| 254 |
+
'context': f"Part {j+1} of {entry.title}",
|
| 255 |
+
'metadata': {
|
| 256 |
+
'parsing_method': 'toc_guided',
|
| 257 |
+
'section_title': entry.title,
|
| 258 |
+
'hierarchy_level': entry.level,
|
| 259 |
+
'part_number': j + 1,
|
| 260 |
+
'total_parts': len(section_chunks)
|
| 261 |
+
}
|
| 262 |
+
})
|
| 263 |
+
|
| 264 |
+
return chunks
|
| 265 |
+
|
| 266 |
+
def _split_text_into_chunks(self, text: str) -> List[str]:
|
| 267 |
+
"""Split text into chunks while preserving sentence boundaries."""
|
| 268 |
+
sentences = re.split(r'(?<=[.!?])\s+', text)
|
| 269 |
+
chunks = []
|
| 270 |
+
current_chunk = []
|
| 271 |
+
current_size = 0
|
| 272 |
+
|
| 273 |
+
for sentence in sentences:
|
| 274 |
+
sentence_size = len(sentence)
|
| 275 |
+
|
| 276 |
+
if current_size + sentence_size > self.target_chunk_size and current_chunk:
|
| 277 |
+
# Save current chunk
|
| 278 |
+
chunks.append(' '.join(current_chunk))
|
| 279 |
+
current_chunk = [sentence]
|
| 280 |
+
current_size = sentence_size
|
| 281 |
+
else:
|
| 282 |
+
current_chunk.append(sentence)
|
| 283 |
+
current_size += sentence_size + 1 # +1 for space
|
| 284 |
+
|
| 285 |
+
if current_chunk:
|
| 286 |
+
chunks.append(' '.join(current_chunk))
|
| 287 |
+
|
| 288 |
+
return chunks
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def parse_pdf_with_toc_guidance(pdf_data: Dict, **kwargs) -> List[Dict]:
|
| 292 |
+
"""Main entry point for TOC-guided parsing."""
|
| 293 |
+
parser = TOCGuidedParser(**kwargs)
|
| 294 |
+
|
| 295 |
+
# Parse TOC
|
| 296 |
+
pages = pdf_data.get('pages', [])
|
| 297 |
+
toc_entries = parser.parse_toc(pages)
|
| 298 |
+
|
| 299 |
+
print(f"Found {len(toc_entries)} TOC entries")
|
| 300 |
+
|
| 301 |
+
if not toc_entries:
|
| 302 |
+
print("No TOC entries found, falling back to basic chunking")
|
| 303 |
+
from .chunker import chunk_technical_text
|
| 304 |
+
return chunk_technical_text(pdf_data.get('text', ''))
|
| 305 |
+
|
| 306 |
+
# Create chunks based on TOC
|
| 307 |
+
chunks = parser.create_chunks_from_toc(pdf_data, toc_entries)
|
| 308 |
+
|
| 309 |
+
print(f"Created {len(chunks)} chunks from TOC structure")
|
| 310 |
+
|
| 311 |
+
return chunks
|
shared_utils/embeddings/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Embeddings module
|
shared_utils/embeddings/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (169 Bytes). View file
|
|
|
shared_utils/embeddings/__pycache__/generator.cpython-312.pyc
ADDED
|
Binary file (3.02 kB). View file
|
|
|
shared_utils/embeddings/generator.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
|
| 6 |
+
# Global cache for embeddings
|
| 7 |
+
_embedding_cache = {}
|
| 8 |
+
_model_cache = {}
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def generate_embeddings(
|
| 12 |
+
texts: List[str],
|
| 13 |
+
model_name: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
| 14 |
+
batch_size: int = 32,
|
| 15 |
+
use_mps: bool = True,
|
| 16 |
+
) -> np.ndarray:
|
| 17 |
+
"""
|
| 18 |
+
Generate embeddings for text chunks with caching.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
texts: List of text chunks to embed
|
| 22 |
+
model_name: SentenceTransformer model identifier
|
| 23 |
+
batch_size: Processing batch size
|
| 24 |
+
use_mps: Use Apple Silicon acceleration
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
numpy array of shape (len(texts), embedding_dim)
|
| 28 |
+
|
| 29 |
+
Performance Target:
|
| 30 |
+
- 100 texts/second on M4-Pro
|
| 31 |
+
- 384-dimensional embeddings
|
| 32 |
+
- Memory usage <500MB
|
| 33 |
+
"""
|
| 34 |
+
# Check cache for all texts
|
| 35 |
+
cache_keys = [f"{model_name}:{text}" for text in texts]
|
| 36 |
+
cached_embeddings = []
|
| 37 |
+
texts_to_compute = []
|
| 38 |
+
compute_indices = []
|
| 39 |
+
|
| 40 |
+
for i, key in enumerate(cache_keys):
|
| 41 |
+
if key in _embedding_cache:
|
| 42 |
+
cached_embeddings.append((i, _embedding_cache[key]))
|
| 43 |
+
else:
|
| 44 |
+
texts_to_compute.append(texts[i])
|
| 45 |
+
compute_indices.append(i)
|
| 46 |
+
|
| 47 |
+
# Load model if needed
|
| 48 |
+
if model_name not in _model_cache:
|
| 49 |
+
model = SentenceTransformer(model_name)
|
| 50 |
+
device = 'mps' if use_mps and torch.backends.mps.is_available() else 'cpu'
|
| 51 |
+
model = model.to(device)
|
| 52 |
+
model.eval()
|
| 53 |
+
_model_cache[model_name] = model
|
| 54 |
+
else:
|
| 55 |
+
model = _model_cache[model_name]
|
| 56 |
+
|
| 57 |
+
# Compute new embeddings
|
| 58 |
+
if texts_to_compute:
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
new_embeddings = model.encode(
|
| 61 |
+
texts_to_compute,
|
| 62 |
+
batch_size=batch_size,
|
| 63 |
+
convert_to_numpy=True,
|
| 64 |
+
normalize_embeddings=False
|
| 65 |
+
).astype(np.float32)
|
| 66 |
+
|
| 67 |
+
# Cache new embeddings
|
| 68 |
+
for i, text in enumerate(texts_to_compute):
|
| 69 |
+
key = f"{model_name}:{text}"
|
| 70 |
+
_embedding_cache[key] = new_embeddings[i]
|
| 71 |
+
|
| 72 |
+
# Reconstruct full embedding array
|
| 73 |
+
result = np.zeros((len(texts), 384), dtype=np.float32)
|
| 74 |
+
|
| 75 |
+
# Fill cached embeddings
|
| 76 |
+
for idx, embedding in cached_embeddings:
|
| 77 |
+
result[idx] = embedding
|
| 78 |
+
|
| 79 |
+
# Fill newly computed embeddings
|
| 80 |
+
if texts_to_compute:
|
| 81 |
+
for i, original_idx in enumerate(compute_indices):
|
| 82 |
+
result[original_idx] = new_embeddings[i]
|
| 83 |
+
|
| 84 |
+
return result
|
shared_utils/generation/__pycache__/adaptive_prompt_engine.cpython-312.pyc
ADDED
|
Binary file (19.9 kB). View file
|
|
|
shared_utils/generation/__pycache__/answer_generator.cpython-312.pyc
ADDED
|
Binary file (27.1 kB). View file
|
|
|
shared_utils/generation/__pycache__/chain_of_thought_engine.cpython-312.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
shared_utils/generation/__pycache__/hf_answer_generator.cpython-312.pyc
ADDED
|
Binary file (35.8 kB). View file
|
|
|
shared_utils/generation/__pycache__/inference_providers_generator.cpython-312.pyc
ADDED
|
Binary file (22.2 kB). View file
|
|
|
shared_utils/generation/__pycache__/ollama_answer_generator.cpython-312.pyc
ADDED
|
Binary file (32 kB). View file
|
|
|
shared_utils/generation/__pycache__/prompt_optimizer.cpython-312.pyc
ADDED
|
Binary file (28.1 kB). View file
|
|
|
shared_utils/generation/__pycache__/prompt_templates.cpython-312.pyc
ADDED
|
Binary file (21.6 kB). View file
|
|
|
shared_utils/generation/adaptive_prompt_engine.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adaptive Prompt Engine for Dynamic Context-Aware Prompt Optimization.
|
| 3 |
+
|
| 4 |
+
This module provides intelligent prompt adaptation based on context quality,
|
| 5 |
+
query complexity, and performance requirements.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from enum import Enum
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from .prompt_templates import (
|
| 15 |
+
QueryType,
|
| 16 |
+
PromptTemplate,
|
| 17 |
+
TechnicalPromptTemplates
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ContextQuality(Enum):
|
| 22 |
+
"""Context quality levels for adaptive prompting."""
|
| 23 |
+
HIGH = "high" # >0.8 relevance, low noise
|
| 24 |
+
MEDIUM = "medium" # 0.5-0.8 relevance, moderate noise
|
| 25 |
+
LOW = "low" # <0.5 relevance, high noise
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class QueryComplexity(Enum):
|
| 29 |
+
"""Query complexity levels."""
|
| 30 |
+
SIMPLE = "simple" # Single concept, direct answer
|
| 31 |
+
MODERATE = "moderate" # Multiple concepts, structured answer
|
| 32 |
+
COMPLEX = "complex" # Multi-step reasoning, comprehensive answer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class ContextMetrics:
|
| 37 |
+
"""Metrics for evaluating context quality."""
|
| 38 |
+
relevance_score: float
|
| 39 |
+
noise_ratio: float
|
| 40 |
+
chunk_count: int
|
| 41 |
+
avg_chunk_length: int
|
| 42 |
+
technical_density: float
|
| 43 |
+
source_diversity: int
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class AdaptivePromptConfig:
|
| 48 |
+
"""Configuration for adaptive prompt generation."""
|
| 49 |
+
context_quality: ContextQuality
|
| 50 |
+
query_complexity: QueryComplexity
|
| 51 |
+
max_context_length: int
|
| 52 |
+
prefer_concise: bool
|
| 53 |
+
include_few_shot: bool
|
| 54 |
+
enable_chain_of_thought: bool
|
| 55 |
+
confidence_threshold: float
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AdaptivePromptEngine:
|
| 59 |
+
"""
|
| 60 |
+
Intelligent prompt adaptation engine that optimizes prompts based on:
|
| 61 |
+
- Context quality and relevance
|
| 62 |
+
- Query complexity and type
|
| 63 |
+
- Performance requirements
|
| 64 |
+
- User preferences
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self):
|
| 68 |
+
"""Initialize the adaptive prompt engine."""
|
| 69 |
+
self.logger = logging.getLogger(__name__)
|
| 70 |
+
|
| 71 |
+
# Context quality thresholds
|
| 72 |
+
self.high_quality_threshold = 0.8
|
| 73 |
+
self.medium_quality_threshold = 0.5
|
| 74 |
+
|
| 75 |
+
# Query complexity indicators
|
| 76 |
+
self.complex_keywords = {
|
| 77 |
+
"implementation": ["implement", "build", "create", "develop", "setup"],
|
| 78 |
+
"comparison": ["compare", "difference", "versus", "vs", "better"],
|
| 79 |
+
"analysis": ["analyze", "evaluate", "assess", "study", "examine"],
|
| 80 |
+
"multi_step": ["process", "procedure", "steps", "how to", "guide"]
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
# Length optimization thresholds
|
| 84 |
+
self.token_limits = {
|
| 85 |
+
"concise": 512,
|
| 86 |
+
"standard": 1024,
|
| 87 |
+
"detailed": 2048,
|
| 88 |
+
"comprehensive": 4096
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
def analyze_context_quality(self, chunks: List[Dict[str, Any]]) -> ContextMetrics:
|
| 92 |
+
"""
|
| 93 |
+
Analyze the quality of retrieved context chunks.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
chunks: List of context chunks with metadata
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
ContextMetrics with quality assessment
|
| 100 |
+
"""
|
| 101 |
+
if not chunks:
|
| 102 |
+
return ContextMetrics(
|
| 103 |
+
relevance_score=0.0,
|
| 104 |
+
noise_ratio=1.0,
|
| 105 |
+
chunk_count=0,
|
| 106 |
+
avg_chunk_length=0,
|
| 107 |
+
technical_density=0.0,
|
| 108 |
+
source_diversity=0
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Calculate relevance score (using confidence scores if available)
|
| 112 |
+
relevance_scores = []
|
| 113 |
+
for chunk in chunks:
|
| 114 |
+
# Use confidence score if available, otherwise use a heuristic
|
| 115 |
+
if 'confidence' in chunk:
|
| 116 |
+
relevance_scores.append(chunk['confidence'])
|
| 117 |
+
elif 'score' in chunk:
|
| 118 |
+
relevance_scores.append(chunk['score'])
|
| 119 |
+
else:
|
| 120 |
+
# Heuristic: longer chunks with technical terms are more relevant
|
| 121 |
+
content = chunk.get('content', chunk.get('text', ''))
|
| 122 |
+
tech_terms = self._count_technical_terms(content)
|
| 123 |
+
relevance_scores.append(min(tech_terms / 10.0, 1.0))
|
| 124 |
+
|
| 125 |
+
avg_relevance = np.mean(relevance_scores) if relevance_scores else 0.0
|
| 126 |
+
|
| 127 |
+
# Calculate noise ratio (fragments, repetitive content)
|
| 128 |
+
noise_count = 0
|
| 129 |
+
total_chunks = len(chunks)
|
| 130 |
+
|
| 131 |
+
for chunk in chunks:
|
| 132 |
+
content = chunk.get('content', chunk.get('text', ''))
|
| 133 |
+
if self._is_noisy_chunk(content):
|
| 134 |
+
noise_count += 1
|
| 135 |
+
|
| 136 |
+
noise_ratio = noise_count / total_chunks if total_chunks > 0 else 0.0
|
| 137 |
+
|
| 138 |
+
# Calculate average chunk length
|
| 139 |
+
chunk_lengths = []
|
| 140 |
+
for chunk in chunks:
|
| 141 |
+
content = chunk.get('content', chunk.get('text', ''))
|
| 142 |
+
chunk_lengths.append(len(content))
|
| 143 |
+
|
| 144 |
+
avg_chunk_length = int(np.mean(chunk_lengths)) if chunk_lengths else 0
|
| 145 |
+
|
| 146 |
+
# Calculate technical density
|
| 147 |
+
technical_density = self._calculate_technical_density(chunks)
|
| 148 |
+
|
| 149 |
+
# Calculate source diversity
|
| 150 |
+
sources = set()
|
| 151 |
+
for chunk in chunks:
|
| 152 |
+
source = chunk.get('metadata', {}).get('source', 'unknown')
|
| 153 |
+
sources.add(source)
|
| 154 |
+
|
| 155 |
+
source_diversity = len(sources)
|
| 156 |
+
|
| 157 |
+
return ContextMetrics(
|
| 158 |
+
relevance_score=avg_relevance,
|
| 159 |
+
noise_ratio=noise_ratio,
|
| 160 |
+
chunk_count=len(chunks),
|
| 161 |
+
avg_chunk_length=avg_chunk_length,
|
| 162 |
+
technical_density=technical_density,
|
| 163 |
+
source_diversity=source_diversity
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def determine_query_complexity(self, query: str) -> QueryComplexity:
|
| 167 |
+
"""
|
| 168 |
+
Determine the complexity level of a query.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
query: User's question
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
QueryComplexity level
|
| 175 |
+
"""
|
| 176 |
+
query_lower = query.lower()
|
| 177 |
+
complexity_score = 0
|
| 178 |
+
|
| 179 |
+
# Check for complex keywords
|
| 180 |
+
for category, keywords in self.complex_keywords.items():
|
| 181 |
+
if any(keyword in query_lower for keyword in keywords):
|
| 182 |
+
complexity_score += 1
|
| 183 |
+
|
| 184 |
+
# Check for multiple questions or concepts
|
| 185 |
+
if '?' in query[:-1]: # Multiple question marks (excluding the last one)
|
| 186 |
+
complexity_score += 1
|
| 187 |
+
|
| 188 |
+
if any(word in query_lower for word in ["and", "or", "also", "additionally", "furthermore"]):
|
| 189 |
+
complexity_score += 1
|
| 190 |
+
|
| 191 |
+
# Check query length
|
| 192 |
+
word_count = len(query.split())
|
| 193 |
+
if word_count > 20:
|
| 194 |
+
complexity_score += 1
|
| 195 |
+
elif word_count > 10:
|
| 196 |
+
complexity_score += 0.5
|
| 197 |
+
|
| 198 |
+
# Determine complexity level
|
| 199 |
+
if complexity_score >= 2:
|
| 200 |
+
return QueryComplexity.COMPLEX
|
| 201 |
+
elif complexity_score >= 1:
|
| 202 |
+
return QueryComplexity.MODERATE
|
| 203 |
+
else:
|
| 204 |
+
return QueryComplexity.SIMPLE
|
| 205 |
+
|
| 206 |
+
def generate_adaptive_config(
|
| 207 |
+
self,
|
| 208 |
+
query: str,
|
| 209 |
+
context_chunks: List[Dict[str, Any]],
|
| 210 |
+
max_tokens: int = 2048,
|
| 211 |
+
prefer_speed: bool = False
|
| 212 |
+
) -> AdaptivePromptConfig:
|
| 213 |
+
"""
|
| 214 |
+
Generate adaptive prompt configuration based on context and query analysis.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
query: User's question
|
| 218 |
+
context_chunks: Retrieved context chunks
|
| 219 |
+
max_tokens: Maximum token limit
|
| 220 |
+
prefer_speed: Whether to optimize for speed over quality
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
AdaptivePromptConfig with optimized settings
|
| 224 |
+
"""
|
| 225 |
+
# Analyze context quality
|
| 226 |
+
context_metrics = self.analyze_context_quality(context_chunks)
|
| 227 |
+
|
| 228 |
+
# Determine context quality level
|
| 229 |
+
if context_metrics.relevance_score >= self.high_quality_threshold:
|
| 230 |
+
context_quality = ContextQuality.HIGH
|
| 231 |
+
elif context_metrics.relevance_score >= self.medium_quality_threshold:
|
| 232 |
+
context_quality = ContextQuality.MEDIUM
|
| 233 |
+
else:
|
| 234 |
+
context_quality = ContextQuality.LOW
|
| 235 |
+
|
| 236 |
+
# Determine query complexity
|
| 237 |
+
query_complexity = self.determine_query_complexity(query)
|
| 238 |
+
|
| 239 |
+
# Adapt configuration based on analysis
|
| 240 |
+
config = AdaptivePromptConfig(
|
| 241 |
+
context_quality=context_quality,
|
| 242 |
+
query_complexity=query_complexity,
|
| 243 |
+
max_context_length=max_tokens,
|
| 244 |
+
prefer_concise=prefer_speed,
|
| 245 |
+
include_few_shot=self._should_include_few_shot(context_quality, query_complexity),
|
| 246 |
+
enable_chain_of_thought=self._should_enable_cot(query_complexity),
|
| 247 |
+
confidence_threshold=self._get_confidence_threshold(context_quality)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return config
|
| 251 |
+
|
| 252 |
+
def create_adaptive_prompt(
|
| 253 |
+
self,
|
| 254 |
+
query: str,
|
| 255 |
+
context_chunks: List[Dict[str, Any]],
|
| 256 |
+
config: Optional[AdaptivePromptConfig] = None
|
| 257 |
+
) -> Dict[str, str]:
|
| 258 |
+
"""
|
| 259 |
+
Create an adaptive prompt optimized for the specific query and context.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
query: User's question
|
| 263 |
+
context_chunks: Retrieved context chunks
|
| 264 |
+
config: Optional configuration (auto-generated if None)
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
Dict with optimized 'system' and 'user' prompts
|
| 268 |
+
"""
|
| 269 |
+
if config is None:
|
| 270 |
+
config = self.generate_adaptive_config(query, context_chunks)
|
| 271 |
+
|
| 272 |
+
# Get base template
|
| 273 |
+
query_type = TechnicalPromptTemplates.detect_query_type(query)
|
| 274 |
+
base_template = TechnicalPromptTemplates.get_template_for_query(query)
|
| 275 |
+
|
| 276 |
+
# Adapt template based on configuration
|
| 277 |
+
adapted_template = self._adapt_template(base_template, config)
|
| 278 |
+
|
| 279 |
+
# Format context with optimization
|
| 280 |
+
formatted_context = self._format_context_adaptive(context_chunks, config)
|
| 281 |
+
|
| 282 |
+
# Create prompt with adaptive formatting
|
| 283 |
+
prompt = TechnicalPromptTemplates.format_prompt_with_template(
|
| 284 |
+
query=query,
|
| 285 |
+
context=formatted_context,
|
| 286 |
+
template=adapted_template,
|
| 287 |
+
include_few_shot=config.include_few_shot
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Add chain-of-thought if enabled
|
| 291 |
+
if config.enable_chain_of_thought:
|
| 292 |
+
prompt = self._add_chain_of_thought(prompt, query_type)
|
| 293 |
+
|
| 294 |
+
return prompt
|
| 295 |
+
|
| 296 |
+
def _adapt_template(
|
| 297 |
+
self,
|
| 298 |
+
base_template: PromptTemplate,
|
| 299 |
+
config: AdaptivePromptConfig
|
| 300 |
+
) -> PromptTemplate:
|
| 301 |
+
"""
|
| 302 |
+
Adapt a base template based on configuration.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
base_template: Base prompt template
|
| 306 |
+
config: Adaptive configuration
|
| 307 |
+
|
| 308 |
+
Returns:
|
| 309 |
+
Adapted PromptTemplate
|
| 310 |
+
"""
|
| 311 |
+
# Modify system prompt based on context quality
|
| 312 |
+
system_prompt = base_template.system_prompt
|
| 313 |
+
|
| 314 |
+
if config.context_quality == ContextQuality.LOW:
|
| 315 |
+
system_prompt += """
|
| 316 |
+
|
| 317 |
+
IMPORTANT: The provided context may have limited relevance. Focus on:
|
| 318 |
+
- Only use information that directly relates to the question
|
| 319 |
+
- Clearly state if information is insufficient
|
| 320 |
+
- Avoid making assumptions beyond the provided context
|
| 321 |
+
- Be explicit about confidence levels"""
|
| 322 |
+
|
| 323 |
+
elif config.context_quality == ContextQuality.HIGH:
|
| 324 |
+
system_prompt += """
|
| 325 |
+
|
| 326 |
+
CONTEXT QUALITY: High-quality, relevant context is provided. You can:
|
| 327 |
+
- Provide comprehensive, detailed answers
|
| 328 |
+
- Make reasonable inferences from the context
|
| 329 |
+
- Include related technical details and examples
|
| 330 |
+
- Reference multiple sources confidently"""
|
| 331 |
+
|
| 332 |
+
# Modify answer guidelines based on complexity and preferences
|
| 333 |
+
answer_guidelines = base_template.answer_guidelines
|
| 334 |
+
|
| 335 |
+
if config.prefer_concise:
|
| 336 |
+
answer_guidelines += "\n\nResponse style: Be concise and focus on essential information. Aim for clarity over comprehensiveness."
|
| 337 |
+
|
| 338 |
+
if config.query_complexity == QueryComplexity.COMPLEX:
|
| 339 |
+
answer_guidelines += "\n\nComplex query handling: Break down your answer into clear sections. Use numbered steps for procedures."
|
| 340 |
+
|
| 341 |
+
return PromptTemplate(
|
| 342 |
+
system_prompt=system_prompt,
|
| 343 |
+
context_format=base_template.context_format,
|
| 344 |
+
query_format=base_template.query_format,
|
| 345 |
+
answer_guidelines=answer_guidelines,
|
| 346 |
+
few_shot_examples=base_template.few_shot_examples
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def _format_context_adaptive(
|
| 350 |
+
self,
|
| 351 |
+
chunks: List[Dict[str, Any]],
|
| 352 |
+
config: AdaptivePromptConfig
|
| 353 |
+
) -> str:
|
| 354 |
+
"""
|
| 355 |
+
Format context chunks with adaptive optimization.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
chunks: Context chunks to format
|
| 359 |
+
config: Adaptive configuration
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
Formatted context string
|
| 363 |
+
"""
|
| 364 |
+
if not chunks:
|
| 365 |
+
return "No relevant context available."
|
| 366 |
+
|
| 367 |
+
# Filter chunks based on confidence if low quality context
|
| 368 |
+
filtered_chunks = chunks
|
| 369 |
+
if config.context_quality == ContextQuality.LOW:
|
| 370 |
+
filtered_chunks = [
|
| 371 |
+
chunk for chunk in chunks
|
| 372 |
+
if self._meets_confidence_threshold(chunk, config.confidence_threshold)
|
| 373 |
+
]
|
| 374 |
+
|
| 375 |
+
# Limit context length if needed
|
| 376 |
+
if config.prefer_concise:
|
| 377 |
+
filtered_chunks = filtered_chunks[:3] # Limit to top 3 chunks
|
| 378 |
+
|
| 379 |
+
# Format chunks
|
| 380 |
+
context_parts = []
|
| 381 |
+
for i, chunk in enumerate(filtered_chunks):
|
| 382 |
+
chunk_text = chunk.get('content', chunk.get('text', ''))
|
| 383 |
+
|
| 384 |
+
# Truncate if too long and prefer_concise is True
|
| 385 |
+
if config.prefer_concise and len(chunk_text) > 800:
|
| 386 |
+
chunk_text = chunk_text[:800] + "..."
|
| 387 |
+
|
| 388 |
+
metadata = chunk.get('metadata', {})
|
| 389 |
+
page_num = metadata.get('page_number', 'unknown')
|
| 390 |
+
source = metadata.get('source', 'unknown')
|
| 391 |
+
|
| 392 |
+
context_parts.append(
|
| 393 |
+
f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}"
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
return "\n\n---\n\n".join(context_parts)
|
| 397 |
+
|
| 398 |
+
def _add_chain_of_thought(
|
| 399 |
+
self,
|
| 400 |
+
prompt: Dict[str, str],
|
| 401 |
+
query_type: QueryType
|
| 402 |
+
) -> Dict[str, str]:
|
| 403 |
+
"""
|
| 404 |
+
Add chain-of-thought reasoning to the prompt.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
prompt: Base prompt dictionary
|
| 408 |
+
query_type: Type of query
|
| 409 |
+
|
| 410 |
+
Returns:
|
| 411 |
+
Enhanced prompt with chain-of-thought
|
| 412 |
+
"""
|
| 413 |
+
cot_addition = """
|
| 414 |
+
|
| 415 |
+
Before providing your final answer, think through this step-by-step:
|
| 416 |
+
|
| 417 |
+
1. What is the user specifically asking for?
|
| 418 |
+
2. What relevant information is available in the context?
|
| 419 |
+
3. How should I structure my response for maximum clarity?
|
| 420 |
+
4. Are there any important caveats or limitations to mention?
|
| 421 |
+
|
| 422 |
+
Step-by-step reasoning:"""
|
| 423 |
+
|
| 424 |
+
prompt["user"] = prompt["user"] + cot_addition
|
| 425 |
+
|
| 426 |
+
return prompt
|
| 427 |
+
|
| 428 |
+
def _should_include_few_shot(
|
| 429 |
+
self,
|
| 430 |
+
context_quality: ContextQuality,
|
| 431 |
+
query_complexity: QueryComplexity
|
| 432 |
+
) -> bool:
|
| 433 |
+
"""Determine if few-shot examples should be included."""
|
| 434 |
+
# Include few-shot for complex queries or when context quality is low
|
| 435 |
+
if query_complexity == QueryComplexity.COMPLEX:
|
| 436 |
+
return True
|
| 437 |
+
if context_quality == ContextQuality.LOW:
|
| 438 |
+
return True
|
| 439 |
+
return False
|
| 440 |
+
|
| 441 |
+
def _should_enable_cot(self, query_complexity: QueryComplexity) -> bool:
|
| 442 |
+
"""Determine if chain-of-thought should be enabled."""
|
| 443 |
+
return query_complexity == QueryComplexity.COMPLEX
|
| 444 |
+
|
| 445 |
+
def _get_confidence_threshold(self, context_quality: ContextQuality) -> float:
|
| 446 |
+
"""Get confidence threshold based on context quality."""
|
| 447 |
+
thresholds = {
|
| 448 |
+
ContextQuality.HIGH: 0.3,
|
| 449 |
+
ContextQuality.MEDIUM: 0.5,
|
| 450 |
+
ContextQuality.LOW: 0.7
|
| 451 |
+
}
|
| 452 |
+
return thresholds[context_quality]
|
| 453 |
+
|
| 454 |
+
def _count_technical_terms(self, text: str) -> int:
|
| 455 |
+
"""Count technical terms in text."""
|
| 456 |
+
technical_terms = [
|
| 457 |
+
"risc-v", "riscv", "cpu", "gpu", "mcu", "interrupt", "register",
|
| 458 |
+
"memory", "cache", "pipeline", "instruction", "assembly", "compiler",
|
| 459 |
+
"embedded", "freertos", "rtos", "gpio", "uart", "spi", "i2c",
|
| 460 |
+
"adc", "dac", "timer", "pwm", "dma", "firmware", "bootloader",
|
| 461 |
+
"ai", "ml", "neural", "transformer", "attention", "embedding"
|
| 462 |
+
]
|
| 463 |
+
|
| 464 |
+
text_lower = text.lower()
|
| 465 |
+
count = 0
|
| 466 |
+
for term in technical_terms:
|
| 467 |
+
count += text_lower.count(term)
|
| 468 |
+
|
| 469 |
+
return count
|
| 470 |
+
|
| 471 |
+
def _is_noisy_chunk(self, content: str) -> bool:
|
| 472 |
+
"""Determine if a chunk is noisy (low quality)."""
|
| 473 |
+
# Check for common noise indicators
|
| 474 |
+
noise_indicators = [
|
| 475 |
+
"table of contents",
|
| 476 |
+
"copyright",
|
| 477 |
+
"creative commons",
|
| 478 |
+
"license",
|
| 479 |
+
"all rights reserved",
|
| 480 |
+
"terms of use",
|
| 481 |
+
"privacy policy"
|
| 482 |
+
]
|
| 483 |
+
|
| 484 |
+
content_lower = content.lower()
|
| 485 |
+
|
| 486 |
+
# Check for noise indicators
|
| 487 |
+
for indicator in noise_indicators:
|
| 488 |
+
if indicator in content_lower:
|
| 489 |
+
return True
|
| 490 |
+
|
| 491 |
+
# Check for very short fragments
|
| 492 |
+
if len(content) < 100:
|
| 493 |
+
return True
|
| 494 |
+
|
| 495 |
+
# Check for repetitive content
|
| 496 |
+
words = content.split()
|
| 497 |
+
if len(set(words)) < len(words) * 0.3: # Less than 30% unique words
|
| 498 |
+
return True
|
| 499 |
+
|
| 500 |
+
return False
|
| 501 |
+
|
| 502 |
+
def _calculate_technical_density(self, chunks: List[Dict[str, Any]]) -> float:
|
| 503 |
+
"""Calculate technical density of chunks."""
|
| 504 |
+
if not chunks:
|
| 505 |
+
return 0.0
|
| 506 |
+
|
| 507 |
+
total_terms = 0
|
| 508 |
+
total_words = 0
|
| 509 |
+
|
| 510 |
+
for chunk in chunks:
|
| 511 |
+
content = chunk.get('content', chunk.get('text', ''))
|
| 512 |
+
words = content.split()
|
| 513 |
+
total_words += len(words)
|
| 514 |
+
total_terms += self._count_technical_terms(content)
|
| 515 |
+
|
| 516 |
+
return (total_terms / total_words) if total_words > 0 else 0.0
|
| 517 |
+
|
| 518 |
+
def _meets_confidence_threshold(
|
| 519 |
+
self,
|
| 520 |
+
chunk: Dict[str, Any],
|
| 521 |
+
threshold: float
|
| 522 |
+
) -> bool:
|
| 523 |
+
"""Check if chunk meets confidence threshold."""
|
| 524 |
+
confidence = chunk.get('confidence', chunk.get('score', 0.5))
|
| 525 |
+
return confidence >= threshold
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# Example usage
|
| 529 |
+
if __name__ == "__main__":
|
| 530 |
+
# Initialize engine
|
| 531 |
+
engine = AdaptivePromptEngine()
|
| 532 |
+
|
| 533 |
+
# Example context chunks
|
| 534 |
+
example_chunks = [
|
| 535 |
+
{
|
| 536 |
+
"content": "RISC-V is an open-source instruction set architecture...",
|
| 537 |
+
"metadata": {"page_number": 1, "source": "riscv-spec.pdf"},
|
| 538 |
+
"confidence": 0.9
|
| 539 |
+
},
|
| 540 |
+
{
|
| 541 |
+
"content": "The RISC-V processor supports 32-bit and 64-bit implementations...",
|
| 542 |
+
"metadata": {"page_number": 2, "source": "riscv-spec.pdf"},
|
| 543 |
+
"confidence": 0.8
|
| 544 |
+
}
|
| 545 |
+
]
|
| 546 |
+
|
| 547 |
+
# Example queries
|
| 548 |
+
simple_query = "What is RISC-V?"
|
| 549 |
+
complex_query = "How do I implement a complete interrupt handling system in RISC-V with nested interrupts and priority management?"
|
| 550 |
+
|
| 551 |
+
# Generate adaptive prompts
|
| 552 |
+
simple_config = engine.generate_adaptive_config(simple_query, example_chunks)
|
| 553 |
+
complex_config = engine.generate_adaptive_config(complex_query, example_chunks)
|
| 554 |
+
|
| 555 |
+
print(f"Simple query complexity: {simple_config.query_complexity}")
|
| 556 |
+
print(f"Complex query complexity: {complex_config.query_complexity}")
|
| 557 |
+
print(f"Context quality: {simple_config.context_quality}")
|
| 558 |
+
print(f"Few-shot enabled: {complex_config.include_few_shot}")
|
| 559 |
+
print(f"Chain-of-thought enabled: {complex_config.enable_chain_of_thought}")
|
shared_utils/generation/answer_generator.py
ADDED
|
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Answer generation module using Ollama for local LLM inference.
|
| 3 |
+
|
| 4 |
+
This module provides answer generation with citation support for RAG systems,
|
| 5 |
+
optimized for technical documentation Q&A on Apple Silicon.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from typing import List, Dict, Any, Optional, Generator, Tuple
|
| 12 |
+
import ollama
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
import re
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
# Import calibration framework
|
| 19 |
+
try:
|
| 20 |
+
from src.confidence_calibration import ConfidenceCalibrator
|
| 21 |
+
except ImportError:
|
| 22 |
+
# Fallback - disable calibration for deployment
|
| 23 |
+
ConfidenceCalibrator = None
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class Citation:
|
| 30 |
+
"""Represents a citation to a source document chunk."""
|
| 31 |
+
chunk_id: str
|
| 32 |
+
page_number: int
|
| 33 |
+
source_file: str
|
| 34 |
+
relevance_score: float
|
| 35 |
+
text_snippet: str
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class GeneratedAnswer:
|
| 40 |
+
"""Represents a generated answer with citations."""
|
| 41 |
+
answer: str
|
| 42 |
+
citations: List[Citation]
|
| 43 |
+
confidence_score: float
|
| 44 |
+
generation_time: float
|
| 45 |
+
model_used: str
|
| 46 |
+
context_used: List[Dict[str, Any]]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class AnswerGenerator:
|
| 50 |
+
"""
|
| 51 |
+
Generates answers using local LLMs via Ollama with citation support.
|
| 52 |
+
|
| 53 |
+
Optimized for technical documentation Q&A with:
|
| 54 |
+
- Streaming response support
|
| 55 |
+
- Citation extraction and formatting
|
| 56 |
+
- Confidence scoring
|
| 57 |
+
- Fallback model support
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
primary_model: str = "llama3.2:3b",
|
| 63 |
+
fallback_model: str = "mistral:latest",
|
| 64 |
+
temperature: float = 0.3,
|
| 65 |
+
max_tokens: int = 1024,
|
| 66 |
+
stream: bool = True,
|
| 67 |
+
enable_calibration: bool = True
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Initialize the answer generator.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
primary_model: Primary Ollama model to use
|
| 74 |
+
fallback_model: Fallback model for complex queries
|
| 75 |
+
temperature: Generation temperature (0.0-1.0)
|
| 76 |
+
max_tokens: Maximum tokens to generate
|
| 77 |
+
stream: Whether to stream responses
|
| 78 |
+
enable_calibration: Whether to enable confidence calibration
|
| 79 |
+
"""
|
| 80 |
+
self.primary_model = primary_model
|
| 81 |
+
self.fallback_model = fallback_model
|
| 82 |
+
self.temperature = temperature
|
| 83 |
+
self.max_tokens = max_tokens
|
| 84 |
+
self.stream = stream
|
| 85 |
+
self.client = ollama.Client()
|
| 86 |
+
|
| 87 |
+
# Initialize confidence calibration
|
| 88 |
+
self.enable_calibration = enable_calibration
|
| 89 |
+
self.calibrator = None
|
| 90 |
+
if enable_calibration and ConfidenceCalibrator is not None:
|
| 91 |
+
try:
|
| 92 |
+
self.calibrator = ConfidenceCalibrator()
|
| 93 |
+
logger.info("Confidence calibration enabled")
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.warning(f"Failed to initialize calibration: {e}")
|
| 96 |
+
self.enable_calibration = False
|
| 97 |
+
elif enable_calibration and ConfidenceCalibrator is None:
|
| 98 |
+
logger.warning("Calibration requested but ConfidenceCalibrator not available - disabling")
|
| 99 |
+
self.enable_calibration = False
|
| 100 |
+
|
| 101 |
+
# Verify models are available
|
| 102 |
+
self._verify_models()
|
| 103 |
+
|
| 104 |
+
def _verify_models(self) -> None:
|
| 105 |
+
"""Verify that required models are available."""
|
| 106 |
+
try:
|
| 107 |
+
model_list = self.client.list()
|
| 108 |
+
available_models = []
|
| 109 |
+
|
| 110 |
+
# Handle Ollama's ListResponse object
|
| 111 |
+
if hasattr(model_list, 'models'):
|
| 112 |
+
for model in model_list.models:
|
| 113 |
+
if hasattr(model, 'model'):
|
| 114 |
+
available_models.append(model.model)
|
| 115 |
+
elif isinstance(model, dict) and 'model' in model:
|
| 116 |
+
available_models.append(model['model'])
|
| 117 |
+
|
| 118 |
+
if self.primary_model not in available_models:
|
| 119 |
+
logger.warning(f"Primary model {self.primary_model} not found. Available models: {available_models}")
|
| 120 |
+
raise ValueError(f"Model {self.primary_model} not available. Please run: ollama pull {self.primary_model}")
|
| 121 |
+
|
| 122 |
+
if self.fallback_model not in available_models:
|
| 123 |
+
logger.warning(f"Fallback model {self.fallback_model} not found in: {available_models}")
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.error(f"Error verifying models: {e}")
|
| 127 |
+
raise
|
| 128 |
+
|
| 129 |
+
def _create_system_prompt(self) -> str:
|
| 130 |
+
"""Create system prompt for technical documentation Q&A."""
|
| 131 |
+
return """You are a technical documentation assistant that provides clear, accurate answers based on the provided context.
|
| 132 |
+
|
| 133 |
+
CORE PRINCIPLES:
|
| 134 |
+
1. ANSWER DIRECTLY: If context contains the answer, provide it clearly and confidently
|
| 135 |
+
2. BE CONCISE: Keep responses focused and avoid unnecessary uncertainty language
|
| 136 |
+
3. CITE ACCURATELY: Use [chunk_X] citations for every fact from context
|
| 137 |
+
|
| 138 |
+
RESPONSE GUIDELINES:
|
| 139 |
+
- If context has sufficient information → Answer directly and confidently
|
| 140 |
+
- If context has partial information → Answer what's available, note what's missing briefly
|
| 141 |
+
- If context is irrelevant → Brief refusal: "This information isn't available in the provided documents"
|
| 142 |
+
|
| 143 |
+
CITATION FORMAT:
|
| 144 |
+
- Use [chunk_1], [chunk_2] etc. for all facts from context
|
| 145 |
+
- Example: "According to [chunk_1], RISC-V is an open-source architecture."
|
| 146 |
+
|
| 147 |
+
WHAT TO AVOID:
|
| 148 |
+
- Do NOT add details not in context
|
| 149 |
+
- Do NOT second-guess yourself if context is clear
|
| 150 |
+
- Do NOT use phrases like "does not contain sufficient information" when context clearly answers the question
|
| 151 |
+
- Do NOT be overly cautious when context is adequate
|
| 152 |
+
|
| 153 |
+
Be direct, confident, and accurate. If the context answers the question, provide that answer clearly."""
|
| 154 |
+
|
| 155 |
+
def _format_context(self, chunks: List[Dict[str, Any]]) -> str:
|
| 156 |
+
"""
|
| 157 |
+
Format retrieved chunks into context for the LLM.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
chunks: List of retrieved chunks with metadata
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Formatted context string
|
| 164 |
+
"""
|
| 165 |
+
context_parts = []
|
| 166 |
+
|
| 167 |
+
for i, chunk in enumerate(chunks):
|
| 168 |
+
chunk_text = chunk.get('content', chunk.get('text', ''))
|
| 169 |
+
page_num = chunk.get('metadata', {}).get('page_number', 'unknown')
|
| 170 |
+
source = chunk.get('metadata', {}).get('source', 'unknown')
|
| 171 |
+
|
| 172 |
+
context_parts.append(
|
| 173 |
+
f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return "\n---\n".join(context_parts)
|
| 177 |
+
|
| 178 |
+
def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]:
|
| 179 |
+
"""
|
| 180 |
+
Extract citations from the generated answer and integrate them naturally.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
answer: Generated answer with [chunk_X] citations
|
| 184 |
+
chunks: Original chunks used for context
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Tuple of (natural_answer, citations)
|
| 188 |
+
"""
|
| 189 |
+
citations = []
|
| 190 |
+
citation_pattern = r'\[chunk_(\d+)\]'
|
| 191 |
+
|
| 192 |
+
cited_chunks = set()
|
| 193 |
+
|
| 194 |
+
# Find [chunk_X] citations and collect cited chunks
|
| 195 |
+
matches = re.finditer(citation_pattern, answer)
|
| 196 |
+
for match in matches:
|
| 197 |
+
chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index
|
| 198 |
+
if 0 <= chunk_idx < len(chunks):
|
| 199 |
+
cited_chunks.add(chunk_idx)
|
| 200 |
+
|
| 201 |
+
# Create Citation objects for each cited chunk
|
| 202 |
+
chunk_to_source = {}
|
| 203 |
+
for idx in cited_chunks:
|
| 204 |
+
chunk = chunks[idx]
|
| 205 |
+
citation = Citation(
|
| 206 |
+
chunk_id=chunk.get('id', f'chunk_{idx}'),
|
| 207 |
+
page_number=chunk.get('metadata', {}).get('page_number', 0),
|
| 208 |
+
source_file=chunk.get('metadata', {}).get('source', 'unknown'),
|
| 209 |
+
relevance_score=chunk.get('score', 0.0),
|
| 210 |
+
text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...'
|
| 211 |
+
)
|
| 212 |
+
citations.append(citation)
|
| 213 |
+
|
| 214 |
+
# Map chunk reference to natural source name
|
| 215 |
+
source_name = chunk.get('metadata', {}).get('source', 'unknown')
|
| 216 |
+
if source_name != 'unknown':
|
| 217 |
+
# Use just the filename without extension for natural reference
|
| 218 |
+
natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ')
|
| 219 |
+
chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation"
|
| 220 |
+
else:
|
| 221 |
+
chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation"
|
| 222 |
+
|
| 223 |
+
# Replace [chunk_X] with natural references instead of removing them
|
| 224 |
+
natural_answer = answer
|
| 225 |
+
for chunk_ref, natural_ref in chunk_to_source.items():
|
| 226 |
+
natural_answer = natural_answer.replace(chunk_ref, natural_ref)
|
| 227 |
+
|
| 228 |
+
# Clean up any remaining unreferenced citations (fallback)
|
| 229 |
+
natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer)
|
| 230 |
+
|
| 231 |
+
# Clean up multiple spaces and formatting
|
| 232 |
+
natural_answer = re.sub(r'\s+', ' ', natural_answer).strip()
|
| 233 |
+
|
| 234 |
+
return natural_answer, citations
|
| 235 |
+
|
| 236 |
+
def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float:
|
| 237 |
+
"""
|
| 238 |
+
Calculate confidence score for the generated answer with improved calibration.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
answer: Generated answer
|
| 242 |
+
citations: Extracted citations
|
| 243 |
+
chunks: Retrieved chunks
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
Confidence score (0.0-1.0)
|
| 247 |
+
"""
|
| 248 |
+
# Check if no chunks were provided first
|
| 249 |
+
if not chunks:
|
| 250 |
+
return 0.05 # No context = very low confidence
|
| 251 |
+
|
| 252 |
+
# Assess context quality to determine base confidence
|
| 253 |
+
scores = [chunk.get('score', 0) for chunk in chunks]
|
| 254 |
+
max_relevance = max(scores) if scores else 0
|
| 255 |
+
avg_relevance = sum(scores) / len(scores) if scores else 0
|
| 256 |
+
|
| 257 |
+
# Dynamic base confidence based on context quality
|
| 258 |
+
if max_relevance >= 0.8:
|
| 259 |
+
confidence = 0.6 # High-quality context starts high
|
| 260 |
+
elif max_relevance >= 0.6:
|
| 261 |
+
confidence = 0.4 # Good context starts moderately
|
| 262 |
+
elif max_relevance >= 0.4:
|
| 263 |
+
confidence = 0.2 # Fair context starts low
|
| 264 |
+
else:
|
| 265 |
+
confidence = 0.05 # Poor context starts very low
|
| 266 |
+
|
| 267 |
+
# Strong uncertainty and explicit refusal indicators
|
| 268 |
+
strong_uncertainty_phrases = [
|
| 269 |
+
"does not contain sufficient information",
|
| 270 |
+
"context does not provide",
|
| 271 |
+
"insufficient information",
|
| 272 |
+
"cannot determine",
|
| 273 |
+
"refuse to answer",
|
| 274 |
+
"cannot answer",
|
| 275 |
+
"does not contain relevant",
|
| 276 |
+
"no relevant context",
|
| 277 |
+
"missing from the provided context"
|
| 278 |
+
]
|
| 279 |
+
|
| 280 |
+
# Weak uncertainty phrases that might be in nuanced but correct answers
|
| 281 |
+
weak_uncertainty_phrases = [
|
| 282 |
+
"unclear",
|
| 283 |
+
"conflicting",
|
| 284 |
+
"not specified",
|
| 285 |
+
"questionable",
|
| 286 |
+
"not contained",
|
| 287 |
+
"no mention",
|
| 288 |
+
"no relevant",
|
| 289 |
+
"missing",
|
| 290 |
+
"not explicitly"
|
| 291 |
+
]
|
| 292 |
+
|
| 293 |
+
# Check for strong uncertainty - these should drastically reduce confidence
|
| 294 |
+
if any(phrase in answer.lower() for phrase in strong_uncertainty_phrases):
|
| 295 |
+
return min(0.1, confidence * 0.2) # Max 10% for explicit refusal/uncertainty
|
| 296 |
+
|
| 297 |
+
# Check for weak uncertainty - reduce but don't destroy confidence for good context
|
| 298 |
+
weak_uncertainty_count = sum(1 for phrase in weak_uncertainty_phrases if phrase in answer.lower())
|
| 299 |
+
if weak_uncertainty_count > 0:
|
| 300 |
+
if max_relevance >= 0.7 and citations:
|
| 301 |
+
# Good context with citations - reduce less severely
|
| 302 |
+
confidence *= (0.8 ** weak_uncertainty_count) # Moderate penalty
|
| 303 |
+
else:
|
| 304 |
+
# Poor context - reduce more severely
|
| 305 |
+
confidence *= (0.5 ** weak_uncertainty_count) # Strong penalty
|
| 306 |
+
|
| 307 |
+
# If all chunks have very low relevance scores, cap confidence low
|
| 308 |
+
if max_relevance < 0.4:
|
| 309 |
+
return min(0.08, confidence) # Max 8% for low relevance context
|
| 310 |
+
|
| 311 |
+
# Factor 1: Citation quality and coverage
|
| 312 |
+
if citations and chunks:
|
| 313 |
+
citation_ratio = len(citations) / min(len(chunks), 3)
|
| 314 |
+
|
| 315 |
+
# Strong boost for high-relevance citations
|
| 316 |
+
relevant_chunks = [c for c in chunks if c.get('score', 0) > 0.6]
|
| 317 |
+
if relevant_chunks:
|
| 318 |
+
# Significant boost for citing relevant chunks
|
| 319 |
+
confidence += 0.25 * citation_ratio
|
| 320 |
+
|
| 321 |
+
# Extra boost if citing majority of relevant chunks
|
| 322 |
+
if len(citations) >= len(relevant_chunks) * 0.5:
|
| 323 |
+
confidence += 0.15
|
| 324 |
+
else:
|
| 325 |
+
# Small boost for citations to lower-relevance chunks
|
| 326 |
+
confidence += 0.1 * citation_ratio
|
| 327 |
+
else:
|
| 328 |
+
# No citations = reduce confidence unless it's a simple factual statement
|
| 329 |
+
if max_relevance >= 0.8 and len(answer.split()) < 20:
|
| 330 |
+
confidence *= 0.8 # Gentle penalty for uncited but simple answers
|
| 331 |
+
else:
|
| 332 |
+
confidence *= 0.6 # Stronger penalty for complex uncited answers
|
| 333 |
+
|
| 334 |
+
# Factor 2: Relevance score reinforcement
|
| 335 |
+
if citations:
|
| 336 |
+
avg_citation_relevance = sum(c.relevance_score for c in citations) / len(citations)
|
| 337 |
+
if avg_citation_relevance > 0.8:
|
| 338 |
+
confidence += 0.2 # Strong boost for highly relevant citations
|
| 339 |
+
elif avg_citation_relevance > 0.6:
|
| 340 |
+
confidence += 0.1 # Moderate boost
|
| 341 |
+
elif avg_citation_relevance < 0.4:
|
| 342 |
+
confidence *= 0.6 # Penalty for low-relevance citations
|
| 343 |
+
|
| 344 |
+
# Factor 3: Context utilization quality
|
| 345 |
+
if chunks:
|
| 346 |
+
avg_chunk_length = sum(len(chunk.get('content', chunk.get('text', ''))) for chunk in chunks) / len(chunks)
|
| 347 |
+
|
| 348 |
+
# Boost for substantial, high-quality context
|
| 349 |
+
if avg_chunk_length > 200 and max_relevance > 0.8:
|
| 350 |
+
confidence += 0.1
|
| 351 |
+
elif avg_chunk_length < 50: # Very short chunks
|
| 352 |
+
confidence *= 0.8
|
| 353 |
+
|
| 354 |
+
# Factor 4: Answer characteristics
|
| 355 |
+
answer_words = len(answer.split())
|
| 356 |
+
if answer_words < 10:
|
| 357 |
+
confidence *= 0.9 # Slight penalty for very short answers
|
| 358 |
+
elif answer_words > 50 and citations:
|
| 359 |
+
confidence += 0.05 # Small boost for detailed cited answers
|
| 360 |
+
|
| 361 |
+
# Factor 5: High-quality scenario bonus
|
| 362 |
+
if (max_relevance >= 0.8 and citations and
|
| 363 |
+
len(citations) > 0 and
|
| 364 |
+
not any(phrase in answer.lower() for phrase in strong_uncertainty_phrases)):
|
| 365 |
+
# This is a high-quality response scenario
|
| 366 |
+
confidence += 0.15
|
| 367 |
+
|
| 368 |
+
raw_confidence = min(confidence, 0.95) # Cap at 95% to maintain some uncertainty
|
| 369 |
+
|
| 370 |
+
# Apply temperature scaling calibration if available
|
| 371 |
+
if self.enable_calibration and self.calibrator and self.calibrator.is_fitted:
|
| 372 |
+
try:
|
| 373 |
+
calibrated_confidence = self.calibrator.calibrate_confidence(raw_confidence)
|
| 374 |
+
logger.debug(f"Confidence calibrated: {raw_confidence:.3f} -> {calibrated_confidence:.3f}")
|
| 375 |
+
return calibrated_confidence
|
| 376 |
+
except Exception as e:
|
| 377 |
+
logger.warning(f"Calibration failed, using raw confidence: {e}")
|
| 378 |
+
|
| 379 |
+
return raw_confidence
|
| 380 |
+
|
| 381 |
+
def fit_calibration(self, validation_data: List[Dict[str, Any]]) -> float:
|
| 382 |
+
"""
|
| 383 |
+
Fit temperature scaling calibration using validation data.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
validation_data: List of dicts with 'confidence' and 'correctness' keys
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
Optimal temperature parameter
|
| 390 |
+
"""
|
| 391 |
+
if not self.enable_calibration or not self.calibrator:
|
| 392 |
+
logger.warning("Calibration not enabled or not available")
|
| 393 |
+
return 1.0
|
| 394 |
+
|
| 395 |
+
try:
|
| 396 |
+
confidences = [item['confidence'] for item in validation_data]
|
| 397 |
+
correctness = [item['correctness'] for item in validation_data]
|
| 398 |
+
|
| 399 |
+
optimal_temp = self.calibrator.fit_temperature_scaling(confidences, correctness)
|
| 400 |
+
logger.info(f"Calibration fitted with temperature: {optimal_temp:.3f}")
|
| 401 |
+
return optimal_temp
|
| 402 |
+
|
| 403 |
+
except Exception as e:
|
| 404 |
+
logger.error(f"Failed to fit calibration: {e}")
|
| 405 |
+
return 1.0
|
| 406 |
+
|
| 407 |
+
def save_calibration(self, filepath: str) -> bool:
|
| 408 |
+
"""Save fitted calibration to file."""
|
| 409 |
+
if not self.calibrator or not self.calibrator.is_fitted:
|
| 410 |
+
logger.warning("No fitted calibration to save")
|
| 411 |
+
return False
|
| 412 |
+
|
| 413 |
+
try:
|
| 414 |
+
calibration_data = {
|
| 415 |
+
'temperature': self.calibrator.temperature,
|
| 416 |
+
'is_fitted': self.calibrator.is_fitted,
|
| 417 |
+
'model_info': {
|
| 418 |
+
'primary_model': self.primary_model,
|
| 419 |
+
'fallback_model': self.fallback_model
|
| 420 |
+
}
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
with open(filepath, 'w') as f:
|
| 424 |
+
json.dump(calibration_data, f, indent=2)
|
| 425 |
+
|
| 426 |
+
logger.info(f"Calibration saved to {filepath}")
|
| 427 |
+
return True
|
| 428 |
+
|
| 429 |
+
except Exception as e:
|
| 430 |
+
logger.error(f"Failed to save calibration: {e}")
|
| 431 |
+
return False
|
| 432 |
+
|
| 433 |
+
def load_calibration(self, filepath: str) -> bool:
|
| 434 |
+
"""Load fitted calibration from file."""
|
| 435 |
+
if not self.enable_calibration or not self.calibrator:
|
| 436 |
+
logger.warning("Calibration not enabled")
|
| 437 |
+
return False
|
| 438 |
+
|
| 439 |
+
try:
|
| 440 |
+
with open(filepath, 'r') as f:
|
| 441 |
+
calibration_data = json.load(f)
|
| 442 |
+
|
| 443 |
+
self.calibrator.temperature = calibration_data['temperature']
|
| 444 |
+
self.calibrator.is_fitted = calibration_data['is_fitted']
|
| 445 |
+
|
| 446 |
+
logger.info(f"Calibration loaded from {filepath} (temp: {self.calibrator.temperature:.3f})")
|
| 447 |
+
return True
|
| 448 |
+
|
| 449 |
+
except Exception as e:
|
| 450 |
+
logger.error(f"Failed to load calibration: {e}")
|
| 451 |
+
return False
|
| 452 |
+
|
| 453 |
+
def generate(
|
| 454 |
+
self,
|
| 455 |
+
query: str,
|
| 456 |
+
chunks: List[Dict[str, Any]],
|
| 457 |
+
use_fallback: bool = False
|
| 458 |
+
) -> GeneratedAnswer:
|
| 459 |
+
"""
|
| 460 |
+
Generate an answer based on the query and retrieved chunks.
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
query: User's question
|
| 464 |
+
chunks: Retrieved document chunks
|
| 465 |
+
use_fallback: Whether to use fallback model
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
GeneratedAnswer object with answer, citations, and metadata
|
| 469 |
+
"""
|
| 470 |
+
start_time = datetime.now()
|
| 471 |
+
model = self.fallback_model if use_fallback else self.primary_model
|
| 472 |
+
|
| 473 |
+
# Check for no-context or very poor context situation
|
| 474 |
+
if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks):
|
| 475 |
+
# Handle no-context situation with brief, professional refusal
|
| 476 |
+
user_prompt = f"""Context: [NO RELEVANT CONTEXT FOUND]
|
| 477 |
+
|
| 478 |
+
Question: {query}
|
| 479 |
+
|
| 480 |
+
INSTRUCTION: Respond with exactly this brief message:
|
| 481 |
+
|
| 482 |
+
"This information isn't available in the provided documents."
|
| 483 |
+
|
| 484 |
+
DO NOT elaborate, explain, or add any other information."""
|
| 485 |
+
else:
|
| 486 |
+
# Format context from chunks
|
| 487 |
+
context = self._format_context(chunks)
|
| 488 |
+
|
| 489 |
+
# Create concise prompt for faster generation
|
| 490 |
+
user_prompt = f"""Context:
|
| 491 |
+
{context}
|
| 492 |
+
|
| 493 |
+
Question: {query}
|
| 494 |
+
|
| 495 |
+
Instructions: Answer using only the context above. Cite with [chunk_1], [chunk_2] etc.
|
| 496 |
+
|
| 497 |
+
Answer:"""
|
| 498 |
+
|
| 499 |
+
try:
|
| 500 |
+
# Generate response
|
| 501 |
+
response = self.client.chat(
|
| 502 |
+
model=model,
|
| 503 |
+
messages=[
|
| 504 |
+
{"role": "system", "content": self._create_system_prompt()},
|
| 505 |
+
{"role": "user", "content": user_prompt}
|
| 506 |
+
],
|
| 507 |
+
options={
|
| 508 |
+
"temperature": self.temperature,
|
| 509 |
+
"num_predict": min(self.max_tokens, 300), # Reduce max tokens for speed
|
| 510 |
+
"top_k": 40, # Optimize sampling for speed
|
| 511 |
+
"top_p": 0.9,
|
| 512 |
+
"repeat_penalty": 1.1
|
| 513 |
+
},
|
| 514 |
+
stream=False # Get complete response for processing
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Extract answer
|
| 518 |
+
answer_with_citations = response['message']['content']
|
| 519 |
+
|
| 520 |
+
# Extract and clean citations
|
| 521 |
+
clean_answer, citations = self._extract_citations(answer_with_citations, chunks)
|
| 522 |
+
|
| 523 |
+
# Calculate confidence
|
| 524 |
+
confidence = self._calculate_confidence(clean_answer, citations, chunks)
|
| 525 |
+
|
| 526 |
+
# Calculate generation time
|
| 527 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 528 |
+
|
| 529 |
+
return GeneratedAnswer(
|
| 530 |
+
answer=clean_answer,
|
| 531 |
+
citations=citations,
|
| 532 |
+
confidence_score=confidence,
|
| 533 |
+
generation_time=generation_time,
|
| 534 |
+
model_used=model,
|
| 535 |
+
context_used=chunks
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
except Exception as e:
|
| 539 |
+
logger.error(f"Error generating answer: {e}")
|
| 540 |
+
# Return a fallback response
|
| 541 |
+
return GeneratedAnswer(
|
| 542 |
+
answer="I apologize, but I encountered an error while generating the answer. Please try again.",
|
| 543 |
+
citations=[],
|
| 544 |
+
confidence_score=0.0,
|
| 545 |
+
generation_time=0.0,
|
| 546 |
+
model_used=model,
|
| 547 |
+
context_used=chunks
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
def generate_stream(
|
| 551 |
+
self,
|
| 552 |
+
query: str,
|
| 553 |
+
chunks: List[Dict[str, Any]],
|
| 554 |
+
use_fallback: bool = False
|
| 555 |
+
) -> Generator[str, None, GeneratedAnswer]:
|
| 556 |
+
"""
|
| 557 |
+
Generate an answer with streaming support.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
query: User's question
|
| 561 |
+
chunks: Retrieved document chunks
|
| 562 |
+
use_fallback: Whether to use fallback model
|
| 563 |
+
|
| 564 |
+
Yields:
|
| 565 |
+
Partial answer strings
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
Final GeneratedAnswer object
|
| 569 |
+
"""
|
| 570 |
+
start_time = datetime.now()
|
| 571 |
+
model = self.fallback_model if use_fallback else self.primary_model
|
| 572 |
+
|
| 573 |
+
# Check for no-context or very poor context situation
|
| 574 |
+
if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks):
|
| 575 |
+
# Handle no-context situation with brief, professional refusal
|
| 576 |
+
user_prompt = f"""Context: [NO RELEVANT CONTEXT FOUND]
|
| 577 |
+
|
| 578 |
+
Question: {query}
|
| 579 |
+
|
| 580 |
+
INSTRUCTION: Respond with exactly this brief message:
|
| 581 |
+
|
| 582 |
+
"This information isn't available in the provided documents."
|
| 583 |
+
|
| 584 |
+
DO NOT elaborate, explain, or add any other information."""
|
| 585 |
+
else:
|
| 586 |
+
# Format context from chunks
|
| 587 |
+
context = self._format_context(chunks)
|
| 588 |
+
|
| 589 |
+
# Create concise prompt for faster generation
|
| 590 |
+
user_prompt = f"""Context:
|
| 591 |
+
{context}
|
| 592 |
+
|
| 593 |
+
Question: {query}
|
| 594 |
+
|
| 595 |
+
Instructions: Answer using only the context above. Cite with [chunk_1], [chunk_2] etc.
|
| 596 |
+
|
| 597 |
+
Answer:"""
|
| 598 |
+
|
| 599 |
+
try:
|
| 600 |
+
# Generate streaming response
|
| 601 |
+
stream = self.client.chat(
|
| 602 |
+
model=model,
|
| 603 |
+
messages=[
|
| 604 |
+
{"role": "system", "content": self._create_system_prompt()},
|
| 605 |
+
{"role": "user", "content": user_prompt}
|
| 606 |
+
],
|
| 607 |
+
options={
|
| 608 |
+
"temperature": self.temperature,
|
| 609 |
+
"num_predict": min(self.max_tokens, 300), # Reduce max tokens for speed
|
| 610 |
+
"top_k": 40, # Optimize sampling for speed
|
| 611 |
+
"top_p": 0.9,
|
| 612 |
+
"repeat_penalty": 1.1
|
| 613 |
+
},
|
| 614 |
+
stream=True
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
# Collect full answer while streaming
|
| 618 |
+
full_answer = ""
|
| 619 |
+
for chunk in stream:
|
| 620 |
+
if 'message' in chunk and 'content' in chunk['message']:
|
| 621 |
+
partial = chunk['message']['content']
|
| 622 |
+
full_answer += partial
|
| 623 |
+
yield partial
|
| 624 |
+
|
| 625 |
+
# Process complete answer
|
| 626 |
+
clean_answer, citations = self._extract_citations(full_answer, chunks)
|
| 627 |
+
confidence = self._calculate_confidence(clean_answer, citations, chunks)
|
| 628 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 629 |
+
|
| 630 |
+
return GeneratedAnswer(
|
| 631 |
+
answer=clean_answer,
|
| 632 |
+
citations=citations,
|
| 633 |
+
confidence_score=confidence,
|
| 634 |
+
generation_time=generation_time,
|
| 635 |
+
model_used=model,
|
| 636 |
+
context_used=chunks
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
except Exception as e:
|
| 640 |
+
logger.error(f"Error in streaming generation: {e}")
|
| 641 |
+
yield "I apologize, but I encountered an error while generating the answer."
|
| 642 |
+
|
| 643 |
+
return GeneratedAnswer(
|
| 644 |
+
answer="Error occurred during generation.",
|
| 645 |
+
citations=[],
|
| 646 |
+
confidence_score=0.0,
|
| 647 |
+
generation_time=0.0,
|
| 648 |
+
model_used=model,
|
| 649 |
+
context_used=chunks
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
def format_answer_with_citations(self, generated_answer: GeneratedAnswer) -> str:
|
| 653 |
+
"""
|
| 654 |
+
Format the generated answer with citations for display.
|
| 655 |
+
|
| 656 |
+
Args:
|
| 657 |
+
generated_answer: GeneratedAnswer object
|
| 658 |
+
|
| 659 |
+
Returns:
|
| 660 |
+
Formatted string with answer and citations
|
| 661 |
+
"""
|
| 662 |
+
formatted = f"{generated_answer.answer}\n\n"
|
| 663 |
+
|
| 664 |
+
if generated_answer.citations:
|
| 665 |
+
formatted += "**Sources:**\n"
|
| 666 |
+
for i, citation in enumerate(generated_answer.citations, 1):
|
| 667 |
+
formatted += f"{i}. {citation.source_file} (Page {citation.page_number})\n"
|
| 668 |
+
|
| 669 |
+
formatted += f"\n*Confidence: {generated_answer.confidence_score:.1%} | "
|
| 670 |
+
formatted += f"Model: {generated_answer.model_used} | "
|
| 671 |
+
formatted += f"Time: {generated_answer.generation_time:.2f}s*"
|
| 672 |
+
|
| 673 |
+
return formatted
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
if __name__ == "__main__":
|
| 677 |
+
# Example usage
|
| 678 |
+
generator = AnswerGenerator()
|
| 679 |
+
|
| 680 |
+
# Example chunks (would come from retrieval system)
|
| 681 |
+
example_chunks = [
|
| 682 |
+
{
|
| 683 |
+
"id": "chunk_1",
|
| 684 |
+
"content": "RISC-V is an open-source instruction set architecture (ISA) based on reduced instruction set computer (RISC) principles.",
|
| 685 |
+
"metadata": {"page_number": 1, "source": "riscv-spec.pdf"},
|
| 686 |
+
"score": 0.95
|
| 687 |
+
},
|
| 688 |
+
{
|
| 689 |
+
"id": "chunk_2",
|
| 690 |
+
"content": "The RISC-V ISA is designed to support a wide range of implementations including 32-bit, 64-bit, and 128-bit variants.",
|
| 691 |
+
"metadata": {"page_number": 2, "source": "riscv-spec.pdf"},
|
| 692 |
+
"score": 0.89
|
| 693 |
+
}
|
| 694 |
+
]
|
| 695 |
+
|
| 696 |
+
# Generate answer
|
| 697 |
+
result = generator.generate(
|
| 698 |
+
query="What is RISC-V?",
|
| 699 |
+
chunks=example_chunks
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# Display formatted result
|
| 703 |
+
print(generator.format_answer_with_citations(result))
|
shared_utils/generation/chain_of_thought_engine.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chain-of-Thought Reasoning Engine for Complex Technical Queries.
|
| 3 |
+
|
| 4 |
+
This module provides structured reasoning capabilities for complex technical
|
| 5 |
+
questions that require multi-step analysis and implementation guidance.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, List, Optional, Any, Tuple
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from enum import Enum
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
from .prompt_templates import QueryType, PromptTemplate
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ReasoningStep(Enum):
|
| 17 |
+
"""Types of reasoning steps in chain-of-thought."""
|
| 18 |
+
ANALYSIS = "analysis"
|
| 19 |
+
DECOMPOSITION = "decomposition"
|
| 20 |
+
SYNTHESIS = "synthesis"
|
| 21 |
+
VALIDATION = "validation"
|
| 22 |
+
IMPLEMENTATION = "implementation"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class ChainStep:
|
| 27 |
+
"""Represents a single step in chain-of-thought reasoning."""
|
| 28 |
+
step_type: ReasoningStep
|
| 29 |
+
description: str
|
| 30 |
+
prompt_addition: str
|
| 31 |
+
requires_context: bool = True
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ChainOfThoughtEngine:
|
| 35 |
+
"""
|
| 36 |
+
Engine for generating chain-of-thought reasoning prompts for complex technical queries.
|
| 37 |
+
|
| 38 |
+
Features:
|
| 39 |
+
- Multi-step reasoning for complex implementations
|
| 40 |
+
- Context-aware step generation
|
| 41 |
+
- Query type specific reasoning chains
|
| 42 |
+
- Validation and error checking steps
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self):
|
| 46 |
+
"""Initialize the chain-of-thought engine."""
|
| 47 |
+
self.reasoning_chains = self._initialize_reasoning_chains()
|
| 48 |
+
|
| 49 |
+
def _initialize_reasoning_chains(self) -> Dict[QueryType, List[ChainStep]]:
|
| 50 |
+
"""Initialize reasoning chains for different query types."""
|
| 51 |
+
return {
|
| 52 |
+
QueryType.IMPLEMENTATION: [
|
| 53 |
+
ChainStep(
|
| 54 |
+
step_type=ReasoningStep.ANALYSIS,
|
| 55 |
+
description="Analyze the implementation requirements",
|
| 56 |
+
prompt_addition="""
|
| 57 |
+
First, let me analyze what needs to be implemented:
|
| 58 |
+
1. What is the specific goal or functionality required?
|
| 59 |
+
2. What are the key components or modules involved?
|
| 60 |
+
3. Are there any hardware or software constraints mentioned?"""
|
| 61 |
+
),
|
| 62 |
+
ChainStep(
|
| 63 |
+
step_type=ReasoningStep.DECOMPOSITION,
|
| 64 |
+
description="Break down into implementation steps",
|
| 65 |
+
prompt_addition="""
|
| 66 |
+
Next, let me break this down into logical implementation steps:
|
| 67 |
+
1. What are the prerequisites and dependencies?
|
| 68 |
+
2. What is the logical sequence of implementation?
|
| 69 |
+
3. Which steps are critical and which are optional?"""
|
| 70 |
+
),
|
| 71 |
+
ChainStep(
|
| 72 |
+
step_type=ReasoningStep.SYNTHESIS,
|
| 73 |
+
description="Synthesize the complete solution",
|
| 74 |
+
prompt_addition="""
|
| 75 |
+
Now I'll synthesize the complete solution:
|
| 76 |
+
1. How do the individual steps connect together?
|
| 77 |
+
2. What code examples or configurations are needed?
|
| 78 |
+
3. What are the key integration points?"""
|
| 79 |
+
),
|
| 80 |
+
ChainStep(
|
| 81 |
+
step_type=ReasoningStep.VALIDATION,
|
| 82 |
+
description="Consider validation and error handling",
|
| 83 |
+
prompt_addition="""
|
| 84 |
+
Finally, let me consider validation and potential issues:
|
| 85 |
+
1. How can we verify the implementation works?
|
| 86 |
+
2. What are common pitfalls or error conditions?
|
| 87 |
+
3. What debugging or troubleshooting steps are important?"""
|
| 88 |
+
)
|
| 89 |
+
],
|
| 90 |
+
|
| 91 |
+
QueryType.COMPARISON: [
|
| 92 |
+
ChainStep(
|
| 93 |
+
step_type=ReasoningStep.ANALYSIS,
|
| 94 |
+
description="Analyze items being compared",
|
| 95 |
+
prompt_addition="""
|
| 96 |
+
Let me start by analyzing what's being compared:
|
| 97 |
+
1. What are the specific items or concepts being compared?
|
| 98 |
+
2. What aspects or dimensions are relevant for comparison?
|
| 99 |
+
3. What context or use case should guide the comparison?"""
|
| 100 |
+
),
|
| 101 |
+
ChainStep(
|
| 102 |
+
step_type=ReasoningStep.DECOMPOSITION,
|
| 103 |
+
description="Break down comparison criteria",
|
| 104 |
+
prompt_addition="""
|
| 105 |
+
Next, let me identify the key comparison criteria:
|
| 106 |
+
1. What are the technical specifications or features to compare?
|
| 107 |
+
2. What are the performance characteristics?
|
| 108 |
+
3. What are the practical considerations (cost, complexity, etc.)?"""
|
| 109 |
+
),
|
| 110 |
+
ChainStep(
|
| 111 |
+
step_type=ReasoningStep.SYNTHESIS,
|
| 112 |
+
description="Synthesize comparison results",
|
| 113 |
+
prompt_addition="""
|
| 114 |
+
Now I'll synthesize the comparison:
|
| 115 |
+
1. How do the items compare on each criterion?
|
| 116 |
+
2. What are the key trade-offs and differences?
|
| 117 |
+
3. What recommendations can be made for different scenarios?"""
|
| 118 |
+
)
|
| 119 |
+
],
|
| 120 |
+
|
| 121 |
+
QueryType.TROUBLESHOOTING: [
|
| 122 |
+
ChainStep(
|
| 123 |
+
step_type=ReasoningStep.ANALYSIS,
|
| 124 |
+
description="Analyze the problem",
|
| 125 |
+
prompt_addition="""
|
| 126 |
+
Let me start by analyzing the problem:
|
| 127 |
+
1. What are the specific symptoms or error conditions?
|
| 128 |
+
2. What system or component is affected?
|
| 129 |
+
3. What was the expected vs actual behavior?"""
|
| 130 |
+
),
|
| 131 |
+
ChainStep(
|
| 132 |
+
step_type=ReasoningStep.DECOMPOSITION,
|
| 133 |
+
description="Identify potential root causes",
|
| 134 |
+
prompt_addition="""
|
| 135 |
+
Next, let me identify potential root causes:
|
| 136 |
+
1. What are the most likely causes based on the symptoms?
|
| 137 |
+
2. What system components could be involved?
|
| 138 |
+
3. What external factors might contribute to the issue?"""
|
| 139 |
+
),
|
| 140 |
+
ChainStep(
|
| 141 |
+
step_type=ReasoningStep.VALIDATION,
|
| 142 |
+
description="Develop diagnostic approach",
|
| 143 |
+
prompt_addition="""
|
| 144 |
+
Now I'll develop a diagnostic approach:
|
| 145 |
+
1. What tests or checks can isolate the root cause?
|
| 146 |
+
2. What is the recommended sequence of diagnostic steps?
|
| 147 |
+
3. How can we verify the fix once implemented?"""
|
| 148 |
+
)
|
| 149 |
+
],
|
| 150 |
+
|
| 151 |
+
QueryType.HARDWARE_CONSTRAINT: [
|
| 152 |
+
ChainStep(
|
| 153 |
+
step_type=ReasoningStep.ANALYSIS,
|
| 154 |
+
description="Analyze hardware requirements",
|
| 155 |
+
prompt_addition="""
|
| 156 |
+
Let me analyze the hardware requirements:
|
| 157 |
+
1. What are the specific hardware resources needed?
|
| 158 |
+
2. What are the performance requirements?
|
| 159 |
+
3. What are the power and size constraints?"""
|
| 160 |
+
),
|
| 161 |
+
ChainStep(
|
| 162 |
+
step_type=ReasoningStep.DECOMPOSITION,
|
| 163 |
+
description="Break down resource utilization",
|
| 164 |
+
prompt_addition="""
|
| 165 |
+
Next, let me break down resource utilization:
|
| 166 |
+
1. How much memory (RAM/Flash) is required?
|
| 167 |
+
2. What are the processing requirements (CPU/DSP)?
|
| 168 |
+
3. What I/O and peripheral requirements exist?"""
|
| 169 |
+
),
|
| 170 |
+
ChainStep(
|
| 171 |
+
step_type=ReasoningStep.SYNTHESIS,
|
| 172 |
+
description="Evaluate feasibility and alternatives",
|
| 173 |
+
prompt_addition="""
|
| 174 |
+
Now I'll evaluate feasibility:
|
| 175 |
+
1. Can the requirements be met with the available hardware?
|
| 176 |
+
2. What optimizations might be needed?
|
| 177 |
+
3. What are alternative approaches if constraints are exceeded?"""
|
| 178 |
+
)
|
| 179 |
+
]
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
def generate_chain_of_thought_prompt(
|
| 183 |
+
self,
|
| 184 |
+
query: str,
|
| 185 |
+
query_type: QueryType,
|
| 186 |
+
context: str,
|
| 187 |
+
base_template: PromptTemplate
|
| 188 |
+
) -> Dict[str, str]:
|
| 189 |
+
"""
|
| 190 |
+
Generate a chain-of-thought enhanced prompt.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
query: User's question
|
| 194 |
+
query_type: Type of query
|
| 195 |
+
context: Retrieved context
|
| 196 |
+
base_template: Base prompt template
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Enhanced prompt with chain-of-thought reasoning
|
| 200 |
+
"""
|
| 201 |
+
# Get reasoning chain for query type
|
| 202 |
+
reasoning_chain = self.reasoning_chains.get(query_type, [])
|
| 203 |
+
|
| 204 |
+
if not reasoning_chain:
|
| 205 |
+
# Fall back to generic reasoning for unsupported types
|
| 206 |
+
reasoning_chain = self._generate_generic_reasoning_chain(query)
|
| 207 |
+
|
| 208 |
+
# Build chain-of-thought prompt
|
| 209 |
+
cot_prompt = self._build_cot_prompt(reasoning_chain, query, context)
|
| 210 |
+
|
| 211 |
+
# Enhance system prompt
|
| 212 |
+
enhanced_system = base_template.system_prompt + """
|
| 213 |
+
|
| 214 |
+
CHAIN-OF-THOUGHT REASONING: You will approach this question using structured reasoning.
|
| 215 |
+
Work through each step methodically before providing your final answer.
|
| 216 |
+
Show your reasoning process clearly, then provide a comprehensive final answer."""
|
| 217 |
+
|
| 218 |
+
# Enhance user prompt
|
| 219 |
+
enhanced_user = f"""{base_template.context_format.format(context=context)}
|
| 220 |
+
|
| 221 |
+
{base_template.query_format.format(query=query)}
|
| 222 |
+
|
| 223 |
+
{cot_prompt}
|
| 224 |
+
|
| 225 |
+
{base_template.answer_guidelines}
|
| 226 |
+
|
| 227 |
+
After working through your reasoning, provide your final answer in the requested format."""
|
| 228 |
+
|
| 229 |
+
return {
|
| 230 |
+
"system": enhanced_system,
|
| 231 |
+
"user": enhanced_user
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
def _build_cot_prompt(
|
| 235 |
+
self,
|
| 236 |
+
reasoning_chain: List[ChainStep],
|
| 237 |
+
query: str,
|
| 238 |
+
context: str
|
| 239 |
+
) -> str:
|
| 240 |
+
"""
|
| 241 |
+
Build the chain-of-thought prompt section.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
reasoning_chain: List of reasoning steps
|
| 245 |
+
query: User's question
|
| 246 |
+
context: Retrieved context
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Chain-of-thought prompt text
|
| 250 |
+
"""
|
| 251 |
+
cot_sections = [
|
| 252 |
+
"REASONING PROCESS:",
|
| 253 |
+
"Work through this step-by-step using the following reasoning framework:",
|
| 254 |
+
""
|
| 255 |
+
]
|
| 256 |
+
|
| 257 |
+
for i, step in enumerate(reasoning_chain, 1):
|
| 258 |
+
cot_sections.append(f"Step {i}: {step.description}")
|
| 259 |
+
cot_sections.append(step.prompt_addition)
|
| 260 |
+
cot_sections.append("")
|
| 261 |
+
|
| 262 |
+
cot_sections.extend([
|
| 263 |
+
"STRUCTURED REASONING:",
|
| 264 |
+
"Now work through each step above, referencing the provided context where relevant.",
|
| 265 |
+
"Use [chunk_X] citations for your reasoning at each step.",
|
| 266 |
+
""
|
| 267 |
+
])
|
| 268 |
+
|
| 269 |
+
return "\n".join(cot_sections)
|
| 270 |
+
|
| 271 |
+
def _generate_generic_reasoning_chain(self, query: str) -> List[ChainStep]:
|
| 272 |
+
"""
|
| 273 |
+
Generate a generic reasoning chain for unsupported query types.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
query: User's question
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
List of generic reasoning steps
|
| 280 |
+
"""
|
| 281 |
+
# Analyze query complexity to determine appropriate steps
|
| 282 |
+
complexity_indicators = {
|
| 283 |
+
"multi_part": ["and", "also", "additionally", "furthermore"],
|
| 284 |
+
"causal": ["why", "because", "cause", "reason"],
|
| 285 |
+
"conditional": ["if", "when", "unless", "provided that"],
|
| 286 |
+
"comparative": ["better", "worse", "compare", "versus", "vs"]
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
query_lower = query.lower()
|
| 290 |
+
steps = []
|
| 291 |
+
|
| 292 |
+
# Always start with analysis
|
| 293 |
+
steps.append(ChainStep(
|
| 294 |
+
step_type=ReasoningStep.ANALYSIS,
|
| 295 |
+
description="Analyze the question",
|
| 296 |
+
prompt_addition="""
|
| 297 |
+
Let me start by analyzing the question:
|
| 298 |
+
1. What is the core question being asked?
|
| 299 |
+
2. What context or domain knowledge is needed?
|
| 300 |
+
3. Are there multiple parts to this question?"""
|
| 301 |
+
))
|
| 302 |
+
|
| 303 |
+
# Add decomposition for complex queries
|
| 304 |
+
if any(indicator in query_lower for indicators in complexity_indicators.values() for indicator in indicators):
|
| 305 |
+
steps.append(ChainStep(
|
| 306 |
+
step_type=ReasoningStep.DECOMPOSITION,
|
| 307 |
+
description="Break down the question",
|
| 308 |
+
prompt_addition="""
|
| 309 |
+
Let me break this down into components:
|
| 310 |
+
1. What are the key concepts or elements involved?
|
| 311 |
+
2. How do these elements relate to each other?
|
| 312 |
+
3. What information do I need to address each part?"""
|
| 313 |
+
))
|
| 314 |
+
|
| 315 |
+
# Always end with synthesis
|
| 316 |
+
steps.append(ChainStep(
|
| 317 |
+
step_type=ReasoningStep.SYNTHESIS,
|
| 318 |
+
description="Synthesize the answer",
|
| 319 |
+
prompt_addition="""
|
| 320 |
+
Now I'll synthesize a comprehensive answer:
|
| 321 |
+
1. How do all the pieces fit together?
|
| 322 |
+
2. What is the most complete and accurate response?
|
| 323 |
+
3. Are there any important caveats or limitations?"""
|
| 324 |
+
))
|
| 325 |
+
|
| 326 |
+
return steps
|
| 327 |
+
|
| 328 |
+
def create_reasoning_validation_prompt(
|
| 329 |
+
self,
|
| 330 |
+
query: str,
|
| 331 |
+
proposed_answer: str,
|
| 332 |
+
context: str
|
| 333 |
+
) -> str:
|
| 334 |
+
"""
|
| 335 |
+
Create a prompt for validating chain-of-thought reasoning.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
query: Original query
|
| 339 |
+
proposed_answer: Generated answer to validate
|
| 340 |
+
context: Context used for the answer
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
Validation prompt
|
| 344 |
+
"""
|
| 345 |
+
return f"""
|
| 346 |
+
REASONING VALIDATION TASK:
|
| 347 |
+
|
| 348 |
+
Original Query: {query}
|
| 349 |
+
|
| 350 |
+
Proposed Answer: {proposed_answer}
|
| 351 |
+
|
| 352 |
+
Context Used: {context}
|
| 353 |
+
|
| 354 |
+
Please validate the reasoning in the proposed answer by checking:
|
| 355 |
+
|
| 356 |
+
1. LOGICAL CONSISTENCY:
|
| 357 |
+
- Are the reasoning steps logically connected?
|
| 358 |
+
- Are there any contradictions or gaps in logic?
|
| 359 |
+
- Does the conclusion follow from the premises?
|
| 360 |
+
|
| 361 |
+
2. FACTUAL ACCURACY:
|
| 362 |
+
- Are the facts and technical details correct?
|
| 363 |
+
- Are the citations appropriate and accurate?
|
| 364 |
+
- Is the information consistent with the provided context?
|
| 365 |
+
|
| 366 |
+
3. COMPLETENESS:
|
| 367 |
+
- Does the answer address all parts of the question?
|
| 368 |
+
- Are important considerations or caveats mentioned?
|
| 369 |
+
- Is the level of detail appropriate for the question?
|
| 370 |
+
|
| 371 |
+
4. CLARITY:
|
| 372 |
+
- Is the reasoning easy to follow?
|
| 373 |
+
- Are technical terms used correctly?
|
| 374 |
+
- Is the structure logical and well-organized?
|
| 375 |
+
|
| 376 |
+
Provide your validation assessment with specific feedback on any issues found.
|
| 377 |
+
"""
|
| 378 |
+
|
| 379 |
+
def extract_reasoning_steps(self, cot_response: str) -> List[Dict[str, str]]:
|
| 380 |
+
"""
|
| 381 |
+
Extract reasoning steps from a chain-of-thought response.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
cot_response: Response containing chain-of-thought reasoning
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
List of extracted reasoning steps
|
| 388 |
+
"""
|
| 389 |
+
steps = []
|
| 390 |
+
|
| 391 |
+
# Look for step patterns
|
| 392 |
+
step_patterns = [
|
| 393 |
+
r"Step \d+:?\s*(.+?)(?=Step \d+|$)",
|
| 394 |
+
r"First,?\s*(.+?)(?=Next,?|Second,?|Then,?|Finally,?|$)",
|
| 395 |
+
r"Next,?\s*(.+?)(?=Then,?|Finally,?|Now,?|$)",
|
| 396 |
+
r"Then,?\s*(.+?)(?=Finally,?|Now,?|$)",
|
| 397 |
+
r"Finally,?\s*(.+?)(?=\n\n|$)"
|
| 398 |
+
]
|
| 399 |
+
|
| 400 |
+
for pattern in step_patterns:
|
| 401 |
+
matches = re.findall(pattern, cot_response, re.DOTALL | re.IGNORECASE)
|
| 402 |
+
for match in matches:
|
| 403 |
+
if match.strip():
|
| 404 |
+
steps.append({
|
| 405 |
+
"step_text": match.strip(),
|
| 406 |
+
"pattern": pattern
|
| 407 |
+
})
|
| 408 |
+
|
| 409 |
+
return steps
|
| 410 |
+
|
| 411 |
+
def evaluate_reasoning_quality(self, reasoning_steps: List[Dict[str, str]]) -> Dict[str, float]:
|
| 412 |
+
"""
|
| 413 |
+
Evaluate the quality of chain-of-thought reasoning.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
reasoning_steps: List of reasoning steps
|
| 417 |
+
|
| 418 |
+
Returns:
|
| 419 |
+
Dictionary of quality metrics
|
| 420 |
+
"""
|
| 421 |
+
if not reasoning_steps:
|
| 422 |
+
return {"overall_quality": 0.0, "step_count": 0}
|
| 423 |
+
|
| 424 |
+
# Evaluate different aspects
|
| 425 |
+
metrics = {
|
| 426 |
+
"step_count": len(reasoning_steps),
|
| 427 |
+
"logical_flow": self._evaluate_logical_flow(reasoning_steps),
|
| 428 |
+
"technical_depth": self._evaluate_technical_depth(reasoning_steps),
|
| 429 |
+
"citation_usage": self._evaluate_citation_usage(reasoning_steps),
|
| 430 |
+
"completeness": self._evaluate_completeness(reasoning_steps)
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
# Calculate overall quality
|
| 434 |
+
quality_weights = {
|
| 435 |
+
"logical_flow": 0.3,
|
| 436 |
+
"technical_depth": 0.3,
|
| 437 |
+
"citation_usage": 0.2,
|
| 438 |
+
"completeness": 0.2
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
overall_quality = sum(
|
| 442 |
+
metrics[key] * quality_weights[key]
|
| 443 |
+
for key in quality_weights
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
metrics["overall_quality"] = overall_quality
|
| 447 |
+
|
| 448 |
+
return metrics
|
| 449 |
+
|
| 450 |
+
def _evaluate_logical_flow(self, steps: List[Dict[str, str]]) -> float:
|
| 451 |
+
"""Evaluate logical flow between reasoning steps."""
|
| 452 |
+
if len(steps) < 2:
|
| 453 |
+
return 0.5
|
| 454 |
+
|
| 455 |
+
# Check for logical connectors
|
| 456 |
+
connectors = ["therefore", "thus", "because", "since", "as a result", "consequently"]
|
| 457 |
+
connector_count = 0
|
| 458 |
+
|
| 459 |
+
for step in steps:
|
| 460 |
+
step_text = step["step_text"].lower()
|
| 461 |
+
if any(connector in step_text for connector in connectors):
|
| 462 |
+
connector_count += 1
|
| 463 |
+
|
| 464 |
+
return min(connector_count / len(steps), 1.0)
|
| 465 |
+
|
| 466 |
+
def _evaluate_technical_depth(self, steps: List[Dict[str, str]]) -> float:
|
| 467 |
+
"""Evaluate technical depth of reasoning."""
|
| 468 |
+
technical_terms = [
|
| 469 |
+
"implementation", "architecture", "algorithm", "protocol", "specification",
|
| 470 |
+
"optimization", "configuration", "register", "memory", "hardware",
|
| 471 |
+
"software", "system", "component", "module", "interface"
|
| 472 |
+
]
|
| 473 |
+
|
| 474 |
+
total_terms = 0
|
| 475 |
+
total_words = 0
|
| 476 |
+
|
| 477 |
+
for step in steps:
|
| 478 |
+
words = step["step_text"].lower().split()
|
| 479 |
+
total_words += len(words)
|
| 480 |
+
|
| 481 |
+
for term in technical_terms:
|
| 482 |
+
total_terms += words.count(term)
|
| 483 |
+
|
| 484 |
+
return min(total_terms / max(total_words, 1) * 100, 1.0)
|
| 485 |
+
|
| 486 |
+
def _evaluate_citation_usage(self, steps: List[Dict[str, str]]) -> float:
|
| 487 |
+
"""Evaluate citation usage in reasoning."""
|
| 488 |
+
citation_pattern = r'\[chunk_\d+\]'
|
| 489 |
+
total_citations = 0
|
| 490 |
+
|
| 491 |
+
for step in steps:
|
| 492 |
+
citations = re.findall(citation_pattern, step["step_text"])
|
| 493 |
+
total_citations += len(citations)
|
| 494 |
+
|
| 495 |
+
# Good reasoning should have at least one citation per step
|
| 496 |
+
return min(total_citations / len(steps), 1.0)
|
| 497 |
+
|
| 498 |
+
def _evaluate_completeness(self, steps: List[Dict[str, str]]) -> float:
|
| 499 |
+
"""Evaluate completeness of reasoning."""
|
| 500 |
+
# Check for key reasoning elements
|
| 501 |
+
completeness_indicators = [
|
| 502 |
+
"analysis", "consider", "examine", "evaluate",
|
| 503 |
+
"conclusion", "summary", "result", "therefore",
|
| 504 |
+
"requirement", "constraint", "limitation", "important"
|
| 505 |
+
]
|
| 506 |
+
|
| 507 |
+
indicator_count = 0
|
| 508 |
+
for step in steps:
|
| 509 |
+
step_text = step["step_text"].lower()
|
| 510 |
+
for indicator in completeness_indicators:
|
| 511 |
+
if indicator in step_text:
|
| 512 |
+
indicator_count += 1
|
| 513 |
+
break
|
| 514 |
+
|
| 515 |
+
return indicator_count / len(steps)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
# Example usage
|
| 519 |
+
if __name__ == "__main__":
|
| 520 |
+
# Initialize engine
|
| 521 |
+
cot_engine = ChainOfThoughtEngine()
|
| 522 |
+
|
| 523 |
+
# Example implementation query
|
| 524 |
+
query = "How do I implement a real-time task scheduler in FreeRTOS with priority inheritance?"
|
| 525 |
+
query_type = QueryType.IMPLEMENTATION
|
| 526 |
+
context = "FreeRTOS supports priority-based scheduling with optional priority inheritance..."
|
| 527 |
+
|
| 528 |
+
# Create a basic template
|
| 529 |
+
base_template = PromptTemplate(
|
| 530 |
+
system_prompt="You are a technical assistant.",
|
| 531 |
+
context_format="Context: {context}",
|
| 532 |
+
query_format="Question: {query}",
|
| 533 |
+
answer_guidelines="Provide a structured answer."
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# Generate chain-of-thought prompt
|
| 537 |
+
cot_prompt = cot_engine.generate_chain_of_thought_prompt(
|
| 538 |
+
query=query,
|
| 539 |
+
query_type=query_type,
|
| 540 |
+
context=context,
|
| 541 |
+
base_template=base_template
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
print("Chain-of-Thought Enhanced Prompt:")
|
| 545 |
+
print("=" * 50)
|
| 546 |
+
print("System:", cot_prompt["system"][:200], "...")
|
| 547 |
+
print("User:", cot_prompt["user"][:300], "...")
|
| 548 |
+
print("=" * 50)
|
| 549 |
+
|
| 550 |
+
# Example reasoning evaluation
|
| 551 |
+
example_response = """
|
| 552 |
+
Step 1: Let me analyze the requirements
|
| 553 |
+
FreeRTOS provides priority-based scheduling [chunk_1]...
|
| 554 |
+
|
| 555 |
+
Step 2: Breaking down the implementation
|
| 556 |
+
Priority inheritance requires mutex implementation [chunk_2]...
|
| 557 |
+
|
| 558 |
+
Step 3: Synthesizing the solution
|
| 559 |
+
Therefore, we need to configure priority inheritance in FreeRTOS [chunk_3]...
|
| 560 |
+
"""
|
| 561 |
+
|
| 562 |
+
steps = cot_engine.extract_reasoning_steps(example_response)
|
| 563 |
+
quality = cot_engine.evaluate_reasoning_quality(steps)
|
| 564 |
+
|
| 565 |
+
print(f"Reasoning Quality: {quality}")
|
shared_utils/generation/hf_answer_generator.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HuggingFace API-based answer generation for deployment environments.
|
| 3 |
+
|
| 4 |
+
This module provides answer generation using HuggingFace's Inference API,
|
| 5 |
+
optimized for cloud deployment where local LLMs aren't feasible.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from typing import List, Dict, Any, Optional, Generator, Tuple
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import re
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import requests
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
# Import technical prompt templates
|
| 20 |
+
from .prompt_templates import TechnicalPromptTemplates
|
| 21 |
+
|
| 22 |
+
# Import standard interfaces (add this for the adapter)
|
| 23 |
+
try:
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
import sys
|
| 26 |
+
project_root = Path(__file__).parent.parent.parent.parent.parent
|
| 27 |
+
sys.path.append(str(project_root))
|
| 28 |
+
from src.core.interfaces import Document, Answer, AnswerGenerator
|
| 29 |
+
except ImportError:
|
| 30 |
+
# Fallback for standalone usage
|
| 31 |
+
Document = None
|
| 32 |
+
Answer = None
|
| 33 |
+
AnswerGenerator = object
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class Citation:
|
| 40 |
+
"""Represents a citation to a source document chunk."""
|
| 41 |
+
chunk_id: str
|
| 42 |
+
page_number: int
|
| 43 |
+
source_file: str
|
| 44 |
+
relevance_score: float
|
| 45 |
+
text_snippet: str
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class GeneratedAnswer:
|
| 50 |
+
"""Represents a generated answer with citations."""
|
| 51 |
+
answer: str
|
| 52 |
+
citations: List[Citation]
|
| 53 |
+
confidence_score: float
|
| 54 |
+
generation_time: float
|
| 55 |
+
model_used: str
|
| 56 |
+
context_used: List[Dict[str, Any]]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class HuggingFaceAnswerGenerator(AnswerGenerator if AnswerGenerator != object else object):
|
| 60 |
+
"""
|
| 61 |
+
Generates answers using HuggingFace Inference API with hybrid reliability.
|
| 62 |
+
|
| 63 |
+
🎯 HYBRID APPROACH - Best of Both Worlds:
|
| 64 |
+
- Primary: High-quality open models (Zephyr-7B, Mistral-7B-Instruct)
|
| 65 |
+
- Fallback: Reliable classics (DialoGPT-medium)
|
| 66 |
+
- Foundation: HF GPT's proven Docker + auth setup
|
| 67 |
+
- Pro Benefits: Better rate limits, priority processing
|
| 68 |
+
|
| 69 |
+
Optimized for deployment environments with:
|
| 70 |
+
- Fast API-based inference
|
| 71 |
+
- No local model requirements
|
| 72 |
+
- Citation extraction and formatting
|
| 73 |
+
- Confidence scoring
|
| 74 |
+
- Automatic fallback for reliability
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
model_name: str = "sshleifer/distilbart-cnn-12-6",
|
| 80 |
+
api_token: Optional[str] = None,
|
| 81 |
+
temperature: float = 0.3,
|
| 82 |
+
max_tokens: int = 512
|
| 83 |
+
):
|
| 84 |
+
"""
|
| 85 |
+
Initialize the HuggingFace answer generator.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
model_name: HuggingFace model to use
|
| 89 |
+
api_token: HF API token (optional, uses free tier if None)
|
| 90 |
+
temperature: Generation temperature (0.0-1.0)
|
| 91 |
+
max_tokens: Maximum tokens to generate
|
| 92 |
+
"""
|
| 93 |
+
self.model_name = model_name
|
| 94 |
+
# Try multiple common token environment variable names
|
| 95 |
+
self.api_token = (api_token or
|
| 96 |
+
os.getenv("HUGGINGFACE_API_TOKEN") or
|
| 97 |
+
os.getenv("HF_TOKEN") or
|
| 98 |
+
os.getenv("HF_API_TOKEN"))
|
| 99 |
+
self.temperature = temperature
|
| 100 |
+
self.max_tokens = max_tokens
|
| 101 |
+
|
| 102 |
+
# Hybrid approach: Classic API + fallback models
|
| 103 |
+
self.api_url = f"https://api-inference.huggingface.co/models/{model_name}"
|
| 104 |
+
|
| 105 |
+
# Prepare headers
|
| 106 |
+
self.headers = {"Content-Type": "application/json"}
|
| 107 |
+
self._auth_failed = False # Track if auth has failed
|
| 108 |
+
if self.api_token:
|
| 109 |
+
self.headers["Authorization"] = f"Bearer {self.api_token}"
|
| 110 |
+
logger.info("Using authenticated HuggingFace API")
|
| 111 |
+
else:
|
| 112 |
+
logger.info("Using free HuggingFace API (rate limited)")
|
| 113 |
+
|
| 114 |
+
# Only include models that actually work based on tests
|
| 115 |
+
self.fallback_models = [
|
| 116 |
+
"deepset/roberta-base-squad2", # Q&A model - perfect for RAG
|
| 117 |
+
"sshleifer/distilbart-cnn-12-6", # Summarization - also good
|
| 118 |
+
"facebook/bart-base", # Base BART - works but needs right format
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
def _make_api_request(self, url: str, payload: dict, timeout: int = 30) -> requests.Response:
|
| 122 |
+
"""Make API request with automatic 401 handling."""
|
| 123 |
+
# Use current headers (may have been updated if auth failed)
|
| 124 |
+
headers = self.headers.copy()
|
| 125 |
+
|
| 126 |
+
# If we've already had auth failure, don't include the token
|
| 127 |
+
if self._auth_failed and "Authorization" in headers:
|
| 128 |
+
del headers["Authorization"]
|
| 129 |
+
|
| 130 |
+
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
| 131 |
+
|
| 132 |
+
# Handle 401 error
|
| 133 |
+
if response.status_code == 401 and not self._auth_failed and self.api_token:
|
| 134 |
+
logger.error(f"API request failed: 401 Unauthorized")
|
| 135 |
+
logger.error(f"Response body: {response.text}")
|
| 136 |
+
logger.warning("Token appears invalid, retrying without authentication...")
|
| 137 |
+
self._auth_failed = True
|
| 138 |
+
# Remove auth header
|
| 139 |
+
if "Authorization" in self.headers:
|
| 140 |
+
del self.headers["Authorization"]
|
| 141 |
+
headers = self.headers.copy()
|
| 142 |
+
# Retry without auth
|
| 143 |
+
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
| 144 |
+
if response.status_code == 401:
|
| 145 |
+
logger.error("Still getting 401 even without auth token")
|
| 146 |
+
logger.error(f"Response body: {response.text}")
|
| 147 |
+
|
| 148 |
+
return response
|
| 149 |
+
|
| 150 |
+
def _call_api_with_model(self, prompt: str, model_name: str) -> str:
|
| 151 |
+
"""Call API with a specific model (for fallback support)."""
|
| 152 |
+
fallback_url = f"https://api-inference.huggingface.co/models/{model_name}"
|
| 153 |
+
|
| 154 |
+
# SIMPLIFIED payload that works
|
| 155 |
+
payload = {"inputs": prompt}
|
| 156 |
+
|
| 157 |
+
# Use helper method with 401 handling
|
| 158 |
+
response = self._make_api_request(fallback_url, payload)
|
| 159 |
+
|
| 160 |
+
response.raise_for_status()
|
| 161 |
+
result = response.json()
|
| 162 |
+
|
| 163 |
+
# Handle response
|
| 164 |
+
if isinstance(result, list) and len(result) > 0:
|
| 165 |
+
if isinstance(result[0], dict):
|
| 166 |
+
return result[0].get("generated_text", "").strip()
|
| 167 |
+
else:
|
| 168 |
+
return str(result[0]).strip()
|
| 169 |
+
elif isinstance(result, dict):
|
| 170 |
+
return result.get("generated_text", "").strip()
|
| 171 |
+
else:
|
| 172 |
+
return str(result).strip()
|
| 173 |
+
|
| 174 |
+
def _create_system_prompt(self) -> str:
|
| 175 |
+
"""Create system prompt optimized for the model type."""
|
| 176 |
+
if "squad" in self.model_name.lower() or "roberta" in self.model_name.lower():
|
| 177 |
+
# RoBERTa Squad2 uses question/context format - no system prompt needed
|
| 178 |
+
return ""
|
| 179 |
+
elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower():
|
| 180 |
+
# GPT-2 style completion prompt - simpler is better
|
| 181 |
+
return "Based on the following context, answer the question.\n\nContext: "
|
| 182 |
+
elif "llama" in self.model_name.lower():
|
| 183 |
+
# Llama-2 chat format
|
| 184 |
+
return """<s>[INST] You are a helpful technical documentation assistant. Answer the user's question based only on the provided context. Always cite sources using [chunk_X] format.
|
| 185 |
+
|
| 186 |
+
Context:"""
|
| 187 |
+
elif "flan" in self.model_name.lower() or "t5" in self.model_name.lower():
|
| 188 |
+
# Flan-T5 instruction format - simple and direct
|
| 189 |
+
return """Answer the question based on the context below. Cite sources using [chunk_X] format.
|
| 190 |
+
|
| 191 |
+
Context: """
|
| 192 |
+
elif "falcon" in self.model_name.lower():
|
| 193 |
+
# Falcon instruction format
|
| 194 |
+
return """### Instruction: Answer based on the context and cite sources with [chunk_X].
|
| 195 |
+
|
| 196 |
+
### Context: """
|
| 197 |
+
elif "bart" in self.model_name.lower():
|
| 198 |
+
# BART summarization format
|
| 199 |
+
return """Summarize the answer to the question from the context. Use [chunk_X] for citations.
|
| 200 |
+
|
| 201 |
+
Context: """
|
| 202 |
+
else:
|
| 203 |
+
# Default instruction prompt for other models
|
| 204 |
+
return """You are a technical documentation assistant that provides clear, accurate answers based on the provided context.
|
| 205 |
+
|
| 206 |
+
CORE PRINCIPLES:
|
| 207 |
+
1. ANSWER DIRECTLY: If context contains the answer, provide it clearly and confidently
|
| 208 |
+
2. BE CONCISE: Keep responses focused and avoid unnecessary uncertainty language
|
| 209 |
+
3. CITE ACCURATELY: Use [chunk_X] citations for every fact from context
|
| 210 |
+
|
| 211 |
+
RESPONSE GUIDELINES:
|
| 212 |
+
- If context has sufficient information → Answer directly and confidently
|
| 213 |
+
- If context has partial information → Answer what's available, note what's missing briefly
|
| 214 |
+
- If context is irrelevant → Brief refusal: "This information isn't available in the provided documents"
|
| 215 |
+
|
| 216 |
+
CITATION FORMAT:
|
| 217 |
+
- Use [chunk_1], [chunk_2] etc. for all facts from context
|
| 218 |
+
- Example: "According to [chunk_1], RISC-V is an open-source architecture."
|
| 219 |
+
|
| 220 |
+
Be direct, confident, and accurate. If the context answers the question, provide that answer clearly."""
|
| 221 |
+
|
| 222 |
+
def _format_context(self, chunks: List[Dict[str, Any]]) -> str:
|
| 223 |
+
"""
|
| 224 |
+
Format retrieved chunks into context for the LLM.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
chunks: List of retrieved chunks with metadata
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Formatted context string
|
| 231 |
+
"""
|
| 232 |
+
context_parts = []
|
| 233 |
+
|
| 234 |
+
for i, chunk in enumerate(chunks):
|
| 235 |
+
chunk_text = chunk.get('content', chunk.get('text', ''))
|
| 236 |
+
page_num = chunk.get('metadata', {}).get('page_number', 'unknown')
|
| 237 |
+
source = chunk.get('metadata', {}).get('source', 'unknown')
|
| 238 |
+
|
| 239 |
+
context_parts.append(
|
| 240 |
+
f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n"
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return "\n---\n".join(context_parts)
|
| 244 |
+
|
| 245 |
+
def _call_api(self, prompt: str) -> str:
|
| 246 |
+
"""
|
| 247 |
+
Call HuggingFace Inference API.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
prompt: Input prompt for the model
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
Generated text response
|
| 254 |
+
"""
|
| 255 |
+
# Validate prompt
|
| 256 |
+
if not prompt or len(prompt.strip()) < 5:
|
| 257 |
+
logger.warning(f"Prompt too short: '{prompt}' - padding it")
|
| 258 |
+
prompt = f"Please provide information about: {prompt}. Based on the context, give a detailed answer."
|
| 259 |
+
|
| 260 |
+
# Model-specific payload formatting
|
| 261 |
+
if "squad" in self.model_name.lower() or "roberta" in self.model_name.lower():
|
| 262 |
+
# RoBERTa Squad2 needs question and context separately
|
| 263 |
+
# Parse the structured prompt format we create
|
| 264 |
+
if "Context:" in prompt and "Question:" in prompt:
|
| 265 |
+
# Split by the markers we use
|
| 266 |
+
parts = prompt.split("Question:")
|
| 267 |
+
if len(parts) == 2:
|
| 268 |
+
context_part = parts[0].replace("Context:", "").strip()
|
| 269 |
+
question_part = parts[1].strip()
|
| 270 |
+
else:
|
| 271 |
+
# Fallback
|
| 272 |
+
question_part = "What is this about?"
|
| 273 |
+
context_part = prompt
|
| 274 |
+
else:
|
| 275 |
+
# Fallback for unexpected format
|
| 276 |
+
question_part = "What is this about?"
|
| 277 |
+
context_part = prompt
|
| 278 |
+
|
| 279 |
+
# Clean up the context and question
|
| 280 |
+
context_part = context_part.replace("---", "").strip()
|
| 281 |
+
if not question_part or len(question_part.strip()) < 3:
|
| 282 |
+
question_part = "What is the main information?"
|
| 283 |
+
|
| 284 |
+
# Debug output
|
| 285 |
+
print(f"🔍 Squad2 Question: {question_part[:100]}...")
|
| 286 |
+
print(f"🔍 Squad2 Context: {context_part[:200]}...")
|
| 287 |
+
|
| 288 |
+
payload = {
|
| 289 |
+
"inputs": {
|
| 290 |
+
"question": question_part,
|
| 291 |
+
"context": context_part
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
elif "bart" in self.model_name.lower() or "distilbart" in self.model_name.lower():
|
| 295 |
+
# BART/DistilBART for summarization
|
| 296 |
+
if len(prompt) < 50:
|
| 297 |
+
prompt = f"{prompt} Please provide a comprehensive answer based on the available information."
|
| 298 |
+
|
| 299 |
+
payload = {
|
| 300 |
+
"inputs": prompt,
|
| 301 |
+
"parameters": {
|
| 302 |
+
"max_length": 150,
|
| 303 |
+
"min_length": 10,
|
| 304 |
+
"do_sample": False
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
else:
|
| 308 |
+
# Simple payload for other models
|
| 309 |
+
payload = {"inputs": prompt}
|
| 310 |
+
|
| 311 |
+
try:
|
| 312 |
+
logger.info(f"Calling API URL: {self.api_url}")
|
| 313 |
+
logger.info(f"Headers: {self.headers}")
|
| 314 |
+
logger.info(f"Payload: {payload}")
|
| 315 |
+
|
| 316 |
+
# Use helper method with 401 handling
|
| 317 |
+
response = self._make_api_request(self.api_url, payload)
|
| 318 |
+
|
| 319 |
+
logger.info(f"Response status: {response.status_code}")
|
| 320 |
+
logger.info(f"Response headers: {response.headers}")
|
| 321 |
+
|
| 322 |
+
if response.status_code == 503:
|
| 323 |
+
# Model is loading, wait and retry
|
| 324 |
+
logger.warning("Model loading, waiting 20 seconds...")
|
| 325 |
+
import time
|
| 326 |
+
time.sleep(20)
|
| 327 |
+
response = self._make_api_request(self.api_url, payload)
|
| 328 |
+
logger.info(f"Retry response status: {response.status_code}")
|
| 329 |
+
|
| 330 |
+
elif response.status_code == 404:
|
| 331 |
+
logger.error(f"Model not found: {self.model_name}")
|
| 332 |
+
logger.error(f"Response text: {response.text}")
|
| 333 |
+
# Try fallback models
|
| 334 |
+
for fallback_model in self.fallback_models:
|
| 335 |
+
if fallback_model != self.model_name:
|
| 336 |
+
logger.info(f"Trying fallback model: {fallback_model}")
|
| 337 |
+
try:
|
| 338 |
+
return self._call_api_with_model(prompt, fallback_model)
|
| 339 |
+
except Exception as e:
|
| 340 |
+
logger.warning(f"Fallback model {fallback_model} failed: {e}")
|
| 341 |
+
continue
|
| 342 |
+
return "All models are currently unavailable. Please try again later."
|
| 343 |
+
|
| 344 |
+
response.raise_for_status()
|
| 345 |
+
result = response.json()
|
| 346 |
+
|
| 347 |
+
# Handle different response formats based on model type
|
| 348 |
+
print(f"🔍 API Response type: {type(result)}")
|
| 349 |
+
print(f"🔍 API Response preview: {str(result)[:300]}...")
|
| 350 |
+
|
| 351 |
+
if isinstance(result, dict) and "answer" in result:
|
| 352 |
+
# RoBERTa Squad2 format: {"answer": "...", "score": ..., "start": ..., "end": ...}
|
| 353 |
+
answer = result["answer"].strip()
|
| 354 |
+
print(f"🔍 Squad2 extracted answer: {answer}")
|
| 355 |
+
return answer
|
| 356 |
+
elif isinstance(result, list) and len(result) > 0:
|
| 357 |
+
# Check for DistilBART format (returns dict with summary_text)
|
| 358 |
+
if isinstance(result[0], dict) and "summary_text" in result[0]:
|
| 359 |
+
return result[0]["summary_text"].strip()
|
| 360 |
+
# Check for nested list (BART format: [[...]])
|
| 361 |
+
elif isinstance(result[0], list) and len(result[0]) > 0:
|
| 362 |
+
if isinstance(result[0][0], dict):
|
| 363 |
+
return result[0][0].get("summary_text", str(result[0][0])).strip()
|
| 364 |
+
else:
|
| 365 |
+
# BART base returns embeddings - not useful for text generation
|
| 366 |
+
logger.warning("BART returned embeddings instead of text")
|
| 367 |
+
return "Model returned embeddings instead of text. Please try a different model."
|
| 368 |
+
# Regular list format
|
| 369 |
+
elif isinstance(result[0], dict):
|
| 370 |
+
# Try different keys that models might use
|
| 371 |
+
text = (result[0].get("generated_text", "") or
|
| 372 |
+
result[0].get("summary_text", "") or
|
| 373 |
+
result[0].get("translation_text", "") or
|
| 374 |
+
result[0].get("answer", "") or
|
| 375 |
+
str(result[0]))
|
| 376 |
+
# Remove the input prompt from the output if present
|
| 377 |
+
if isinstance(prompt, str) and text.startswith(prompt):
|
| 378 |
+
text = text[len(prompt):].strip()
|
| 379 |
+
return text
|
| 380 |
+
else:
|
| 381 |
+
return str(result[0]).strip()
|
| 382 |
+
elif isinstance(result, dict):
|
| 383 |
+
# Some models return dict directly
|
| 384 |
+
text = (result.get("generated_text", "") or
|
| 385 |
+
result.get("summary_text", "") or
|
| 386 |
+
result.get("translation_text", "") or
|
| 387 |
+
result.get("answer", "") or
|
| 388 |
+
str(result))
|
| 389 |
+
# Remove input prompt if model included it
|
| 390 |
+
if isinstance(prompt, str) and text.startswith(prompt):
|
| 391 |
+
text = text[len(prompt):].strip()
|
| 392 |
+
return text
|
| 393 |
+
elif isinstance(result, str):
|
| 394 |
+
return result.strip()
|
| 395 |
+
else:
|
| 396 |
+
logger.error(f"Unexpected response format: {type(result)} - {result}")
|
| 397 |
+
return "I apologize, but I couldn't generate a response."
|
| 398 |
+
|
| 399 |
+
except requests.exceptions.RequestException as e:
|
| 400 |
+
logger.error(f"API request failed: {e}")
|
| 401 |
+
if hasattr(e, 'response') and e.response is not None:
|
| 402 |
+
logger.error(f"Response status: {e.response.status_code}")
|
| 403 |
+
logger.error(f"Response body: {e.response.text}")
|
| 404 |
+
return f"API Error: {str(e)}. Using free tier? Try adding an API token."
|
| 405 |
+
except Exception as e:
|
| 406 |
+
logger.error(f"Unexpected error: {e}")
|
| 407 |
+
import traceback
|
| 408 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 409 |
+
return f"Error: {str(e)}. Please check logs for details."
|
| 410 |
+
|
| 411 |
+
def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]:
|
| 412 |
+
"""
|
| 413 |
+
Extract citations from the generated answer and integrate them naturally.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
answer: Generated answer with [chunk_X] citations
|
| 417 |
+
chunks: Original chunks used for context
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
Tuple of (natural_answer, citations)
|
| 421 |
+
"""
|
| 422 |
+
citations = []
|
| 423 |
+
citation_pattern = r'\[chunk_(\d+)\]'
|
| 424 |
+
|
| 425 |
+
cited_chunks = set()
|
| 426 |
+
|
| 427 |
+
# Find [chunk_X] citations and collect cited chunks
|
| 428 |
+
matches = re.finditer(citation_pattern, answer)
|
| 429 |
+
for match in matches:
|
| 430 |
+
chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index
|
| 431 |
+
if 0 <= chunk_idx < len(chunks):
|
| 432 |
+
cited_chunks.add(chunk_idx)
|
| 433 |
+
|
| 434 |
+
# FALLBACK: If no explicit citations found but we have an answer and chunks,
|
| 435 |
+
# create citations for the top chunks that were likely used
|
| 436 |
+
if not cited_chunks and chunks and len(answer.strip()) > 50:
|
| 437 |
+
# Use the top chunks that were provided as likely sources
|
| 438 |
+
num_fallback_citations = min(3, len(chunks)) # Use top 3 chunks max
|
| 439 |
+
cited_chunks = set(range(num_fallback_citations))
|
| 440 |
+
print(f"🔧 HF Fallback: Creating {num_fallback_citations} citations for answer without explicit [chunk_X] references", file=sys.stderr, flush=True)
|
| 441 |
+
|
| 442 |
+
# Create Citation objects for each cited chunk
|
| 443 |
+
chunk_to_source = {}
|
| 444 |
+
for idx in cited_chunks:
|
| 445 |
+
chunk = chunks[idx]
|
| 446 |
+
citation = Citation(
|
| 447 |
+
chunk_id=chunk.get('id', f'chunk_{idx}'),
|
| 448 |
+
page_number=chunk.get('metadata', {}).get('page_number', 0),
|
| 449 |
+
source_file=chunk.get('metadata', {}).get('source', 'unknown'),
|
| 450 |
+
relevance_score=chunk.get('score', 0.0),
|
| 451 |
+
text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...'
|
| 452 |
+
)
|
| 453 |
+
citations.append(citation)
|
| 454 |
+
|
| 455 |
+
# Map chunk reference to natural source name
|
| 456 |
+
source_name = chunk.get('metadata', {}).get('source', 'unknown')
|
| 457 |
+
if source_name != 'unknown':
|
| 458 |
+
# Use just the filename without extension for natural reference
|
| 459 |
+
natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ')
|
| 460 |
+
chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation"
|
| 461 |
+
else:
|
| 462 |
+
chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation"
|
| 463 |
+
|
| 464 |
+
# Replace [chunk_X] with natural references instead of removing them
|
| 465 |
+
natural_answer = answer
|
| 466 |
+
for chunk_ref, natural_ref in chunk_to_source.items():
|
| 467 |
+
natural_answer = natural_answer.replace(chunk_ref, natural_ref)
|
| 468 |
+
|
| 469 |
+
# Clean up any remaining unreferenced citations (fallback)
|
| 470 |
+
natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer)
|
| 471 |
+
|
| 472 |
+
# Clean up multiple spaces and formatting
|
| 473 |
+
natural_answer = re.sub(r'\s+', ' ', natural_answer).strip()
|
| 474 |
+
|
| 475 |
+
return natural_answer, citations
|
| 476 |
+
|
| 477 |
+
def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float:
|
| 478 |
+
"""
|
| 479 |
+
Calculate confidence score for the generated answer.
|
| 480 |
+
|
| 481 |
+
Args:
|
| 482 |
+
answer: Generated answer
|
| 483 |
+
citations: Extracted citations
|
| 484 |
+
chunks: Retrieved chunks
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
Confidence score (0.0-1.0)
|
| 488 |
+
"""
|
| 489 |
+
if not chunks:
|
| 490 |
+
return 0.05 # No context = very low confidence
|
| 491 |
+
|
| 492 |
+
# Base confidence from context quality
|
| 493 |
+
scores = [chunk.get('score', 0) for chunk in chunks]
|
| 494 |
+
max_relevance = max(scores) if scores else 0
|
| 495 |
+
|
| 496 |
+
if max_relevance >= 0.8:
|
| 497 |
+
confidence = 0.7 # High-quality context
|
| 498 |
+
elif max_relevance >= 0.6:
|
| 499 |
+
confidence = 0.5 # Good context
|
| 500 |
+
elif max_relevance >= 0.4:
|
| 501 |
+
confidence = 0.3 # Fair context
|
| 502 |
+
else:
|
| 503 |
+
confidence = 0.1 # Poor context
|
| 504 |
+
|
| 505 |
+
# Uncertainty indicators
|
| 506 |
+
uncertainty_phrases = [
|
| 507 |
+
"does not contain sufficient information",
|
| 508 |
+
"context does not provide",
|
| 509 |
+
"insufficient information",
|
| 510 |
+
"cannot determine",
|
| 511 |
+
"not available in the provided documents"
|
| 512 |
+
]
|
| 513 |
+
|
| 514 |
+
if any(phrase in answer.lower() for phrase in uncertainty_phrases):
|
| 515 |
+
return min(0.15, confidence * 0.3)
|
| 516 |
+
|
| 517 |
+
# Citation bonus
|
| 518 |
+
if citations and chunks:
|
| 519 |
+
citation_ratio = len(citations) / min(len(chunks), 3)
|
| 520 |
+
confidence += 0.2 * citation_ratio
|
| 521 |
+
|
| 522 |
+
return min(confidence, 0.9) # Cap at 90%
|
| 523 |
+
|
| 524 |
+
def generate(self, query: str, context: List[Document]) -> Answer:
|
| 525 |
+
"""
|
| 526 |
+
Generate an answer from query and context documents (standard interface).
|
| 527 |
+
|
| 528 |
+
This is the public interface that conforms to the AnswerGenerator protocol.
|
| 529 |
+
It handles the conversion between standard Document objects and HuggingFace's
|
| 530 |
+
internal chunk format.
|
| 531 |
+
|
| 532 |
+
Args:
|
| 533 |
+
query: User's question
|
| 534 |
+
context: List of relevant Document objects
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
Answer object conforming to standard interface
|
| 538 |
+
|
| 539 |
+
Raises:
|
| 540 |
+
ValueError: If query is empty or context is None
|
| 541 |
+
"""
|
| 542 |
+
if not query.strip():
|
| 543 |
+
raise ValueError("Query cannot be empty")
|
| 544 |
+
|
| 545 |
+
if context is None:
|
| 546 |
+
raise ValueError("Context cannot be None")
|
| 547 |
+
|
| 548 |
+
# Internal adapter: Convert Documents to HuggingFace chunk format
|
| 549 |
+
hf_chunks = self._documents_to_hf_chunks(context)
|
| 550 |
+
|
| 551 |
+
# Use existing HuggingFace-specific generation logic
|
| 552 |
+
hf_result = self._generate_internal(query, hf_chunks)
|
| 553 |
+
|
| 554 |
+
# Internal adapter: Convert HuggingFace result to standard Answer
|
| 555 |
+
return self._hf_result_to_answer(hf_result, context)
|
| 556 |
+
|
| 557 |
+
def _generate_internal(
|
| 558 |
+
self,
|
| 559 |
+
query: str,
|
| 560 |
+
chunks: List[Dict[str, Any]]
|
| 561 |
+
) -> GeneratedAnswer:
|
| 562 |
+
"""
|
| 563 |
+
Generate an answer based on the query and retrieved chunks.
|
| 564 |
+
|
| 565 |
+
Args:
|
| 566 |
+
query: User's question
|
| 567 |
+
chunks: Retrieved document chunks
|
| 568 |
+
|
| 569 |
+
Returns:
|
| 570 |
+
GeneratedAnswer object with answer, citations, and metadata
|
| 571 |
+
"""
|
| 572 |
+
start_time = datetime.now()
|
| 573 |
+
|
| 574 |
+
# Check for no-context situation
|
| 575 |
+
if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks):
|
| 576 |
+
return GeneratedAnswer(
|
| 577 |
+
answer="This information isn't available in the provided documents.",
|
| 578 |
+
citations=[],
|
| 579 |
+
confidence_score=0.05,
|
| 580 |
+
generation_time=0.1,
|
| 581 |
+
model_used=self.model_name,
|
| 582 |
+
context_used=chunks
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# Format context from chunks
|
| 586 |
+
context = self._format_context(chunks)
|
| 587 |
+
|
| 588 |
+
# Create prompt using TechnicalPromptTemplates for consistency
|
| 589 |
+
prompt_data = TechnicalPromptTemplates.format_prompt_with_template(
|
| 590 |
+
query=query,
|
| 591 |
+
context=context
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
# Format for specific model types
|
| 595 |
+
if "squad" in self.model_name.lower() or "roberta" in self.model_name.lower():
|
| 596 |
+
# Squad2 uses special question/context format - handled in _call_api
|
| 597 |
+
prompt = f"Context: {context}\n\nQuestion: {query}"
|
| 598 |
+
elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower():
|
| 599 |
+
# Simple completion style for GPT-2
|
| 600 |
+
prompt = f"""{prompt_data['system']}
|
| 601 |
+
|
| 602 |
+
{prompt_data['user']}
|
| 603 |
+
|
| 604 |
+
MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts.
|
| 605 |
+
|
| 606 |
+
Answer:"""
|
| 607 |
+
elif "llama" in self.model_name.lower():
|
| 608 |
+
# Llama-2 chat format with technical templates
|
| 609 |
+
prompt = f"""[INST] {prompt_data['system']}
|
| 610 |
+
|
| 611 |
+
{prompt_data['user']}
|
| 612 |
+
|
| 613 |
+
MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
|
| 614 |
+
elif "mistral" in self.model_name.lower():
|
| 615 |
+
# Mistral instruction format with technical templates
|
| 616 |
+
prompt = f"""[INST] {prompt_data['system']}
|
| 617 |
+
|
| 618 |
+
{prompt_data['user']}
|
| 619 |
+
|
| 620 |
+
MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
|
| 621 |
+
elif "codellama" in self.model_name.lower():
|
| 622 |
+
# CodeLlama instruction format with technical templates
|
| 623 |
+
prompt = f"""[INST] {prompt_data['system']}
|
| 624 |
+
|
| 625 |
+
{prompt_data['user']}
|
| 626 |
+
|
| 627 |
+
MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
|
| 628 |
+
elif "distilbart" in self.model_name.lower():
|
| 629 |
+
# DistilBART is a summarization model - simpler prompt works better
|
| 630 |
+
prompt = f"""Technical Documentation Context:
|
| 631 |
+
{context}
|
| 632 |
+
|
| 633 |
+
Question: {query}
|
| 634 |
+
|
| 635 |
+
Instructions: Provide a technical answer using only the context above. Include source citations."""
|
| 636 |
+
else:
|
| 637 |
+
# Default instruction prompt with technical templates
|
| 638 |
+
prompt = f"""{prompt_data['system']}
|
| 639 |
+
|
| 640 |
+
{prompt_data['user']}
|
| 641 |
+
|
| 642 |
+
MANDATORY: Use [chunk_1], [chunk_2] etc. for all factual statements.
|
| 643 |
+
|
| 644 |
+
Answer:"""
|
| 645 |
+
|
| 646 |
+
# Generate response
|
| 647 |
+
try:
|
| 648 |
+
answer_with_citations = self._call_api(prompt)
|
| 649 |
+
|
| 650 |
+
# Extract and clean citations
|
| 651 |
+
clean_answer, citations = self._extract_citations(answer_with_citations, chunks)
|
| 652 |
+
|
| 653 |
+
# Calculate confidence
|
| 654 |
+
confidence = self._calculate_confidence(clean_answer, citations, chunks)
|
| 655 |
+
|
| 656 |
+
# Calculate generation time
|
| 657 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 658 |
+
|
| 659 |
+
return GeneratedAnswer(
|
| 660 |
+
answer=clean_answer,
|
| 661 |
+
citations=citations,
|
| 662 |
+
confidence_score=confidence,
|
| 663 |
+
generation_time=generation_time,
|
| 664 |
+
model_used=self.model_name,
|
| 665 |
+
context_used=chunks
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
except Exception as e:
|
| 669 |
+
logger.error(f"Error generating answer: {e}")
|
| 670 |
+
return GeneratedAnswer(
|
| 671 |
+
answer="I apologize, but I encountered an error while generating the answer. Please try again.",
|
| 672 |
+
citations=[],
|
| 673 |
+
confidence_score=0.0,
|
| 674 |
+
generation_time=0.0,
|
| 675 |
+
model_used=self.model_name,
|
| 676 |
+
context_used=chunks
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
def generate_with_custom_prompt(
|
| 680 |
+
self,
|
| 681 |
+
query: str,
|
| 682 |
+
chunks: List[Dict[str, Any]],
|
| 683 |
+
custom_prompt: Dict[str, str]
|
| 684 |
+
) -> GeneratedAnswer:
|
| 685 |
+
"""
|
| 686 |
+
Generate answer using a custom prompt (for adaptive prompting).
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
query: User's question
|
| 690 |
+
chunks: Retrieved context chunks
|
| 691 |
+
custom_prompt: Dict with 'system' and 'user' prompts
|
| 692 |
+
|
| 693 |
+
Returns:
|
| 694 |
+
GeneratedAnswer with custom prompt enhancement
|
| 695 |
+
"""
|
| 696 |
+
start_time = datetime.now()
|
| 697 |
+
|
| 698 |
+
# Format context
|
| 699 |
+
context = self._format_context(chunks)
|
| 700 |
+
|
| 701 |
+
# Build prompt using custom format
|
| 702 |
+
if "llama" in self.model_name.lower():
|
| 703 |
+
prompt = f"""[INST] {custom_prompt['system']}
|
| 704 |
+
|
| 705 |
+
{custom_prompt['user']}
|
| 706 |
+
|
| 707 |
+
MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
|
| 708 |
+
elif "mistral" in self.model_name.lower():
|
| 709 |
+
prompt = f"""[INST] {custom_prompt['system']}
|
| 710 |
+
|
| 711 |
+
{custom_prompt['user']}
|
| 712 |
+
|
| 713 |
+
MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
|
| 714 |
+
elif "distilbart" in self.model_name.lower():
|
| 715 |
+
# For BART, use the user prompt directly (it already contains context)
|
| 716 |
+
prompt = custom_prompt['user']
|
| 717 |
+
else:
|
| 718 |
+
# Default format
|
| 719 |
+
prompt = f"""{custom_prompt['system']}
|
| 720 |
+
|
| 721 |
+
{custom_prompt['user']}
|
| 722 |
+
|
| 723 |
+
MANDATORY: Use [chunk_1], [chunk_2] etc. for all factual statements.
|
| 724 |
+
|
| 725 |
+
Answer:"""
|
| 726 |
+
|
| 727 |
+
# Generate response
|
| 728 |
+
try:
|
| 729 |
+
answer_with_citations = self._call_api(prompt)
|
| 730 |
+
|
| 731 |
+
# Extract and clean citations
|
| 732 |
+
clean_answer, citations = self._extract_citations(answer_with_citations, chunks)
|
| 733 |
+
|
| 734 |
+
# Calculate confidence
|
| 735 |
+
confidence = self._calculate_confidence(clean_answer, citations, chunks)
|
| 736 |
+
|
| 737 |
+
# Calculate generation time
|
| 738 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 739 |
+
|
| 740 |
+
return GeneratedAnswer(
|
| 741 |
+
answer=clean_answer,
|
| 742 |
+
citations=citations,
|
| 743 |
+
confidence_score=confidence,
|
| 744 |
+
generation_time=generation_time,
|
| 745 |
+
model_used=self.model_name,
|
| 746 |
+
context_used=chunks
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
except Exception as e:
|
| 750 |
+
logger.error(f"Error generating answer with custom prompt: {e}")
|
| 751 |
+
return GeneratedAnswer(
|
| 752 |
+
answer="I apologize, but I encountered an error while generating the answer. Please try again.",
|
| 753 |
+
citations=[],
|
| 754 |
+
confidence_score=0.0,
|
| 755 |
+
generation_time=0.0,
|
| 756 |
+
model_used=self.model_name,
|
| 757 |
+
context_used=chunks
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
def format_answer_with_citations(self, generated_answer: GeneratedAnswer) -> str:
|
| 761 |
+
"""
|
| 762 |
+
Format the generated answer with citations for display.
|
| 763 |
+
|
| 764 |
+
Args:
|
| 765 |
+
generated_answer: GeneratedAnswer object
|
| 766 |
+
|
| 767 |
+
Returns:
|
| 768 |
+
Formatted string with answer and citations
|
| 769 |
+
"""
|
| 770 |
+
formatted = f"{generated_answer.answer}\n\n"
|
| 771 |
+
|
| 772 |
+
if generated_answer.citations:
|
| 773 |
+
formatted += "**Sources:**\n"
|
| 774 |
+
for i, citation in enumerate(generated_answer.citations, 1):
|
| 775 |
+
formatted += f"{i}. {citation.source_file} (Page {citation.page_number})\n"
|
| 776 |
+
|
| 777 |
+
formatted += f"\n*Confidence: {generated_answer.confidence_score:.1%} | "
|
| 778 |
+
formatted += f"Model: {generated_answer.model_used} | "
|
| 779 |
+
formatted += f"Time: {generated_answer.generation_time:.2f}s*"
|
| 780 |
+
|
| 781 |
+
return formatted
|
| 782 |
+
|
| 783 |
+
def _documents_to_hf_chunks(self, documents: List[Document]) -> List[Dict[str, Any]]:
|
| 784 |
+
"""
|
| 785 |
+
Convert Document objects to HuggingFace's internal chunk format.
|
| 786 |
+
|
| 787 |
+
This internal adapter ensures that Document objects are properly formatted
|
| 788 |
+
for HuggingFace's processing pipeline while keeping the format requirements
|
| 789 |
+
encapsulated within this class.
|
| 790 |
+
|
| 791 |
+
Args:
|
| 792 |
+
documents: List of Document objects from the standard interface
|
| 793 |
+
|
| 794 |
+
Returns:
|
| 795 |
+
List of chunk dictionaries in HuggingFace's expected format
|
| 796 |
+
"""
|
| 797 |
+
if not documents:
|
| 798 |
+
return []
|
| 799 |
+
|
| 800 |
+
chunks = []
|
| 801 |
+
for i, doc in enumerate(documents):
|
| 802 |
+
chunk = {
|
| 803 |
+
"id": f"chunk_{i+1}",
|
| 804 |
+
"content": doc.content, # HuggingFace expects "content" field
|
| 805 |
+
"text": doc.content, # Alternative field for compatibility
|
| 806 |
+
"score": 1.0, # Default relevance score
|
| 807 |
+
"metadata": {
|
| 808 |
+
"page_number": doc.metadata.get("start_page", 1),
|
| 809 |
+
"source": doc.metadata.get("source", "unknown"),
|
| 810 |
+
**doc.metadata # Include all original metadata
|
| 811 |
+
}
|
| 812 |
+
}
|
| 813 |
+
chunks.append(chunk)
|
| 814 |
+
|
| 815 |
+
return chunks
|
| 816 |
+
|
| 817 |
+
def _hf_result_to_answer(self, hf_result: GeneratedAnswer, original_context: List[Document]) -> Answer:
|
| 818 |
+
"""
|
| 819 |
+
Convert HuggingFace's GeneratedAnswer to the standard Answer format.
|
| 820 |
+
|
| 821 |
+
This internal adapter converts HuggingFace's result format back to the
|
| 822 |
+
standard interface format expected by the rest of the system.
|
| 823 |
+
|
| 824 |
+
Args:
|
| 825 |
+
hf_result: Result from HuggingFace's internal generation
|
| 826 |
+
original_context: Original Document objects for sources
|
| 827 |
+
|
| 828 |
+
Returns:
|
| 829 |
+
Answer object conforming to standard interface
|
| 830 |
+
"""
|
| 831 |
+
if Answer is None:
|
| 832 |
+
# Fallback if standard interface not available
|
| 833 |
+
return hf_result
|
| 834 |
+
|
| 835 |
+
# Convert to standard Answer format
|
| 836 |
+
return Answer(
|
| 837 |
+
text=hf_result.answer,
|
| 838 |
+
sources=original_context, # Use original Document objects
|
| 839 |
+
confidence=hf_result.confidence_score,
|
| 840 |
+
metadata={
|
| 841 |
+
"model_used": hf_result.model_used,
|
| 842 |
+
"generation_time": hf_result.generation_time,
|
| 843 |
+
"citations": [
|
| 844 |
+
{
|
| 845 |
+
"chunk_id": cit.chunk_id,
|
| 846 |
+
"page_number": cit.page_number,
|
| 847 |
+
"source_file": cit.source_file,
|
| 848 |
+
"relevance_score": cit.relevance_score,
|
| 849 |
+
"text_snippet": cit.text_snippet
|
| 850 |
+
}
|
| 851 |
+
for cit in hf_result.citations
|
| 852 |
+
],
|
| 853 |
+
"provider": "huggingface",
|
| 854 |
+
"api_token_used": bool(self.api_token),
|
| 855 |
+
"fallback_used": hasattr(self, '_auth_failed') and self._auth_failed
|
| 856 |
+
}
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
if __name__ == "__main__":
|
| 861 |
+
# Example usage
|
| 862 |
+
generator = HuggingFaceAnswerGenerator()
|
| 863 |
+
|
| 864 |
+
# Example chunks (would come from retrieval system)
|
| 865 |
+
example_chunks = [
|
| 866 |
+
{
|
| 867 |
+
"id": "chunk_1",
|
| 868 |
+
"content": "RISC-V is an open-source instruction set architecture (ISA) based on reduced instruction set computer (RISC) principles.",
|
| 869 |
+
"metadata": {"page_number": 1, "source": "riscv-spec.pdf"},
|
| 870 |
+
"score": 0.95
|
| 871 |
+
}
|
| 872 |
+
]
|
| 873 |
+
|
| 874 |
+
# Generate answer
|
| 875 |
+
result = generator.generate(
|
| 876 |
+
query="What is RISC-V?",
|
| 877 |
+
chunks=example_chunks
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
# Display formatted result
|
| 881 |
+
print(generator.format_answer_with_citations(result))
|
shared_utils/generation/inference_providers_generator.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
HuggingFace Inference Providers API-based answer generation.
|
| 4 |
+
|
| 5 |
+
This module provides answer generation using HuggingFace's new Inference Providers API,
|
| 6 |
+
which offers OpenAI-compatible chat completion format for better reliability and consistency.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import logging
|
| 12 |
+
import time
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import re
|
| 17 |
+
|
| 18 |
+
# Import shared components
|
| 19 |
+
from .hf_answer_generator import Citation, GeneratedAnswer
|
| 20 |
+
from .prompt_templates import TechnicalPromptTemplates
|
| 21 |
+
|
| 22 |
+
# Check if huggingface_hub is new enough for InferenceClient chat completion
|
| 23 |
+
try:
|
| 24 |
+
from huggingface_hub import InferenceClient
|
| 25 |
+
from huggingface_hub import __version__ as hf_hub_version
|
| 26 |
+
print(f"🔍 Using huggingface_hub version: {hf_hub_version}", file=sys.stderr, flush=True)
|
| 27 |
+
except ImportError:
|
| 28 |
+
print("❌ huggingface_hub not found or outdated. Please install: pip install -U huggingface-hub", file=sys.stderr, flush=True)
|
| 29 |
+
raise
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class InferenceProvidersGenerator:
|
| 35 |
+
"""
|
| 36 |
+
Generates answers using HuggingFace Inference Providers API.
|
| 37 |
+
|
| 38 |
+
This uses the new OpenAI-compatible chat completion format for better reliability
|
| 39 |
+
compared to the classic Inference API. It provides:
|
| 40 |
+
- Consistent response format across models
|
| 41 |
+
- Better error handling and retry logic
|
| 42 |
+
- Support for streaming responses
|
| 43 |
+
- Automatic provider selection and failover
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
# Models that work well with chat completion format
|
| 47 |
+
CHAT_MODELS = [
|
| 48 |
+
"microsoft/DialoGPT-medium", # Proven conversational model
|
| 49 |
+
"google/gemma-2-2b-it", # Instruction-tuned, good for Q&A
|
| 50 |
+
"meta-llama/Llama-3.2-3B-Instruct", # If available with token
|
| 51 |
+
"Qwen/Qwen2.5-1.5B-Instruct", # Small, fast, good quality
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
# Fallback to classic API models if chat completion fails
|
| 55 |
+
CLASSIC_FALLBACK_MODELS = [
|
| 56 |
+
"google/flan-t5-small", # Good for instructions
|
| 57 |
+
"deepset/roberta-base-squad2", # Q&A specific
|
| 58 |
+
"facebook/bart-base", # Summarization
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
model_name: Optional[str] = None,
|
| 64 |
+
api_token: Optional[str] = None,
|
| 65 |
+
temperature: float = 0.3,
|
| 66 |
+
max_tokens: int = 512,
|
| 67 |
+
timeout: int = 30
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Initialize the Inference Providers answer generator.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
model_name: Model to use (defaults to first available chat model)
|
| 74 |
+
api_token: HF API token (uses env vars if not provided)
|
| 75 |
+
temperature: Generation temperature (0.0-1.0)
|
| 76 |
+
max_tokens: Maximum tokens to generate
|
| 77 |
+
timeout: Request timeout in seconds
|
| 78 |
+
"""
|
| 79 |
+
# Get API token from various sources
|
| 80 |
+
self.api_token = (
|
| 81 |
+
api_token or
|
| 82 |
+
os.getenv("HUGGINGFACE_API_TOKEN") or
|
| 83 |
+
os.getenv("HF_TOKEN") or
|
| 84 |
+
os.getenv("HF_API_TOKEN")
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if not self.api_token:
|
| 88 |
+
print("⚠️ No HF API token found. Inference Providers requires authentication.", file=sys.stderr, flush=True)
|
| 89 |
+
print("Set HF_TOKEN, HUGGINGFACE_API_TOKEN, or HF_API_TOKEN environment variable.", file=sys.stderr, flush=True)
|
| 90 |
+
raise ValueError("HuggingFace API token required for Inference Providers")
|
| 91 |
+
|
| 92 |
+
print(f"✅ Found HF token (starts with: {self.api_token[:8]}...)", file=sys.stderr, flush=True)
|
| 93 |
+
|
| 94 |
+
# Initialize client with token
|
| 95 |
+
self.client = InferenceClient(token=self.api_token)
|
| 96 |
+
self.temperature = temperature
|
| 97 |
+
self.max_tokens = max_tokens
|
| 98 |
+
self.timeout = timeout
|
| 99 |
+
|
| 100 |
+
# Select model
|
| 101 |
+
self.model_name = model_name or self.CHAT_MODELS[0]
|
| 102 |
+
self.using_chat_completion = True
|
| 103 |
+
|
| 104 |
+
print(f"🚀 Initialized Inference Providers with model: {self.model_name}", file=sys.stderr, flush=True)
|
| 105 |
+
|
| 106 |
+
# Test the connection
|
| 107 |
+
self._test_connection()
|
| 108 |
+
|
| 109 |
+
def _test_connection(self):
|
| 110 |
+
"""Test if the API is accessible and model is available."""
|
| 111 |
+
print(f"🔧 Testing Inference Providers API connection...", file=sys.stderr, flush=True)
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
# Try a simple test query
|
| 115 |
+
test_messages = [
|
| 116 |
+
{"role": "user", "content": "Hello"}
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
# First try chat completion (preferred)
|
| 120 |
+
try:
|
| 121 |
+
response = self.client.chat_completion(
|
| 122 |
+
messages=test_messages,
|
| 123 |
+
model=self.model_name,
|
| 124 |
+
max_tokens=10,
|
| 125 |
+
temperature=0.1
|
| 126 |
+
)
|
| 127 |
+
print(f"✅ Chat completion API working with {self.model_name}", file=sys.stderr, flush=True)
|
| 128 |
+
self.using_chat_completion = True
|
| 129 |
+
return
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"��️ Chat completion failed for {self.model_name}: {e}", file=sys.stderr, flush=True)
|
| 132 |
+
|
| 133 |
+
# Try other chat models
|
| 134 |
+
for model in self.CHAT_MODELS:
|
| 135 |
+
if model != self.model_name:
|
| 136 |
+
try:
|
| 137 |
+
print(f"🔄 Trying {model}...", file=sys.stderr, flush=True)
|
| 138 |
+
response = self.client.chat_completion(
|
| 139 |
+
messages=test_messages,
|
| 140 |
+
model=model,
|
| 141 |
+
max_tokens=10
|
| 142 |
+
)
|
| 143 |
+
print(f"✅ Found working model: {model}", file=sys.stderr, flush=True)
|
| 144 |
+
self.model_name = model
|
| 145 |
+
self.using_chat_completion = True
|
| 146 |
+
return
|
| 147 |
+
except:
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
# If chat completion fails, test classic text generation
|
| 151 |
+
print("🔄 Falling back to classic text generation API...", file=sys.stderr, flush=True)
|
| 152 |
+
for model in self.CLASSIC_FALLBACK_MODELS:
|
| 153 |
+
try:
|
| 154 |
+
response = self.client.text_generation(
|
| 155 |
+
model=model,
|
| 156 |
+
prompt="Hello",
|
| 157 |
+
max_new_tokens=10
|
| 158 |
+
)
|
| 159 |
+
print(f"✅ Classic API working with fallback model: {model}", file=sys.stderr, flush=True)
|
| 160 |
+
self.model_name = model
|
| 161 |
+
self.using_chat_completion = False
|
| 162 |
+
return
|
| 163 |
+
except:
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
raise Exception("No working models found in Inference Providers API")
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
print(f"❌ Inference Providers API test failed: {e}", file=sys.stderr, flush=True)
|
| 170 |
+
raise
|
| 171 |
+
|
| 172 |
+
def _format_context(self, chunks: List[Dict[str, Any]]) -> str:
|
| 173 |
+
"""Format retrieved chunks into context string."""
|
| 174 |
+
context_parts = []
|
| 175 |
+
|
| 176 |
+
for i, chunk in enumerate(chunks):
|
| 177 |
+
chunk_text = chunk.get('content', chunk.get('text', ''))
|
| 178 |
+
page_num = chunk.get('metadata', {}).get('page_number', 'unknown')
|
| 179 |
+
source = chunk.get('metadata', {}).get('source', 'unknown')
|
| 180 |
+
|
| 181 |
+
context_parts.append(
|
| 182 |
+
f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
return "\n---\n".join(context_parts)
|
| 186 |
+
|
| 187 |
+
def _create_messages(self, query: str, context: str) -> List[Dict[str, str]]:
|
| 188 |
+
"""Create chat messages using TechnicalPromptTemplates."""
|
| 189 |
+
# Get appropriate template based on query type
|
| 190 |
+
prompt_data = TechnicalPromptTemplates.format_prompt_with_template(
|
| 191 |
+
query=query,
|
| 192 |
+
context=context
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Create messages for chat completion
|
| 196 |
+
messages = [
|
| 197 |
+
{
|
| 198 |
+
"role": "system",
|
| 199 |
+
"content": prompt_data['system'] + "\n\nMANDATORY: Use [chunk_X] citations for all facts."
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"role": "user",
|
| 203 |
+
"content": prompt_data['user']
|
| 204 |
+
}
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
return messages
|
| 208 |
+
|
| 209 |
+
def _call_chat_completion(self, messages: List[Dict[str, str]]) -> str:
|
| 210 |
+
"""Call the chat completion API."""
|
| 211 |
+
try:
|
| 212 |
+
print(f"🤖 Calling Inference Providers chat completion with {self.model_name}...", file=sys.stderr, flush=True)
|
| 213 |
+
|
| 214 |
+
# Use chat completion with proper error handling
|
| 215 |
+
response = self.client.chat_completion(
|
| 216 |
+
messages=messages,
|
| 217 |
+
model=self.model_name,
|
| 218 |
+
temperature=self.temperature,
|
| 219 |
+
max_tokens=self.max_tokens,
|
| 220 |
+
stream=False
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Extract content from response
|
| 224 |
+
if hasattr(response, 'choices') and response.choices:
|
| 225 |
+
content = response.choices[0].message.content
|
| 226 |
+
print(f"✅ Got response: {len(content)} characters", file=sys.stderr, flush=True)
|
| 227 |
+
return content
|
| 228 |
+
else:
|
| 229 |
+
print(f"⚠️ Unexpected response format: {response}", file=sys.stderr, flush=True)
|
| 230 |
+
return str(response)
|
| 231 |
+
|
| 232 |
+
except Exception as e:
|
| 233 |
+
print(f"❌ Chat completion error: {e}", file=sys.stderr, flush=True)
|
| 234 |
+
|
| 235 |
+
# Try with a fallback model
|
| 236 |
+
if self.model_name != "microsoft/DialoGPT-medium":
|
| 237 |
+
print("🔄 Trying fallback model: microsoft/DialoGPT-medium", file=sys.stderr, flush=True)
|
| 238 |
+
try:
|
| 239 |
+
response = self.client.chat_completion(
|
| 240 |
+
messages=messages,
|
| 241 |
+
model="microsoft/DialoGPT-medium",
|
| 242 |
+
temperature=self.temperature,
|
| 243 |
+
max_tokens=self.max_tokens
|
| 244 |
+
)
|
| 245 |
+
if hasattr(response, 'choices') and response.choices:
|
| 246 |
+
return response.choices[0].message.content
|
| 247 |
+
except:
|
| 248 |
+
pass
|
| 249 |
+
|
| 250 |
+
raise Exception(f"Chat completion failed: {e}")
|
| 251 |
+
|
| 252 |
+
def _call_classic_api(self, query: str, context: str) -> str:
|
| 253 |
+
"""Fallback to classic text generation API."""
|
| 254 |
+
print(f"🔄 Using classic text generation with {self.model_name}...", file=sys.stderr, flush=True)
|
| 255 |
+
|
| 256 |
+
# Format prompt for classic API
|
| 257 |
+
if "squad" in self.model_name.lower():
|
| 258 |
+
# Q&A format for squad models
|
| 259 |
+
prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
|
| 260 |
+
elif "flan" in self.model_name.lower():
|
| 261 |
+
# Instruction format for Flan models
|
| 262 |
+
prompt = f"Answer the question based on the context.\n\nContext: {context}\n\nQuestion: {query}\n\nAnswer:"
|
| 263 |
+
else:
|
| 264 |
+
# Generic format
|
| 265 |
+
prompt = f"Based on the following context, answer the question.\n\nContext:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
| 266 |
+
|
| 267 |
+
try:
|
| 268 |
+
response = self.client.text_generation(
|
| 269 |
+
model=self.model_name,
|
| 270 |
+
prompt=prompt,
|
| 271 |
+
max_new_tokens=self.max_tokens,
|
| 272 |
+
temperature=self.temperature
|
| 273 |
+
)
|
| 274 |
+
return response
|
| 275 |
+
except Exception as e:
|
| 276 |
+
print(f"❌ Classic API error: {e}", file=sys.stderr, flush=True)
|
| 277 |
+
return f"Error generating response: {str(e)}"
|
| 278 |
+
|
| 279 |
+
def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]:
|
| 280 |
+
"""Extract citations from the answer."""
|
| 281 |
+
citations = []
|
| 282 |
+
citation_pattern = r'\[chunk_(\d+)\]'
|
| 283 |
+
|
| 284 |
+
cited_chunks = set()
|
| 285 |
+
|
| 286 |
+
# Find explicit citations
|
| 287 |
+
matches = re.finditer(citation_pattern, answer)
|
| 288 |
+
for match in matches:
|
| 289 |
+
chunk_idx = int(match.group(1)) - 1
|
| 290 |
+
if 0 <= chunk_idx < len(chunks):
|
| 291 |
+
cited_chunks.add(chunk_idx)
|
| 292 |
+
|
| 293 |
+
# Fallback: Create citations for top chunks if none found
|
| 294 |
+
if not cited_chunks and chunks and len(answer.strip()) > 50:
|
| 295 |
+
num_fallback = min(3, len(chunks))
|
| 296 |
+
cited_chunks = set(range(num_fallback))
|
| 297 |
+
print(f"🔧 Creating {num_fallback} fallback citations", file=sys.stderr, flush=True)
|
| 298 |
+
|
| 299 |
+
# Create Citation objects
|
| 300 |
+
chunk_to_source = {}
|
| 301 |
+
for idx in cited_chunks:
|
| 302 |
+
chunk = chunks[idx]
|
| 303 |
+
citation = Citation(
|
| 304 |
+
chunk_id=chunk.get('id', f'chunk_{idx}'),
|
| 305 |
+
page_number=chunk.get('metadata', {}).get('page_number', 0),
|
| 306 |
+
source_file=chunk.get('metadata', {}).get('source', 'unknown'),
|
| 307 |
+
relevance_score=chunk.get('score', 0.0),
|
| 308 |
+
text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...'
|
| 309 |
+
)
|
| 310 |
+
citations.append(citation)
|
| 311 |
+
|
| 312 |
+
# Map for natural language replacement
|
| 313 |
+
source_name = chunk.get('metadata', {}).get('source', 'unknown')
|
| 314 |
+
if source_name != 'unknown':
|
| 315 |
+
natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ')
|
| 316 |
+
chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation"
|
| 317 |
+
else:
|
| 318 |
+
chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation"
|
| 319 |
+
|
| 320 |
+
# Replace citations with natural language
|
| 321 |
+
natural_answer = answer
|
| 322 |
+
for chunk_ref, natural_ref in chunk_to_source.items():
|
| 323 |
+
natural_answer = natural_answer.replace(chunk_ref, natural_ref)
|
| 324 |
+
|
| 325 |
+
# Clean up any remaining citations
|
| 326 |
+
natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer)
|
| 327 |
+
natural_answer = re.sub(r'\s+', ' ', natural_answer).strip()
|
| 328 |
+
|
| 329 |
+
return natural_answer, citations
|
| 330 |
+
|
| 331 |
+
def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float:
|
| 332 |
+
"""Calculate confidence score for the answer."""
|
| 333 |
+
if not answer or len(answer.strip()) < 10:
|
| 334 |
+
return 0.1
|
| 335 |
+
|
| 336 |
+
# Base confidence from chunk quality
|
| 337 |
+
if len(chunks) >= 3:
|
| 338 |
+
confidence = 0.8
|
| 339 |
+
elif len(chunks) >= 2:
|
| 340 |
+
confidence = 0.7
|
| 341 |
+
else:
|
| 342 |
+
confidence = 0.6
|
| 343 |
+
|
| 344 |
+
# Citation bonus
|
| 345 |
+
if citations and chunks:
|
| 346 |
+
citation_ratio = len(citations) / min(len(chunks), 3)
|
| 347 |
+
confidence += 0.15 * citation_ratio
|
| 348 |
+
|
| 349 |
+
# Check for uncertainty phrases
|
| 350 |
+
uncertainty_phrases = [
|
| 351 |
+
"insufficient information",
|
| 352 |
+
"cannot determine",
|
| 353 |
+
"not available in the provided documents",
|
| 354 |
+
"i don't know",
|
| 355 |
+
"unclear"
|
| 356 |
+
]
|
| 357 |
+
|
| 358 |
+
if any(phrase in answer.lower() for phrase in uncertainty_phrases):
|
| 359 |
+
confidence *= 0.3
|
| 360 |
+
|
| 361 |
+
return min(confidence, 0.95)
|
| 362 |
+
|
| 363 |
+
def generate(self, query: str, chunks: List[Dict[str, Any]]) -> GeneratedAnswer:
|
| 364 |
+
"""
|
| 365 |
+
Generate an answer using Inference Providers API.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
query: User's question
|
| 369 |
+
chunks: Retrieved document chunks
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
GeneratedAnswer with answer, citations, and metadata
|
| 373 |
+
"""
|
| 374 |
+
start_time = datetime.now()
|
| 375 |
+
|
| 376 |
+
# Check for no-context situation
|
| 377 |
+
if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks):
|
| 378 |
+
return GeneratedAnswer(
|
| 379 |
+
answer="This information isn't available in the provided documents.",
|
| 380 |
+
citations=[],
|
| 381 |
+
confidence_score=0.05,
|
| 382 |
+
generation_time=0.1,
|
| 383 |
+
model_used=self.model_name,
|
| 384 |
+
context_used=chunks
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# Format context
|
| 388 |
+
context = self._format_context(chunks)
|
| 389 |
+
|
| 390 |
+
# Generate answer
|
| 391 |
+
try:
|
| 392 |
+
if self.using_chat_completion:
|
| 393 |
+
# Create chat messages
|
| 394 |
+
messages = self._create_messages(query, context)
|
| 395 |
+
|
| 396 |
+
# Call chat completion API
|
| 397 |
+
answer_text = self._call_chat_completion(messages)
|
| 398 |
+
else:
|
| 399 |
+
# Fallback to classic API
|
| 400 |
+
answer_text = self._call_classic_api(query, context)
|
| 401 |
+
|
| 402 |
+
# Extract citations and clean answer
|
| 403 |
+
natural_answer, citations = self._extract_citations(answer_text, chunks)
|
| 404 |
+
|
| 405 |
+
# Calculate confidence
|
| 406 |
+
confidence = self._calculate_confidence(natural_answer, citations, chunks)
|
| 407 |
+
|
| 408 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 409 |
+
|
| 410 |
+
return GeneratedAnswer(
|
| 411 |
+
answer=natural_answer,
|
| 412 |
+
citations=citations,
|
| 413 |
+
confidence_score=confidence,
|
| 414 |
+
generation_time=generation_time,
|
| 415 |
+
model_used=self.model_name,
|
| 416 |
+
context_used=chunks
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
except Exception as e:
|
| 420 |
+
logger.error(f"Error generating answer: {e}")
|
| 421 |
+
print(f"❌ Generation failed: {e}", file=sys.stderr, flush=True)
|
| 422 |
+
|
| 423 |
+
# Return error response
|
| 424 |
+
return GeneratedAnswer(
|
| 425 |
+
answer="I apologize, but I encountered an error while generating the answer. Please try again.",
|
| 426 |
+
citations=[],
|
| 427 |
+
confidence_score=0.0,
|
| 428 |
+
generation_time=(datetime.now() - start_time).total_seconds(),
|
| 429 |
+
model_used=self.model_name,
|
| 430 |
+
context_used=chunks
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
def generate_with_custom_prompt(
|
| 434 |
+
self,
|
| 435 |
+
query: str,
|
| 436 |
+
chunks: List[Dict[str, Any]],
|
| 437 |
+
custom_prompt: Dict[str, str]
|
| 438 |
+
) -> GeneratedAnswer:
|
| 439 |
+
"""
|
| 440 |
+
Generate answer using a custom prompt (for adaptive prompting).
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
query: User's question
|
| 444 |
+
chunks: Retrieved context chunks
|
| 445 |
+
custom_prompt: Dict with 'system' and 'user' prompts
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
GeneratedAnswer with custom prompt enhancement
|
| 449 |
+
"""
|
| 450 |
+
start_time = datetime.now()
|
| 451 |
+
|
| 452 |
+
if not chunks:
|
| 453 |
+
return GeneratedAnswer(
|
| 454 |
+
answer="I don't have enough context to answer your question.",
|
| 455 |
+
citations=[],
|
| 456 |
+
confidence_score=0.0,
|
| 457 |
+
generation_time=0.1,
|
| 458 |
+
model_used=self.model_name,
|
| 459 |
+
context_used=chunks
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
try:
|
| 463 |
+
# Try chat completion with custom prompt
|
| 464 |
+
messages = [
|
| 465 |
+
{"role": "system", "content": custom_prompt['system']},
|
| 466 |
+
{"role": "user", "content": custom_prompt['user']}
|
| 467 |
+
]
|
| 468 |
+
|
| 469 |
+
answer_text = self._call_chat_completion(messages)
|
| 470 |
+
|
| 471 |
+
# Extract citations and clean answer
|
| 472 |
+
natural_answer, citations = self._extract_citations(answer_text, chunks)
|
| 473 |
+
|
| 474 |
+
# Calculate confidence
|
| 475 |
+
confidence = self._calculate_confidence(natural_answer, citations, chunks)
|
| 476 |
+
|
| 477 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 478 |
+
|
| 479 |
+
return GeneratedAnswer(
|
| 480 |
+
answer=natural_answer,
|
| 481 |
+
citations=citations,
|
| 482 |
+
confidence_score=confidence,
|
| 483 |
+
generation_time=generation_time,
|
| 484 |
+
model_used=self.model_name,
|
| 485 |
+
context_used=chunks
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
except Exception as e:
|
| 489 |
+
logger.error(f"Error generating answer with custom prompt: {e}")
|
| 490 |
+
print(f"❌ Custom prompt generation failed: {e}", file=sys.stderr, flush=True)
|
| 491 |
+
|
| 492 |
+
# Return error response
|
| 493 |
+
return GeneratedAnswer(
|
| 494 |
+
answer="I apologize, but I encountered an error while generating the answer. Please try again.",
|
| 495 |
+
citations=[],
|
| 496 |
+
confidence_score=0.0,
|
| 497 |
+
generation_time=(datetime.now() - start_time).total_seconds(),
|
| 498 |
+
model_used=self.model_name,
|
| 499 |
+
context_used=chunks
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
# Example usage
|
| 504 |
+
if __name__ == "__main__":
|
| 505 |
+
# Test the generator
|
| 506 |
+
print("Testing Inference Providers Generator...")
|
| 507 |
+
|
| 508 |
+
try:
|
| 509 |
+
generator = InferenceProvidersGenerator()
|
| 510 |
+
|
| 511 |
+
# Test chunks
|
| 512 |
+
test_chunks = [
|
| 513 |
+
{
|
| 514 |
+
"content": "RISC-V is an open-source instruction set architecture (ISA) based on established reduced instruction set computer (RISC) principles.",
|
| 515 |
+
"metadata": {"page_number": 1, "source": "riscv-spec.pdf"},
|
| 516 |
+
"score": 0.95
|
| 517 |
+
},
|
| 518 |
+
{
|
| 519 |
+
"content": "Unlike most other ISA designs, RISC-V is provided under open source licenses that do not require fees to use.",
|
| 520 |
+
"metadata": {"page_number": 2, "source": "riscv-spec.pdf"},
|
| 521 |
+
"score": 0.85
|
| 522 |
+
}
|
| 523 |
+
]
|
| 524 |
+
|
| 525 |
+
# Generate answer
|
| 526 |
+
result = generator.generate("What is RISC-V and why is it important?", test_chunks)
|
| 527 |
+
|
| 528 |
+
print(f"\n📝 Answer: {result.answer}")
|
| 529 |
+
print(f"📊 Confidence: {result.confidence_score:.1%}")
|
| 530 |
+
print(f"⏱️ Generation time: {result.generation_time:.2f}s")
|
| 531 |
+
print(f"🤖 Model: {result.model_used}")
|
| 532 |
+
print(f"📚 Citations: {len(result.citations)}")
|
| 533 |
+
|
| 534 |
+
except Exception as e:
|
| 535 |
+
print(f"❌ Test failed: {e}")
|
| 536 |
+
import traceback
|
| 537 |
+
traceback.print_exc()
|
shared_utils/generation/ollama_answer_generator.py
ADDED
|
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Ollama-based answer generator for local inference.
|
| 4 |
+
|
| 5 |
+
Provides the same interface as HuggingFaceAnswerGenerator but uses
|
| 6 |
+
local Ollama server for model inference.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import time
|
| 10 |
+
import requests
|
| 11 |
+
import json
|
| 12 |
+
import re
|
| 13 |
+
import sys
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, List, Optional, Any, Tuple
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
|
| 19 |
+
# Import shared components
|
| 20 |
+
from .hf_answer_generator import Citation, GeneratedAnswer
|
| 21 |
+
from .prompt_templates import TechnicalPromptTemplates
|
| 22 |
+
|
| 23 |
+
# Import standard interfaces (add this for the adapter)
|
| 24 |
+
try:
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
import sys
|
| 27 |
+
project_root = Path(__file__).parent.parent.parent.parent.parent
|
| 28 |
+
sys.path.append(str(project_root))
|
| 29 |
+
from src.core.interfaces import Document, Answer, AnswerGenerator
|
| 30 |
+
except ImportError:
|
| 31 |
+
# Fallback for standalone usage
|
| 32 |
+
Document = None
|
| 33 |
+
Answer = None
|
| 34 |
+
AnswerGenerator = object
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class OllamaAnswerGenerator(AnswerGenerator if AnswerGenerator != object else object):
|
| 38 |
+
"""
|
| 39 |
+
Generates answers using local Ollama server.
|
| 40 |
+
|
| 41 |
+
Perfect for:
|
| 42 |
+
- Local development
|
| 43 |
+
- Privacy-sensitive applications
|
| 44 |
+
- No API rate limits
|
| 45 |
+
- Consistent performance
|
| 46 |
+
- Offline operation
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
model_name: str = "llama3.2:3b",
|
| 52 |
+
base_url: str = "http://localhost:11434",
|
| 53 |
+
temperature: float = 0.3,
|
| 54 |
+
max_tokens: int = 512,
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Initialize Ollama answer generator.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
model_name: Ollama model to use (e.g., "llama3.2:3b", "mistral")
|
| 61 |
+
base_url: Ollama server URL
|
| 62 |
+
temperature: Generation temperature
|
| 63 |
+
max_tokens: Maximum tokens to generate
|
| 64 |
+
"""
|
| 65 |
+
self.model_name = model_name
|
| 66 |
+
self.base_url = base_url.rstrip("/")
|
| 67 |
+
self.temperature = temperature
|
| 68 |
+
self.max_tokens = max_tokens
|
| 69 |
+
|
| 70 |
+
# Test connection
|
| 71 |
+
self._test_connection()
|
| 72 |
+
|
| 73 |
+
def _test_connection(self):
|
| 74 |
+
"""Test if Ollama server is accessible."""
|
| 75 |
+
# Reduce retries for faster initialization - container should be ready quickly
|
| 76 |
+
max_retries = 12 # Wait up to 60 seconds for Ollama to start
|
| 77 |
+
retry_delay = 5
|
| 78 |
+
|
| 79 |
+
print(
|
| 80 |
+
f"🔧 Testing connection to {self.base_url}/api/tags...",
|
| 81 |
+
file=sys.stderr,
|
| 82 |
+
flush=True,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
for attempt in range(max_retries):
|
| 86 |
+
try:
|
| 87 |
+
response = requests.get(f"{self.base_url}/api/tags", timeout=8)
|
| 88 |
+
if response.status_code == 200:
|
| 89 |
+
print(
|
| 90 |
+
f"✅ Connected to Ollama at {self.base_url}",
|
| 91 |
+
file=sys.stderr,
|
| 92 |
+
flush=True,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Check if our model is available
|
| 96 |
+
models = response.json().get("models", [])
|
| 97 |
+
model_names = [m["name"] for m in models]
|
| 98 |
+
|
| 99 |
+
if self.model_name in model_names:
|
| 100 |
+
print(
|
| 101 |
+
f"✅ Model {self.model_name} is available",
|
| 102 |
+
file=sys.stderr,
|
| 103 |
+
flush=True,
|
| 104 |
+
)
|
| 105 |
+
return # Success!
|
| 106 |
+
else:
|
| 107 |
+
print(
|
| 108 |
+
f"⚠️ Model {self.model_name} not found. Available: {model_names}",
|
| 109 |
+
file=sys.stderr,
|
| 110 |
+
flush=True,
|
| 111 |
+
)
|
| 112 |
+
if models: # If any models are available, use the first one
|
| 113 |
+
fallback_model = model_names[0]
|
| 114 |
+
print(
|
| 115 |
+
f"🔄 Using fallback model: {fallback_model}",
|
| 116 |
+
file=sys.stderr,
|
| 117 |
+
flush=True,
|
| 118 |
+
)
|
| 119 |
+
self.model_name = fallback_model
|
| 120 |
+
return
|
| 121 |
+
else:
|
| 122 |
+
print(
|
| 123 |
+
f"📥 No models found, will try to pull {self.model_name}",
|
| 124 |
+
file=sys.stderr,
|
| 125 |
+
flush=True,
|
| 126 |
+
)
|
| 127 |
+
# Try to pull the model
|
| 128 |
+
self._pull_model(self.model_name)
|
| 129 |
+
return
|
| 130 |
+
else:
|
| 131 |
+
print(f"⚠️ Ollama server returned status {response.status_code}")
|
| 132 |
+
if attempt < max_retries - 1:
|
| 133 |
+
print(
|
| 134 |
+
f"🔄 Retry {attempt + 1}/{max_retries} in {retry_delay} seconds..."
|
| 135 |
+
)
|
| 136 |
+
time.sleep(retry_delay)
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
except requests.exceptions.ConnectionError:
|
| 140 |
+
if attempt < max_retries - 1:
|
| 141 |
+
print(
|
| 142 |
+
f"⏳ Ollama not ready yet, retry {attempt + 1}/{max_retries} in {retry_delay} seconds..."
|
| 143 |
+
)
|
| 144 |
+
time.sleep(retry_delay)
|
| 145 |
+
continue
|
| 146 |
+
else:
|
| 147 |
+
raise Exception(
|
| 148 |
+
f"Cannot connect to Ollama server at {self.base_url} after 60 seconds. Check if it's running."
|
| 149 |
+
)
|
| 150 |
+
except requests.exceptions.Timeout:
|
| 151 |
+
if attempt < max_retries - 1:
|
| 152 |
+
print(f"⏳ Ollama timeout, retry {attempt + 1}/{max_retries}...")
|
| 153 |
+
time.sleep(retry_delay)
|
| 154 |
+
continue
|
| 155 |
+
else:
|
| 156 |
+
raise Exception("Ollama server timeout after multiple retries.")
|
| 157 |
+
except Exception as e:
|
| 158 |
+
if attempt < max_retries - 1:
|
| 159 |
+
print(f"⚠️ Ollama error: {e}, retry {attempt + 1}/{max_retries}...")
|
| 160 |
+
time.sleep(retry_delay)
|
| 161 |
+
continue
|
| 162 |
+
else:
|
| 163 |
+
raise Exception(
|
| 164 |
+
f"Ollama connection failed after {max_retries} attempts: {e}"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
raise Exception("Failed to connect to Ollama after all retries")
|
| 168 |
+
|
| 169 |
+
def _pull_model(self, model_name: str):
|
| 170 |
+
"""Pull a model if it's not available."""
|
| 171 |
+
try:
|
| 172 |
+
print(f"📥 Pulling model {model_name}...")
|
| 173 |
+
pull_response = requests.post(
|
| 174 |
+
f"{self.base_url}/api/pull",
|
| 175 |
+
json={"name": model_name},
|
| 176 |
+
timeout=300, # 5 minutes for model download
|
| 177 |
+
)
|
| 178 |
+
if pull_response.status_code == 200:
|
| 179 |
+
print(f"✅ Successfully pulled {model_name}")
|
| 180 |
+
else:
|
| 181 |
+
print(f"⚠️ Failed to pull {model_name}: {pull_response.status_code}")
|
| 182 |
+
# Try smaller models as fallback
|
| 183 |
+
fallback_models = ["llama3.2:1b", "llama2:latest", "mistral:latest"]
|
| 184 |
+
for fallback in fallback_models:
|
| 185 |
+
try:
|
| 186 |
+
print(f"🔄 Trying fallback model: {fallback}")
|
| 187 |
+
fallback_response = requests.post(
|
| 188 |
+
f"{self.base_url}/api/pull",
|
| 189 |
+
json={"name": fallback},
|
| 190 |
+
timeout=300,
|
| 191 |
+
)
|
| 192 |
+
if fallback_response.status_code == 200:
|
| 193 |
+
print(f"✅ Successfully pulled fallback {fallback}")
|
| 194 |
+
self.model_name = fallback
|
| 195 |
+
return
|
| 196 |
+
except:
|
| 197 |
+
continue
|
| 198 |
+
raise Exception(f"Failed to pull {model_name} or any fallback models")
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"❌ Model pull failed: {e}")
|
| 201 |
+
raise
|
| 202 |
+
|
| 203 |
+
def _format_context(self, chunks: List[Dict[str, Any]]) -> str:
|
| 204 |
+
"""Format retrieved chunks into context."""
|
| 205 |
+
context_parts = []
|
| 206 |
+
|
| 207 |
+
for i, chunk in enumerate(chunks):
|
| 208 |
+
chunk_text = chunk.get("content", chunk.get("text", ""))
|
| 209 |
+
page_num = chunk.get("metadata", {}).get("page_number", "unknown")
|
| 210 |
+
source = chunk.get("metadata", {}).get("source", "unknown")
|
| 211 |
+
|
| 212 |
+
context_parts.append(
|
| 213 |
+
f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
return "\n---\n".join(context_parts)
|
| 217 |
+
|
| 218 |
+
def _create_prompt(self, query: str, context: str, chunks: List[Dict[str, Any]]) -> str:
|
| 219 |
+
"""Create optimized prompt with dynamic length constraints and citation instructions."""
|
| 220 |
+
# Get the appropriate template based on query type
|
| 221 |
+
prompt_data = TechnicalPromptTemplates.format_prompt_with_template(
|
| 222 |
+
query=query, context=context
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Create dynamic citation instructions based on available chunks
|
| 226 |
+
num_chunks = len(chunks)
|
| 227 |
+
available_chunks = ", ".join([f"[chunk_{i+1}]" for i in range(min(num_chunks, 5))]) # Show max 5 examples
|
| 228 |
+
|
| 229 |
+
# Create appropriate example based on actual chunks
|
| 230 |
+
if num_chunks == 1:
|
| 231 |
+
citation_example = "RISC-V is an open-source ISA [chunk_1]."
|
| 232 |
+
elif num_chunks == 2:
|
| 233 |
+
citation_example = "RISC-V is an open-source ISA [chunk_1] that supports multiple data widths [chunk_2]."
|
| 234 |
+
else:
|
| 235 |
+
citation_example = "RISC-V is an open-source ISA [chunk_1] that supports multiple data widths [chunk_2] and provides extensions [chunk_3]."
|
| 236 |
+
|
| 237 |
+
# Determine optimal answer length based on query complexity
|
| 238 |
+
target_length = self._determine_target_length(query, chunks)
|
| 239 |
+
length_instruction = self._create_length_instruction(target_length)
|
| 240 |
+
|
| 241 |
+
# Format for different model types
|
| 242 |
+
if "llama" in self.model_name.lower():
|
| 243 |
+
# Llama-3.2 format with technical prompt templates
|
| 244 |
+
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
| 245 |
+
{prompt_data['system']}
|
| 246 |
+
|
| 247 |
+
MANDATORY CITATION RULES:
|
| 248 |
+
- ONLY use available chunks: {available_chunks}
|
| 249 |
+
- You have {num_chunks} chunks available - DO NOT cite chunk numbers higher than {num_chunks}
|
| 250 |
+
- Every technical claim MUST have a citation from available chunks
|
| 251 |
+
- Example: "{citation_example}"
|
| 252 |
+
|
| 253 |
+
{length_instruction}
|
| 254 |
+
|
| 255 |
+
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
| 256 |
+
{prompt_data['user']}
|
| 257 |
+
|
| 258 |
+
CRITICAL: You MUST cite sources ONLY from available chunks: {available_chunks}. DO NOT use chunk numbers > {num_chunks}.
|
| 259 |
+
{length_instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
|
| 260 |
+
|
| 261 |
+
elif "mistral" in self.model_name.lower():
|
| 262 |
+
# Mistral format with technical templates
|
| 263 |
+
return f"""[INST] {prompt_data['system']}
|
| 264 |
+
|
| 265 |
+
Context:
|
| 266 |
+
{context}
|
| 267 |
+
|
| 268 |
+
Question: {query}
|
| 269 |
+
|
| 270 |
+
MANDATORY: ONLY use available chunks: {available_chunks}. DO NOT cite chunk numbers > {num_chunks}.
|
| 271 |
+
{length_instruction} [/INST]"""
|
| 272 |
+
|
| 273 |
+
else:
|
| 274 |
+
# Generic format with technical templates
|
| 275 |
+
return f"""{prompt_data['system']}
|
| 276 |
+
|
| 277 |
+
Context:
|
| 278 |
+
{context}
|
| 279 |
+
|
| 280 |
+
Question: {query}
|
| 281 |
+
|
| 282 |
+
MANDATORY CITATIONS: ONLY use available chunks: {available_chunks}. DO NOT cite chunk numbers > {num_chunks}.
|
| 283 |
+
{length_instruction}
|
| 284 |
+
|
| 285 |
+
Answer:"""
|
| 286 |
+
|
| 287 |
+
def _determine_target_length(self, query: str, chunks: List[Dict[str, Any]]) -> int:
|
| 288 |
+
"""
|
| 289 |
+
Determine optimal answer length based on query complexity.
|
| 290 |
+
|
| 291 |
+
Target range: 150-400 characters (down from 1000-2600)
|
| 292 |
+
"""
|
| 293 |
+
# Analyze query complexity
|
| 294 |
+
query_length = len(query)
|
| 295 |
+
query_words = len(query.split())
|
| 296 |
+
|
| 297 |
+
# Check for complexity indicators
|
| 298 |
+
complex_words = [
|
| 299 |
+
"explain", "describe", "analyze", "compare", "contrast",
|
| 300 |
+
"evaluate", "discuss", "detail", "elaborate", "comprehensive"
|
| 301 |
+
]
|
| 302 |
+
|
| 303 |
+
simple_words = [
|
| 304 |
+
"what", "define", "list", "name", "identify", "is", "are"
|
| 305 |
+
]
|
| 306 |
+
|
| 307 |
+
query_lower = query.lower()
|
| 308 |
+
is_complex = any(word in query_lower for word in complex_words)
|
| 309 |
+
is_simple = any(word in query_lower for word in simple_words)
|
| 310 |
+
|
| 311 |
+
# Base length from query type
|
| 312 |
+
if is_complex:
|
| 313 |
+
base_length = 350 # Complex queries get longer answers
|
| 314 |
+
elif is_simple:
|
| 315 |
+
base_length = 200 # Simple queries get shorter answers
|
| 316 |
+
else:
|
| 317 |
+
base_length = 275 # Default middle ground
|
| 318 |
+
|
| 319 |
+
# Adjust based on available context
|
| 320 |
+
context_factor = min(len(chunks) * 25, 75) # More context allows longer answers
|
| 321 |
+
|
| 322 |
+
# Adjust based on query length
|
| 323 |
+
query_factor = min(query_words * 5, 50) # Longer queries allow longer answers
|
| 324 |
+
|
| 325 |
+
target_length = base_length + context_factor + query_factor
|
| 326 |
+
|
| 327 |
+
# Constrain to target range
|
| 328 |
+
return max(150, min(target_length, 400))
|
| 329 |
+
|
| 330 |
+
def _create_length_instruction(self, target_length: int) -> str:
|
| 331 |
+
"""Create length instruction based on target length."""
|
| 332 |
+
if target_length <= 200:
|
| 333 |
+
return f"ANSWER LENGTH: Keep your answer concise and focused, approximately {target_length} characters. Be direct and to the point."
|
| 334 |
+
elif target_length <= 300:
|
| 335 |
+
return f"ANSWER LENGTH: Provide a clear and informative answer, approximately {target_length} characters. Include key details but avoid unnecessary elaboration."
|
| 336 |
+
else:
|
| 337 |
+
return f"ANSWER LENGTH: Provide a comprehensive but concise answer, approximately {target_length} characters. Include important details while maintaining clarity."
|
| 338 |
+
|
| 339 |
+
def _call_ollama(self, prompt: str) -> str:
|
| 340 |
+
"""Call Ollama API for generation."""
|
| 341 |
+
payload = {
|
| 342 |
+
"model": self.model_name,
|
| 343 |
+
"prompt": prompt,
|
| 344 |
+
"stream": False,
|
| 345 |
+
"options": {
|
| 346 |
+
"temperature": self.temperature,
|
| 347 |
+
"num_predict": self.max_tokens,
|
| 348 |
+
"top_p": 0.9,
|
| 349 |
+
"repeat_penalty": 1.1,
|
| 350 |
+
},
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
try:
|
| 354 |
+
response = requests.post(
|
| 355 |
+
f"{self.base_url}/api/generate", json=payload, timeout=300
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
response.raise_for_status()
|
| 359 |
+
result = response.json()
|
| 360 |
+
|
| 361 |
+
return result.get("response", "").strip()
|
| 362 |
+
|
| 363 |
+
except requests.exceptions.RequestException as e:
|
| 364 |
+
print(f"❌ Ollama API error: {e}")
|
| 365 |
+
return f"Error communicating with Ollama: {str(e)}"
|
| 366 |
+
except Exception as e:
|
| 367 |
+
print(f"❌ Unexpected error: {e}")
|
| 368 |
+
return f"Unexpected error: {str(e)}"
|
| 369 |
+
|
| 370 |
+
def _extract_citations(
|
| 371 |
+
self, answer: str, chunks: List[Dict[str, Any]]
|
| 372 |
+
) -> Tuple[str, List[Citation]]:
|
| 373 |
+
"""Extract citations from the generated answer."""
|
| 374 |
+
citations = []
|
| 375 |
+
citation_pattern = r"\[chunk_(\d+)\]"
|
| 376 |
+
|
| 377 |
+
cited_chunks = set()
|
| 378 |
+
|
| 379 |
+
# Find [chunk_X] citations
|
| 380 |
+
matches = re.finditer(citation_pattern, answer)
|
| 381 |
+
for match in matches:
|
| 382 |
+
chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index
|
| 383 |
+
if 0 <= chunk_idx < len(chunks):
|
| 384 |
+
cited_chunks.add(chunk_idx)
|
| 385 |
+
|
| 386 |
+
# FALLBACK: If no explicit citations found but we have an answer and chunks,
|
| 387 |
+
# create citations for the top chunks that were likely used
|
| 388 |
+
if not cited_chunks and chunks and len(answer.strip()) > 50:
|
| 389 |
+
# Use the top chunks that were provided as likely sources
|
| 390 |
+
num_fallback_citations = min(3, len(chunks)) # Use top 3 chunks max
|
| 391 |
+
cited_chunks = set(range(num_fallback_citations))
|
| 392 |
+
print(
|
| 393 |
+
f"🔧 Fallback: Creating {num_fallback_citations} citations for answer without explicit [chunk_X] references",
|
| 394 |
+
file=sys.stderr,
|
| 395 |
+
flush=True,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# Create Citation objects
|
| 399 |
+
chunk_to_source = {}
|
| 400 |
+
for idx in cited_chunks:
|
| 401 |
+
chunk = chunks[idx]
|
| 402 |
+
citation = Citation(
|
| 403 |
+
chunk_id=chunk.get("id", f"chunk_{idx}"),
|
| 404 |
+
page_number=chunk.get("metadata", {}).get("page_number", 0),
|
| 405 |
+
source_file=chunk.get("metadata", {}).get("source", "unknown"),
|
| 406 |
+
relevance_score=chunk.get("score", 0.0),
|
| 407 |
+
text_snippet=chunk.get("content", chunk.get("text", ""))[:200] + "...",
|
| 408 |
+
)
|
| 409 |
+
citations.append(citation)
|
| 410 |
+
|
| 411 |
+
# Don't replace chunk references - keep them as proper citations
|
| 412 |
+
# The issue was that replacing [chunk_X] with "the documentation" creates repetitive text
|
| 413 |
+
# Instead, we should keep the proper citation format
|
| 414 |
+
pass
|
| 415 |
+
|
| 416 |
+
# Keep the answer as-is with proper [chunk_X] citations
|
| 417 |
+
# Don't replace citations with repetitive text
|
| 418 |
+
natural_answer = re.sub(r"\s+", " ", answer).strip()
|
| 419 |
+
|
| 420 |
+
return natural_answer, citations
|
| 421 |
+
|
| 422 |
+
def _calculate_confidence(
|
| 423 |
+
self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]
|
| 424 |
+
) -> float:
|
| 425 |
+
"""
|
| 426 |
+
Calculate confidence score with expanded multi-factor assessment.
|
| 427 |
+
|
| 428 |
+
Enhanced algorithm expands range from 0.75-0.95 to 0.3-0.9 with:
|
| 429 |
+
- Context quality assessment
|
| 430 |
+
- Citation quality evaluation
|
| 431 |
+
- Semantic relevance scoring
|
| 432 |
+
- Off-topic detection
|
| 433 |
+
- Answer completeness analysis
|
| 434 |
+
"""
|
| 435 |
+
if not answer or len(answer.strip()) < 10:
|
| 436 |
+
return 0.1
|
| 437 |
+
|
| 438 |
+
# 1. Context Quality Assessment (0.3-0.6 base range)
|
| 439 |
+
context_quality = self._assess_context_quality(chunks)
|
| 440 |
+
|
| 441 |
+
# 2. Citation Quality Evaluation (0.0-0.2 boost)
|
| 442 |
+
citation_quality = self._assess_citation_quality(citations, chunks)
|
| 443 |
+
|
| 444 |
+
# 3. Semantic Relevance Scoring (0.0-0.15 boost)
|
| 445 |
+
semantic_relevance = self._assess_semantic_relevance(answer, chunks)
|
| 446 |
+
|
| 447 |
+
# 4. Off-topic Detection (-0.4 penalty if off-topic)
|
| 448 |
+
off_topic_penalty = self._detect_off_topic(answer, chunks)
|
| 449 |
+
|
| 450 |
+
# 5. Answer Completeness Analysis (0.0-0.1 boost)
|
| 451 |
+
completeness_bonus = self._assess_answer_completeness(answer, len(chunks))
|
| 452 |
+
|
| 453 |
+
# Combine all factors
|
| 454 |
+
confidence = (
|
| 455 |
+
context_quality +
|
| 456 |
+
citation_quality +
|
| 457 |
+
semantic_relevance +
|
| 458 |
+
completeness_bonus +
|
| 459 |
+
off_topic_penalty
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Apply uncertainty penalty
|
| 463 |
+
uncertainty_phrases = [
|
| 464 |
+
"insufficient information",
|
| 465 |
+
"cannot determine",
|
| 466 |
+
"not available in the provided documents",
|
| 467 |
+
"I don't have enough context",
|
| 468 |
+
"the context doesn't seem to provide"
|
| 469 |
+
]
|
| 470 |
+
|
| 471 |
+
if any(phrase in answer.lower() for phrase in uncertainty_phrases):
|
| 472 |
+
confidence *= 0.4 # Stronger penalty for uncertainty
|
| 473 |
+
|
| 474 |
+
# Constrain to target range 0.3-0.9
|
| 475 |
+
return max(0.3, min(confidence, 0.9))
|
| 476 |
+
|
| 477 |
+
def _assess_context_quality(self, chunks: List[Dict[str, Any]]) -> float:
|
| 478 |
+
"""Assess quality of context chunks (0.3-0.6 range)."""
|
| 479 |
+
if not chunks:
|
| 480 |
+
return 0.3
|
| 481 |
+
|
| 482 |
+
# Base score from chunk count
|
| 483 |
+
if len(chunks) >= 3:
|
| 484 |
+
base_score = 0.6
|
| 485 |
+
elif len(chunks) >= 2:
|
| 486 |
+
base_score = 0.5
|
| 487 |
+
else:
|
| 488 |
+
base_score = 0.4
|
| 489 |
+
|
| 490 |
+
# Quality adjustments based on chunk content
|
| 491 |
+
avg_chunk_length = sum(len(chunk.get("content", chunk.get("text", ""))) for chunk in chunks) / len(chunks)
|
| 492 |
+
|
| 493 |
+
if avg_chunk_length > 500: # Rich content
|
| 494 |
+
base_score += 0.05
|
| 495 |
+
elif avg_chunk_length < 100: # Sparse content
|
| 496 |
+
base_score -= 0.05
|
| 497 |
+
|
| 498 |
+
return max(0.3, min(base_score, 0.6))
|
| 499 |
+
|
| 500 |
+
def _assess_citation_quality(self, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float:
|
| 501 |
+
"""Assess citation quality (0.0-0.2 range)."""
|
| 502 |
+
if not citations or not chunks:
|
| 503 |
+
return 0.0
|
| 504 |
+
|
| 505 |
+
# Citation coverage bonus
|
| 506 |
+
citation_ratio = len(citations) / min(len(chunks), 3)
|
| 507 |
+
coverage_bonus = 0.1 * citation_ratio
|
| 508 |
+
|
| 509 |
+
# Citation diversity bonus (multiple sources)
|
| 510 |
+
unique_sources = len(set(cit.source_file for cit in citations))
|
| 511 |
+
diversity_bonus = 0.05 * min(unique_sources / max(len(chunks), 1), 1.0)
|
| 512 |
+
|
| 513 |
+
return min(coverage_bonus + diversity_bonus, 0.2)
|
| 514 |
+
|
| 515 |
+
def _assess_semantic_relevance(self, answer: str, chunks: List[Dict[str, Any]]) -> float:
|
| 516 |
+
"""Assess semantic relevance between answer and context (0.0-0.15 range)."""
|
| 517 |
+
if not answer or not chunks:
|
| 518 |
+
return 0.0
|
| 519 |
+
|
| 520 |
+
# Simple keyword overlap assessment
|
| 521 |
+
answer_words = set(answer.lower().split())
|
| 522 |
+
context_words = set()
|
| 523 |
+
|
| 524 |
+
for chunk in chunks:
|
| 525 |
+
chunk_text = chunk.get("content", chunk.get("text", ""))
|
| 526 |
+
context_words.update(chunk_text.lower().split())
|
| 527 |
+
|
| 528 |
+
if not context_words:
|
| 529 |
+
return 0.0
|
| 530 |
+
|
| 531 |
+
# Calculate overlap ratio
|
| 532 |
+
overlap = len(answer_words & context_words)
|
| 533 |
+
total_unique = len(answer_words | context_words)
|
| 534 |
+
|
| 535 |
+
if total_unique == 0:
|
| 536 |
+
return 0.0
|
| 537 |
+
|
| 538 |
+
overlap_ratio = overlap / total_unique
|
| 539 |
+
return min(0.15 * overlap_ratio, 0.15)
|
| 540 |
+
|
| 541 |
+
def _detect_off_topic(self, answer: str, chunks: List[Dict[str, Any]]) -> float:
|
| 542 |
+
"""Detect if answer is off-topic (-0.4 penalty if off-topic)."""
|
| 543 |
+
if not answer or not chunks:
|
| 544 |
+
return 0.0
|
| 545 |
+
|
| 546 |
+
# Check for off-topic indicators
|
| 547 |
+
off_topic_phrases = [
|
| 548 |
+
"but I have to say that the context doesn't seem to provide",
|
| 549 |
+
"these documents appear to be focused on",
|
| 550 |
+
"but they don't seem to cover",
|
| 551 |
+
"I'd recommend consulting a different type of documentation",
|
| 552 |
+
"without more context or information"
|
| 553 |
+
]
|
| 554 |
+
|
| 555 |
+
answer_lower = answer.lower()
|
| 556 |
+
for phrase in off_topic_phrases:
|
| 557 |
+
if phrase in answer_lower:
|
| 558 |
+
return -0.4 # Strong penalty for off-topic responses
|
| 559 |
+
|
| 560 |
+
return 0.0
|
| 561 |
+
|
| 562 |
+
def _assess_answer_completeness(self, answer: str, chunk_count: int) -> float:
|
| 563 |
+
"""Assess answer completeness (0.0-0.1 range)."""
|
| 564 |
+
if not answer:
|
| 565 |
+
return 0.0
|
| 566 |
+
|
| 567 |
+
# Length-based completeness assessment
|
| 568 |
+
answer_length = len(answer)
|
| 569 |
+
|
| 570 |
+
if answer_length > 500: # Comprehensive answer
|
| 571 |
+
return 0.1
|
| 572 |
+
elif answer_length > 200: # Adequate answer
|
| 573 |
+
return 0.05
|
| 574 |
+
else: # Brief answer
|
| 575 |
+
return 0.0
|
| 576 |
+
|
| 577 |
+
def generate(self, query: str, context: List[Document]) -> Answer:
|
| 578 |
+
"""
|
| 579 |
+
Generate an answer from query and context documents (standard interface).
|
| 580 |
+
|
| 581 |
+
This is the public interface that conforms to the AnswerGenerator protocol.
|
| 582 |
+
It handles the conversion between standard Document objects and Ollama's
|
| 583 |
+
internal chunk format.
|
| 584 |
+
|
| 585 |
+
Args:
|
| 586 |
+
query: User's question
|
| 587 |
+
context: List of relevant Document objects
|
| 588 |
+
|
| 589 |
+
Returns:
|
| 590 |
+
Answer object conforming to standard interface
|
| 591 |
+
|
| 592 |
+
Raises:
|
| 593 |
+
ValueError: If query is empty or context is None
|
| 594 |
+
"""
|
| 595 |
+
if not query.strip():
|
| 596 |
+
raise ValueError("Query cannot be empty")
|
| 597 |
+
|
| 598 |
+
if context is None:
|
| 599 |
+
raise ValueError("Context cannot be None")
|
| 600 |
+
|
| 601 |
+
# Internal adapter: Convert Documents to Ollama chunk format
|
| 602 |
+
ollama_chunks = self._documents_to_ollama_chunks(context)
|
| 603 |
+
|
| 604 |
+
# Use existing Ollama-specific generation logic
|
| 605 |
+
ollama_result = self._generate_internal(query, ollama_chunks)
|
| 606 |
+
|
| 607 |
+
# Internal adapter: Convert Ollama result to standard Answer
|
| 608 |
+
return self._ollama_result_to_answer(ollama_result, context)
|
| 609 |
+
|
| 610 |
+
def _generate_internal(self, query: str, chunks: List[Dict[str, Any]]) -> GeneratedAnswer:
|
| 611 |
+
"""
|
| 612 |
+
Generate an answer based on the query and retrieved chunks.
|
| 613 |
+
|
| 614 |
+
Args:
|
| 615 |
+
query: User's question
|
| 616 |
+
chunks: Retrieved document chunks
|
| 617 |
+
|
| 618 |
+
Returns:
|
| 619 |
+
GeneratedAnswer object with answer, citations, and metadata
|
| 620 |
+
"""
|
| 621 |
+
start_time = datetime.now()
|
| 622 |
+
|
| 623 |
+
# Check for no-context situation
|
| 624 |
+
if not chunks or all(
|
| 625 |
+
len(chunk.get("content", chunk.get("text", ""))) < 20 for chunk in chunks
|
| 626 |
+
):
|
| 627 |
+
return GeneratedAnswer(
|
| 628 |
+
answer="This information isn't available in the provided documents.",
|
| 629 |
+
citations=[],
|
| 630 |
+
confidence_score=0.05,
|
| 631 |
+
generation_time=0.1,
|
| 632 |
+
model_used=self.model_name,
|
| 633 |
+
context_used=chunks,
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
# Format context
|
| 637 |
+
context = self._format_context(chunks)
|
| 638 |
+
|
| 639 |
+
# Create prompt with chunks parameter for dynamic citation instructions
|
| 640 |
+
prompt = self._create_prompt(query, context, chunks)
|
| 641 |
+
|
| 642 |
+
# Generate answer
|
| 643 |
+
print(
|
| 644 |
+
f"🤖 Calling Ollama with {self.model_name}...", file=sys.stderr, flush=True
|
| 645 |
+
)
|
| 646 |
+
answer_with_citations = self._call_ollama(prompt)
|
| 647 |
+
|
| 648 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 649 |
+
|
| 650 |
+
# Extract citations and create natural answer
|
| 651 |
+
natural_answer, citations = self._extract_citations(
|
| 652 |
+
answer_with_citations, chunks
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# Calculate confidence
|
| 656 |
+
confidence = self._calculate_confidence(natural_answer, citations, chunks)
|
| 657 |
+
|
| 658 |
+
return GeneratedAnswer(
|
| 659 |
+
answer=natural_answer,
|
| 660 |
+
citations=citations,
|
| 661 |
+
confidence_score=confidence,
|
| 662 |
+
generation_time=generation_time,
|
| 663 |
+
model_used=self.model_name,
|
| 664 |
+
context_used=chunks,
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
def generate_with_custom_prompt(
|
| 668 |
+
self,
|
| 669 |
+
query: str,
|
| 670 |
+
chunks: List[Dict[str, Any]],
|
| 671 |
+
custom_prompt: Dict[str, str]
|
| 672 |
+
) -> GeneratedAnswer:
|
| 673 |
+
"""
|
| 674 |
+
Generate answer using a custom prompt (for adaptive prompting).
|
| 675 |
+
|
| 676 |
+
Args:
|
| 677 |
+
query: User's question
|
| 678 |
+
chunks: Retrieved context chunks
|
| 679 |
+
custom_prompt: Dict with 'system' and 'user' prompts
|
| 680 |
+
|
| 681 |
+
Returns:
|
| 682 |
+
GeneratedAnswer with custom prompt enhancement
|
| 683 |
+
"""
|
| 684 |
+
start_time = datetime.now()
|
| 685 |
+
|
| 686 |
+
if not chunks:
|
| 687 |
+
return GeneratedAnswer(
|
| 688 |
+
answer="I don't have enough context to answer your question.",
|
| 689 |
+
citations=[],
|
| 690 |
+
confidence_score=0.0,
|
| 691 |
+
generation_time=0.1,
|
| 692 |
+
model_used=self.model_name,
|
| 693 |
+
context_used=chunks,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
# Build custom prompt based on model type
|
| 697 |
+
if "llama" in self.model_name.lower():
|
| 698 |
+
prompt = f"""[INST] {custom_prompt['system']}
|
| 699 |
+
|
| 700 |
+
{custom_prompt['user']}
|
| 701 |
+
|
| 702 |
+
MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
|
| 703 |
+
elif "mistral" in self.model_name.lower():
|
| 704 |
+
prompt = f"""[INST] {custom_prompt['system']}
|
| 705 |
+
|
| 706 |
+
{custom_prompt['user']}
|
| 707 |
+
|
| 708 |
+
MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
|
| 709 |
+
else:
|
| 710 |
+
# Generic format for other models
|
| 711 |
+
prompt = f"""{custom_prompt['system']}
|
| 712 |
+
|
| 713 |
+
{custom_prompt['user']}
|
| 714 |
+
|
| 715 |
+
MANDATORY: Use [chunk_1], [chunk_2] etc. for all factual statements.
|
| 716 |
+
|
| 717 |
+
Answer:"""
|
| 718 |
+
|
| 719 |
+
# Generate answer
|
| 720 |
+
print(f"🤖 Calling Ollama with custom prompt using {self.model_name}...", file=sys.stderr, flush=True)
|
| 721 |
+
answer_with_citations = self._call_ollama(prompt)
|
| 722 |
+
|
| 723 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 724 |
+
|
| 725 |
+
# Extract citations and create natural answer
|
| 726 |
+
natural_answer, citations = self._extract_citations(answer_with_citations, chunks)
|
| 727 |
+
|
| 728 |
+
# Calculate confidence
|
| 729 |
+
confidence = self._calculate_confidence(natural_answer, citations, chunks)
|
| 730 |
+
|
| 731 |
+
return GeneratedAnswer(
|
| 732 |
+
answer=natural_answer,
|
| 733 |
+
citations=citations,
|
| 734 |
+
confidence_score=confidence,
|
| 735 |
+
generation_time=generation_time,
|
| 736 |
+
model_used=self.model_name,
|
| 737 |
+
context_used=chunks,
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
def _documents_to_ollama_chunks(self, documents: List[Document]) -> List[Dict[str, Any]]:
|
| 741 |
+
"""
|
| 742 |
+
Convert Document objects to Ollama's internal chunk format.
|
| 743 |
+
|
| 744 |
+
This internal adapter ensures that Document objects are properly formatted
|
| 745 |
+
for Ollama's processing pipeline while keeping the format requirements
|
| 746 |
+
encapsulated within this class.
|
| 747 |
+
|
| 748 |
+
Args:
|
| 749 |
+
documents: List of Document objects from the standard interface
|
| 750 |
+
|
| 751 |
+
Returns:
|
| 752 |
+
List of chunk dictionaries in Ollama's expected format
|
| 753 |
+
"""
|
| 754 |
+
if not documents:
|
| 755 |
+
return []
|
| 756 |
+
|
| 757 |
+
chunks = []
|
| 758 |
+
for i, doc in enumerate(documents):
|
| 759 |
+
chunk = {
|
| 760 |
+
"id": f"chunk_{i+1}",
|
| 761 |
+
"content": doc.content, # Ollama expects "content" field
|
| 762 |
+
"text": doc.content, # Fallback field for compatibility
|
| 763 |
+
"score": 1.0, # Default relevance score
|
| 764 |
+
"metadata": {
|
| 765 |
+
"source": doc.metadata.get("source", "unknown"),
|
| 766 |
+
"page_number": doc.metadata.get("start_page", 1),
|
| 767 |
+
**doc.metadata # Include all original metadata
|
| 768 |
+
}
|
| 769 |
+
}
|
| 770 |
+
chunks.append(chunk)
|
| 771 |
+
|
| 772 |
+
return chunks
|
| 773 |
+
|
| 774 |
+
def _ollama_result_to_answer(self, ollama_result: GeneratedAnswer, original_context: List[Document]) -> Answer:
|
| 775 |
+
"""
|
| 776 |
+
Convert Ollama's GeneratedAnswer to the standard Answer format.
|
| 777 |
+
|
| 778 |
+
This internal adapter converts Ollama's result format back to the
|
| 779 |
+
standard interface format expected by the rest of the system.
|
| 780 |
+
|
| 781 |
+
Args:
|
| 782 |
+
ollama_result: Result from Ollama's internal generation
|
| 783 |
+
original_context: Original Document objects for sources
|
| 784 |
+
|
| 785 |
+
Returns:
|
| 786 |
+
Answer object conforming to standard interface
|
| 787 |
+
"""
|
| 788 |
+
if not Answer:
|
| 789 |
+
# Fallback if standard interface not available
|
| 790 |
+
return ollama_result
|
| 791 |
+
|
| 792 |
+
# Convert to standard Answer format
|
| 793 |
+
return Answer(
|
| 794 |
+
text=ollama_result.answer,
|
| 795 |
+
sources=original_context, # Use original Document objects
|
| 796 |
+
confidence=ollama_result.confidence_score,
|
| 797 |
+
metadata={
|
| 798 |
+
"model_used": ollama_result.model_used,
|
| 799 |
+
"generation_time": ollama_result.generation_time,
|
| 800 |
+
"citations": [
|
| 801 |
+
{
|
| 802 |
+
"chunk_id": cit.chunk_id,
|
| 803 |
+
"page_number": cit.page_number,
|
| 804 |
+
"source_file": cit.source_file,
|
| 805 |
+
"relevance_score": cit.relevance_score,
|
| 806 |
+
"text_snippet": cit.text_snippet
|
| 807 |
+
}
|
| 808 |
+
for cit in ollama_result.citations
|
| 809 |
+
],
|
| 810 |
+
"provider": "ollama",
|
| 811 |
+
"temperature": self.temperature,
|
| 812 |
+
"max_tokens": self.max_tokens
|
| 813 |
+
}
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
# Example usage
|
| 818 |
+
if __name__ == "__main__":
|
| 819 |
+
# Test Ollama connection
|
| 820 |
+
generator = OllamaAnswerGenerator(model_name="llama3.2:3b")
|
| 821 |
+
|
| 822 |
+
# Mock chunks for testing
|
| 823 |
+
test_chunks = [
|
| 824 |
+
{
|
| 825 |
+
"content": "RISC-V is a free and open-source ISA.",
|
| 826 |
+
"metadata": {"page_number": 1, "source": "riscv-spec.pdf"},
|
| 827 |
+
"score": 0.9,
|
| 828 |
+
}
|
| 829 |
+
]
|
| 830 |
+
|
| 831 |
+
# Test generation
|
| 832 |
+
result = generator.generate("What is RISC-V?", test_chunks)
|
| 833 |
+
print(f"Answer: {result.answer}")
|
| 834 |
+
print(f"Confidence: {result.confidence_score:.2%}")
|
shared_utils/generation/prompt_optimizer.py
ADDED
|
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A/B Testing Framework for Prompt Optimization.
|
| 3 |
+
|
| 4 |
+
This module provides systematic prompt optimization through A/B testing,
|
| 5 |
+
performance analysis, and automated prompt variation generation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import time
|
| 10 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 11 |
+
from dataclasses import dataclass, asdict
|
| 12 |
+
from enum import Enum
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import numpy as np
|
| 15 |
+
from collections import defaultdict
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
from .prompt_templates import QueryType, PromptTemplate, TechnicalPromptTemplates
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class OptimizationMetric(Enum):
|
| 22 |
+
"""Metrics for evaluating prompt performance."""
|
| 23 |
+
RESPONSE_TIME = "response_time"
|
| 24 |
+
CONFIDENCE_SCORE = "confidence_score"
|
| 25 |
+
CITATION_COUNT = "citation_count"
|
| 26 |
+
ANSWER_LENGTH = "answer_length"
|
| 27 |
+
TECHNICAL_ACCURACY = "technical_accuracy"
|
| 28 |
+
USER_SATISFACTION = "user_satisfaction"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class PromptVariation:
|
| 33 |
+
"""Represents a prompt variation for A/B testing."""
|
| 34 |
+
variation_id: str
|
| 35 |
+
name: str
|
| 36 |
+
description: str
|
| 37 |
+
template: PromptTemplate
|
| 38 |
+
query_type: QueryType
|
| 39 |
+
created_at: float
|
| 40 |
+
metadata: Dict[str, Any]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class TestResult:
|
| 45 |
+
"""Represents a single test result."""
|
| 46 |
+
variation_id: str
|
| 47 |
+
query: str
|
| 48 |
+
query_type: QueryType
|
| 49 |
+
response_time: float
|
| 50 |
+
confidence_score: float
|
| 51 |
+
citation_count: int
|
| 52 |
+
answer_length: int
|
| 53 |
+
technical_accuracy: Optional[float] = None
|
| 54 |
+
user_satisfaction: Optional[float] = None
|
| 55 |
+
timestamp: float = None
|
| 56 |
+
metadata: Dict[str, Any] = None
|
| 57 |
+
|
| 58 |
+
def __post_init__(self):
|
| 59 |
+
if self.timestamp is None:
|
| 60 |
+
self.timestamp = time.time()
|
| 61 |
+
if self.metadata is None:
|
| 62 |
+
self.metadata = {}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class ComparisonResult:
|
| 67 |
+
"""Results of A/B test comparison."""
|
| 68 |
+
variation_a: str
|
| 69 |
+
variation_b: str
|
| 70 |
+
metric: OptimizationMetric
|
| 71 |
+
a_mean: float
|
| 72 |
+
b_mean: float
|
| 73 |
+
improvement_percent: float
|
| 74 |
+
p_value: float
|
| 75 |
+
confidence_interval: Tuple[float, float]
|
| 76 |
+
is_significant: bool
|
| 77 |
+
sample_size: int
|
| 78 |
+
recommendation: str
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class PromptOptimizer:
|
| 82 |
+
"""
|
| 83 |
+
A/B testing framework for systematic prompt optimization.
|
| 84 |
+
|
| 85 |
+
Features:
|
| 86 |
+
- Automated prompt variation generation
|
| 87 |
+
- Performance metric tracking
|
| 88 |
+
- Statistical significance testing
|
| 89 |
+
- Recommendation engine
|
| 90 |
+
- Persistence and experiment tracking
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, experiment_dir: str = "experiments"):
|
| 94 |
+
"""
|
| 95 |
+
Initialize the prompt optimizer.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
experiment_dir: Directory to store experiment data
|
| 99 |
+
"""
|
| 100 |
+
self.experiment_dir = Path(experiment_dir)
|
| 101 |
+
self.experiment_dir.mkdir(exist_ok=True)
|
| 102 |
+
|
| 103 |
+
self.variations: Dict[str, PromptVariation] = {}
|
| 104 |
+
self.test_results: List[TestResult] = []
|
| 105 |
+
self.active_experiments: Dict[str, List[str]] = {}
|
| 106 |
+
|
| 107 |
+
# Load existing experiments
|
| 108 |
+
self._load_experiments()
|
| 109 |
+
|
| 110 |
+
# Setup logging
|
| 111 |
+
logging.basicConfig(level=logging.INFO)
|
| 112 |
+
self.logger = logging.getLogger(__name__)
|
| 113 |
+
|
| 114 |
+
def create_variation(
|
| 115 |
+
self,
|
| 116 |
+
base_template: PromptTemplate,
|
| 117 |
+
query_type: QueryType,
|
| 118 |
+
variation_name: str,
|
| 119 |
+
modifications: Dict[str, str],
|
| 120 |
+
description: str = ""
|
| 121 |
+
) -> str:
|
| 122 |
+
"""
|
| 123 |
+
Create a new prompt variation.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
base_template: Base template to modify
|
| 127 |
+
query_type: Type of query this variation is for
|
| 128 |
+
variation_name: Human-readable name
|
| 129 |
+
modifications: Dict of template field modifications
|
| 130 |
+
description: Description of the variation
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Variation ID
|
| 134 |
+
"""
|
| 135 |
+
variation_id = f"{query_type.value}_{variation_name}_{int(time.time())}"
|
| 136 |
+
|
| 137 |
+
# Create modified template
|
| 138 |
+
modified_template = PromptTemplate(
|
| 139 |
+
system_prompt=modifications.get("system_prompt", base_template.system_prompt),
|
| 140 |
+
context_format=modifications.get("context_format", base_template.context_format),
|
| 141 |
+
query_format=modifications.get("query_format", base_template.query_format),
|
| 142 |
+
answer_guidelines=modifications.get("answer_guidelines", base_template.answer_guidelines)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
variation = PromptVariation(
|
| 146 |
+
variation_id=variation_id,
|
| 147 |
+
name=variation_name,
|
| 148 |
+
description=description,
|
| 149 |
+
template=modified_template,
|
| 150 |
+
query_type=query_type,
|
| 151 |
+
created_at=time.time(),
|
| 152 |
+
metadata=modifications
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self.variations[variation_id] = variation
|
| 156 |
+
self._save_variation(variation)
|
| 157 |
+
|
| 158 |
+
self.logger.info(f"Created variation: {variation_id}")
|
| 159 |
+
return variation_id
|
| 160 |
+
|
| 161 |
+
def create_temperature_variations(
|
| 162 |
+
self,
|
| 163 |
+
base_query_type: QueryType,
|
| 164 |
+
temperatures: List[float] = [0.3, 0.5, 0.7, 0.9]
|
| 165 |
+
) -> List[str]:
|
| 166 |
+
"""
|
| 167 |
+
Create variations with different temperature settings.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
base_query_type: Query type to create variations for
|
| 171 |
+
temperatures: List of temperature values to test
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
List of variation IDs
|
| 175 |
+
"""
|
| 176 |
+
base_template = TechnicalPromptTemplates.get_template_for_query("")
|
| 177 |
+
if base_query_type != QueryType.GENERAL:
|
| 178 |
+
template_map = {
|
| 179 |
+
QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template,
|
| 180 |
+
QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template,
|
| 181 |
+
QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template,
|
| 182 |
+
QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template,
|
| 183 |
+
QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template,
|
| 184 |
+
QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template,
|
| 185 |
+
QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template,
|
| 186 |
+
}
|
| 187 |
+
base_template = template_map[base_query_type]()
|
| 188 |
+
|
| 189 |
+
variation_ids = []
|
| 190 |
+
for temp in temperatures:
|
| 191 |
+
temp_modification = {
|
| 192 |
+
"system_prompt": base_template.system_prompt + f"\n\nGenerate responses with temperature={temp} (creativity level).",
|
| 193 |
+
"answer_guidelines": base_template.answer_guidelines + f"\n\nAdjust response creativity to temperature={temp}."
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
variation_id = self.create_variation(
|
| 197 |
+
base_template=base_template,
|
| 198 |
+
query_type=base_query_type,
|
| 199 |
+
variation_name=f"temp_{temp}",
|
| 200 |
+
modifications=temp_modification,
|
| 201 |
+
description=f"Temperature variation with {temp} creativity level"
|
| 202 |
+
)
|
| 203 |
+
variation_ids.append(variation_id)
|
| 204 |
+
|
| 205 |
+
return variation_ids
|
| 206 |
+
|
| 207 |
+
def create_length_variations(
|
| 208 |
+
self,
|
| 209 |
+
base_query_type: QueryType,
|
| 210 |
+
length_styles: List[str] = ["concise", "detailed", "comprehensive"]
|
| 211 |
+
) -> List[str]:
|
| 212 |
+
"""
|
| 213 |
+
Create variations with different response length preferences.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
base_query_type: Query type to create variations for
|
| 217 |
+
length_styles: List of length styles to test
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
List of variation IDs
|
| 221 |
+
"""
|
| 222 |
+
base_template = TechnicalPromptTemplates.get_template_for_query("")
|
| 223 |
+
if base_query_type != QueryType.GENERAL:
|
| 224 |
+
template_map = {
|
| 225 |
+
QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template,
|
| 226 |
+
QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template,
|
| 227 |
+
QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template,
|
| 228 |
+
QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template,
|
| 229 |
+
QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template,
|
| 230 |
+
QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template,
|
| 231 |
+
QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template,
|
| 232 |
+
}
|
| 233 |
+
base_template = template_map[base_query_type]()
|
| 234 |
+
|
| 235 |
+
length_prompts = {
|
| 236 |
+
"concise": "Be concise and focus on essential information only. Aim for 2-3 sentences per point.",
|
| 237 |
+
"detailed": "Provide detailed explanations with examples. Aim for comprehensive coverage.",
|
| 238 |
+
"comprehensive": "Provide exhaustive detail with multiple examples, edge cases, and related concepts."
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
variation_ids = []
|
| 242 |
+
for style in length_styles:
|
| 243 |
+
length_modification = {
|
| 244 |
+
"answer_guidelines": base_template.answer_guidelines + f"\n\nResponse style: {length_prompts[style]}"
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
variation_id = self.create_variation(
|
| 248 |
+
base_template=base_template,
|
| 249 |
+
query_type=base_query_type,
|
| 250 |
+
variation_name=f"length_{style}",
|
| 251 |
+
modifications=length_modification,
|
| 252 |
+
description=f"Length variation with {style} response style"
|
| 253 |
+
)
|
| 254 |
+
variation_ids.append(variation_id)
|
| 255 |
+
|
| 256 |
+
return variation_ids
|
| 257 |
+
|
| 258 |
+
def create_citation_variations(
|
| 259 |
+
self,
|
| 260 |
+
base_query_type: QueryType,
|
| 261 |
+
citation_styles: List[str] = ["minimal", "standard", "extensive"]
|
| 262 |
+
) -> List[str]:
|
| 263 |
+
"""
|
| 264 |
+
Create variations with different citation requirements.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
base_query_type: Query type to create variations for
|
| 268 |
+
citation_styles: List of citation styles to test
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
List of variation IDs
|
| 272 |
+
"""
|
| 273 |
+
base_template = TechnicalPromptTemplates.get_template_for_query("")
|
| 274 |
+
if base_query_type != QueryType.GENERAL:
|
| 275 |
+
template_map = {
|
| 276 |
+
QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template,
|
| 277 |
+
QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template,
|
| 278 |
+
QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template,
|
| 279 |
+
QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template,
|
| 280 |
+
QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template,
|
| 281 |
+
QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template,
|
| 282 |
+
QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template,
|
| 283 |
+
}
|
| 284 |
+
base_template = template_map[base_query_type]()
|
| 285 |
+
|
| 286 |
+
citation_prompts = {
|
| 287 |
+
"minimal": "Use [chunk_X] citations only for direct quotes or specific claims.",
|
| 288 |
+
"standard": "Include [chunk_X] citations for each major point or claim.",
|
| 289 |
+
"extensive": "Provide [chunk_X] citations for every statement. Use multiple citations per point where relevant."
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
variation_ids = []
|
| 293 |
+
for style in citation_styles:
|
| 294 |
+
citation_modification = {
|
| 295 |
+
"answer_guidelines": base_template.answer_guidelines + f"\n\nCitation style: {citation_prompts[style]}"
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
variation_id = self.create_variation(
|
| 299 |
+
base_template=base_template,
|
| 300 |
+
query_type=base_query_type,
|
| 301 |
+
variation_name=f"citation_{style}",
|
| 302 |
+
modifications=citation_modification,
|
| 303 |
+
description=f"Citation variation with {style} citation requirements"
|
| 304 |
+
)
|
| 305 |
+
variation_ids.append(variation_id)
|
| 306 |
+
|
| 307 |
+
return variation_ids
|
| 308 |
+
|
| 309 |
+
def setup_experiment(
|
| 310 |
+
self,
|
| 311 |
+
experiment_name: str,
|
| 312 |
+
variation_ids: List[str],
|
| 313 |
+
test_queries: List[str]
|
| 314 |
+
) -> str:
|
| 315 |
+
"""
|
| 316 |
+
Set up a new A/B test experiment.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
experiment_name: Name of the experiment
|
| 320 |
+
variation_ids: List of variation IDs to test
|
| 321 |
+
test_queries: List of test queries
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
Experiment ID
|
| 325 |
+
"""
|
| 326 |
+
experiment_id = f"exp_{experiment_name}_{int(time.time())}"
|
| 327 |
+
|
| 328 |
+
experiment_config = {
|
| 329 |
+
"experiment_id": experiment_id,
|
| 330 |
+
"name": experiment_name,
|
| 331 |
+
"variation_ids": variation_ids,
|
| 332 |
+
"test_queries": test_queries,
|
| 333 |
+
"created_at": time.time(),
|
| 334 |
+
"status": "active"
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
self.active_experiments[experiment_id] = variation_ids
|
| 338 |
+
|
| 339 |
+
# Save experiment config
|
| 340 |
+
experiment_file = self.experiment_dir / f"{experiment_id}.json"
|
| 341 |
+
with open(experiment_file, 'w') as f:
|
| 342 |
+
json.dump(experiment_config, f, indent=2)
|
| 343 |
+
|
| 344 |
+
self.logger.info(f"Created experiment: {experiment_id}")
|
| 345 |
+
return experiment_id
|
| 346 |
+
|
| 347 |
+
def record_test_result(
|
| 348 |
+
self,
|
| 349 |
+
variation_id: str,
|
| 350 |
+
query: str,
|
| 351 |
+
query_type: QueryType,
|
| 352 |
+
response_time: float,
|
| 353 |
+
confidence_score: float,
|
| 354 |
+
citation_count: int,
|
| 355 |
+
answer_length: int,
|
| 356 |
+
technical_accuracy: Optional[float] = None,
|
| 357 |
+
user_satisfaction: Optional[float] = None,
|
| 358 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 359 |
+
) -> None:
|
| 360 |
+
"""
|
| 361 |
+
Record a test result for analysis.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
variation_id: ID of the variation tested
|
| 365 |
+
query: The query that was tested
|
| 366 |
+
query_type: Type of the query
|
| 367 |
+
response_time: Response time in seconds
|
| 368 |
+
confidence_score: Confidence score (0-1)
|
| 369 |
+
citation_count: Number of citations in response
|
| 370 |
+
answer_length: Length of answer in characters
|
| 371 |
+
technical_accuracy: Optional technical accuracy score (0-1)
|
| 372 |
+
user_satisfaction: Optional user satisfaction score (0-1)
|
| 373 |
+
metadata: Optional additional metadata
|
| 374 |
+
"""
|
| 375 |
+
result = TestResult(
|
| 376 |
+
variation_id=variation_id,
|
| 377 |
+
query=query,
|
| 378 |
+
query_type=query_type,
|
| 379 |
+
response_time=response_time,
|
| 380 |
+
confidence_score=confidence_score,
|
| 381 |
+
citation_count=citation_count,
|
| 382 |
+
answer_length=answer_length,
|
| 383 |
+
technical_accuracy=technical_accuracy,
|
| 384 |
+
user_satisfaction=user_satisfaction,
|
| 385 |
+
metadata=metadata or {}
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
self.test_results.append(result)
|
| 389 |
+
self._save_test_result(result)
|
| 390 |
+
|
| 391 |
+
self.logger.info(f"Recorded test result for variation: {variation_id}")
|
| 392 |
+
|
| 393 |
+
def analyze_variations(
|
| 394 |
+
self,
|
| 395 |
+
variation_a: str,
|
| 396 |
+
variation_b: str,
|
| 397 |
+
metric: OptimizationMetric,
|
| 398 |
+
min_samples: int = 10
|
| 399 |
+
) -> ComparisonResult:
|
| 400 |
+
"""
|
| 401 |
+
Analyze performance difference between two variations.
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
variation_a: First variation ID
|
| 405 |
+
variation_b: Second variation ID
|
| 406 |
+
metric: Metric to compare
|
| 407 |
+
min_samples: Minimum samples required for analysis
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
Comparison result with statistical analysis
|
| 411 |
+
"""
|
| 412 |
+
# Filter results for each variation
|
| 413 |
+
results_a = [r for r in self.test_results if r.variation_id == variation_a]
|
| 414 |
+
results_b = [r for r in self.test_results if r.variation_id == variation_b]
|
| 415 |
+
|
| 416 |
+
if len(results_a) < min_samples or len(results_b) < min_samples:
|
| 417 |
+
raise ValueError(f"Insufficient samples. Need at least {min_samples} for each variation.")
|
| 418 |
+
|
| 419 |
+
# Extract metric values
|
| 420 |
+
values_a = self._extract_metric_values(results_a, metric)
|
| 421 |
+
values_b = self._extract_metric_values(results_b, metric)
|
| 422 |
+
|
| 423 |
+
# Calculate statistics
|
| 424 |
+
mean_a = np.mean(values_a)
|
| 425 |
+
mean_b = np.mean(values_b)
|
| 426 |
+
|
| 427 |
+
# Calculate improvement percentage
|
| 428 |
+
improvement = ((mean_b - mean_a) / mean_a) * 100
|
| 429 |
+
|
| 430 |
+
# Simple t-test (normally would use scipy.stats.ttest_ind)
|
| 431 |
+
# For now, using basic statistical comparison
|
| 432 |
+
std_a = np.std(values_a)
|
| 433 |
+
std_b = np.std(values_b)
|
| 434 |
+
n_a = len(values_a)
|
| 435 |
+
n_b = len(values_b)
|
| 436 |
+
|
| 437 |
+
# Basic p-value estimation (simplified)
|
| 438 |
+
pooled_std = np.sqrt(((n_a - 1) * std_a**2 + (n_b - 1) * std_b**2) / (n_a + n_b - 2))
|
| 439 |
+
t_stat = (mean_b - mean_a) / (pooled_std * np.sqrt(1/n_a + 1/n_b))
|
| 440 |
+
p_value = 2 * (1 - abs(t_stat) / (abs(t_stat) + 1)) # Rough approximation
|
| 441 |
+
|
| 442 |
+
# Confidence interval (simplified)
|
| 443 |
+
margin_of_error = 1.96 * pooled_std * np.sqrt(1/n_a + 1/n_b)
|
| 444 |
+
ci_lower = (mean_b - mean_a) - margin_of_error
|
| 445 |
+
ci_upper = (mean_b - mean_a) + margin_of_error
|
| 446 |
+
|
| 447 |
+
# Determine significance
|
| 448 |
+
is_significant = p_value < 0.05
|
| 449 |
+
|
| 450 |
+
# Generate recommendation
|
| 451 |
+
if is_significant:
|
| 452 |
+
if improvement > 0:
|
| 453 |
+
recommendation = f"Variation B shows significant improvement ({improvement:.1f}%). Recommend adopting variation B."
|
| 454 |
+
else:
|
| 455 |
+
recommendation = f"Variation A shows significant improvement ({-improvement:.1f}%). Recommend keeping variation A."
|
| 456 |
+
else:
|
| 457 |
+
recommendation = f"No significant difference detected (p={p_value:.3f}). More data needed or variations are equivalent."
|
| 458 |
+
|
| 459 |
+
return ComparisonResult(
|
| 460 |
+
variation_a=variation_a,
|
| 461 |
+
variation_b=variation_b,
|
| 462 |
+
metric=metric,
|
| 463 |
+
a_mean=mean_a,
|
| 464 |
+
b_mean=mean_b,
|
| 465 |
+
improvement_percent=improvement,
|
| 466 |
+
p_value=p_value,
|
| 467 |
+
confidence_interval=(ci_lower, ci_upper),
|
| 468 |
+
is_significant=is_significant,
|
| 469 |
+
sample_size=min(n_a, n_b),
|
| 470 |
+
recommendation=recommendation
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
def get_best_variation(
|
| 474 |
+
self,
|
| 475 |
+
query_type: QueryType,
|
| 476 |
+
metric: OptimizationMetric,
|
| 477 |
+
min_samples: int = 10
|
| 478 |
+
) -> Optional[str]:
|
| 479 |
+
"""
|
| 480 |
+
Get the best performing variation for a query type and metric.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
query_type: Type of query
|
| 484 |
+
metric: Metric to optimize for
|
| 485 |
+
min_samples: Minimum samples required
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
Best variation ID or None if insufficient data
|
| 489 |
+
"""
|
| 490 |
+
# Filter results by query type
|
| 491 |
+
relevant_results = [r for r in self.test_results if r.query_type == query_type]
|
| 492 |
+
|
| 493 |
+
# Group by variation
|
| 494 |
+
variation_performance = defaultdict(list)
|
| 495 |
+
for result in relevant_results:
|
| 496 |
+
variation_performance[result.variation_id].append(result)
|
| 497 |
+
|
| 498 |
+
# Calculate mean performance for each variation
|
| 499 |
+
best_variation = None
|
| 500 |
+
best_score = None
|
| 501 |
+
|
| 502 |
+
for variation_id, results in variation_performance.items():
|
| 503 |
+
if len(results) >= min_samples:
|
| 504 |
+
values = self._extract_metric_values(results, metric)
|
| 505 |
+
mean_score = np.mean(values)
|
| 506 |
+
|
| 507 |
+
if best_score is None or mean_score > best_score:
|
| 508 |
+
best_score = mean_score
|
| 509 |
+
best_variation = variation_id
|
| 510 |
+
|
| 511 |
+
return best_variation
|
| 512 |
+
|
| 513 |
+
def generate_optimization_report(
|
| 514 |
+
self,
|
| 515 |
+
experiment_id: str,
|
| 516 |
+
output_file: Optional[str] = None
|
| 517 |
+
) -> Dict[str, Any]:
|
| 518 |
+
"""
|
| 519 |
+
Generate a comprehensive optimization report.
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
experiment_id: Experiment to analyze
|
| 523 |
+
output_file: Optional file to save report
|
| 524 |
+
|
| 525 |
+
Returns:
|
| 526 |
+
Report dictionary
|
| 527 |
+
"""
|
| 528 |
+
if experiment_id not in self.active_experiments:
|
| 529 |
+
raise ValueError(f"Experiment {experiment_id} not found")
|
| 530 |
+
|
| 531 |
+
variation_ids = self.active_experiments[experiment_id]
|
| 532 |
+
experiment_results = [r for r in self.test_results if r.variation_id in variation_ids]
|
| 533 |
+
|
| 534 |
+
if not experiment_results:
|
| 535 |
+
raise ValueError(f"No results found for experiment {experiment_id}")
|
| 536 |
+
|
| 537 |
+
# Analyze each metric
|
| 538 |
+
metrics = [
|
| 539 |
+
OptimizationMetric.RESPONSE_TIME,
|
| 540 |
+
OptimizationMetric.CONFIDENCE_SCORE,
|
| 541 |
+
OptimizationMetric.CITATION_COUNT,
|
| 542 |
+
OptimizationMetric.ANSWER_LENGTH
|
| 543 |
+
]
|
| 544 |
+
|
| 545 |
+
report = {
|
| 546 |
+
"experiment_id": experiment_id,
|
| 547 |
+
"variations_tested": len(variation_ids),
|
| 548 |
+
"total_tests": len(experiment_results),
|
| 549 |
+
"analysis_date": time.time(),
|
| 550 |
+
"metric_analysis": {},
|
| 551 |
+
"recommendations": []
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
# Analyze each metric across variations
|
| 555 |
+
for metric in metrics:
|
| 556 |
+
metric_data = {}
|
| 557 |
+
for variation_id in variation_ids:
|
| 558 |
+
var_results = [r for r in experiment_results if r.variation_id == variation_id]
|
| 559 |
+
if var_results:
|
| 560 |
+
values = self._extract_metric_values(var_results, metric)
|
| 561 |
+
metric_data[variation_id] = {
|
| 562 |
+
"mean": np.mean(values),
|
| 563 |
+
"std": np.std(values),
|
| 564 |
+
"count": len(values)
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
report["metric_analysis"][metric.value] = metric_data
|
| 568 |
+
|
| 569 |
+
# Generate recommendations
|
| 570 |
+
for metric in metrics:
|
| 571 |
+
best_variation = self.get_best_variation(
|
| 572 |
+
query_type=QueryType.GENERAL, # Could be made more specific
|
| 573 |
+
metric=metric,
|
| 574 |
+
min_samples=5
|
| 575 |
+
)
|
| 576 |
+
if best_variation:
|
| 577 |
+
report["recommendations"].append({
|
| 578 |
+
"metric": metric.value,
|
| 579 |
+
"best_variation": best_variation,
|
| 580 |
+
"variation_name": self.variations[best_variation].name
|
| 581 |
+
})
|
| 582 |
+
|
| 583 |
+
# Save report if requested
|
| 584 |
+
if output_file:
|
| 585 |
+
with open(output_file, 'w') as f:
|
| 586 |
+
json.dump(report, f, indent=2)
|
| 587 |
+
|
| 588 |
+
return report
|
| 589 |
+
|
| 590 |
+
def _extract_metric_values(self, results: List[TestResult], metric: OptimizationMetric) -> List[float]:
|
| 591 |
+
"""Extract metric values from test results."""
|
| 592 |
+
values = []
|
| 593 |
+
for result in results:
|
| 594 |
+
if metric == OptimizationMetric.RESPONSE_TIME:
|
| 595 |
+
values.append(result.response_time)
|
| 596 |
+
elif metric == OptimizationMetric.CONFIDENCE_SCORE:
|
| 597 |
+
values.append(result.confidence_score)
|
| 598 |
+
elif metric == OptimizationMetric.CITATION_COUNT:
|
| 599 |
+
values.append(float(result.citation_count))
|
| 600 |
+
elif metric == OptimizationMetric.ANSWER_LENGTH:
|
| 601 |
+
values.append(float(result.answer_length))
|
| 602 |
+
elif metric == OptimizationMetric.TECHNICAL_ACCURACY and result.technical_accuracy is not None:
|
| 603 |
+
values.append(result.technical_accuracy)
|
| 604 |
+
elif metric == OptimizationMetric.USER_SATISFACTION and result.user_satisfaction is not None:
|
| 605 |
+
values.append(result.user_satisfaction)
|
| 606 |
+
|
| 607 |
+
return values
|
| 608 |
+
|
| 609 |
+
def _load_experiments(self) -> None:
|
| 610 |
+
"""Load existing experiments from disk."""
|
| 611 |
+
if not self.experiment_dir.exists():
|
| 612 |
+
return
|
| 613 |
+
|
| 614 |
+
for file_path in self.experiment_dir.glob("*.json"):
|
| 615 |
+
if file_path.name.startswith("exp_"):
|
| 616 |
+
with open(file_path, 'r') as f:
|
| 617 |
+
config = json.load(f)
|
| 618 |
+
self.active_experiments[config["experiment_id"]] = config["variation_ids"]
|
| 619 |
+
|
| 620 |
+
# Load variations and results
|
| 621 |
+
for file_path in self.experiment_dir.glob("variation_*.json"):
|
| 622 |
+
with open(file_path, 'r') as f:
|
| 623 |
+
var_data = json.load(f)
|
| 624 |
+
variation = PromptVariation(**var_data)
|
| 625 |
+
self.variations[variation.variation_id] = variation
|
| 626 |
+
|
| 627 |
+
for file_path in self.experiment_dir.glob("result_*.json"):
|
| 628 |
+
with open(file_path, 'r') as f:
|
| 629 |
+
result_data = json.load(f)
|
| 630 |
+
result = TestResult(**result_data)
|
| 631 |
+
self.test_results.append(result)
|
| 632 |
+
|
| 633 |
+
def _save_variation(self, variation: PromptVariation) -> None:
|
| 634 |
+
"""Save variation to disk."""
|
| 635 |
+
file_path = self.experiment_dir / f"variation_{variation.variation_id}.json"
|
| 636 |
+
var_dict = asdict(variation)
|
| 637 |
+
|
| 638 |
+
# Convert template to dict
|
| 639 |
+
var_dict["template"] = asdict(variation.template)
|
| 640 |
+
var_dict["query_type"] = variation.query_type.value
|
| 641 |
+
|
| 642 |
+
with open(file_path, 'w') as f:
|
| 643 |
+
json.dump(var_dict, f, indent=2)
|
| 644 |
+
|
| 645 |
+
def _save_test_result(self, result: TestResult) -> None:
|
| 646 |
+
"""Save test result to disk."""
|
| 647 |
+
file_path = self.experiment_dir / f"result_{int(result.timestamp)}.json"
|
| 648 |
+
result_dict = asdict(result)
|
| 649 |
+
result_dict["query_type"] = result.query_type.value
|
| 650 |
+
|
| 651 |
+
with open(file_path, 'w') as f:
|
| 652 |
+
json.dump(result_dict, f, indent=2)
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
# Example usage
|
| 656 |
+
if __name__ == "__main__":
|
| 657 |
+
# Initialize optimizer
|
| 658 |
+
optimizer = PromptOptimizer()
|
| 659 |
+
|
| 660 |
+
# Create temperature variations for implementation queries
|
| 661 |
+
temp_variations = optimizer.create_temperature_variations(
|
| 662 |
+
base_query_type=QueryType.IMPLEMENTATION,
|
| 663 |
+
temperatures=[0.3, 0.7]
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
# Create length variations for definition queries
|
| 667 |
+
length_variations = optimizer.create_length_variations(
|
| 668 |
+
base_query_type=QueryType.DEFINITION,
|
| 669 |
+
length_styles=["concise", "detailed"]
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# Setup experiment
|
| 673 |
+
test_queries = [
|
| 674 |
+
"How do I implement a timer interrupt in RISC-V?",
|
| 675 |
+
"What is the difference between machine mode and user mode?",
|
| 676 |
+
"Configure GPIO pins for input/output operations"
|
| 677 |
+
]
|
| 678 |
+
|
| 679 |
+
experiment_id = optimizer.setup_experiment(
|
| 680 |
+
experiment_name="temperature_vs_length",
|
| 681 |
+
variation_ids=temp_variations + length_variations,
|
| 682 |
+
test_queries=test_queries
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
print(f"Created experiment: {experiment_id}")
|
| 686 |
+
print(f"Variations: {len(temp_variations + length_variations)}")
|
| 687 |
+
print(f"Test queries: {len(test_queries)}")
|
shared_utils/generation/prompt_templates.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt templates optimized for technical documentation Q&A.
|
| 3 |
+
|
| 4 |
+
This module provides specialized prompt templates for different types of
|
| 5 |
+
technical queries, with a focus on embedded systems and AI documentation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from enum import Enum
|
| 9 |
+
from typing import Dict, List, Optional
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class QueryType(Enum):
|
| 14 |
+
"""Types of technical queries."""
|
| 15 |
+
DEFINITION = "definition"
|
| 16 |
+
IMPLEMENTATION = "implementation"
|
| 17 |
+
COMPARISON = "comparison"
|
| 18 |
+
TROUBLESHOOTING = "troubleshooting"
|
| 19 |
+
SPECIFICATION = "specification"
|
| 20 |
+
CODE_EXAMPLE = "code_example"
|
| 21 |
+
HARDWARE_CONSTRAINT = "hardware_constraint"
|
| 22 |
+
GENERAL = "general"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class PromptTemplate:
|
| 27 |
+
"""Represents a prompt template with its components."""
|
| 28 |
+
system_prompt: str
|
| 29 |
+
context_format: str
|
| 30 |
+
query_format: str
|
| 31 |
+
answer_guidelines: str
|
| 32 |
+
few_shot_examples: Optional[List[str]] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TechnicalPromptTemplates:
|
| 36 |
+
"""
|
| 37 |
+
Collection of prompt templates optimized for technical documentation.
|
| 38 |
+
|
| 39 |
+
Features:
|
| 40 |
+
- Domain-specific templates for embedded systems and AI
|
| 41 |
+
- Structured output formats
|
| 42 |
+
- Citation requirements
|
| 43 |
+
- Technical accuracy emphasis
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_base_system_prompt() -> str:
|
| 48 |
+
"""Get the base system prompt for technical documentation."""
|
| 49 |
+
return """You are an expert technical documentation assistant specializing in embedded systems,
|
| 50 |
+
RISC-V architecture, RTOS, and embedded AI/ML. Your role is to provide accurate, detailed
|
| 51 |
+
technical answers based strictly on the provided context.
|
| 52 |
+
|
| 53 |
+
Key responsibilities:
|
| 54 |
+
1. Answer questions using ONLY information from the provided context
|
| 55 |
+
2. Include precise citations using [chunk_X] notation for every claim
|
| 56 |
+
3. Maintain technical accuracy and use correct terminology
|
| 57 |
+
4. Format code snippets and technical specifications properly
|
| 58 |
+
5. Clearly state when information is not available in the context
|
| 59 |
+
6. Consider hardware constraints and embedded system limitations when relevant
|
| 60 |
+
|
| 61 |
+
Write naturally and conversationally. Avoid repetitive phrases and numbered lists unless specifically requested. Never make up information. If the context doesn't contain the answer, say so explicitly."""
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def get_definition_template() -> PromptTemplate:
|
| 65 |
+
"""Template for definition/explanation queries."""
|
| 66 |
+
return PromptTemplate(
|
| 67 |
+
system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
|
| 68 |
+
|
| 69 |
+
For definition queries, focus on:
|
| 70 |
+
- Clear, concise technical definitions
|
| 71 |
+
- Related concepts and terminology
|
| 72 |
+
- Technical context and applications
|
| 73 |
+
- Any acronym expansions""",
|
| 74 |
+
|
| 75 |
+
context_format="""Technical Documentation Context:
|
| 76 |
+
{context}""",
|
| 77 |
+
|
| 78 |
+
query_format="""Define or explain: {query}
|
| 79 |
+
|
| 80 |
+
Provide a comprehensive technical definition with proper citations.""",
|
| 81 |
+
|
| 82 |
+
answer_guidelines="""Provide a clear, comprehensive answer that directly addresses the question. Include relevant technical details and cite your sources using [chunk_X] notation. Make your response natural and conversational while maintaining technical accuracy.""",
|
| 83 |
+
|
| 84 |
+
few_shot_examples=[
|
| 85 |
+
"""Q: What is RISC-V?
|
| 86 |
+
A: RISC-V is an open-source instruction set architecture (ISA) based on established reduced instruction set computing (RISC) principles [chunk_1]. Unlike proprietary ISAs, RISC-V is freely available under open-source licenses, allowing anyone to implement RISC-V processors without licensing fees [chunk_2]. The architecture supports 32-bit, 64-bit, and 128-bit address spaces, with a modular design that includes base integer instruction sets and optional extensions [chunk_3]. RISC-V stands for "RISC-Five" referring to the fifth generation of RISC architecture developed at UC Berkeley.""",
|
| 87 |
+
|
| 88 |
+
"""Q: What is FreeRTOS?
|
| 89 |
+
A: FreeRTOS is a real-time operating system kernel for embedded devices that provides multitasking capabilities for microcontrollers and small microprocessors [chunk_1]. It implements a preemptive scheduler with priority-based task scheduling, ensuring deterministic real-time behavior [chunk_2]. FreeRTOS includes core features like task management, semaphores, queues, and memory management while maintaining a small footprint typically under 10KB [chunk_3]. The "Free" in FreeRTOS refers to both its open-source license and the fact that it's free of charge for commercial use."""
|
| 90 |
+
]
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def get_implementation_template() -> PromptTemplate:
|
| 95 |
+
"""Template for implementation/how-to queries."""
|
| 96 |
+
return PromptTemplate(
|
| 97 |
+
system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
|
| 98 |
+
|
| 99 |
+
For implementation queries, focus on:
|
| 100 |
+
- Step-by-step instructions
|
| 101 |
+
- Required components or dependencies
|
| 102 |
+
- Code examples with proper formatting
|
| 103 |
+
- Hardware/software requirements
|
| 104 |
+
- Common pitfalls or considerations""",
|
| 105 |
+
|
| 106 |
+
context_format="""Implementation Documentation:
|
| 107 |
+
{context}""",
|
| 108 |
+
|
| 109 |
+
query_format="""Implementation question: {query}
|
| 110 |
+
|
| 111 |
+
Provide detailed implementation guidance with code examples where available.""",
|
| 112 |
+
|
| 113 |
+
answer_guidelines="""Structure your answer as:
|
| 114 |
+
1. Overview of the implementation approach [chunk_X]
|
| 115 |
+
2. Prerequisites and requirements [chunk_Y]
|
| 116 |
+
3. Step-by-step implementation:
|
| 117 |
+
- Step 1: Description [chunk_Z]
|
| 118 |
+
- Step 2: Description [chunk_W]
|
| 119 |
+
4. Code example (if available):
|
| 120 |
+
```language
|
| 121 |
+
// Code here
|
| 122 |
+
```
|
| 123 |
+
5. Important considerations or warnings""",
|
| 124 |
+
|
| 125 |
+
few_shot_examples=[
|
| 126 |
+
"""Q: How do I configure GPIO pins for output in RISC-V?
|
| 127 |
+
A: GPIO configuration for output requires setting the pin direction and initial value through memory-mapped registers [chunk_1]. First, identify the GPIO base address for your specific RISC-V implementation (commonly 0x10060000 for SiFive cores) [chunk_2].
|
| 128 |
+
|
| 129 |
+
Steps:
|
| 130 |
+
1. Set pin direction to output by writing to GPIO_OUTPUT_EN register [chunk_3]
|
| 131 |
+
2. Configure initial output value using GPIO_OUTPUT_VAL register [chunk_4]
|
| 132 |
+
|
| 133 |
+
```c
|
| 134 |
+
#define GPIO_BASE 0x10060000
|
| 135 |
+
#define GPIO_OUTPUT_EN (GPIO_BASE + 0x08)
|
| 136 |
+
#define GPIO_OUTPUT_VAL (GPIO_BASE + 0x0C)
|
| 137 |
+
|
| 138 |
+
// Configure pin 5 as output
|
| 139 |
+
volatile uint32_t *gpio_en = (uint32_t*)GPIO_OUTPUT_EN;
|
| 140 |
+
volatile uint32_t *gpio_val = (uint32_t*)GPIO_OUTPUT_VAL;
|
| 141 |
+
|
| 142 |
+
*gpio_en |= (1 << 5); // Enable output on pin 5
|
| 143 |
+
*gpio_val |= (1 << 5); // Set pin 5 high
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
Important: Always check your board's documentation for the correct GPIO base address and pin mapping [chunk_5].""",
|
| 147 |
+
|
| 148 |
+
"""Q: How to implement a basic timer interrupt in RISC-V?
|
| 149 |
+
A: Timer interrupts in RISC-V use the machine timer (mtime) and timer compare (mtimecmp) registers for precise timing control [chunk_1]. The implementation requires configuring the timer hardware, setting up the interrupt handler, and enabling machine timer interrupts [chunk_2].
|
| 150 |
+
|
| 151 |
+
Prerequisites:
|
| 152 |
+
- RISC-V processor with timer support
|
| 153 |
+
- Access to machine-level CSRs
|
| 154 |
+
- Understanding of memory-mapped timer registers [chunk_3]
|
| 155 |
+
|
| 156 |
+
Implementation steps:
|
| 157 |
+
1. Set up timer compare value in mtimecmp register [chunk_4]
|
| 158 |
+
2. Enable machine timer interrupt in mie CSR [chunk_5]
|
| 159 |
+
3. Configure interrupt handler in mtvec CSR [chunk_6]
|
| 160 |
+
|
| 161 |
+
```c
|
| 162 |
+
#define MTIME_BASE 0x0200bff8
|
| 163 |
+
#define MTIMECMP_BASE 0x02004000
|
| 164 |
+
|
| 165 |
+
void setup_timer_interrupt(uint64_t interval) {
|
| 166 |
+
uint64_t *mtime = (uint64_t*)MTIME_BASE;
|
| 167 |
+
uint64_t *mtimecmp = (uint64_t*)MTIMECMP_BASE;
|
| 168 |
+
|
| 169 |
+
// Set next interrupt time
|
| 170 |
+
*mtimecmp = *mtime + interval;
|
| 171 |
+
|
| 172 |
+
// Enable machine timer interrupt
|
| 173 |
+
asm volatile ("csrs mie, %0" : : "r"(0x80));
|
| 174 |
+
|
| 175 |
+
// Enable global interrupts
|
| 176 |
+
asm volatile ("csrs mstatus, %0" : : "r"(0x8));
|
| 177 |
+
}
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
Critical considerations: Timer registers are 64-bit and must be accessed atomically on 32-bit systems [chunk_7]."""
|
| 181 |
+
]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def get_comparison_template() -> PromptTemplate:
|
| 186 |
+
"""Template for comparison queries."""
|
| 187 |
+
return PromptTemplate(
|
| 188 |
+
system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
|
| 189 |
+
|
| 190 |
+
For comparison queries, focus on:
|
| 191 |
+
- Clear distinction between compared items
|
| 192 |
+
- Technical specifications and differences
|
| 193 |
+
- Use cases for each option
|
| 194 |
+
- Performance or resource implications
|
| 195 |
+
- Recommendations based on context""",
|
| 196 |
+
|
| 197 |
+
context_format="""Technical Comparison Context:
|
| 198 |
+
{context}""",
|
| 199 |
+
|
| 200 |
+
query_format="""Compare: {query}
|
| 201 |
+
|
| 202 |
+
Provide a detailed technical comparison with clear distinctions.""",
|
| 203 |
+
|
| 204 |
+
answer_guidelines="""Structure your answer as:
|
| 205 |
+
1. Overview of items being compared [chunk_X]
|
| 206 |
+
2. Key differences:
|
| 207 |
+
- Feature A: Item1 vs Item2 [chunk_Y]
|
| 208 |
+
- Feature B: Item1 vs Item2 [chunk_Z]
|
| 209 |
+
3. Technical specifications comparison
|
| 210 |
+
4. Use case recommendations
|
| 211 |
+
5. Performance/resource considerations"""
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
@staticmethod
|
| 215 |
+
def get_specification_template() -> PromptTemplate:
|
| 216 |
+
"""Template for specification/parameter queries."""
|
| 217 |
+
return PromptTemplate(
|
| 218 |
+
system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
|
| 219 |
+
|
| 220 |
+
For specification queries, focus on:
|
| 221 |
+
- Exact technical specifications
|
| 222 |
+
- Parameter ranges and limits
|
| 223 |
+
- Units and measurements
|
| 224 |
+
- Compliance with standards
|
| 225 |
+
- Version-specific information""",
|
| 226 |
+
|
| 227 |
+
context_format="""Technical Specifications:
|
| 228 |
+
{context}""",
|
| 229 |
+
|
| 230 |
+
query_format="""Specification query: {query}
|
| 231 |
+
|
| 232 |
+
Provide precise technical specifications with all relevant parameters.""",
|
| 233 |
+
|
| 234 |
+
answer_guidelines="""Structure your answer as:
|
| 235 |
+
1. Specification overview [chunk_X]
|
| 236 |
+
2. Detailed parameters:
|
| 237 |
+
- Parameter 1: value (unit) [chunk_Y]
|
| 238 |
+
- Parameter 2: value (unit) [chunk_Z]
|
| 239 |
+
3. Operating conditions or constraints
|
| 240 |
+
4. Compliance/standards information
|
| 241 |
+
5. Version or variant notes"""
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
@staticmethod
|
| 245 |
+
def get_code_example_template() -> PromptTemplate:
|
| 246 |
+
"""Template for code example queries."""
|
| 247 |
+
return PromptTemplate(
|
| 248 |
+
system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
|
| 249 |
+
|
| 250 |
+
For code example queries, focus on:
|
| 251 |
+
- Complete, runnable code examples
|
| 252 |
+
- Proper syntax highlighting
|
| 253 |
+
- Clear comments and documentation
|
| 254 |
+
- Error handling
|
| 255 |
+
- Best practices for embedded systems""",
|
| 256 |
+
|
| 257 |
+
context_format="""Code Examples and Documentation:
|
| 258 |
+
{context}""",
|
| 259 |
+
|
| 260 |
+
query_format="""Code example request: {query}
|
| 261 |
+
|
| 262 |
+
Provide working code examples with explanations.""",
|
| 263 |
+
|
| 264 |
+
answer_guidelines="""Structure your answer as:
|
| 265 |
+
1. Purpose and overview [chunk_X]
|
| 266 |
+
2. Required includes/imports [chunk_Y]
|
| 267 |
+
3. Complete code example:
|
| 268 |
+
```c
|
| 269 |
+
// Or appropriate language
|
| 270 |
+
#include <necessary_headers.h>
|
| 271 |
+
|
| 272 |
+
// Function or code implementation
|
| 273 |
+
```
|
| 274 |
+
4. Key points explained [chunk_Z]
|
| 275 |
+
5. Common variations or modifications"""
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def get_hardware_constraint_template() -> PromptTemplate:
|
| 280 |
+
"""Template for hardware constraint queries."""
|
| 281 |
+
return PromptTemplate(
|
| 282 |
+
system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
|
| 283 |
+
|
| 284 |
+
For hardware constraint queries, focus on:
|
| 285 |
+
- Memory requirements (RAM, Flash)
|
| 286 |
+
- Processing power needs (MIPS, frequency)
|
| 287 |
+
- Power consumption
|
| 288 |
+
- I/O requirements
|
| 289 |
+
- Real-time constraints
|
| 290 |
+
- Temperature/environmental limits""",
|
| 291 |
+
|
| 292 |
+
context_format="""Hardware Specifications and Constraints:
|
| 293 |
+
{context}""",
|
| 294 |
+
|
| 295 |
+
query_format="""Hardware constraint question: {query}
|
| 296 |
+
|
| 297 |
+
Analyze feasibility and constraints for embedded deployment.""",
|
| 298 |
+
|
| 299 |
+
answer_guidelines="""Structure your answer as:
|
| 300 |
+
1. Hardware requirements summary [chunk_X]
|
| 301 |
+
2. Detailed constraints:
|
| 302 |
+
- Memory: RAM/Flash requirements [chunk_Y]
|
| 303 |
+
- Processing: CPU/frequency needs [chunk_Z]
|
| 304 |
+
- Power: Consumption estimates [chunk_W]
|
| 305 |
+
3. Feasibility assessment
|
| 306 |
+
4. Optimization suggestions
|
| 307 |
+
5. Alternative approaches if constraints are exceeded"""
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
@staticmethod
|
| 311 |
+
def get_troubleshooting_template() -> PromptTemplate:
|
| 312 |
+
"""Template for troubleshooting queries."""
|
| 313 |
+
return PromptTemplate(
|
| 314 |
+
system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
|
| 315 |
+
|
| 316 |
+
For troubleshooting queries, focus on:
|
| 317 |
+
- Common error causes
|
| 318 |
+
- Diagnostic steps
|
| 319 |
+
- Solution procedures
|
| 320 |
+
- Preventive measures
|
| 321 |
+
- Debug techniques for embedded systems""",
|
| 322 |
+
|
| 323 |
+
context_format="""Troubleshooting Documentation:
|
| 324 |
+
{context}""",
|
| 325 |
+
|
| 326 |
+
query_format="""Troubleshooting issue: {query}
|
| 327 |
+
|
| 328 |
+
Provide diagnostic steps and solutions.""",
|
| 329 |
+
|
| 330 |
+
answer_guidelines="""Structure your answer as:
|
| 331 |
+
1. Problem identification [chunk_X]
|
| 332 |
+
2. Common causes:
|
| 333 |
+
- Cause 1: Description [chunk_Y]
|
| 334 |
+
- Cause 2: Description [chunk_Z]
|
| 335 |
+
3. Diagnostic steps:
|
| 336 |
+
- Step 1: Check... [chunk_W]
|
| 337 |
+
- Step 2: Verify... [chunk_V]
|
| 338 |
+
4. Solutions for each cause
|
| 339 |
+
5. Prevention recommendations"""
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
@staticmethod
|
| 343 |
+
def get_general_template() -> PromptTemplate:
|
| 344 |
+
"""Default template for general queries."""
|
| 345 |
+
return PromptTemplate(
|
| 346 |
+
system_prompt=TechnicalPromptTemplates.get_base_system_prompt(),
|
| 347 |
+
|
| 348 |
+
context_format="""Technical Documentation:
|
| 349 |
+
{context}""",
|
| 350 |
+
|
| 351 |
+
query_format="""Question: {query}
|
| 352 |
+
|
| 353 |
+
Provide a comprehensive technical answer based on the documentation.""",
|
| 354 |
+
|
| 355 |
+
answer_guidelines="""Provide a clear, comprehensive answer that directly addresses the question. Include relevant technical details and cite your sources using [chunk_X] notation. Write naturally and conversationally while maintaining technical accuracy. Acknowledge any limitations in available information."""
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
@staticmethod
|
| 359 |
+
def detect_query_type(query: str) -> QueryType:
|
| 360 |
+
"""
|
| 361 |
+
Detect the type of query based on keywords and patterns.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
query: User's question
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
Detected QueryType
|
| 368 |
+
"""
|
| 369 |
+
query_lower = query.lower()
|
| 370 |
+
|
| 371 |
+
# Definition keywords
|
| 372 |
+
if any(keyword in query_lower for keyword in [
|
| 373 |
+
"what is", "what are", "define", "definition", "meaning of", "explain what"
|
| 374 |
+
]):
|
| 375 |
+
return QueryType.DEFINITION
|
| 376 |
+
|
| 377 |
+
# Implementation keywords
|
| 378 |
+
if any(keyword in query_lower for keyword in [
|
| 379 |
+
"how to", "how do i", "implement", "setup", "configure", "install"
|
| 380 |
+
]):
|
| 381 |
+
return QueryType.IMPLEMENTATION
|
| 382 |
+
|
| 383 |
+
# Comparison keywords
|
| 384 |
+
if any(keyword in query_lower for keyword in [
|
| 385 |
+
"difference between", "compare", "vs", "versus", "better than", "which is"
|
| 386 |
+
]):
|
| 387 |
+
return QueryType.COMPARISON
|
| 388 |
+
|
| 389 |
+
# Specification keywords
|
| 390 |
+
if any(keyword in query_lower for keyword in [
|
| 391 |
+
"specification", "specs", "parameters", "limits", "range", "maximum", "minimum"
|
| 392 |
+
]):
|
| 393 |
+
return QueryType.SPECIFICATION
|
| 394 |
+
|
| 395 |
+
# Code example keywords
|
| 396 |
+
if any(keyword in query_lower for keyword in [
|
| 397 |
+
"example", "code", "snippet", "sample", "demo", "show me"
|
| 398 |
+
]):
|
| 399 |
+
return QueryType.CODE_EXAMPLE
|
| 400 |
+
|
| 401 |
+
# Hardware constraint keywords
|
| 402 |
+
if any(keyword in query_lower for keyword in [
|
| 403 |
+
"memory", "ram", "flash", "mcu", "constraint", "fit on", "run on", "power consumption"
|
| 404 |
+
]):
|
| 405 |
+
return QueryType.HARDWARE_CONSTRAINT
|
| 406 |
+
|
| 407 |
+
# Troubleshooting keywords
|
| 408 |
+
if any(keyword in query_lower for keyword in [
|
| 409 |
+
"error", "problem", "issue", "debug", "troubleshoot", "fix", "solve", "not working"
|
| 410 |
+
]):
|
| 411 |
+
return QueryType.TROUBLESHOOTING
|
| 412 |
+
|
| 413 |
+
return QueryType.GENERAL
|
| 414 |
+
|
| 415 |
+
@staticmethod
|
| 416 |
+
def get_template_for_query(query: str) -> PromptTemplate:
|
| 417 |
+
"""
|
| 418 |
+
Get the appropriate template based on query type.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
query: User's question
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
Appropriate PromptTemplate
|
| 425 |
+
"""
|
| 426 |
+
query_type = TechnicalPromptTemplates.detect_query_type(query)
|
| 427 |
+
|
| 428 |
+
template_map = {
|
| 429 |
+
QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template,
|
| 430 |
+
QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template,
|
| 431 |
+
QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template,
|
| 432 |
+
QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template,
|
| 433 |
+
QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template,
|
| 434 |
+
QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template,
|
| 435 |
+
QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template,
|
| 436 |
+
QueryType.GENERAL: TechnicalPromptTemplates.get_general_template
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
return template_map[query_type]()
|
| 440 |
+
|
| 441 |
+
@staticmethod
|
| 442 |
+
def format_prompt_with_template(
|
| 443 |
+
query: str,
|
| 444 |
+
context: str,
|
| 445 |
+
template: Optional[PromptTemplate] = None,
|
| 446 |
+
include_few_shot: bool = True
|
| 447 |
+
) -> Dict[str, str]:
|
| 448 |
+
"""
|
| 449 |
+
Format a complete prompt using the appropriate template.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
query: User's question
|
| 453 |
+
context: Retrieved context chunks
|
| 454 |
+
template: Optional specific template (auto-detected if None)
|
| 455 |
+
include_few_shot: Whether to include few-shot examples
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
Dict with 'system' and 'user' prompts
|
| 459 |
+
"""
|
| 460 |
+
if template is None:
|
| 461 |
+
template = TechnicalPromptTemplates.get_template_for_query(query)
|
| 462 |
+
|
| 463 |
+
# Format the context
|
| 464 |
+
formatted_context = template.context_format.format(context=context)
|
| 465 |
+
|
| 466 |
+
# Format the query
|
| 467 |
+
formatted_query = template.query_format.format(query=query)
|
| 468 |
+
|
| 469 |
+
# Build user prompt with optional few-shot examples
|
| 470 |
+
user_prompt_parts = []
|
| 471 |
+
|
| 472 |
+
# Add few-shot examples if available and requested
|
| 473 |
+
if include_few_shot and template.few_shot_examples:
|
| 474 |
+
user_prompt_parts.append("Here are some examples of how to answer similar questions:")
|
| 475 |
+
user_prompt_parts.append("\n\n".join(template.few_shot_examples))
|
| 476 |
+
user_prompt_parts.append("\nNow answer the following question using the same format:")
|
| 477 |
+
|
| 478 |
+
user_prompt_parts.extend([
|
| 479 |
+
formatted_context,
|
| 480 |
+
formatted_query,
|
| 481 |
+
template.answer_guidelines
|
| 482 |
+
])
|
| 483 |
+
|
| 484 |
+
user_prompt = "\n\n".join(user_prompt_parts)
|
| 485 |
+
|
| 486 |
+
return {
|
| 487 |
+
"system": template.system_prompt,
|
| 488 |
+
"user": user_prompt
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
# Example usage and testing
|
| 493 |
+
if __name__ == "__main__":
|
| 494 |
+
# Test query type detection
|
| 495 |
+
test_queries = [
|
| 496 |
+
"What is RISC-V?",
|
| 497 |
+
"How do I implement a timer interrupt?",
|
| 498 |
+
"What's the difference between FreeRTOS and Zephyr?",
|
| 499 |
+
"What are the memory specifications for STM32F4?",
|
| 500 |
+
"Show me an example of GPIO configuration",
|
| 501 |
+
"Can this model run on an MCU with 256KB RAM?",
|
| 502 |
+
"Debug error: undefined reference to main"
|
| 503 |
+
]
|
| 504 |
+
|
| 505 |
+
for query in test_queries:
|
| 506 |
+
query_type = TechnicalPromptTemplates.detect_query_type(query)
|
| 507 |
+
print(f"Query: '{query}' -> Type: {query_type.value}")
|
| 508 |
+
|
| 509 |
+
# Example prompt formatting
|
| 510 |
+
example_context = "RISC-V is an open instruction set architecture..."
|
| 511 |
+
example_query = "What is RISC-V?"
|
| 512 |
+
|
| 513 |
+
formatted = TechnicalPromptTemplates.format_prompt_with_template(
|
| 514 |
+
query=example_query,
|
| 515 |
+
context=example_context
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
print("\nFormatted prompt example:")
|
| 519 |
+
print("System:", formatted["system"][:100], "...")
|
| 520 |
+
print("User:", formatted["user"][:200], "...")
|
shared_utils/query_processing/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Query processing utilities for intelligent RAG systems.
|
| 3 |
+
Provides query enhancement, analysis, and optimization capabilities.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .query_enhancer import QueryEnhancer
|
| 7 |
+
|
| 8 |
+
__all__ = ['QueryEnhancer']
|
shared_utils/query_processing/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (406 Bytes). View file
|
|
|
shared_utils/query_processing/__pycache__/query_enhancer.cpython-312.pyc
ADDED
|
Binary file (24.3 kB). View file
|
|
|
shared_utils/query_processing/query_enhancer.py
ADDED
|
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Intelligent query processing for technical documentation RAG.
|
| 3 |
+
|
| 4 |
+
Provides adaptive query enhancement through technical term expansion,
|
| 5 |
+
acronym handling, and intelligent hybrid weighting optimization.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, List, Any, Tuple, Set, Optional
|
| 9 |
+
import re
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class QueryEnhancer:
|
| 15 |
+
"""
|
| 16 |
+
Intelligent query processing for technical documentation RAG.
|
| 17 |
+
|
| 18 |
+
Analyzes query characteristics and enhances retrieval through:
|
| 19 |
+
- Technical synonym expansion
|
| 20 |
+
- Acronym detection and expansion
|
| 21 |
+
- Adaptive hybrid weighting based on query type
|
| 22 |
+
- Query complexity analysis for optimal retrieval strategy
|
| 23 |
+
|
| 24 |
+
Optimized for embedded systems and technical documentation domains.
|
| 25 |
+
|
| 26 |
+
Performance: <10ms query enhancement, improves retrieval relevance by >10%
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
"""Initialize QueryEnhancer with technical domain knowledge."""
|
| 31 |
+
|
| 32 |
+
# Technical vocabulary dictionary organized by domain
|
| 33 |
+
self.technical_synonyms = {
|
| 34 |
+
# Processor terminology
|
| 35 |
+
'cpu': ['processor', 'microprocessor', 'central processing unit'],
|
| 36 |
+
'mcu': ['microcontroller', 'microcontroller unit', 'embedded processor'],
|
| 37 |
+
'core': ['processor core', 'cpu core', 'execution unit'],
|
| 38 |
+
'alu': ['arithmetic logic unit', 'arithmetic unit'],
|
| 39 |
+
|
| 40 |
+
# Memory terminology
|
| 41 |
+
'memory': ['ram', 'storage', 'buffer', 'cache'],
|
| 42 |
+
'flash': ['non-volatile memory', 'program memory', 'code storage'],
|
| 43 |
+
'sram': ['static ram', 'static memory', 'cache memory'],
|
| 44 |
+
'dram': ['dynamic ram', 'dynamic memory'],
|
| 45 |
+
'cache': ['buffer', 'temporary storage', 'fast memory'],
|
| 46 |
+
|
| 47 |
+
# Architecture terminology
|
| 48 |
+
'risc-v': ['riscv', 'risc v', 'open isa', 'open instruction set'],
|
| 49 |
+
'arm': ['advanced risc machine', 'acorn risc machine'],
|
| 50 |
+
'isa': ['instruction set architecture', 'instruction set'],
|
| 51 |
+
'architecture': ['design', 'structure', 'organization'],
|
| 52 |
+
|
| 53 |
+
# Embedded systems terminology
|
| 54 |
+
'rtos': ['real-time operating system', 'real-time os'],
|
| 55 |
+
'interrupt': ['isr', 'interrupt service routine', 'exception handler'],
|
| 56 |
+
'peripheral': ['hardware peripheral', 'external device', 'io device'],
|
| 57 |
+
'firmware': ['embedded software', 'system software'],
|
| 58 |
+
'bootloader': ['boot code', 'initialization code'],
|
| 59 |
+
|
| 60 |
+
# Performance terminology
|
| 61 |
+
'latency': ['delay', 'response time', 'execution time'],
|
| 62 |
+
'throughput': ['bandwidth', 'data rate', 'performance'],
|
| 63 |
+
'power': ['power consumption', 'energy usage', 'battery life'],
|
| 64 |
+
'optimization': ['improvement', 'enhancement', 'tuning'],
|
| 65 |
+
|
| 66 |
+
# Communication protocols
|
| 67 |
+
'uart': ['serial communication', 'async serial'],
|
| 68 |
+
'spi': ['serial peripheral interface', 'synchronous serial'],
|
| 69 |
+
'i2c': ['inter-integrated circuit', 'two-wire interface'],
|
| 70 |
+
'usb': ['universal serial bus'],
|
| 71 |
+
|
| 72 |
+
# Development terminology
|
| 73 |
+
'debug': ['debugging', 'troubleshooting', 'testing'],
|
| 74 |
+
'compile': ['compilation', 'build', 'assembly'],
|
| 75 |
+
'programming': ['coding', 'development', 'implementation']
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
# Comprehensive acronym expansions for embedded/technical domains
|
| 79 |
+
self.acronym_expansions = {
|
| 80 |
+
# Processor & Architecture
|
| 81 |
+
'CPU': 'Central Processing Unit',
|
| 82 |
+
'MCU': 'Microcontroller Unit',
|
| 83 |
+
'MPU': 'Microprocessor Unit',
|
| 84 |
+
'DSP': 'Digital Signal Processor',
|
| 85 |
+
'GPU': 'Graphics Processing Unit',
|
| 86 |
+
'ALU': 'Arithmetic Logic Unit',
|
| 87 |
+
'FPU': 'Floating Point Unit',
|
| 88 |
+
'MMU': 'Memory Management Unit',
|
| 89 |
+
'ISA': 'Instruction Set Architecture',
|
| 90 |
+
'RISC': 'Reduced Instruction Set Computer',
|
| 91 |
+
'CISC': 'Complex Instruction Set Computer',
|
| 92 |
+
|
| 93 |
+
# Memory & Storage
|
| 94 |
+
'RAM': 'Random Access Memory',
|
| 95 |
+
'ROM': 'Read Only Memory',
|
| 96 |
+
'EEPROM': 'Electrically Erasable Programmable ROM',
|
| 97 |
+
'SRAM': 'Static Random Access Memory',
|
| 98 |
+
'DRAM': 'Dynamic Random Access Memory',
|
| 99 |
+
'FRAM': 'Ferroelectric Random Access Memory',
|
| 100 |
+
'MRAM': 'Magnetoresistive Random Access Memory',
|
| 101 |
+
'DMA': 'Direct Memory Access',
|
| 102 |
+
|
| 103 |
+
# Operating Systems & Software
|
| 104 |
+
'RTOS': 'Real-Time Operating System',
|
| 105 |
+
'OS': 'Operating System',
|
| 106 |
+
'API': 'Application Programming Interface',
|
| 107 |
+
'SDK': 'Software Development Kit',
|
| 108 |
+
'IDE': 'Integrated Development Environment',
|
| 109 |
+
'HAL': 'Hardware Abstraction Layer',
|
| 110 |
+
'BSP': 'Board Support Package',
|
| 111 |
+
|
| 112 |
+
# Interrupts & Exceptions
|
| 113 |
+
'ISR': 'Interrupt Service Routine',
|
| 114 |
+
'IRQ': 'Interrupt Request',
|
| 115 |
+
'NMI': 'Non-Maskable Interrupt',
|
| 116 |
+
'NVIC': 'Nested Vectored Interrupt Controller',
|
| 117 |
+
|
| 118 |
+
# Communication Protocols
|
| 119 |
+
'UART': 'Universal Asynchronous Receiver Transmitter',
|
| 120 |
+
'USART': 'Universal Synchronous Asynchronous Receiver Transmitter',
|
| 121 |
+
'SPI': 'Serial Peripheral Interface',
|
| 122 |
+
'I2C': 'Inter-Integrated Circuit',
|
| 123 |
+
'CAN': 'Controller Area Network',
|
| 124 |
+
'USB': 'Universal Serial Bus',
|
| 125 |
+
'TCP': 'Transmission Control Protocol',
|
| 126 |
+
'UDP': 'User Datagram Protocol',
|
| 127 |
+
'HTTP': 'HyperText Transfer Protocol',
|
| 128 |
+
'MQTT': 'Message Queuing Telemetry Transport',
|
| 129 |
+
|
| 130 |
+
# Analog & Digital
|
| 131 |
+
'ADC': 'Analog to Digital Converter',
|
| 132 |
+
'DAC': 'Digital to Analog Converter',
|
| 133 |
+
'PWM': 'Pulse Width Modulation',
|
| 134 |
+
'GPIO': 'General Purpose Input Output',
|
| 135 |
+
'JTAG': 'Joint Test Action Group',
|
| 136 |
+
'SWD': 'Serial Wire Debug',
|
| 137 |
+
|
| 138 |
+
# Power & Clock
|
| 139 |
+
'PLL': 'Phase Locked Loop',
|
| 140 |
+
'VCO': 'Voltage Controlled Oscillator',
|
| 141 |
+
'LDO': 'Low Dropout Regulator',
|
| 142 |
+
'PMU': 'Power Management Unit',
|
| 143 |
+
'RTC': 'Real Time Clock',
|
| 144 |
+
|
| 145 |
+
# Standards & Organizations
|
| 146 |
+
'IEEE': 'Institute of Electrical and Electronics Engineers',
|
| 147 |
+
'ISO': 'International Organization for Standardization',
|
| 148 |
+
'ANSI': 'American National Standards Institute',
|
| 149 |
+
'IEC': 'International Electrotechnical Commission'
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
# Compile regex patterns for efficiency
|
| 153 |
+
self._acronym_pattern = re.compile(r'\b[A-Z]{2,}\b')
|
| 154 |
+
self._technical_term_pattern = re.compile(r'\b\w+(?:-\w+)*\b', re.IGNORECASE)
|
| 155 |
+
self._question_indicators = re.compile(r'\b(?:how|what|why|when|where|which|explain|describe|define)\b', re.IGNORECASE)
|
| 156 |
+
|
| 157 |
+
# Question type classification keywords
|
| 158 |
+
self.question_type_keywords = {
|
| 159 |
+
'conceptual': ['how', 'why', 'what', 'explain', 'describe', 'understand', 'concept', 'theory'],
|
| 160 |
+
'technical': ['configure', 'implement', 'setup', 'install', 'code', 'program', 'register'],
|
| 161 |
+
'procedural': ['steps', 'process', 'procedure', 'workflow', 'guide', 'tutorial'],
|
| 162 |
+
'troubleshooting': ['error', 'problem', 'issue', 'debug', 'fix', 'solve', 'troubleshoot']
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
def analyze_query_characteristics(self, query: str) -> Dict[str, Any]:
|
| 166 |
+
"""
|
| 167 |
+
Analyze query to determine optimal processing strategy.
|
| 168 |
+
|
| 169 |
+
Performs comprehensive analysis including:
|
| 170 |
+
- Technical term detection and counting
|
| 171 |
+
- Acronym presence identification
|
| 172 |
+
- Question type classification
|
| 173 |
+
- Complexity scoring based on multiple factors
|
| 174 |
+
- Optimal hybrid weight recommendation
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
query: User input query string
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Dictionary with comprehensive query analysis:
|
| 181 |
+
- technical_term_count: Number of domain-specific terms detected
|
| 182 |
+
- has_acronyms: Boolean indicating acronym presence
|
| 183 |
+
- question_type: 'conceptual', 'technical', 'procedural', 'mixed'
|
| 184 |
+
- complexity_score: Float 0-1 indicating query complexity
|
| 185 |
+
- recommended_dense_weight: Optimal weight for hybrid search
|
| 186 |
+
- detected_acronyms: List of acronyms found
|
| 187 |
+
- technical_terms: List of technical terms found
|
| 188 |
+
|
| 189 |
+
Performance: <2ms for typical queries
|
| 190 |
+
"""
|
| 191 |
+
if not query or not query.strip():
|
| 192 |
+
return {
|
| 193 |
+
'technical_term_count': 0,
|
| 194 |
+
'has_acronyms': False,
|
| 195 |
+
'question_type': 'unknown',
|
| 196 |
+
'complexity_score': 0.0,
|
| 197 |
+
'recommended_dense_weight': 0.7,
|
| 198 |
+
'detected_acronyms': [],
|
| 199 |
+
'technical_terms': []
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
query_lower = query.lower()
|
| 203 |
+
words = query.split()
|
| 204 |
+
|
| 205 |
+
# Detect acronyms
|
| 206 |
+
detected_acronyms = self._acronym_pattern.findall(query)
|
| 207 |
+
has_acronyms = len(detected_acronyms) > 0
|
| 208 |
+
|
| 209 |
+
# Detect technical terms
|
| 210 |
+
technical_terms = []
|
| 211 |
+
technical_term_count = 0
|
| 212 |
+
|
| 213 |
+
for word in words:
|
| 214 |
+
word_clean = re.sub(r'[^\w\-]', '', word.lower())
|
| 215 |
+
if word_clean in self.technical_synonyms:
|
| 216 |
+
technical_terms.append(word_clean)
|
| 217 |
+
technical_term_count += 1
|
| 218 |
+
# Also check for compound technical terms like "risc-v"
|
| 219 |
+
elif any(term in word_clean for term in ['risc-v', 'arm', 'cpu', 'mcu']):
|
| 220 |
+
technical_terms.append(word_clean)
|
| 221 |
+
technical_term_count += 1
|
| 222 |
+
|
| 223 |
+
# Add acronyms to technical term count
|
| 224 |
+
for acronym in detected_acronyms:
|
| 225 |
+
if acronym in self.acronym_expansions:
|
| 226 |
+
technical_term_count += 1
|
| 227 |
+
|
| 228 |
+
# Determine question type
|
| 229 |
+
question_type = self._classify_question_type(query_lower)
|
| 230 |
+
|
| 231 |
+
# Calculate complexity score (0-1)
|
| 232 |
+
complexity_factors = [
|
| 233 |
+
len(words) / 20.0, # Word count factor (normalized to 20 words max)
|
| 234 |
+
technical_term_count / 5.0, # Technical density (normalized to 5 terms max)
|
| 235 |
+
len(detected_acronyms) / 3.0, # Acronym density (normalized to 3 acronyms max)
|
| 236 |
+
1.0 if self._question_indicators.search(query) else 0.5, # Question complexity
|
| 237 |
+
]
|
| 238 |
+
complexity_score = min(1.0, sum(complexity_factors) / len(complexity_factors))
|
| 239 |
+
|
| 240 |
+
# Determine recommended dense weight based on analysis
|
| 241 |
+
recommended_dense_weight = self._calculate_optimal_weight(
|
| 242 |
+
question_type, technical_term_count, has_acronyms, complexity_score
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
return {
|
| 246 |
+
'technical_term_count': technical_term_count,
|
| 247 |
+
'has_acronyms': has_acronyms,
|
| 248 |
+
'question_type': question_type,
|
| 249 |
+
'complexity_score': complexity_score,
|
| 250 |
+
'recommended_dense_weight': recommended_dense_weight,
|
| 251 |
+
'detected_acronyms': detected_acronyms,
|
| 252 |
+
'technical_terms': technical_terms,
|
| 253 |
+
'word_count': len(words),
|
| 254 |
+
'has_question_indicators': bool(self._question_indicators.search(query))
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
def _classify_question_type(self, query_lower: str) -> str:
|
| 258 |
+
"""Classify query into conceptual, technical, procedural, or mixed categories."""
|
| 259 |
+
type_scores = defaultdict(int)
|
| 260 |
+
|
| 261 |
+
for question_type, keywords in self.question_type_keywords.items():
|
| 262 |
+
for keyword in keywords:
|
| 263 |
+
if keyword in query_lower:
|
| 264 |
+
type_scores[question_type] += 1
|
| 265 |
+
|
| 266 |
+
if not type_scores:
|
| 267 |
+
return 'mixed'
|
| 268 |
+
|
| 269 |
+
# Return type with highest score, or 'mixed' if tie
|
| 270 |
+
max_score = max(type_scores.values())
|
| 271 |
+
top_types = [t for t, s in type_scores.items() if s == max_score]
|
| 272 |
+
|
| 273 |
+
return top_types[0] if len(top_types) == 1 else 'mixed'
|
| 274 |
+
|
| 275 |
+
def _calculate_optimal_weight(self, question_type: str, tech_terms: int,
|
| 276 |
+
has_acronyms: bool, complexity: float) -> float:
|
| 277 |
+
"""Calculate optimal dense weight based on query characteristics."""
|
| 278 |
+
|
| 279 |
+
# Base weights by question type
|
| 280 |
+
base_weights = {
|
| 281 |
+
'technical': 0.3, # Favor sparse for technical precision
|
| 282 |
+
'conceptual': 0.8, # Favor dense for conceptual understanding
|
| 283 |
+
'procedural': 0.5, # Balanced for step-by-step queries
|
| 284 |
+
'troubleshooting': 0.4, # Slight sparse favor for specific issues
|
| 285 |
+
'mixed': 0.7, # Default balanced
|
| 286 |
+
'unknown': 0.7 # Default balanced
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
weight = base_weights.get(question_type, 0.7)
|
| 290 |
+
|
| 291 |
+
# Adjust based on technical term density
|
| 292 |
+
if tech_terms > 2:
|
| 293 |
+
weight -= 0.2 # More technical → favor sparse
|
| 294 |
+
elif tech_terms == 0:
|
| 295 |
+
weight += 0.1 # Less technical → favor dense
|
| 296 |
+
|
| 297 |
+
# Adjust based on acronym presence
|
| 298 |
+
if has_acronyms:
|
| 299 |
+
weight -= 0.1 # Acronyms → favor sparse for exact matching
|
| 300 |
+
|
| 301 |
+
# Adjust based on complexity
|
| 302 |
+
if complexity > 0.8:
|
| 303 |
+
weight += 0.1 # High complexity → favor dense for understanding
|
| 304 |
+
elif complexity < 0.3:
|
| 305 |
+
weight -= 0.1 # Low complexity → favor sparse for precision
|
| 306 |
+
|
| 307 |
+
# Ensure weight stays within valid bounds
|
| 308 |
+
return max(0.1, min(0.9, weight))
|
| 309 |
+
|
| 310 |
+
def expand_technical_terms(self, query: str, max_expansions: int = 1) -> str:
|
| 311 |
+
"""
|
| 312 |
+
Expand query with technical synonyms while preventing bloat.
|
| 313 |
+
|
| 314 |
+
Conservative expansion strategy:
|
| 315 |
+
- Maximum 1 synonym per technical term by default
|
| 316 |
+
- Prioritizes most relevant/common synonyms
|
| 317 |
+
- Maintains semantic focus while improving recall
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
query: Original user query
|
| 321 |
+
max_expansions: Maximum synonyms per term (default 1 for focus)
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
Conservatively enhanced query
|
| 325 |
+
|
| 326 |
+
Example:
|
| 327 |
+
Input: "CPU performance optimization"
|
| 328 |
+
Output: "CPU processor performance optimization"
|
| 329 |
+
|
| 330 |
+
Performance: <3ms for typical queries
|
| 331 |
+
"""
|
| 332 |
+
if not query or not query.strip():
|
| 333 |
+
return query
|
| 334 |
+
|
| 335 |
+
words = query.split()
|
| 336 |
+
|
| 337 |
+
# Conservative expansion: only add most relevant synonym
|
| 338 |
+
expansion_candidates = []
|
| 339 |
+
|
| 340 |
+
for word in words:
|
| 341 |
+
word_clean = re.sub(r'[^\w\-]', '', word.lower())
|
| 342 |
+
|
| 343 |
+
# Check for direct synonym expansion
|
| 344 |
+
if word_clean in self.technical_synonyms:
|
| 345 |
+
synonyms = self.technical_synonyms[word_clean]
|
| 346 |
+
# Add only the first (most common) synonym
|
| 347 |
+
if synonyms and max_expansions > 0:
|
| 348 |
+
expansion_candidates.append(synonyms[0])
|
| 349 |
+
|
| 350 |
+
# Limit total expansion to prevent bloat
|
| 351 |
+
max_total_expansions = min(2, len(words) // 2) # At most 50% expansion
|
| 352 |
+
selected_expansions = expansion_candidates[:max_total_expansions]
|
| 353 |
+
|
| 354 |
+
# Reconstruct with minimal expansion
|
| 355 |
+
if selected_expansions:
|
| 356 |
+
return ' '.join(words + selected_expansions)
|
| 357 |
+
else:
|
| 358 |
+
return query
|
| 359 |
+
|
| 360 |
+
def detect_and_expand_acronyms(self, query: str, conservative: bool = True) -> str:
|
| 361 |
+
"""
|
| 362 |
+
Detect technical acronyms and add their expansions conservatively.
|
| 363 |
+
|
| 364 |
+
Conservative approach to prevent query bloat:
|
| 365 |
+
- Limits acronym expansions to most relevant ones
|
| 366 |
+
- Preserves original acronyms for exact matching
|
| 367 |
+
- Maintains query focus and performance
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
query: Query potentially containing acronyms
|
| 371 |
+
conservative: If True, limits expansion to prevent bloat
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
Query with selective acronym expansions
|
| 375 |
+
|
| 376 |
+
Example:
|
| 377 |
+
Input: "RTOS scheduling algorithm"
|
| 378 |
+
Output: "RTOS Real-Time Operating System scheduling algorithm"
|
| 379 |
+
|
| 380 |
+
Performance: <2ms for typical queries
|
| 381 |
+
"""
|
| 382 |
+
if not query or not query.strip():
|
| 383 |
+
return query
|
| 384 |
+
|
| 385 |
+
# Find all acronyms in the query
|
| 386 |
+
acronyms = self._acronym_pattern.findall(query)
|
| 387 |
+
|
| 388 |
+
if not acronyms:
|
| 389 |
+
return query
|
| 390 |
+
|
| 391 |
+
# Conservative mode: limit expansions
|
| 392 |
+
if conservative and len(acronyms) > 2:
|
| 393 |
+
# Only expand first 2 acronyms to prevent bloat
|
| 394 |
+
acronyms = acronyms[:2]
|
| 395 |
+
|
| 396 |
+
result = query
|
| 397 |
+
|
| 398 |
+
# Expand selected acronyms
|
| 399 |
+
for acronym in acronyms:
|
| 400 |
+
if acronym in self.acronym_expansions:
|
| 401 |
+
expansion = self.acronym_expansions[acronym]
|
| 402 |
+
# Add expansion after the acronym (preserving original)
|
| 403 |
+
result = result.replace(acronym, f"{acronym} {expansion}", 1)
|
| 404 |
+
|
| 405 |
+
return result
|
| 406 |
+
|
| 407 |
+
def adaptive_hybrid_weighting(self, query: str) -> float:
|
| 408 |
+
"""
|
| 409 |
+
Determine optimal dense_weight based on query characteristics.
|
| 410 |
+
|
| 411 |
+
Analyzes query to automatically determine the best balance between
|
| 412 |
+
dense semantic search and sparse keyword matching for optimal results.
|
| 413 |
+
|
| 414 |
+
Strategy:
|
| 415 |
+
- Technical/exact queries → lower dense_weight (favor sparse/BM25)
|
| 416 |
+
- Conceptual questions → higher dense_weight (favor semantic)
|
| 417 |
+
- Mixed queries → balanced weighting based on complexity
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
query: User query string
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
Float between 0.1 and 0.9 representing optimal dense_weight
|
| 424 |
+
|
| 425 |
+
Performance: <2ms analysis time
|
| 426 |
+
"""
|
| 427 |
+
analysis = self.analyze_query_characteristics(query)
|
| 428 |
+
return analysis['recommended_dense_weight']
|
| 429 |
+
|
| 430 |
+
def enhance_query(self, query: str, conservative: bool = True) -> Dict[str, Any]:
|
| 431 |
+
"""
|
| 432 |
+
Comprehensive query enhancement with performance and quality focus.
|
| 433 |
+
|
| 434 |
+
Optimized enhancement strategy:
|
| 435 |
+
- Conservative expansion to maintain semantic focus
|
| 436 |
+
- Performance-first approach with minimal overhead
|
| 437 |
+
- Quality validation to ensure improvements
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
query: Original user query
|
| 441 |
+
conservative: Use conservative expansion (recommended for production)
|
| 442 |
+
|
| 443 |
+
Returns:
|
| 444 |
+
Dictionary containing:
|
| 445 |
+
- enhanced_query: Optimized enhanced query
|
| 446 |
+
- optimal_weight: Recommended dense weight
|
| 447 |
+
- analysis: Complete query analysis
|
| 448 |
+
- enhancement_metadata: Performance and quality metrics
|
| 449 |
+
|
| 450 |
+
Performance: <5ms total enhancement time
|
| 451 |
+
"""
|
| 452 |
+
start_time = time.perf_counter()
|
| 453 |
+
|
| 454 |
+
# Fast analysis
|
| 455 |
+
analysis = self.analyze_query_characteristics(query)
|
| 456 |
+
|
| 457 |
+
# Conservative enhancement approach
|
| 458 |
+
if conservative:
|
| 459 |
+
enhanced_query = self.expand_technical_terms(query, max_expansions=1)
|
| 460 |
+
enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True)
|
| 461 |
+
else:
|
| 462 |
+
# Legacy aggressive expansion
|
| 463 |
+
enhanced_query = self.expand_technical_terms(query, max_expansions=2)
|
| 464 |
+
enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=False)
|
| 465 |
+
|
| 466 |
+
# Quality validation: prevent excessive bloat
|
| 467 |
+
expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0
|
| 468 |
+
if expansion_ratio > 2.5: # Limit to 2.5x expansion
|
| 469 |
+
# Fallback to minimal enhancement
|
| 470 |
+
enhanced_query = self.expand_technical_terms(query, max_expansions=0)
|
| 471 |
+
enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True)
|
| 472 |
+
expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0
|
| 473 |
+
|
| 474 |
+
# Calculate optimal weight
|
| 475 |
+
optimal_weight = analysis['recommended_dense_weight']
|
| 476 |
+
|
| 477 |
+
enhancement_time = time.perf_counter() - start_time
|
| 478 |
+
|
| 479 |
+
return {
|
| 480 |
+
'enhanced_query': enhanced_query,
|
| 481 |
+
'optimal_weight': optimal_weight,
|
| 482 |
+
'analysis': analysis,
|
| 483 |
+
'enhancement_metadata': {
|
| 484 |
+
'original_length': len(query.split()),
|
| 485 |
+
'enhanced_length': len(enhanced_query.split()),
|
| 486 |
+
'expansion_ratio': expansion_ratio,
|
| 487 |
+
'processing_time_ms': enhancement_time * 1000,
|
| 488 |
+
'techniques_applied': ['conservative_expansion', 'quality_validation', 'adaptive_weighting'],
|
| 489 |
+
'conservative_mode': conservative
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
def expand_technical_terms_with_vocabulary(
|
| 494 |
+
self,
|
| 495 |
+
query: str,
|
| 496 |
+
vocabulary_index: Optional['VocabularyIndex'] = None,
|
| 497 |
+
min_frequency: int = 3
|
| 498 |
+
) -> str:
|
| 499 |
+
"""
|
| 500 |
+
Expand query with vocabulary-aware synonym filtering.
|
| 501 |
+
|
| 502 |
+
Only adds synonyms that exist in the document corpus with sufficient
|
| 503 |
+
frequency to ensure relevance and prevent query dilution.
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
query: Original query
|
| 507 |
+
vocabulary_index: Optional vocabulary index for filtering
|
| 508 |
+
min_frequency: Minimum term frequency required
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
Enhanced query with validated synonyms
|
| 512 |
+
|
| 513 |
+
Performance: <2ms with vocabulary validation
|
| 514 |
+
"""
|
| 515 |
+
if not query or not query.strip():
|
| 516 |
+
return query
|
| 517 |
+
|
| 518 |
+
if vocabulary_index is None:
|
| 519 |
+
# Fallback to standard expansion
|
| 520 |
+
return self.expand_technical_terms(query, max_expansions=1)
|
| 521 |
+
|
| 522 |
+
words = query.split()
|
| 523 |
+
expanded_terms = []
|
| 524 |
+
|
| 525 |
+
for word in words:
|
| 526 |
+
word_clean = re.sub(r'[^\w\-]', '', word.lower())
|
| 527 |
+
|
| 528 |
+
# Check for synonym expansion
|
| 529 |
+
if word_clean in self.technical_synonyms:
|
| 530 |
+
synonyms = self.technical_synonyms[word_clean]
|
| 531 |
+
|
| 532 |
+
# Filter synonyms through vocabulary
|
| 533 |
+
valid_synonyms = vocabulary_index.filter_synonyms(
|
| 534 |
+
synonyms,
|
| 535 |
+
min_frequency=min_frequency
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# Add only the best valid synonym
|
| 539 |
+
if valid_synonyms:
|
| 540 |
+
expanded_terms.append(valid_synonyms[0])
|
| 541 |
+
|
| 542 |
+
# Reconstruct query with validated expansions
|
| 543 |
+
if expanded_terms:
|
| 544 |
+
return ' '.join(words + expanded_terms)
|
| 545 |
+
else:
|
| 546 |
+
return query
|
| 547 |
+
|
| 548 |
+
def enhance_query_with_vocabulary(
|
| 549 |
+
self,
|
| 550 |
+
query: str,
|
| 551 |
+
vocabulary_index: Optional['VocabularyIndex'] = None,
|
| 552 |
+
min_frequency: int = 3,
|
| 553 |
+
require_technical: bool = False
|
| 554 |
+
) -> Dict[str, Any]:
|
| 555 |
+
"""
|
| 556 |
+
Enhanced query processing with vocabulary validation.
|
| 557 |
+
|
| 558 |
+
Uses corpus vocabulary to ensure all expansions are relevant
|
| 559 |
+
and actually present in the documents.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
query: Original query
|
| 563 |
+
vocabulary_index: Vocabulary index for validation
|
| 564 |
+
min_frequency: Minimum term frequency
|
| 565 |
+
require_technical: Only expand with technical terms
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
Enhanced query with vocabulary-aware expansion
|
| 569 |
+
"""
|
| 570 |
+
start_time = time.perf_counter()
|
| 571 |
+
|
| 572 |
+
# Perform analysis
|
| 573 |
+
analysis = self.analyze_query_characteristics(query)
|
| 574 |
+
|
| 575 |
+
# Vocabulary-aware enhancement
|
| 576 |
+
if vocabulary_index:
|
| 577 |
+
# Technical term expansion with validation
|
| 578 |
+
enhanced_query = self.expand_technical_terms_with_vocabulary(
|
| 579 |
+
query, vocabulary_index, min_frequency
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
# Acronym expansion (already conservative)
|
| 583 |
+
enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True)
|
| 584 |
+
|
| 585 |
+
# Track vocabulary validation
|
| 586 |
+
validation_applied = True
|
| 587 |
+
|
| 588 |
+
# Detect domain if available
|
| 589 |
+
detected_domain = vocabulary_index.detect_domain()
|
| 590 |
+
else:
|
| 591 |
+
# Fallback to standard enhancement
|
| 592 |
+
enhanced_query = self.expand_technical_terms(query, max_expansions=1)
|
| 593 |
+
enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True)
|
| 594 |
+
validation_applied = False
|
| 595 |
+
detected_domain = 'unknown'
|
| 596 |
+
|
| 597 |
+
# Calculate metrics
|
| 598 |
+
expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0
|
| 599 |
+
enhancement_time = time.perf_counter() - start_time
|
| 600 |
+
|
| 601 |
+
return {
|
| 602 |
+
'enhanced_query': enhanced_query,
|
| 603 |
+
'optimal_weight': analysis['recommended_dense_weight'],
|
| 604 |
+
'analysis': analysis,
|
| 605 |
+
'enhancement_metadata': {
|
| 606 |
+
'original_length': len(query.split()),
|
| 607 |
+
'enhanced_length': len(enhanced_query.split()),
|
| 608 |
+
'expansion_ratio': expansion_ratio,
|
| 609 |
+
'processing_time_ms': enhancement_time * 1000,
|
| 610 |
+
'techniques_applied': ['vocabulary_validation', 'conservative_expansion'],
|
| 611 |
+
'vocabulary_validated': validation_applied,
|
| 612 |
+
'detected_domain': detected_domain,
|
| 613 |
+
'min_frequency_threshold': min_frequency
|
| 614 |
+
}
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
def get_enhancement_stats(self) -> Dict[str, Any]:
|
| 618 |
+
"""
|
| 619 |
+
Get statistics about the enhancement system capabilities.
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
Dictionary with system statistics and capabilities
|
| 623 |
+
"""
|
| 624 |
+
return {
|
| 625 |
+
'technical_synonyms_count': len(self.technical_synonyms),
|
| 626 |
+
'acronym_expansions_count': len(self.acronym_expansions),
|
| 627 |
+
'supported_domains': [
|
| 628 |
+
'embedded_systems', 'processor_architecture', 'memory_systems',
|
| 629 |
+
'communication_protocols', 'real_time_systems', 'power_management'
|
| 630 |
+
],
|
| 631 |
+
'question_types_supported': list(self.question_type_keywords.keys()),
|
| 632 |
+
'weight_range': {'min': 0.1, 'max': 0.9, 'default': 0.7},
|
| 633 |
+
'performance_targets': {
|
| 634 |
+
'enhancement_time_ms': '<10',
|
| 635 |
+
'accuracy_improvement': '>10%',
|
| 636 |
+
'memory_overhead': '<1MB'
|
| 637 |
+
},
|
| 638 |
+
'vocabulary_features': {
|
| 639 |
+
'vocabulary_aware_expansion': True,
|
| 640 |
+
'min_frequency_filtering': True,
|
| 641 |
+
'domain_detection': True,
|
| 642 |
+
'technical_term_priority': True
|
| 643 |
+
}
|
| 644 |
+
}
|
shared_utils/retrieval/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Retrieval utilities for hybrid RAG systems.
|
| 3 |
+
Combines dense semantic search with sparse keyword matching.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .hybrid_search import HybridRetriever
|
| 7 |
+
|
| 8 |
+
__all__ = ['HybridRetriever']
|
shared_utils/retrieval/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (380 Bytes). View file
|
|
|
shared_utils/retrieval/__pycache__/hybrid_search.cpython-312.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
shared_utils/retrieval/__pycache__/vocabulary_index.cpython-312.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
shared_utils/retrieval/hybrid_search.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid retrieval combining dense semantic search with sparse BM25 keyword matching.
|
| 3 |
+
Uses Reciprocal Rank Fusion (RRF) to combine results from both approaches.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Dict, Tuple, Optional
|
| 7 |
+
import numpy as np
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
# Add project root to Python path for imports
|
| 12 |
+
project_root = Path(__file__).parent.parent.parent / "project-1-technical-rag"
|
| 13 |
+
sys.path.append(str(project_root))
|
| 14 |
+
|
| 15 |
+
from src.sparse_retrieval import BM25SparseRetriever
|
| 16 |
+
from src.fusion import reciprocal_rank_fusion, adaptive_fusion
|
| 17 |
+
from shared_utils.embeddings.generator import generate_embeddings
|
| 18 |
+
import faiss
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class HybridRetriever:
|
| 22 |
+
"""
|
| 23 |
+
Hybrid retrieval system combining dense semantic search with sparse BM25.
|
| 24 |
+
|
| 25 |
+
Optimized for technical documentation where both semantic similarity
|
| 26 |
+
and exact keyword matching are important for retrieval quality.
|
| 27 |
+
|
| 28 |
+
Performance: Sub-second search on 1000+ document corpus
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
dense_weight: float = 0.7,
|
| 34 |
+
embedding_model: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
| 35 |
+
use_mps: bool = True,
|
| 36 |
+
bm25_k1: float = 1.2,
|
| 37 |
+
bm25_b: float = 0.75,
|
| 38 |
+
rrf_k: int = 10
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Initialize hybrid retriever with dense and sparse components.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
dense_weight: Weight for semantic similarity in fusion (0.7 default)
|
| 45 |
+
embedding_model: Sentence transformer model name
|
| 46 |
+
use_mps: Use Apple Silicon MPS acceleration for embeddings
|
| 47 |
+
bm25_k1: BM25 term frequency saturation parameter
|
| 48 |
+
bm25_b: BM25 document length normalization parameter
|
| 49 |
+
rrf_k: Reciprocal Rank Fusion constant (1=strong rank preference, 2=moderate)
|
| 50 |
+
|
| 51 |
+
Raises:
|
| 52 |
+
ValueError: If parameters are invalid
|
| 53 |
+
"""
|
| 54 |
+
if not 0 <= dense_weight <= 1:
|
| 55 |
+
raise ValueError("dense_weight must be between 0 and 1")
|
| 56 |
+
|
| 57 |
+
self.dense_weight = dense_weight
|
| 58 |
+
self.embedding_model = embedding_model
|
| 59 |
+
self.use_mps = use_mps
|
| 60 |
+
self.rrf_k = rrf_k
|
| 61 |
+
|
| 62 |
+
# Initialize sparse retriever
|
| 63 |
+
self.sparse_retriever = BM25SparseRetriever(k1=bm25_k1, b=bm25_b)
|
| 64 |
+
|
| 65 |
+
# Dense retrieval components (initialized on first index)
|
| 66 |
+
self.dense_index: Optional[faiss.Index] = None
|
| 67 |
+
self.chunks: List[Dict] = []
|
| 68 |
+
self.embeddings: Optional[np.ndarray] = None
|
| 69 |
+
|
| 70 |
+
def index_documents(self, chunks: List[Dict]) -> None:
|
| 71 |
+
"""
|
| 72 |
+
Index documents for both dense and sparse retrieval.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
chunks: List of chunk dictionaries with 'text' field
|
| 76 |
+
|
| 77 |
+
Raises:
|
| 78 |
+
ValueError: If chunks is empty or malformed
|
| 79 |
+
|
| 80 |
+
Performance: ~100 chunks/second for complete indexing
|
| 81 |
+
"""
|
| 82 |
+
if not chunks:
|
| 83 |
+
raise ValueError("Cannot index empty chunk list")
|
| 84 |
+
|
| 85 |
+
print(f"Indexing {len(chunks)} chunks for hybrid retrieval...")
|
| 86 |
+
|
| 87 |
+
# Store chunks for retrieval
|
| 88 |
+
self.chunks = chunks
|
| 89 |
+
|
| 90 |
+
# Index for sparse retrieval
|
| 91 |
+
print("Building BM25 sparse index...")
|
| 92 |
+
self.sparse_retriever.index_documents(chunks)
|
| 93 |
+
|
| 94 |
+
# Index for dense retrieval
|
| 95 |
+
print("Building dense semantic index...")
|
| 96 |
+
texts = [chunk['text'] for chunk in chunks]
|
| 97 |
+
|
| 98 |
+
# Generate embeddings
|
| 99 |
+
self.embeddings = generate_embeddings(
|
| 100 |
+
texts,
|
| 101 |
+
model_name=self.embedding_model,
|
| 102 |
+
use_mps=self.use_mps
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Create FAISS index
|
| 106 |
+
embedding_dim = self.embeddings.shape[1]
|
| 107 |
+
self.dense_index = faiss.IndexFlatIP(embedding_dim) # Inner product for cosine similarity
|
| 108 |
+
|
| 109 |
+
# Normalize embeddings for cosine similarity
|
| 110 |
+
faiss.normalize_L2(self.embeddings)
|
| 111 |
+
self.dense_index.add(self.embeddings)
|
| 112 |
+
|
| 113 |
+
print(f"Hybrid indexing complete: {len(chunks)} chunks ready for search")
|
| 114 |
+
|
| 115 |
+
def search(
|
| 116 |
+
self,
|
| 117 |
+
query: str,
|
| 118 |
+
top_k: int = 10,
|
| 119 |
+
dense_top_k: Optional[int] = None,
|
| 120 |
+
sparse_top_k: Optional[int] = None
|
| 121 |
+
) -> List[Tuple[int, float, Dict]]:
|
| 122 |
+
"""
|
| 123 |
+
Hybrid search combining dense and sparse retrieval with RRF.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
query: Search query string
|
| 127 |
+
top_k: Final number of results to return
|
| 128 |
+
dense_top_k: Results from dense search (default: 2*top_k)
|
| 129 |
+
sparse_top_k: Results from sparse search (default: 2*top_k)
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
List of (chunk_index, rrf_score, chunk_dict) tuples
|
| 133 |
+
|
| 134 |
+
Raises:
|
| 135 |
+
ValueError: If not indexed or invalid parameters
|
| 136 |
+
|
| 137 |
+
Performance: <200ms for 1000+ document corpus
|
| 138 |
+
"""
|
| 139 |
+
if self.dense_index is None:
|
| 140 |
+
raise ValueError("Must call index_documents() before searching")
|
| 141 |
+
|
| 142 |
+
if not query.strip():
|
| 143 |
+
return []
|
| 144 |
+
|
| 145 |
+
if top_k <= 0:
|
| 146 |
+
raise ValueError("top_k must be positive")
|
| 147 |
+
|
| 148 |
+
# Set default intermediate result counts
|
| 149 |
+
if dense_top_k is None:
|
| 150 |
+
dense_top_k = min(2 * top_k, len(self.chunks))
|
| 151 |
+
if sparse_top_k is None:
|
| 152 |
+
sparse_top_k = min(2 * top_k, len(self.chunks))
|
| 153 |
+
|
| 154 |
+
# Dense semantic search
|
| 155 |
+
dense_results = self._dense_search(query, dense_top_k)
|
| 156 |
+
|
| 157 |
+
# Sparse BM25 search
|
| 158 |
+
sparse_results = self.sparse_retriever.search(query, sparse_top_k)
|
| 159 |
+
|
| 160 |
+
# Combine using Adaptive Fusion (better for small result sets)
|
| 161 |
+
fused_results = adaptive_fusion(
|
| 162 |
+
dense_results=dense_results,
|
| 163 |
+
sparse_results=sparse_results,
|
| 164 |
+
dense_weight=self.dense_weight,
|
| 165 |
+
result_size=top_k
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Prepare final results with chunk content and apply source diversity
|
| 169 |
+
final_results = []
|
| 170 |
+
for chunk_idx, rrf_score in fused_results:
|
| 171 |
+
chunk_dict = self.chunks[chunk_idx]
|
| 172 |
+
final_results.append((chunk_idx, rrf_score, chunk_dict))
|
| 173 |
+
|
| 174 |
+
# Apply source diversity enhancement
|
| 175 |
+
diverse_results = self._enhance_source_diversity(final_results, top_k)
|
| 176 |
+
|
| 177 |
+
return diverse_results
|
| 178 |
+
|
| 179 |
+
def _dense_search(self, query: str, top_k: int) -> List[Tuple[int, float]]:
|
| 180 |
+
"""
|
| 181 |
+
Perform dense semantic search using FAISS.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
query: Search query
|
| 185 |
+
top_k: Number of results to return
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
List of (chunk_index, similarity_score) tuples
|
| 189 |
+
"""
|
| 190 |
+
# Generate query embedding
|
| 191 |
+
query_embedding = generate_embeddings(
|
| 192 |
+
[query],
|
| 193 |
+
model_name=self.embedding_model,
|
| 194 |
+
use_mps=self.use_mps
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Normalize for cosine similarity
|
| 198 |
+
faiss.normalize_L2(query_embedding)
|
| 199 |
+
|
| 200 |
+
# Search dense index
|
| 201 |
+
similarities, indices = self.dense_index.search(query_embedding, top_k)
|
| 202 |
+
|
| 203 |
+
# Convert to required format
|
| 204 |
+
results = [
|
| 205 |
+
(int(indices[0][i]), float(similarities[0][i]))
|
| 206 |
+
for i in range(len(indices[0]))
|
| 207 |
+
if indices[0][i] != -1 # Filter out invalid results
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
return results
|
| 211 |
+
|
| 212 |
+
def _enhance_source_diversity(
|
| 213 |
+
self,
|
| 214 |
+
results: List[Tuple[int, float, Dict]],
|
| 215 |
+
top_k: int,
|
| 216 |
+
max_per_source: int = 2
|
| 217 |
+
) -> List[Tuple[int, float, Dict]]:
|
| 218 |
+
"""
|
| 219 |
+
Enhance source diversity in retrieval results to prevent over-focusing on single documents.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
results: List of (chunk_idx, score, chunk_dict) tuples sorted by relevance
|
| 223 |
+
top_k: Maximum number of results to return
|
| 224 |
+
max_per_source: Maximum chunks allowed per source document
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
Diversified results maintaining relevance while improving source coverage
|
| 228 |
+
"""
|
| 229 |
+
if not results:
|
| 230 |
+
return []
|
| 231 |
+
|
| 232 |
+
source_counts = {}
|
| 233 |
+
diverse_results = []
|
| 234 |
+
|
| 235 |
+
# First pass: Add highest scoring results respecting source limits
|
| 236 |
+
for chunk_idx, score, chunk_dict in results:
|
| 237 |
+
source = chunk_dict.get('source', 'unknown')
|
| 238 |
+
current_count = source_counts.get(source, 0)
|
| 239 |
+
|
| 240 |
+
if current_count < max_per_source:
|
| 241 |
+
diverse_results.append((chunk_idx, score, chunk_dict))
|
| 242 |
+
source_counts[source] = current_count + 1
|
| 243 |
+
|
| 244 |
+
if len(diverse_results) >= top_k:
|
| 245 |
+
break
|
| 246 |
+
|
| 247 |
+
# Second pass: If we still need more results, relax source constraints
|
| 248 |
+
if len(diverse_results) < top_k:
|
| 249 |
+
for chunk_idx, score, chunk_dict in results:
|
| 250 |
+
if (chunk_idx, score, chunk_dict) not in diverse_results:
|
| 251 |
+
diverse_results.append((chunk_idx, score, chunk_dict))
|
| 252 |
+
|
| 253 |
+
if len(diverse_results) >= top_k:
|
| 254 |
+
break
|
| 255 |
+
|
| 256 |
+
return diverse_results[:top_k]
|
| 257 |
+
|
| 258 |
+
def get_retrieval_stats(self) -> Dict[str, any]:
|
| 259 |
+
"""
|
| 260 |
+
Get statistics about the indexed corpus and retrieval performance.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
Dictionary with corpus statistics
|
| 264 |
+
"""
|
| 265 |
+
if not self.chunks:
|
| 266 |
+
return {"status": "not_indexed"}
|
| 267 |
+
|
| 268 |
+
return {
|
| 269 |
+
"status": "indexed",
|
| 270 |
+
"total_chunks": len(self.chunks),
|
| 271 |
+
"dense_index_size": self.dense_index.ntotal if self.dense_index else 0,
|
| 272 |
+
"embedding_dim": self.embeddings.shape[1] if self.embeddings is not None else 0,
|
| 273 |
+
"sparse_indexed_chunks": len(self.sparse_retriever.chunk_mapping),
|
| 274 |
+
"dense_weight": self.dense_weight,
|
| 275 |
+
"sparse_weight": 1.0 - self.dense_weight,
|
| 276 |
+
"rrf_k": self.rrf_k
|
| 277 |
+
}
|
shared_utils/retrieval/vocabulary_index.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vocabulary index for corpus-aware query enhancement.
|
| 3 |
+
|
| 4 |
+
Tracks all unique terms in the document corpus to enable intelligent
|
| 5 |
+
synonym expansion that only adds terms actually present in documents.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Set, Dict, List, Optional
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
import re
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class VocabularyIndex:
|
| 16 |
+
"""
|
| 17 |
+
Maintains vocabulary statistics for intelligent query enhancement.
|
| 18 |
+
|
| 19 |
+
Features:
|
| 20 |
+
- Tracks all unique terms in document corpus
|
| 21 |
+
- Stores term frequencies for relevance weighting
|
| 22 |
+
- Identifies technical terms and domain vocabulary
|
| 23 |
+
- Enables vocabulary-aware synonym expansion
|
| 24 |
+
|
| 25 |
+
Performance:
|
| 26 |
+
- Build time: ~1s per 1000 chunks
|
| 27 |
+
- Memory: ~3MB for 80K unique terms
|
| 28 |
+
- Lookup: O(1) set operations
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
"""Initialize empty vocabulary index."""
|
| 33 |
+
self.vocabulary: Set[str] = set()
|
| 34 |
+
self.term_frequencies: Dict[str, int] = defaultdict(int)
|
| 35 |
+
self.technical_terms: Set[str] = set()
|
| 36 |
+
self.document_frequencies: Dict[str, int] = defaultdict(int)
|
| 37 |
+
self.total_documents = 0
|
| 38 |
+
self.total_terms = 0
|
| 39 |
+
|
| 40 |
+
# Regex for term extraction
|
| 41 |
+
self._term_pattern = re.compile(r'\b[a-zA-Z][a-zA-Z0-9\-_]*\b')
|
| 42 |
+
self._technical_pattern = re.compile(r'\b[A-Z]{2,}|[a-zA-Z]+[\-_][a-zA-Z]+|\b\d+[a-zA-Z]+\b')
|
| 43 |
+
|
| 44 |
+
def build_from_chunks(self, chunks: List[Dict]) -> None:
|
| 45 |
+
"""
|
| 46 |
+
Build vocabulary index from document chunks.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
chunks: List of document chunks with 'text' field
|
| 50 |
+
|
| 51 |
+
Performance: ~1s per 1000 chunks
|
| 52 |
+
"""
|
| 53 |
+
self.total_documents = len(chunks)
|
| 54 |
+
|
| 55 |
+
for chunk in chunks:
|
| 56 |
+
text = chunk.get('text', '')
|
| 57 |
+
|
| 58 |
+
# Extract and process terms
|
| 59 |
+
terms = self._extract_terms(text)
|
| 60 |
+
unique_terms = set(terms)
|
| 61 |
+
|
| 62 |
+
# Update vocabulary
|
| 63 |
+
self.vocabulary.update(unique_terms)
|
| 64 |
+
|
| 65 |
+
# Update frequencies
|
| 66 |
+
for term in terms:
|
| 67 |
+
self.term_frequencies[term] += 1
|
| 68 |
+
self.total_terms += 1
|
| 69 |
+
|
| 70 |
+
# Update document frequencies
|
| 71 |
+
for term in unique_terms:
|
| 72 |
+
self.document_frequencies[term] += 1
|
| 73 |
+
|
| 74 |
+
# Identify technical terms
|
| 75 |
+
technical = self._extract_technical_terms(text)
|
| 76 |
+
self.technical_terms.update(technical)
|
| 77 |
+
|
| 78 |
+
def _extract_terms(self, text: str) -> List[str]:
|
| 79 |
+
"""Extract normalized terms from text."""
|
| 80 |
+
# Convert to lowercase and extract words
|
| 81 |
+
text_lower = text.lower()
|
| 82 |
+
terms = self._term_pattern.findall(text_lower)
|
| 83 |
+
|
| 84 |
+
# Filter short terms
|
| 85 |
+
return [term for term in terms if len(term) > 2]
|
| 86 |
+
|
| 87 |
+
def _extract_technical_terms(self, text: str) -> Set[str]:
|
| 88 |
+
"""Extract technical terms (acronyms, hyphenated, etc)."""
|
| 89 |
+
technical = set()
|
| 90 |
+
|
| 91 |
+
# Find potential technical terms
|
| 92 |
+
matches = self._technical_pattern.findall(text)
|
| 93 |
+
|
| 94 |
+
for match in matches:
|
| 95 |
+
# Normalize but preserve technical nature
|
| 96 |
+
normalized = match.lower()
|
| 97 |
+
if len(normalized) > 2:
|
| 98 |
+
technical.add(normalized)
|
| 99 |
+
|
| 100 |
+
return technical
|
| 101 |
+
|
| 102 |
+
def contains(self, term: str) -> bool:
|
| 103 |
+
"""Check if term exists in vocabulary."""
|
| 104 |
+
return term.lower() in self.vocabulary
|
| 105 |
+
|
| 106 |
+
def get_frequency(self, term: str) -> int:
|
| 107 |
+
"""Get term frequency in corpus."""
|
| 108 |
+
return self.term_frequencies.get(term.lower(), 0)
|
| 109 |
+
|
| 110 |
+
def get_document_frequency(self, term: str) -> int:
|
| 111 |
+
"""Get number of documents containing term."""
|
| 112 |
+
return self.document_frequencies.get(term.lower(), 0)
|
| 113 |
+
|
| 114 |
+
def is_common_term(self, term: str, min_frequency: int = 5) -> bool:
|
| 115 |
+
"""Check if term appears frequently enough."""
|
| 116 |
+
return self.get_frequency(term) >= min_frequency
|
| 117 |
+
|
| 118 |
+
def is_technical_term(self, term: str) -> bool:
|
| 119 |
+
"""Check if term is identified as technical."""
|
| 120 |
+
return term.lower() in self.technical_terms
|
| 121 |
+
|
| 122 |
+
def filter_synonyms(self, synonyms: List[str],
|
| 123 |
+
min_frequency: int = 3,
|
| 124 |
+
require_technical: bool = False) -> List[str]:
|
| 125 |
+
"""
|
| 126 |
+
Filter synonym list to only include terms in vocabulary.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
synonyms: List of potential synonyms
|
| 130 |
+
min_frequency: Minimum term frequency required
|
| 131 |
+
require_technical: Only include technical terms
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Filtered list of valid synonyms
|
| 135 |
+
"""
|
| 136 |
+
valid_synonyms = []
|
| 137 |
+
|
| 138 |
+
for synonym in synonyms:
|
| 139 |
+
# Check existence
|
| 140 |
+
if not self.contains(synonym):
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
# Check frequency threshold
|
| 144 |
+
if self.get_frequency(synonym) < min_frequency:
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
# Check technical requirement
|
| 148 |
+
if require_technical and not self.is_technical_term(synonym):
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
valid_synonyms.append(synonym)
|
| 152 |
+
|
| 153 |
+
return valid_synonyms
|
| 154 |
+
|
| 155 |
+
def get_vocabulary_stats(self) -> Dict[str, any]:
|
| 156 |
+
"""Get comprehensive vocabulary statistics."""
|
| 157 |
+
return {
|
| 158 |
+
'unique_terms': len(self.vocabulary),
|
| 159 |
+
'total_terms': self.total_terms,
|
| 160 |
+
'technical_terms': len(self.technical_terms),
|
| 161 |
+
'total_documents': self.total_documents,
|
| 162 |
+
'avg_terms_per_doc': self.total_terms / self.total_documents if self.total_documents > 0 else 0,
|
| 163 |
+
'vocabulary_richness': len(self.vocabulary) / self.total_terms if self.total_terms > 0 else 0,
|
| 164 |
+
'technical_ratio': len(self.technical_terms) / len(self.vocabulary) if self.vocabulary else 0
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
def get_top_terms(self, n: int = 100, technical_only: bool = False) -> List[tuple]:
|
| 168 |
+
"""
|
| 169 |
+
Get most frequent terms in corpus.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
n: Number of top terms to return
|
| 173 |
+
technical_only: Only return technical terms
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
List of (term, frequency) tuples
|
| 177 |
+
"""
|
| 178 |
+
if technical_only:
|
| 179 |
+
term_freq = {
|
| 180 |
+
term: freq for term, freq in self.term_frequencies.items()
|
| 181 |
+
if term in self.technical_terms
|
| 182 |
+
}
|
| 183 |
+
else:
|
| 184 |
+
term_freq = self.term_frequencies
|
| 185 |
+
|
| 186 |
+
return sorted(term_freq.items(), key=lambda x: x[1], reverse=True)[:n]
|
| 187 |
+
|
| 188 |
+
def detect_domain(self) -> str:
|
| 189 |
+
"""
|
| 190 |
+
Detect document domain from vocabulary patterns.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Detected domain name
|
| 194 |
+
"""
|
| 195 |
+
# Domain detection heuristics
|
| 196 |
+
domain_indicators = {
|
| 197 |
+
'embedded_systems': ['microcontroller', 'rtos', 'embedded', 'firmware', 'mcu'],
|
| 198 |
+
'processor_architecture': ['risc-v', 'riscv', 'instruction', 'register', 'isa'],
|
| 199 |
+
'regulatory': ['fda', 'validation', 'compliance', 'regulation', 'guidance'],
|
| 200 |
+
'ai_ml': ['model', 'training', 'neural', 'algorithm', 'machine learning'],
|
| 201 |
+
'software_engineering': ['software', 'development', 'testing', 'debugging', 'code']
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
domain_scores = {}
|
| 205 |
+
|
| 206 |
+
for domain, indicators in domain_indicators.items():
|
| 207 |
+
score = sum(
|
| 208 |
+
self.get_document_frequency(indicator)
|
| 209 |
+
for indicator in indicators
|
| 210 |
+
if self.contains(indicator)
|
| 211 |
+
)
|
| 212 |
+
domain_scores[domain] = score
|
| 213 |
+
|
| 214 |
+
# Return domain with highest score
|
| 215 |
+
if domain_scores:
|
| 216 |
+
return max(domain_scores, key=domain_scores.get)
|
| 217 |
+
return 'general'
|
| 218 |
+
|
| 219 |
+
def save_to_file(self, path: Path) -> None:
|
| 220 |
+
"""Save vocabulary index to JSON file."""
|
| 221 |
+
data = {
|
| 222 |
+
'vocabulary': list(self.vocabulary),
|
| 223 |
+
'term_frequencies': dict(self.term_frequencies),
|
| 224 |
+
'technical_terms': list(self.technical_terms),
|
| 225 |
+
'document_frequencies': dict(self.document_frequencies),
|
| 226 |
+
'total_documents': self.total_documents,
|
| 227 |
+
'total_terms': self.total_terms
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
with open(path, 'w') as f:
|
| 231 |
+
json.dump(data, f, indent=2)
|
| 232 |
+
|
| 233 |
+
def load_from_file(self, path: Path) -> None:
|
| 234 |
+
"""Load vocabulary index from JSON file."""
|
| 235 |
+
with open(path, 'r') as f:
|
| 236 |
+
data = json.load(f)
|
| 237 |
+
|
| 238 |
+
self.vocabulary = set(data['vocabulary'])
|
| 239 |
+
self.term_frequencies = defaultdict(int, data['term_frequencies'])
|
| 240 |
+
self.technical_terms = set(data['technical_terms'])
|
| 241 |
+
self.document_frequencies = defaultdict(int, data['document_frequencies'])
|
| 242 |
+
self.total_documents = data['total_documents']
|
| 243 |
+
self.total_terms = data['total_terms']
|
| 244 |
+
|
| 245 |
+
def merge_with(self, other: 'VocabularyIndex') -> None:
|
| 246 |
+
"""Merge another vocabulary index into this one."""
|
| 247 |
+
# Merge vocabularies
|
| 248 |
+
self.vocabulary.update(other.vocabulary)
|
| 249 |
+
self.technical_terms.update(other.technical_terms)
|
| 250 |
+
|
| 251 |
+
# Merge frequencies
|
| 252 |
+
for term, freq in other.term_frequencies.items():
|
| 253 |
+
self.term_frequencies[term] += freq
|
| 254 |
+
|
| 255 |
+
for term, doc_freq in other.document_frequencies.items():
|
| 256 |
+
self.document_frequencies[term] += doc_freq
|
| 257 |
+
|
| 258 |
+
# Update totals
|
| 259 |
+
self.total_documents += other.total_documents
|
| 260 |
+
self.total_terms += other.total_terms
|
shared_utils/vector_stores/__init__.py
ADDED
|
File without changes
|
shared_utils/vector_stores/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (158 Bytes). View file
|
|
|
shared_utils/vector_stores/document_processing/__init__.py
ADDED
|
File without changes
|