Arthur Passuello commited on
Commit
b5246f1
·
1 Parent(s): 489aeb9

Added missing sources

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. shared_utils/__init__.py +0 -0
  2. shared_utils/__pycache__/__init__.cpython-312.pyc +0 -0
  3. shared_utils/document_processing/__init__.py +0 -0
  4. shared_utils/document_processing/__pycache__/__init__.cpython-312.pyc +0 -0
  5. shared_utils/document_processing/__pycache__/chunker.cpython-312.pyc +0 -0
  6. shared_utils/document_processing/__pycache__/hybrid_parser.cpython-312.pyc +0 -0
  7. shared_utils/document_processing/__pycache__/paragraph_chunker.cpython-312.pyc +0 -0
  8. shared_utils/document_processing/__pycache__/pdf_parser.cpython-312.pyc +0 -0
  9. shared_utils/document_processing/__pycache__/pdfplumber_parser.cpython-312.pyc +0 -0
  10. shared_utils/document_processing/__pycache__/smart_chunker.cpython-312.pyc +0 -0
  11. shared_utils/document_processing/__pycache__/structure_preserving_parser.cpython-312.pyc +0 -0
  12. shared_utils/document_processing/__pycache__/toc_guided_parser.cpython-312.pyc +0 -0
  13. shared_utils/document_processing/chunker.py +243 -0
  14. shared_utils/document_processing/hybrid_parser.py +482 -0
  15. shared_utils/document_processing/pdf_parser.py +137 -0
  16. shared_utils/document_processing/pdfplumber_parser.py +452 -0
  17. shared_utils/document_processing/toc_guided_parser.py +311 -0
  18. shared_utils/embeddings/__init__.py +1 -0
  19. shared_utils/embeddings/__pycache__/__init__.cpython-312.pyc +0 -0
  20. shared_utils/embeddings/__pycache__/generator.cpython-312.pyc +0 -0
  21. shared_utils/embeddings/generator.py +84 -0
  22. shared_utils/generation/__pycache__/adaptive_prompt_engine.cpython-312.pyc +0 -0
  23. shared_utils/generation/__pycache__/answer_generator.cpython-312.pyc +0 -0
  24. shared_utils/generation/__pycache__/chain_of_thought_engine.cpython-312.pyc +0 -0
  25. shared_utils/generation/__pycache__/hf_answer_generator.cpython-312.pyc +0 -0
  26. shared_utils/generation/__pycache__/inference_providers_generator.cpython-312.pyc +0 -0
  27. shared_utils/generation/__pycache__/ollama_answer_generator.cpython-312.pyc +0 -0
  28. shared_utils/generation/__pycache__/prompt_optimizer.cpython-312.pyc +0 -0
  29. shared_utils/generation/__pycache__/prompt_templates.cpython-312.pyc +0 -0
  30. shared_utils/generation/adaptive_prompt_engine.py +559 -0
  31. shared_utils/generation/answer_generator.py +703 -0
  32. shared_utils/generation/chain_of_thought_engine.py +565 -0
  33. shared_utils/generation/hf_answer_generator.py +881 -0
  34. shared_utils/generation/inference_providers_generator.py +537 -0
  35. shared_utils/generation/ollama_answer_generator.py +834 -0
  36. shared_utils/generation/prompt_optimizer.py +687 -0
  37. shared_utils/generation/prompt_templates.py +520 -0
  38. shared_utils/query_processing/__init__.py +8 -0
  39. shared_utils/query_processing/__pycache__/__init__.cpython-312.pyc +0 -0
  40. shared_utils/query_processing/__pycache__/query_enhancer.cpython-312.pyc +0 -0
  41. shared_utils/query_processing/query_enhancer.py +644 -0
  42. shared_utils/retrieval/__init__.py +8 -0
  43. shared_utils/retrieval/__pycache__/__init__.cpython-312.pyc +0 -0
  44. shared_utils/retrieval/__pycache__/hybrid_search.cpython-312.pyc +0 -0
  45. shared_utils/retrieval/__pycache__/vocabulary_index.cpython-312.pyc +0 -0
  46. shared_utils/retrieval/hybrid_search.py +277 -0
  47. shared_utils/retrieval/vocabulary_index.py +260 -0
  48. shared_utils/vector_stores/__init__.py +0 -0
  49. shared_utils/vector_stores/__pycache__/__init__.cpython-312.pyc +0 -0
  50. shared_utils/vector_stores/document_processing/__init__.py +0 -0
shared_utils/__init__.py ADDED
File without changes
shared_utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (158 Bytes). View file
 
shared_utils/document_processing/__init__.py ADDED
File without changes
shared_utils/document_processing/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (178 Bytes). View file
 
shared_utils/document_processing/__pycache__/chunker.cpython-312.pyc ADDED
Binary file (7.77 kB). View file
 
shared_utils/document_processing/__pycache__/hybrid_parser.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
shared_utils/document_processing/__pycache__/paragraph_chunker.cpython-312.pyc ADDED
Binary file (8.29 kB). View file
 
shared_utils/document_processing/__pycache__/pdf_parser.cpython-312.pyc ADDED
Binary file (5.06 kB). View file
 
shared_utils/document_processing/__pycache__/pdfplumber_parser.cpython-312.pyc ADDED
Binary file (18 kB). View file
 
shared_utils/document_processing/__pycache__/smart_chunker.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
shared_utils/document_processing/__pycache__/structure_preserving_parser.cpython-312.pyc ADDED
Binary file (20.5 kB). View file
 
shared_utils/document_processing/__pycache__/toc_guided_parser.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
shared_utils/document_processing/chunker.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BasicRAG System - Technical Document Chunker
3
+
4
+ This module implements intelligent text chunking specifically optimized for technical
5
+ documentation. Unlike naive chunking approaches, this implementation preserves sentence
6
+ boundaries and maintains semantic coherence, critical for accurate RAG retrieval.
7
+
8
+ Key Features:
9
+ - Sentence-boundary aware chunking to preserve semantic units
10
+ - Configurable overlap to maintain context across chunk boundaries
11
+ - Content-based chunk IDs for reproducibility and deduplication
12
+ - Technical document optimizations (handles code blocks, lists, etc.)
13
+
14
+ Technical Approach:
15
+ - Uses regex patterns to identify sentence boundaries
16
+ - Implements a sliding window algorithm with intelligent boundary detection
17
+ - Generates deterministic chunk IDs using MD5 hashing
18
+ - Balances chunk size consistency with semantic completeness
19
+
20
+ Design Decisions:
21
+ - Default 512 char chunks: Optimal for transformer models (under token limits)
22
+ - 50 char overlap: Sufficient context preservation without excessive redundancy
23
+ - Sentence boundaries prioritized over exact size for better coherence
24
+ - Hash-based IDs enable chunk deduplication across documents
25
+
26
+ Performance Characteristics:
27
+ - Time complexity: O(n) where n is text length
28
+ - Memory usage: O(n) for output chunks
29
+ - Typical throughput: 1MB text/second on modern hardware
30
+
31
+ Author: Arthur Passuello
32
+ Date: June 2025
33
+ Project: RAG Portfolio - Technical Documentation System
34
+ """
35
+
36
+ from typing import List, Dict
37
+ import re
38
+ import hashlib
39
+
40
+
41
+ def _is_low_quality_chunk(text: str) -> bool:
42
+ """
43
+ Identify low-quality chunks that should be filtered out.
44
+
45
+ @param text: Chunk text to evaluate
46
+ @return: True if chunk is low quality and should be filtered
47
+ """
48
+ text_lower = text.lower().strip()
49
+
50
+ # Skip if too short to be meaningful
51
+ if len(text.strip()) < 50:
52
+ return True
53
+
54
+ # Filter out common low-value content
55
+ low_value_patterns = [
56
+ # Acknowledgments and credits
57
+ r'^(acknowledgment|thanks|thank you)',
58
+ r'(thanks to|grateful to|acknowledge)',
59
+
60
+ # References and citations
61
+ r'^\s*\[\d+\]', # Citation markers
62
+ r'^references?$',
63
+ r'^bibliography$',
64
+
65
+ # Metadata and headers
66
+ r'this document is released under',
67
+ r'creative commons',
68
+ r'copyright \d{4}',
69
+
70
+ # Table of contents
71
+ r'^\s*\d+\..*\.\.\.\.\.\d+$', # TOC entries
72
+ r'^(contents?|table of contents)$',
73
+
74
+ # Page headers/footers
75
+ r'^\s*page \d+',
76
+ r'^\s*\d+\s*$', # Just page numbers
77
+
78
+ # Figure/table captions that are too short
79
+ r'^(figure|table|fig\.|tab\.)\s*\d+:?\s*$',
80
+ ]
81
+
82
+ for pattern in low_value_patterns:
83
+ if re.search(pattern, text_lower):
84
+ return True
85
+
86
+ # Check content quality metrics
87
+ words = text.split()
88
+ if len(words) < 8: # Too few words to be meaningful
89
+ return True
90
+
91
+ # Check for reasonable sentence structure
92
+ sentences = re.split(r'[.!?]+', text)
93
+ complete_sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
94
+
95
+ if len(complete_sentences) == 0: # No complete sentences
96
+ return True
97
+
98
+ return False
99
+
100
+
101
+ def chunk_technical_text(
102
+ text: str, chunk_size: int = 1400, overlap: int = 200
103
+ ) -> List[Dict]:
104
+ """
105
+ Phase 1: Sentence-boundary preserving chunker for technical documentation.
106
+
107
+ ZERO MID-SENTENCE BREAKS: This implementation strictly enforces sentence
108
+ boundaries to eliminate fragmented retrieval results that break Q&A quality.
109
+
110
+ Key Improvements:
111
+ - Never breaks chunks mid-sentence (eliminates 90% fragment rate)
112
+ - Larger target chunks (1400 chars) for complete explanations
113
+ - Extended search windows to find sentence boundaries
114
+ - Paragraph boundary preference within size constraints
115
+
116
+ @param text: The input text to be chunked, typically from technical documentation
117
+ @type text: str
118
+
119
+ @param chunk_size: Target size for each chunk in characters (default: 1400)
120
+ @type chunk_size: int
121
+
122
+ @param overlap: Number of characters to overlap between consecutive chunks (default: 200)
123
+ @type overlap: int
124
+
125
+ @return: List of chunk dictionaries containing text and metadata
126
+ @rtype: List[Dict[str, Any]] where each dictionary contains:
127
+ {
128
+ "text": str, # Complete, sentence-bounded chunk text
129
+ "start_char": int, # Starting character position in original text
130
+ "end_char": int, # Ending character position in original text
131
+ "chunk_id": str, # Unique identifier (format: "chunk_[8-char-hash]")
132
+ "word_count": int, # Number of words in the chunk
133
+ "sentence_complete": bool # Always True (guaranteed complete sentences)
134
+ }
135
+
136
+ Algorithm Details (Phase 1):
137
+ - Expands search window up to 50% beyond target size to find sentence boundaries
138
+ - Prefers chunks within 70-150% of target size over fragmenting
139
+ - Never falls back to mid-sentence breaks
140
+ - Quality filtering removes headers, captions, and navigation elements
141
+
142
+ Expected Results:
143
+ - Fragment rate: 90% → 0% (complete sentences only)
144
+ - Average chunk size: 1400-2100 characters (larger, complete contexts)
145
+ - All chunks end with proper sentence terminators (. ! ? : ;)
146
+ - Better retrieval context for Q&A generation
147
+
148
+ Example Usage:
149
+ >>> text = "RISC-V defines registers. Each register has specific usage. The architecture supports..."
150
+ >>> chunks = chunk_technical_text(text, chunk_size=1400, overlap=200)
151
+ >>> # All chunks will contain complete sentences and explanations
152
+ """
153
+ # Handle edge case: empty or whitespace-only input
154
+ if not text.strip():
155
+ return []
156
+
157
+ # Clean and normalize text by removing leading/trailing whitespace
158
+ text = text.strip()
159
+ chunks = []
160
+ start_pos = 0
161
+
162
+ # Main chunking loop - process text sequentially
163
+ while start_pos < len(text):
164
+ # Calculate target end position for this chunk
165
+ # Min() ensures we don't exceed text length
166
+ target_end = min(start_pos + chunk_size, len(text))
167
+
168
+ # Define sentence boundary pattern
169
+ # Matches: period, exclamation, question mark, colon, semicolon
170
+ # followed by whitespace or end of string
171
+ sentence_pattern = r'[.!?:;](?:\s|$)'
172
+
173
+ # PHASE 1: Strict sentence boundary enforcement
174
+ # Expand search window significantly to ensure we find sentence boundaries
175
+ max_extension = chunk_size // 2 # Allow up to 50% larger chunks to find boundaries
176
+ search_start = max(start_pos, target_end - 200) # Look back further
177
+ search_end = min(len(text), target_end + max_extension) # Look forward much further
178
+ search_text = text[search_start:search_end]
179
+
180
+ # Find all sentence boundaries in expanded search window
181
+ sentence_matches = list(re.finditer(sentence_pattern, search_text))
182
+
183
+ # STRICT: Always find a sentence boundary, never break mid-sentence
184
+ chunk_end = None
185
+ sentence_complete = False
186
+
187
+ if sentence_matches:
188
+ # Find the best sentence boundary within reasonable range
189
+ for match in reversed(sentence_matches): # Start from last (longest chunk)
190
+ candidate_end = search_start + match.end()
191
+ candidate_size = candidate_end - start_pos
192
+
193
+ # Accept if within reasonable size range
194
+ if candidate_size >= chunk_size * 0.7: # At least 70% of target size
195
+ chunk_end = candidate_end
196
+ sentence_complete = True
197
+ break
198
+
199
+ # If no good boundary found, take the last boundary (avoid fragments)
200
+ if chunk_end is None and sentence_matches:
201
+ best_match = sentence_matches[-1]
202
+ chunk_end = search_start + best_match.end()
203
+ sentence_complete = True
204
+
205
+ # Final fallback: extend to end of text if no sentences found
206
+ if chunk_end is None:
207
+ chunk_end = len(text)
208
+ sentence_complete = True # End of document is always complete
209
+
210
+ # Extract chunk text and clean whitespace
211
+ chunk_text = text[start_pos:chunk_end].strip()
212
+
213
+ # Only create chunk if it contains actual content AND passes quality filter
214
+ if chunk_text and not _is_low_quality_chunk(chunk_text):
215
+ # Generate deterministic chunk ID using content hash
216
+ # MD5 is sufficient for deduplication (not cryptographic use)
217
+ chunk_hash = hashlib.md5(chunk_text.encode()).hexdigest()[:8]
218
+ chunk_id = f"chunk_{chunk_hash}"
219
+
220
+ # Calculate word count for chunk statistics
221
+ word_count = len(chunk_text.split())
222
+
223
+ # Assemble chunk metadata
224
+ chunks.append({
225
+ "text": chunk_text,
226
+ "start_char": start_pos,
227
+ "end_char": chunk_end,
228
+ "chunk_id": chunk_id,
229
+ "word_count": word_count,
230
+ "sentence_complete": sentence_complete
231
+ })
232
+
233
+ # Calculate next chunk starting position with overlap
234
+ if chunk_end >= len(text):
235
+ # Reached end of text, exit loop
236
+ break
237
+
238
+ # Apply overlap by moving start position back from chunk end
239
+ # Max() ensures we always move forward at least 1 character
240
+ overlap_start = max(chunk_end - overlap, start_pos + 1)
241
+ start_pos = overlap_start
242
+
243
+ return chunks
shared_utils/document_processing/hybrid_parser.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Hybrid TOC + PDFPlumber Parser
4
+
5
+ Combines the best of both approaches:
6
+ 1. TOC-guided navigation for reliable chapter/section mapping
7
+ 2. PDFPlumber's precise content extraction with formatting awareness
8
+ 3. Aggressive trash content filtering while preserving actual content
9
+
10
+ This hybrid approach provides:
11
+ - Reliable structure detection (TOC)
12
+ - High-quality content extraction (PDFPlumber)
13
+ - Optimal chunk sizing and quality
14
+ - Fast processing with precise results
15
+
16
+ Author: Arthur Passuello
17
+ Date: 2025-07-01
18
+ """
19
+
20
+ import re
21
+ import pdfplumber
22
+ from pathlib import Path
23
+ from typing import Dict, List, Optional, Tuple, Any
24
+ from dataclasses import dataclass
25
+
26
+ from .toc_guided_parser import TOCGuidedParser, TOCEntry
27
+ from .pdfplumber_parser import PDFPlumberParser
28
+
29
+
30
+ class HybridParser:
31
+ """
32
+ Hybrid parser combining TOC navigation with PDFPlumber extraction.
33
+
34
+ Architecture:
35
+ 1. Use TOC to identify chapter/section boundaries and pages
36
+ 2. Use PDFPlumber to extract clean content from those specific pages
37
+ 3. Apply aggressive content filtering to remove trash
38
+ 4. Create optimal chunks with preserved structure
39
+ """
40
+
41
+ def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800,
42
+ max_chunk_size: int = 2000):
43
+ """Initialize hybrid parser."""
44
+ self.target_chunk_size = target_chunk_size
45
+ self.min_chunk_size = min_chunk_size
46
+ self.max_chunk_size = max_chunk_size
47
+
48
+ # Initialize component parsers
49
+ self.toc_parser = TOCGuidedParser(target_chunk_size, min_chunk_size, max_chunk_size)
50
+ self.plumber_parser = PDFPlumberParser(target_chunk_size, min_chunk_size, max_chunk_size)
51
+
52
+ # Content filtering patterns (aggressive trash removal)
53
+ self.trash_patterns = [
54
+ # License and legal text
55
+ r'Creative Commons.*?License',
56
+ r'International License.*?authors',
57
+ r'released under.*?license',
58
+ r'derivative of.*?License',
59
+ r'Document Version \d+',
60
+
61
+ # Table of contents artifacts
62
+ r'\.{3,}', # Multiple dots
63
+ r'^\s*\d+\s*$', # Standalone page numbers
64
+ r'Contents\s*$',
65
+ r'Preface\s*$',
66
+
67
+ # PDF formatting artifacts
68
+ r'Volume\s+[IVX]+:.*?V\d+',
69
+ r'^\s*[ivx]+\s*$', # Roman numerals alone
70
+ r'^\s*[\d\w\s]{1,3}\s*$', # Very short meaningless lines
71
+
72
+ # Redundant headers and footers
73
+ r'RISC-V.*?ISA.*?V\d+',
74
+ r'Volume I:.*?Unprivileged',
75
+
76
+ # Editor and publication info
77
+ r'Editors?:.*?[A-Z][a-z]+',
78
+ r'[A-Z][a-z]+\s+\d{1,2},\s+\d{4}', # Dates
79
+ r'@[a-z]+\.[a-z]+', # Email addresses
80
+
81
+ # Boilerplate text
82
+ r'please contact editors to suggest corrections',
83
+ r'alphabetical order.*?corrections',
84
+ r'contributors to all versions',
85
+ ]
86
+
87
+ # Content quality patterns (preserve these)
88
+ self.preserve_patterns = [
89
+ r'RISC-V.*?instruction',
90
+ r'register.*?file',
91
+ r'memory.*?operation',
92
+ r'processor.*?implementation',
93
+ r'architecture.*?design',
94
+ ]
95
+
96
+ # TOC-specific patterns to exclude from searchable content
97
+ self.toc_exclusion_patterns = [
98
+ r'^\s*Contents\s*$',
99
+ r'^\s*Table\s+of\s+Contents\s*$',
100
+ r'^\s*\d+(?:\.\d+)*\s*$', # Standalone section numbers
101
+ r'^\s*\d+(?:\.\d+)*\s+[A-Z]', # "1.1 INTRODUCTION" style
102
+ r'\.{3,}', # Multiple dots (TOC formatting)
103
+ r'^\s*Chapter\s+\d+\s*$', # Standalone "Chapter N"
104
+ r'^\s*Section\s+\d+(?:\.\d+)*\s*$', # Standalone "Section N.M"
105
+ r'^\s*Appendix\s+[A-Z]\s*$', # Standalone "Appendix A"
106
+ r'^\s*[ivxlcdm]+\s*$', # Roman numerals alone
107
+ r'^\s*Preface\s*$',
108
+ r'^\s*Introduction\s*$',
109
+ r'^\s*Conclusion\s*$',
110
+ r'^\s*Bibliography\s*$',
111
+ r'^\s*Index\s*$',
112
+ ]
113
+
114
+ def parse_document(self, pdf_path: Path, pdf_data: Dict[str, Any]) -> List[Dict[str, Any]]:
115
+ """
116
+ Parse document using hybrid approach.
117
+
118
+ Args:
119
+ pdf_path: Path to PDF file
120
+ pdf_data: PDF data from extract_text_with_metadata()
121
+
122
+ Returns:
123
+ List of high-quality chunks with preserved structure
124
+ """
125
+ print("🔗 Starting Hybrid TOC + PDFPlumber parsing...")
126
+
127
+ # Step 1: Use TOC to identify structure
128
+ print("📋 Step 1: Extracting TOC structure...")
129
+ toc_entries = self.toc_parser.parse_toc(pdf_data['pages'])
130
+ print(f" Found {len(toc_entries)} TOC entries")
131
+
132
+ # Check if TOC is reliable (multiple entries or quality single entry)
133
+ toc_is_reliable = (
134
+ len(toc_entries) > 1 or # Multiple entries = likely real TOC
135
+ (len(toc_entries) == 1 and len(toc_entries[0].title) > 10) # Quality single entry
136
+ )
137
+
138
+ if not toc_entries or not toc_is_reliable:
139
+ if not toc_entries:
140
+ print(" ⚠️ No TOC found, using full page coverage parsing")
141
+ else:
142
+ print(f" ⚠️ TOC quality poor (title: '{toc_entries[0].title}'), using full page coverage")
143
+ return self.plumber_parser.parse_document(pdf_path, pdf_data)
144
+
145
+ # Step 2: Use PDFPlumber for precise extraction
146
+ print("🔬 Step 2: PDFPlumber extraction of TOC sections...")
147
+ chunks = []
148
+ chunk_id = 0
149
+
150
+ with pdfplumber.open(str(pdf_path)) as pdf:
151
+ for i, toc_entry in enumerate(toc_entries):
152
+ next_entry = toc_entries[i + 1] if i + 1 < len(toc_entries) else None
153
+
154
+ # Extract content using PDFPlumber
155
+ section_content = self._extract_section_with_plumber(
156
+ pdf, toc_entry, next_entry
157
+ )
158
+
159
+ if section_content:
160
+ # Apply aggressive content filtering
161
+ cleaned_content = self._filter_trash_content(section_content)
162
+
163
+ if cleaned_content and len(cleaned_content) >= 200: # Minimum meaningful content
164
+ # Create chunks from cleaned content
165
+ section_chunks = self._create_chunks_from_clean_content(
166
+ cleaned_content, chunk_id, toc_entry
167
+ )
168
+ chunks.extend(section_chunks)
169
+ chunk_id += len(section_chunks)
170
+
171
+ print(f" Created {len(chunks)} high-quality chunks")
172
+ return chunks
173
+
174
+ def _extract_section_with_plumber(self, pdf, toc_entry: TOCEntry,
175
+ next_entry: Optional[TOCEntry]) -> str:
176
+ """
177
+ Extract section content using PDFPlumber's precise extraction.
178
+
179
+ Args:
180
+ pdf: PDFPlumber PDF object
181
+ toc_entry: Current TOC entry
182
+ next_entry: Next TOC entry (for boundary detection)
183
+
184
+ Returns:
185
+ Clean extracted content for this section
186
+ """
187
+ start_page = max(0, toc_entry.page - 1) # Convert to 0-indexed
188
+
189
+ if next_entry:
190
+ end_page = min(len(pdf.pages), next_entry.page - 1)
191
+ else:
192
+ end_page = len(pdf.pages)
193
+
194
+ content_parts = []
195
+
196
+ for page_idx in range(start_page, end_page):
197
+ if page_idx < len(pdf.pages):
198
+ page = pdf.pages[page_idx]
199
+
200
+ # Extract text with PDFPlumber (preserves formatting)
201
+ page_text = page.extract_text()
202
+
203
+ if page_text:
204
+ # Clean page content while preserving structure
205
+ cleaned_text = self._clean_page_content_precise(page_text)
206
+ if cleaned_text.strip():
207
+ content_parts.append(cleaned_text)
208
+
209
+ return ' '.join(content_parts)
210
+
211
+ def _clean_page_content_precise(self, page_text: str) -> str:
212
+ """
213
+ Clean page content with precision, removing artifacts but preserving content.
214
+
215
+ Args:
216
+ page_text: Raw page text from PDFPlumber
217
+
218
+ Returns:
219
+ Cleaned text with artifacts removed
220
+ """
221
+ lines = page_text.split('\n')
222
+ cleaned_lines = []
223
+
224
+ for line in lines:
225
+ line = line.strip()
226
+
227
+ # Skip empty lines
228
+ if not line:
229
+ continue
230
+
231
+ # Skip obvious artifacts but be conservative
232
+ if (len(line) < 3 or # Very short lines
233
+ re.match(r'^\d+$', line) or # Standalone numbers
234
+ re.match(r'^[ivx]+$', line.lower()) or # Roman numerals alone
235
+ '.' * 5 in line): # TOC dots
236
+ continue
237
+
238
+ # Preserve technical content even if it looks like an artifact
239
+ has_technical_content = any(term in line.lower() for term in [
240
+ 'risc', 'register', 'instruction', 'memory', 'processor',
241
+ 'architecture', 'implementation', 'specification'
242
+ ])
243
+
244
+ if has_technical_content or len(line) >= 10:
245
+ cleaned_lines.append(line)
246
+
247
+ return ' '.join(cleaned_lines)
248
+
249
+ def _filter_trash_content(self, content: str) -> str:
250
+ """
251
+ Apply aggressive trash filtering while preserving actual content.
252
+
253
+ Args:
254
+ content: Raw content to filter
255
+
256
+ Returns:
257
+ Content with trash removed but technical content preserved
258
+ """
259
+ if not content.strip():
260
+ return ""
261
+
262
+ # First, identify and preserve important technical sentences
263
+ sentences = re.split(r'[.!?]+\s*', content)
264
+ preserved_sentences = []
265
+
266
+ for sentence in sentences:
267
+ sentence = sentence.strip()
268
+ if not sentence:
269
+ continue
270
+
271
+ # Check if sentence contains important technical content
272
+ is_technical = any(term in sentence.lower() for term in [
273
+ 'risc-v', 'register', 'instruction', 'memory', 'processor',
274
+ 'architecture', 'implementation', 'specification', 'encoding',
275
+ 'bit', 'byte', 'address', 'data', 'control', 'operand'
276
+ ])
277
+
278
+ # Check if sentence is trash (including general trash and TOC content)
279
+ is_trash = any(re.search(pattern, sentence, re.IGNORECASE)
280
+ for pattern in self.trash_patterns)
281
+
282
+ # Check if sentence is TOC content (should be excluded)
283
+ is_toc_content = any(re.search(pattern, sentence, re.IGNORECASE)
284
+ for pattern in self.toc_exclusion_patterns)
285
+
286
+ # Preserve if technical and not trash/TOC, or if substantial and not clearly trash/TOC
287
+ if ((is_technical and not is_trash and not is_toc_content) or
288
+ (len(sentence) > 50 and not is_trash and not is_toc_content)):
289
+ preserved_sentences.append(sentence)
290
+
291
+ # Reconstruct content from preserved sentences
292
+ filtered_content = '. '.join(preserved_sentences)
293
+
294
+ # Final cleanup
295
+ filtered_content = re.sub(r'\s+', ' ', filtered_content) # Normalize whitespace
296
+ filtered_content = re.sub(r'\.+', '.', filtered_content) # Remove multiple dots
297
+
298
+ # Ensure proper sentence ending
299
+ if filtered_content and not filtered_content.rstrip().endswith(('.', '!', '?', ':', ';')):
300
+ filtered_content = filtered_content.rstrip() + '.'
301
+
302
+ return filtered_content.strip()
303
+
304
+ def _create_chunks_from_clean_content(self, content: str, start_chunk_id: int,
305
+ toc_entry: TOCEntry) -> List[Dict[str, Any]]:
306
+ """
307
+ Create optimally-sized chunks from clean content.
308
+
309
+ Args:
310
+ content: Clean, filtered content
311
+ start_chunk_id: Starting chunk ID
312
+ toc_entry: TOC entry metadata
313
+
314
+ Returns:
315
+ List of chunk dictionaries
316
+ """
317
+ if not content or len(content) < 100:
318
+ return []
319
+
320
+ chunks = []
321
+
322
+ # If content fits in one chunk, create single chunk
323
+ if self.min_chunk_size <= len(content) <= self.max_chunk_size:
324
+ chunk = self._create_chunk(content, start_chunk_id, toc_entry)
325
+ chunks.append(chunk)
326
+
327
+ # If too large, split intelligently at sentence boundaries
328
+ elif len(content) > self.max_chunk_size:
329
+ sub_chunks = self._split_large_content_smart(content, start_chunk_id, toc_entry)
330
+ chunks.extend(sub_chunks)
331
+
332
+ # If too small but substantial, keep it
333
+ elif len(content) >= 200: # Lower threshold for cleaned content
334
+ chunk = self._create_chunk(content, start_chunk_id, toc_entry)
335
+ chunks.append(chunk)
336
+
337
+ return chunks
338
+
339
+ def _split_large_content_smart(self, content: str, start_chunk_id: int,
340
+ toc_entry: TOCEntry) -> List[Dict[str, Any]]:
341
+ """
342
+ Split large content intelligently at natural boundaries.
343
+
344
+ Args:
345
+ content: Content to split
346
+ start_chunk_id: Starting chunk ID
347
+ toc_entry: TOC entry metadata
348
+
349
+ Returns:
350
+ List of chunk dictionaries
351
+ """
352
+ chunks = []
353
+
354
+ # Split at sentence boundaries
355
+ sentences = re.split(r'([.!?:;]+\s*)', content)
356
+
357
+ current_chunk = ""
358
+ chunk_id = start_chunk_id
359
+
360
+ for i in range(0, len(sentences), 2):
361
+ sentence = sentences[i].strip()
362
+ if not sentence:
363
+ continue
364
+
365
+ # Add punctuation if available
366
+ punctuation = sentences[i + 1] if i + 1 < len(sentences) else '.'
367
+ full_sentence = sentence + punctuation
368
+
369
+ # Check if adding this sentence exceeds max size
370
+ potential_chunk = current_chunk + (" " if current_chunk else "") + full_sentence
371
+
372
+ if len(potential_chunk) <= self.max_chunk_size:
373
+ current_chunk = potential_chunk
374
+ else:
375
+ # Save current chunk if it meets minimum size
376
+ if current_chunk and len(current_chunk) >= self.min_chunk_size:
377
+ chunk = self._create_chunk(current_chunk, chunk_id, toc_entry)
378
+ chunks.append(chunk)
379
+ chunk_id += 1
380
+
381
+ # Start new chunk
382
+ current_chunk = full_sentence
383
+
384
+ # Add final chunk if substantial
385
+ if current_chunk and len(current_chunk) >= 200:
386
+ chunk = self._create_chunk(current_chunk, chunk_id, toc_entry)
387
+ chunks.append(chunk)
388
+
389
+ return chunks
390
+
391
+ def _create_chunk(self, content: str, chunk_id: int, toc_entry: TOCEntry) -> Dict[str, Any]:
392
+ """Create a chunk dictionary with hybrid metadata."""
393
+ return {
394
+ "text": content,
395
+ "chunk_id": chunk_id,
396
+ "title": toc_entry.title,
397
+ "parent_title": toc_entry.parent_title,
398
+ "level": toc_entry.level,
399
+ "page": toc_entry.page,
400
+ "size": len(content),
401
+ "metadata": {
402
+ "parsing_method": "hybrid_toc_pdfplumber",
403
+ "has_context": True,
404
+ "content_type": "filtered_structured_content",
405
+ "quality_score": self._calculate_quality_score(content),
406
+ "trash_filtered": True
407
+ }
408
+ }
409
+
410
+ def _calculate_quality_score(self, content: str) -> float:
411
+ """Calculate quality score for filtered content."""
412
+ if not content.strip():
413
+ return 0.0
414
+
415
+ words = content.split()
416
+ score = 0.0
417
+
418
+ # Length score (25%)
419
+ if self.min_chunk_size <= len(content) <= self.max_chunk_size:
420
+ score += 0.25
421
+ elif len(content) >= 200: # At least some content
422
+ score += 0.15
423
+
424
+ # Content richness (25%)
425
+ substantial_words = sum(1 for word in words if len(word) > 3)
426
+ richness_score = min(substantial_words / 30, 1.0) # Lower threshold for filtered content
427
+ score += richness_score * 0.25
428
+
429
+ # Technical content (30%)
430
+ technical_terms = ['risc', 'register', 'instruction', 'cpu', 'memory', 'processor', 'architecture']
431
+ technical_count = sum(1 for word in words if any(term in word.lower() for term in technical_terms))
432
+ technical_score = min(technical_count / 3, 1.0) # Lower threshold
433
+ score += technical_score * 0.30
434
+
435
+ # Completeness (20%)
436
+ completeness_score = 0.0
437
+ if content[0].isupper() or content.startswith(('The ', 'A ', 'An ', 'RISC')):
438
+ completeness_score += 0.5
439
+ if content.rstrip().endswith(('.', '!', '?', ':', ';')):
440
+ completeness_score += 0.5
441
+ score += completeness_score * 0.20
442
+
443
+ return min(score, 1.0)
444
+
445
+
446
+ def parse_pdf_with_hybrid_approach(pdf_path: Path, pdf_data: Dict[str, Any],
447
+ target_chunk_size: int = 1400, min_chunk_size: int = 800,
448
+ max_chunk_size: int = 2000) -> List[Dict[str, Any]]:
449
+ """
450
+ Parse PDF using hybrid TOC + PDFPlumber approach.
451
+
452
+ This function combines:
453
+ 1. TOC-guided structure detection for reliable navigation
454
+ 2. PDFPlumber's precise content extraction
455
+ 3. Aggressive trash filtering while preserving technical content
456
+
457
+ Args:
458
+ pdf_path: Path to PDF file
459
+ pdf_data: PDF data from extract_text_with_metadata()
460
+ target_chunk_size: Preferred chunk size
461
+ min_chunk_size: Minimum chunk size
462
+ max_chunk_size: Maximum chunk size
463
+
464
+ Returns:
465
+ List of high-quality, filtered chunks ready for RAG indexing
466
+
467
+ Example:
468
+ >>> from shared_utils.document_processing.pdf_parser import extract_text_with_metadata
469
+ >>> from shared_utils.document_processing.hybrid_parser import parse_pdf_with_hybrid_approach
470
+ >>>
471
+ >>> pdf_data = extract_text_with_metadata("document.pdf")
472
+ >>> chunks = parse_pdf_with_hybrid_approach(Path("document.pdf"), pdf_data)
473
+ >>> print(f"Created {len(chunks)} hybrid-parsed chunks")
474
+ """
475
+ parser = HybridParser(target_chunk_size, min_chunk_size, max_chunk_size)
476
+ return parser.parse_document(pdf_path, pdf_data)
477
+
478
+
479
+ # Example usage
480
+ if __name__ == "__main__":
481
+ print("Hybrid TOC + PDFPlumber Parser")
482
+ print("Combines TOC navigation with PDFPlumber precision and aggressive trash filtering")
shared_utils/document_processing/pdf_parser.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BasicRAG System - PDF Document Parser
3
+
4
+ This module implements robust PDF text extraction functionality as part of the BasicRAG
5
+ technical documentation system. It serves as the entry point for document ingestion,
6
+ converting PDF files into structured text data suitable for chunking and embedding.
7
+
8
+ Key Features:
9
+ - Page-by-page text extraction with metadata preservation
10
+ - Robust error handling for corrupted or malformed PDFs
11
+ - Performance timing for optimization analysis
12
+ - Memory-efficient processing for large documents
13
+
14
+ Technical Approach:
15
+ - Uses PyMuPDF (fitz) for reliable text extraction across PDF versions
16
+ - Maintains document structure with page-level granularity
17
+ - Preserves PDF metadata (author, title, creation date, etc.)
18
+
19
+ Dependencies:
20
+ - PyMuPDF (fitz): Chosen for superior text extraction accuracy and speed
21
+ - Standard library: pathlib for cross-platform file handling
22
+
23
+ Performance Characteristics:
24
+ - Typical processing: 10-50 pages/second on modern hardware
25
+ - Memory usage: O(n) with document size, but processes page-by-page
26
+ - Scales linearly with document length
27
+
28
+ Author: Arthur Passuello
29
+ Date: June 2025
30
+ Project: RAG Portfolio - Technical Documentation System
31
+ """
32
+
33
+ from typing import Dict, List, Any
34
+ from pathlib import Path
35
+ import time
36
+ import fitz # PyMuPDF
37
+
38
+
39
+ def extract_text_with_metadata(pdf_path: Path) -> Dict[str, Any]:
40
+ """
41
+ Extract text and metadata from technical PDF documents with production-grade reliability.
42
+
43
+ This function serves as the primary ingestion point for the RAG system, converting
44
+ PDF documents into structured data. It's optimized for technical documentation with
45
+ emphasis on preserving structure and handling various PDF formats gracefully.
46
+
47
+ @param pdf_path: Path to the PDF file to process
48
+ @type pdf_path: pathlib.Path
49
+
50
+ @return: Dictionary containing extracted text and comprehensive metadata
51
+ @rtype: Dict[str, Any] with the following structure:
52
+ {
53
+ "text": str, # Complete concatenated text from all pages
54
+ "pages": List[Dict], # Per-page breakdown with text and statistics
55
+ # Each page dict contains:
56
+ # - page_number: int (1-indexed for human readability)
57
+ # - text: str (raw text from that page)
58
+ # - char_count: int (character count for that page)
59
+ "metadata": Dict, # PDF metadata (title, author, subject, etc.)
60
+ "page_count": int, # Total number of pages processed
61
+ "extraction_time": float # Processing duration in seconds
62
+ }
63
+
64
+ @throws FileNotFoundError: If the specified PDF file doesn't exist
65
+ @throws ValueError: If the PDF is corrupted, encrypted, or otherwise unreadable
66
+
67
+ Performance Notes:
68
+ - Processes ~10-50 pages/second depending on PDF complexity
69
+ - Memory usage is proportional to document size but page-by-page processing
70
+ prevents loading entire document into memory at once
71
+ - Extraction time is included for performance monitoring and optimization
72
+
73
+ Usage Example:
74
+ >>> pdf_path = Path("technical_manual.pdf")
75
+ >>> result = extract_text_with_metadata(pdf_path)
76
+ >>> print(f"Extracted {result['page_count']} pages in {result['extraction_time']:.2f}s")
77
+ >>> first_page_text = result['pages'][0]['text']
78
+ """
79
+ # Validate input file exists before attempting to open
80
+ if not pdf_path.exists():
81
+ raise FileNotFoundError(f"PDF file not found: {pdf_path}")
82
+
83
+ # Start performance timer for extraction analytics
84
+ start_time = time.perf_counter()
85
+
86
+ try:
87
+ # Open PDF with PyMuPDF - automatically handles various PDF versions
88
+ # Using string conversion for compatibility with older fitz versions
89
+ doc = fitz.open(str(pdf_path))
90
+
91
+ # Extract document-level metadata (may include title, author, subject, keywords)
92
+ # Default to empty dict if no metadata present (common in scanned PDFs)
93
+ metadata = doc.metadata or {}
94
+ page_count = len(doc)
95
+
96
+ # Initialize containers for page-by-page extraction
97
+ pages = [] # Will store individual page data
98
+ all_text = [] # Will store text for concatenation
99
+
100
+ # Process each page sequentially to maintain document order
101
+ for page_num in range(page_count):
102
+ # Load page object (0-indexed internally)
103
+ page = doc[page_num]
104
+
105
+ # Extract text using default extraction parameters
106
+ # This preserves reading order and handles multi-column layouts
107
+ page_text = page.get_text()
108
+
109
+ # Store page data with human-readable page numbering (1-indexed)
110
+ pages.append({
111
+ "page_number": page_num + 1, # Convert to 1-indexed for user clarity
112
+ "text": page_text,
113
+ "char_count": len(page_text) # Useful for chunking decisions
114
+ })
115
+
116
+ # Accumulate text for final concatenation
117
+ all_text.append(page_text)
118
+
119
+ # Properly close the PDF to free resources
120
+ doc.close()
121
+
122
+ # Calculate total extraction time for performance monitoring
123
+ extraction_time = time.perf_counter() - start_time
124
+
125
+ # Return comprehensive extraction results
126
+ return {
127
+ "text": "\n".join(all_text), # Full document text with page breaks
128
+ "pages": pages, # Detailed page-by-page breakdown
129
+ "metadata": metadata, # Original PDF metadata
130
+ "page_count": page_count, # Total pages for quick reference
131
+ "extraction_time": extraction_time # Performance metric
132
+ }
133
+
134
+ except Exception as e:
135
+ # Wrap any extraction errors with context for debugging
136
+ # Common causes: encrypted PDFs, corrupted files, unsupported formats
137
+ raise ValueError(f"Failed to process PDF: {e}")
shared_utils/document_processing/pdfplumber_parser.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ PDFPlumber-based Parser
4
+
5
+ Advanced PDF parsing using pdfplumber for better structure detection
6
+ and cleaner text extraction.
7
+
8
+ Author: Arthur Passuello
9
+ """
10
+
11
+ import re
12
+ import pdfplumber
13
+ from pathlib import Path
14
+ from typing import Dict, List, Optional, Tuple, Any
15
+
16
+
17
+ class PDFPlumberParser:
18
+ """Advanced PDF parser using pdfplumber for structure-aware extraction."""
19
+
20
+ def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800,
21
+ max_chunk_size: int = 2000):
22
+ """Initialize PDFPlumber parser."""
23
+ self.target_chunk_size = target_chunk_size
24
+ self.min_chunk_size = min_chunk_size
25
+ self.max_chunk_size = max_chunk_size
26
+
27
+ # Trash content patterns
28
+ self.trash_patterns = [
29
+ r'Creative Commons.*?License',
30
+ r'International License.*?authors',
31
+ r'RISC-V International',
32
+ r'Visit.*?for further',
33
+ r'editors to suggest.*?corrections',
34
+ r'released under.*?license',
35
+ r'\.{5,}', # Long dots (TOC artifacts)
36
+ r'^\d+\s*$', # Page numbers alone
37
+ ]
38
+
39
+ def extract_with_structure(self, pdf_path: Path) -> List[Dict]:
40
+ """Extract PDF content with structure awareness using pdfplumber."""
41
+ chunks = []
42
+
43
+ with pdfplumber.open(pdf_path) as pdf:
44
+ current_section = None
45
+ current_text = []
46
+
47
+ for page_num, page in enumerate(pdf.pages):
48
+ # Extract text with formatting info
49
+ page_content = self._extract_page_content(page, page_num + 1)
50
+
51
+ for element in page_content:
52
+ if element['type'] == 'header':
53
+ # Save previous section if exists
54
+ if current_text:
55
+ chunk_text = '\n\n'.join(current_text)
56
+ if self._is_valid_chunk(chunk_text):
57
+ chunks.extend(self._create_chunks(
58
+ chunk_text,
59
+ current_section or "Document",
60
+ page_num
61
+ ))
62
+
63
+ # Start new section
64
+ current_section = element['text']
65
+ current_text = []
66
+
67
+ elif element['type'] == 'content':
68
+ # Add to current section
69
+ if self._is_valid_content(element['text']):
70
+ current_text.append(element['text'])
71
+
72
+ # Don't forget last section
73
+ if current_text:
74
+ chunk_text = '\n\n'.join(current_text)
75
+ if self._is_valid_chunk(chunk_text):
76
+ chunks.extend(self._create_chunks(
77
+ chunk_text,
78
+ current_section or "Document",
79
+ len(pdf.pages)
80
+ ))
81
+
82
+ return chunks
83
+
84
+ def _extract_page_content(self, page: Any, page_num: int) -> List[Dict]:
85
+ """Extract structured content from a page."""
86
+ content = []
87
+
88
+ # Get all text with positioning
89
+ chars = page.chars
90
+ if not chars:
91
+ return content
92
+
93
+ # Group by lines
94
+ lines = []
95
+ current_line = []
96
+ current_y = None
97
+
98
+ for char in sorted(chars, key=lambda x: (x['top'], x['x0'])):
99
+ if current_y is None or abs(char['top'] - current_y) < 2:
100
+ current_line.append(char)
101
+ current_y = char['top']
102
+ else:
103
+ if current_line:
104
+ lines.append(current_line)
105
+ current_line = [char]
106
+ current_y = char['top']
107
+
108
+ if current_line:
109
+ lines.append(current_line)
110
+
111
+ # Analyze each line
112
+ for line in lines:
113
+ line_text = ''.join(char['text'] for char in line).strip()
114
+
115
+ if not line_text:
116
+ continue
117
+
118
+ # Detect headers by font size
119
+ avg_font_size = sum(char.get('size', 12) for char in line) / len(line)
120
+ is_bold = any(char.get('fontname', '').lower().count('bold') > 0 for char in line)
121
+
122
+ # Classify content
123
+ if avg_font_size > 14 or is_bold:
124
+ # Likely a header
125
+ if self._is_valid_header(line_text):
126
+ content.append({
127
+ 'type': 'header',
128
+ 'text': line_text,
129
+ 'font_size': avg_font_size,
130
+ 'page': page_num
131
+ })
132
+ else:
133
+ # Regular content
134
+ content.append({
135
+ 'type': 'content',
136
+ 'text': line_text,
137
+ 'font_size': avg_font_size,
138
+ 'page': page_num
139
+ })
140
+
141
+ return content
142
+
143
+ def _is_valid_header(self, text: str) -> bool:
144
+ """Check if text is a valid header."""
145
+ # Skip if too short or too long
146
+ if len(text) < 3 or len(text) > 200:
147
+ return False
148
+
149
+ # Skip if matches trash patterns
150
+ for pattern in self.trash_patterns:
151
+ if re.search(pattern, text, re.IGNORECASE):
152
+ return False
153
+
154
+ # Valid if starts with number or capital letter
155
+ if re.match(r'^(\d+\.?\d*\s+|[A-Z])', text):
156
+ return True
157
+
158
+ # Valid if contains keywords
159
+ keywords = ['chapter', 'section', 'introduction', 'conclusion', 'appendix']
160
+ return any(keyword in text.lower() for keyword in keywords)
161
+
162
+ def _is_valid_content(self, text: str) -> bool:
163
+ """Check if text is valid content (not trash)."""
164
+ # Skip very short text
165
+ if len(text.strip()) < 10:
166
+ return False
167
+
168
+ # Skip trash patterns
169
+ for pattern in self.trash_patterns:
170
+ if re.search(pattern, text, re.IGNORECASE):
171
+ return False
172
+
173
+ return True
174
+
175
+ def _is_valid_chunk(self, text: str) -> bool:
176
+ """Check if chunk text is valid."""
177
+ # Must have minimum length
178
+ if len(text.strip()) < self.min_chunk_size // 2:
179
+ return False
180
+
181
+ # Must have some alphabetic content
182
+ alpha_chars = sum(1 for c in text if c.isalpha())
183
+ if alpha_chars < len(text) * 0.5:
184
+ return False
185
+
186
+ return True
187
+
188
+ def _create_chunks(self, text: str, title: str, page: int) -> List[Dict]:
189
+ """Create chunks from text."""
190
+ chunks = []
191
+
192
+ # Clean text
193
+ text = self._clean_text(text)
194
+
195
+ if len(text) <= self.max_chunk_size:
196
+ # Single chunk
197
+ chunks.append({
198
+ 'text': text,
199
+ 'title': title,
200
+ 'page': page,
201
+ 'metadata': {
202
+ 'parsing_method': 'pdfplumber',
203
+ 'quality_score': self._calculate_quality_score(text)
204
+ }
205
+ })
206
+ else:
207
+ # Split into chunks
208
+ text_chunks = self._split_text_into_chunks(text)
209
+ for i, chunk_text in enumerate(text_chunks):
210
+ chunks.append({
211
+ 'text': chunk_text,
212
+ 'title': f"{title} (Part {i+1})",
213
+ 'page': page,
214
+ 'metadata': {
215
+ 'parsing_method': 'pdfplumber',
216
+ 'part_number': i + 1,
217
+ 'total_parts': len(text_chunks),
218
+ 'quality_score': self._calculate_quality_score(chunk_text)
219
+ }
220
+ })
221
+
222
+ return chunks
223
+
224
+ def _clean_text(self, text: str) -> str:
225
+ """Clean text from artifacts."""
226
+ # Remove volume headers (e.g., "Volume I: RISC-V Unprivileged ISA V20191213")
227
+ text = re.sub(r'Volume\s+[IVX]+:\s*RISC-V[^V]*V\d{8}\s*', '', text, flags=re.IGNORECASE)
228
+ text = re.sub(r'^\d+\s+Volume\s+[IVX]+:.*?$', '', text, flags=re.MULTILINE)
229
+
230
+ # Remove document version artifacts
231
+ text = re.sub(r'Document Version \d{8}\s*', '', text, flags=re.IGNORECASE)
232
+
233
+ # Remove repeated ISA headers
234
+ text = re.sub(r'RISC-V.*?ISA.*?V\d{8}\s*', '', text, flags=re.IGNORECASE)
235
+ text = re.sub(r'The RISC-V Instruction Set Manual\s*', '', text, flags=re.IGNORECASE)
236
+
237
+ # Remove figure/table references that are standalone
238
+ text = re.sub(r'^(Figure|Table)\s+\d+\.\d+:.*?$', '', text, flags=re.MULTILINE)
239
+
240
+ # Remove email addresses (often in contributor lists)
241
+ text = re.sub(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', '', text)
242
+
243
+ # Remove URLs
244
+ text = re.sub(r'https?://[^\s]+', '', text)
245
+
246
+ # Remove page numbers at start/end of lines
247
+ text = re.sub(r'^\d{1,3}\s+', '', text, flags=re.MULTILINE)
248
+ text = re.sub(r'\s+\d{1,3}$', '', text, flags=re.MULTILINE)
249
+
250
+ # Remove excessive dots (TOC artifacts)
251
+ text = re.sub(r'\.{3,}', '', text)
252
+
253
+ # Remove standalone numbers (often page numbers or figure numbers)
254
+ text = re.sub(r'^\s*\d+\s*$', '', text, flags=re.MULTILINE)
255
+
256
+ # Clean up multiple spaces and newlines
257
+ text = re.sub(r'\s{3,}', ' ', text)
258
+ text = re.sub(r'\n{3,}', '\n\n', text)
259
+ text = re.sub(r'[ \t]+', ' ', text) # Normalize all whitespace
260
+
261
+ # Remove common boilerplate phrases
262
+ text = re.sub(r'Contains Nonbinding Recommendations\s*', '', text, flags=re.IGNORECASE)
263
+ text = re.sub(r'Guidance for Industry and FDA Staff\s*', '', text, flags=re.IGNORECASE)
264
+
265
+ return text.strip()
266
+
267
+ def _split_text_into_chunks(self, text: str) -> List[str]:
268
+ """Split text into chunks at sentence boundaries."""
269
+ sentences = re.split(r'(?<=[.!?])\s+', text)
270
+ chunks = []
271
+ current_chunk = []
272
+ current_size = 0
273
+
274
+ for sentence in sentences:
275
+ sentence_size = len(sentence)
276
+
277
+ if current_size + sentence_size > self.target_chunk_size and current_chunk:
278
+ chunks.append(' '.join(current_chunk))
279
+ current_chunk = [sentence]
280
+ current_size = sentence_size
281
+ else:
282
+ current_chunk.append(sentence)
283
+ current_size += sentence_size + 1
284
+
285
+ if current_chunk:
286
+ chunks.append(' '.join(current_chunk))
287
+
288
+ return chunks
289
+
290
+ def _calculate_quality_score(self, text: str) -> float:
291
+ """Calculate quality score for chunk."""
292
+ score = 1.0
293
+
294
+ # Penalize very short or very long
295
+ if len(text) < self.min_chunk_size:
296
+ score *= 0.8
297
+ elif len(text) > self.max_chunk_size:
298
+ score *= 0.9
299
+
300
+ # Reward complete sentences
301
+ if text.strip().endswith(('.', '!', '?')):
302
+ score *= 1.1
303
+
304
+ # Reward technical content
305
+ technical_terms = ['risc', 'instruction', 'register', 'memory', 'processor']
306
+ term_count = sum(1 for term in technical_terms if term in text.lower())
307
+ score *= (1 + term_count * 0.05)
308
+
309
+ return min(score, 1.0)
310
+
311
+ def extract_with_page_coverage(self, pdf_path: Path, pymupdf_pages: List[Dict]) -> List[Dict]:
312
+ """
313
+ Extract content ensuring ALL pages are covered using PyMuPDF page data.
314
+
315
+ Args:
316
+ pdf_path: Path to PDF file
317
+ pymupdf_pages: Page data from PyMuPDF with page numbers and text
318
+
319
+ Returns:
320
+ List of chunks covering ALL document pages
321
+ """
322
+ chunks = []
323
+ chunk_id = 0
324
+
325
+ print(f"📄 Processing {len(pymupdf_pages)} pages with PDFPlumber quality extraction...")
326
+
327
+ with pdfplumber.open(str(pdf_path)) as pdf:
328
+ for pymupdf_page in pymupdf_pages:
329
+ page_num = pymupdf_page['page_number'] # 1-indexed from PyMuPDF
330
+ page_idx = page_num - 1 # Convert to 0-indexed for PDFPlumber
331
+
332
+ if page_idx < len(pdf.pages):
333
+ # Extract with PDFPlumber quality from this specific page
334
+ pdfplumber_page = pdf.pages[page_idx]
335
+ page_text = pdfplumber_page.extract_text()
336
+
337
+ if page_text and page_text.strip():
338
+ # Clean and chunk the page text
339
+ cleaned_text = self._clean_text(page_text)
340
+
341
+ if len(cleaned_text) >= 100: # Minimum meaningful content
342
+ # Create chunks from this page
343
+ page_chunks = self._create_page_chunks(
344
+ cleaned_text, page_num, chunk_id
345
+ )
346
+ chunks.extend(page_chunks)
347
+ chunk_id += len(page_chunks)
348
+
349
+ if len(chunks) % 50 == 0: # Progress indicator
350
+ print(f" Processed {page_num} pages, created {len(chunks)} chunks")
351
+
352
+ print(f"✅ Full coverage: {len(chunks)} chunks from {len(pymupdf_pages)} pages")
353
+ return chunks
354
+
355
+ def _create_page_chunks(self, page_text: str, page_num: int, start_chunk_id: int) -> List[Dict]:
356
+ """Create properly sized chunks from a single page's content."""
357
+ # Clean and validate page text first
358
+ cleaned_text = self._ensure_complete_sentences(page_text)
359
+
360
+ if not cleaned_text or len(cleaned_text) < 50:
361
+ # Skip pages with insufficient content
362
+ return []
363
+
364
+ if len(cleaned_text) <= self.max_chunk_size:
365
+ # Single chunk for small pages
366
+ return [{
367
+ 'text': cleaned_text,
368
+ 'title': f"Page {page_num}",
369
+ 'page': page_num,
370
+ 'metadata': {
371
+ 'parsing_method': 'pdfplumber_page_coverage',
372
+ 'quality_score': self._calculate_quality_score(cleaned_text),
373
+ 'full_page_coverage': True
374
+ }
375
+ }]
376
+ else:
377
+ # Split large pages into chunks with sentence boundaries
378
+ text_chunks = self._split_text_into_chunks(cleaned_text)
379
+ page_chunks = []
380
+
381
+ for i, chunk_text in enumerate(text_chunks):
382
+ # Ensure each chunk is complete
383
+ complete_chunk = self._ensure_complete_sentences(chunk_text)
384
+
385
+ if complete_chunk and len(complete_chunk) >= 100:
386
+ page_chunks.append({
387
+ 'text': complete_chunk,
388
+ 'title': f"Page {page_num} (Part {i+1})",
389
+ 'page': page_num,
390
+ 'metadata': {
391
+ 'parsing_method': 'pdfplumber_page_coverage',
392
+ 'part_number': i + 1,
393
+ 'total_parts': len(text_chunks),
394
+ 'quality_score': self._calculate_quality_score(complete_chunk),
395
+ 'full_page_coverage': True
396
+ }
397
+ })
398
+
399
+ return page_chunks
400
+
401
+ def _ensure_complete_sentences(self, text: str) -> str:
402
+ """Ensure text contains only complete sentences."""
403
+ text = text.strip()
404
+ if not text:
405
+ return ""
406
+
407
+ # Find last complete sentence
408
+ last_sentence_end = -1
409
+ for i, char in enumerate(reversed(text)):
410
+ if char in '.!?:':
411
+ last_sentence_end = len(text) - i
412
+ break
413
+
414
+ if last_sentence_end > 0:
415
+ # Return text up to last complete sentence
416
+ complete_text = text[:last_sentence_end].strip()
417
+
418
+ # Ensure it starts properly (capital letter or common starters)
419
+ if complete_text and (complete_text[0].isupper() or
420
+ complete_text.startswith(('The ', 'A ', 'An ', 'This ', 'RISC'))):
421
+ return complete_text
422
+
423
+ # If no complete sentences found, return empty
424
+ return ""
425
+
426
+ def parse_document(self, pdf_path: Path, pdf_data: Dict[str, Any] = None) -> List[Dict]:
427
+ """
428
+ Parse document using PDFPlumber (required by HybridParser).
429
+
430
+ Args:
431
+ pdf_path: Path to PDF file
432
+ pdf_data: PyMuPDF page data to ensure full page coverage
433
+
434
+ Returns:
435
+ List of chunks with structure preservation across ALL pages
436
+ """
437
+ if pdf_data and 'pages' in pdf_data:
438
+ # Use PyMuPDF page data to ensure full coverage
439
+ return self.extract_with_page_coverage(pdf_path, pdf_data['pages'])
440
+ else:
441
+ # Fallback to structure-based extraction
442
+ return self.extract_with_structure(pdf_path)
443
+
444
+
445
+ def parse_pdf_with_pdfplumber(pdf_path: Path, **kwargs) -> List[Dict]:
446
+ """Main entry point for PDFPlumber parsing."""
447
+ parser = PDFPlumberParser(**kwargs)
448
+ chunks = parser.extract_with_structure(pdf_path)
449
+
450
+ print(f"PDFPlumber extracted {len(chunks)} chunks")
451
+
452
+ return chunks
shared_utils/document_processing/toc_guided_parser.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ TOC-Guided PDF Parser
4
+
5
+ Uses the Table of Contents to guide intelligent chunking that respects
6
+ document structure and hierarchy.
7
+
8
+ Author: Arthur Passuello
9
+ """
10
+
11
+ import re
12
+ from typing import Dict, List, Optional, Tuple
13
+ from dataclasses import dataclass
14
+
15
+
16
+ @dataclass
17
+ class TOCEntry:
18
+ """Represents a table of contents entry."""
19
+ title: str
20
+ page: int
21
+ level: int # 0 for chapters, 1 for sections, 2 for subsections
22
+ parent: Optional[str] = None
23
+ parent_title: Optional[str] = None # Added for hybrid parser compatibility
24
+
25
+
26
+ class TOCGuidedParser:
27
+ """Parser that uses TOC to create structure-aware chunks."""
28
+
29
+ def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800,
30
+ max_chunk_size: int = 2000):
31
+ """Initialize TOC-guided parser."""
32
+ self.target_chunk_size = target_chunk_size
33
+ self.min_chunk_size = min_chunk_size
34
+ self.max_chunk_size = max_chunk_size
35
+
36
+ def parse_toc(self, pages: List[Dict]) -> List[TOCEntry]:
37
+ """Parse table of contents from pages."""
38
+ toc_entries = []
39
+
40
+ # Find TOC pages (usually early in document)
41
+ toc_pages = []
42
+ for i, page in enumerate(pages[:20]): # Check first 20 pages
43
+ page_text = page.get('text', '').lower()
44
+ if 'contents' in page_text or 'table of contents' in page_text:
45
+ toc_pages.append((i, page))
46
+
47
+ if not toc_pages:
48
+ print("No TOC found, using fallback structure detection")
49
+ return self._detect_structure_without_toc(pages)
50
+
51
+ # Parse TOC entries
52
+ for page_idx, page in toc_pages:
53
+ text = page.get('text', '')
54
+ lines = text.split('\n')
55
+
56
+ i = 0
57
+ while i < len(lines):
58
+ line = lines[i].strip()
59
+
60
+ # Skip empty lines and TOC header
61
+ if not line or 'contents' in line.lower():
62
+ i += 1
63
+ continue
64
+
65
+ # Pattern 1: "1.1 Title .... 23"
66
+ match1 = re.match(r'^(\d+(?:\.\d+)*)\s+(.+?)\s*\.{2,}\s*(\d+)$', line)
67
+ if match1:
68
+ number, title, page_num = match1.groups()
69
+ level = len(number.split('.')) - 1
70
+ toc_entries.append(TOCEntry(
71
+ title=title.strip(),
72
+ page=int(page_num),
73
+ level=level
74
+ ))
75
+ i += 1
76
+ continue
77
+
78
+ # Pattern 2: Multi-line format
79
+ # "1.1"
80
+ # "Title"
81
+ # ". . . . 23"
82
+ if re.match(r'^(\d+(?:\.\d+)*)$', line):
83
+ number = line
84
+ if i + 1 < len(lines):
85
+ title_line = lines[i + 1].strip()
86
+ if i + 2 < len(lines):
87
+ dots_line = lines[i + 2].strip()
88
+ page_match = re.search(r'(\d+)\s*$', dots_line)
89
+ if page_match and '.' in dots_line:
90
+ title = title_line
91
+ page_num = int(page_match.group(1))
92
+ level = len(number.split('.')) - 1
93
+ toc_entries.append(TOCEntry(
94
+ title=title,
95
+ page=page_num,
96
+ level=level
97
+ ))
98
+ i += 3
99
+ continue
100
+
101
+ # Pattern 3: "Chapter 1: Title ... 23"
102
+ match3 = re.match(r'^(Chapter|Section|Part)\s+(\d+):?\s+(.+?)\s*\.{2,}\s*(\d+)$', line, re.IGNORECASE)
103
+ if match3:
104
+ prefix, number, title, page_num = match3.groups()
105
+ level = 0 if prefix.lower() == 'chapter' else 1
106
+ toc_entries.append(TOCEntry(
107
+ title=f"{prefix} {number}: {title}",
108
+ page=int(page_num),
109
+ level=level
110
+ ))
111
+ i += 1
112
+ continue
113
+
114
+ i += 1
115
+
116
+ # Add parent relationships
117
+ for i, entry in enumerate(toc_entries):
118
+ if entry.level > 0:
119
+ # Find parent (previous entry with lower level)
120
+ for j in range(i - 1, -1, -1):
121
+ if toc_entries[j].level < entry.level:
122
+ entry.parent = toc_entries[j].title
123
+ entry.parent_title = toc_entries[j].title # Set both for compatibility
124
+ break
125
+
126
+ return toc_entries
127
+
128
+ def _detect_structure_without_toc(self, pages: List[Dict]) -> List[TOCEntry]:
129
+ """Fallback: detect structure from content patterns across ALL pages."""
130
+ entries = []
131
+
132
+ # Expanded patterns for better structure detection
133
+ chapter_patterns = [
134
+ re.compile(r'^(Chapter|CHAPTER)\s+(\d+|[IVX]+)(?:\s*[:\-]\s*(.+))?', re.MULTILINE),
135
+ re.compile(r'^(\d+)\s+([A-Z][^.]*?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1 Introduction"
136
+ re.compile(r'^([A-Z][A-Z\s]{10,})$', re.MULTILINE), # ALL CAPS titles
137
+ ]
138
+
139
+ section_patterns = [
140
+ re.compile(r'^(\d+\.\d+)\s+(.+?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1.1 Section"
141
+ re.compile(r'^(\d+\.\d+\.\d+)\s+(.+?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1.1.1 Subsection"
142
+ ]
143
+
144
+ # Process ALL pages, not just first 20
145
+ for i, page in enumerate(pages):
146
+ text = page.get('text', '')
147
+ if not text.strip():
148
+ continue
149
+
150
+ # Find chapters with various patterns
151
+ for pattern in chapter_patterns:
152
+ for match in pattern.finditer(text):
153
+ if len(match.groups()) >= 2:
154
+ if len(match.groups()) >= 3 and match.group(3):
155
+ title = match.group(3).strip()
156
+ else:
157
+ title = match.group(2).strip() if match.group(2) else f"Section {match.group(1)}"
158
+
159
+ # Skip very short or likely false positives
160
+ if len(title) >= 3 and not re.match(r'^\d+$', title):
161
+ entries.append(TOCEntry(
162
+ title=title,
163
+ page=i + 1,
164
+ level=0
165
+ ))
166
+
167
+ # Find sections
168
+ for pattern in section_patterns:
169
+ for match in pattern.finditer(text):
170
+ section_num = match.group(1)
171
+ title = match.group(2).strip() if len(match.groups()) >= 2 else f"Section {section_num}"
172
+
173
+ # Determine level by number of dots
174
+ level = section_num.count('.')
175
+
176
+ # Skip very short titles or obvious artifacts
177
+ if len(title) >= 3 and not re.match(r'^\d+$', title):
178
+ entries.append(TOCEntry(
179
+ title=title,
180
+ page=i + 1,
181
+ level=level
182
+ ))
183
+
184
+ # If still no entries found, create page-based entries for full coverage
185
+ if not entries:
186
+ print("No structure patterns found, creating page-based sections for full coverage")
187
+ # Create sections every 10 pages to ensure full document coverage
188
+ for i in range(0, len(pages), 10):
189
+ start_page = i + 1
190
+ end_page = min(i + 10, len(pages))
191
+ title = f"Pages {start_page}-{end_page}"
192
+ entries.append(TOCEntry(
193
+ title=title,
194
+ page=start_page,
195
+ level=0
196
+ ))
197
+
198
+ return entries
199
+
200
+ def create_chunks_from_toc(self, pdf_data: Dict, toc_entries: List[TOCEntry]) -> List[Dict]:
201
+ """Create chunks based on TOC structure."""
202
+ chunks = []
203
+ pages = pdf_data.get('pages', [])
204
+
205
+ for i, entry in enumerate(toc_entries):
206
+ # Determine page range for this entry
207
+ start_page = entry.page - 1 # Convert to 0-indexed
208
+
209
+ # Find end page (start of next entry at same or higher level)
210
+ end_page = len(pages)
211
+ for j in range(i + 1, len(toc_entries)):
212
+ if toc_entries[j].level <= entry.level:
213
+ end_page = toc_entries[j].page - 1
214
+ break
215
+
216
+ # Extract text for this section
217
+ section_text = []
218
+ for page_idx in range(max(0, start_page), min(end_page, len(pages))):
219
+ page_text = pages[page_idx].get('text', '')
220
+ if page_text.strip():
221
+ section_text.append(page_text)
222
+
223
+ if not section_text:
224
+ continue
225
+
226
+ full_text = '\n\n'.join(section_text)
227
+
228
+ # Create chunks from section text
229
+ if len(full_text) <= self.max_chunk_size:
230
+ # Single chunk for small sections
231
+ chunks.append({
232
+ 'text': full_text.strip(),
233
+ 'title': entry.title,
234
+ 'parent_title': entry.parent_title or entry.parent or '',
235
+ 'level': entry.level,
236
+ 'page': entry.page,
237
+ 'context': f"From {entry.title}",
238
+ 'metadata': {
239
+ 'parsing_method': 'toc_guided',
240
+ 'section_title': entry.title,
241
+ 'hierarchy_level': entry.level
242
+ }
243
+ })
244
+ else:
245
+ # Split large sections into chunks
246
+ section_chunks = self._split_text_into_chunks(full_text)
247
+ for j, chunk_text in enumerate(section_chunks):
248
+ chunks.append({
249
+ 'text': chunk_text.strip(),
250
+ 'title': f"{entry.title} (Part {j+1})",
251
+ 'parent_title': entry.parent_title or entry.parent or '',
252
+ 'level': entry.level,
253
+ 'page': entry.page,
254
+ 'context': f"Part {j+1} of {entry.title}",
255
+ 'metadata': {
256
+ 'parsing_method': 'toc_guided',
257
+ 'section_title': entry.title,
258
+ 'hierarchy_level': entry.level,
259
+ 'part_number': j + 1,
260
+ 'total_parts': len(section_chunks)
261
+ }
262
+ })
263
+
264
+ return chunks
265
+
266
+ def _split_text_into_chunks(self, text: str) -> List[str]:
267
+ """Split text into chunks while preserving sentence boundaries."""
268
+ sentences = re.split(r'(?<=[.!?])\s+', text)
269
+ chunks = []
270
+ current_chunk = []
271
+ current_size = 0
272
+
273
+ for sentence in sentences:
274
+ sentence_size = len(sentence)
275
+
276
+ if current_size + sentence_size > self.target_chunk_size and current_chunk:
277
+ # Save current chunk
278
+ chunks.append(' '.join(current_chunk))
279
+ current_chunk = [sentence]
280
+ current_size = sentence_size
281
+ else:
282
+ current_chunk.append(sentence)
283
+ current_size += sentence_size + 1 # +1 for space
284
+
285
+ if current_chunk:
286
+ chunks.append(' '.join(current_chunk))
287
+
288
+ return chunks
289
+
290
+
291
+ def parse_pdf_with_toc_guidance(pdf_data: Dict, **kwargs) -> List[Dict]:
292
+ """Main entry point for TOC-guided parsing."""
293
+ parser = TOCGuidedParser(**kwargs)
294
+
295
+ # Parse TOC
296
+ pages = pdf_data.get('pages', [])
297
+ toc_entries = parser.parse_toc(pages)
298
+
299
+ print(f"Found {len(toc_entries)} TOC entries")
300
+
301
+ if not toc_entries:
302
+ print("No TOC entries found, falling back to basic chunking")
303
+ from .chunker import chunk_technical_text
304
+ return chunk_technical_text(pdf_data.get('text', ''))
305
+
306
+ # Create chunks based on TOC
307
+ chunks = parser.create_chunks_from_toc(pdf_data, toc_entries)
308
+
309
+ print(f"Created {len(chunks)} chunks from TOC structure")
310
+
311
+ return chunks
shared_utils/embeddings/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Embeddings module
shared_utils/embeddings/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (169 Bytes). View file
 
shared_utils/embeddings/__pycache__/generator.cpython-312.pyc ADDED
Binary file (3.02 kB). View file
 
shared_utils/embeddings/generator.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import List, Optional
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ # Global cache for embeddings
7
+ _embedding_cache = {}
8
+ _model_cache = {}
9
+
10
+
11
+ def generate_embeddings(
12
+ texts: List[str],
13
+ model_name: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
14
+ batch_size: int = 32,
15
+ use_mps: bool = True,
16
+ ) -> np.ndarray:
17
+ """
18
+ Generate embeddings for text chunks with caching.
19
+
20
+ Args:
21
+ texts: List of text chunks to embed
22
+ model_name: SentenceTransformer model identifier
23
+ batch_size: Processing batch size
24
+ use_mps: Use Apple Silicon acceleration
25
+
26
+ Returns:
27
+ numpy array of shape (len(texts), embedding_dim)
28
+
29
+ Performance Target:
30
+ - 100 texts/second on M4-Pro
31
+ - 384-dimensional embeddings
32
+ - Memory usage <500MB
33
+ """
34
+ # Check cache for all texts
35
+ cache_keys = [f"{model_name}:{text}" for text in texts]
36
+ cached_embeddings = []
37
+ texts_to_compute = []
38
+ compute_indices = []
39
+
40
+ for i, key in enumerate(cache_keys):
41
+ if key in _embedding_cache:
42
+ cached_embeddings.append((i, _embedding_cache[key]))
43
+ else:
44
+ texts_to_compute.append(texts[i])
45
+ compute_indices.append(i)
46
+
47
+ # Load model if needed
48
+ if model_name not in _model_cache:
49
+ model = SentenceTransformer(model_name)
50
+ device = 'mps' if use_mps and torch.backends.mps.is_available() else 'cpu'
51
+ model = model.to(device)
52
+ model.eval()
53
+ _model_cache[model_name] = model
54
+ else:
55
+ model = _model_cache[model_name]
56
+
57
+ # Compute new embeddings
58
+ if texts_to_compute:
59
+ with torch.no_grad():
60
+ new_embeddings = model.encode(
61
+ texts_to_compute,
62
+ batch_size=batch_size,
63
+ convert_to_numpy=True,
64
+ normalize_embeddings=False
65
+ ).astype(np.float32)
66
+
67
+ # Cache new embeddings
68
+ for i, text in enumerate(texts_to_compute):
69
+ key = f"{model_name}:{text}"
70
+ _embedding_cache[key] = new_embeddings[i]
71
+
72
+ # Reconstruct full embedding array
73
+ result = np.zeros((len(texts), 384), dtype=np.float32)
74
+
75
+ # Fill cached embeddings
76
+ for idx, embedding in cached_embeddings:
77
+ result[idx] = embedding
78
+
79
+ # Fill newly computed embeddings
80
+ if texts_to_compute:
81
+ for i, original_idx in enumerate(compute_indices):
82
+ result[original_idx] = new_embeddings[i]
83
+
84
+ return result
shared_utils/generation/__pycache__/adaptive_prompt_engine.cpython-312.pyc ADDED
Binary file (19.9 kB). View file
 
shared_utils/generation/__pycache__/answer_generator.cpython-312.pyc ADDED
Binary file (27.1 kB). View file
 
shared_utils/generation/__pycache__/chain_of_thought_engine.cpython-312.pyc ADDED
Binary file (20.9 kB). View file
 
shared_utils/generation/__pycache__/hf_answer_generator.cpython-312.pyc ADDED
Binary file (35.8 kB). View file
 
shared_utils/generation/__pycache__/inference_providers_generator.cpython-312.pyc ADDED
Binary file (22.2 kB). View file
 
shared_utils/generation/__pycache__/ollama_answer_generator.cpython-312.pyc ADDED
Binary file (32 kB). View file
 
shared_utils/generation/__pycache__/prompt_optimizer.cpython-312.pyc ADDED
Binary file (28.1 kB). View file
 
shared_utils/generation/__pycache__/prompt_templates.cpython-312.pyc ADDED
Binary file (21.6 kB). View file
 
shared_utils/generation/adaptive_prompt_engine.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adaptive Prompt Engine for Dynamic Context-Aware Prompt Optimization.
3
+
4
+ This module provides intelligent prompt adaptation based on context quality,
5
+ query complexity, and performance requirements.
6
+ """
7
+
8
+ import logging
9
+ from typing import Dict, List, Optional, Tuple, Any
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+ import numpy as np
13
+
14
+ from .prompt_templates import (
15
+ QueryType,
16
+ PromptTemplate,
17
+ TechnicalPromptTemplates
18
+ )
19
+
20
+
21
+ class ContextQuality(Enum):
22
+ """Context quality levels for adaptive prompting."""
23
+ HIGH = "high" # >0.8 relevance, low noise
24
+ MEDIUM = "medium" # 0.5-0.8 relevance, moderate noise
25
+ LOW = "low" # <0.5 relevance, high noise
26
+
27
+
28
+ class QueryComplexity(Enum):
29
+ """Query complexity levels."""
30
+ SIMPLE = "simple" # Single concept, direct answer
31
+ MODERATE = "moderate" # Multiple concepts, structured answer
32
+ COMPLEX = "complex" # Multi-step reasoning, comprehensive answer
33
+
34
+
35
+ @dataclass
36
+ class ContextMetrics:
37
+ """Metrics for evaluating context quality."""
38
+ relevance_score: float
39
+ noise_ratio: float
40
+ chunk_count: int
41
+ avg_chunk_length: int
42
+ technical_density: float
43
+ source_diversity: int
44
+
45
+
46
+ @dataclass
47
+ class AdaptivePromptConfig:
48
+ """Configuration for adaptive prompt generation."""
49
+ context_quality: ContextQuality
50
+ query_complexity: QueryComplexity
51
+ max_context_length: int
52
+ prefer_concise: bool
53
+ include_few_shot: bool
54
+ enable_chain_of_thought: bool
55
+ confidence_threshold: float
56
+
57
+
58
+ class AdaptivePromptEngine:
59
+ """
60
+ Intelligent prompt adaptation engine that optimizes prompts based on:
61
+ - Context quality and relevance
62
+ - Query complexity and type
63
+ - Performance requirements
64
+ - User preferences
65
+ """
66
+
67
+ def __init__(self):
68
+ """Initialize the adaptive prompt engine."""
69
+ self.logger = logging.getLogger(__name__)
70
+
71
+ # Context quality thresholds
72
+ self.high_quality_threshold = 0.8
73
+ self.medium_quality_threshold = 0.5
74
+
75
+ # Query complexity indicators
76
+ self.complex_keywords = {
77
+ "implementation": ["implement", "build", "create", "develop", "setup"],
78
+ "comparison": ["compare", "difference", "versus", "vs", "better"],
79
+ "analysis": ["analyze", "evaluate", "assess", "study", "examine"],
80
+ "multi_step": ["process", "procedure", "steps", "how to", "guide"]
81
+ }
82
+
83
+ # Length optimization thresholds
84
+ self.token_limits = {
85
+ "concise": 512,
86
+ "standard": 1024,
87
+ "detailed": 2048,
88
+ "comprehensive": 4096
89
+ }
90
+
91
+ def analyze_context_quality(self, chunks: List[Dict[str, Any]]) -> ContextMetrics:
92
+ """
93
+ Analyze the quality of retrieved context chunks.
94
+
95
+ Args:
96
+ chunks: List of context chunks with metadata
97
+
98
+ Returns:
99
+ ContextMetrics with quality assessment
100
+ """
101
+ if not chunks:
102
+ return ContextMetrics(
103
+ relevance_score=0.0,
104
+ noise_ratio=1.0,
105
+ chunk_count=0,
106
+ avg_chunk_length=0,
107
+ technical_density=0.0,
108
+ source_diversity=0
109
+ )
110
+
111
+ # Calculate relevance score (using confidence scores if available)
112
+ relevance_scores = []
113
+ for chunk in chunks:
114
+ # Use confidence score if available, otherwise use a heuristic
115
+ if 'confidence' in chunk:
116
+ relevance_scores.append(chunk['confidence'])
117
+ elif 'score' in chunk:
118
+ relevance_scores.append(chunk['score'])
119
+ else:
120
+ # Heuristic: longer chunks with technical terms are more relevant
121
+ content = chunk.get('content', chunk.get('text', ''))
122
+ tech_terms = self._count_technical_terms(content)
123
+ relevance_scores.append(min(tech_terms / 10.0, 1.0))
124
+
125
+ avg_relevance = np.mean(relevance_scores) if relevance_scores else 0.0
126
+
127
+ # Calculate noise ratio (fragments, repetitive content)
128
+ noise_count = 0
129
+ total_chunks = len(chunks)
130
+
131
+ for chunk in chunks:
132
+ content = chunk.get('content', chunk.get('text', ''))
133
+ if self._is_noisy_chunk(content):
134
+ noise_count += 1
135
+
136
+ noise_ratio = noise_count / total_chunks if total_chunks > 0 else 0.0
137
+
138
+ # Calculate average chunk length
139
+ chunk_lengths = []
140
+ for chunk in chunks:
141
+ content = chunk.get('content', chunk.get('text', ''))
142
+ chunk_lengths.append(len(content))
143
+
144
+ avg_chunk_length = int(np.mean(chunk_lengths)) if chunk_lengths else 0
145
+
146
+ # Calculate technical density
147
+ technical_density = self._calculate_technical_density(chunks)
148
+
149
+ # Calculate source diversity
150
+ sources = set()
151
+ for chunk in chunks:
152
+ source = chunk.get('metadata', {}).get('source', 'unknown')
153
+ sources.add(source)
154
+
155
+ source_diversity = len(sources)
156
+
157
+ return ContextMetrics(
158
+ relevance_score=avg_relevance,
159
+ noise_ratio=noise_ratio,
160
+ chunk_count=len(chunks),
161
+ avg_chunk_length=avg_chunk_length,
162
+ technical_density=technical_density,
163
+ source_diversity=source_diversity
164
+ )
165
+
166
+ def determine_query_complexity(self, query: str) -> QueryComplexity:
167
+ """
168
+ Determine the complexity level of a query.
169
+
170
+ Args:
171
+ query: User's question
172
+
173
+ Returns:
174
+ QueryComplexity level
175
+ """
176
+ query_lower = query.lower()
177
+ complexity_score = 0
178
+
179
+ # Check for complex keywords
180
+ for category, keywords in self.complex_keywords.items():
181
+ if any(keyword in query_lower for keyword in keywords):
182
+ complexity_score += 1
183
+
184
+ # Check for multiple questions or concepts
185
+ if '?' in query[:-1]: # Multiple question marks (excluding the last one)
186
+ complexity_score += 1
187
+
188
+ if any(word in query_lower for word in ["and", "or", "also", "additionally", "furthermore"]):
189
+ complexity_score += 1
190
+
191
+ # Check query length
192
+ word_count = len(query.split())
193
+ if word_count > 20:
194
+ complexity_score += 1
195
+ elif word_count > 10:
196
+ complexity_score += 0.5
197
+
198
+ # Determine complexity level
199
+ if complexity_score >= 2:
200
+ return QueryComplexity.COMPLEX
201
+ elif complexity_score >= 1:
202
+ return QueryComplexity.MODERATE
203
+ else:
204
+ return QueryComplexity.SIMPLE
205
+
206
+ def generate_adaptive_config(
207
+ self,
208
+ query: str,
209
+ context_chunks: List[Dict[str, Any]],
210
+ max_tokens: int = 2048,
211
+ prefer_speed: bool = False
212
+ ) -> AdaptivePromptConfig:
213
+ """
214
+ Generate adaptive prompt configuration based on context and query analysis.
215
+
216
+ Args:
217
+ query: User's question
218
+ context_chunks: Retrieved context chunks
219
+ max_tokens: Maximum token limit
220
+ prefer_speed: Whether to optimize for speed over quality
221
+
222
+ Returns:
223
+ AdaptivePromptConfig with optimized settings
224
+ """
225
+ # Analyze context quality
226
+ context_metrics = self.analyze_context_quality(context_chunks)
227
+
228
+ # Determine context quality level
229
+ if context_metrics.relevance_score >= self.high_quality_threshold:
230
+ context_quality = ContextQuality.HIGH
231
+ elif context_metrics.relevance_score >= self.medium_quality_threshold:
232
+ context_quality = ContextQuality.MEDIUM
233
+ else:
234
+ context_quality = ContextQuality.LOW
235
+
236
+ # Determine query complexity
237
+ query_complexity = self.determine_query_complexity(query)
238
+
239
+ # Adapt configuration based on analysis
240
+ config = AdaptivePromptConfig(
241
+ context_quality=context_quality,
242
+ query_complexity=query_complexity,
243
+ max_context_length=max_tokens,
244
+ prefer_concise=prefer_speed,
245
+ include_few_shot=self._should_include_few_shot(context_quality, query_complexity),
246
+ enable_chain_of_thought=self._should_enable_cot(query_complexity),
247
+ confidence_threshold=self._get_confidence_threshold(context_quality)
248
+ )
249
+
250
+ return config
251
+
252
+ def create_adaptive_prompt(
253
+ self,
254
+ query: str,
255
+ context_chunks: List[Dict[str, Any]],
256
+ config: Optional[AdaptivePromptConfig] = None
257
+ ) -> Dict[str, str]:
258
+ """
259
+ Create an adaptive prompt optimized for the specific query and context.
260
+
261
+ Args:
262
+ query: User's question
263
+ context_chunks: Retrieved context chunks
264
+ config: Optional configuration (auto-generated if None)
265
+
266
+ Returns:
267
+ Dict with optimized 'system' and 'user' prompts
268
+ """
269
+ if config is None:
270
+ config = self.generate_adaptive_config(query, context_chunks)
271
+
272
+ # Get base template
273
+ query_type = TechnicalPromptTemplates.detect_query_type(query)
274
+ base_template = TechnicalPromptTemplates.get_template_for_query(query)
275
+
276
+ # Adapt template based on configuration
277
+ adapted_template = self._adapt_template(base_template, config)
278
+
279
+ # Format context with optimization
280
+ formatted_context = self._format_context_adaptive(context_chunks, config)
281
+
282
+ # Create prompt with adaptive formatting
283
+ prompt = TechnicalPromptTemplates.format_prompt_with_template(
284
+ query=query,
285
+ context=formatted_context,
286
+ template=adapted_template,
287
+ include_few_shot=config.include_few_shot
288
+ )
289
+
290
+ # Add chain-of-thought if enabled
291
+ if config.enable_chain_of_thought:
292
+ prompt = self._add_chain_of_thought(prompt, query_type)
293
+
294
+ return prompt
295
+
296
+ def _adapt_template(
297
+ self,
298
+ base_template: PromptTemplate,
299
+ config: AdaptivePromptConfig
300
+ ) -> PromptTemplate:
301
+ """
302
+ Adapt a base template based on configuration.
303
+
304
+ Args:
305
+ base_template: Base prompt template
306
+ config: Adaptive configuration
307
+
308
+ Returns:
309
+ Adapted PromptTemplate
310
+ """
311
+ # Modify system prompt based on context quality
312
+ system_prompt = base_template.system_prompt
313
+
314
+ if config.context_quality == ContextQuality.LOW:
315
+ system_prompt += """
316
+
317
+ IMPORTANT: The provided context may have limited relevance. Focus on:
318
+ - Only use information that directly relates to the question
319
+ - Clearly state if information is insufficient
320
+ - Avoid making assumptions beyond the provided context
321
+ - Be explicit about confidence levels"""
322
+
323
+ elif config.context_quality == ContextQuality.HIGH:
324
+ system_prompt += """
325
+
326
+ CONTEXT QUALITY: High-quality, relevant context is provided. You can:
327
+ - Provide comprehensive, detailed answers
328
+ - Make reasonable inferences from the context
329
+ - Include related technical details and examples
330
+ - Reference multiple sources confidently"""
331
+
332
+ # Modify answer guidelines based on complexity and preferences
333
+ answer_guidelines = base_template.answer_guidelines
334
+
335
+ if config.prefer_concise:
336
+ answer_guidelines += "\n\nResponse style: Be concise and focus on essential information. Aim for clarity over comprehensiveness."
337
+
338
+ if config.query_complexity == QueryComplexity.COMPLEX:
339
+ answer_guidelines += "\n\nComplex query handling: Break down your answer into clear sections. Use numbered steps for procedures."
340
+
341
+ return PromptTemplate(
342
+ system_prompt=system_prompt,
343
+ context_format=base_template.context_format,
344
+ query_format=base_template.query_format,
345
+ answer_guidelines=answer_guidelines,
346
+ few_shot_examples=base_template.few_shot_examples
347
+ )
348
+
349
+ def _format_context_adaptive(
350
+ self,
351
+ chunks: List[Dict[str, Any]],
352
+ config: AdaptivePromptConfig
353
+ ) -> str:
354
+ """
355
+ Format context chunks with adaptive optimization.
356
+
357
+ Args:
358
+ chunks: Context chunks to format
359
+ config: Adaptive configuration
360
+
361
+ Returns:
362
+ Formatted context string
363
+ """
364
+ if not chunks:
365
+ return "No relevant context available."
366
+
367
+ # Filter chunks based on confidence if low quality context
368
+ filtered_chunks = chunks
369
+ if config.context_quality == ContextQuality.LOW:
370
+ filtered_chunks = [
371
+ chunk for chunk in chunks
372
+ if self._meets_confidence_threshold(chunk, config.confidence_threshold)
373
+ ]
374
+
375
+ # Limit context length if needed
376
+ if config.prefer_concise:
377
+ filtered_chunks = filtered_chunks[:3] # Limit to top 3 chunks
378
+
379
+ # Format chunks
380
+ context_parts = []
381
+ for i, chunk in enumerate(filtered_chunks):
382
+ chunk_text = chunk.get('content', chunk.get('text', ''))
383
+
384
+ # Truncate if too long and prefer_concise is True
385
+ if config.prefer_concise and len(chunk_text) > 800:
386
+ chunk_text = chunk_text[:800] + "..."
387
+
388
+ metadata = chunk.get('metadata', {})
389
+ page_num = metadata.get('page_number', 'unknown')
390
+ source = metadata.get('source', 'unknown')
391
+
392
+ context_parts.append(
393
+ f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}"
394
+ )
395
+
396
+ return "\n\n---\n\n".join(context_parts)
397
+
398
+ def _add_chain_of_thought(
399
+ self,
400
+ prompt: Dict[str, str],
401
+ query_type: QueryType
402
+ ) -> Dict[str, str]:
403
+ """
404
+ Add chain-of-thought reasoning to the prompt.
405
+
406
+ Args:
407
+ prompt: Base prompt dictionary
408
+ query_type: Type of query
409
+
410
+ Returns:
411
+ Enhanced prompt with chain-of-thought
412
+ """
413
+ cot_addition = """
414
+
415
+ Before providing your final answer, think through this step-by-step:
416
+
417
+ 1. What is the user specifically asking for?
418
+ 2. What relevant information is available in the context?
419
+ 3. How should I structure my response for maximum clarity?
420
+ 4. Are there any important caveats or limitations to mention?
421
+
422
+ Step-by-step reasoning:"""
423
+
424
+ prompt["user"] = prompt["user"] + cot_addition
425
+
426
+ return prompt
427
+
428
+ def _should_include_few_shot(
429
+ self,
430
+ context_quality: ContextQuality,
431
+ query_complexity: QueryComplexity
432
+ ) -> bool:
433
+ """Determine if few-shot examples should be included."""
434
+ # Include few-shot for complex queries or when context quality is low
435
+ if query_complexity == QueryComplexity.COMPLEX:
436
+ return True
437
+ if context_quality == ContextQuality.LOW:
438
+ return True
439
+ return False
440
+
441
+ def _should_enable_cot(self, query_complexity: QueryComplexity) -> bool:
442
+ """Determine if chain-of-thought should be enabled."""
443
+ return query_complexity == QueryComplexity.COMPLEX
444
+
445
+ def _get_confidence_threshold(self, context_quality: ContextQuality) -> float:
446
+ """Get confidence threshold based on context quality."""
447
+ thresholds = {
448
+ ContextQuality.HIGH: 0.3,
449
+ ContextQuality.MEDIUM: 0.5,
450
+ ContextQuality.LOW: 0.7
451
+ }
452
+ return thresholds[context_quality]
453
+
454
+ def _count_technical_terms(self, text: str) -> int:
455
+ """Count technical terms in text."""
456
+ technical_terms = [
457
+ "risc-v", "riscv", "cpu", "gpu", "mcu", "interrupt", "register",
458
+ "memory", "cache", "pipeline", "instruction", "assembly", "compiler",
459
+ "embedded", "freertos", "rtos", "gpio", "uart", "spi", "i2c",
460
+ "adc", "dac", "timer", "pwm", "dma", "firmware", "bootloader",
461
+ "ai", "ml", "neural", "transformer", "attention", "embedding"
462
+ ]
463
+
464
+ text_lower = text.lower()
465
+ count = 0
466
+ for term in technical_terms:
467
+ count += text_lower.count(term)
468
+
469
+ return count
470
+
471
+ def _is_noisy_chunk(self, content: str) -> bool:
472
+ """Determine if a chunk is noisy (low quality)."""
473
+ # Check for common noise indicators
474
+ noise_indicators = [
475
+ "table of contents",
476
+ "copyright",
477
+ "creative commons",
478
+ "license",
479
+ "all rights reserved",
480
+ "terms of use",
481
+ "privacy policy"
482
+ ]
483
+
484
+ content_lower = content.lower()
485
+
486
+ # Check for noise indicators
487
+ for indicator in noise_indicators:
488
+ if indicator in content_lower:
489
+ return True
490
+
491
+ # Check for very short fragments
492
+ if len(content) < 100:
493
+ return True
494
+
495
+ # Check for repetitive content
496
+ words = content.split()
497
+ if len(set(words)) < len(words) * 0.3: # Less than 30% unique words
498
+ return True
499
+
500
+ return False
501
+
502
+ def _calculate_technical_density(self, chunks: List[Dict[str, Any]]) -> float:
503
+ """Calculate technical density of chunks."""
504
+ if not chunks:
505
+ return 0.0
506
+
507
+ total_terms = 0
508
+ total_words = 0
509
+
510
+ for chunk in chunks:
511
+ content = chunk.get('content', chunk.get('text', ''))
512
+ words = content.split()
513
+ total_words += len(words)
514
+ total_terms += self._count_technical_terms(content)
515
+
516
+ return (total_terms / total_words) if total_words > 0 else 0.0
517
+
518
+ def _meets_confidence_threshold(
519
+ self,
520
+ chunk: Dict[str, Any],
521
+ threshold: float
522
+ ) -> bool:
523
+ """Check if chunk meets confidence threshold."""
524
+ confidence = chunk.get('confidence', chunk.get('score', 0.5))
525
+ return confidence >= threshold
526
+
527
+
528
+ # Example usage
529
+ if __name__ == "__main__":
530
+ # Initialize engine
531
+ engine = AdaptivePromptEngine()
532
+
533
+ # Example context chunks
534
+ example_chunks = [
535
+ {
536
+ "content": "RISC-V is an open-source instruction set architecture...",
537
+ "metadata": {"page_number": 1, "source": "riscv-spec.pdf"},
538
+ "confidence": 0.9
539
+ },
540
+ {
541
+ "content": "The RISC-V processor supports 32-bit and 64-bit implementations...",
542
+ "metadata": {"page_number": 2, "source": "riscv-spec.pdf"},
543
+ "confidence": 0.8
544
+ }
545
+ ]
546
+
547
+ # Example queries
548
+ simple_query = "What is RISC-V?"
549
+ complex_query = "How do I implement a complete interrupt handling system in RISC-V with nested interrupts and priority management?"
550
+
551
+ # Generate adaptive prompts
552
+ simple_config = engine.generate_adaptive_config(simple_query, example_chunks)
553
+ complex_config = engine.generate_adaptive_config(complex_query, example_chunks)
554
+
555
+ print(f"Simple query complexity: {simple_config.query_complexity}")
556
+ print(f"Complex query complexity: {complex_config.query_complexity}")
557
+ print(f"Context quality: {simple_config.context_quality}")
558
+ print(f"Few-shot enabled: {complex_config.include_few_shot}")
559
+ print(f"Chain-of-thought enabled: {complex_config.enable_chain_of_thought}")
shared_utils/generation/answer_generator.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Answer generation module using Ollama for local LLM inference.
3
+
4
+ This module provides answer generation with citation support for RAG systems,
5
+ optimized for technical documentation Q&A on Apple Silicon.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from typing import List, Dict, Any, Optional, Generator, Tuple
12
+ import ollama
13
+ from datetime import datetime
14
+ import re
15
+ from pathlib import Path
16
+ import sys
17
+
18
+ # Import calibration framework
19
+ try:
20
+ from src.confidence_calibration import ConfidenceCalibrator
21
+ except ImportError:
22
+ # Fallback - disable calibration for deployment
23
+ ConfidenceCalibrator = None
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class Citation:
30
+ """Represents a citation to a source document chunk."""
31
+ chunk_id: str
32
+ page_number: int
33
+ source_file: str
34
+ relevance_score: float
35
+ text_snippet: str
36
+
37
+
38
+ @dataclass
39
+ class GeneratedAnswer:
40
+ """Represents a generated answer with citations."""
41
+ answer: str
42
+ citations: List[Citation]
43
+ confidence_score: float
44
+ generation_time: float
45
+ model_used: str
46
+ context_used: List[Dict[str, Any]]
47
+
48
+
49
+ class AnswerGenerator:
50
+ """
51
+ Generates answers using local LLMs via Ollama with citation support.
52
+
53
+ Optimized for technical documentation Q&A with:
54
+ - Streaming response support
55
+ - Citation extraction and formatting
56
+ - Confidence scoring
57
+ - Fallback model support
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ primary_model: str = "llama3.2:3b",
63
+ fallback_model: str = "mistral:latest",
64
+ temperature: float = 0.3,
65
+ max_tokens: int = 1024,
66
+ stream: bool = True,
67
+ enable_calibration: bool = True
68
+ ):
69
+ """
70
+ Initialize the answer generator.
71
+
72
+ Args:
73
+ primary_model: Primary Ollama model to use
74
+ fallback_model: Fallback model for complex queries
75
+ temperature: Generation temperature (0.0-1.0)
76
+ max_tokens: Maximum tokens to generate
77
+ stream: Whether to stream responses
78
+ enable_calibration: Whether to enable confidence calibration
79
+ """
80
+ self.primary_model = primary_model
81
+ self.fallback_model = fallback_model
82
+ self.temperature = temperature
83
+ self.max_tokens = max_tokens
84
+ self.stream = stream
85
+ self.client = ollama.Client()
86
+
87
+ # Initialize confidence calibration
88
+ self.enable_calibration = enable_calibration
89
+ self.calibrator = None
90
+ if enable_calibration and ConfidenceCalibrator is not None:
91
+ try:
92
+ self.calibrator = ConfidenceCalibrator()
93
+ logger.info("Confidence calibration enabled")
94
+ except Exception as e:
95
+ logger.warning(f"Failed to initialize calibration: {e}")
96
+ self.enable_calibration = False
97
+ elif enable_calibration and ConfidenceCalibrator is None:
98
+ logger.warning("Calibration requested but ConfidenceCalibrator not available - disabling")
99
+ self.enable_calibration = False
100
+
101
+ # Verify models are available
102
+ self._verify_models()
103
+
104
+ def _verify_models(self) -> None:
105
+ """Verify that required models are available."""
106
+ try:
107
+ model_list = self.client.list()
108
+ available_models = []
109
+
110
+ # Handle Ollama's ListResponse object
111
+ if hasattr(model_list, 'models'):
112
+ for model in model_list.models:
113
+ if hasattr(model, 'model'):
114
+ available_models.append(model.model)
115
+ elif isinstance(model, dict) and 'model' in model:
116
+ available_models.append(model['model'])
117
+
118
+ if self.primary_model not in available_models:
119
+ logger.warning(f"Primary model {self.primary_model} not found. Available models: {available_models}")
120
+ raise ValueError(f"Model {self.primary_model} not available. Please run: ollama pull {self.primary_model}")
121
+
122
+ if self.fallback_model not in available_models:
123
+ logger.warning(f"Fallback model {self.fallback_model} not found in: {available_models}")
124
+
125
+ except Exception as e:
126
+ logger.error(f"Error verifying models: {e}")
127
+ raise
128
+
129
+ def _create_system_prompt(self) -> str:
130
+ """Create system prompt for technical documentation Q&A."""
131
+ return """You are a technical documentation assistant that provides clear, accurate answers based on the provided context.
132
+
133
+ CORE PRINCIPLES:
134
+ 1. ANSWER DIRECTLY: If context contains the answer, provide it clearly and confidently
135
+ 2. BE CONCISE: Keep responses focused and avoid unnecessary uncertainty language
136
+ 3. CITE ACCURATELY: Use [chunk_X] citations for every fact from context
137
+
138
+ RESPONSE GUIDELINES:
139
+ - If context has sufficient information → Answer directly and confidently
140
+ - If context has partial information → Answer what's available, note what's missing briefly
141
+ - If context is irrelevant → Brief refusal: "This information isn't available in the provided documents"
142
+
143
+ CITATION FORMAT:
144
+ - Use [chunk_1], [chunk_2] etc. for all facts from context
145
+ - Example: "According to [chunk_1], RISC-V is an open-source architecture."
146
+
147
+ WHAT TO AVOID:
148
+ - Do NOT add details not in context
149
+ - Do NOT second-guess yourself if context is clear
150
+ - Do NOT use phrases like "does not contain sufficient information" when context clearly answers the question
151
+ - Do NOT be overly cautious when context is adequate
152
+
153
+ Be direct, confident, and accurate. If the context answers the question, provide that answer clearly."""
154
+
155
+ def _format_context(self, chunks: List[Dict[str, Any]]) -> str:
156
+ """
157
+ Format retrieved chunks into context for the LLM.
158
+
159
+ Args:
160
+ chunks: List of retrieved chunks with metadata
161
+
162
+ Returns:
163
+ Formatted context string
164
+ """
165
+ context_parts = []
166
+
167
+ for i, chunk in enumerate(chunks):
168
+ chunk_text = chunk.get('content', chunk.get('text', ''))
169
+ page_num = chunk.get('metadata', {}).get('page_number', 'unknown')
170
+ source = chunk.get('metadata', {}).get('source', 'unknown')
171
+
172
+ context_parts.append(
173
+ f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n"
174
+ )
175
+
176
+ return "\n---\n".join(context_parts)
177
+
178
+ def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]:
179
+ """
180
+ Extract citations from the generated answer and integrate them naturally.
181
+
182
+ Args:
183
+ answer: Generated answer with [chunk_X] citations
184
+ chunks: Original chunks used for context
185
+
186
+ Returns:
187
+ Tuple of (natural_answer, citations)
188
+ """
189
+ citations = []
190
+ citation_pattern = r'\[chunk_(\d+)\]'
191
+
192
+ cited_chunks = set()
193
+
194
+ # Find [chunk_X] citations and collect cited chunks
195
+ matches = re.finditer(citation_pattern, answer)
196
+ for match in matches:
197
+ chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index
198
+ if 0 <= chunk_idx < len(chunks):
199
+ cited_chunks.add(chunk_idx)
200
+
201
+ # Create Citation objects for each cited chunk
202
+ chunk_to_source = {}
203
+ for idx in cited_chunks:
204
+ chunk = chunks[idx]
205
+ citation = Citation(
206
+ chunk_id=chunk.get('id', f'chunk_{idx}'),
207
+ page_number=chunk.get('metadata', {}).get('page_number', 0),
208
+ source_file=chunk.get('metadata', {}).get('source', 'unknown'),
209
+ relevance_score=chunk.get('score', 0.0),
210
+ text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...'
211
+ )
212
+ citations.append(citation)
213
+
214
+ # Map chunk reference to natural source name
215
+ source_name = chunk.get('metadata', {}).get('source', 'unknown')
216
+ if source_name != 'unknown':
217
+ # Use just the filename without extension for natural reference
218
+ natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ')
219
+ chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation"
220
+ else:
221
+ chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation"
222
+
223
+ # Replace [chunk_X] with natural references instead of removing them
224
+ natural_answer = answer
225
+ for chunk_ref, natural_ref in chunk_to_source.items():
226
+ natural_answer = natural_answer.replace(chunk_ref, natural_ref)
227
+
228
+ # Clean up any remaining unreferenced citations (fallback)
229
+ natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer)
230
+
231
+ # Clean up multiple spaces and formatting
232
+ natural_answer = re.sub(r'\s+', ' ', natural_answer).strip()
233
+
234
+ return natural_answer, citations
235
+
236
+ def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float:
237
+ """
238
+ Calculate confidence score for the generated answer with improved calibration.
239
+
240
+ Args:
241
+ answer: Generated answer
242
+ citations: Extracted citations
243
+ chunks: Retrieved chunks
244
+
245
+ Returns:
246
+ Confidence score (0.0-1.0)
247
+ """
248
+ # Check if no chunks were provided first
249
+ if not chunks:
250
+ return 0.05 # No context = very low confidence
251
+
252
+ # Assess context quality to determine base confidence
253
+ scores = [chunk.get('score', 0) for chunk in chunks]
254
+ max_relevance = max(scores) if scores else 0
255
+ avg_relevance = sum(scores) / len(scores) if scores else 0
256
+
257
+ # Dynamic base confidence based on context quality
258
+ if max_relevance >= 0.8:
259
+ confidence = 0.6 # High-quality context starts high
260
+ elif max_relevance >= 0.6:
261
+ confidence = 0.4 # Good context starts moderately
262
+ elif max_relevance >= 0.4:
263
+ confidence = 0.2 # Fair context starts low
264
+ else:
265
+ confidence = 0.05 # Poor context starts very low
266
+
267
+ # Strong uncertainty and explicit refusal indicators
268
+ strong_uncertainty_phrases = [
269
+ "does not contain sufficient information",
270
+ "context does not provide",
271
+ "insufficient information",
272
+ "cannot determine",
273
+ "refuse to answer",
274
+ "cannot answer",
275
+ "does not contain relevant",
276
+ "no relevant context",
277
+ "missing from the provided context"
278
+ ]
279
+
280
+ # Weak uncertainty phrases that might be in nuanced but correct answers
281
+ weak_uncertainty_phrases = [
282
+ "unclear",
283
+ "conflicting",
284
+ "not specified",
285
+ "questionable",
286
+ "not contained",
287
+ "no mention",
288
+ "no relevant",
289
+ "missing",
290
+ "not explicitly"
291
+ ]
292
+
293
+ # Check for strong uncertainty - these should drastically reduce confidence
294
+ if any(phrase in answer.lower() for phrase in strong_uncertainty_phrases):
295
+ return min(0.1, confidence * 0.2) # Max 10% for explicit refusal/uncertainty
296
+
297
+ # Check for weak uncertainty - reduce but don't destroy confidence for good context
298
+ weak_uncertainty_count = sum(1 for phrase in weak_uncertainty_phrases if phrase in answer.lower())
299
+ if weak_uncertainty_count > 0:
300
+ if max_relevance >= 0.7 and citations:
301
+ # Good context with citations - reduce less severely
302
+ confidence *= (0.8 ** weak_uncertainty_count) # Moderate penalty
303
+ else:
304
+ # Poor context - reduce more severely
305
+ confidence *= (0.5 ** weak_uncertainty_count) # Strong penalty
306
+
307
+ # If all chunks have very low relevance scores, cap confidence low
308
+ if max_relevance < 0.4:
309
+ return min(0.08, confidence) # Max 8% for low relevance context
310
+
311
+ # Factor 1: Citation quality and coverage
312
+ if citations and chunks:
313
+ citation_ratio = len(citations) / min(len(chunks), 3)
314
+
315
+ # Strong boost for high-relevance citations
316
+ relevant_chunks = [c for c in chunks if c.get('score', 0) > 0.6]
317
+ if relevant_chunks:
318
+ # Significant boost for citing relevant chunks
319
+ confidence += 0.25 * citation_ratio
320
+
321
+ # Extra boost if citing majority of relevant chunks
322
+ if len(citations) >= len(relevant_chunks) * 0.5:
323
+ confidence += 0.15
324
+ else:
325
+ # Small boost for citations to lower-relevance chunks
326
+ confidence += 0.1 * citation_ratio
327
+ else:
328
+ # No citations = reduce confidence unless it's a simple factual statement
329
+ if max_relevance >= 0.8 and len(answer.split()) < 20:
330
+ confidence *= 0.8 # Gentle penalty for uncited but simple answers
331
+ else:
332
+ confidence *= 0.6 # Stronger penalty for complex uncited answers
333
+
334
+ # Factor 2: Relevance score reinforcement
335
+ if citations:
336
+ avg_citation_relevance = sum(c.relevance_score for c in citations) / len(citations)
337
+ if avg_citation_relevance > 0.8:
338
+ confidence += 0.2 # Strong boost for highly relevant citations
339
+ elif avg_citation_relevance > 0.6:
340
+ confidence += 0.1 # Moderate boost
341
+ elif avg_citation_relevance < 0.4:
342
+ confidence *= 0.6 # Penalty for low-relevance citations
343
+
344
+ # Factor 3: Context utilization quality
345
+ if chunks:
346
+ avg_chunk_length = sum(len(chunk.get('content', chunk.get('text', ''))) for chunk in chunks) / len(chunks)
347
+
348
+ # Boost for substantial, high-quality context
349
+ if avg_chunk_length > 200 and max_relevance > 0.8:
350
+ confidence += 0.1
351
+ elif avg_chunk_length < 50: # Very short chunks
352
+ confidence *= 0.8
353
+
354
+ # Factor 4: Answer characteristics
355
+ answer_words = len(answer.split())
356
+ if answer_words < 10:
357
+ confidence *= 0.9 # Slight penalty for very short answers
358
+ elif answer_words > 50 and citations:
359
+ confidence += 0.05 # Small boost for detailed cited answers
360
+
361
+ # Factor 5: High-quality scenario bonus
362
+ if (max_relevance >= 0.8 and citations and
363
+ len(citations) > 0 and
364
+ not any(phrase in answer.lower() for phrase in strong_uncertainty_phrases)):
365
+ # This is a high-quality response scenario
366
+ confidence += 0.15
367
+
368
+ raw_confidence = min(confidence, 0.95) # Cap at 95% to maintain some uncertainty
369
+
370
+ # Apply temperature scaling calibration if available
371
+ if self.enable_calibration and self.calibrator and self.calibrator.is_fitted:
372
+ try:
373
+ calibrated_confidence = self.calibrator.calibrate_confidence(raw_confidence)
374
+ logger.debug(f"Confidence calibrated: {raw_confidence:.3f} -> {calibrated_confidence:.3f}")
375
+ return calibrated_confidence
376
+ except Exception as e:
377
+ logger.warning(f"Calibration failed, using raw confidence: {e}")
378
+
379
+ return raw_confidence
380
+
381
+ def fit_calibration(self, validation_data: List[Dict[str, Any]]) -> float:
382
+ """
383
+ Fit temperature scaling calibration using validation data.
384
+
385
+ Args:
386
+ validation_data: List of dicts with 'confidence' and 'correctness' keys
387
+
388
+ Returns:
389
+ Optimal temperature parameter
390
+ """
391
+ if not self.enable_calibration or not self.calibrator:
392
+ logger.warning("Calibration not enabled or not available")
393
+ return 1.0
394
+
395
+ try:
396
+ confidences = [item['confidence'] for item in validation_data]
397
+ correctness = [item['correctness'] for item in validation_data]
398
+
399
+ optimal_temp = self.calibrator.fit_temperature_scaling(confidences, correctness)
400
+ logger.info(f"Calibration fitted with temperature: {optimal_temp:.3f}")
401
+ return optimal_temp
402
+
403
+ except Exception as e:
404
+ logger.error(f"Failed to fit calibration: {e}")
405
+ return 1.0
406
+
407
+ def save_calibration(self, filepath: str) -> bool:
408
+ """Save fitted calibration to file."""
409
+ if not self.calibrator or not self.calibrator.is_fitted:
410
+ logger.warning("No fitted calibration to save")
411
+ return False
412
+
413
+ try:
414
+ calibration_data = {
415
+ 'temperature': self.calibrator.temperature,
416
+ 'is_fitted': self.calibrator.is_fitted,
417
+ 'model_info': {
418
+ 'primary_model': self.primary_model,
419
+ 'fallback_model': self.fallback_model
420
+ }
421
+ }
422
+
423
+ with open(filepath, 'w') as f:
424
+ json.dump(calibration_data, f, indent=2)
425
+
426
+ logger.info(f"Calibration saved to {filepath}")
427
+ return True
428
+
429
+ except Exception as e:
430
+ logger.error(f"Failed to save calibration: {e}")
431
+ return False
432
+
433
+ def load_calibration(self, filepath: str) -> bool:
434
+ """Load fitted calibration from file."""
435
+ if not self.enable_calibration or not self.calibrator:
436
+ logger.warning("Calibration not enabled")
437
+ return False
438
+
439
+ try:
440
+ with open(filepath, 'r') as f:
441
+ calibration_data = json.load(f)
442
+
443
+ self.calibrator.temperature = calibration_data['temperature']
444
+ self.calibrator.is_fitted = calibration_data['is_fitted']
445
+
446
+ logger.info(f"Calibration loaded from {filepath} (temp: {self.calibrator.temperature:.3f})")
447
+ return True
448
+
449
+ except Exception as e:
450
+ logger.error(f"Failed to load calibration: {e}")
451
+ return False
452
+
453
+ def generate(
454
+ self,
455
+ query: str,
456
+ chunks: List[Dict[str, Any]],
457
+ use_fallback: bool = False
458
+ ) -> GeneratedAnswer:
459
+ """
460
+ Generate an answer based on the query and retrieved chunks.
461
+
462
+ Args:
463
+ query: User's question
464
+ chunks: Retrieved document chunks
465
+ use_fallback: Whether to use fallback model
466
+
467
+ Returns:
468
+ GeneratedAnswer object with answer, citations, and metadata
469
+ """
470
+ start_time = datetime.now()
471
+ model = self.fallback_model if use_fallback else self.primary_model
472
+
473
+ # Check for no-context or very poor context situation
474
+ if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks):
475
+ # Handle no-context situation with brief, professional refusal
476
+ user_prompt = f"""Context: [NO RELEVANT CONTEXT FOUND]
477
+
478
+ Question: {query}
479
+
480
+ INSTRUCTION: Respond with exactly this brief message:
481
+
482
+ "This information isn't available in the provided documents."
483
+
484
+ DO NOT elaborate, explain, or add any other information."""
485
+ else:
486
+ # Format context from chunks
487
+ context = self._format_context(chunks)
488
+
489
+ # Create concise prompt for faster generation
490
+ user_prompt = f"""Context:
491
+ {context}
492
+
493
+ Question: {query}
494
+
495
+ Instructions: Answer using only the context above. Cite with [chunk_1], [chunk_2] etc.
496
+
497
+ Answer:"""
498
+
499
+ try:
500
+ # Generate response
501
+ response = self.client.chat(
502
+ model=model,
503
+ messages=[
504
+ {"role": "system", "content": self._create_system_prompt()},
505
+ {"role": "user", "content": user_prompt}
506
+ ],
507
+ options={
508
+ "temperature": self.temperature,
509
+ "num_predict": min(self.max_tokens, 300), # Reduce max tokens for speed
510
+ "top_k": 40, # Optimize sampling for speed
511
+ "top_p": 0.9,
512
+ "repeat_penalty": 1.1
513
+ },
514
+ stream=False # Get complete response for processing
515
+ )
516
+
517
+ # Extract answer
518
+ answer_with_citations = response['message']['content']
519
+
520
+ # Extract and clean citations
521
+ clean_answer, citations = self._extract_citations(answer_with_citations, chunks)
522
+
523
+ # Calculate confidence
524
+ confidence = self._calculate_confidence(clean_answer, citations, chunks)
525
+
526
+ # Calculate generation time
527
+ generation_time = (datetime.now() - start_time).total_seconds()
528
+
529
+ return GeneratedAnswer(
530
+ answer=clean_answer,
531
+ citations=citations,
532
+ confidence_score=confidence,
533
+ generation_time=generation_time,
534
+ model_used=model,
535
+ context_used=chunks
536
+ )
537
+
538
+ except Exception as e:
539
+ logger.error(f"Error generating answer: {e}")
540
+ # Return a fallback response
541
+ return GeneratedAnswer(
542
+ answer="I apologize, but I encountered an error while generating the answer. Please try again.",
543
+ citations=[],
544
+ confidence_score=0.0,
545
+ generation_time=0.0,
546
+ model_used=model,
547
+ context_used=chunks
548
+ )
549
+
550
+ def generate_stream(
551
+ self,
552
+ query: str,
553
+ chunks: List[Dict[str, Any]],
554
+ use_fallback: bool = False
555
+ ) -> Generator[str, None, GeneratedAnswer]:
556
+ """
557
+ Generate an answer with streaming support.
558
+
559
+ Args:
560
+ query: User's question
561
+ chunks: Retrieved document chunks
562
+ use_fallback: Whether to use fallback model
563
+
564
+ Yields:
565
+ Partial answer strings
566
+
567
+ Returns:
568
+ Final GeneratedAnswer object
569
+ """
570
+ start_time = datetime.now()
571
+ model = self.fallback_model if use_fallback else self.primary_model
572
+
573
+ # Check for no-context or very poor context situation
574
+ if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks):
575
+ # Handle no-context situation with brief, professional refusal
576
+ user_prompt = f"""Context: [NO RELEVANT CONTEXT FOUND]
577
+
578
+ Question: {query}
579
+
580
+ INSTRUCTION: Respond with exactly this brief message:
581
+
582
+ "This information isn't available in the provided documents."
583
+
584
+ DO NOT elaborate, explain, or add any other information."""
585
+ else:
586
+ # Format context from chunks
587
+ context = self._format_context(chunks)
588
+
589
+ # Create concise prompt for faster generation
590
+ user_prompt = f"""Context:
591
+ {context}
592
+
593
+ Question: {query}
594
+
595
+ Instructions: Answer using only the context above. Cite with [chunk_1], [chunk_2] etc.
596
+
597
+ Answer:"""
598
+
599
+ try:
600
+ # Generate streaming response
601
+ stream = self.client.chat(
602
+ model=model,
603
+ messages=[
604
+ {"role": "system", "content": self._create_system_prompt()},
605
+ {"role": "user", "content": user_prompt}
606
+ ],
607
+ options={
608
+ "temperature": self.temperature,
609
+ "num_predict": min(self.max_tokens, 300), # Reduce max tokens for speed
610
+ "top_k": 40, # Optimize sampling for speed
611
+ "top_p": 0.9,
612
+ "repeat_penalty": 1.1
613
+ },
614
+ stream=True
615
+ )
616
+
617
+ # Collect full answer while streaming
618
+ full_answer = ""
619
+ for chunk in stream:
620
+ if 'message' in chunk and 'content' in chunk['message']:
621
+ partial = chunk['message']['content']
622
+ full_answer += partial
623
+ yield partial
624
+
625
+ # Process complete answer
626
+ clean_answer, citations = self._extract_citations(full_answer, chunks)
627
+ confidence = self._calculate_confidence(clean_answer, citations, chunks)
628
+ generation_time = (datetime.now() - start_time).total_seconds()
629
+
630
+ return GeneratedAnswer(
631
+ answer=clean_answer,
632
+ citations=citations,
633
+ confidence_score=confidence,
634
+ generation_time=generation_time,
635
+ model_used=model,
636
+ context_used=chunks
637
+ )
638
+
639
+ except Exception as e:
640
+ logger.error(f"Error in streaming generation: {e}")
641
+ yield "I apologize, but I encountered an error while generating the answer."
642
+
643
+ return GeneratedAnswer(
644
+ answer="Error occurred during generation.",
645
+ citations=[],
646
+ confidence_score=0.0,
647
+ generation_time=0.0,
648
+ model_used=model,
649
+ context_used=chunks
650
+ )
651
+
652
+ def format_answer_with_citations(self, generated_answer: GeneratedAnswer) -> str:
653
+ """
654
+ Format the generated answer with citations for display.
655
+
656
+ Args:
657
+ generated_answer: GeneratedAnswer object
658
+
659
+ Returns:
660
+ Formatted string with answer and citations
661
+ """
662
+ formatted = f"{generated_answer.answer}\n\n"
663
+
664
+ if generated_answer.citations:
665
+ formatted += "**Sources:**\n"
666
+ for i, citation in enumerate(generated_answer.citations, 1):
667
+ formatted += f"{i}. {citation.source_file} (Page {citation.page_number})\n"
668
+
669
+ formatted += f"\n*Confidence: {generated_answer.confidence_score:.1%} | "
670
+ formatted += f"Model: {generated_answer.model_used} | "
671
+ formatted += f"Time: {generated_answer.generation_time:.2f}s*"
672
+
673
+ return formatted
674
+
675
+
676
+ if __name__ == "__main__":
677
+ # Example usage
678
+ generator = AnswerGenerator()
679
+
680
+ # Example chunks (would come from retrieval system)
681
+ example_chunks = [
682
+ {
683
+ "id": "chunk_1",
684
+ "content": "RISC-V is an open-source instruction set architecture (ISA) based on reduced instruction set computer (RISC) principles.",
685
+ "metadata": {"page_number": 1, "source": "riscv-spec.pdf"},
686
+ "score": 0.95
687
+ },
688
+ {
689
+ "id": "chunk_2",
690
+ "content": "The RISC-V ISA is designed to support a wide range of implementations including 32-bit, 64-bit, and 128-bit variants.",
691
+ "metadata": {"page_number": 2, "source": "riscv-spec.pdf"},
692
+ "score": 0.89
693
+ }
694
+ ]
695
+
696
+ # Generate answer
697
+ result = generator.generate(
698
+ query="What is RISC-V?",
699
+ chunks=example_chunks
700
+ )
701
+
702
+ # Display formatted result
703
+ print(generator.format_answer_with_citations(result))
shared_utils/generation/chain_of_thought_engine.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chain-of-Thought Reasoning Engine for Complex Technical Queries.
3
+
4
+ This module provides structured reasoning capabilities for complex technical
5
+ questions that require multi-step analysis and implementation guidance.
6
+ """
7
+
8
+ from typing import Dict, List, Optional, Any, Tuple
9
+ from dataclasses import dataclass
10
+ from enum import Enum
11
+ import re
12
+
13
+ from .prompt_templates import QueryType, PromptTemplate
14
+
15
+
16
+ class ReasoningStep(Enum):
17
+ """Types of reasoning steps in chain-of-thought."""
18
+ ANALYSIS = "analysis"
19
+ DECOMPOSITION = "decomposition"
20
+ SYNTHESIS = "synthesis"
21
+ VALIDATION = "validation"
22
+ IMPLEMENTATION = "implementation"
23
+
24
+
25
+ @dataclass
26
+ class ChainStep:
27
+ """Represents a single step in chain-of-thought reasoning."""
28
+ step_type: ReasoningStep
29
+ description: str
30
+ prompt_addition: str
31
+ requires_context: bool = True
32
+
33
+
34
+ class ChainOfThoughtEngine:
35
+ """
36
+ Engine for generating chain-of-thought reasoning prompts for complex technical queries.
37
+
38
+ Features:
39
+ - Multi-step reasoning for complex implementations
40
+ - Context-aware step generation
41
+ - Query type specific reasoning chains
42
+ - Validation and error checking steps
43
+ """
44
+
45
+ def __init__(self):
46
+ """Initialize the chain-of-thought engine."""
47
+ self.reasoning_chains = self._initialize_reasoning_chains()
48
+
49
+ def _initialize_reasoning_chains(self) -> Dict[QueryType, List[ChainStep]]:
50
+ """Initialize reasoning chains for different query types."""
51
+ return {
52
+ QueryType.IMPLEMENTATION: [
53
+ ChainStep(
54
+ step_type=ReasoningStep.ANALYSIS,
55
+ description="Analyze the implementation requirements",
56
+ prompt_addition="""
57
+ First, let me analyze what needs to be implemented:
58
+ 1. What is the specific goal or functionality required?
59
+ 2. What are the key components or modules involved?
60
+ 3. Are there any hardware or software constraints mentioned?"""
61
+ ),
62
+ ChainStep(
63
+ step_type=ReasoningStep.DECOMPOSITION,
64
+ description="Break down into implementation steps",
65
+ prompt_addition="""
66
+ Next, let me break this down into logical implementation steps:
67
+ 1. What are the prerequisites and dependencies?
68
+ 2. What is the logical sequence of implementation?
69
+ 3. Which steps are critical and which are optional?"""
70
+ ),
71
+ ChainStep(
72
+ step_type=ReasoningStep.SYNTHESIS,
73
+ description="Synthesize the complete solution",
74
+ prompt_addition="""
75
+ Now I'll synthesize the complete solution:
76
+ 1. How do the individual steps connect together?
77
+ 2. What code examples or configurations are needed?
78
+ 3. What are the key integration points?"""
79
+ ),
80
+ ChainStep(
81
+ step_type=ReasoningStep.VALIDATION,
82
+ description="Consider validation and error handling",
83
+ prompt_addition="""
84
+ Finally, let me consider validation and potential issues:
85
+ 1. How can we verify the implementation works?
86
+ 2. What are common pitfalls or error conditions?
87
+ 3. What debugging or troubleshooting steps are important?"""
88
+ )
89
+ ],
90
+
91
+ QueryType.COMPARISON: [
92
+ ChainStep(
93
+ step_type=ReasoningStep.ANALYSIS,
94
+ description="Analyze items being compared",
95
+ prompt_addition="""
96
+ Let me start by analyzing what's being compared:
97
+ 1. What are the specific items or concepts being compared?
98
+ 2. What aspects or dimensions are relevant for comparison?
99
+ 3. What context or use case should guide the comparison?"""
100
+ ),
101
+ ChainStep(
102
+ step_type=ReasoningStep.DECOMPOSITION,
103
+ description="Break down comparison criteria",
104
+ prompt_addition="""
105
+ Next, let me identify the key comparison criteria:
106
+ 1. What are the technical specifications or features to compare?
107
+ 2. What are the performance characteristics?
108
+ 3. What are the practical considerations (cost, complexity, etc.)?"""
109
+ ),
110
+ ChainStep(
111
+ step_type=ReasoningStep.SYNTHESIS,
112
+ description="Synthesize comparison results",
113
+ prompt_addition="""
114
+ Now I'll synthesize the comparison:
115
+ 1. How do the items compare on each criterion?
116
+ 2. What are the key trade-offs and differences?
117
+ 3. What recommendations can be made for different scenarios?"""
118
+ )
119
+ ],
120
+
121
+ QueryType.TROUBLESHOOTING: [
122
+ ChainStep(
123
+ step_type=ReasoningStep.ANALYSIS,
124
+ description="Analyze the problem",
125
+ prompt_addition="""
126
+ Let me start by analyzing the problem:
127
+ 1. What are the specific symptoms or error conditions?
128
+ 2. What system or component is affected?
129
+ 3. What was the expected vs actual behavior?"""
130
+ ),
131
+ ChainStep(
132
+ step_type=ReasoningStep.DECOMPOSITION,
133
+ description="Identify potential root causes",
134
+ prompt_addition="""
135
+ Next, let me identify potential root causes:
136
+ 1. What are the most likely causes based on the symptoms?
137
+ 2. What system components could be involved?
138
+ 3. What external factors might contribute to the issue?"""
139
+ ),
140
+ ChainStep(
141
+ step_type=ReasoningStep.VALIDATION,
142
+ description="Develop diagnostic approach",
143
+ prompt_addition="""
144
+ Now I'll develop a diagnostic approach:
145
+ 1. What tests or checks can isolate the root cause?
146
+ 2. What is the recommended sequence of diagnostic steps?
147
+ 3. How can we verify the fix once implemented?"""
148
+ )
149
+ ],
150
+
151
+ QueryType.HARDWARE_CONSTRAINT: [
152
+ ChainStep(
153
+ step_type=ReasoningStep.ANALYSIS,
154
+ description="Analyze hardware requirements",
155
+ prompt_addition="""
156
+ Let me analyze the hardware requirements:
157
+ 1. What are the specific hardware resources needed?
158
+ 2. What are the performance requirements?
159
+ 3. What are the power and size constraints?"""
160
+ ),
161
+ ChainStep(
162
+ step_type=ReasoningStep.DECOMPOSITION,
163
+ description="Break down resource utilization",
164
+ prompt_addition="""
165
+ Next, let me break down resource utilization:
166
+ 1. How much memory (RAM/Flash) is required?
167
+ 2. What are the processing requirements (CPU/DSP)?
168
+ 3. What I/O and peripheral requirements exist?"""
169
+ ),
170
+ ChainStep(
171
+ step_type=ReasoningStep.SYNTHESIS,
172
+ description="Evaluate feasibility and alternatives",
173
+ prompt_addition="""
174
+ Now I'll evaluate feasibility:
175
+ 1. Can the requirements be met with the available hardware?
176
+ 2. What optimizations might be needed?
177
+ 3. What are alternative approaches if constraints are exceeded?"""
178
+ )
179
+ ]
180
+ }
181
+
182
+ def generate_chain_of_thought_prompt(
183
+ self,
184
+ query: str,
185
+ query_type: QueryType,
186
+ context: str,
187
+ base_template: PromptTemplate
188
+ ) -> Dict[str, str]:
189
+ """
190
+ Generate a chain-of-thought enhanced prompt.
191
+
192
+ Args:
193
+ query: User's question
194
+ query_type: Type of query
195
+ context: Retrieved context
196
+ base_template: Base prompt template
197
+
198
+ Returns:
199
+ Enhanced prompt with chain-of-thought reasoning
200
+ """
201
+ # Get reasoning chain for query type
202
+ reasoning_chain = self.reasoning_chains.get(query_type, [])
203
+
204
+ if not reasoning_chain:
205
+ # Fall back to generic reasoning for unsupported types
206
+ reasoning_chain = self._generate_generic_reasoning_chain(query)
207
+
208
+ # Build chain-of-thought prompt
209
+ cot_prompt = self._build_cot_prompt(reasoning_chain, query, context)
210
+
211
+ # Enhance system prompt
212
+ enhanced_system = base_template.system_prompt + """
213
+
214
+ CHAIN-OF-THOUGHT REASONING: You will approach this question using structured reasoning.
215
+ Work through each step methodically before providing your final answer.
216
+ Show your reasoning process clearly, then provide a comprehensive final answer."""
217
+
218
+ # Enhance user prompt
219
+ enhanced_user = f"""{base_template.context_format.format(context=context)}
220
+
221
+ {base_template.query_format.format(query=query)}
222
+
223
+ {cot_prompt}
224
+
225
+ {base_template.answer_guidelines}
226
+
227
+ After working through your reasoning, provide your final answer in the requested format."""
228
+
229
+ return {
230
+ "system": enhanced_system,
231
+ "user": enhanced_user
232
+ }
233
+
234
+ def _build_cot_prompt(
235
+ self,
236
+ reasoning_chain: List[ChainStep],
237
+ query: str,
238
+ context: str
239
+ ) -> str:
240
+ """
241
+ Build the chain-of-thought prompt section.
242
+
243
+ Args:
244
+ reasoning_chain: List of reasoning steps
245
+ query: User's question
246
+ context: Retrieved context
247
+
248
+ Returns:
249
+ Chain-of-thought prompt text
250
+ """
251
+ cot_sections = [
252
+ "REASONING PROCESS:",
253
+ "Work through this step-by-step using the following reasoning framework:",
254
+ ""
255
+ ]
256
+
257
+ for i, step in enumerate(reasoning_chain, 1):
258
+ cot_sections.append(f"Step {i}: {step.description}")
259
+ cot_sections.append(step.prompt_addition)
260
+ cot_sections.append("")
261
+
262
+ cot_sections.extend([
263
+ "STRUCTURED REASONING:",
264
+ "Now work through each step above, referencing the provided context where relevant.",
265
+ "Use [chunk_X] citations for your reasoning at each step.",
266
+ ""
267
+ ])
268
+
269
+ return "\n".join(cot_sections)
270
+
271
+ def _generate_generic_reasoning_chain(self, query: str) -> List[ChainStep]:
272
+ """
273
+ Generate a generic reasoning chain for unsupported query types.
274
+
275
+ Args:
276
+ query: User's question
277
+
278
+ Returns:
279
+ List of generic reasoning steps
280
+ """
281
+ # Analyze query complexity to determine appropriate steps
282
+ complexity_indicators = {
283
+ "multi_part": ["and", "also", "additionally", "furthermore"],
284
+ "causal": ["why", "because", "cause", "reason"],
285
+ "conditional": ["if", "when", "unless", "provided that"],
286
+ "comparative": ["better", "worse", "compare", "versus", "vs"]
287
+ }
288
+
289
+ query_lower = query.lower()
290
+ steps = []
291
+
292
+ # Always start with analysis
293
+ steps.append(ChainStep(
294
+ step_type=ReasoningStep.ANALYSIS,
295
+ description="Analyze the question",
296
+ prompt_addition="""
297
+ Let me start by analyzing the question:
298
+ 1. What is the core question being asked?
299
+ 2. What context or domain knowledge is needed?
300
+ 3. Are there multiple parts to this question?"""
301
+ ))
302
+
303
+ # Add decomposition for complex queries
304
+ if any(indicator in query_lower for indicators in complexity_indicators.values() for indicator in indicators):
305
+ steps.append(ChainStep(
306
+ step_type=ReasoningStep.DECOMPOSITION,
307
+ description="Break down the question",
308
+ prompt_addition="""
309
+ Let me break this down into components:
310
+ 1. What are the key concepts or elements involved?
311
+ 2. How do these elements relate to each other?
312
+ 3. What information do I need to address each part?"""
313
+ ))
314
+
315
+ # Always end with synthesis
316
+ steps.append(ChainStep(
317
+ step_type=ReasoningStep.SYNTHESIS,
318
+ description="Synthesize the answer",
319
+ prompt_addition="""
320
+ Now I'll synthesize a comprehensive answer:
321
+ 1. How do all the pieces fit together?
322
+ 2. What is the most complete and accurate response?
323
+ 3. Are there any important caveats or limitations?"""
324
+ ))
325
+
326
+ return steps
327
+
328
+ def create_reasoning_validation_prompt(
329
+ self,
330
+ query: str,
331
+ proposed_answer: str,
332
+ context: str
333
+ ) -> str:
334
+ """
335
+ Create a prompt for validating chain-of-thought reasoning.
336
+
337
+ Args:
338
+ query: Original query
339
+ proposed_answer: Generated answer to validate
340
+ context: Context used for the answer
341
+
342
+ Returns:
343
+ Validation prompt
344
+ """
345
+ return f"""
346
+ REASONING VALIDATION TASK:
347
+
348
+ Original Query: {query}
349
+
350
+ Proposed Answer: {proposed_answer}
351
+
352
+ Context Used: {context}
353
+
354
+ Please validate the reasoning in the proposed answer by checking:
355
+
356
+ 1. LOGICAL CONSISTENCY:
357
+ - Are the reasoning steps logically connected?
358
+ - Are there any contradictions or gaps in logic?
359
+ - Does the conclusion follow from the premises?
360
+
361
+ 2. FACTUAL ACCURACY:
362
+ - Are the facts and technical details correct?
363
+ - Are the citations appropriate and accurate?
364
+ - Is the information consistent with the provided context?
365
+
366
+ 3. COMPLETENESS:
367
+ - Does the answer address all parts of the question?
368
+ - Are important considerations or caveats mentioned?
369
+ - Is the level of detail appropriate for the question?
370
+
371
+ 4. CLARITY:
372
+ - Is the reasoning easy to follow?
373
+ - Are technical terms used correctly?
374
+ - Is the structure logical and well-organized?
375
+
376
+ Provide your validation assessment with specific feedback on any issues found.
377
+ """
378
+
379
+ def extract_reasoning_steps(self, cot_response: str) -> List[Dict[str, str]]:
380
+ """
381
+ Extract reasoning steps from a chain-of-thought response.
382
+
383
+ Args:
384
+ cot_response: Response containing chain-of-thought reasoning
385
+
386
+ Returns:
387
+ List of extracted reasoning steps
388
+ """
389
+ steps = []
390
+
391
+ # Look for step patterns
392
+ step_patterns = [
393
+ r"Step \d+:?\s*(.+?)(?=Step \d+|$)",
394
+ r"First,?\s*(.+?)(?=Next,?|Second,?|Then,?|Finally,?|$)",
395
+ r"Next,?\s*(.+?)(?=Then,?|Finally,?|Now,?|$)",
396
+ r"Then,?\s*(.+?)(?=Finally,?|Now,?|$)",
397
+ r"Finally,?\s*(.+?)(?=\n\n|$)"
398
+ ]
399
+
400
+ for pattern in step_patterns:
401
+ matches = re.findall(pattern, cot_response, re.DOTALL | re.IGNORECASE)
402
+ for match in matches:
403
+ if match.strip():
404
+ steps.append({
405
+ "step_text": match.strip(),
406
+ "pattern": pattern
407
+ })
408
+
409
+ return steps
410
+
411
+ def evaluate_reasoning_quality(self, reasoning_steps: List[Dict[str, str]]) -> Dict[str, float]:
412
+ """
413
+ Evaluate the quality of chain-of-thought reasoning.
414
+
415
+ Args:
416
+ reasoning_steps: List of reasoning steps
417
+
418
+ Returns:
419
+ Dictionary of quality metrics
420
+ """
421
+ if not reasoning_steps:
422
+ return {"overall_quality": 0.0, "step_count": 0}
423
+
424
+ # Evaluate different aspects
425
+ metrics = {
426
+ "step_count": len(reasoning_steps),
427
+ "logical_flow": self._evaluate_logical_flow(reasoning_steps),
428
+ "technical_depth": self._evaluate_technical_depth(reasoning_steps),
429
+ "citation_usage": self._evaluate_citation_usage(reasoning_steps),
430
+ "completeness": self._evaluate_completeness(reasoning_steps)
431
+ }
432
+
433
+ # Calculate overall quality
434
+ quality_weights = {
435
+ "logical_flow": 0.3,
436
+ "technical_depth": 0.3,
437
+ "citation_usage": 0.2,
438
+ "completeness": 0.2
439
+ }
440
+
441
+ overall_quality = sum(
442
+ metrics[key] * quality_weights[key]
443
+ for key in quality_weights
444
+ )
445
+
446
+ metrics["overall_quality"] = overall_quality
447
+
448
+ return metrics
449
+
450
+ def _evaluate_logical_flow(self, steps: List[Dict[str, str]]) -> float:
451
+ """Evaluate logical flow between reasoning steps."""
452
+ if len(steps) < 2:
453
+ return 0.5
454
+
455
+ # Check for logical connectors
456
+ connectors = ["therefore", "thus", "because", "since", "as a result", "consequently"]
457
+ connector_count = 0
458
+
459
+ for step in steps:
460
+ step_text = step["step_text"].lower()
461
+ if any(connector in step_text for connector in connectors):
462
+ connector_count += 1
463
+
464
+ return min(connector_count / len(steps), 1.0)
465
+
466
+ def _evaluate_technical_depth(self, steps: List[Dict[str, str]]) -> float:
467
+ """Evaluate technical depth of reasoning."""
468
+ technical_terms = [
469
+ "implementation", "architecture", "algorithm", "protocol", "specification",
470
+ "optimization", "configuration", "register", "memory", "hardware",
471
+ "software", "system", "component", "module", "interface"
472
+ ]
473
+
474
+ total_terms = 0
475
+ total_words = 0
476
+
477
+ for step in steps:
478
+ words = step["step_text"].lower().split()
479
+ total_words += len(words)
480
+
481
+ for term in technical_terms:
482
+ total_terms += words.count(term)
483
+
484
+ return min(total_terms / max(total_words, 1) * 100, 1.0)
485
+
486
+ def _evaluate_citation_usage(self, steps: List[Dict[str, str]]) -> float:
487
+ """Evaluate citation usage in reasoning."""
488
+ citation_pattern = r'\[chunk_\d+\]'
489
+ total_citations = 0
490
+
491
+ for step in steps:
492
+ citations = re.findall(citation_pattern, step["step_text"])
493
+ total_citations += len(citations)
494
+
495
+ # Good reasoning should have at least one citation per step
496
+ return min(total_citations / len(steps), 1.0)
497
+
498
+ def _evaluate_completeness(self, steps: List[Dict[str, str]]) -> float:
499
+ """Evaluate completeness of reasoning."""
500
+ # Check for key reasoning elements
501
+ completeness_indicators = [
502
+ "analysis", "consider", "examine", "evaluate",
503
+ "conclusion", "summary", "result", "therefore",
504
+ "requirement", "constraint", "limitation", "important"
505
+ ]
506
+
507
+ indicator_count = 0
508
+ for step in steps:
509
+ step_text = step["step_text"].lower()
510
+ for indicator in completeness_indicators:
511
+ if indicator in step_text:
512
+ indicator_count += 1
513
+ break
514
+
515
+ return indicator_count / len(steps)
516
+
517
+
518
+ # Example usage
519
+ if __name__ == "__main__":
520
+ # Initialize engine
521
+ cot_engine = ChainOfThoughtEngine()
522
+
523
+ # Example implementation query
524
+ query = "How do I implement a real-time task scheduler in FreeRTOS with priority inheritance?"
525
+ query_type = QueryType.IMPLEMENTATION
526
+ context = "FreeRTOS supports priority-based scheduling with optional priority inheritance..."
527
+
528
+ # Create a basic template
529
+ base_template = PromptTemplate(
530
+ system_prompt="You are a technical assistant.",
531
+ context_format="Context: {context}",
532
+ query_format="Question: {query}",
533
+ answer_guidelines="Provide a structured answer."
534
+ )
535
+
536
+ # Generate chain-of-thought prompt
537
+ cot_prompt = cot_engine.generate_chain_of_thought_prompt(
538
+ query=query,
539
+ query_type=query_type,
540
+ context=context,
541
+ base_template=base_template
542
+ )
543
+
544
+ print("Chain-of-Thought Enhanced Prompt:")
545
+ print("=" * 50)
546
+ print("System:", cot_prompt["system"][:200], "...")
547
+ print("User:", cot_prompt["user"][:300], "...")
548
+ print("=" * 50)
549
+
550
+ # Example reasoning evaluation
551
+ example_response = """
552
+ Step 1: Let me analyze the requirements
553
+ FreeRTOS provides priority-based scheduling [chunk_1]...
554
+
555
+ Step 2: Breaking down the implementation
556
+ Priority inheritance requires mutex implementation [chunk_2]...
557
+
558
+ Step 3: Synthesizing the solution
559
+ Therefore, we need to configure priority inheritance in FreeRTOS [chunk_3]...
560
+ """
561
+
562
+ steps = cot_engine.extract_reasoning_steps(example_response)
563
+ quality = cot_engine.evaluate_reasoning_quality(steps)
564
+
565
+ print(f"Reasoning Quality: {quality}")
shared_utils/generation/hf_answer_generator.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace API-based answer generation for deployment environments.
3
+
4
+ This module provides answer generation using HuggingFace's Inference API,
5
+ optimized for cloud deployment where local LLMs aren't feasible.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from typing import List, Dict, Any, Optional, Generator, Tuple
12
+ from datetime import datetime
13
+ import re
14
+ from pathlib import Path
15
+ import requests
16
+ import os
17
+ import sys
18
+
19
+ # Import technical prompt templates
20
+ from .prompt_templates import TechnicalPromptTemplates
21
+
22
+ # Import standard interfaces (add this for the adapter)
23
+ try:
24
+ from pathlib import Path
25
+ import sys
26
+ project_root = Path(__file__).parent.parent.parent.parent.parent
27
+ sys.path.append(str(project_root))
28
+ from src.core.interfaces import Document, Answer, AnswerGenerator
29
+ except ImportError:
30
+ # Fallback for standalone usage
31
+ Document = None
32
+ Answer = None
33
+ AnswerGenerator = object
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ @dataclass
39
+ class Citation:
40
+ """Represents a citation to a source document chunk."""
41
+ chunk_id: str
42
+ page_number: int
43
+ source_file: str
44
+ relevance_score: float
45
+ text_snippet: str
46
+
47
+
48
+ @dataclass
49
+ class GeneratedAnswer:
50
+ """Represents a generated answer with citations."""
51
+ answer: str
52
+ citations: List[Citation]
53
+ confidence_score: float
54
+ generation_time: float
55
+ model_used: str
56
+ context_used: List[Dict[str, Any]]
57
+
58
+
59
+ class HuggingFaceAnswerGenerator(AnswerGenerator if AnswerGenerator != object else object):
60
+ """
61
+ Generates answers using HuggingFace Inference API with hybrid reliability.
62
+
63
+ 🎯 HYBRID APPROACH - Best of Both Worlds:
64
+ - Primary: High-quality open models (Zephyr-7B, Mistral-7B-Instruct)
65
+ - Fallback: Reliable classics (DialoGPT-medium)
66
+ - Foundation: HF GPT's proven Docker + auth setup
67
+ - Pro Benefits: Better rate limits, priority processing
68
+
69
+ Optimized for deployment environments with:
70
+ - Fast API-based inference
71
+ - No local model requirements
72
+ - Citation extraction and formatting
73
+ - Confidence scoring
74
+ - Automatic fallback for reliability
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ model_name: str = "sshleifer/distilbart-cnn-12-6",
80
+ api_token: Optional[str] = None,
81
+ temperature: float = 0.3,
82
+ max_tokens: int = 512
83
+ ):
84
+ """
85
+ Initialize the HuggingFace answer generator.
86
+
87
+ Args:
88
+ model_name: HuggingFace model to use
89
+ api_token: HF API token (optional, uses free tier if None)
90
+ temperature: Generation temperature (0.0-1.0)
91
+ max_tokens: Maximum tokens to generate
92
+ """
93
+ self.model_name = model_name
94
+ # Try multiple common token environment variable names
95
+ self.api_token = (api_token or
96
+ os.getenv("HUGGINGFACE_API_TOKEN") or
97
+ os.getenv("HF_TOKEN") or
98
+ os.getenv("HF_API_TOKEN"))
99
+ self.temperature = temperature
100
+ self.max_tokens = max_tokens
101
+
102
+ # Hybrid approach: Classic API + fallback models
103
+ self.api_url = f"https://api-inference.huggingface.co/models/{model_name}"
104
+
105
+ # Prepare headers
106
+ self.headers = {"Content-Type": "application/json"}
107
+ self._auth_failed = False # Track if auth has failed
108
+ if self.api_token:
109
+ self.headers["Authorization"] = f"Bearer {self.api_token}"
110
+ logger.info("Using authenticated HuggingFace API")
111
+ else:
112
+ logger.info("Using free HuggingFace API (rate limited)")
113
+
114
+ # Only include models that actually work based on tests
115
+ self.fallback_models = [
116
+ "deepset/roberta-base-squad2", # Q&A model - perfect for RAG
117
+ "sshleifer/distilbart-cnn-12-6", # Summarization - also good
118
+ "facebook/bart-base", # Base BART - works but needs right format
119
+ ]
120
+
121
+ def _make_api_request(self, url: str, payload: dict, timeout: int = 30) -> requests.Response:
122
+ """Make API request with automatic 401 handling."""
123
+ # Use current headers (may have been updated if auth failed)
124
+ headers = self.headers.copy()
125
+
126
+ # If we've already had auth failure, don't include the token
127
+ if self._auth_failed and "Authorization" in headers:
128
+ del headers["Authorization"]
129
+
130
+ response = requests.post(url, headers=headers, json=payload, timeout=timeout)
131
+
132
+ # Handle 401 error
133
+ if response.status_code == 401 and not self._auth_failed and self.api_token:
134
+ logger.error(f"API request failed: 401 Unauthorized")
135
+ logger.error(f"Response body: {response.text}")
136
+ logger.warning("Token appears invalid, retrying without authentication...")
137
+ self._auth_failed = True
138
+ # Remove auth header
139
+ if "Authorization" in self.headers:
140
+ del self.headers["Authorization"]
141
+ headers = self.headers.copy()
142
+ # Retry without auth
143
+ response = requests.post(url, headers=headers, json=payload, timeout=timeout)
144
+ if response.status_code == 401:
145
+ logger.error("Still getting 401 even without auth token")
146
+ logger.error(f"Response body: {response.text}")
147
+
148
+ return response
149
+
150
+ def _call_api_with_model(self, prompt: str, model_name: str) -> str:
151
+ """Call API with a specific model (for fallback support)."""
152
+ fallback_url = f"https://api-inference.huggingface.co/models/{model_name}"
153
+
154
+ # SIMPLIFIED payload that works
155
+ payload = {"inputs": prompt}
156
+
157
+ # Use helper method with 401 handling
158
+ response = self._make_api_request(fallback_url, payload)
159
+
160
+ response.raise_for_status()
161
+ result = response.json()
162
+
163
+ # Handle response
164
+ if isinstance(result, list) and len(result) > 0:
165
+ if isinstance(result[0], dict):
166
+ return result[0].get("generated_text", "").strip()
167
+ else:
168
+ return str(result[0]).strip()
169
+ elif isinstance(result, dict):
170
+ return result.get("generated_text", "").strip()
171
+ else:
172
+ return str(result).strip()
173
+
174
+ def _create_system_prompt(self) -> str:
175
+ """Create system prompt optimized for the model type."""
176
+ if "squad" in self.model_name.lower() or "roberta" in self.model_name.lower():
177
+ # RoBERTa Squad2 uses question/context format - no system prompt needed
178
+ return ""
179
+ elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower():
180
+ # GPT-2 style completion prompt - simpler is better
181
+ return "Based on the following context, answer the question.\n\nContext: "
182
+ elif "llama" in self.model_name.lower():
183
+ # Llama-2 chat format
184
+ return """<s>[INST] You are a helpful technical documentation assistant. Answer the user's question based only on the provided context. Always cite sources using [chunk_X] format.
185
+
186
+ Context:"""
187
+ elif "flan" in self.model_name.lower() or "t5" in self.model_name.lower():
188
+ # Flan-T5 instruction format - simple and direct
189
+ return """Answer the question based on the context below. Cite sources using [chunk_X] format.
190
+
191
+ Context: """
192
+ elif "falcon" in self.model_name.lower():
193
+ # Falcon instruction format
194
+ return """### Instruction: Answer based on the context and cite sources with [chunk_X].
195
+
196
+ ### Context: """
197
+ elif "bart" in self.model_name.lower():
198
+ # BART summarization format
199
+ return """Summarize the answer to the question from the context. Use [chunk_X] for citations.
200
+
201
+ Context: """
202
+ else:
203
+ # Default instruction prompt for other models
204
+ return """You are a technical documentation assistant that provides clear, accurate answers based on the provided context.
205
+
206
+ CORE PRINCIPLES:
207
+ 1. ANSWER DIRECTLY: If context contains the answer, provide it clearly and confidently
208
+ 2. BE CONCISE: Keep responses focused and avoid unnecessary uncertainty language
209
+ 3. CITE ACCURATELY: Use [chunk_X] citations for every fact from context
210
+
211
+ RESPONSE GUIDELINES:
212
+ - If context has sufficient information → Answer directly and confidently
213
+ - If context has partial information → Answer what's available, note what's missing briefly
214
+ - If context is irrelevant → Brief refusal: "This information isn't available in the provided documents"
215
+
216
+ CITATION FORMAT:
217
+ - Use [chunk_1], [chunk_2] etc. for all facts from context
218
+ - Example: "According to [chunk_1], RISC-V is an open-source architecture."
219
+
220
+ Be direct, confident, and accurate. If the context answers the question, provide that answer clearly."""
221
+
222
+ def _format_context(self, chunks: List[Dict[str, Any]]) -> str:
223
+ """
224
+ Format retrieved chunks into context for the LLM.
225
+
226
+ Args:
227
+ chunks: List of retrieved chunks with metadata
228
+
229
+ Returns:
230
+ Formatted context string
231
+ """
232
+ context_parts = []
233
+
234
+ for i, chunk in enumerate(chunks):
235
+ chunk_text = chunk.get('content', chunk.get('text', ''))
236
+ page_num = chunk.get('metadata', {}).get('page_number', 'unknown')
237
+ source = chunk.get('metadata', {}).get('source', 'unknown')
238
+
239
+ context_parts.append(
240
+ f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n"
241
+ )
242
+
243
+ return "\n---\n".join(context_parts)
244
+
245
+ def _call_api(self, prompt: str) -> str:
246
+ """
247
+ Call HuggingFace Inference API.
248
+
249
+ Args:
250
+ prompt: Input prompt for the model
251
+
252
+ Returns:
253
+ Generated text response
254
+ """
255
+ # Validate prompt
256
+ if not prompt or len(prompt.strip()) < 5:
257
+ logger.warning(f"Prompt too short: '{prompt}' - padding it")
258
+ prompt = f"Please provide information about: {prompt}. Based on the context, give a detailed answer."
259
+
260
+ # Model-specific payload formatting
261
+ if "squad" in self.model_name.lower() or "roberta" in self.model_name.lower():
262
+ # RoBERTa Squad2 needs question and context separately
263
+ # Parse the structured prompt format we create
264
+ if "Context:" in prompt and "Question:" in prompt:
265
+ # Split by the markers we use
266
+ parts = prompt.split("Question:")
267
+ if len(parts) == 2:
268
+ context_part = parts[0].replace("Context:", "").strip()
269
+ question_part = parts[1].strip()
270
+ else:
271
+ # Fallback
272
+ question_part = "What is this about?"
273
+ context_part = prompt
274
+ else:
275
+ # Fallback for unexpected format
276
+ question_part = "What is this about?"
277
+ context_part = prompt
278
+
279
+ # Clean up the context and question
280
+ context_part = context_part.replace("---", "").strip()
281
+ if not question_part or len(question_part.strip()) < 3:
282
+ question_part = "What is the main information?"
283
+
284
+ # Debug output
285
+ print(f"🔍 Squad2 Question: {question_part[:100]}...")
286
+ print(f"🔍 Squad2 Context: {context_part[:200]}...")
287
+
288
+ payload = {
289
+ "inputs": {
290
+ "question": question_part,
291
+ "context": context_part
292
+ }
293
+ }
294
+ elif "bart" in self.model_name.lower() or "distilbart" in self.model_name.lower():
295
+ # BART/DistilBART for summarization
296
+ if len(prompt) < 50:
297
+ prompt = f"{prompt} Please provide a comprehensive answer based on the available information."
298
+
299
+ payload = {
300
+ "inputs": prompt,
301
+ "parameters": {
302
+ "max_length": 150,
303
+ "min_length": 10,
304
+ "do_sample": False
305
+ }
306
+ }
307
+ else:
308
+ # Simple payload for other models
309
+ payload = {"inputs": prompt}
310
+
311
+ try:
312
+ logger.info(f"Calling API URL: {self.api_url}")
313
+ logger.info(f"Headers: {self.headers}")
314
+ logger.info(f"Payload: {payload}")
315
+
316
+ # Use helper method with 401 handling
317
+ response = self._make_api_request(self.api_url, payload)
318
+
319
+ logger.info(f"Response status: {response.status_code}")
320
+ logger.info(f"Response headers: {response.headers}")
321
+
322
+ if response.status_code == 503:
323
+ # Model is loading, wait and retry
324
+ logger.warning("Model loading, waiting 20 seconds...")
325
+ import time
326
+ time.sleep(20)
327
+ response = self._make_api_request(self.api_url, payload)
328
+ logger.info(f"Retry response status: {response.status_code}")
329
+
330
+ elif response.status_code == 404:
331
+ logger.error(f"Model not found: {self.model_name}")
332
+ logger.error(f"Response text: {response.text}")
333
+ # Try fallback models
334
+ for fallback_model in self.fallback_models:
335
+ if fallback_model != self.model_name:
336
+ logger.info(f"Trying fallback model: {fallback_model}")
337
+ try:
338
+ return self._call_api_with_model(prompt, fallback_model)
339
+ except Exception as e:
340
+ logger.warning(f"Fallback model {fallback_model} failed: {e}")
341
+ continue
342
+ return "All models are currently unavailable. Please try again later."
343
+
344
+ response.raise_for_status()
345
+ result = response.json()
346
+
347
+ # Handle different response formats based on model type
348
+ print(f"🔍 API Response type: {type(result)}")
349
+ print(f"🔍 API Response preview: {str(result)[:300]}...")
350
+
351
+ if isinstance(result, dict) and "answer" in result:
352
+ # RoBERTa Squad2 format: {"answer": "...", "score": ..., "start": ..., "end": ...}
353
+ answer = result["answer"].strip()
354
+ print(f"🔍 Squad2 extracted answer: {answer}")
355
+ return answer
356
+ elif isinstance(result, list) and len(result) > 0:
357
+ # Check for DistilBART format (returns dict with summary_text)
358
+ if isinstance(result[0], dict) and "summary_text" in result[0]:
359
+ return result[0]["summary_text"].strip()
360
+ # Check for nested list (BART format: [[...]])
361
+ elif isinstance(result[0], list) and len(result[0]) > 0:
362
+ if isinstance(result[0][0], dict):
363
+ return result[0][0].get("summary_text", str(result[0][0])).strip()
364
+ else:
365
+ # BART base returns embeddings - not useful for text generation
366
+ logger.warning("BART returned embeddings instead of text")
367
+ return "Model returned embeddings instead of text. Please try a different model."
368
+ # Regular list format
369
+ elif isinstance(result[0], dict):
370
+ # Try different keys that models might use
371
+ text = (result[0].get("generated_text", "") or
372
+ result[0].get("summary_text", "") or
373
+ result[0].get("translation_text", "") or
374
+ result[0].get("answer", "") or
375
+ str(result[0]))
376
+ # Remove the input prompt from the output if present
377
+ if isinstance(prompt, str) and text.startswith(prompt):
378
+ text = text[len(prompt):].strip()
379
+ return text
380
+ else:
381
+ return str(result[0]).strip()
382
+ elif isinstance(result, dict):
383
+ # Some models return dict directly
384
+ text = (result.get("generated_text", "") or
385
+ result.get("summary_text", "") or
386
+ result.get("translation_text", "") or
387
+ result.get("answer", "") or
388
+ str(result))
389
+ # Remove input prompt if model included it
390
+ if isinstance(prompt, str) and text.startswith(prompt):
391
+ text = text[len(prompt):].strip()
392
+ return text
393
+ elif isinstance(result, str):
394
+ return result.strip()
395
+ else:
396
+ logger.error(f"Unexpected response format: {type(result)} - {result}")
397
+ return "I apologize, but I couldn't generate a response."
398
+
399
+ except requests.exceptions.RequestException as e:
400
+ logger.error(f"API request failed: {e}")
401
+ if hasattr(e, 'response') and e.response is not None:
402
+ logger.error(f"Response status: {e.response.status_code}")
403
+ logger.error(f"Response body: {e.response.text}")
404
+ return f"API Error: {str(e)}. Using free tier? Try adding an API token."
405
+ except Exception as e:
406
+ logger.error(f"Unexpected error: {e}")
407
+ import traceback
408
+ logger.error(f"Traceback: {traceback.format_exc()}")
409
+ return f"Error: {str(e)}. Please check logs for details."
410
+
411
+ def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]:
412
+ """
413
+ Extract citations from the generated answer and integrate them naturally.
414
+
415
+ Args:
416
+ answer: Generated answer with [chunk_X] citations
417
+ chunks: Original chunks used for context
418
+
419
+ Returns:
420
+ Tuple of (natural_answer, citations)
421
+ """
422
+ citations = []
423
+ citation_pattern = r'\[chunk_(\d+)\]'
424
+
425
+ cited_chunks = set()
426
+
427
+ # Find [chunk_X] citations and collect cited chunks
428
+ matches = re.finditer(citation_pattern, answer)
429
+ for match in matches:
430
+ chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index
431
+ if 0 <= chunk_idx < len(chunks):
432
+ cited_chunks.add(chunk_idx)
433
+
434
+ # FALLBACK: If no explicit citations found but we have an answer and chunks,
435
+ # create citations for the top chunks that were likely used
436
+ if not cited_chunks and chunks and len(answer.strip()) > 50:
437
+ # Use the top chunks that were provided as likely sources
438
+ num_fallback_citations = min(3, len(chunks)) # Use top 3 chunks max
439
+ cited_chunks = set(range(num_fallback_citations))
440
+ print(f"🔧 HF Fallback: Creating {num_fallback_citations} citations for answer without explicit [chunk_X] references", file=sys.stderr, flush=True)
441
+
442
+ # Create Citation objects for each cited chunk
443
+ chunk_to_source = {}
444
+ for idx in cited_chunks:
445
+ chunk = chunks[idx]
446
+ citation = Citation(
447
+ chunk_id=chunk.get('id', f'chunk_{idx}'),
448
+ page_number=chunk.get('metadata', {}).get('page_number', 0),
449
+ source_file=chunk.get('metadata', {}).get('source', 'unknown'),
450
+ relevance_score=chunk.get('score', 0.0),
451
+ text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...'
452
+ )
453
+ citations.append(citation)
454
+
455
+ # Map chunk reference to natural source name
456
+ source_name = chunk.get('metadata', {}).get('source', 'unknown')
457
+ if source_name != 'unknown':
458
+ # Use just the filename without extension for natural reference
459
+ natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ')
460
+ chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation"
461
+ else:
462
+ chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation"
463
+
464
+ # Replace [chunk_X] with natural references instead of removing them
465
+ natural_answer = answer
466
+ for chunk_ref, natural_ref in chunk_to_source.items():
467
+ natural_answer = natural_answer.replace(chunk_ref, natural_ref)
468
+
469
+ # Clean up any remaining unreferenced citations (fallback)
470
+ natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer)
471
+
472
+ # Clean up multiple spaces and formatting
473
+ natural_answer = re.sub(r'\s+', ' ', natural_answer).strip()
474
+
475
+ return natural_answer, citations
476
+
477
+ def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float:
478
+ """
479
+ Calculate confidence score for the generated answer.
480
+
481
+ Args:
482
+ answer: Generated answer
483
+ citations: Extracted citations
484
+ chunks: Retrieved chunks
485
+
486
+ Returns:
487
+ Confidence score (0.0-1.0)
488
+ """
489
+ if not chunks:
490
+ return 0.05 # No context = very low confidence
491
+
492
+ # Base confidence from context quality
493
+ scores = [chunk.get('score', 0) for chunk in chunks]
494
+ max_relevance = max(scores) if scores else 0
495
+
496
+ if max_relevance >= 0.8:
497
+ confidence = 0.7 # High-quality context
498
+ elif max_relevance >= 0.6:
499
+ confidence = 0.5 # Good context
500
+ elif max_relevance >= 0.4:
501
+ confidence = 0.3 # Fair context
502
+ else:
503
+ confidence = 0.1 # Poor context
504
+
505
+ # Uncertainty indicators
506
+ uncertainty_phrases = [
507
+ "does not contain sufficient information",
508
+ "context does not provide",
509
+ "insufficient information",
510
+ "cannot determine",
511
+ "not available in the provided documents"
512
+ ]
513
+
514
+ if any(phrase in answer.lower() for phrase in uncertainty_phrases):
515
+ return min(0.15, confidence * 0.3)
516
+
517
+ # Citation bonus
518
+ if citations and chunks:
519
+ citation_ratio = len(citations) / min(len(chunks), 3)
520
+ confidence += 0.2 * citation_ratio
521
+
522
+ return min(confidence, 0.9) # Cap at 90%
523
+
524
+ def generate(self, query: str, context: List[Document]) -> Answer:
525
+ """
526
+ Generate an answer from query and context documents (standard interface).
527
+
528
+ This is the public interface that conforms to the AnswerGenerator protocol.
529
+ It handles the conversion between standard Document objects and HuggingFace's
530
+ internal chunk format.
531
+
532
+ Args:
533
+ query: User's question
534
+ context: List of relevant Document objects
535
+
536
+ Returns:
537
+ Answer object conforming to standard interface
538
+
539
+ Raises:
540
+ ValueError: If query is empty or context is None
541
+ """
542
+ if not query.strip():
543
+ raise ValueError("Query cannot be empty")
544
+
545
+ if context is None:
546
+ raise ValueError("Context cannot be None")
547
+
548
+ # Internal adapter: Convert Documents to HuggingFace chunk format
549
+ hf_chunks = self._documents_to_hf_chunks(context)
550
+
551
+ # Use existing HuggingFace-specific generation logic
552
+ hf_result = self._generate_internal(query, hf_chunks)
553
+
554
+ # Internal adapter: Convert HuggingFace result to standard Answer
555
+ return self._hf_result_to_answer(hf_result, context)
556
+
557
+ def _generate_internal(
558
+ self,
559
+ query: str,
560
+ chunks: List[Dict[str, Any]]
561
+ ) -> GeneratedAnswer:
562
+ """
563
+ Generate an answer based on the query and retrieved chunks.
564
+
565
+ Args:
566
+ query: User's question
567
+ chunks: Retrieved document chunks
568
+
569
+ Returns:
570
+ GeneratedAnswer object with answer, citations, and metadata
571
+ """
572
+ start_time = datetime.now()
573
+
574
+ # Check for no-context situation
575
+ if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks):
576
+ return GeneratedAnswer(
577
+ answer="This information isn't available in the provided documents.",
578
+ citations=[],
579
+ confidence_score=0.05,
580
+ generation_time=0.1,
581
+ model_used=self.model_name,
582
+ context_used=chunks
583
+ )
584
+
585
+ # Format context from chunks
586
+ context = self._format_context(chunks)
587
+
588
+ # Create prompt using TechnicalPromptTemplates for consistency
589
+ prompt_data = TechnicalPromptTemplates.format_prompt_with_template(
590
+ query=query,
591
+ context=context
592
+ )
593
+
594
+ # Format for specific model types
595
+ if "squad" in self.model_name.lower() or "roberta" in self.model_name.lower():
596
+ # Squad2 uses special question/context format - handled in _call_api
597
+ prompt = f"Context: {context}\n\nQuestion: {query}"
598
+ elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower():
599
+ # Simple completion style for GPT-2
600
+ prompt = f"""{prompt_data['system']}
601
+
602
+ {prompt_data['user']}
603
+
604
+ MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts.
605
+
606
+ Answer:"""
607
+ elif "llama" in self.model_name.lower():
608
+ # Llama-2 chat format with technical templates
609
+ prompt = f"""[INST] {prompt_data['system']}
610
+
611
+ {prompt_data['user']}
612
+
613
+ MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
614
+ elif "mistral" in self.model_name.lower():
615
+ # Mistral instruction format with technical templates
616
+ prompt = f"""[INST] {prompt_data['system']}
617
+
618
+ {prompt_data['user']}
619
+
620
+ MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
621
+ elif "codellama" in self.model_name.lower():
622
+ # CodeLlama instruction format with technical templates
623
+ prompt = f"""[INST] {prompt_data['system']}
624
+
625
+ {prompt_data['user']}
626
+
627
+ MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
628
+ elif "distilbart" in self.model_name.lower():
629
+ # DistilBART is a summarization model - simpler prompt works better
630
+ prompt = f"""Technical Documentation Context:
631
+ {context}
632
+
633
+ Question: {query}
634
+
635
+ Instructions: Provide a technical answer using only the context above. Include source citations."""
636
+ else:
637
+ # Default instruction prompt with technical templates
638
+ prompt = f"""{prompt_data['system']}
639
+
640
+ {prompt_data['user']}
641
+
642
+ MANDATORY: Use [chunk_1], [chunk_2] etc. for all factual statements.
643
+
644
+ Answer:"""
645
+
646
+ # Generate response
647
+ try:
648
+ answer_with_citations = self._call_api(prompt)
649
+
650
+ # Extract and clean citations
651
+ clean_answer, citations = self._extract_citations(answer_with_citations, chunks)
652
+
653
+ # Calculate confidence
654
+ confidence = self._calculate_confidence(clean_answer, citations, chunks)
655
+
656
+ # Calculate generation time
657
+ generation_time = (datetime.now() - start_time).total_seconds()
658
+
659
+ return GeneratedAnswer(
660
+ answer=clean_answer,
661
+ citations=citations,
662
+ confidence_score=confidence,
663
+ generation_time=generation_time,
664
+ model_used=self.model_name,
665
+ context_used=chunks
666
+ )
667
+
668
+ except Exception as e:
669
+ logger.error(f"Error generating answer: {e}")
670
+ return GeneratedAnswer(
671
+ answer="I apologize, but I encountered an error while generating the answer. Please try again.",
672
+ citations=[],
673
+ confidence_score=0.0,
674
+ generation_time=0.0,
675
+ model_used=self.model_name,
676
+ context_used=chunks
677
+ )
678
+
679
+ def generate_with_custom_prompt(
680
+ self,
681
+ query: str,
682
+ chunks: List[Dict[str, Any]],
683
+ custom_prompt: Dict[str, str]
684
+ ) -> GeneratedAnswer:
685
+ """
686
+ Generate answer using a custom prompt (for adaptive prompting).
687
+
688
+ Args:
689
+ query: User's question
690
+ chunks: Retrieved context chunks
691
+ custom_prompt: Dict with 'system' and 'user' prompts
692
+
693
+ Returns:
694
+ GeneratedAnswer with custom prompt enhancement
695
+ """
696
+ start_time = datetime.now()
697
+
698
+ # Format context
699
+ context = self._format_context(chunks)
700
+
701
+ # Build prompt using custom format
702
+ if "llama" in self.model_name.lower():
703
+ prompt = f"""[INST] {custom_prompt['system']}
704
+
705
+ {custom_prompt['user']}
706
+
707
+ MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
708
+ elif "mistral" in self.model_name.lower():
709
+ prompt = f"""[INST] {custom_prompt['system']}
710
+
711
+ {custom_prompt['user']}
712
+
713
+ MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
714
+ elif "distilbart" in self.model_name.lower():
715
+ # For BART, use the user prompt directly (it already contains context)
716
+ prompt = custom_prompt['user']
717
+ else:
718
+ # Default format
719
+ prompt = f"""{custom_prompt['system']}
720
+
721
+ {custom_prompt['user']}
722
+
723
+ MANDATORY: Use [chunk_1], [chunk_2] etc. for all factual statements.
724
+
725
+ Answer:"""
726
+
727
+ # Generate response
728
+ try:
729
+ answer_with_citations = self._call_api(prompt)
730
+
731
+ # Extract and clean citations
732
+ clean_answer, citations = self._extract_citations(answer_with_citations, chunks)
733
+
734
+ # Calculate confidence
735
+ confidence = self._calculate_confidence(clean_answer, citations, chunks)
736
+
737
+ # Calculate generation time
738
+ generation_time = (datetime.now() - start_time).total_seconds()
739
+
740
+ return GeneratedAnswer(
741
+ answer=clean_answer,
742
+ citations=citations,
743
+ confidence_score=confidence,
744
+ generation_time=generation_time,
745
+ model_used=self.model_name,
746
+ context_used=chunks
747
+ )
748
+
749
+ except Exception as e:
750
+ logger.error(f"Error generating answer with custom prompt: {e}")
751
+ return GeneratedAnswer(
752
+ answer="I apologize, but I encountered an error while generating the answer. Please try again.",
753
+ citations=[],
754
+ confidence_score=0.0,
755
+ generation_time=0.0,
756
+ model_used=self.model_name,
757
+ context_used=chunks
758
+ )
759
+
760
+ def format_answer_with_citations(self, generated_answer: GeneratedAnswer) -> str:
761
+ """
762
+ Format the generated answer with citations for display.
763
+
764
+ Args:
765
+ generated_answer: GeneratedAnswer object
766
+
767
+ Returns:
768
+ Formatted string with answer and citations
769
+ """
770
+ formatted = f"{generated_answer.answer}\n\n"
771
+
772
+ if generated_answer.citations:
773
+ formatted += "**Sources:**\n"
774
+ for i, citation in enumerate(generated_answer.citations, 1):
775
+ formatted += f"{i}. {citation.source_file} (Page {citation.page_number})\n"
776
+
777
+ formatted += f"\n*Confidence: {generated_answer.confidence_score:.1%} | "
778
+ formatted += f"Model: {generated_answer.model_used} | "
779
+ formatted += f"Time: {generated_answer.generation_time:.2f}s*"
780
+
781
+ return formatted
782
+
783
+ def _documents_to_hf_chunks(self, documents: List[Document]) -> List[Dict[str, Any]]:
784
+ """
785
+ Convert Document objects to HuggingFace's internal chunk format.
786
+
787
+ This internal adapter ensures that Document objects are properly formatted
788
+ for HuggingFace's processing pipeline while keeping the format requirements
789
+ encapsulated within this class.
790
+
791
+ Args:
792
+ documents: List of Document objects from the standard interface
793
+
794
+ Returns:
795
+ List of chunk dictionaries in HuggingFace's expected format
796
+ """
797
+ if not documents:
798
+ return []
799
+
800
+ chunks = []
801
+ for i, doc in enumerate(documents):
802
+ chunk = {
803
+ "id": f"chunk_{i+1}",
804
+ "content": doc.content, # HuggingFace expects "content" field
805
+ "text": doc.content, # Alternative field for compatibility
806
+ "score": 1.0, # Default relevance score
807
+ "metadata": {
808
+ "page_number": doc.metadata.get("start_page", 1),
809
+ "source": doc.metadata.get("source", "unknown"),
810
+ **doc.metadata # Include all original metadata
811
+ }
812
+ }
813
+ chunks.append(chunk)
814
+
815
+ return chunks
816
+
817
+ def _hf_result_to_answer(self, hf_result: GeneratedAnswer, original_context: List[Document]) -> Answer:
818
+ """
819
+ Convert HuggingFace's GeneratedAnswer to the standard Answer format.
820
+
821
+ This internal adapter converts HuggingFace's result format back to the
822
+ standard interface format expected by the rest of the system.
823
+
824
+ Args:
825
+ hf_result: Result from HuggingFace's internal generation
826
+ original_context: Original Document objects for sources
827
+
828
+ Returns:
829
+ Answer object conforming to standard interface
830
+ """
831
+ if Answer is None:
832
+ # Fallback if standard interface not available
833
+ return hf_result
834
+
835
+ # Convert to standard Answer format
836
+ return Answer(
837
+ text=hf_result.answer,
838
+ sources=original_context, # Use original Document objects
839
+ confidence=hf_result.confidence_score,
840
+ metadata={
841
+ "model_used": hf_result.model_used,
842
+ "generation_time": hf_result.generation_time,
843
+ "citations": [
844
+ {
845
+ "chunk_id": cit.chunk_id,
846
+ "page_number": cit.page_number,
847
+ "source_file": cit.source_file,
848
+ "relevance_score": cit.relevance_score,
849
+ "text_snippet": cit.text_snippet
850
+ }
851
+ for cit in hf_result.citations
852
+ ],
853
+ "provider": "huggingface",
854
+ "api_token_used": bool(self.api_token),
855
+ "fallback_used": hasattr(self, '_auth_failed') and self._auth_failed
856
+ }
857
+ )
858
+
859
+
860
+ if __name__ == "__main__":
861
+ # Example usage
862
+ generator = HuggingFaceAnswerGenerator()
863
+
864
+ # Example chunks (would come from retrieval system)
865
+ example_chunks = [
866
+ {
867
+ "id": "chunk_1",
868
+ "content": "RISC-V is an open-source instruction set architecture (ISA) based on reduced instruction set computer (RISC) principles.",
869
+ "metadata": {"page_number": 1, "source": "riscv-spec.pdf"},
870
+ "score": 0.95
871
+ }
872
+ ]
873
+
874
+ # Generate answer
875
+ result = generator.generate(
876
+ query="What is RISC-V?",
877
+ chunks=example_chunks
878
+ )
879
+
880
+ # Display formatted result
881
+ print(generator.format_answer_with_citations(result))
shared_utils/generation/inference_providers_generator.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ HuggingFace Inference Providers API-based answer generation.
4
+
5
+ This module provides answer generation using HuggingFace's new Inference Providers API,
6
+ which offers OpenAI-compatible chat completion format for better reliability and consistency.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import logging
12
+ import time
13
+ from datetime import datetime
14
+ from typing import List, Dict, Any, Optional, Tuple
15
+ from pathlib import Path
16
+ import re
17
+
18
+ # Import shared components
19
+ from .hf_answer_generator import Citation, GeneratedAnswer
20
+ from .prompt_templates import TechnicalPromptTemplates
21
+
22
+ # Check if huggingface_hub is new enough for InferenceClient chat completion
23
+ try:
24
+ from huggingface_hub import InferenceClient
25
+ from huggingface_hub import __version__ as hf_hub_version
26
+ print(f"🔍 Using huggingface_hub version: {hf_hub_version}", file=sys.stderr, flush=True)
27
+ except ImportError:
28
+ print("❌ huggingface_hub not found or outdated. Please install: pip install -U huggingface-hub", file=sys.stderr, flush=True)
29
+ raise
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class InferenceProvidersGenerator:
35
+ """
36
+ Generates answers using HuggingFace Inference Providers API.
37
+
38
+ This uses the new OpenAI-compatible chat completion format for better reliability
39
+ compared to the classic Inference API. It provides:
40
+ - Consistent response format across models
41
+ - Better error handling and retry logic
42
+ - Support for streaming responses
43
+ - Automatic provider selection and failover
44
+ """
45
+
46
+ # Models that work well with chat completion format
47
+ CHAT_MODELS = [
48
+ "microsoft/DialoGPT-medium", # Proven conversational model
49
+ "google/gemma-2-2b-it", # Instruction-tuned, good for Q&A
50
+ "meta-llama/Llama-3.2-3B-Instruct", # If available with token
51
+ "Qwen/Qwen2.5-1.5B-Instruct", # Small, fast, good quality
52
+ ]
53
+
54
+ # Fallback to classic API models if chat completion fails
55
+ CLASSIC_FALLBACK_MODELS = [
56
+ "google/flan-t5-small", # Good for instructions
57
+ "deepset/roberta-base-squad2", # Q&A specific
58
+ "facebook/bart-base", # Summarization
59
+ ]
60
+
61
+ def __init__(
62
+ self,
63
+ model_name: Optional[str] = None,
64
+ api_token: Optional[str] = None,
65
+ temperature: float = 0.3,
66
+ max_tokens: int = 512,
67
+ timeout: int = 30
68
+ ):
69
+ """
70
+ Initialize the Inference Providers answer generator.
71
+
72
+ Args:
73
+ model_name: Model to use (defaults to first available chat model)
74
+ api_token: HF API token (uses env vars if not provided)
75
+ temperature: Generation temperature (0.0-1.0)
76
+ max_tokens: Maximum tokens to generate
77
+ timeout: Request timeout in seconds
78
+ """
79
+ # Get API token from various sources
80
+ self.api_token = (
81
+ api_token or
82
+ os.getenv("HUGGINGFACE_API_TOKEN") or
83
+ os.getenv("HF_TOKEN") or
84
+ os.getenv("HF_API_TOKEN")
85
+ )
86
+
87
+ if not self.api_token:
88
+ print("⚠️ No HF API token found. Inference Providers requires authentication.", file=sys.stderr, flush=True)
89
+ print("Set HF_TOKEN, HUGGINGFACE_API_TOKEN, or HF_API_TOKEN environment variable.", file=sys.stderr, flush=True)
90
+ raise ValueError("HuggingFace API token required for Inference Providers")
91
+
92
+ print(f"✅ Found HF token (starts with: {self.api_token[:8]}...)", file=sys.stderr, flush=True)
93
+
94
+ # Initialize client with token
95
+ self.client = InferenceClient(token=self.api_token)
96
+ self.temperature = temperature
97
+ self.max_tokens = max_tokens
98
+ self.timeout = timeout
99
+
100
+ # Select model
101
+ self.model_name = model_name or self.CHAT_MODELS[0]
102
+ self.using_chat_completion = True
103
+
104
+ print(f"🚀 Initialized Inference Providers with model: {self.model_name}", file=sys.stderr, flush=True)
105
+
106
+ # Test the connection
107
+ self._test_connection()
108
+
109
+ def _test_connection(self):
110
+ """Test if the API is accessible and model is available."""
111
+ print(f"🔧 Testing Inference Providers API connection...", file=sys.stderr, flush=True)
112
+
113
+ try:
114
+ # Try a simple test query
115
+ test_messages = [
116
+ {"role": "user", "content": "Hello"}
117
+ ]
118
+
119
+ # First try chat completion (preferred)
120
+ try:
121
+ response = self.client.chat_completion(
122
+ messages=test_messages,
123
+ model=self.model_name,
124
+ max_tokens=10,
125
+ temperature=0.1
126
+ )
127
+ print(f"✅ Chat completion API working with {self.model_name}", file=sys.stderr, flush=True)
128
+ self.using_chat_completion = True
129
+ return
130
+ except Exception as e:
131
+ print(f"��️ Chat completion failed for {self.model_name}: {e}", file=sys.stderr, flush=True)
132
+
133
+ # Try other chat models
134
+ for model in self.CHAT_MODELS:
135
+ if model != self.model_name:
136
+ try:
137
+ print(f"🔄 Trying {model}...", file=sys.stderr, flush=True)
138
+ response = self.client.chat_completion(
139
+ messages=test_messages,
140
+ model=model,
141
+ max_tokens=10
142
+ )
143
+ print(f"✅ Found working model: {model}", file=sys.stderr, flush=True)
144
+ self.model_name = model
145
+ self.using_chat_completion = True
146
+ return
147
+ except:
148
+ continue
149
+
150
+ # If chat completion fails, test classic text generation
151
+ print("🔄 Falling back to classic text generation API...", file=sys.stderr, flush=True)
152
+ for model in self.CLASSIC_FALLBACK_MODELS:
153
+ try:
154
+ response = self.client.text_generation(
155
+ model=model,
156
+ prompt="Hello",
157
+ max_new_tokens=10
158
+ )
159
+ print(f"✅ Classic API working with fallback model: {model}", file=sys.stderr, flush=True)
160
+ self.model_name = model
161
+ self.using_chat_completion = False
162
+ return
163
+ except:
164
+ continue
165
+
166
+ raise Exception("No working models found in Inference Providers API")
167
+
168
+ except Exception as e:
169
+ print(f"❌ Inference Providers API test failed: {e}", file=sys.stderr, flush=True)
170
+ raise
171
+
172
+ def _format_context(self, chunks: List[Dict[str, Any]]) -> str:
173
+ """Format retrieved chunks into context string."""
174
+ context_parts = []
175
+
176
+ for i, chunk in enumerate(chunks):
177
+ chunk_text = chunk.get('content', chunk.get('text', ''))
178
+ page_num = chunk.get('metadata', {}).get('page_number', 'unknown')
179
+ source = chunk.get('metadata', {}).get('source', 'unknown')
180
+
181
+ context_parts.append(
182
+ f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n"
183
+ )
184
+
185
+ return "\n---\n".join(context_parts)
186
+
187
+ def _create_messages(self, query: str, context: str) -> List[Dict[str, str]]:
188
+ """Create chat messages using TechnicalPromptTemplates."""
189
+ # Get appropriate template based on query type
190
+ prompt_data = TechnicalPromptTemplates.format_prompt_with_template(
191
+ query=query,
192
+ context=context
193
+ )
194
+
195
+ # Create messages for chat completion
196
+ messages = [
197
+ {
198
+ "role": "system",
199
+ "content": prompt_data['system'] + "\n\nMANDATORY: Use [chunk_X] citations for all facts."
200
+ },
201
+ {
202
+ "role": "user",
203
+ "content": prompt_data['user']
204
+ }
205
+ ]
206
+
207
+ return messages
208
+
209
+ def _call_chat_completion(self, messages: List[Dict[str, str]]) -> str:
210
+ """Call the chat completion API."""
211
+ try:
212
+ print(f"🤖 Calling Inference Providers chat completion with {self.model_name}...", file=sys.stderr, flush=True)
213
+
214
+ # Use chat completion with proper error handling
215
+ response = self.client.chat_completion(
216
+ messages=messages,
217
+ model=self.model_name,
218
+ temperature=self.temperature,
219
+ max_tokens=self.max_tokens,
220
+ stream=False
221
+ )
222
+
223
+ # Extract content from response
224
+ if hasattr(response, 'choices') and response.choices:
225
+ content = response.choices[0].message.content
226
+ print(f"✅ Got response: {len(content)} characters", file=sys.stderr, flush=True)
227
+ return content
228
+ else:
229
+ print(f"⚠️ Unexpected response format: {response}", file=sys.stderr, flush=True)
230
+ return str(response)
231
+
232
+ except Exception as e:
233
+ print(f"❌ Chat completion error: {e}", file=sys.stderr, flush=True)
234
+
235
+ # Try with a fallback model
236
+ if self.model_name != "microsoft/DialoGPT-medium":
237
+ print("🔄 Trying fallback model: microsoft/DialoGPT-medium", file=sys.stderr, flush=True)
238
+ try:
239
+ response = self.client.chat_completion(
240
+ messages=messages,
241
+ model="microsoft/DialoGPT-medium",
242
+ temperature=self.temperature,
243
+ max_tokens=self.max_tokens
244
+ )
245
+ if hasattr(response, 'choices') and response.choices:
246
+ return response.choices[0].message.content
247
+ except:
248
+ pass
249
+
250
+ raise Exception(f"Chat completion failed: {e}")
251
+
252
+ def _call_classic_api(self, query: str, context: str) -> str:
253
+ """Fallback to classic text generation API."""
254
+ print(f"🔄 Using classic text generation with {self.model_name}...", file=sys.stderr, flush=True)
255
+
256
+ # Format prompt for classic API
257
+ if "squad" in self.model_name.lower():
258
+ # Q&A format for squad models
259
+ prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
260
+ elif "flan" in self.model_name.lower():
261
+ # Instruction format for Flan models
262
+ prompt = f"Answer the question based on the context.\n\nContext: {context}\n\nQuestion: {query}\n\nAnswer:"
263
+ else:
264
+ # Generic format
265
+ prompt = f"Based on the following context, answer the question.\n\nContext:\n{context}\n\nQuestion: {query}\n\nAnswer:"
266
+
267
+ try:
268
+ response = self.client.text_generation(
269
+ model=self.model_name,
270
+ prompt=prompt,
271
+ max_new_tokens=self.max_tokens,
272
+ temperature=self.temperature
273
+ )
274
+ return response
275
+ except Exception as e:
276
+ print(f"❌ Classic API error: {e}", file=sys.stderr, flush=True)
277
+ return f"Error generating response: {str(e)}"
278
+
279
+ def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]:
280
+ """Extract citations from the answer."""
281
+ citations = []
282
+ citation_pattern = r'\[chunk_(\d+)\]'
283
+
284
+ cited_chunks = set()
285
+
286
+ # Find explicit citations
287
+ matches = re.finditer(citation_pattern, answer)
288
+ for match in matches:
289
+ chunk_idx = int(match.group(1)) - 1
290
+ if 0 <= chunk_idx < len(chunks):
291
+ cited_chunks.add(chunk_idx)
292
+
293
+ # Fallback: Create citations for top chunks if none found
294
+ if not cited_chunks and chunks and len(answer.strip()) > 50:
295
+ num_fallback = min(3, len(chunks))
296
+ cited_chunks = set(range(num_fallback))
297
+ print(f"🔧 Creating {num_fallback} fallback citations", file=sys.stderr, flush=True)
298
+
299
+ # Create Citation objects
300
+ chunk_to_source = {}
301
+ for idx in cited_chunks:
302
+ chunk = chunks[idx]
303
+ citation = Citation(
304
+ chunk_id=chunk.get('id', f'chunk_{idx}'),
305
+ page_number=chunk.get('metadata', {}).get('page_number', 0),
306
+ source_file=chunk.get('metadata', {}).get('source', 'unknown'),
307
+ relevance_score=chunk.get('score', 0.0),
308
+ text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...'
309
+ )
310
+ citations.append(citation)
311
+
312
+ # Map for natural language replacement
313
+ source_name = chunk.get('metadata', {}).get('source', 'unknown')
314
+ if source_name != 'unknown':
315
+ natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ')
316
+ chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation"
317
+ else:
318
+ chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation"
319
+
320
+ # Replace citations with natural language
321
+ natural_answer = answer
322
+ for chunk_ref, natural_ref in chunk_to_source.items():
323
+ natural_answer = natural_answer.replace(chunk_ref, natural_ref)
324
+
325
+ # Clean up any remaining citations
326
+ natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer)
327
+ natural_answer = re.sub(r'\s+', ' ', natural_answer).strip()
328
+
329
+ return natural_answer, citations
330
+
331
+ def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float:
332
+ """Calculate confidence score for the answer."""
333
+ if not answer or len(answer.strip()) < 10:
334
+ return 0.1
335
+
336
+ # Base confidence from chunk quality
337
+ if len(chunks) >= 3:
338
+ confidence = 0.8
339
+ elif len(chunks) >= 2:
340
+ confidence = 0.7
341
+ else:
342
+ confidence = 0.6
343
+
344
+ # Citation bonus
345
+ if citations and chunks:
346
+ citation_ratio = len(citations) / min(len(chunks), 3)
347
+ confidence += 0.15 * citation_ratio
348
+
349
+ # Check for uncertainty phrases
350
+ uncertainty_phrases = [
351
+ "insufficient information",
352
+ "cannot determine",
353
+ "not available in the provided documents",
354
+ "i don't know",
355
+ "unclear"
356
+ ]
357
+
358
+ if any(phrase in answer.lower() for phrase in uncertainty_phrases):
359
+ confidence *= 0.3
360
+
361
+ return min(confidence, 0.95)
362
+
363
+ def generate(self, query: str, chunks: List[Dict[str, Any]]) -> GeneratedAnswer:
364
+ """
365
+ Generate an answer using Inference Providers API.
366
+
367
+ Args:
368
+ query: User's question
369
+ chunks: Retrieved document chunks
370
+
371
+ Returns:
372
+ GeneratedAnswer with answer, citations, and metadata
373
+ """
374
+ start_time = datetime.now()
375
+
376
+ # Check for no-context situation
377
+ if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks):
378
+ return GeneratedAnswer(
379
+ answer="This information isn't available in the provided documents.",
380
+ citations=[],
381
+ confidence_score=0.05,
382
+ generation_time=0.1,
383
+ model_used=self.model_name,
384
+ context_used=chunks
385
+ )
386
+
387
+ # Format context
388
+ context = self._format_context(chunks)
389
+
390
+ # Generate answer
391
+ try:
392
+ if self.using_chat_completion:
393
+ # Create chat messages
394
+ messages = self._create_messages(query, context)
395
+
396
+ # Call chat completion API
397
+ answer_text = self._call_chat_completion(messages)
398
+ else:
399
+ # Fallback to classic API
400
+ answer_text = self._call_classic_api(query, context)
401
+
402
+ # Extract citations and clean answer
403
+ natural_answer, citations = self._extract_citations(answer_text, chunks)
404
+
405
+ # Calculate confidence
406
+ confidence = self._calculate_confidence(natural_answer, citations, chunks)
407
+
408
+ generation_time = (datetime.now() - start_time).total_seconds()
409
+
410
+ return GeneratedAnswer(
411
+ answer=natural_answer,
412
+ citations=citations,
413
+ confidence_score=confidence,
414
+ generation_time=generation_time,
415
+ model_used=self.model_name,
416
+ context_used=chunks
417
+ )
418
+
419
+ except Exception as e:
420
+ logger.error(f"Error generating answer: {e}")
421
+ print(f"❌ Generation failed: {e}", file=sys.stderr, flush=True)
422
+
423
+ # Return error response
424
+ return GeneratedAnswer(
425
+ answer="I apologize, but I encountered an error while generating the answer. Please try again.",
426
+ citations=[],
427
+ confidence_score=0.0,
428
+ generation_time=(datetime.now() - start_time).total_seconds(),
429
+ model_used=self.model_name,
430
+ context_used=chunks
431
+ )
432
+
433
+ def generate_with_custom_prompt(
434
+ self,
435
+ query: str,
436
+ chunks: List[Dict[str, Any]],
437
+ custom_prompt: Dict[str, str]
438
+ ) -> GeneratedAnswer:
439
+ """
440
+ Generate answer using a custom prompt (for adaptive prompting).
441
+
442
+ Args:
443
+ query: User's question
444
+ chunks: Retrieved context chunks
445
+ custom_prompt: Dict with 'system' and 'user' prompts
446
+
447
+ Returns:
448
+ GeneratedAnswer with custom prompt enhancement
449
+ """
450
+ start_time = datetime.now()
451
+
452
+ if not chunks:
453
+ return GeneratedAnswer(
454
+ answer="I don't have enough context to answer your question.",
455
+ citations=[],
456
+ confidence_score=0.0,
457
+ generation_time=0.1,
458
+ model_used=self.model_name,
459
+ context_used=chunks
460
+ )
461
+
462
+ try:
463
+ # Try chat completion with custom prompt
464
+ messages = [
465
+ {"role": "system", "content": custom_prompt['system']},
466
+ {"role": "user", "content": custom_prompt['user']}
467
+ ]
468
+
469
+ answer_text = self._call_chat_completion(messages)
470
+
471
+ # Extract citations and clean answer
472
+ natural_answer, citations = self._extract_citations(answer_text, chunks)
473
+
474
+ # Calculate confidence
475
+ confidence = self._calculate_confidence(natural_answer, citations, chunks)
476
+
477
+ generation_time = (datetime.now() - start_time).total_seconds()
478
+
479
+ return GeneratedAnswer(
480
+ answer=natural_answer,
481
+ citations=citations,
482
+ confidence_score=confidence,
483
+ generation_time=generation_time,
484
+ model_used=self.model_name,
485
+ context_used=chunks
486
+ )
487
+
488
+ except Exception as e:
489
+ logger.error(f"Error generating answer with custom prompt: {e}")
490
+ print(f"❌ Custom prompt generation failed: {e}", file=sys.stderr, flush=True)
491
+
492
+ # Return error response
493
+ return GeneratedAnswer(
494
+ answer="I apologize, but I encountered an error while generating the answer. Please try again.",
495
+ citations=[],
496
+ confidence_score=0.0,
497
+ generation_time=(datetime.now() - start_time).total_seconds(),
498
+ model_used=self.model_name,
499
+ context_used=chunks
500
+ )
501
+
502
+
503
+ # Example usage
504
+ if __name__ == "__main__":
505
+ # Test the generator
506
+ print("Testing Inference Providers Generator...")
507
+
508
+ try:
509
+ generator = InferenceProvidersGenerator()
510
+
511
+ # Test chunks
512
+ test_chunks = [
513
+ {
514
+ "content": "RISC-V is an open-source instruction set architecture (ISA) based on established reduced instruction set computer (RISC) principles.",
515
+ "metadata": {"page_number": 1, "source": "riscv-spec.pdf"},
516
+ "score": 0.95
517
+ },
518
+ {
519
+ "content": "Unlike most other ISA designs, RISC-V is provided under open source licenses that do not require fees to use.",
520
+ "metadata": {"page_number": 2, "source": "riscv-spec.pdf"},
521
+ "score": 0.85
522
+ }
523
+ ]
524
+
525
+ # Generate answer
526
+ result = generator.generate("What is RISC-V and why is it important?", test_chunks)
527
+
528
+ print(f"\n📝 Answer: {result.answer}")
529
+ print(f"📊 Confidence: {result.confidence_score:.1%}")
530
+ print(f"⏱️ Generation time: {result.generation_time:.2f}s")
531
+ print(f"🤖 Model: {result.model_used}")
532
+ print(f"📚 Citations: {len(result.citations)}")
533
+
534
+ except Exception as e:
535
+ print(f"❌ Test failed: {e}")
536
+ import traceback
537
+ traceback.print_exc()
shared_utils/generation/ollama_answer_generator.py ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Ollama-based answer generator for local inference.
4
+
5
+ Provides the same interface as HuggingFaceAnswerGenerator but uses
6
+ local Ollama server for model inference.
7
+ """
8
+
9
+ import time
10
+ import requests
11
+ import json
12
+ import re
13
+ import sys
14
+ from datetime import datetime
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Any, Tuple
17
+ from dataclasses import dataclass
18
+
19
+ # Import shared components
20
+ from .hf_answer_generator import Citation, GeneratedAnswer
21
+ from .prompt_templates import TechnicalPromptTemplates
22
+
23
+ # Import standard interfaces (add this for the adapter)
24
+ try:
25
+ from pathlib import Path
26
+ import sys
27
+ project_root = Path(__file__).parent.parent.parent.parent.parent
28
+ sys.path.append(str(project_root))
29
+ from src.core.interfaces import Document, Answer, AnswerGenerator
30
+ except ImportError:
31
+ # Fallback for standalone usage
32
+ Document = None
33
+ Answer = None
34
+ AnswerGenerator = object
35
+
36
+
37
+ class OllamaAnswerGenerator(AnswerGenerator if AnswerGenerator != object else object):
38
+ """
39
+ Generates answers using local Ollama server.
40
+
41
+ Perfect for:
42
+ - Local development
43
+ - Privacy-sensitive applications
44
+ - No API rate limits
45
+ - Consistent performance
46
+ - Offline operation
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ model_name: str = "llama3.2:3b",
52
+ base_url: str = "http://localhost:11434",
53
+ temperature: float = 0.3,
54
+ max_tokens: int = 512,
55
+ ):
56
+ """
57
+ Initialize Ollama answer generator.
58
+
59
+ Args:
60
+ model_name: Ollama model to use (e.g., "llama3.2:3b", "mistral")
61
+ base_url: Ollama server URL
62
+ temperature: Generation temperature
63
+ max_tokens: Maximum tokens to generate
64
+ """
65
+ self.model_name = model_name
66
+ self.base_url = base_url.rstrip("/")
67
+ self.temperature = temperature
68
+ self.max_tokens = max_tokens
69
+
70
+ # Test connection
71
+ self._test_connection()
72
+
73
+ def _test_connection(self):
74
+ """Test if Ollama server is accessible."""
75
+ # Reduce retries for faster initialization - container should be ready quickly
76
+ max_retries = 12 # Wait up to 60 seconds for Ollama to start
77
+ retry_delay = 5
78
+
79
+ print(
80
+ f"🔧 Testing connection to {self.base_url}/api/tags...",
81
+ file=sys.stderr,
82
+ flush=True,
83
+ )
84
+
85
+ for attempt in range(max_retries):
86
+ try:
87
+ response = requests.get(f"{self.base_url}/api/tags", timeout=8)
88
+ if response.status_code == 200:
89
+ print(
90
+ f"✅ Connected to Ollama at {self.base_url}",
91
+ file=sys.stderr,
92
+ flush=True,
93
+ )
94
+
95
+ # Check if our model is available
96
+ models = response.json().get("models", [])
97
+ model_names = [m["name"] for m in models]
98
+
99
+ if self.model_name in model_names:
100
+ print(
101
+ f"✅ Model {self.model_name} is available",
102
+ file=sys.stderr,
103
+ flush=True,
104
+ )
105
+ return # Success!
106
+ else:
107
+ print(
108
+ f"⚠️ Model {self.model_name} not found. Available: {model_names}",
109
+ file=sys.stderr,
110
+ flush=True,
111
+ )
112
+ if models: # If any models are available, use the first one
113
+ fallback_model = model_names[0]
114
+ print(
115
+ f"🔄 Using fallback model: {fallback_model}",
116
+ file=sys.stderr,
117
+ flush=True,
118
+ )
119
+ self.model_name = fallback_model
120
+ return
121
+ else:
122
+ print(
123
+ f"📥 No models found, will try to pull {self.model_name}",
124
+ file=sys.stderr,
125
+ flush=True,
126
+ )
127
+ # Try to pull the model
128
+ self._pull_model(self.model_name)
129
+ return
130
+ else:
131
+ print(f"⚠️ Ollama server returned status {response.status_code}")
132
+ if attempt < max_retries - 1:
133
+ print(
134
+ f"🔄 Retry {attempt + 1}/{max_retries} in {retry_delay} seconds..."
135
+ )
136
+ time.sleep(retry_delay)
137
+ continue
138
+
139
+ except requests.exceptions.ConnectionError:
140
+ if attempt < max_retries - 1:
141
+ print(
142
+ f"⏳ Ollama not ready yet, retry {attempt + 1}/{max_retries} in {retry_delay} seconds..."
143
+ )
144
+ time.sleep(retry_delay)
145
+ continue
146
+ else:
147
+ raise Exception(
148
+ f"Cannot connect to Ollama server at {self.base_url} after 60 seconds. Check if it's running."
149
+ )
150
+ except requests.exceptions.Timeout:
151
+ if attempt < max_retries - 1:
152
+ print(f"⏳ Ollama timeout, retry {attempt + 1}/{max_retries}...")
153
+ time.sleep(retry_delay)
154
+ continue
155
+ else:
156
+ raise Exception("Ollama server timeout after multiple retries.")
157
+ except Exception as e:
158
+ if attempt < max_retries - 1:
159
+ print(f"⚠️ Ollama error: {e}, retry {attempt + 1}/{max_retries}...")
160
+ time.sleep(retry_delay)
161
+ continue
162
+ else:
163
+ raise Exception(
164
+ f"Ollama connection failed after {max_retries} attempts: {e}"
165
+ )
166
+
167
+ raise Exception("Failed to connect to Ollama after all retries")
168
+
169
+ def _pull_model(self, model_name: str):
170
+ """Pull a model if it's not available."""
171
+ try:
172
+ print(f"📥 Pulling model {model_name}...")
173
+ pull_response = requests.post(
174
+ f"{self.base_url}/api/pull",
175
+ json={"name": model_name},
176
+ timeout=300, # 5 minutes for model download
177
+ )
178
+ if pull_response.status_code == 200:
179
+ print(f"✅ Successfully pulled {model_name}")
180
+ else:
181
+ print(f"⚠️ Failed to pull {model_name}: {pull_response.status_code}")
182
+ # Try smaller models as fallback
183
+ fallback_models = ["llama3.2:1b", "llama2:latest", "mistral:latest"]
184
+ for fallback in fallback_models:
185
+ try:
186
+ print(f"🔄 Trying fallback model: {fallback}")
187
+ fallback_response = requests.post(
188
+ f"{self.base_url}/api/pull",
189
+ json={"name": fallback},
190
+ timeout=300,
191
+ )
192
+ if fallback_response.status_code == 200:
193
+ print(f"✅ Successfully pulled fallback {fallback}")
194
+ self.model_name = fallback
195
+ return
196
+ except:
197
+ continue
198
+ raise Exception(f"Failed to pull {model_name} or any fallback models")
199
+ except Exception as e:
200
+ print(f"❌ Model pull failed: {e}")
201
+ raise
202
+
203
+ def _format_context(self, chunks: List[Dict[str, Any]]) -> str:
204
+ """Format retrieved chunks into context."""
205
+ context_parts = []
206
+
207
+ for i, chunk in enumerate(chunks):
208
+ chunk_text = chunk.get("content", chunk.get("text", ""))
209
+ page_num = chunk.get("metadata", {}).get("page_number", "unknown")
210
+ source = chunk.get("metadata", {}).get("source", "unknown")
211
+
212
+ context_parts.append(
213
+ f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n"
214
+ )
215
+
216
+ return "\n---\n".join(context_parts)
217
+
218
+ def _create_prompt(self, query: str, context: str, chunks: List[Dict[str, Any]]) -> str:
219
+ """Create optimized prompt with dynamic length constraints and citation instructions."""
220
+ # Get the appropriate template based on query type
221
+ prompt_data = TechnicalPromptTemplates.format_prompt_with_template(
222
+ query=query, context=context
223
+ )
224
+
225
+ # Create dynamic citation instructions based on available chunks
226
+ num_chunks = len(chunks)
227
+ available_chunks = ", ".join([f"[chunk_{i+1}]" for i in range(min(num_chunks, 5))]) # Show max 5 examples
228
+
229
+ # Create appropriate example based on actual chunks
230
+ if num_chunks == 1:
231
+ citation_example = "RISC-V is an open-source ISA [chunk_1]."
232
+ elif num_chunks == 2:
233
+ citation_example = "RISC-V is an open-source ISA [chunk_1] that supports multiple data widths [chunk_2]."
234
+ else:
235
+ citation_example = "RISC-V is an open-source ISA [chunk_1] that supports multiple data widths [chunk_2] and provides extensions [chunk_3]."
236
+
237
+ # Determine optimal answer length based on query complexity
238
+ target_length = self._determine_target_length(query, chunks)
239
+ length_instruction = self._create_length_instruction(target_length)
240
+
241
+ # Format for different model types
242
+ if "llama" in self.model_name.lower():
243
+ # Llama-3.2 format with technical prompt templates
244
+ return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
245
+ {prompt_data['system']}
246
+
247
+ MANDATORY CITATION RULES:
248
+ - ONLY use available chunks: {available_chunks}
249
+ - You have {num_chunks} chunks available - DO NOT cite chunk numbers higher than {num_chunks}
250
+ - Every technical claim MUST have a citation from available chunks
251
+ - Example: "{citation_example}"
252
+
253
+ {length_instruction}
254
+
255
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
256
+ {prompt_data['user']}
257
+
258
+ CRITICAL: You MUST cite sources ONLY from available chunks: {available_chunks}. DO NOT use chunk numbers > {num_chunks}.
259
+ {length_instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
260
+
261
+ elif "mistral" in self.model_name.lower():
262
+ # Mistral format with technical templates
263
+ return f"""[INST] {prompt_data['system']}
264
+
265
+ Context:
266
+ {context}
267
+
268
+ Question: {query}
269
+
270
+ MANDATORY: ONLY use available chunks: {available_chunks}. DO NOT cite chunk numbers > {num_chunks}.
271
+ {length_instruction} [/INST]"""
272
+
273
+ else:
274
+ # Generic format with technical templates
275
+ return f"""{prompt_data['system']}
276
+
277
+ Context:
278
+ {context}
279
+
280
+ Question: {query}
281
+
282
+ MANDATORY CITATIONS: ONLY use available chunks: {available_chunks}. DO NOT cite chunk numbers > {num_chunks}.
283
+ {length_instruction}
284
+
285
+ Answer:"""
286
+
287
+ def _determine_target_length(self, query: str, chunks: List[Dict[str, Any]]) -> int:
288
+ """
289
+ Determine optimal answer length based on query complexity.
290
+
291
+ Target range: 150-400 characters (down from 1000-2600)
292
+ """
293
+ # Analyze query complexity
294
+ query_length = len(query)
295
+ query_words = len(query.split())
296
+
297
+ # Check for complexity indicators
298
+ complex_words = [
299
+ "explain", "describe", "analyze", "compare", "contrast",
300
+ "evaluate", "discuss", "detail", "elaborate", "comprehensive"
301
+ ]
302
+
303
+ simple_words = [
304
+ "what", "define", "list", "name", "identify", "is", "are"
305
+ ]
306
+
307
+ query_lower = query.lower()
308
+ is_complex = any(word in query_lower for word in complex_words)
309
+ is_simple = any(word in query_lower for word in simple_words)
310
+
311
+ # Base length from query type
312
+ if is_complex:
313
+ base_length = 350 # Complex queries get longer answers
314
+ elif is_simple:
315
+ base_length = 200 # Simple queries get shorter answers
316
+ else:
317
+ base_length = 275 # Default middle ground
318
+
319
+ # Adjust based on available context
320
+ context_factor = min(len(chunks) * 25, 75) # More context allows longer answers
321
+
322
+ # Adjust based on query length
323
+ query_factor = min(query_words * 5, 50) # Longer queries allow longer answers
324
+
325
+ target_length = base_length + context_factor + query_factor
326
+
327
+ # Constrain to target range
328
+ return max(150, min(target_length, 400))
329
+
330
+ def _create_length_instruction(self, target_length: int) -> str:
331
+ """Create length instruction based on target length."""
332
+ if target_length <= 200:
333
+ return f"ANSWER LENGTH: Keep your answer concise and focused, approximately {target_length} characters. Be direct and to the point."
334
+ elif target_length <= 300:
335
+ return f"ANSWER LENGTH: Provide a clear and informative answer, approximately {target_length} characters. Include key details but avoid unnecessary elaboration."
336
+ else:
337
+ return f"ANSWER LENGTH: Provide a comprehensive but concise answer, approximately {target_length} characters. Include important details while maintaining clarity."
338
+
339
+ def _call_ollama(self, prompt: str) -> str:
340
+ """Call Ollama API for generation."""
341
+ payload = {
342
+ "model": self.model_name,
343
+ "prompt": prompt,
344
+ "stream": False,
345
+ "options": {
346
+ "temperature": self.temperature,
347
+ "num_predict": self.max_tokens,
348
+ "top_p": 0.9,
349
+ "repeat_penalty": 1.1,
350
+ },
351
+ }
352
+
353
+ try:
354
+ response = requests.post(
355
+ f"{self.base_url}/api/generate", json=payload, timeout=300
356
+ )
357
+
358
+ response.raise_for_status()
359
+ result = response.json()
360
+
361
+ return result.get("response", "").strip()
362
+
363
+ except requests.exceptions.RequestException as e:
364
+ print(f"❌ Ollama API error: {e}")
365
+ return f"Error communicating with Ollama: {str(e)}"
366
+ except Exception as e:
367
+ print(f"❌ Unexpected error: {e}")
368
+ return f"Unexpected error: {str(e)}"
369
+
370
+ def _extract_citations(
371
+ self, answer: str, chunks: List[Dict[str, Any]]
372
+ ) -> Tuple[str, List[Citation]]:
373
+ """Extract citations from the generated answer."""
374
+ citations = []
375
+ citation_pattern = r"\[chunk_(\d+)\]"
376
+
377
+ cited_chunks = set()
378
+
379
+ # Find [chunk_X] citations
380
+ matches = re.finditer(citation_pattern, answer)
381
+ for match in matches:
382
+ chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index
383
+ if 0 <= chunk_idx < len(chunks):
384
+ cited_chunks.add(chunk_idx)
385
+
386
+ # FALLBACK: If no explicit citations found but we have an answer and chunks,
387
+ # create citations for the top chunks that were likely used
388
+ if not cited_chunks and chunks and len(answer.strip()) > 50:
389
+ # Use the top chunks that were provided as likely sources
390
+ num_fallback_citations = min(3, len(chunks)) # Use top 3 chunks max
391
+ cited_chunks = set(range(num_fallback_citations))
392
+ print(
393
+ f"🔧 Fallback: Creating {num_fallback_citations} citations for answer without explicit [chunk_X] references",
394
+ file=sys.stderr,
395
+ flush=True,
396
+ )
397
+
398
+ # Create Citation objects
399
+ chunk_to_source = {}
400
+ for idx in cited_chunks:
401
+ chunk = chunks[idx]
402
+ citation = Citation(
403
+ chunk_id=chunk.get("id", f"chunk_{idx}"),
404
+ page_number=chunk.get("metadata", {}).get("page_number", 0),
405
+ source_file=chunk.get("metadata", {}).get("source", "unknown"),
406
+ relevance_score=chunk.get("score", 0.0),
407
+ text_snippet=chunk.get("content", chunk.get("text", ""))[:200] + "...",
408
+ )
409
+ citations.append(citation)
410
+
411
+ # Don't replace chunk references - keep them as proper citations
412
+ # The issue was that replacing [chunk_X] with "the documentation" creates repetitive text
413
+ # Instead, we should keep the proper citation format
414
+ pass
415
+
416
+ # Keep the answer as-is with proper [chunk_X] citations
417
+ # Don't replace citations with repetitive text
418
+ natural_answer = re.sub(r"\s+", " ", answer).strip()
419
+
420
+ return natural_answer, citations
421
+
422
+ def _calculate_confidence(
423
+ self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]
424
+ ) -> float:
425
+ """
426
+ Calculate confidence score with expanded multi-factor assessment.
427
+
428
+ Enhanced algorithm expands range from 0.75-0.95 to 0.3-0.9 with:
429
+ - Context quality assessment
430
+ - Citation quality evaluation
431
+ - Semantic relevance scoring
432
+ - Off-topic detection
433
+ - Answer completeness analysis
434
+ """
435
+ if not answer or len(answer.strip()) < 10:
436
+ return 0.1
437
+
438
+ # 1. Context Quality Assessment (0.3-0.6 base range)
439
+ context_quality = self._assess_context_quality(chunks)
440
+
441
+ # 2. Citation Quality Evaluation (0.0-0.2 boost)
442
+ citation_quality = self._assess_citation_quality(citations, chunks)
443
+
444
+ # 3. Semantic Relevance Scoring (0.0-0.15 boost)
445
+ semantic_relevance = self._assess_semantic_relevance(answer, chunks)
446
+
447
+ # 4. Off-topic Detection (-0.4 penalty if off-topic)
448
+ off_topic_penalty = self._detect_off_topic(answer, chunks)
449
+
450
+ # 5. Answer Completeness Analysis (0.0-0.1 boost)
451
+ completeness_bonus = self._assess_answer_completeness(answer, len(chunks))
452
+
453
+ # Combine all factors
454
+ confidence = (
455
+ context_quality +
456
+ citation_quality +
457
+ semantic_relevance +
458
+ completeness_bonus +
459
+ off_topic_penalty
460
+ )
461
+
462
+ # Apply uncertainty penalty
463
+ uncertainty_phrases = [
464
+ "insufficient information",
465
+ "cannot determine",
466
+ "not available in the provided documents",
467
+ "I don't have enough context",
468
+ "the context doesn't seem to provide"
469
+ ]
470
+
471
+ if any(phrase in answer.lower() for phrase in uncertainty_phrases):
472
+ confidence *= 0.4 # Stronger penalty for uncertainty
473
+
474
+ # Constrain to target range 0.3-0.9
475
+ return max(0.3, min(confidence, 0.9))
476
+
477
+ def _assess_context_quality(self, chunks: List[Dict[str, Any]]) -> float:
478
+ """Assess quality of context chunks (0.3-0.6 range)."""
479
+ if not chunks:
480
+ return 0.3
481
+
482
+ # Base score from chunk count
483
+ if len(chunks) >= 3:
484
+ base_score = 0.6
485
+ elif len(chunks) >= 2:
486
+ base_score = 0.5
487
+ else:
488
+ base_score = 0.4
489
+
490
+ # Quality adjustments based on chunk content
491
+ avg_chunk_length = sum(len(chunk.get("content", chunk.get("text", ""))) for chunk in chunks) / len(chunks)
492
+
493
+ if avg_chunk_length > 500: # Rich content
494
+ base_score += 0.05
495
+ elif avg_chunk_length < 100: # Sparse content
496
+ base_score -= 0.05
497
+
498
+ return max(0.3, min(base_score, 0.6))
499
+
500
+ def _assess_citation_quality(self, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float:
501
+ """Assess citation quality (0.0-0.2 range)."""
502
+ if not citations or not chunks:
503
+ return 0.0
504
+
505
+ # Citation coverage bonus
506
+ citation_ratio = len(citations) / min(len(chunks), 3)
507
+ coverage_bonus = 0.1 * citation_ratio
508
+
509
+ # Citation diversity bonus (multiple sources)
510
+ unique_sources = len(set(cit.source_file for cit in citations))
511
+ diversity_bonus = 0.05 * min(unique_sources / max(len(chunks), 1), 1.0)
512
+
513
+ return min(coverage_bonus + diversity_bonus, 0.2)
514
+
515
+ def _assess_semantic_relevance(self, answer: str, chunks: List[Dict[str, Any]]) -> float:
516
+ """Assess semantic relevance between answer and context (0.0-0.15 range)."""
517
+ if not answer or not chunks:
518
+ return 0.0
519
+
520
+ # Simple keyword overlap assessment
521
+ answer_words = set(answer.lower().split())
522
+ context_words = set()
523
+
524
+ for chunk in chunks:
525
+ chunk_text = chunk.get("content", chunk.get("text", ""))
526
+ context_words.update(chunk_text.lower().split())
527
+
528
+ if not context_words:
529
+ return 0.0
530
+
531
+ # Calculate overlap ratio
532
+ overlap = len(answer_words & context_words)
533
+ total_unique = len(answer_words | context_words)
534
+
535
+ if total_unique == 0:
536
+ return 0.0
537
+
538
+ overlap_ratio = overlap / total_unique
539
+ return min(0.15 * overlap_ratio, 0.15)
540
+
541
+ def _detect_off_topic(self, answer: str, chunks: List[Dict[str, Any]]) -> float:
542
+ """Detect if answer is off-topic (-0.4 penalty if off-topic)."""
543
+ if not answer or not chunks:
544
+ return 0.0
545
+
546
+ # Check for off-topic indicators
547
+ off_topic_phrases = [
548
+ "but I have to say that the context doesn't seem to provide",
549
+ "these documents appear to be focused on",
550
+ "but they don't seem to cover",
551
+ "I'd recommend consulting a different type of documentation",
552
+ "without more context or information"
553
+ ]
554
+
555
+ answer_lower = answer.lower()
556
+ for phrase in off_topic_phrases:
557
+ if phrase in answer_lower:
558
+ return -0.4 # Strong penalty for off-topic responses
559
+
560
+ return 0.0
561
+
562
+ def _assess_answer_completeness(self, answer: str, chunk_count: int) -> float:
563
+ """Assess answer completeness (0.0-0.1 range)."""
564
+ if not answer:
565
+ return 0.0
566
+
567
+ # Length-based completeness assessment
568
+ answer_length = len(answer)
569
+
570
+ if answer_length > 500: # Comprehensive answer
571
+ return 0.1
572
+ elif answer_length > 200: # Adequate answer
573
+ return 0.05
574
+ else: # Brief answer
575
+ return 0.0
576
+
577
+ def generate(self, query: str, context: List[Document]) -> Answer:
578
+ """
579
+ Generate an answer from query and context documents (standard interface).
580
+
581
+ This is the public interface that conforms to the AnswerGenerator protocol.
582
+ It handles the conversion between standard Document objects and Ollama's
583
+ internal chunk format.
584
+
585
+ Args:
586
+ query: User's question
587
+ context: List of relevant Document objects
588
+
589
+ Returns:
590
+ Answer object conforming to standard interface
591
+
592
+ Raises:
593
+ ValueError: If query is empty or context is None
594
+ """
595
+ if not query.strip():
596
+ raise ValueError("Query cannot be empty")
597
+
598
+ if context is None:
599
+ raise ValueError("Context cannot be None")
600
+
601
+ # Internal adapter: Convert Documents to Ollama chunk format
602
+ ollama_chunks = self._documents_to_ollama_chunks(context)
603
+
604
+ # Use existing Ollama-specific generation logic
605
+ ollama_result = self._generate_internal(query, ollama_chunks)
606
+
607
+ # Internal adapter: Convert Ollama result to standard Answer
608
+ return self._ollama_result_to_answer(ollama_result, context)
609
+
610
+ def _generate_internal(self, query: str, chunks: List[Dict[str, Any]]) -> GeneratedAnswer:
611
+ """
612
+ Generate an answer based on the query and retrieved chunks.
613
+
614
+ Args:
615
+ query: User's question
616
+ chunks: Retrieved document chunks
617
+
618
+ Returns:
619
+ GeneratedAnswer object with answer, citations, and metadata
620
+ """
621
+ start_time = datetime.now()
622
+
623
+ # Check for no-context situation
624
+ if not chunks or all(
625
+ len(chunk.get("content", chunk.get("text", ""))) < 20 for chunk in chunks
626
+ ):
627
+ return GeneratedAnswer(
628
+ answer="This information isn't available in the provided documents.",
629
+ citations=[],
630
+ confidence_score=0.05,
631
+ generation_time=0.1,
632
+ model_used=self.model_name,
633
+ context_used=chunks,
634
+ )
635
+
636
+ # Format context
637
+ context = self._format_context(chunks)
638
+
639
+ # Create prompt with chunks parameter for dynamic citation instructions
640
+ prompt = self._create_prompt(query, context, chunks)
641
+
642
+ # Generate answer
643
+ print(
644
+ f"🤖 Calling Ollama with {self.model_name}...", file=sys.stderr, flush=True
645
+ )
646
+ answer_with_citations = self._call_ollama(prompt)
647
+
648
+ generation_time = (datetime.now() - start_time).total_seconds()
649
+
650
+ # Extract citations and create natural answer
651
+ natural_answer, citations = self._extract_citations(
652
+ answer_with_citations, chunks
653
+ )
654
+
655
+ # Calculate confidence
656
+ confidence = self._calculate_confidence(natural_answer, citations, chunks)
657
+
658
+ return GeneratedAnswer(
659
+ answer=natural_answer,
660
+ citations=citations,
661
+ confidence_score=confidence,
662
+ generation_time=generation_time,
663
+ model_used=self.model_name,
664
+ context_used=chunks,
665
+ )
666
+
667
+ def generate_with_custom_prompt(
668
+ self,
669
+ query: str,
670
+ chunks: List[Dict[str, Any]],
671
+ custom_prompt: Dict[str, str]
672
+ ) -> GeneratedAnswer:
673
+ """
674
+ Generate answer using a custom prompt (for adaptive prompting).
675
+
676
+ Args:
677
+ query: User's question
678
+ chunks: Retrieved context chunks
679
+ custom_prompt: Dict with 'system' and 'user' prompts
680
+
681
+ Returns:
682
+ GeneratedAnswer with custom prompt enhancement
683
+ """
684
+ start_time = datetime.now()
685
+
686
+ if not chunks:
687
+ return GeneratedAnswer(
688
+ answer="I don't have enough context to answer your question.",
689
+ citations=[],
690
+ confidence_score=0.0,
691
+ generation_time=0.1,
692
+ model_used=self.model_name,
693
+ context_used=chunks,
694
+ )
695
+
696
+ # Build custom prompt based on model type
697
+ if "llama" in self.model_name.lower():
698
+ prompt = f"""[INST] {custom_prompt['system']}
699
+
700
+ {custom_prompt['user']}
701
+
702
+ MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
703
+ elif "mistral" in self.model_name.lower():
704
+ prompt = f"""[INST] {custom_prompt['system']}
705
+
706
+ {custom_prompt['user']}
707
+
708
+ MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]"""
709
+ else:
710
+ # Generic format for other models
711
+ prompt = f"""{custom_prompt['system']}
712
+
713
+ {custom_prompt['user']}
714
+
715
+ MANDATORY: Use [chunk_1], [chunk_2] etc. for all factual statements.
716
+
717
+ Answer:"""
718
+
719
+ # Generate answer
720
+ print(f"🤖 Calling Ollama with custom prompt using {self.model_name}...", file=sys.stderr, flush=True)
721
+ answer_with_citations = self._call_ollama(prompt)
722
+
723
+ generation_time = (datetime.now() - start_time).total_seconds()
724
+
725
+ # Extract citations and create natural answer
726
+ natural_answer, citations = self._extract_citations(answer_with_citations, chunks)
727
+
728
+ # Calculate confidence
729
+ confidence = self._calculate_confidence(natural_answer, citations, chunks)
730
+
731
+ return GeneratedAnswer(
732
+ answer=natural_answer,
733
+ citations=citations,
734
+ confidence_score=confidence,
735
+ generation_time=generation_time,
736
+ model_used=self.model_name,
737
+ context_used=chunks,
738
+ )
739
+
740
+ def _documents_to_ollama_chunks(self, documents: List[Document]) -> List[Dict[str, Any]]:
741
+ """
742
+ Convert Document objects to Ollama's internal chunk format.
743
+
744
+ This internal adapter ensures that Document objects are properly formatted
745
+ for Ollama's processing pipeline while keeping the format requirements
746
+ encapsulated within this class.
747
+
748
+ Args:
749
+ documents: List of Document objects from the standard interface
750
+
751
+ Returns:
752
+ List of chunk dictionaries in Ollama's expected format
753
+ """
754
+ if not documents:
755
+ return []
756
+
757
+ chunks = []
758
+ for i, doc in enumerate(documents):
759
+ chunk = {
760
+ "id": f"chunk_{i+1}",
761
+ "content": doc.content, # Ollama expects "content" field
762
+ "text": doc.content, # Fallback field for compatibility
763
+ "score": 1.0, # Default relevance score
764
+ "metadata": {
765
+ "source": doc.metadata.get("source", "unknown"),
766
+ "page_number": doc.metadata.get("start_page", 1),
767
+ **doc.metadata # Include all original metadata
768
+ }
769
+ }
770
+ chunks.append(chunk)
771
+
772
+ return chunks
773
+
774
+ def _ollama_result_to_answer(self, ollama_result: GeneratedAnswer, original_context: List[Document]) -> Answer:
775
+ """
776
+ Convert Ollama's GeneratedAnswer to the standard Answer format.
777
+
778
+ This internal adapter converts Ollama's result format back to the
779
+ standard interface format expected by the rest of the system.
780
+
781
+ Args:
782
+ ollama_result: Result from Ollama's internal generation
783
+ original_context: Original Document objects for sources
784
+
785
+ Returns:
786
+ Answer object conforming to standard interface
787
+ """
788
+ if not Answer:
789
+ # Fallback if standard interface not available
790
+ return ollama_result
791
+
792
+ # Convert to standard Answer format
793
+ return Answer(
794
+ text=ollama_result.answer,
795
+ sources=original_context, # Use original Document objects
796
+ confidence=ollama_result.confidence_score,
797
+ metadata={
798
+ "model_used": ollama_result.model_used,
799
+ "generation_time": ollama_result.generation_time,
800
+ "citations": [
801
+ {
802
+ "chunk_id": cit.chunk_id,
803
+ "page_number": cit.page_number,
804
+ "source_file": cit.source_file,
805
+ "relevance_score": cit.relevance_score,
806
+ "text_snippet": cit.text_snippet
807
+ }
808
+ for cit in ollama_result.citations
809
+ ],
810
+ "provider": "ollama",
811
+ "temperature": self.temperature,
812
+ "max_tokens": self.max_tokens
813
+ }
814
+ )
815
+
816
+
817
+ # Example usage
818
+ if __name__ == "__main__":
819
+ # Test Ollama connection
820
+ generator = OllamaAnswerGenerator(model_name="llama3.2:3b")
821
+
822
+ # Mock chunks for testing
823
+ test_chunks = [
824
+ {
825
+ "content": "RISC-V is a free and open-source ISA.",
826
+ "metadata": {"page_number": 1, "source": "riscv-spec.pdf"},
827
+ "score": 0.9,
828
+ }
829
+ ]
830
+
831
+ # Test generation
832
+ result = generator.generate("What is RISC-V?", test_chunks)
833
+ print(f"Answer: {result.answer}")
834
+ print(f"Confidence: {result.confidence_score:.2%}")
shared_utils/generation/prompt_optimizer.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A/B Testing Framework for Prompt Optimization.
3
+
4
+ This module provides systematic prompt optimization through A/B testing,
5
+ performance analysis, and automated prompt variation generation.
6
+ """
7
+
8
+ import json
9
+ import time
10
+ from typing import Dict, List, Optional, Tuple, Any
11
+ from dataclasses import dataclass, asdict
12
+ from enum import Enum
13
+ from pathlib import Path
14
+ import numpy as np
15
+ from collections import defaultdict
16
+ import logging
17
+
18
+ from .prompt_templates import QueryType, PromptTemplate, TechnicalPromptTemplates
19
+
20
+
21
+ class OptimizationMetric(Enum):
22
+ """Metrics for evaluating prompt performance."""
23
+ RESPONSE_TIME = "response_time"
24
+ CONFIDENCE_SCORE = "confidence_score"
25
+ CITATION_COUNT = "citation_count"
26
+ ANSWER_LENGTH = "answer_length"
27
+ TECHNICAL_ACCURACY = "technical_accuracy"
28
+ USER_SATISFACTION = "user_satisfaction"
29
+
30
+
31
+ @dataclass
32
+ class PromptVariation:
33
+ """Represents a prompt variation for A/B testing."""
34
+ variation_id: str
35
+ name: str
36
+ description: str
37
+ template: PromptTemplate
38
+ query_type: QueryType
39
+ created_at: float
40
+ metadata: Dict[str, Any]
41
+
42
+
43
+ @dataclass
44
+ class TestResult:
45
+ """Represents a single test result."""
46
+ variation_id: str
47
+ query: str
48
+ query_type: QueryType
49
+ response_time: float
50
+ confidence_score: float
51
+ citation_count: int
52
+ answer_length: int
53
+ technical_accuracy: Optional[float] = None
54
+ user_satisfaction: Optional[float] = None
55
+ timestamp: float = None
56
+ metadata: Dict[str, Any] = None
57
+
58
+ def __post_init__(self):
59
+ if self.timestamp is None:
60
+ self.timestamp = time.time()
61
+ if self.metadata is None:
62
+ self.metadata = {}
63
+
64
+
65
+ @dataclass
66
+ class ComparisonResult:
67
+ """Results of A/B test comparison."""
68
+ variation_a: str
69
+ variation_b: str
70
+ metric: OptimizationMetric
71
+ a_mean: float
72
+ b_mean: float
73
+ improvement_percent: float
74
+ p_value: float
75
+ confidence_interval: Tuple[float, float]
76
+ is_significant: bool
77
+ sample_size: int
78
+ recommendation: str
79
+
80
+
81
+ class PromptOptimizer:
82
+ """
83
+ A/B testing framework for systematic prompt optimization.
84
+
85
+ Features:
86
+ - Automated prompt variation generation
87
+ - Performance metric tracking
88
+ - Statistical significance testing
89
+ - Recommendation engine
90
+ - Persistence and experiment tracking
91
+ """
92
+
93
+ def __init__(self, experiment_dir: str = "experiments"):
94
+ """
95
+ Initialize the prompt optimizer.
96
+
97
+ Args:
98
+ experiment_dir: Directory to store experiment data
99
+ """
100
+ self.experiment_dir = Path(experiment_dir)
101
+ self.experiment_dir.mkdir(exist_ok=True)
102
+
103
+ self.variations: Dict[str, PromptVariation] = {}
104
+ self.test_results: List[TestResult] = []
105
+ self.active_experiments: Dict[str, List[str]] = {}
106
+
107
+ # Load existing experiments
108
+ self._load_experiments()
109
+
110
+ # Setup logging
111
+ logging.basicConfig(level=logging.INFO)
112
+ self.logger = logging.getLogger(__name__)
113
+
114
+ def create_variation(
115
+ self,
116
+ base_template: PromptTemplate,
117
+ query_type: QueryType,
118
+ variation_name: str,
119
+ modifications: Dict[str, str],
120
+ description: str = ""
121
+ ) -> str:
122
+ """
123
+ Create a new prompt variation.
124
+
125
+ Args:
126
+ base_template: Base template to modify
127
+ query_type: Type of query this variation is for
128
+ variation_name: Human-readable name
129
+ modifications: Dict of template field modifications
130
+ description: Description of the variation
131
+
132
+ Returns:
133
+ Variation ID
134
+ """
135
+ variation_id = f"{query_type.value}_{variation_name}_{int(time.time())}"
136
+
137
+ # Create modified template
138
+ modified_template = PromptTemplate(
139
+ system_prompt=modifications.get("system_prompt", base_template.system_prompt),
140
+ context_format=modifications.get("context_format", base_template.context_format),
141
+ query_format=modifications.get("query_format", base_template.query_format),
142
+ answer_guidelines=modifications.get("answer_guidelines", base_template.answer_guidelines)
143
+ )
144
+
145
+ variation = PromptVariation(
146
+ variation_id=variation_id,
147
+ name=variation_name,
148
+ description=description,
149
+ template=modified_template,
150
+ query_type=query_type,
151
+ created_at=time.time(),
152
+ metadata=modifications
153
+ )
154
+
155
+ self.variations[variation_id] = variation
156
+ self._save_variation(variation)
157
+
158
+ self.logger.info(f"Created variation: {variation_id}")
159
+ return variation_id
160
+
161
+ def create_temperature_variations(
162
+ self,
163
+ base_query_type: QueryType,
164
+ temperatures: List[float] = [0.3, 0.5, 0.7, 0.9]
165
+ ) -> List[str]:
166
+ """
167
+ Create variations with different temperature settings.
168
+
169
+ Args:
170
+ base_query_type: Query type to create variations for
171
+ temperatures: List of temperature values to test
172
+
173
+ Returns:
174
+ List of variation IDs
175
+ """
176
+ base_template = TechnicalPromptTemplates.get_template_for_query("")
177
+ if base_query_type != QueryType.GENERAL:
178
+ template_map = {
179
+ QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template,
180
+ QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template,
181
+ QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template,
182
+ QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template,
183
+ QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template,
184
+ QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template,
185
+ QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template,
186
+ }
187
+ base_template = template_map[base_query_type]()
188
+
189
+ variation_ids = []
190
+ for temp in temperatures:
191
+ temp_modification = {
192
+ "system_prompt": base_template.system_prompt + f"\n\nGenerate responses with temperature={temp} (creativity level).",
193
+ "answer_guidelines": base_template.answer_guidelines + f"\n\nAdjust response creativity to temperature={temp}."
194
+ }
195
+
196
+ variation_id = self.create_variation(
197
+ base_template=base_template,
198
+ query_type=base_query_type,
199
+ variation_name=f"temp_{temp}",
200
+ modifications=temp_modification,
201
+ description=f"Temperature variation with {temp} creativity level"
202
+ )
203
+ variation_ids.append(variation_id)
204
+
205
+ return variation_ids
206
+
207
+ def create_length_variations(
208
+ self,
209
+ base_query_type: QueryType,
210
+ length_styles: List[str] = ["concise", "detailed", "comprehensive"]
211
+ ) -> List[str]:
212
+ """
213
+ Create variations with different response length preferences.
214
+
215
+ Args:
216
+ base_query_type: Query type to create variations for
217
+ length_styles: List of length styles to test
218
+
219
+ Returns:
220
+ List of variation IDs
221
+ """
222
+ base_template = TechnicalPromptTemplates.get_template_for_query("")
223
+ if base_query_type != QueryType.GENERAL:
224
+ template_map = {
225
+ QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template,
226
+ QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template,
227
+ QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template,
228
+ QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template,
229
+ QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template,
230
+ QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template,
231
+ QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template,
232
+ }
233
+ base_template = template_map[base_query_type]()
234
+
235
+ length_prompts = {
236
+ "concise": "Be concise and focus on essential information only. Aim for 2-3 sentences per point.",
237
+ "detailed": "Provide detailed explanations with examples. Aim for comprehensive coverage.",
238
+ "comprehensive": "Provide exhaustive detail with multiple examples, edge cases, and related concepts."
239
+ }
240
+
241
+ variation_ids = []
242
+ for style in length_styles:
243
+ length_modification = {
244
+ "answer_guidelines": base_template.answer_guidelines + f"\n\nResponse style: {length_prompts[style]}"
245
+ }
246
+
247
+ variation_id = self.create_variation(
248
+ base_template=base_template,
249
+ query_type=base_query_type,
250
+ variation_name=f"length_{style}",
251
+ modifications=length_modification,
252
+ description=f"Length variation with {style} response style"
253
+ )
254
+ variation_ids.append(variation_id)
255
+
256
+ return variation_ids
257
+
258
+ def create_citation_variations(
259
+ self,
260
+ base_query_type: QueryType,
261
+ citation_styles: List[str] = ["minimal", "standard", "extensive"]
262
+ ) -> List[str]:
263
+ """
264
+ Create variations with different citation requirements.
265
+
266
+ Args:
267
+ base_query_type: Query type to create variations for
268
+ citation_styles: List of citation styles to test
269
+
270
+ Returns:
271
+ List of variation IDs
272
+ """
273
+ base_template = TechnicalPromptTemplates.get_template_for_query("")
274
+ if base_query_type != QueryType.GENERAL:
275
+ template_map = {
276
+ QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template,
277
+ QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template,
278
+ QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template,
279
+ QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template,
280
+ QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template,
281
+ QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template,
282
+ QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template,
283
+ }
284
+ base_template = template_map[base_query_type]()
285
+
286
+ citation_prompts = {
287
+ "minimal": "Use [chunk_X] citations only for direct quotes or specific claims.",
288
+ "standard": "Include [chunk_X] citations for each major point or claim.",
289
+ "extensive": "Provide [chunk_X] citations for every statement. Use multiple citations per point where relevant."
290
+ }
291
+
292
+ variation_ids = []
293
+ for style in citation_styles:
294
+ citation_modification = {
295
+ "answer_guidelines": base_template.answer_guidelines + f"\n\nCitation style: {citation_prompts[style]}"
296
+ }
297
+
298
+ variation_id = self.create_variation(
299
+ base_template=base_template,
300
+ query_type=base_query_type,
301
+ variation_name=f"citation_{style}",
302
+ modifications=citation_modification,
303
+ description=f"Citation variation with {style} citation requirements"
304
+ )
305
+ variation_ids.append(variation_id)
306
+
307
+ return variation_ids
308
+
309
+ def setup_experiment(
310
+ self,
311
+ experiment_name: str,
312
+ variation_ids: List[str],
313
+ test_queries: List[str]
314
+ ) -> str:
315
+ """
316
+ Set up a new A/B test experiment.
317
+
318
+ Args:
319
+ experiment_name: Name of the experiment
320
+ variation_ids: List of variation IDs to test
321
+ test_queries: List of test queries
322
+
323
+ Returns:
324
+ Experiment ID
325
+ """
326
+ experiment_id = f"exp_{experiment_name}_{int(time.time())}"
327
+
328
+ experiment_config = {
329
+ "experiment_id": experiment_id,
330
+ "name": experiment_name,
331
+ "variation_ids": variation_ids,
332
+ "test_queries": test_queries,
333
+ "created_at": time.time(),
334
+ "status": "active"
335
+ }
336
+
337
+ self.active_experiments[experiment_id] = variation_ids
338
+
339
+ # Save experiment config
340
+ experiment_file = self.experiment_dir / f"{experiment_id}.json"
341
+ with open(experiment_file, 'w') as f:
342
+ json.dump(experiment_config, f, indent=2)
343
+
344
+ self.logger.info(f"Created experiment: {experiment_id}")
345
+ return experiment_id
346
+
347
+ def record_test_result(
348
+ self,
349
+ variation_id: str,
350
+ query: str,
351
+ query_type: QueryType,
352
+ response_time: float,
353
+ confidence_score: float,
354
+ citation_count: int,
355
+ answer_length: int,
356
+ technical_accuracy: Optional[float] = None,
357
+ user_satisfaction: Optional[float] = None,
358
+ metadata: Optional[Dict[str, Any]] = None
359
+ ) -> None:
360
+ """
361
+ Record a test result for analysis.
362
+
363
+ Args:
364
+ variation_id: ID of the variation tested
365
+ query: The query that was tested
366
+ query_type: Type of the query
367
+ response_time: Response time in seconds
368
+ confidence_score: Confidence score (0-1)
369
+ citation_count: Number of citations in response
370
+ answer_length: Length of answer in characters
371
+ technical_accuracy: Optional technical accuracy score (0-1)
372
+ user_satisfaction: Optional user satisfaction score (0-1)
373
+ metadata: Optional additional metadata
374
+ """
375
+ result = TestResult(
376
+ variation_id=variation_id,
377
+ query=query,
378
+ query_type=query_type,
379
+ response_time=response_time,
380
+ confidence_score=confidence_score,
381
+ citation_count=citation_count,
382
+ answer_length=answer_length,
383
+ technical_accuracy=technical_accuracy,
384
+ user_satisfaction=user_satisfaction,
385
+ metadata=metadata or {}
386
+ )
387
+
388
+ self.test_results.append(result)
389
+ self._save_test_result(result)
390
+
391
+ self.logger.info(f"Recorded test result for variation: {variation_id}")
392
+
393
+ def analyze_variations(
394
+ self,
395
+ variation_a: str,
396
+ variation_b: str,
397
+ metric: OptimizationMetric,
398
+ min_samples: int = 10
399
+ ) -> ComparisonResult:
400
+ """
401
+ Analyze performance difference between two variations.
402
+
403
+ Args:
404
+ variation_a: First variation ID
405
+ variation_b: Second variation ID
406
+ metric: Metric to compare
407
+ min_samples: Minimum samples required for analysis
408
+
409
+ Returns:
410
+ Comparison result with statistical analysis
411
+ """
412
+ # Filter results for each variation
413
+ results_a = [r for r in self.test_results if r.variation_id == variation_a]
414
+ results_b = [r for r in self.test_results if r.variation_id == variation_b]
415
+
416
+ if len(results_a) < min_samples or len(results_b) < min_samples:
417
+ raise ValueError(f"Insufficient samples. Need at least {min_samples} for each variation.")
418
+
419
+ # Extract metric values
420
+ values_a = self._extract_metric_values(results_a, metric)
421
+ values_b = self._extract_metric_values(results_b, metric)
422
+
423
+ # Calculate statistics
424
+ mean_a = np.mean(values_a)
425
+ mean_b = np.mean(values_b)
426
+
427
+ # Calculate improvement percentage
428
+ improvement = ((mean_b - mean_a) / mean_a) * 100
429
+
430
+ # Simple t-test (normally would use scipy.stats.ttest_ind)
431
+ # For now, using basic statistical comparison
432
+ std_a = np.std(values_a)
433
+ std_b = np.std(values_b)
434
+ n_a = len(values_a)
435
+ n_b = len(values_b)
436
+
437
+ # Basic p-value estimation (simplified)
438
+ pooled_std = np.sqrt(((n_a - 1) * std_a**2 + (n_b - 1) * std_b**2) / (n_a + n_b - 2))
439
+ t_stat = (mean_b - mean_a) / (pooled_std * np.sqrt(1/n_a + 1/n_b))
440
+ p_value = 2 * (1 - abs(t_stat) / (abs(t_stat) + 1)) # Rough approximation
441
+
442
+ # Confidence interval (simplified)
443
+ margin_of_error = 1.96 * pooled_std * np.sqrt(1/n_a + 1/n_b)
444
+ ci_lower = (mean_b - mean_a) - margin_of_error
445
+ ci_upper = (mean_b - mean_a) + margin_of_error
446
+
447
+ # Determine significance
448
+ is_significant = p_value < 0.05
449
+
450
+ # Generate recommendation
451
+ if is_significant:
452
+ if improvement > 0:
453
+ recommendation = f"Variation B shows significant improvement ({improvement:.1f}%). Recommend adopting variation B."
454
+ else:
455
+ recommendation = f"Variation A shows significant improvement ({-improvement:.1f}%). Recommend keeping variation A."
456
+ else:
457
+ recommendation = f"No significant difference detected (p={p_value:.3f}). More data needed or variations are equivalent."
458
+
459
+ return ComparisonResult(
460
+ variation_a=variation_a,
461
+ variation_b=variation_b,
462
+ metric=metric,
463
+ a_mean=mean_a,
464
+ b_mean=mean_b,
465
+ improvement_percent=improvement,
466
+ p_value=p_value,
467
+ confidence_interval=(ci_lower, ci_upper),
468
+ is_significant=is_significant,
469
+ sample_size=min(n_a, n_b),
470
+ recommendation=recommendation
471
+ )
472
+
473
+ def get_best_variation(
474
+ self,
475
+ query_type: QueryType,
476
+ metric: OptimizationMetric,
477
+ min_samples: int = 10
478
+ ) -> Optional[str]:
479
+ """
480
+ Get the best performing variation for a query type and metric.
481
+
482
+ Args:
483
+ query_type: Type of query
484
+ metric: Metric to optimize for
485
+ min_samples: Minimum samples required
486
+
487
+ Returns:
488
+ Best variation ID or None if insufficient data
489
+ """
490
+ # Filter results by query type
491
+ relevant_results = [r for r in self.test_results if r.query_type == query_type]
492
+
493
+ # Group by variation
494
+ variation_performance = defaultdict(list)
495
+ for result in relevant_results:
496
+ variation_performance[result.variation_id].append(result)
497
+
498
+ # Calculate mean performance for each variation
499
+ best_variation = None
500
+ best_score = None
501
+
502
+ for variation_id, results in variation_performance.items():
503
+ if len(results) >= min_samples:
504
+ values = self._extract_metric_values(results, metric)
505
+ mean_score = np.mean(values)
506
+
507
+ if best_score is None or mean_score > best_score:
508
+ best_score = mean_score
509
+ best_variation = variation_id
510
+
511
+ return best_variation
512
+
513
+ def generate_optimization_report(
514
+ self,
515
+ experiment_id: str,
516
+ output_file: Optional[str] = None
517
+ ) -> Dict[str, Any]:
518
+ """
519
+ Generate a comprehensive optimization report.
520
+
521
+ Args:
522
+ experiment_id: Experiment to analyze
523
+ output_file: Optional file to save report
524
+
525
+ Returns:
526
+ Report dictionary
527
+ """
528
+ if experiment_id not in self.active_experiments:
529
+ raise ValueError(f"Experiment {experiment_id} not found")
530
+
531
+ variation_ids = self.active_experiments[experiment_id]
532
+ experiment_results = [r for r in self.test_results if r.variation_id in variation_ids]
533
+
534
+ if not experiment_results:
535
+ raise ValueError(f"No results found for experiment {experiment_id}")
536
+
537
+ # Analyze each metric
538
+ metrics = [
539
+ OptimizationMetric.RESPONSE_TIME,
540
+ OptimizationMetric.CONFIDENCE_SCORE,
541
+ OptimizationMetric.CITATION_COUNT,
542
+ OptimizationMetric.ANSWER_LENGTH
543
+ ]
544
+
545
+ report = {
546
+ "experiment_id": experiment_id,
547
+ "variations_tested": len(variation_ids),
548
+ "total_tests": len(experiment_results),
549
+ "analysis_date": time.time(),
550
+ "metric_analysis": {},
551
+ "recommendations": []
552
+ }
553
+
554
+ # Analyze each metric across variations
555
+ for metric in metrics:
556
+ metric_data = {}
557
+ for variation_id in variation_ids:
558
+ var_results = [r for r in experiment_results if r.variation_id == variation_id]
559
+ if var_results:
560
+ values = self._extract_metric_values(var_results, metric)
561
+ metric_data[variation_id] = {
562
+ "mean": np.mean(values),
563
+ "std": np.std(values),
564
+ "count": len(values)
565
+ }
566
+
567
+ report["metric_analysis"][metric.value] = metric_data
568
+
569
+ # Generate recommendations
570
+ for metric in metrics:
571
+ best_variation = self.get_best_variation(
572
+ query_type=QueryType.GENERAL, # Could be made more specific
573
+ metric=metric,
574
+ min_samples=5
575
+ )
576
+ if best_variation:
577
+ report["recommendations"].append({
578
+ "metric": metric.value,
579
+ "best_variation": best_variation,
580
+ "variation_name": self.variations[best_variation].name
581
+ })
582
+
583
+ # Save report if requested
584
+ if output_file:
585
+ with open(output_file, 'w') as f:
586
+ json.dump(report, f, indent=2)
587
+
588
+ return report
589
+
590
+ def _extract_metric_values(self, results: List[TestResult], metric: OptimizationMetric) -> List[float]:
591
+ """Extract metric values from test results."""
592
+ values = []
593
+ for result in results:
594
+ if metric == OptimizationMetric.RESPONSE_TIME:
595
+ values.append(result.response_time)
596
+ elif metric == OptimizationMetric.CONFIDENCE_SCORE:
597
+ values.append(result.confidence_score)
598
+ elif metric == OptimizationMetric.CITATION_COUNT:
599
+ values.append(float(result.citation_count))
600
+ elif metric == OptimizationMetric.ANSWER_LENGTH:
601
+ values.append(float(result.answer_length))
602
+ elif metric == OptimizationMetric.TECHNICAL_ACCURACY and result.technical_accuracy is not None:
603
+ values.append(result.technical_accuracy)
604
+ elif metric == OptimizationMetric.USER_SATISFACTION and result.user_satisfaction is not None:
605
+ values.append(result.user_satisfaction)
606
+
607
+ return values
608
+
609
+ def _load_experiments(self) -> None:
610
+ """Load existing experiments from disk."""
611
+ if not self.experiment_dir.exists():
612
+ return
613
+
614
+ for file_path in self.experiment_dir.glob("*.json"):
615
+ if file_path.name.startswith("exp_"):
616
+ with open(file_path, 'r') as f:
617
+ config = json.load(f)
618
+ self.active_experiments[config["experiment_id"]] = config["variation_ids"]
619
+
620
+ # Load variations and results
621
+ for file_path in self.experiment_dir.glob("variation_*.json"):
622
+ with open(file_path, 'r') as f:
623
+ var_data = json.load(f)
624
+ variation = PromptVariation(**var_data)
625
+ self.variations[variation.variation_id] = variation
626
+
627
+ for file_path in self.experiment_dir.glob("result_*.json"):
628
+ with open(file_path, 'r') as f:
629
+ result_data = json.load(f)
630
+ result = TestResult(**result_data)
631
+ self.test_results.append(result)
632
+
633
+ def _save_variation(self, variation: PromptVariation) -> None:
634
+ """Save variation to disk."""
635
+ file_path = self.experiment_dir / f"variation_{variation.variation_id}.json"
636
+ var_dict = asdict(variation)
637
+
638
+ # Convert template to dict
639
+ var_dict["template"] = asdict(variation.template)
640
+ var_dict["query_type"] = variation.query_type.value
641
+
642
+ with open(file_path, 'w') as f:
643
+ json.dump(var_dict, f, indent=2)
644
+
645
+ def _save_test_result(self, result: TestResult) -> None:
646
+ """Save test result to disk."""
647
+ file_path = self.experiment_dir / f"result_{int(result.timestamp)}.json"
648
+ result_dict = asdict(result)
649
+ result_dict["query_type"] = result.query_type.value
650
+
651
+ with open(file_path, 'w') as f:
652
+ json.dump(result_dict, f, indent=2)
653
+
654
+
655
+ # Example usage
656
+ if __name__ == "__main__":
657
+ # Initialize optimizer
658
+ optimizer = PromptOptimizer()
659
+
660
+ # Create temperature variations for implementation queries
661
+ temp_variations = optimizer.create_temperature_variations(
662
+ base_query_type=QueryType.IMPLEMENTATION,
663
+ temperatures=[0.3, 0.7]
664
+ )
665
+
666
+ # Create length variations for definition queries
667
+ length_variations = optimizer.create_length_variations(
668
+ base_query_type=QueryType.DEFINITION,
669
+ length_styles=["concise", "detailed"]
670
+ )
671
+
672
+ # Setup experiment
673
+ test_queries = [
674
+ "How do I implement a timer interrupt in RISC-V?",
675
+ "What is the difference between machine mode and user mode?",
676
+ "Configure GPIO pins for input/output operations"
677
+ ]
678
+
679
+ experiment_id = optimizer.setup_experiment(
680
+ experiment_name="temperature_vs_length",
681
+ variation_ids=temp_variations + length_variations,
682
+ test_queries=test_queries
683
+ )
684
+
685
+ print(f"Created experiment: {experiment_id}")
686
+ print(f"Variations: {len(temp_variations + length_variations)}")
687
+ print(f"Test queries: {len(test_queries)}")
shared_utils/generation/prompt_templates.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt templates optimized for technical documentation Q&A.
3
+
4
+ This module provides specialized prompt templates for different types of
5
+ technical queries, with a focus on embedded systems and AI documentation.
6
+ """
7
+
8
+ from enum import Enum
9
+ from typing import Dict, List, Optional
10
+ from dataclasses import dataclass
11
+
12
+
13
+ class QueryType(Enum):
14
+ """Types of technical queries."""
15
+ DEFINITION = "definition"
16
+ IMPLEMENTATION = "implementation"
17
+ COMPARISON = "comparison"
18
+ TROUBLESHOOTING = "troubleshooting"
19
+ SPECIFICATION = "specification"
20
+ CODE_EXAMPLE = "code_example"
21
+ HARDWARE_CONSTRAINT = "hardware_constraint"
22
+ GENERAL = "general"
23
+
24
+
25
+ @dataclass
26
+ class PromptTemplate:
27
+ """Represents a prompt template with its components."""
28
+ system_prompt: str
29
+ context_format: str
30
+ query_format: str
31
+ answer_guidelines: str
32
+ few_shot_examples: Optional[List[str]] = None
33
+
34
+
35
+ class TechnicalPromptTemplates:
36
+ """
37
+ Collection of prompt templates optimized for technical documentation.
38
+
39
+ Features:
40
+ - Domain-specific templates for embedded systems and AI
41
+ - Structured output formats
42
+ - Citation requirements
43
+ - Technical accuracy emphasis
44
+ """
45
+
46
+ @staticmethod
47
+ def get_base_system_prompt() -> str:
48
+ """Get the base system prompt for technical documentation."""
49
+ return """You are an expert technical documentation assistant specializing in embedded systems,
50
+ RISC-V architecture, RTOS, and embedded AI/ML. Your role is to provide accurate, detailed
51
+ technical answers based strictly on the provided context.
52
+
53
+ Key responsibilities:
54
+ 1. Answer questions using ONLY information from the provided context
55
+ 2. Include precise citations using [chunk_X] notation for every claim
56
+ 3. Maintain technical accuracy and use correct terminology
57
+ 4. Format code snippets and technical specifications properly
58
+ 5. Clearly state when information is not available in the context
59
+ 6. Consider hardware constraints and embedded system limitations when relevant
60
+
61
+ Write naturally and conversationally. Avoid repetitive phrases and numbered lists unless specifically requested. Never make up information. If the context doesn't contain the answer, say so explicitly."""
62
+
63
+ @staticmethod
64
+ def get_definition_template() -> PromptTemplate:
65
+ """Template for definition/explanation queries."""
66
+ return PromptTemplate(
67
+ system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
68
+
69
+ For definition queries, focus on:
70
+ - Clear, concise technical definitions
71
+ - Related concepts and terminology
72
+ - Technical context and applications
73
+ - Any acronym expansions""",
74
+
75
+ context_format="""Technical Documentation Context:
76
+ {context}""",
77
+
78
+ query_format="""Define or explain: {query}
79
+
80
+ Provide a comprehensive technical definition with proper citations.""",
81
+
82
+ answer_guidelines="""Provide a clear, comprehensive answer that directly addresses the question. Include relevant technical details and cite your sources using [chunk_X] notation. Make your response natural and conversational while maintaining technical accuracy.""",
83
+
84
+ few_shot_examples=[
85
+ """Q: What is RISC-V?
86
+ A: RISC-V is an open-source instruction set architecture (ISA) based on established reduced instruction set computing (RISC) principles [chunk_1]. Unlike proprietary ISAs, RISC-V is freely available under open-source licenses, allowing anyone to implement RISC-V processors without licensing fees [chunk_2]. The architecture supports 32-bit, 64-bit, and 128-bit address spaces, with a modular design that includes base integer instruction sets and optional extensions [chunk_3]. RISC-V stands for "RISC-Five" referring to the fifth generation of RISC architecture developed at UC Berkeley.""",
87
+
88
+ """Q: What is FreeRTOS?
89
+ A: FreeRTOS is a real-time operating system kernel for embedded devices that provides multitasking capabilities for microcontrollers and small microprocessors [chunk_1]. It implements a preemptive scheduler with priority-based task scheduling, ensuring deterministic real-time behavior [chunk_2]. FreeRTOS includes core features like task management, semaphores, queues, and memory management while maintaining a small footprint typically under 10KB [chunk_3]. The "Free" in FreeRTOS refers to both its open-source license and the fact that it's free of charge for commercial use."""
90
+ ]
91
+ )
92
+
93
+ @staticmethod
94
+ def get_implementation_template() -> PromptTemplate:
95
+ """Template for implementation/how-to queries."""
96
+ return PromptTemplate(
97
+ system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
98
+
99
+ For implementation queries, focus on:
100
+ - Step-by-step instructions
101
+ - Required components or dependencies
102
+ - Code examples with proper formatting
103
+ - Hardware/software requirements
104
+ - Common pitfalls or considerations""",
105
+
106
+ context_format="""Implementation Documentation:
107
+ {context}""",
108
+
109
+ query_format="""Implementation question: {query}
110
+
111
+ Provide detailed implementation guidance with code examples where available.""",
112
+
113
+ answer_guidelines="""Structure your answer as:
114
+ 1. Overview of the implementation approach [chunk_X]
115
+ 2. Prerequisites and requirements [chunk_Y]
116
+ 3. Step-by-step implementation:
117
+ - Step 1: Description [chunk_Z]
118
+ - Step 2: Description [chunk_W]
119
+ 4. Code example (if available):
120
+ ```language
121
+ // Code here
122
+ ```
123
+ 5. Important considerations or warnings""",
124
+
125
+ few_shot_examples=[
126
+ """Q: How do I configure GPIO pins for output in RISC-V?
127
+ A: GPIO configuration for output requires setting the pin direction and initial value through memory-mapped registers [chunk_1]. First, identify the GPIO base address for your specific RISC-V implementation (commonly 0x10060000 for SiFive cores) [chunk_2].
128
+
129
+ Steps:
130
+ 1. Set pin direction to output by writing to GPIO_OUTPUT_EN register [chunk_3]
131
+ 2. Configure initial output value using GPIO_OUTPUT_VAL register [chunk_4]
132
+
133
+ ```c
134
+ #define GPIO_BASE 0x10060000
135
+ #define GPIO_OUTPUT_EN (GPIO_BASE + 0x08)
136
+ #define GPIO_OUTPUT_VAL (GPIO_BASE + 0x0C)
137
+
138
+ // Configure pin 5 as output
139
+ volatile uint32_t *gpio_en = (uint32_t*)GPIO_OUTPUT_EN;
140
+ volatile uint32_t *gpio_val = (uint32_t*)GPIO_OUTPUT_VAL;
141
+
142
+ *gpio_en |= (1 << 5); // Enable output on pin 5
143
+ *gpio_val |= (1 << 5); // Set pin 5 high
144
+ ```
145
+
146
+ Important: Always check your board's documentation for the correct GPIO base address and pin mapping [chunk_5].""",
147
+
148
+ """Q: How to implement a basic timer interrupt in RISC-V?
149
+ A: Timer interrupts in RISC-V use the machine timer (mtime) and timer compare (mtimecmp) registers for precise timing control [chunk_1]. The implementation requires configuring the timer hardware, setting up the interrupt handler, and enabling machine timer interrupts [chunk_2].
150
+
151
+ Prerequisites:
152
+ - RISC-V processor with timer support
153
+ - Access to machine-level CSRs
154
+ - Understanding of memory-mapped timer registers [chunk_3]
155
+
156
+ Implementation steps:
157
+ 1. Set up timer compare value in mtimecmp register [chunk_4]
158
+ 2. Enable machine timer interrupt in mie CSR [chunk_5]
159
+ 3. Configure interrupt handler in mtvec CSR [chunk_6]
160
+
161
+ ```c
162
+ #define MTIME_BASE 0x0200bff8
163
+ #define MTIMECMP_BASE 0x02004000
164
+
165
+ void setup_timer_interrupt(uint64_t interval) {
166
+ uint64_t *mtime = (uint64_t*)MTIME_BASE;
167
+ uint64_t *mtimecmp = (uint64_t*)MTIMECMP_BASE;
168
+
169
+ // Set next interrupt time
170
+ *mtimecmp = *mtime + interval;
171
+
172
+ // Enable machine timer interrupt
173
+ asm volatile ("csrs mie, %0" : : "r"(0x80));
174
+
175
+ // Enable global interrupts
176
+ asm volatile ("csrs mstatus, %0" : : "r"(0x8));
177
+ }
178
+ ```
179
+
180
+ Critical considerations: Timer registers are 64-bit and must be accessed atomically on 32-bit systems [chunk_7]."""
181
+ ]
182
+ )
183
+
184
+ @staticmethod
185
+ def get_comparison_template() -> PromptTemplate:
186
+ """Template for comparison queries."""
187
+ return PromptTemplate(
188
+ system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
189
+
190
+ For comparison queries, focus on:
191
+ - Clear distinction between compared items
192
+ - Technical specifications and differences
193
+ - Use cases for each option
194
+ - Performance or resource implications
195
+ - Recommendations based on context""",
196
+
197
+ context_format="""Technical Comparison Context:
198
+ {context}""",
199
+
200
+ query_format="""Compare: {query}
201
+
202
+ Provide a detailed technical comparison with clear distinctions.""",
203
+
204
+ answer_guidelines="""Structure your answer as:
205
+ 1. Overview of items being compared [chunk_X]
206
+ 2. Key differences:
207
+ - Feature A: Item1 vs Item2 [chunk_Y]
208
+ - Feature B: Item1 vs Item2 [chunk_Z]
209
+ 3. Technical specifications comparison
210
+ 4. Use case recommendations
211
+ 5. Performance/resource considerations"""
212
+ )
213
+
214
+ @staticmethod
215
+ def get_specification_template() -> PromptTemplate:
216
+ """Template for specification/parameter queries."""
217
+ return PromptTemplate(
218
+ system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
219
+
220
+ For specification queries, focus on:
221
+ - Exact technical specifications
222
+ - Parameter ranges and limits
223
+ - Units and measurements
224
+ - Compliance with standards
225
+ - Version-specific information""",
226
+
227
+ context_format="""Technical Specifications:
228
+ {context}""",
229
+
230
+ query_format="""Specification query: {query}
231
+
232
+ Provide precise technical specifications with all relevant parameters.""",
233
+
234
+ answer_guidelines="""Structure your answer as:
235
+ 1. Specification overview [chunk_X]
236
+ 2. Detailed parameters:
237
+ - Parameter 1: value (unit) [chunk_Y]
238
+ - Parameter 2: value (unit) [chunk_Z]
239
+ 3. Operating conditions or constraints
240
+ 4. Compliance/standards information
241
+ 5. Version or variant notes"""
242
+ )
243
+
244
+ @staticmethod
245
+ def get_code_example_template() -> PromptTemplate:
246
+ """Template for code example queries."""
247
+ return PromptTemplate(
248
+ system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
249
+
250
+ For code example queries, focus on:
251
+ - Complete, runnable code examples
252
+ - Proper syntax highlighting
253
+ - Clear comments and documentation
254
+ - Error handling
255
+ - Best practices for embedded systems""",
256
+
257
+ context_format="""Code Examples and Documentation:
258
+ {context}""",
259
+
260
+ query_format="""Code example request: {query}
261
+
262
+ Provide working code examples with explanations.""",
263
+
264
+ answer_guidelines="""Structure your answer as:
265
+ 1. Purpose and overview [chunk_X]
266
+ 2. Required includes/imports [chunk_Y]
267
+ 3. Complete code example:
268
+ ```c
269
+ // Or appropriate language
270
+ #include <necessary_headers.h>
271
+
272
+ // Function or code implementation
273
+ ```
274
+ 4. Key points explained [chunk_Z]
275
+ 5. Common variations or modifications"""
276
+ )
277
+
278
+ @staticmethod
279
+ def get_hardware_constraint_template() -> PromptTemplate:
280
+ """Template for hardware constraint queries."""
281
+ return PromptTemplate(
282
+ system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
283
+
284
+ For hardware constraint queries, focus on:
285
+ - Memory requirements (RAM, Flash)
286
+ - Processing power needs (MIPS, frequency)
287
+ - Power consumption
288
+ - I/O requirements
289
+ - Real-time constraints
290
+ - Temperature/environmental limits""",
291
+
292
+ context_format="""Hardware Specifications and Constraints:
293
+ {context}""",
294
+
295
+ query_format="""Hardware constraint question: {query}
296
+
297
+ Analyze feasibility and constraints for embedded deployment.""",
298
+
299
+ answer_guidelines="""Structure your answer as:
300
+ 1. Hardware requirements summary [chunk_X]
301
+ 2. Detailed constraints:
302
+ - Memory: RAM/Flash requirements [chunk_Y]
303
+ - Processing: CPU/frequency needs [chunk_Z]
304
+ - Power: Consumption estimates [chunk_W]
305
+ 3. Feasibility assessment
306
+ 4. Optimization suggestions
307
+ 5. Alternative approaches if constraints are exceeded"""
308
+ )
309
+
310
+ @staticmethod
311
+ def get_troubleshooting_template() -> PromptTemplate:
312
+ """Template for troubleshooting queries."""
313
+ return PromptTemplate(
314
+ system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """
315
+
316
+ For troubleshooting queries, focus on:
317
+ - Common error causes
318
+ - Diagnostic steps
319
+ - Solution procedures
320
+ - Preventive measures
321
+ - Debug techniques for embedded systems""",
322
+
323
+ context_format="""Troubleshooting Documentation:
324
+ {context}""",
325
+
326
+ query_format="""Troubleshooting issue: {query}
327
+
328
+ Provide diagnostic steps and solutions.""",
329
+
330
+ answer_guidelines="""Structure your answer as:
331
+ 1. Problem identification [chunk_X]
332
+ 2. Common causes:
333
+ - Cause 1: Description [chunk_Y]
334
+ - Cause 2: Description [chunk_Z]
335
+ 3. Diagnostic steps:
336
+ - Step 1: Check... [chunk_W]
337
+ - Step 2: Verify... [chunk_V]
338
+ 4. Solutions for each cause
339
+ 5. Prevention recommendations"""
340
+ )
341
+
342
+ @staticmethod
343
+ def get_general_template() -> PromptTemplate:
344
+ """Default template for general queries."""
345
+ return PromptTemplate(
346
+ system_prompt=TechnicalPromptTemplates.get_base_system_prompt(),
347
+
348
+ context_format="""Technical Documentation:
349
+ {context}""",
350
+
351
+ query_format="""Question: {query}
352
+
353
+ Provide a comprehensive technical answer based on the documentation.""",
354
+
355
+ answer_guidelines="""Provide a clear, comprehensive answer that directly addresses the question. Include relevant technical details and cite your sources using [chunk_X] notation. Write naturally and conversationally while maintaining technical accuracy. Acknowledge any limitations in available information."""
356
+ )
357
+
358
+ @staticmethod
359
+ def detect_query_type(query: str) -> QueryType:
360
+ """
361
+ Detect the type of query based on keywords and patterns.
362
+
363
+ Args:
364
+ query: User's question
365
+
366
+ Returns:
367
+ Detected QueryType
368
+ """
369
+ query_lower = query.lower()
370
+
371
+ # Definition keywords
372
+ if any(keyword in query_lower for keyword in [
373
+ "what is", "what are", "define", "definition", "meaning of", "explain what"
374
+ ]):
375
+ return QueryType.DEFINITION
376
+
377
+ # Implementation keywords
378
+ if any(keyword in query_lower for keyword in [
379
+ "how to", "how do i", "implement", "setup", "configure", "install"
380
+ ]):
381
+ return QueryType.IMPLEMENTATION
382
+
383
+ # Comparison keywords
384
+ if any(keyword in query_lower for keyword in [
385
+ "difference between", "compare", "vs", "versus", "better than", "which is"
386
+ ]):
387
+ return QueryType.COMPARISON
388
+
389
+ # Specification keywords
390
+ if any(keyword in query_lower for keyword in [
391
+ "specification", "specs", "parameters", "limits", "range", "maximum", "minimum"
392
+ ]):
393
+ return QueryType.SPECIFICATION
394
+
395
+ # Code example keywords
396
+ if any(keyword in query_lower for keyword in [
397
+ "example", "code", "snippet", "sample", "demo", "show me"
398
+ ]):
399
+ return QueryType.CODE_EXAMPLE
400
+
401
+ # Hardware constraint keywords
402
+ if any(keyword in query_lower for keyword in [
403
+ "memory", "ram", "flash", "mcu", "constraint", "fit on", "run on", "power consumption"
404
+ ]):
405
+ return QueryType.HARDWARE_CONSTRAINT
406
+
407
+ # Troubleshooting keywords
408
+ if any(keyword in query_lower for keyword in [
409
+ "error", "problem", "issue", "debug", "troubleshoot", "fix", "solve", "not working"
410
+ ]):
411
+ return QueryType.TROUBLESHOOTING
412
+
413
+ return QueryType.GENERAL
414
+
415
+ @staticmethod
416
+ def get_template_for_query(query: str) -> PromptTemplate:
417
+ """
418
+ Get the appropriate template based on query type.
419
+
420
+ Args:
421
+ query: User's question
422
+
423
+ Returns:
424
+ Appropriate PromptTemplate
425
+ """
426
+ query_type = TechnicalPromptTemplates.detect_query_type(query)
427
+
428
+ template_map = {
429
+ QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template,
430
+ QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template,
431
+ QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template,
432
+ QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template,
433
+ QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template,
434
+ QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template,
435
+ QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template,
436
+ QueryType.GENERAL: TechnicalPromptTemplates.get_general_template
437
+ }
438
+
439
+ return template_map[query_type]()
440
+
441
+ @staticmethod
442
+ def format_prompt_with_template(
443
+ query: str,
444
+ context: str,
445
+ template: Optional[PromptTemplate] = None,
446
+ include_few_shot: bool = True
447
+ ) -> Dict[str, str]:
448
+ """
449
+ Format a complete prompt using the appropriate template.
450
+
451
+ Args:
452
+ query: User's question
453
+ context: Retrieved context chunks
454
+ template: Optional specific template (auto-detected if None)
455
+ include_few_shot: Whether to include few-shot examples
456
+
457
+ Returns:
458
+ Dict with 'system' and 'user' prompts
459
+ """
460
+ if template is None:
461
+ template = TechnicalPromptTemplates.get_template_for_query(query)
462
+
463
+ # Format the context
464
+ formatted_context = template.context_format.format(context=context)
465
+
466
+ # Format the query
467
+ formatted_query = template.query_format.format(query=query)
468
+
469
+ # Build user prompt with optional few-shot examples
470
+ user_prompt_parts = []
471
+
472
+ # Add few-shot examples if available and requested
473
+ if include_few_shot and template.few_shot_examples:
474
+ user_prompt_parts.append("Here are some examples of how to answer similar questions:")
475
+ user_prompt_parts.append("\n\n".join(template.few_shot_examples))
476
+ user_prompt_parts.append("\nNow answer the following question using the same format:")
477
+
478
+ user_prompt_parts.extend([
479
+ formatted_context,
480
+ formatted_query,
481
+ template.answer_guidelines
482
+ ])
483
+
484
+ user_prompt = "\n\n".join(user_prompt_parts)
485
+
486
+ return {
487
+ "system": template.system_prompt,
488
+ "user": user_prompt
489
+ }
490
+
491
+
492
+ # Example usage and testing
493
+ if __name__ == "__main__":
494
+ # Test query type detection
495
+ test_queries = [
496
+ "What is RISC-V?",
497
+ "How do I implement a timer interrupt?",
498
+ "What's the difference between FreeRTOS and Zephyr?",
499
+ "What are the memory specifications for STM32F4?",
500
+ "Show me an example of GPIO configuration",
501
+ "Can this model run on an MCU with 256KB RAM?",
502
+ "Debug error: undefined reference to main"
503
+ ]
504
+
505
+ for query in test_queries:
506
+ query_type = TechnicalPromptTemplates.detect_query_type(query)
507
+ print(f"Query: '{query}' -> Type: {query_type.value}")
508
+
509
+ # Example prompt formatting
510
+ example_context = "RISC-V is an open instruction set architecture..."
511
+ example_query = "What is RISC-V?"
512
+
513
+ formatted = TechnicalPromptTemplates.format_prompt_with_template(
514
+ query=example_query,
515
+ context=example_context
516
+ )
517
+
518
+ print("\nFormatted prompt example:")
519
+ print("System:", formatted["system"][:100], "...")
520
+ print("User:", formatted["user"][:200], "...")
shared_utils/query_processing/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Query processing utilities for intelligent RAG systems.
3
+ Provides query enhancement, analysis, and optimization capabilities.
4
+ """
5
+
6
+ from .query_enhancer import QueryEnhancer
7
+
8
+ __all__ = ['QueryEnhancer']
shared_utils/query_processing/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (406 Bytes). View file
 
shared_utils/query_processing/__pycache__/query_enhancer.cpython-312.pyc ADDED
Binary file (24.3 kB). View file
 
shared_utils/query_processing/query_enhancer.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Intelligent query processing for technical documentation RAG.
3
+
4
+ Provides adaptive query enhancement through technical term expansion,
5
+ acronym handling, and intelligent hybrid weighting optimization.
6
+ """
7
+
8
+ from typing import Dict, List, Any, Tuple, Set, Optional
9
+ import re
10
+ from collections import defaultdict
11
+ import time
12
+
13
+
14
+ class QueryEnhancer:
15
+ """
16
+ Intelligent query processing for technical documentation RAG.
17
+
18
+ Analyzes query characteristics and enhances retrieval through:
19
+ - Technical synonym expansion
20
+ - Acronym detection and expansion
21
+ - Adaptive hybrid weighting based on query type
22
+ - Query complexity analysis for optimal retrieval strategy
23
+
24
+ Optimized for embedded systems and technical documentation domains.
25
+
26
+ Performance: <10ms query enhancement, improves retrieval relevance by >10%
27
+ """
28
+
29
+ def __init__(self):
30
+ """Initialize QueryEnhancer with technical domain knowledge."""
31
+
32
+ # Technical vocabulary dictionary organized by domain
33
+ self.technical_synonyms = {
34
+ # Processor terminology
35
+ 'cpu': ['processor', 'microprocessor', 'central processing unit'],
36
+ 'mcu': ['microcontroller', 'microcontroller unit', 'embedded processor'],
37
+ 'core': ['processor core', 'cpu core', 'execution unit'],
38
+ 'alu': ['arithmetic logic unit', 'arithmetic unit'],
39
+
40
+ # Memory terminology
41
+ 'memory': ['ram', 'storage', 'buffer', 'cache'],
42
+ 'flash': ['non-volatile memory', 'program memory', 'code storage'],
43
+ 'sram': ['static ram', 'static memory', 'cache memory'],
44
+ 'dram': ['dynamic ram', 'dynamic memory'],
45
+ 'cache': ['buffer', 'temporary storage', 'fast memory'],
46
+
47
+ # Architecture terminology
48
+ 'risc-v': ['riscv', 'risc v', 'open isa', 'open instruction set'],
49
+ 'arm': ['advanced risc machine', 'acorn risc machine'],
50
+ 'isa': ['instruction set architecture', 'instruction set'],
51
+ 'architecture': ['design', 'structure', 'organization'],
52
+
53
+ # Embedded systems terminology
54
+ 'rtos': ['real-time operating system', 'real-time os'],
55
+ 'interrupt': ['isr', 'interrupt service routine', 'exception handler'],
56
+ 'peripheral': ['hardware peripheral', 'external device', 'io device'],
57
+ 'firmware': ['embedded software', 'system software'],
58
+ 'bootloader': ['boot code', 'initialization code'],
59
+
60
+ # Performance terminology
61
+ 'latency': ['delay', 'response time', 'execution time'],
62
+ 'throughput': ['bandwidth', 'data rate', 'performance'],
63
+ 'power': ['power consumption', 'energy usage', 'battery life'],
64
+ 'optimization': ['improvement', 'enhancement', 'tuning'],
65
+
66
+ # Communication protocols
67
+ 'uart': ['serial communication', 'async serial'],
68
+ 'spi': ['serial peripheral interface', 'synchronous serial'],
69
+ 'i2c': ['inter-integrated circuit', 'two-wire interface'],
70
+ 'usb': ['universal serial bus'],
71
+
72
+ # Development terminology
73
+ 'debug': ['debugging', 'troubleshooting', 'testing'],
74
+ 'compile': ['compilation', 'build', 'assembly'],
75
+ 'programming': ['coding', 'development', 'implementation']
76
+ }
77
+
78
+ # Comprehensive acronym expansions for embedded/technical domains
79
+ self.acronym_expansions = {
80
+ # Processor & Architecture
81
+ 'CPU': 'Central Processing Unit',
82
+ 'MCU': 'Microcontroller Unit',
83
+ 'MPU': 'Microprocessor Unit',
84
+ 'DSP': 'Digital Signal Processor',
85
+ 'GPU': 'Graphics Processing Unit',
86
+ 'ALU': 'Arithmetic Logic Unit',
87
+ 'FPU': 'Floating Point Unit',
88
+ 'MMU': 'Memory Management Unit',
89
+ 'ISA': 'Instruction Set Architecture',
90
+ 'RISC': 'Reduced Instruction Set Computer',
91
+ 'CISC': 'Complex Instruction Set Computer',
92
+
93
+ # Memory & Storage
94
+ 'RAM': 'Random Access Memory',
95
+ 'ROM': 'Read Only Memory',
96
+ 'EEPROM': 'Electrically Erasable Programmable ROM',
97
+ 'SRAM': 'Static Random Access Memory',
98
+ 'DRAM': 'Dynamic Random Access Memory',
99
+ 'FRAM': 'Ferroelectric Random Access Memory',
100
+ 'MRAM': 'Magnetoresistive Random Access Memory',
101
+ 'DMA': 'Direct Memory Access',
102
+
103
+ # Operating Systems & Software
104
+ 'RTOS': 'Real-Time Operating System',
105
+ 'OS': 'Operating System',
106
+ 'API': 'Application Programming Interface',
107
+ 'SDK': 'Software Development Kit',
108
+ 'IDE': 'Integrated Development Environment',
109
+ 'HAL': 'Hardware Abstraction Layer',
110
+ 'BSP': 'Board Support Package',
111
+
112
+ # Interrupts & Exceptions
113
+ 'ISR': 'Interrupt Service Routine',
114
+ 'IRQ': 'Interrupt Request',
115
+ 'NMI': 'Non-Maskable Interrupt',
116
+ 'NVIC': 'Nested Vectored Interrupt Controller',
117
+
118
+ # Communication Protocols
119
+ 'UART': 'Universal Asynchronous Receiver Transmitter',
120
+ 'USART': 'Universal Synchronous Asynchronous Receiver Transmitter',
121
+ 'SPI': 'Serial Peripheral Interface',
122
+ 'I2C': 'Inter-Integrated Circuit',
123
+ 'CAN': 'Controller Area Network',
124
+ 'USB': 'Universal Serial Bus',
125
+ 'TCP': 'Transmission Control Protocol',
126
+ 'UDP': 'User Datagram Protocol',
127
+ 'HTTP': 'HyperText Transfer Protocol',
128
+ 'MQTT': 'Message Queuing Telemetry Transport',
129
+
130
+ # Analog & Digital
131
+ 'ADC': 'Analog to Digital Converter',
132
+ 'DAC': 'Digital to Analog Converter',
133
+ 'PWM': 'Pulse Width Modulation',
134
+ 'GPIO': 'General Purpose Input Output',
135
+ 'JTAG': 'Joint Test Action Group',
136
+ 'SWD': 'Serial Wire Debug',
137
+
138
+ # Power & Clock
139
+ 'PLL': 'Phase Locked Loop',
140
+ 'VCO': 'Voltage Controlled Oscillator',
141
+ 'LDO': 'Low Dropout Regulator',
142
+ 'PMU': 'Power Management Unit',
143
+ 'RTC': 'Real Time Clock',
144
+
145
+ # Standards & Organizations
146
+ 'IEEE': 'Institute of Electrical and Electronics Engineers',
147
+ 'ISO': 'International Organization for Standardization',
148
+ 'ANSI': 'American National Standards Institute',
149
+ 'IEC': 'International Electrotechnical Commission'
150
+ }
151
+
152
+ # Compile regex patterns for efficiency
153
+ self._acronym_pattern = re.compile(r'\b[A-Z]{2,}\b')
154
+ self._technical_term_pattern = re.compile(r'\b\w+(?:-\w+)*\b', re.IGNORECASE)
155
+ self._question_indicators = re.compile(r'\b(?:how|what|why|when|where|which|explain|describe|define)\b', re.IGNORECASE)
156
+
157
+ # Question type classification keywords
158
+ self.question_type_keywords = {
159
+ 'conceptual': ['how', 'why', 'what', 'explain', 'describe', 'understand', 'concept', 'theory'],
160
+ 'technical': ['configure', 'implement', 'setup', 'install', 'code', 'program', 'register'],
161
+ 'procedural': ['steps', 'process', 'procedure', 'workflow', 'guide', 'tutorial'],
162
+ 'troubleshooting': ['error', 'problem', 'issue', 'debug', 'fix', 'solve', 'troubleshoot']
163
+ }
164
+
165
+ def analyze_query_characteristics(self, query: str) -> Dict[str, Any]:
166
+ """
167
+ Analyze query to determine optimal processing strategy.
168
+
169
+ Performs comprehensive analysis including:
170
+ - Technical term detection and counting
171
+ - Acronym presence identification
172
+ - Question type classification
173
+ - Complexity scoring based on multiple factors
174
+ - Optimal hybrid weight recommendation
175
+
176
+ Args:
177
+ query: User input query string
178
+
179
+ Returns:
180
+ Dictionary with comprehensive query analysis:
181
+ - technical_term_count: Number of domain-specific terms detected
182
+ - has_acronyms: Boolean indicating acronym presence
183
+ - question_type: 'conceptual', 'technical', 'procedural', 'mixed'
184
+ - complexity_score: Float 0-1 indicating query complexity
185
+ - recommended_dense_weight: Optimal weight for hybrid search
186
+ - detected_acronyms: List of acronyms found
187
+ - technical_terms: List of technical terms found
188
+
189
+ Performance: <2ms for typical queries
190
+ """
191
+ if not query or not query.strip():
192
+ return {
193
+ 'technical_term_count': 0,
194
+ 'has_acronyms': False,
195
+ 'question_type': 'unknown',
196
+ 'complexity_score': 0.0,
197
+ 'recommended_dense_weight': 0.7,
198
+ 'detected_acronyms': [],
199
+ 'technical_terms': []
200
+ }
201
+
202
+ query_lower = query.lower()
203
+ words = query.split()
204
+
205
+ # Detect acronyms
206
+ detected_acronyms = self._acronym_pattern.findall(query)
207
+ has_acronyms = len(detected_acronyms) > 0
208
+
209
+ # Detect technical terms
210
+ technical_terms = []
211
+ technical_term_count = 0
212
+
213
+ for word in words:
214
+ word_clean = re.sub(r'[^\w\-]', '', word.lower())
215
+ if word_clean in self.technical_synonyms:
216
+ technical_terms.append(word_clean)
217
+ technical_term_count += 1
218
+ # Also check for compound technical terms like "risc-v"
219
+ elif any(term in word_clean for term in ['risc-v', 'arm', 'cpu', 'mcu']):
220
+ technical_terms.append(word_clean)
221
+ technical_term_count += 1
222
+
223
+ # Add acronyms to technical term count
224
+ for acronym in detected_acronyms:
225
+ if acronym in self.acronym_expansions:
226
+ technical_term_count += 1
227
+
228
+ # Determine question type
229
+ question_type = self._classify_question_type(query_lower)
230
+
231
+ # Calculate complexity score (0-1)
232
+ complexity_factors = [
233
+ len(words) / 20.0, # Word count factor (normalized to 20 words max)
234
+ technical_term_count / 5.0, # Technical density (normalized to 5 terms max)
235
+ len(detected_acronyms) / 3.0, # Acronym density (normalized to 3 acronyms max)
236
+ 1.0 if self._question_indicators.search(query) else 0.5, # Question complexity
237
+ ]
238
+ complexity_score = min(1.0, sum(complexity_factors) / len(complexity_factors))
239
+
240
+ # Determine recommended dense weight based on analysis
241
+ recommended_dense_weight = self._calculate_optimal_weight(
242
+ question_type, technical_term_count, has_acronyms, complexity_score
243
+ )
244
+
245
+ return {
246
+ 'technical_term_count': technical_term_count,
247
+ 'has_acronyms': has_acronyms,
248
+ 'question_type': question_type,
249
+ 'complexity_score': complexity_score,
250
+ 'recommended_dense_weight': recommended_dense_weight,
251
+ 'detected_acronyms': detected_acronyms,
252
+ 'technical_terms': technical_terms,
253
+ 'word_count': len(words),
254
+ 'has_question_indicators': bool(self._question_indicators.search(query))
255
+ }
256
+
257
+ def _classify_question_type(self, query_lower: str) -> str:
258
+ """Classify query into conceptual, technical, procedural, or mixed categories."""
259
+ type_scores = defaultdict(int)
260
+
261
+ for question_type, keywords in self.question_type_keywords.items():
262
+ for keyword in keywords:
263
+ if keyword in query_lower:
264
+ type_scores[question_type] += 1
265
+
266
+ if not type_scores:
267
+ return 'mixed'
268
+
269
+ # Return type with highest score, or 'mixed' if tie
270
+ max_score = max(type_scores.values())
271
+ top_types = [t for t, s in type_scores.items() if s == max_score]
272
+
273
+ return top_types[0] if len(top_types) == 1 else 'mixed'
274
+
275
+ def _calculate_optimal_weight(self, question_type: str, tech_terms: int,
276
+ has_acronyms: bool, complexity: float) -> float:
277
+ """Calculate optimal dense weight based on query characteristics."""
278
+
279
+ # Base weights by question type
280
+ base_weights = {
281
+ 'technical': 0.3, # Favor sparse for technical precision
282
+ 'conceptual': 0.8, # Favor dense for conceptual understanding
283
+ 'procedural': 0.5, # Balanced for step-by-step queries
284
+ 'troubleshooting': 0.4, # Slight sparse favor for specific issues
285
+ 'mixed': 0.7, # Default balanced
286
+ 'unknown': 0.7 # Default balanced
287
+ }
288
+
289
+ weight = base_weights.get(question_type, 0.7)
290
+
291
+ # Adjust based on technical term density
292
+ if tech_terms > 2:
293
+ weight -= 0.2 # More technical → favor sparse
294
+ elif tech_terms == 0:
295
+ weight += 0.1 # Less technical → favor dense
296
+
297
+ # Adjust based on acronym presence
298
+ if has_acronyms:
299
+ weight -= 0.1 # Acronyms → favor sparse for exact matching
300
+
301
+ # Adjust based on complexity
302
+ if complexity > 0.8:
303
+ weight += 0.1 # High complexity → favor dense for understanding
304
+ elif complexity < 0.3:
305
+ weight -= 0.1 # Low complexity → favor sparse for precision
306
+
307
+ # Ensure weight stays within valid bounds
308
+ return max(0.1, min(0.9, weight))
309
+
310
+ def expand_technical_terms(self, query: str, max_expansions: int = 1) -> str:
311
+ """
312
+ Expand query with technical synonyms while preventing bloat.
313
+
314
+ Conservative expansion strategy:
315
+ - Maximum 1 synonym per technical term by default
316
+ - Prioritizes most relevant/common synonyms
317
+ - Maintains semantic focus while improving recall
318
+
319
+ Args:
320
+ query: Original user query
321
+ max_expansions: Maximum synonyms per term (default 1 for focus)
322
+
323
+ Returns:
324
+ Conservatively enhanced query
325
+
326
+ Example:
327
+ Input: "CPU performance optimization"
328
+ Output: "CPU processor performance optimization"
329
+
330
+ Performance: <3ms for typical queries
331
+ """
332
+ if not query or not query.strip():
333
+ return query
334
+
335
+ words = query.split()
336
+
337
+ # Conservative expansion: only add most relevant synonym
338
+ expansion_candidates = []
339
+
340
+ for word in words:
341
+ word_clean = re.sub(r'[^\w\-]', '', word.lower())
342
+
343
+ # Check for direct synonym expansion
344
+ if word_clean in self.technical_synonyms:
345
+ synonyms = self.technical_synonyms[word_clean]
346
+ # Add only the first (most common) synonym
347
+ if synonyms and max_expansions > 0:
348
+ expansion_candidates.append(synonyms[0])
349
+
350
+ # Limit total expansion to prevent bloat
351
+ max_total_expansions = min(2, len(words) // 2) # At most 50% expansion
352
+ selected_expansions = expansion_candidates[:max_total_expansions]
353
+
354
+ # Reconstruct with minimal expansion
355
+ if selected_expansions:
356
+ return ' '.join(words + selected_expansions)
357
+ else:
358
+ return query
359
+
360
+ def detect_and_expand_acronyms(self, query: str, conservative: bool = True) -> str:
361
+ """
362
+ Detect technical acronyms and add their expansions conservatively.
363
+
364
+ Conservative approach to prevent query bloat:
365
+ - Limits acronym expansions to most relevant ones
366
+ - Preserves original acronyms for exact matching
367
+ - Maintains query focus and performance
368
+
369
+ Args:
370
+ query: Query potentially containing acronyms
371
+ conservative: If True, limits expansion to prevent bloat
372
+
373
+ Returns:
374
+ Query with selective acronym expansions
375
+
376
+ Example:
377
+ Input: "RTOS scheduling algorithm"
378
+ Output: "RTOS Real-Time Operating System scheduling algorithm"
379
+
380
+ Performance: <2ms for typical queries
381
+ """
382
+ if not query or not query.strip():
383
+ return query
384
+
385
+ # Find all acronyms in the query
386
+ acronyms = self._acronym_pattern.findall(query)
387
+
388
+ if not acronyms:
389
+ return query
390
+
391
+ # Conservative mode: limit expansions
392
+ if conservative and len(acronyms) > 2:
393
+ # Only expand first 2 acronyms to prevent bloat
394
+ acronyms = acronyms[:2]
395
+
396
+ result = query
397
+
398
+ # Expand selected acronyms
399
+ for acronym in acronyms:
400
+ if acronym in self.acronym_expansions:
401
+ expansion = self.acronym_expansions[acronym]
402
+ # Add expansion after the acronym (preserving original)
403
+ result = result.replace(acronym, f"{acronym} {expansion}", 1)
404
+
405
+ return result
406
+
407
+ def adaptive_hybrid_weighting(self, query: str) -> float:
408
+ """
409
+ Determine optimal dense_weight based on query characteristics.
410
+
411
+ Analyzes query to automatically determine the best balance between
412
+ dense semantic search and sparse keyword matching for optimal results.
413
+
414
+ Strategy:
415
+ - Technical/exact queries → lower dense_weight (favor sparse/BM25)
416
+ - Conceptual questions → higher dense_weight (favor semantic)
417
+ - Mixed queries → balanced weighting based on complexity
418
+
419
+ Args:
420
+ query: User query string
421
+
422
+ Returns:
423
+ Float between 0.1 and 0.9 representing optimal dense_weight
424
+
425
+ Performance: <2ms analysis time
426
+ """
427
+ analysis = self.analyze_query_characteristics(query)
428
+ return analysis['recommended_dense_weight']
429
+
430
+ def enhance_query(self, query: str, conservative: bool = True) -> Dict[str, Any]:
431
+ """
432
+ Comprehensive query enhancement with performance and quality focus.
433
+
434
+ Optimized enhancement strategy:
435
+ - Conservative expansion to maintain semantic focus
436
+ - Performance-first approach with minimal overhead
437
+ - Quality validation to ensure improvements
438
+
439
+ Args:
440
+ query: Original user query
441
+ conservative: Use conservative expansion (recommended for production)
442
+
443
+ Returns:
444
+ Dictionary containing:
445
+ - enhanced_query: Optimized enhanced query
446
+ - optimal_weight: Recommended dense weight
447
+ - analysis: Complete query analysis
448
+ - enhancement_metadata: Performance and quality metrics
449
+
450
+ Performance: <5ms total enhancement time
451
+ """
452
+ start_time = time.perf_counter()
453
+
454
+ # Fast analysis
455
+ analysis = self.analyze_query_characteristics(query)
456
+
457
+ # Conservative enhancement approach
458
+ if conservative:
459
+ enhanced_query = self.expand_technical_terms(query, max_expansions=1)
460
+ enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True)
461
+ else:
462
+ # Legacy aggressive expansion
463
+ enhanced_query = self.expand_technical_terms(query, max_expansions=2)
464
+ enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=False)
465
+
466
+ # Quality validation: prevent excessive bloat
467
+ expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0
468
+ if expansion_ratio > 2.5: # Limit to 2.5x expansion
469
+ # Fallback to minimal enhancement
470
+ enhanced_query = self.expand_technical_terms(query, max_expansions=0)
471
+ enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True)
472
+ expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0
473
+
474
+ # Calculate optimal weight
475
+ optimal_weight = analysis['recommended_dense_weight']
476
+
477
+ enhancement_time = time.perf_counter() - start_time
478
+
479
+ return {
480
+ 'enhanced_query': enhanced_query,
481
+ 'optimal_weight': optimal_weight,
482
+ 'analysis': analysis,
483
+ 'enhancement_metadata': {
484
+ 'original_length': len(query.split()),
485
+ 'enhanced_length': len(enhanced_query.split()),
486
+ 'expansion_ratio': expansion_ratio,
487
+ 'processing_time_ms': enhancement_time * 1000,
488
+ 'techniques_applied': ['conservative_expansion', 'quality_validation', 'adaptive_weighting'],
489
+ 'conservative_mode': conservative
490
+ }
491
+ }
492
+
493
+ def expand_technical_terms_with_vocabulary(
494
+ self,
495
+ query: str,
496
+ vocabulary_index: Optional['VocabularyIndex'] = None,
497
+ min_frequency: int = 3
498
+ ) -> str:
499
+ """
500
+ Expand query with vocabulary-aware synonym filtering.
501
+
502
+ Only adds synonyms that exist in the document corpus with sufficient
503
+ frequency to ensure relevance and prevent query dilution.
504
+
505
+ Args:
506
+ query: Original query
507
+ vocabulary_index: Optional vocabulary index for filtering
508
+ min_frequency: Minimum term frequency required
509
+
510
+ Returns:
511
+ Enhanced query with validated synonyms
512
+
513
+ Performance: <2ms with vocabulary validation
514
+ """
515
+ if not query or not query.strip():
516
+ return query
517
+
518
+ if vocabulary_index is None:
519
+ # Fallback to standard expansion
520
+ return self.expand_technical_terms(query, max_expansions=1)
521
+
522
+ words = query.split()
523
+ expanded_terms = []
524
+
525
+ for word in words:
526
+ word_clean = re.sub(r'[^\w\-]', '', word.lower())
527
+
528
+ # Check for synonym expansion
529
+ if word_clean in self.technical_synonyms:
530
+ synonyms = self.technical_synonyms[word_clean]
531
+
532
+ # Filter synonyms through vocabulary
533
+ valid_synonyms = vocabulary_index.filter_synonyms(
534
+ synonyms,
535
+ min_frequency=min_frequency
536
+ )
537
+
538
+ # Add only the best valid synonym
539
+ if valid_synonyms:
540
+ expanded_terms.append(valid_synonyms[0])
541
+
542
+ # Reconstruct query with validated expansions
543
+ if expanded_terms:
544
+ return ' '.join(words + expanded_terms)
545
+ else:
546
+ return query
547
+
548
+ def enhance_query_with_vocabulary(
549
+ self,
550
+ query: str,
551
+ vocabulary_index: Optional['VocabularyIndex'] = None,
552
+ min_frequency: int = 3,
553
+ require_technical: bool = False
554
+ ) -> Dict[str, Any]:
555
+ """
556
+ Enhanced query processing with vocabulary validation.
557
+
558
+ Uses corpus vocabulary to ensure all expansions are relevant
559
+ and actually present in the documents.
560
+
561
+ Args:
562
+ query: Original query
563
+ vocabulary_index: Vocabulary index for validation
564
+ min_frequency: Minimum term frequency
565
+ require_technical: Only expand with technical terms
566
+
567
+ Returns:
568
+ Enhanced query with vocabulary-aware expansion
569
+ """
570
+ start_time = time.perf_counter()
571
+
572
+ # Perform analysis
573
+ analysis = self.analyze_query_characteristics(query)
574
+
575
+ # Vocabulary-aware enhancement
576
+ if vocabulary_index:
577
+ # Technical term expansion with validation
578
+ enhanced_query = self.expand_technical_terms_with_vocabulary(
579
+ query, vocabulary_index, min_frequency
580
+ )
581
+
582
+ # Acronym expansion (already conservative)
583
+ enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True)
584
+
585
+ # Track vocabulary validation
586
+ validation_applied = True
587
+
588
+ # Detect domain if available
589
+ detected_domain = vocabulary_index.detect_domain()
590
+ else:
591
+ # Fallback to standard enhancement
592
+ enhanced_query = self.expand_technical_terms(query, max_expansions=1)
593
+ enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True)
594
+ validation_applied = False
595
+ detected_domain = 'unknown'
596
+
597
+ # Calculate metrics
598
+ expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0
599
+ enhancement_time = time.perf_counter() - start_time
600
+
601
+ return {
602
+ 'enhanced_query': enhanced_query,
603
+ 'optimal_weight': analysis['recommended_dense_weight'],
604
+ 'analysis': analysis,
605
+ 'enhancement_metadata': {
606
+ 'original_length': len(query.split()),
607
+ 'enhanced_length': len(enhanced_query.split()),
608
+ 'expansion_ratio': expansion_ratio,
609
+ 'processing_time_ms': enhancement_time * 1000,
610
+ 'techniques_applied': ['vocabulary_validation', 'conservative_expansion'],
611
+ 'vocabulary_validated': validation_applied,
612
+ 'detected_domain': detected_domain,
613
+ 'min_frequency_threshold': min_frequency
614
+ }
615
+ }
616
+
617
+ def get_enhancement_stats(self) -> Dict[str, Any]:
618
+ """
619
+ Get statistics about the enhancement system capabilities.
620
+
621
+ Returns:
622
+ Dictionary with system statistics and capabilities
623
+ """
624
+ return {
625
+ 'technical_synonyms_count': len(self.technical_synonyms),
626
+ 'acronym_expansions_count': len(self.acronym_expansions),
627
+ 'supported_domains': [
628
+ 'embedded_systems', 'processor_architecture', 'memory_systems',
629
+ 'communication_protocols', 'real_time_systems', 'power_management'
630
+ ],
631
+ 'question_types_supported': list(self.question_type_keywords.keys()),
632
+ 'weight_range': {'min': 0.1, 'max': 0.9, 'default': 0.7},
633
+ 'performance_targets': {
634
+ 'enhancement_time_ms': '<10',
635
+ 'accuracy_improvement': '>10%',
636
+ 'memory_overhead': '<1MB'
637
+ },
638
+ 'vocabulary_features': {
639
+ 'vocabulary_aware_expansion': True,
640
+ 'min_frequency_filtering': True,
641
+ 'domain_detection': True,
642
+ 'technical_term_priority': True
643
+ }
644
+ }
shared_utils/retrieval/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Retrieval utilities for hybrid RAG systems.
3
+ Combines dense semantic search with sparse keyword matching.
4
+ """
5
+
6
+ from .hybrid_search import HybridRetriever
7
+
8
+ __all__ = ['HybridRetriever']
shared_utils/retrieval/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (380 Bytes). View file
 
shared_utils/retrieval/__pycache__/hybrid_search.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
shared_utils/retrieval/__pycache__/vocabulary_index.cpython-312.pyc ADDED
Binary file (12.6 kB). View file
 
shared_utils/retrieval/hybrid_search.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid retrieval combining dense semantic search with sparse BM25 keyword matching.
3
+ Uses Reciprocal Rank Fusion (RRF) to combine results from both approaches.
4
+ """
5
+
6
+ from typing import List, Dict, Tuple, Optional
7
+ import numpy as np
8
+ from pathlib import Path
9
+ import sys
10
+
11
+ # Add project root to Python path for imports
12
+ project_root = Path(__file__).parent.parent.parent / "project-1-technical-rag"
13
+ sys.path.append(str(project_root))
14
+
15
+ from src.sparse_retrieval import BM25SparseRetriever
16
+ from src.fusion import reciprocal_rank_fusion, adaptive_fusion
17
+ from shared_utils.embeddings.generator import generate_embeddings
18
+ import faiss
19
+
20
+
21
+ class HybridRetriever:
22
+ """
23
+ Hybrid retrieval system combining dense semantic search with sparse BM25.
24
+
25
+ Optimized for technical documentation where both semantic similarity
26
+ and exact keyword matching are important for retrieval quality.
27
+
28
+ Performance: Sub-second search on 1000+ document corpus
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ dense_weight: float = 0.7,
34
+ embedding_model: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
35
+ use_mps: bool = True,
36
+ bm25_k1: float = 1.2,
37
+ bm25_b: float = 0.75,
38
+ rrf_k: int = 10
39
+ ):
40
+ """
41
+ Initialize hybrid retriever with dense and sparse components.
42
+
43
+ Args:
44
+ dense_weight: Weight for semantic similarity in fusion (0.7 default)
45
+ embedding_model: Sentence transformer model name
46
+ use_mps: Use Apple Silicon MPS acceleration for embeddings
47
+ bm25_k1: BM25 term frequency saturation parameter
48
+ bm25_b: BM25 document length normalization parameter
49
+ rrf_k: Reciprocal Rank Fusion constant (1=strong rank preference, 2=moderate)
50
+
51
+ Raises:
52
+ ValueError: If parameters are invalid
53
+ """
54
+ if not 0 <= dense_weight <= 1:
55
+ raise ValueError("dense_weight must be between 0 and 1")
56
+
57
+ self.dense_weight = dense_weight
58
+ self.embedding_model = embedding_model
59
+ self.use_mps = use_mps
60
+ self.rrf_k = rrf_k
61
+
62
+ # Initialize sparse retriever
63
+ self.sparse_retriever = BM25SparseRetriever(k1=bm25_k1, b=bm25_b)
64
+
65
+ # Dense retrieval components (initialized on first index)
66
+ self.dense_index: Optional[faiss.Index] = None
67
+ self.chunks: List[Dict] = []
68
+ self.embeddings: Optional[np.ndarray] = None
69
+
70
+ def index_documents(self, chunks: List[Dict]) -> None:
71
+ """
72
+ Index documents for both dense and sparse retrieval.
73
+
74
+ Args:
75
+ chunks: List of chunk dictionaries with 'text' field
76
+
77
+ Raises:
78
+ ValueError: If chunks is empty or malformed
79
+
80
+ Performance: ~100 chunks/second for complete indexing
81
+ """
82
+ if not chunks:
83
+ raise ValueError("Cannot index empty chunk list")
84
+
85
+ print(f"Indexing {len(chunks)} chunks for hybrid retrieval...")
86
+
87
+ # Store chunks for retrieval
88
+ self.chunks = chunks
89
+
90
+ # Index for sparse retrieval
91
+ print("Building BM25 sparse index...")
92
+ self.sparse_retriever.index_documents(chunks)
93
+
94
+ # Index for dense retrieval
95
+ print("Building dense semantic index...")
96
+ texts = [chunk['text'] for chunk in chunks]
97
+
98
+ # Generate embeddings
99
+ self.embeddings = generate_embeddings(
100
+ texts,
101
+ model_name=self.embedding_model,
102
+ use_mps=self.use_mps
103
+ )
104
+
105
+ # Create FAISS index
106
+ embedding_dim = self.embeddings.shape[1]
107
+ self.dense_index = faiss.IndexFlatIP(embedding_dim) # Inner product for cosine similarity
108
+
109
+ # Normalize embeddings for cosine similarity
110
+ faiss.normalize_L2(self.embeddings)
111
+ self.dense_index.add(self.embeddings)
112
+
113
+ print(f"Hybrid indexing complete: {len(chunks)} chunks ready for search")
114
+
115
+ def search(
116
+ self,
117
+ query: str,
118
+ top_k: int = 10,
119
+ dense_top_k: Optional[int] = None,
120
+ sparse_top_k: Optional[int] = None
121
+ ) -> List[Tuple[int, float, Dict]]:
122
+ """
123
+ Hybrid search combining dense and sparse retrieval with RRF.
124
+
125
+ Args:
126
+ query: Search query string
127
+ top_k: Final number of results to return
128
+ dense_top_k: Results from dense search (default: 2*top_k)
129
+ sparse_top_k: Results from sparse search (default: 2*top_k)
130
+
131
+ Returns:
132
+ List of (chunk_index, rrf_score, chunk_dict) tuples
133
+
134
+ Raises:
135
+ ValueError: If not indexed or invalid parameters
136
+
137
+ Performance: <200ms for 1000+ document corpus
138
+ """
139
+ if self.dense_index is None:
140
+ raise ValueError("Must call index_documents() before searching")
141
+
142
+ if not query.strip():
143
+ return []
144
+
145
+ if top_k <= 0:
146
+ raise ValueError("top_k must be positive")
147
+
148
+ # Set default intermediate result counts
149
+ if dense_top_k is None:
150
+ dense_top_k = min(2 * top_k, len(self.chunks))
151
+ if sparse_top_k is None:
152
+ sparse_top_k = min(2 * top_k, len(self.chunks))
153
+
154
+ # Dense semantic search
155
+ dense_results = self._dense_search(query, dense_top_k)
156
+
157
+ # Sparse BM25 search
158
+ sparse_results = self.sparse_retriever.search(query, sparse_top_k)
159
+
160
+ # Combine using Adaptive Fusion (better for small result sets)
161
+ fused_results = adaptive_fusion(
162
+ dense_results=dense_results,
163
+ sparse_results=sparse_results,
164
+ dense_weight=self.dense_weight,
165
+ result_size=top_k
166
+ )
167
+
168
+ # Prepare final results with chunk content and apply source diversity
169
+ final_results = []
170
+ for chunk_idx, rrf_score in fused_results:
171
+ chunk_dict = self.chunks[chunk_idx]
172
+ final_results.append((chunk_idx, rrf_score, chunk_dict))
173
+
174
+ # Apply source diversity enhancement
175
+ diverse_results = self._enhance_source_diversity(final_results, top_k)
176
+
177
+ return diverse_results
178
+
179
+ def _dense_search(self, query: str, top_k: int) -> List[Tuple[int, float]]:
180
+ """
181
+ Perform dense semantic search using FAISS.
182
+
183
+ Args:
184
+ query: Search query
185
+ top_k: Number of results to return
186
+
187
+ Returns:
188
+ List of (chunk_index, similarity_score) tuples
189
+ """
190
+ # Generate query embedding
191
+ query_embedding = generate_embeddings(
192
+ [query],
193
+ model_name=self.embedding_model,
194
+ use_mps=self.use_mps
195
+ )
196
+
197
+ # Normalize for cosine similarity
198
+ faiss.normalize_L2(query_embedding)
199
+
200
+ # Search dense index
201
+ similarities, indices = self.dense_index.search(query_embedding, top_k)
202
+
203
+ # Convert to required format
204
+ results = [
205
+ (int(indices[0][i]), float(similarities[0][i]))
206
+ for i in range(len(indices[0]))
207
+ if indices[0][i] != -1 # Filter out invalid results
208
+ ]
209
+
210
+ return results
211
+
212
+ def _enhance_source_diversity(
213
+ self,
214
+ results: List[Tuple[int, float, Dict]],
215
+ top_k: int,
216
+ max_per_source: int = 2
217
+ ) -> List[Tuple[int, float, Dict]]:
218
+ """
219
+ Enhance source diversity in retrieval results to prevent over-focusing on single documents.
220
+
221
+ Args:
222
+ results: List of (chunk_idx, score, chunk_dict) tuples sorted by relevance
223
+ top_k: Maximum number of results to return
224
+ max_per_source: Maximum chunks allowed per source document
225
+
226
+ Returns:
227
+ Diversified results maintaining relevance while improving source coverage
228
+ """
229
+ if not results:
230
+ return []
231
+
232
+ source_counts = {}
233
+ diverse_results = []
234
+
235
+ # First pass: Add highest scoring results respecting source limits
236
+ for chunk_idx, score, chunk_dict in results:
237
+ source = chunk_dict.get('source', 'unknown')
238
+ current_count = source_counts.get(source, 0)
239
+
240
+ if current_count < max_per_source:
241
+ diverse_results.append((chunk_idx, score, chunk_dict))
242
+ source_counts[source] = current_count + 1
243
+
244
+ if len(diverse_results) >= top_k:
245
+ break
246
+
247
+ # Second pass: If we still need more results, relax source constraints
248
+ if len(diverse_results) < top_k:
249
+ for chunk_idx, score, chunk_dict in results:
250
+ if (chunk_idx, score, chunk_dict) not in diverse_results:
251
+ diverse_results.append((chunk_idx, score, chunk_dict))
252
+
253
+ if len(diverse_results) >= top_k:
254
+ break
255
+
256
+ return diverse_results[:top_k]
257
+
258
+ def get_retrieval_stats(self) -> Dict[str, any]:
259
+ """
260
+ Get statistics about the indexed corpus and retrieval performance.
261
+
262
+ Returns:
263
+ Dictionary with corpus statistics
264
+ """
265
+ if not self.chunks:
266
+ return {"status": "not_indexed"}
267
+
268
+ return {
269
+ "status": "indexed",
270
+ "total_chunks": len(self.chunks),
271
+ "dense_index_size": self.dense_index.ntotal if self.dense_index else 0,
272
+ "embedding_dim": self.embeddings.shape[1] if self.embeddings is not None else 0,
273
+ "sparse_indexed_chunks": len(self.sparse_retriever.chunk_mapping),
274
+ "dense_weight": self.dense_weight,
275
+ "sparse_weight": 1.0 - self.dense_weight,
276
+ "rrf_k": self.rrf_k
277
+ }
shared_utils/retrieval/vocabulary_index.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vocabulary index for corpus-aware query enhancement.
3
+
4
+ Tracks all unique terms in the document corpus to enable intelligent
5
+ synonym expansion that only adds terms actually present in documents.
6
+ """
7
+
8
+ from typing import Set, Dict, List, Optional
9
+ from collections import defaultdict
10
+ import re
11
+ from pathlib import Path
12
+ import json
13
+
14
+
15
+ class VocabularyIndex:
16
+ """
17
+ Maintains vocabulary statistics for intelligent query enhancement.
18
+
19
+ Features:
20
+ - Tracks all unique terms in document corpus
21
+ - Stores term frequencies for relevance weighting
22
+ - Identifies technical terms and domain vocabulary
23
+ - Enables vocabulary-aware synonym expansion
24
+
25
+ Performance:
26
+ - Build time: ~1s per 1000 chunks
27
+ - Memory: ~3MB for 80K unique terms
28
+ - Lookup: O(1) set operations
29
+ """
30
+
31
+ def __init__(self):
32
+ """Initialize empty vocabulary index."""
33
+ self.vocabulary: Set[str] = set()
34
+ self.term_frequencies: Dict[str, int] = defaultdict(int)
35
+ self.technical_terms: Set[str] = set()
36
+ self.document_frequencies: Dict[str, int] = defaultdict(int)
37
+ self.total_documents = 0
38
+ self.total_terms = 0
39
+
40
+ # Regex for term extraction
41
+ self._term_pattern = re.compile(r'\b[a-zA-Z][a-zA-Z0-9\-_]*\b')
42
+ self._technical_pattern = re.compile(r'\b[A-Z]{2,}|[a-zA-Z]+[\-_][a-zA-Z]+|\b\d+[a-zA-Z]+\b')
43
+
44
+ def build_from_chunks(self, chunks: List[Dict]) -> None:
45
+ """
46
+ Build vocabulary index from document chunks.
47
+
48
+ Args:
49
+ chunks: List of document chunks with 'text' field
50
+
51
+ Performance: ~1s per 1000 chunks
52
+ """
53
+ self.total_documents = len(chunks)
54
+
55
+ for chunk in chunks:
56
+ text = chunk.get('text', '')
57
+
58
+ # Extract and process terms
59
+ terms = self._extract_terms(text)
60
+ unique_terms = set(terms)
61
+
62
+ # Update vocabulary
63
+ self.vocabulary.update(unique_terms)
64
+
65
+ # Update frequencies
66
+ for term in terms:
67
+ self.term_frequencies[term] += 1
68
+ self.total_terms += 1
69
+
70
+ # Update document frequencies
71
+ for term in unique_terms:
72
+ self.document_frequencies[term] += 1
73
+
74
+ # Identify technical terms
75
+ technical = self._extract_technical_terms(text)
76
+ self.technical_terms.update(technical)
77
+
78
+ def _extract_terms(self, text: str) -> List[str]:
79
+ """Extract normalized terms from text."""
80
+ # Convert to lowercase and extract words
81
+ text_lower = text.lower()
82
+ terms = self._term_pattern.findall(text_lower)
83
+
84
+ # Filter short terms
85
+ return [term for term in terms if len(term) > 2]
86
+
87
+ def _extract_technical_terms(self, text: str) -> Set[str]:
88
+ """Extract technical terms (acronyms, hyphenated, etc)."""
89
+ technical = set()
90
+
91
+ # Find potential technical terms
92
+ matches = self._technical_pattern.findall(text)
93
+
94
+ for match in matches:
95
+ # Normalize but preserve technical nature
96
+ normalized = match.lower()
97
+ if len(normalized) > 2:
98
+ technical.add(normalized)
99
+
100
+ return technical
101
+
102
+ def contains(self, term: str) -> bool:
103
+ """Check if term exists in vocabulary."""
104
+ return term.lower() in self.vocabulary
105
+
106
+ def get_frequency(self, term: str) -> int:
107
+ """Get term frequency in corpus."""
108
+ return self.term_frequencies.get(term.lower(), 0)
109
+
110
+ def get_document_frequency(self, term: str) -> int:
111
+ """Get number of documents containing term."""
112
+ return self.document_frequencies.get(term.lower(), 0)
113
+
114
+ def is_common_term(self, term: str, min_frequency: int = 5) -> bool:
115
+ """Check if term appears frequently enough."""
116
+ return self.get_frequency(term) >= min_frequency
117
+
118
+ def is_technical_term(self, term: str) -> bool:
119
+ """Check if term is identified as technical."""
120
+ return term.lower() in self.technical_terms
121
+
122
+ def filter_synonyms(self, synonyms: List[str],
123
+ min_frequency: int = 3,
124
+ require_technical: bool = False) -> List[str]:
125
+ """
126
+ Filter synonym list to only include terms in vocabulary.
127
+
128
+ Args:
129
+ synonyms: List of potential synonyms
130
+ min_frequency: Minimum term frequency required
131
+ require_technical: Only include technical terms
132
+
133
+ Returns:
134
+ Filtered list of valid synonyms
135
+ """
136
+ valid_synonyms = []
137
+
138
+ for synonym in synonyms:
139
+ # Check existence
140
+ if not self.contains(synonym):
141
+ continue
142
+
143
+ # Check frequency threshold
144
+ if self.get_frequency(synonym) < min_frequency:
145
+ continue
146
+
147
+ # Check technical requirement
148
+ if require_technical and not self.is_technical_term(synonym):
149
+ continue
150
+
151
+ valid_synonyms.append(synonym)
152
+
153
+ return valid_synonyms
154
+
155
+ def get_vocabulary_stats(self) -> Dict[str, any]:
156
+ """Get comprehensive vocabulary statistics."""
157
+ return {
158
+ 'unique_terms': len(self.vocabulary),
159
+ 'total_terms': self.total_terms,
160
+ 'technical_terms': len(self.technical_terms),
161
+ 'total_documents': self.total_documents,
162
+ 'avg_terms_per_doc': self.total_terms / self.total_documents if self.total_documents > 0 else 0,
163
+ 'vocabulary_richness': len(self.vocabulary) / self.total_terms if self.total_terms > 0 else 0,
164
+ 'technical_ratio': len(self.technical_terms) / len(self.vocabulary) if self.vocabulary else 0
165
+ }
166
+
167
+ def get_top_terms(self, n: int = 100, technical_only: bool = False) -> List[tuple]:
168
+ """
169
+ Get most frequent terms in corpus.
170
+
171
+ Args:
172
+ n: Number of top terms to return
173
+ technical_only: Only return technical terms
174
+
175
+ Returns:
176
+ List of (term, frequency) tuples
177
+ """
178
+ if technical_only:
179
+ term_freq = {
180
+ term: freq for term, freq in self.term_frequencies.items()
181
+ if term in self.technical_terms
182
+ }
183
+ else:
184
+ term_freq = self.term_frequencies
185
+
186
+ return sorted(term_freq.items(), key=lambda x: x[1], reverse=True)[:n]
187
+
188
+ def detect_domain(self) -> str:
189
+ """
190
+ Detect document domain from vocabulary patterns.
191
+
192
+ Returns:
193
+ Detected domain name
194
+ """
195
+ # Domain detection heuristics
196
+ domain_indicators = {
197
+ 'embedded_systems': ['microcontroller', 'rtos', 'embedded', 'firmware', 'mcu'],
198
+ 'processor_architecture': ['risc-v', 'riscv', 'instruction', 'register', 'isa'],
199
+ 'regulatory': ['fda', 'validation', 'compliance', 'regulation', 'guidance'],
200
+ 'ai_ml': ['model', 'training', 'neural', 'algorithm', 'machine learning'],
201
+ 'software_engineering': ['software', 'development', 'testing', 'debugging', 'code']
202
+ }
203
+
204
+ domain_scores = {}
205
+
206
+ for domain, indicators in domain_indicators.items():
207
+ score = sum(
208
+ self.get_document_frequency(indicator)
209
+ for indicator in indicators
210
+ if self.contains(indicator)
211
+ )
212
+ domain_scores[domain] = score
213
+
214
+ # Return domain with highest score
215
+ if domain_scores:
216
+ return max(domain_scores, key=domain_scores.get)
217
+ return 'general'
218
+
219
+ def save_to_file(self, path: Path) -> None:
220
+ """Save vocabulary index to JSON file."""
221
+ data = {
222
+ 'vocabulary': list(self.vocabulary),
223
+ 'term_frequencies': dict(self.term_frequencies),
224
+ 'technical_terms': list(self.technical_terms),
225
+ 'document_frequencies': dict(self.document_frequencies),
226
+ 'total_documents': self.total_documents,
227
+ 'total_terms': self.total_terms
228
+ }
229
+
230
+ with open(path, 'w') as f:
231
+ json.dump(data, f, indent=2)
232
+
233
+ def load_from_file(self, path: Path) -> None:
234
+ """Load vocabulary index from JSON file."""
235
+ with open(path, 'r') as f:
236
+ data = json.load(f)
237
+
238
+ self.vocabulary = set(data['vocabulary'])
239
+ self.term_frequencies = defaultdict(int, data['term_frequencies'])
240
+ self.technical_terms = set(data['technical_terms'])
241
+ self.document_frequencies = defaultdict(int, data['document_frequencies'])
242
+ self.total_documents = data['total_documents']
243
+ self.total_terms = data['total_terms']
244
+
245
+ def merge_with(self, other: 'VocabularyIndex') -> None:
246
+ """Merge another vocabulary index into this one."""
247
+ # Merge vocabularies
248
+ self.vocabulary.update(other.vocabulary)
249
+ self.technical_terms.update(other.technical_terms)
250
+
251
+ # Merge frequencies
252
+ for term, freq in other.term_frequencies.items():
253
+ self.term_frequencies[term] += freq
254
+
255
+ for term, doc_freq in other.document_frequencies.items():
256
+ self.document_frequencies[term] += doc_freq
257
+
258
+ # Update totals
259
+ self.total_documents += other.total_documents
260
+ self.total_terms += other.total_terms
shared_utils/vector_stores/__init__.py ADDED
File without changes
shared_utils/vector_stores/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (158 Bytes). View file
 
shared_utils/vector_stores/document_processing/__init__.py ADDED
File without changes