diff --git a/shared_utils/__init__.py b/shared_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shared_utils/__pycache__/__init__.cpython-312.pyc b/shared_utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6984611af7503df2c306e0be4df927bd4bbc15a Binary files /dev/null and b/shared_utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/shared_utils/document_processing/__init__.py b/shared_utils/document_processing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shared_utils/document_processing/__pycache__/__init__.cpython-312.pyc b/shared_utils/document_processing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29228d914a6b968652cef2840fe4719c1df0535b Binary files /dev/null and b/shared_utils/document_processing/__pycache__/__init__.cpython-312.pyc differ diff --git a/shared_utils/document_processing/__pycache__/chunker.cpython-312.pyc b/shared_utils/document_processing/__pycache__/chunker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d888a501ced747a18b556945f57bdc6f05ab7d8d Binary files /dev/null and b/shared_utils/document_processing/__pycache__/chunker.cpython-312.pyc differ diff --git a/shared_utils/document_processing/__pycache__/hybrid_parser.cpython-312.pyc b/shared_utils/document_processing/__pycache__/hybrid_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d406cb38a3d3c9290fe48479d3e219eb71668d3b Binary files /dev/null and b/shared_utils/document_processing/__pycache__/hybrid_parser.cpython-312.pyc differ diff --git a/shared_utils/document_processing/__pycache__/paragraph_chunker.cpython-312.pyc b/shared_utils/document_processing/__pycache__/paragraph_chunker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd5cba549bfd974d3df4310c7f8666656951bba2 Binary files /dev/null and b/shared_utils/document_processing/__pycache__/paragraph_chunker.cpython-312.pyc differ diff --git a/shared_utils/document_processing/__pycache__/pdf_parser.cpython-312.pyc b/shared_utils/document_processing/__pycache__/pdf_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84730dfbf1fe0fff183aebd7513da747a8845d8b Binary files /dev/null and b/shared_utils/document_processing/__pycache__/pdf_parser.cpython-312.pyc differ diff --git a/shared_utils/document_processing/__pycache__/pdfplumber_parser.cpython-312.pyc b/shared_utils/document_processing/__pycache__/pdfplumber_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78668adde61fc2fbda29b7e085b053b0c85a55cc Binary files /dev/null and b/shared_utils/document_processing/__pycache__/pdfplumber_parser.cpython-312.pyc differ diff --git a/shared_utils/document_processing/__pycache__/smart_chunker.cpython-312.pyc b/shared_utils/document_processing/__pycache__/smart_chunker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60c509102ca48c2e030d8bd16101c0bc9e88dde5 Binary files /dev/null and b/shared_utils/document_processing/__pycache__/smart_chunker.cpython-312.pyc differ diff --git a/shared_utils/document_processing/__pycache__/structure_preserving_parser.cpython-312.pyc b/shared_utils/document_processing/__pycache__/structure_preserving_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69c543c3b8d838d9a5acd8b7458bdc7ded30f457 Binary files /dev/null and b/shared_utils/document_processing/__pycache__/structure_preserving_parser.cpython-312.pyc differ diff --git a/shared_utils/document_processing/__pycache__/toc_guided_parser.cpython-312.pyc b/shared_utils/document_processing/__pycache__/toc_guided_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b41cb453b64680b97baf7ec55f9ea920ddd52f17 Binary files /dev/null and b/shared_utils/document_processing/__pycache__/toc_guided_parser.cpython-312.pyc differ diff --git a/shared_utils/document_processing/chunker.py b/shared_utils/document_processing/chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..dc65f930aee006505ed69b7dd834efe001a16e62 --- /dev/null +++ b/shared_utils/document_processing/chunker.py @@ -0,0 +1,243 @@ +""" +BasicRAG System - Technical Document Chunker + +This module implements intelligent text chunking specifically optimized for technical +documentation. Unlike naive chunking approaches, this implementation preserves sentence +boundaries and maintains semantic coherence, critical for accurate RAG retrieval. + +Key Features: +- Sentence-boundary aware chunking to preserve semantic units +- Configurable overlap to maintain context across chunk boundaries +- Content-based chunk IDs for reproducibility and deduplication +- Technical document optimizations (handles code blocks, lists, etc.) + +Technical Approach: +- Uses regex patterns to identify sentence boundaries +- Implements a sliding window algorithm with intelligent boundary detection +- Generates deterministic chunk IDs using MD5 hashing +- Balances chunk size consistency with semantic completeness + +Design Decisions: +- Default 512 char chunks: Optimal for transformer models (under token limits) +- 50 char overlap: Sufficient context preservation without excessive redundancy +- Sentence boundaries prioritized over exact size for better coherence +- Hash-based IDs enable chunk deduplication across documents + +Performance Characteristics: +- Time complexity: O(n) where n is text length +- Memory usage: O(n) for output chunks +- Typical throughput: 1MB text/second on modern hardware + +Author: Arthur Passuello +Date: June 2025 +Project: RAG Portfolio - Technical Documentation System +""" + +from typing import List, Dict +import re +import hashlib + + +def _is_low_quality_chunk(text: str) -> bool: + """ + Identify low-quality chunks that should be filtered out. + + @param text: Chunk text to evaluate + @return: True if chunk is low quality and should be filtered + """ + text_lower = text.lower().strip() + + # Skip if too short to be meaningful + if len(text.strip()) < 50: + return True + + # Filter out common low-value content + low_value_patterns = [ + # Acknowledgments and credits + r'^(acknowledgment|thanks|thank you)', + r'(thanks to|grateful to|acknowledge)', + + # References and citations + r'^\s*\[\d+\]', # Citation markers + r'^references?$', + r'^bibliography$', + + # Metadata and headers + r'this document is released under', + r'creative commons', + r'copyright \d{4}', + + # Table of contents + r'^\s*\d+\..*\.\.\.\.\.\d+$', # TOC entries + r'^(contents?|table of contents)$', + + # Page headers/footers + r'^\s*page \d+', + r'^\s*\d+\s*$', # Just page numbers + + # Figure/table captions that are too short + r'^(figure|table|fig\.|tab\.)\s*\d+:?\s*$', + ] + + for pattern in low_value_patterns: + if re.search(pattern, text_lower): + return True + + # Check content quality metrics + words = text.split() + if len(words) < 8: # Too few words to be meaningful + return True + + # Check for reasonable sentence structure + sentences = re.split(r'[.!?]+', text) + complete_sentences = [s.strip() for s in sentences if len(s.strip()) > 10] + + if len(complete_sentences) == 0: # No complete sentences + return True + + return False + + +def chunk_technical_text( + text: str, chunk_size: int = 1400, overlap: int = 200 +) -> List[Dict]: + """ + Phase 1: Sentence-boundary preserving chunker for technical documentation. + + ZERO MID-SENTENCE BREAKS: This implementation strictly enforces sentence + boundaries to eliminate fragmented retrieval results that break Q&A quality. + + Key Improvements: + - Never breaks chunks mid-sentence (eliminates 90% fragment rate) + - Larger target chunks (1400 chars) for complete explanations + - Extended search windows to find sentence boundaries + - Paragraph boundary preference within size constraints + + @param text: The input text to be chunked, typically from technical documentation + @type text: str + + @param chunk_size: Target size for each chunk in characters (default: 1400) + @type chunk_size: int + + @param overlap: Number of characters to overlap between consecutive chunks (default: 200) + @type overlap: int + + @return: List of chunk dictionaries containing text and metadata + @rtype: List[Dict[str, Any]] where each dictionary contains: + { + "text": str, # Complete, sentence-bounded chunk text + "start_char": int, # Starting character position in original text + "end_char": int, # Ending character position in original text + "chunk_id": str, # Unique identifier (format: "chunk_[8-char-hash]") + "word_count": int, # Number of words in the chunk + "sentence_complete": bool # Always True (guaranteed complete sentences) + } + + Algorithm Details (Phase 1): + - Expands search window up to 50% beyond target size to find sentence boundaries + - Prefers chunks within 70-150% of target size over fragmenting + - Never falls back to mid-sentence breaks + - Quality filtering removes headers, captions, and navigation elements + + Expected Results: + - Fragment rate: 90% → 0% (complete sentences only) + - Average chunk size: 1400-2100 characters (larger, complete contexts) + - All chunks end with proper sentence terminators (. ! ? : ;) + - Better retrieval context for Q&A generation + + Example Usage: + >>> text = "RISC-V defines registers. Each register has specific usage. The architecture supports..." + >>> chunks = chunk_technical_text(text, chunk_size=1400, overlap=200) + >>> # All chunks will contain complete sentences and explanations + """ + # Handle edge case: empty or whitespace-only input + if not text.strip(): + return [] + + # Clean and normalize text by removing leading/trailing whitespace + text = text.strip() + chunks = [] + start_pos = 0 + + # Main chunking loop - process text sequentially + while start_pos < len(text): + # Calculate target end position for this chunk + # Min() ensures we don't exceed text length + target_end = min(start_pos + chunk_size, len(text)) + + # Define sentence boundary pattern + # Matches: period, exclamation, question mark, colon, semicolon + # followed by whitespace or end of string + sentence_pattern = r'[.!?:;](?:\s|$)' + + # PHASE 1: Strict sentence boundary enforcement + # Expand search window significantly to ensure we find sentence boundaries + max_extension = chunk_size // 2 # Allow up to 50% larger chunks to find boundaries + search_start = max(start_pos, target_end - 200) # Look back further + search_end = min(len(text), target_end + max_extension) # Look forward much further + search_text = text[search_start:search_end] + + # Find all sentence boundaries in expanded search window + sentence_matches = list(re.finditer(sentence_pattern, search_text)) + + # STRICT: Always find a sentence boundary, never break mid-sentence + chunk_end = None + sentence_complete = False + + if sentence_matches: + # Find the best sentence boundary within reasonable range + for match in reversed(sentence_matches): # Start from last (longest chunk) + candidate_end = search_start + match.end() + candidate_size = candidate_end - start_pos + + # Accept if within reasonable size range + if candidate_size >= chunk_size * 0.7: # At least 70% of target size + chunk_end = candidate_end + sentence_complete = True + break + + # If no good boundary found, take the last boundary (avoid fragments) + if chunk_end is None and sentence_matches: + best_match = sentence_matches[-1] + chunk_end = search_start + best_match.end() + sentence_complete = True + + # Final fallback: extend to end of text if no sentences found + if chunk_end is None: + chunk_end = len(text) + sentence_complete = True # End of document is always complete + + # Extract chunk text and clean whitespace + chunk_text = text[start_pos:chunk_end].strip() + + # Only create chunk if it contains actual content AND passes quality filter + if chunk_text and not _is_low_quality_chunk(chunk_text): + # Generate deterministic chunk ID using content hash + # MD5 is sufficient for deduplication (not cryptographic use) + chunk_hash = hashlib.md5(chunk_text.encode()).hexdigest()[:8] + chunk_id = f"chunk_{chunk_hash}" + + # Calculate word count for chunk statistics + word_count = len(chunk_text.split()) + + # Assemble chunk metadata + chunks.append({ + "text": chunk_text, + "start_char": start_pos, + "end_char": chunk_end, + "chunk_id": chunk_id, + "word_count": word_count, + "sentence_complete": sentence_complete + }) + + # Calculate next chunk starting position with overlap + if chunk_end >= len(text): + # Reached end of text, exit loop + break + + # Apply overlap by moving start position back from chunk end + # Max() ensures we always move forward at least 1 character + overlap_start = max(chunk_end - overlap, start_pos + 1) + start_pos = overlap_start + + return chunks diff --git a/shared_utils/document_processing/hybrid_parser.py b/shared_utils/document_processing/hybrid_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9ebf15317ceb57ae1d7f09b8c6c2da4ffbcbfaca --- /dev/null +++ b/shared_utils/document_processing/hybrid_parser.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python3 +""" +Hybrid TOC + PDFPlumber Parser + +Combines the best of both approaches: +1. TOC-guided navigation for reliable chapter/section mapping +2. PDFPlumber's precise content extraction with formatting awareness +3. Aggressive trash content filtering while preserving actual content + +This hybrid approach provides: +- Reliable structure detection (TOC) +- High-quality content extraction (PDFPlumber) +- Optimal chunk sizing and quality +- Fast processing with precise results + +Author: Arthur Passuello +Date: 2025-07-01 +""" + +import re +import pdfplumber +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass + +from .toc_guided_parser import TOCGuidedParser, TOCEntry +from .pdfplumber_parser import PDFPlumberParser + + +class HybridParser: + """ + Hybrid parser combining TOC navigation with PDFPlumber extraction. + + Architecture: + 1. Use TOC to identify chapter/section boundaries and pages + 2. Use PDFPlumber to extract clean content from those specific pages + 3. Apply aggressive content filtering to remove trash + 4. Create optimal chunks with preserved structure + """ + + def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800, + max_chunk_size: int = 2000): + """Initialize hybrid parser.""" + self.target_chunk_size = target_chunk_size + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + + # Initialize component parsers + self.toc_parser = TOCGuidedParser(target_chunk_size, min_chunk_size, max_chunk_size) + self.plumber_parser = PDFPlumberParser(target_chunk_size, min_chunk_size, max_chunk_size) + + # Content filtering patterns (aggressive trash removal) + self.trash_patterns = [ + # License and legal text + r'Creative Commons.*?License', + r'International License.*?authors', + r'released under.*?license', + r'derivative of.*?License', + r'Document Version \d+', + + # Table of contents artifacts + r'\.{3,}', # Multiple dots + r'^\s*\d+\s*$', # Standalone page numbers + r'Contents\s*$', + r'Preface\s*$', + + # PDF formatting artifacts + r'Volume\s+[IVX]+:.*?V\d+', + r'^\s*[ivx]+\s*$', # Roman numerals alone + r'^\s*[\d\w\s]{1,3}\s*$', # Very short meaningless lines + + # Redundant headers and footers + r'RISC-V.*?ISA.*?V\d+', + r'Volume I:.*?Unprivileged', + + # Editor and publication info + r'Editors?:.*?[A-Z][a-z]+', + r'[A-Z][a-z]+\s+\d{1,2},\s+\d{4}', # Dates + r'@[a-z]+\.[a-z]+', # Email addresses + + # Boilerplate text + r'please contact editors to suggest corrections', + r'alphabetical order.*?corrections', + r'contributors to all versions', + ] + + # Content quality patterns (preserve these) + self.preserve_patterns = [ + r'RISC-V.*?instruction', + r'register.*?file', + r'memory.*?operation', + r'processor.*?implementation', + r'architecture.*?design', + ] + + # TOC-specific patterns to exclude from searchable content + self.toc_exclusion_patterns = [ + r'^\s*Contents\s*$', + r'^\s*Table\s+of\s+Contents\s*$', + r'^\s*\d+(?:\.\d+)*\s*$', # Standalone section numbers + r'^\s*\d+(?:\.\d+)*\s+[A-Z]', # "1.1 INTRODUCTION" style + r'\.{3,}', # Multiple dots (TOC formatting) + r'^\s*Chapter\s+\d+\s*$', # Standalone "Chapter N" + r'^\s*Section\s+\d+(?:\.\d+)*\s*$', # Standalone "Section N.M" + r'^\s*Appendix\s+[A-Z]\s*$', # Standalone "Appendix A" + r'^\s*[ivxlcdm]+\s*$', # Roman numerals alone + r'^\s*Preface\s*$', + r'^\s*Introduction\s*$', + r'^\s*Conclusion\s*$', + r'^\s*Bibliography\s*$', + r'^\s*Index\s*$', + ] + + def parse_document(self, pdf_path: Path, pdf_data: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Parse document using hybrid approach. + + Args: + pdf_path: Path to PDF file + pdf_data: PDF data from extract_text_with_metadata() + + Returns: + List of high-quality chunks with preserved structure + """ + print("🔗 Starting Hybrid TOC + PDFPlumber parsing...") + + # Step 1: Use TOC to identify structure + print("📋 Step 1: Extracting TOC structure...") + toc_entries = self.toc_parser.parse_toc(pdf_data['pages']) + print(f" Found {len(toc_entries)} TOC entries") + + # Check if TOC is reliable (multiple entries or quality single entry) + toc_is_reliable = ( + len(toc_entries) > 1 or # Multiple entries = likely real TOC + (len(toc_entries) == 1 and len(toc_entries[0].title) > 10) # Quality single entry + ) + + if not toc_entries or not toc_is_reliable: + if not toc_entries: + print(" ⚠️ No TOC found, using full page coverage parsing") + else: + print(f" ⚠️ TOC quality poor (title: '{toc_entries[0].title}'), using full page coverage") + return self.plumber_parser.parse_document(pdf_path, pdf_data) + + # Step 2: Use PDFPlumber for precise extraction + print("🔬 Step 2: PDFPlumber extraction of TOC sections...") + chunks = [] + chunk_id = 0 + + with pdfplumber.open(str(pdf_path)) as pdf: + for i, toc_entry in enumerate(toc_entries): + next_entry = toc_entries[i + 1] if i + 1 < len(toc_entries) else None + + # Extract content using PDFPlumber + section_content = self._extract_section_with_plumber( + pdf, toc_entry, next_entry + ) + + if section_content: + # Apply aggressive content filtering + cleaned_content = self._filter_trash_content(section_content) + + if cleaned_content and len(cleaned_content) >= 200: # Minimum meaningful content + # Create chunks from cleaned content + section_chunks = self._create_chunks_from_clean_content( + cleaned_content, chunk_id, toc_entry + ) + chunks.extend(section_chunks) + chunk_id += len(section_chunks) + + print(f" Created {len(chunks)} high-quality chunks") + return chunks + + def _extract_section_with_plumber(self, pdf, toc_entry: TOCEntry, + next_entry: Optional[TOCEntry]) -> str: + """ + Extract section content using PDFPlumber's precise extraction. + + Args: + pdf: PDFPlumber PDF object + toc_entry: Current TOC entry + next_entry: Next TOC entry (for boundary detection) + + Returns: + Clean extracted content for this section + """ + start_page = max(0, toc_entry.page - 1) # Convert to 0-indexed + + if next_entry: + end_page = min(len(pdf.pages), next_entry.page - 1) + else: + end_page = len(pdf.pages) + + content_parts = [] + + for page_idx in range(start_page, end_page): + if page_idx < len(pdf.pages): + page = pdf.pages[page_idx] + + # Extract text with PDFPlumber (preserves formatting) + page_text = page.extract_text() + + if page_text: + # Clean page content while preserving structure + cleaned_text = self._clean_page_content_precise(page_text) + if cleaned_text.strip(): + content_parts.append(cleaned_text) + + return ' '.join(content_parts) + + def _clean_page_content_precise(self, page_text: str) -> str: + """ + Clean page content with precision, removing artifacts but preserving content. + + Args: + page_text: Raw page text from PDFPlumber + + Returns: + Cleaned text with artifacts removed + """ + lines = page_text.split('\n') + cleaned_lines = [] + + for line in lines: + line = line.strip() + + # Skip empty lines + if not line: + continue + + # Skip obvious artifacts but be conservative + if (len(line) < 3 or # Very short lines + re.match(r'^\d+$', line) or # Standalone numbers + re.match(r'^[ivx]+$', line.lower()) or # Roman numerals alone + '.' * 5 in line): # TOC dots + continue + + # Preserve technical content even if it looks like an artifact + has_technical_content = any(term in line.lower() for term in [ + 'risc', 'register', 'instruction', 'memory', 'processor', + 'architecture', 'implementation', 'specification' + ]) + + if has_technical_content or len(line) >= 10: + cleaned_lines.append(line) + + return ' '.join(cleaned_lines) + + def _filter_trash_content(self, content: str) -> str: + """ + Apply aggressive trash filtering while preserving actual content. + + Args: + content: Raw content to filter + + Returns: + Content with trash removed but technical content preserved + """ + if not content.strip(): + return "" + + # First, identify and preserve important technical sentences + sentences = re.split(r'[.!?]+\s*', content) + preserved_sentences = [] + + for sentence in sentences: + sentence = sentence.strip() + if not sentence: + continue + + # Check if sentence contains important technical content + is_technical = any(term in sentence.lower() for term in [ + 'risc-v', 'register', 'instruction', 'memory', 'processor', + 'architecture', 'implementation', 'specification', 'encoding', + 'bit', 'byte', 'address', 'data', 'control', 'operand' + ]) + + # Check if sentence is trash (including general trash and TOC content) + is_trash = any(re.search(pattern, sentence, re.IGNORECASE) + for pattern in self.trash_patterns) + + # Check if sentence is TOC content (should be excluded) + is_toc_content = any(re.search(pattern, sentence, re.IGNORECASE) + for pattern in self.toc_exclusion_patterns) + + # Preserve if technical and not trash/TOC, or if substantial and not clearly trash/TOC + if ((is_technical and not is_trash and not is_toc_content) or + (len(sentence) > 50 and not is_trash and not is_toc_content)): + preserved_sentences.append(sentence) + + # Reconstruct content from preserved sentences + filtered_content = '. '.join(preserved_sentences) + + # Final cleanup + filtered_content = re.sub(r'\s+', ' ', filtered_content) # Normalize whitespace + filtered_content = re.sub(r'\.+', '.', filtered_content) # Remove multiple dots + + # Ensure proper sentence ending + if filtered_content and not filtered_content.rstrip().endswith(('.', '!', '?', ':', ';')): + filtered_content = filtered_content.rstrip() + '.' + + return filtered_content.strip() + + def _create_chunks_from_clean_content(self, content: str, start_chunk_id: int, + toc_entry: TOCEntry) -> List[Dict[str, Any]]: + """ + Create optimally-sized chunks from clean content. + + Args: + content: Clean, filtered content + start_chunk_id: Starting chunk ID + toc_entry: TOC entry metadata + + Returns: + List of chunk dictionaries + """ + if not content or len(content) < 100: + return [] + + chunks = [] + + # If content fits in one chunk, create single chunk + if self.min_chunk_size <= len(content) <= self.max_chunk_size: + chunk = self._create_chunk(content, start_chunk_id, toc_entry) + chunks.append(chunk) + + # If too large, split intelligently at sentence boundaries + elif len(content) > self.max_chunk_size: + sub_chunks = self._split_large_content_smart(content, start_chunk_id, toc_entry) + chunks.extend(sub_chunks) + + # If too small but substantial, keep it + elif len(content) >= 200: # Lower threshold for cleaned content + chunk = self._create_chunk(content, start_chunk_id, toc_entry) + chunks.append(chunk) + + return chunks + + def _split_large_content_smart(self, content: str, start_chunk_id: int, + toc_entry: TOCEntry) -> List[Dict[str, Any]]: + """ + Split large content intelligently at natural boundaries. + + Args: + content: Content to split + start_chunk_id: Starting chunk ID + toc_entry: TOC entry metadata + + Returns: + List of chunk dictionaries + """ + chunks = [] + + # Split at sentence boundaries + sentences = re.split(r'([.!?:;]+\s*)', content) + + current_chunk = "" + chunk_id = start_chunk_id + + for i in range(0, len(sentences), 2): + sentence = sentences[i].strip() + if not sentence: + continue + + # Add punctuation if available + punctuation = sentences[i + 1] if i + 1 < len(sentences) else '.' + full_sentence = sentence + punctuation + + # Check if adding this sentence exceeds max size + potential_chunk = current_chunk + (" " if current_chunk else "") + full_sentence + + if len(potential_chunk) <= self.max_chunk_size: + current_chunk = potential_chunk + else: + # Save current chunk if it meets minimum size + if current_chunk and len(current_chunk) >= self.min_chunk_size: + chunk = self._create_chunk(current_chunk, chunk_id, toc_entry) + chunks.append(chunk) + chunk_id += 1 + + # Start new chunk + current_chunk = full_sentence + + # Add final chunk if substantial + if current_chunk and len(current_chunk) >= 200: + chunk = self._create_chunk(current_chunk, chunk_id, toc_entry) + chunks.append(chunk) + + return chunks + + def _create_chunk(self, content: str, chunk_id: int, toc_entry: TOCEntry) -> Dict[str, Any]: + """Create a chunk dictionary with hybrid metadata.""" + return { + "text": content, + "chunk_id": chunk_id, + "title": toc_entry.title, + "parent_title": toc_entry.parent_title, + "level": toc_entry.level, + "page": toc_entry.page, + "size": len(content), + "metadata": { + "parsing_method": "hybrid_toc_pdfplumber", + "has_context": True, + "content_type": "filtered_structured_content", + "quality_score": self._calculate_quality_score(content), + "trash_filtered": True + } + } + + def _calculate_quality_score(self, content: str) -> float: + """Calculate quality score for filtered content.""" + if not content.strip(): + return 0.0 + + words = content.split() + score = 0.0 + + # Length score (25%) + if self.min_chunk_size <= len(content) <= self.max_chunk_size: + score += 0.25 + elif len(content) >= 200: # At least some content + score += 0.15 + + # Content richness (25%) + substantial_words = sum(1 for word in words if len(word) > 3) + richness_score = min(substantial_words / 30, 1.0) # Lower threshold for filtered content + score += richness_score * 0.25 + + # Technical content (30%) + technical_terms = ['risc', 'register', 'instruction', 'cpu', 'memory', 'processor', 'architecture'] + technical_count = sum(1 for word in words if any(term in word.lower() for term in technical_terms)) + technical_score = min(technical_count / 3, 1.0) # Lower threshold + score += technical_score * 0.30 + + # Completeness (20%) + completeness_score = 0.0 + if content[0].isupper() or content.startswith(('The ', 'A ', 'An ', 'RISC')): + completeness_score += 0.5 + if content.rstrip().endswith(('.', '!', '?', ':', ';')): + completeness_score += 0.5 + score += completeness_score * 0.20 + + return min(score, 1.0) + + +def parse_pdf_with_hybrid_approach(pdf_path: Path, pdf_data: Dict[str, Any], + target_chunk_size: int = 1400, min_chunk_size: int = 800, + max_chunk_size: int = 2000) -> List[Dict[str, Any]]: + """ + Parse PDF using hybrid TOC + PDFPlumber approach. + + This function combines: + 1. TOC-guided structure detection for reliable navigation + 2. PDFPlumber's precise content extraction + 3. Aggressive trash filtering while preserving technical content + + Args: + pdf_path: Path to PDF file + pdf_data: PDF data from extract_text_with_metadata() + target_chunk_size: Preferred chunk size + min_chunk_size: Minimum chunk size + max_chunk_size: Maximum chunk size + + Returns: + List of high-quality, filtered chunks ready for RAG indexing + + Example: + >>> from shared_utils.document_processing.pdf_parser import extract_text_with_metadata + >>> from shared_utils.document_processing.hybrid_parser import parse_pdf_with_hybrid_approach + >>> + >>> pdf_data = extract_text_with_metadata("document.pdf") + >>> chunks = parse_pdf_with_hybrid_approach(Path("document.pdf"), pdf_data) + >>> print(f"Created {len(chunks)} hybrid-parsed chunks") + """ + parser = HybridParser(target_chunk_size, min_chunk_size, max_chunk_size) + return parser.parse_document(pdf_path, pdf_data) + + +# Example usage +if __name__ == "__main__": + print("Hybrid TOC + PDFPlumber Parser") + print("Combines TOC navigation with PDFPlumber precision and aggressive trash filtering") \ No newline at end of file diff --git a/shared_utils/document_processing/pdf_parser.py b/shared_utils/document_processing/pdf_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad8caf8f6144373e812d9821084f7eb0085c530 --- /dev/null +++ b/shared_utils/document_processing/pdf_parser.py @@ -0,0 +1,137 @@ +""" +BasicRAG System - PDF Document Parser + +This module implements robust PDF text extraction functionality as part of the BasicRAG +technical documentation system. It serves as the entry point for document ingestion, +converting PDF files into structured text data suitable for chunking and embedding. + +Key Features: +- Page-by-page text extraction with metadata preservation +- Robust error handling for corrupted or malformed PDFs +- Performance timing for optimization analysis +- Memory-efficient processing for large documents + +Technical Approach: +- Uses PyMuPDF (fitz) for reliable text extraction across PDF versions +- Maintains document structure with page-level granularity +- Preserves PDF metadata (author, title, creation date, etc.) + +Dependencies: +- PyMuPDF (fitz): Chosen for superior text extraction accuracy and speed +- Standard library: pathlib for cross-platform file handling + +Performance Characteristics: +- Typical processing: 10-50 pages/second on modern hardware +- Memory usage: O(n) with document size, but processes page-by-page +- Scales linearly with document length + +Author: Arthur Passuello +Date: June 2025 +Project: RAG Portfolio - Technical Documentation System +""" + +from typing import Dict, List, Any +from pathlib import Path +import time +import fitz # PyMuPDF + + +def extract_text_with_metadata(pdf_path: Path) -> Dict[str, Any]: + """ + Extract text and metadata from technical PDF documents with production-grade reliability. + + This function serves as the primary ingestion point for the RAG system, converting + PDF documents into structured data. It's optimized for technical documentation with + emphasis on preserving structure and handling various PDF formats gracefully. + + @param pdf_path: Path to the PDF file to process + @type pdf_path: pathlib.Path + + @return: Dictionary containing extracted text and comprehensive metadata + @rtype: Dict[str, Any] with the following structure: + { + "text": str, # Complete concatenated text from all pages + "pages": List[Dict], # Per-page breakdown with text and statistics + # Each page dict contains: + # - page_number: int (1-indexed for human readability) + # - text: str (raw text from that page) + # - char_count: int (character count for that page) + "metadata": Dict, # PDF metadata (title, author, subject, etc.) + "page_count": int, # Total number of pages processed + "extraction_time": float # Processing duration in seconds + } + + @throws FileNotFoundError: If the specified PDF file doesn't exist + @throws ValueError: If the PDF is corrupted, encrypted, or otherwise unreadable + + Performance Notes: + - Processes ~10-50 pages/second depending on PDF complexity + - Memory usage is proportional to document size but page-by-page processing + prevents loading entire document into memory at once + - Extraction time is included for performance monitoring and optimization + + Usage Example: + >>> pdf_path = Path("technical_manual.pdf") + >>> result = extract_text_with_metadata(pdf_path) + >>> print(f"Extracted {result['page_count']} pages in {result['extraction_time']:.2f}s") + >>> first_page_text = result['pages'][0]['text'] + """ + # Validate input file exists before attempting to open + if not pdf_path.exists(): + raise FileNotFoundError(f"PDF file not found: {pdf_path}") + + # Start performance timer for extraction analytics + start_time = time.perf_counter() + + try: + # Open PDF with PyMuPDF - automatically handles various PDF versions + # Using string conversion for compatibility with older fitz versions + doc = fitz.open(str(pdf_path)) + + # Extract document-level metadata (may include title, author, subject, keywords) + # Default to empty dict if no metadata present (common in scanned PDFs) + metadata = doc.metadata or {} + page_count = len(doc) + + # Initialize containers for page-by-page extraction + pages = [] # Will store individual page data + all_text = [] # Will store text for concatenation + + # Process each page sequentially to maintain document order + for page_num in range(page_count): + # Load page object (0-indexed internally) + page = doc[page_num] + + # Extract text using default extraction parameters + # This preserves reading order and handles multi-column layouts + page_text = page.get_text() + + # Store page data with human-readable page numbering (1-indexed) + pages.append({ + "page_number": page_num + 1, # Convert to 1-indexed for user clarity + "text": page_text, + "char_count": len(page_text) # Useful for chunking decisions + }) + + # Accumulate text for final concatenation + all_text.append(page_text) + + # Properly close the PDF to free resources + doc.close() + + # Calculate total extraction time for performance monitoring + extraction_time = time.perf_counter() - start_time + + # Return comprehensive extraction results + return { + "text": "\n".join(all_text), # Full document text with page breaks + "pages": pages, # Detailed page-by-page breakdown + "metadata": metadata, # Original PDF metadata + "page_count": page_count, # Total pages for quick reference + "extraction_time": extraction_time # Performance metric + } + + except Exception as e: + # Wrap any extraction errors with context for debugging + # Common causes: encrypted PDFs, corrupted files, unsupported formats + raise ValueError(f"Failed to process PDF: {e}") \ No newline at end of file diff --git a/shared_utils/document_processing/pdfplumber_parser.py b/shared_utils/document_processing/pdfplumber_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..08ff8e9897c198447eb0b8d1def63e244298496c --- /dev/null +++ b/shared_utils/document_processing/pdfplumber_parser.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python3 +""" +PDFPlumber-based Parser + +Advanced PDF parsing using pdfplumber for better structure detection +and cleaner text extraction. + +Author: Arthur Passuello +""" + +import re +import pdfplumber +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Any + + +class PDFPlumberParser: + """Advanced PDF parser using pdfplumber for structure-aware extraction.""" + + def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800, + max_chunk_size: int = 2000): + """Initialize PDFPlumber parser.""" + self.target_chunk_size = target_chunk_size + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + + # Trash content patterns + self.trash_patterns = [ + r'Creative Commons.*?License', + r'International License.*?authors', + r'RISC-V International', + r'Visit.*?for further', + r'editors to suggest.*?corrections', + r'released under.*?license', + r'\.{5,}', # Long dots (TOC artifacts) + r'^\d+\s*$', # Page numbers alone + ] + + def extract_with_structure(self, pdf_path: Path) -> List[Dict]: + """Extract PDF content with structure awareness using pdfplumber.""" + chunks = [] + + with pdfplumber.open(pdf_path) as pdf: + current_section = None + current_text = [] + + for page_num, page in enumerate(pdf.pages): + # Extract text with formatting info + page_content = self._extract_page_content(page, page_num + 1) + + for element in page_content: + if element['type'] == 'header': + # Save previous section if exists + if current_text: + chunk_text = '\n\n'.join(current_text) + if self._is_valid_chunk(chunk_text): + chunks.extend(self._create_chunks( + chunk_text, + current_section or "Document", + page_num + )) + + # Start new section + current_section = element['text'] + current_text = [] + + elif element['type'] == 'content': + # Add to current section + if self._is_valid_content(element['text']): + current_text.append(element['text']) + + # Don't forget last section + if current_text: + chunk_text = '\n\n'.join(current_text) + if self._is_valid_chunk(chunk_text): + chunks.extend(self._create_chunks( + chunk_text, + current_section or "Document", + len(pdf.pages) + )) + + return chunks + + def _extract_page_content(self, page: Any, page_num: int) -> List[Dict]: + """Extract structured content from a page.""" + content = [] + + # Get all text with positioning + chars = page.chars + if not chars: + return content + + # Group by lines + lines = [] + current_line = [] + current_y = None + + for char in sorted(chars, key=lambda x: (x['top'], x['x0'])): + if current_y is None or abs(char['top'] - current_y) < 2: + current_line.append(char) + current_y = char['top'] + else: + if current_line: + lines.append(current_line) + current_line = [char] + current_y = char['top'] + + if current_line: + lines.append(current_line) + + # Analyze each line + for line in lines: + line_text = ''.join(char['text'] for char in line).strip() + + if not line_text: + continue + + # Detect headers by font size + avg_font_size = sum(char.get('size', 12) for char in line) / len(line) + is_bold = any(char.get('fontname', '').lower().count('bold') > 0 for char in line) + + # Classify content + if avg_font_size > 14 or is_bold: + # Likely a header + if self._is_valid_header(line_text): + content.append({ + 'type': 'header', + 'text': line_text, + 'font_size': avg_font_size, + 'page': page_num + }) + else: + # Regular content + content.append({ + 'type': 'content', + 'text': line_text, + 'font_size': avg_font_size, + 'page': page_num + }) + + return content + + def _is_valid_header(self, text: str) -> bool: + """Check if text is a valid header.""" + # Skip if too short or too long + if len(text) < 3 or len(text) > 200: + return False + + # Skip if matches trash patterns + for pattern in self.trash_patterns: + if re.search(pattern, text, re.IGNORECASE): + return False + + # Valid if starts with number or capital letter + if re.match(r'^(\d+\.?\d*\s+|[A-Z])', text): + return True + + # Valid if contains keywords + keywords = ['chapter', 'section', 'introduction', 'conclusion', 'appendix'] + return any(keyword in text.lower() for keyword in keywords) + + def _is_valid_content(self, text: str) -> bool: + """Check if text is valid content (not trash).""" + # Skip very short text + if len(text.strip()) < 10: + return False + + # Skip trash patterns + for pattern in self.trash_patterns: + if re.search(pattern, text, re.IGNORECASE): + return False + + return True + + def _is_valid_chunk(self, text: str) -> bool: + """Check if chunk text is valid.""" + # Must have minimum length + if len(text.strip()) < self.min_chunk_size // 2: + return False + + # Must have some alphabetic content + alpha_chars = sum(1 for c in text if c.isalpha()) + if alpha_chars < len(text) * 0.5: + return False + + return True + + def _create_chunks(self, text: str, title: str, page: int) -> List[Dict]: + """Create chunks from text.""" + chunks = [] + + # Clean text + text = self._clean_text(text) + + if len(text) <= self.max_chunk_size: + # Single chunk + chunks.append({ + 'text': text, + 'title': title, + 'page': page, + 'metadata': { + 'parsing_method': 'pdfplumber', + 'quality_score': self._calculate_quality_score(text) + } + }) + else: + # Split into chunks + text_chunks = self._split_text_into_chunks(text) + for i, chunk_text in enumerate(text_chunks): + chunks.append({ + 'text': chunk_text, + 'title': f"{title} (Part {i+1})", + 'page': page, + 'metadata': { + 'parsing_method': 'pdfplumber', + 'part_number': i + 1, + 'total_parts': len(text_chunks), + 'quality_score': self._calculate_quality_score(chunk_text) + } + }) + + return chunks + + def _clean_text(self, text: str) -> str: + """Clean text from artifacts.""" + # Remove volume headers (e.g., "Volume I: RISC-V Unprivileged ISA V20191213") + text = re.sub(r'Volume\s+[IVX]+:\s*RISC-V[^V]*V\d{8}\s*', '', text, flags=re.IGNORECASE) + text = re.sub(r'^\d+\s+Volume\s+[IVX]+:.*?$', '', text, flags=re.MULTILINE) + + # Remove document version artifacts + text = re.sub(r'Document Version \d{8}\s*', '', text, flags=re.IGNORECASE) + + # Remove repeated ISA headers + text = re.sub(r'RISC-V.*?ISA.*?V\d{8}\s*', '', text, flags=re.IGNORECASE) + text = re.sub(r'The RISC-V Instruction Set Manual\s*', '', text, flags=re.IGNORECASE) + + # Remove figure/table references that are standalone + text = re.sub(r'^(Figure|Table)\s+\d+\.\d+:.*?$', '', text, flags=re.MULTILINE) + + # Remove email addresses (often in contributor lists) + text = re.sub(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', '', text) + + # Remove URLs + text = re.sub(r'https?://[^\s]+', '', text) + + # Remove page numbers at start/end of lines + text = re.sub(r'^\d{1,3}\s+', '', text, flags=re.MULTILINE) + text = re.sub(r'\s+\d{1,3}$', '', text, flags=re.MULTILINE) + + # Remove excessive dots (TOC artifacts) + text = re.sub(r'\.{3,}', '', text) + + # Remove standalone numbers (often page numbers or figure numbers) + text = re.sub(r'^\s*\d+\s*$', '', text, flags=re.MULTILINE) + + # Clean up multiple spaces and newlines + text = re.sub(r'\s{3,}', ' ', text) + text = re.sub(r'\n{3,}', '\n\n', text) + text = re.sub(r'[ \t]+', ' ', text) # Normalize all whitespace + + # Remove common boilerplate phrases + text = re.sub(r'Contains Nonbinding Recommendations\s*', '', text, flags=re.IGNORECASE) + text = re.sub(r'Guidance for Industry and FDA Staff\s*', '', text, flags=re.IGNORECASE) + + return text.strip() + + def _split_text_into_chunks(self, text: str) -> List[str]: + """Split text into chunks at sentence boundaries.""" + sentences = re.split(r'(?<=[.!?])\s+', text) + chunks = [] + current_chunk = [] + current_size = 0 + + for sentence in sentences: + sentence_size = len(sentence) + + if current_size + sentence_size > self.target_chunk_size and current_chunk: + chunks.append(' '.join(current_chunk)) + current_chunk = [sentence] + current_size = sentence_size + else: + current_chunk.append(sentence) + current_size += sentence_size + 1 + + if current_chunk: + chunks.append(' '.join(current_chunk)) + + return chunks + + def _calculate_quality_score(self, text: str) -> float: + """Calculate quality score for chunk.""" + score = 1.0 + + # Penalize very short or very long + if len(text) < self.min_chunk_size: + score *= 0.8 + elif len(text) > self.max_chunk_size: + score *= 0.9 + + # Reward complete sentences + if text.strip().endswith(('.', '!', '?')): + score *= 1.1 + + # Reward technical content + technical_terms = ['risc', 'instruction', 'register', 'memory', 'processor'] + term_count = sum(1 for term in technical_terms if term in text.lower()) + score *= (1 + term_count * 0.05) + + return min(score, 1.0) + + def extract_with_page_coverage(self, pdf_path: Path, pymupdf_pages: List[Dict]) -> List[Dict]: + """ + Extract content ensuring ALL pages are covered using PyMuPDF page data. + + Args: + pdf_path: Path to PDF file + pymupdf_pages: Page data from PyMuPDF with page numbers and text + + Returns: + List of chunks covering ALL document pages + """ + chunks = [] + chunk_id = 0 + + print(f"📄 Processing {len(pymupdf_pages)} pages with PDFPlumber quality extraction...") + + with pdfplumber.open(str(pdf_path)) as pdf: + for pymupdf_page in pymupdf_pages: + page_num = pymupdf_page['page_number'] # 1-indexed from PyMuPDF + page_idx = page_num - 1 # Convert to 0-indexed for PDFPlumber + + if page_idx < len(pdf.pages): + # Extract with PDFPlumber quality from this specific page + pdfplumber_page = pdf.pages[page_idx] + page_text = pdfplumber_page.extract_text() + + if page_text and page_text.strip(): + # Clean and chunk the page text + cleaned_text = self._clean_text(page_text) + + if len(cleaned_text) >= 100: # Minimum meaningful content + # Create chunks from this page + page_chunks = self._create_page_chunks( + cleaned_text, page_num, chunk_id + ) + chunks.extend(page_chunks) + chunk_id += len(page_chunks) + + if len(chunks) % 50 == 0: # Progress indicator + print(f" Processed {page_num} pages, created {len(chunks)} chunks") + + print(f"✅ Full coverage: {len(chunks)} chunks from {len(pymupdf_pages)} pages") + return chunks + + def _create_page_chunks(self, page_text: str, page_num: int, start_chunk_id: int) -> List[Dict]: + """Create properly sized chunks from a single page's content.""" + # Clean and validate page text first + cleaned_text = self._ensure_complete_sentences(page_text) + + if not cleaned_text or len(cleaned_text) < 50: + # Skip pages with insufficient content + return [] + + if len(cleaned_text) <= self.max_chunk_size: + # Single chunk for small pages + return [{ + 'text': cleaned_text, + 'title': f"Page {page_num}", + 'page': page_num, + 'metadata': { + 'parsing_method': 'pdfplumber_page_coverage', + 'quality_score': self._calculate_quality_score(cleaned_text), + 'full_page_coverage': True + } + }] + else: + # Split large pages into chunks with sentence boundaries + text_chunks = self._split_text_into_chunks(cleaned_text) + page_chunks = [] + + for i, chunk_text in enumerate(text_chunks): + # Ensure each chunk is complete + complete_chunk = self._ensure_complete_sentences(chunk_text) + + if complete_chunk and len(complete_chunk) >= 100: + page_chunks.append({ + 'text': complete_chunk, + 'title': f"Page {page_num} (Part {i+1})", + 'page': page_num, + 'metadata': { + 'parsing_method': 'pdfplumber_page_coverage', + 'part_number': i + 1, + 'total_parts': len(text_chunks), + 'quality_score': self._calculate_quality_score(complete_chunk), + 'full_page_coverage': True + } + }) + + return page_chunks + + def _ensure_complete_sentences(self, text: str) -> str: + """Ensure text contains only complete sentences.""" + text = text.strip() + if not text: + return "" + + # Find last complete sentence + last_sentence_end = -1 + for i, char in enumerate(reversed(text)): + if char in '.!?:': + last_sentence_end = len(text) - i + break + + if last_sentence_end > 0: + # Return text up to last complete sentence + complete_text = text[:last_sentence_end].strip() + + # Ensure it starts properly (capital letter or common starters) + if complete_text and (complete_text[0].isupper() or + complete_text.startswith(('The ', 'A ', 'An ', 'This ', 'RISC'))): + return complete_text + + # If no complete sentences found, return empty + return "" + + def parse_document(self, pdf_path: Path, pdf_data: Dict[str, Any] = None) -> List[Dict]: + """ + Parse document using PDFPlumber (required by HybridParser). + + Args: + pdf_path: Path to PDF file + pdf_data: PyMuPDF page data to ensure full page coverage + + Returns: + List of chunks with structure preservation across ALL pages + """ + if pdf_data and 'pages' in pdf_data: + # Use PyMuPDF page data to ensure full coverage + return self.extract_with_page_coverage(pdf_path, pdf_data['pages']) + else: + # Fallback to structure-based extraction + return self.extract_with_structure(pdf_path) + + +def parse_pdf_with_pdfplumber(pdf_path: Path, **kwargs) -> List[Dict]: + """Main entry point for PDFPlumber parsing.""" + parser = PDFPlumberParser(**kwargs) + chunks = parser.extract_with_structure(pdf_path) + + print(f"PDFPlumber extracted {len(chunks)} chunks") + + return chunks \ No newline at end of file diff --git a/shared_utils/document_processing/toc_guided_parser.py b/shared_utils/document_processing/toc_guided_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..613820600207c0994ac745d29e24491fee38ec69 --- /dev/null +++ b/shared_utils/document_processing/toc_guided_parser.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +""" +TOC-Guided PDF Parser + +Uses the Table of Contents to guide intelligent chunking that respects +document structure and hierarchy. + +Author: Arthur Passuello +""" + +import re +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass + + +@dataclass +class TOCEntry: + """Represents a table of contents entry.""" + title: str + page: int + level: int # 0 for chapters, 1 for sections, 2 for subsections + parent: Optional[str] = None + parent_title: Optional[str] = None # Added for hybrid parser compatibility + + +class TOCGuidedParser: + """Parser that uses TOC to create structure-aware chunks.""" + + def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800, + max_chunk_size: int = 2000): + """Initialize TOC-guided parser.""" + self.target_chunk_size = target_chunk_size + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + + def parse_toc(self, pages: List[Dict]) -> List[TOCEntry]: + """Parse table of contents from pages.""" + toc_entries = [] + + # Find TOC pages (usually early in document) + toc_pages = [] + for i, page in enumerate(pages[:20]): # Check first 20 pages + page_text = page.get('text', '').lower() + if 'contents' in page_text or 'table of contents' in page_text: + toc_pages.append((i, page)) + + if not toc_pages: + print("No TOC found, using fallback structure detection") + return self._detect_structure_without_toc(pages) + + # Parse TOC entries + for page_idx, page in toc_pages: + text = page.get('text', '') + lines = text.split('\n') + + i = 0 + while i < len(lines): + line = lines[i].strip() + + # Skip empty lines and TOC header + if not line or 'contents' in line.lower(): + i += 1 + continue + + # Pattern 1: "1.1 Title .... 23" + match1 = re.match(r'^(\d+(?:\.\d+)*)\s+(.+?)\s*\.{2,}\s*(\d+)$', line) + if match1: + number, title, page_num = match1.groups() + level = len(number.split('.')) - 1 + toc_entries.append(TOCEntry( + title=title.strip(), + page=int(page_num), + level=level + )) + i += 1 + continue + + # Pattern 2: Multi-line format + # "1.1" + # "Title" + # ". . . . 23" + if re.match(r'^(\d+(?:\.\d+)*)$', line): + number = line + if i + 1 < len(lines): + title_line = lines[i + 1].strip() + if i + 2 < len(lines): + dots_line = lines[i + 2].strip() + page_match = re.search(r'(\d+)\s*$', dots_line) + if page_match and '.' in dots_line: + title = title_line + page_num = int(page_match.group(1)) + level = len(number.split('.')) - 1 + toc_entries.append(TOCEntry( + title=title, + page=page_num, + level=level + )) + i += 3 + continue + + # Pattern 3: "Chapter 1: Title ... 23" + match3 = re.match(r'^(Chapter|Section|Part)\s+(\d+):?\s+(.+?)\s*\.{2,}\s*(\d+)$', line, re.IGNORECASE) + if match3: + prefix, number, title, page_num = match3.groups() + level = 0 if prefix.lower() == 'chapter' else 1 + toc_entries.append(TOCEntry( + title=f"{prefix} {number}: {title}", + page=int(page_num), + level=level + )) + i += 1 + continue + + i += 1 + + # Add parent relationships + for i, entry in enumerate(toc_entries): + if entry.level > 0: + # Find parent (previous entry with lower level) + for j in range(i - 1, -1, -1): + if toc_entries[j].level < entry.level: + entry.parent = toc_entries[j].title + entry.parent_title = toc_entries[j].title # Set both for compatibility + break + + return toc_entries + + def _detect_structure_without_toc(self, pages: List[Dict]) -> List[TOCEntry]: + """Fallback: detect structure from content patterns across ALL pages.""" + entries = [] + + # Expanded patterns for better structure detection + chapter_patterns = [ + re.compile(r'^(Chapter|CHAPTER)\s+(\d+|[IVX]+)(?:\s*[:\-]\s*(.+))?', re.MULTILINE), + re.compile(r'^(\d+)\s+([A-Z][^.]*?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1 Introduction" + re.compile(r'^([A-Z][A-Z\s]{10,})$', re.MULTILINE), # ALL CAPS titles + ] + + section_patterns = [ + re.compile(r'^(\d+\.\d+)\s+(.+?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1.1 Section" + re.compile(r'^(\d+\.\d+\.\d+)\s+(.+?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1.1.1 Subsection" + ] + + # Process ALL pages, not just first 20 + for i, page in enumerate(pages): + text = page.get('text', '') + if not text.strip(): + continue + + # Find chapters with various patterns + for pattern in chapter_patterns: + for match in pattern.finditer(text): + if len(match.groups()) >= 2: + if len(match.groups()) >= 3 and match.group(3): + title = match.group(3).strip() + else: + title = match.group(2).strip() if match.group(2) else f"Section {match.group(1)}" + + # Skip very short or likely false positives + if len(title) >= 3 and not re.match(r'^\d+$', title): + entries.append(TOCEntry( + title=title, + page=i + 1, + level=0 + )) + + # Find sections + for pattern in section_patterns: + for match in pattern.finditer(text): + section_num = match.group(1) + title = match.group(2).strip() if len(match.groups()) >= 2 else f"Section {section_num}" + + # Determine level by number of dots + level = section_num.count('.') + + # Skip very short titles or obvious artifacts + if len(title) >= 3 and not re.match(r'^\d+$', title): + entries.append(TOCEntry( + title=title, + page=i + 1, + level=level + )) + + # If still no entries found, create page-based entries for full coverage + if not entries: + print("No structure patterns found, creating page-based sections for full coverage") + # Create sections every 10 pages to ensure full document coverage + for i in range(0, len(pages), 10): + start_page = i + 1 + end_page = min(i + 10, len(pages)) + title = f"Pages {start_page}-{end_page}" + entries.append(TOCEntry( + title=title, + page=start_page, + level=0 + )) + + return entries + + def create_chunks_from_toc(self, pdf_data: Dict, toc_entries: List[TOCEntry]) -> List[Dict]: + """Create chunks based on TOC structure.""" + chunks = [] + pages = pdf_data.get('pages', []) + + for i, entry in enumerate(toc_entries): + # Determine page range for this entry + start_page = entry.page - 1 # Convert to 0-indexed + + # Find end page (start of next entry at same or higher level) + end_page = len(pages) + for j in range(i + 1, len(toc_entries)): + if toc_entries[j].level <= entry.level: + end_page = toc_entries[j].page - 1 + break + + # Extract text for this section + section_text = [] + for page_idx in range(max(0, start_page), min(end_page, len(pages))): + page_text = pages[page_idx].get('text', '') + if page_text.strip(): + section_text.append(page_text) + + if not section_text: + continue + + full_text = '\n\n'.join(section_text) + + # Create chunks from section text + if len(full_text) <= self.max_chunk_size: + # Single chunk for small sections + chunks.append({ + 'text': full_text.strip(), + 'title': entry.title, + 'parent_title': entry.parent_title or entry.parent or '', + 'level': entry.level, + 'page': entry.page, + 'context': f"From {entry.title}", + 'metadata': { + 'parsing_method': 'toc_guided', + 'section_title': entry.title, + 'hierarchy_level': entry.level + } + }) + else: + # Split large sections into chunks + section_chunks = self._split_text_into_chunks(full_text) + for j, chunk_text in enumerate(section_chunks): + chunks.append({ + 'text': chunk_text.strip(), + 'title': f"{entry.title} (Part {j+1})", + 'parent_title': entry.parent_title or entry.parent or '', + 'level': entry.level, + 'page': entry.page, + 'context': f"Part {j+1} of {entry.title}", + 'metadata': { + 'parsing_method': 'toc_guided', + 'section_title': entry.title, + 'hierarchy_level': entry.level, + 'part_number': j + 1, + 'total_parts': len(section_chunks) + } + }) + + return chunks + + def _split_text_into_chunks(self, text: str) -> List[str]: + """Split text into chunks while preserving sentence boundaries.""" + sentences = re.split(r'(?<=[.!?])\s+', text) + chunks = [] + current_chunk = [] + current_size = 0 + + for sentence in sentences: + sentence_size = len(sentence) + + if current_size + sentence_size > self.target_chunk_size and current_chunk: + # Save current chunk + chunks.append(' '.join(current_chunk)) + current_chunk = [sentence] + current_size = sentence_size + else: + current_chunk.append(sentence) + current_size += sentence_size + 1 # +1 for space + + if current_chunk: + chunks.append(' '.join(current_chunk)) + + return chunks + + +def parse_pdf_with_toc_guidance(pdf_data: Dict, **kwargs) -> List[Dict]: + """Main entry point for TOC-guided parsing.""" + parser = TOCGuidedParser(**kwargs) + + # Parse TOC + pages = pdf_data.get('pages', []) + toc_entries = parser.parse_toc(pages) + + print(f"Found {len(toc_entries)} TOC entries") + + if not toc_entries: + print("No TOC entries found, falling back to basic chunking") + from .chunker import chunk_technical_text + return chunk_technical_text(pdf_data.get('text', '')) + + # Create chunks based on TOC + chunks = parser.create_chunks_from_toc(pdf_data, toc_entries) + + print(f"Created {len(chunks)} chunks from TOC structure") + + return chunks \ No newline at end of file diff --git a/shared_utils/embeddings/__init__.py b/shared_utils/embeddings/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a56edba636fcf28480a085092019159318cee266 --- /dev/null +++ b/shared_utils/embeddings/__init__.py @@ -0,0 +1 @@ +# Embeddings module \ No newline at end of file diff --git a/shared_utils/embeddings/__pycache__/__init__.cpython-312.pyc b/shared_utils/embeddings/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4ec2d7d00bae47cd42db0335f28226e9521959c Binary files /dev/null and b/shared_utils/embeddings/__pycache__/__init__.cpython-312.pyc differ diff --git a/shared_utils/embeddings/__pycache__/generator.cpython-312.pyc b/shared_utils/embeddings/__pycache__/generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb3aa25f845ca25f73d8e77f92568c5e8da3e912 Binary files /dev/null and b/shared_utils/embeddings/__pycache__/generator.cpython-312.pyc differ diff --git a/shared_utils/embeddings/generator.py b/shared_utils/embeddings/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b0dba2656760b6b8bf6f64e0fa2d3e66e7bc0e4b --- /dev/null +++ b/shared_utils/embeddings/generator.py @@ -0,0 +1,84 @@ +import numpy as np +import torch +from typing import List, Optional +from sentence_transformers import SentenceTransformer + +# Global cache for embeddings +_embedding_cache = {} +_model_cache = {} + + +def generate_embeddings( + texts: List[str], + model_name: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", + batch_size: int = 32, + use_mps: bool = True, +) -> np.ndarray: + """ + Generate embeddings for text chunks with caching. + + Args: + texts: List of text chunks to embed + model_name: SentenceTransformer model identifier + batch_size: Processing batch size + use_mps: Use Apple Silicon acceleration + + Returns: + numpy array of shape (len(texts), embedding_dim) + + Performance Target: + - 100 texts/second on M4-Pro + - 384-dimensional embeddings + - Memory usage <500MB + """ + # Check cache for all texts + cache_keys = [f"{model_name}:{text}" for text in texts] + cached_embeddings = [] + texts_to_compute = [] + compute_indices = [] + + for i, key in enumerate(cache_keys): + if key in _embedding_cache: + cached_embeddings.append((i, _embedding_cache[key])) + else: + texts_to_compute.append(texts[i]) + compute_indices.append(i) + + # Load model if needed + if model_name not in _model_cache: + model = SentenceTransformer(model_name) + device = 'mps' if use_mps and torch.backends.mps.is_available() else 'cpu' + model = model.to(device) + model.eval() + _model_cache[model_name] = model + else: + model = _model_cache[model_name] + + # Compute new embeddings + if texts_to_compute: + with torch.no_grad(): + new_embeddings = model.encode( + texts_to_compute, + batch_size=batch_size, + convert_to_numpy=True, + normalize_embeddings=False + ).astype(np.float32) + + # Cache new embeddings + for i, text in enumerate(texts_to_compute): + key = f"{model_name}:{text}" + _embedding_cache[key] = new_embeddings[i] + + # Reconstruct full embedding array + result = np.zeros((len(texts), 384), dtype=np.float32) + + # Fill cached embeddings + for idx, embedding in cached_embeddings: + result[idx] = embedding + + # Fill newly computed embeddings + if texts_to_compute: + for i, original_idx in enumerate(compute_indices): + result[original_idx] = new_embeddings[i] + + return result diff --git a/shared_utils/generation/__pycache__/adaptive_prompt_engine.cpython-312.pyc b/shared_utils/generation/__pycache__/adaptive_prompt_engine.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fb2bbb18adb3b9d77fd7e6c41130150ac5c381c Binary files /dev/null and b/shared_utils/generation/__pycache__/adaptive_prompt_engine.cpython-312.pyc differ diff --git a/shared_utils/generation/__pycache__/answer_generator.cpython-312.pyc b/shared_utils/generation/__pycache__/answer_generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cf62c1655462caedffdcfc5d569f18a687e64b3 Binary files /dev/null and b/shared_utils/generation/__pycache__/answer_generator.cpython-312.pyc differ diff --git a/shared_utils/generation/__pycache__/chain_of_thought_engine.cpython-312.pyc b/shared_utils/generation/__pycache__/chain_of_thought_engine.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18aa27d2117b64ecf1aca8b4f3c4042f3308d2ea Binary files /dev/null and b/shared_utils/generation/__pycache__/chain_of_thought_engine.cpython-312.pyc differ diff --git a/shared_utils/generation/__pycache__/hf_answer_generator.cpython-312.pyc b/shared_utils/generation/__pycache__/hf_answer_generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0203f9fa14c87dd51480c239cccc84a9c0add29 Binary files /dev/null and b/shared_utils/generation/__pycache__/hf_answer_generator.cpython-312.pyc differ diff --git a/shared_utils/generation/__pycache__/inference_providers_generator.cpython-312.pyc b/shared_utils/generation/__pycache__/inference_providers_generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b92b8b9ca5538854b321a77dc827cafb3895e82 Binary files /dev/null and b/shared_utils/generation/__pycache__/inference_providers_generator.cpython-312.pyc differ diff --git a/shared_utils/generation/__pycache__/ollama_answer_generator.cpython-312.pyc b/shared_utils/generation/__pycache__/ollama_answer_generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6654e2e1c3ad492af1ee1c5f3339b41f2794d7b2 Binary files /dev/null and b/shared_utils/generation/__pycache__/ollama_answer_generator.cpython-312.pyc differ diff --git a/shared_utils/generation/__pycache__/prompt_optimizer.cpython-312.pyc b/shared_utils/generation/__pycache__/prompt_optimizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6d24efde7c313daf5a489dc4621c825ec23f8df Binary files /dev/null and b/shared_utils/generation/__pycache__/prompt_optimizer.cpython-312.pyc differ diff --git a/shared_utils/generation/__pycache__/prompt_templates.cpython-312.pyc b/shared_utils/generation/__pycache__/prompt_templates.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c5bd8f1b5140dfe5d8113eb7a51dddd661fb497 Binary files /dev/null and b/shared_utils/generation/__pycache__/prompt_templates.cpython-312.pyc differ diff --git a/shared_utils/generation/adaptive_prompt_engine.py b/shared_utils/generation/adaptive_prompt_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..f24ade3c088782a9f5372d3a89abf5114c67c853 --- /dev/null +++ b/shared_utils/generation/adaptive_prompt_engine.py @@ -0,0 +1,559 @@ +""" +Adaptive Prompt Engine for Dynamic Context-Aware Prompt Optimization. + +This module provides intelligent prompt adaptation based on context quality, +query complexity, and performance requirements. +""" + +import logging +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass +from enum import Enum +import numpy as np + +from .prompt_templates import ( + QueryType, + PromptTemplate, + TechnicalPromptTemplates +) + + +class ContextQuality(Enum): + """Context quality levels for adaptive prompting.""" + HIGH = "high" # >0.8 relevance, low noise + MEDIUM = "medium" # 0.5-0.8 relevance, moderate noise + LOW = "low" # <0.5 relevance, high noise + + +class QueryComplexity(Enum): + """Query complexity levels.""" + SIMPLE = "simple" # Single concept, direct answer + MODERATE = "moderate" # Multiple concepts, structured answer + COMPLEX = "complex" # Multi-step reasoning, comprehensive answer + + +@dataclass +class ContextMetrics: + """Metrics for evaluating context quality.""" + relevance_score: float + noise_ratio: float + chunk_count: int + avg_chunk_length: int + technical_density: float + source_diversity: int + + +@dataclass +class AdaptivePromptConfig: + """Configuration for adaptive prompt generation.""" + context_quality: ContextQuality + query_complexity: QueryComplexity + max_context_length: int + prefer_concise: bool + include_few_shot: bool + enable_chain_of_thought: bool + confidence_threshold: float + + +class AdaptivePromptEngine: + """ + Intelligent prompt adaptation engine that optimizes prompts based on: + - Context quality and relevance + - Query complexity and type + - Performance requirements + - User preferences + """ + + def __init__(self): + """Initialize the adaptive prompt engine.""" + self.logger = logging.getLogger(__name__) + + # Context quality thresholds + self.high_quality_threshold = 0.8 + self.medium_quality_threshold = 0.5 + + # Query complexity indicators + self.complex_keywords = { + "implementation": ["implement", "build", "create", "develop", "setup"], + "comparison": ["compare", "difference", "versus", "vs", "better"], + "analysis": ["analyze", "evaluate", "assess", "study", "examine"], + "multi_step": ["process", "procedure", "steps", "how to", "guide"] + } + + # Length optimization thresholds + self.token_limits = { + "concise": 512, + "standard": 1024, + "detailed": 2048, + "comprehensive": 4096 + } + + def analyze_context_quality(self, chunks: List[Dict[str, Any]]) -> ContextMetrics: + """ + Analyze the quality of retrieved context chunks. + + Args: + chunks: List of context chunks with metadata + + Returns: + ContextMetrics with quality assessment + """ + if not chunks: + return ContextMetrics( + relevance_score=0.0, + noise_ratio=1.0, + chunk_count=0, + avg_chunk_length=0, + technical_density=0.0, + source_diversity=0 + ) + + # Calculate relevance score (using confidence scores if available) + relevance_scores = [] + for chunk in chunks: + # Use confidence score if available, otherwise use a heuristic + if 'confidence' in chunk: + relevance_scores.append(chunk['confidence']) + elif 'score' in chunk: + relevance_scores.append(chunk['score']) + else: + # Heuristic: longer chunks with technical terms are more relevant + content = chunk.get('content', chunk.get('text', '')) + tech_terms = self._count_technical_terms(content) + relevance_scores.append(min(tech_terms / 10.0, 1.0)) + + avg_relevance = np.mean(relevance_scores) if relevance_scores else 0.0 + + # Calculate noise ratio (fragments, repetitive content) + noise_count = 0 + total_chunks = len(chunks) + + for chunk in chunks: + content = chunk.get('content', chunk.get('text', '')) + if self._is_noisy_chunk(content): + noise_count += 1 + + noise_ratio = noise_count / total_chunks if total_chunks > 0 else 0.0 + + # Calculate average chunk length + chunk_lengths = [] + for chunk in chunks: + content = chunk.get('content', chunk.get('text', '')) + chunk_lengths.append(len(content)) + + avg_chunk_length = int(np.mean(chunk_lengths)) if chunk_lengths else 0 + + # Calculate technical density + technical_density = self._calculate_technical_density(chunks) + + # Calculate source diversity + sources = set() + for chunk in chunks: + source = chunk.get('metadata', {}).get('source', 'unknown') + sources.add(source) + + source_diversity = len(sources) + + return ContextMetrics( + relevance_score=avg_relevance, + noise_ratio=noise_ratio, + chunk_count=len(chunks), + avg_chunk_length=avg_chunk_length, + technical_density=technical_density, + source_diversity=source_diversity + ) + + def determine_query_complexity(self, query: str) -> QueryComplexity: + """ + Determine the complexity level of a query. + + Args: + query: User's question + + Returns: + QueryComplexity level + """ + query_lower = query.lower() + complexity_score = 0 + + # Check for complex keywords + for category, keywords in self.complex_keywords.items(): + if any(keyword in query_lower for keyword in keywords): + complexity_score += 1 + + # Check for multiple questions or concepts + if '?' in query[:-1]: # Multiple question marks (excluding the last one) + complexity_score += 1 + + if any(word in query_lower for word in ["and", "or", "also", "additionally", "furthermore"]): + complexity_score += 1 + + # Check query length + word_count = len(query.split()) + if word_count > 20: + complexity_score += 1 + elif word_count > 10: + complexity_score += 0.5 + + # Determine complexity level + if complexity_score >= 2: + return QueryComplexity.COMPLEX + elif complexity_score >= 1: + return QueryComplexity.MODERATE + else: + return QueryComplexity.SIMPLE + + def generate_adaptive_config( + self, + query: str, + context_chunks: List[Dict[str, Any]], + max_tokens: int = 2048, + prefer_speed: bool = False + ) -> AdaptivePromptConfig: + """ + Generate adaptive prompt configuration based on context and query analysis. + + Args: + query: User's question + context_chunks: Retrieved context chunks + max_tokens: Maximum token limit + prefer_speed: Whether to optimize for speed over quality + + Returns: + AdaptivePromptConfig with optimized settings + """ + # Analyze context quality + context_metrics = self.analyze_context_quality(context_chunks) + + # Determine context quality level + if context_metrics.relevance_score >= self.high_quality_threshold: + context_quality = ContextQuality.HIGH + elif context_metrics.relevance_score >= self.medium_quality_threshold: + context_quality = ContextQuality.MEDIUM + else: + context_quality = ContextQuality.LOW + + # Determine query complexity + query_complexity = self.determine_query_complexity(query) + + # Adapt configuration based on analysis + config = AdaptivePromptConfig( + context_quality=context_quality, + query_complexity=query_complexity, + max_context_length=max_tokens, + prefer_concise=prefer_speed, + include_few_shot=self._should_include_few_shot(context_quality, query_complexity), + enable_chain_of_thought=self._should_enable_cot(query_complexity), + confidence_threshold=self._get_confidence_threshold(context_quality) + ) + + return config + + def create_adaptive_prompt( + self, + query: str, + context_chunks: List[Dict[str, Any]], + config: Optional[AdaptivePromptConfig] = None + ) -> Dict[str, str]: + """ + Create an adaptive prompt optimized for the specific query and context. + + Args: + query: User's question + context_chunks: Retrieved context chunks + config: Optional configuration (auto-generated if None) + + Returns: + Dict with optimized 'system' and 'user' prompts + """ + if config is None: + config = self.generate_adaptive_config(query, context_chunks) + + # Get base template + query_type = TechnicalPromptTemplates.detect_query_type(query) + base_template = TechnicalPromptTemplates.get_template_for_query(query) + + # Adapt template based on configuration + adapted_template = self._adapt_template(base_template, config) + + # Format context with optimization + formatted_context = self._format_context_adaptive(context_chunks, config) + + # Create prompt with adaptive formatting + prompt = TechnicalPromptTemplates.format_prompt_with_template( + query=query, + context=formatted_context, + template=adapted_template, + include_few_shot=config.include_few_shot + ) + + # Add chain-of-thought if enabled + if config.enable_chain_of_thought: + prompt = self._add_chain_of_thought(prompt, query_type) + + return prompt + + def _adapt_template( + self, + base_template: PromptTemplate, + config: AdaptivePromptConfig + ) -> PromptTemplate: + """ + Adapt a base template based on configuration. + + Args: + base_template: Base prompt template + config: Adaptive configuration + + Returns: + Adapted PromptTemplate + """ + # Modify system prompt based on context quality + system_prompt = base_template.system_prompt + + if config.context_quality == ContextQuality.LOW: + system_prompt += """ + +IMPORTANT: The provided context may have limited relevance. Focus on: +- Only use information that directly relates to the question +- Clearly state if information is insufficient +- Avoid making assumptions beyond the provided context +- Be explicit about confidence levels""" + + elif config.context_quality == ContextQuality.HIGH: + system_prompt += """ + +CONTEXT QUALITY: High-quality, relevant context is provided. You can: +- Provide comprehensive, detailed answers +- Make reasonable inferences from the context +- Include related technical details and examples +- Reference multiple sources confidently""" + + # Modify answer guidelines based on complexity and preferences + answer_guidelines = base_template.answer_guidelines + + if config.prefer_concise: + answer_guidelines += "\n\nResponse style: Be concise and focus on essential information. Aim for clarity over comprehensiveness." + + if config.query_complexity == QueryComplexity.COMPLEX: + answer_guidelines += "\n\nComplex query handling: Break down your answer into clear sections. Use numbered steps for procedures." + + return PromptTemplate( + system_prompt=system_prompt, + context_format=base_template.context_format, + query_format=base_template.query_format, + answer_guidelines=answer_guidelines, + few_shot_examples=base_template.few_shot_examples + ) + + def _format_context_adaptive( + self, + chunks: List[Dict[str, Any]], + config: AdaptivePromptConfig + ) -> str: + """ + Format context chunks with adaptive optimization. + + Args: + chunks: Context chunks to format + config: Adaptive configuration + + Returns: + Formatted context string + """ + if not chunks: + return "No relevant context available." + + # Filter chunks based on confidence if low quality context + filtered_chunks = chunks + if config.context_quality == ContextQuality.LOW: + filtered_chunks = [ + chunk for chunk in chunks + if self._meets_confidence_threshold(chunk, config.confidence_threshold) + ] + + # Limit context length if needed + if config.prefer_concise: + filtered_chunks = filtered_chunks[:3] # Limit to top 3 chunks + + # Format chunks + context_parts = [] + for i, chunk in enumerate(filtered_chunks): + chunk_text = chunk.get('content', chunk.get('text', '')) + + # Truncate if too long and prefer_concise is True + if config.prefer_concise and len(chunk_text) > 800: + chunk_text = chunk_text[:800] + "..." + + metadata = chunk.get('metadata', {}) + page_num = metadata.get('page_number', 'unknown') + source = metadata.get('source', 'unknown') + + context_parts.append( + f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}" + ) + + return "\n\n---\n\n".join(context_parts) + + def _add_chain_of_thought( + self, + prompt: Dict[str, str], + query_type: QueryType + ) -> Dict[str, str]: + """ + Add chain-of-thought reasoning to the prompt. + + Args: + prompt: Base prompt dictionary + query_type: Type of query + + Returns: + Enhanced prompt with chain-of-thought + """ + cot_addition = """ + +Before providing your final answer, think through this step-by-step: + +1. What is the user specifically asking for? +2. What relevant information is available in the context? +3. How should I structure my response for maximum clarity? +4. Are there any important caveats or limitations to mention? + +Step-by-step reasoning:""" + + prompt["user"] = prompt["user"] + cot_addition + + return prompt + + def _should_include_few_shot( + self, + context_quality: ContextQuality, + query_complexity: QueryComplexity + ) -> bool: + """Determine if few-shot examples should be included.""" + # Include few-shot for complex queries or when context quality is low + if query_complexity == QueryComplexity.COMPLEX: + return True + if context_quality == ContextQuality.LOW: + return True + return False + + def _should_enable_cot(self, query_complexity: QueryComplexity) -> bool: + """Determine if chain-of-thought should be enabled.""" + return query_complexity == QueryComplexity.COMPLEX + + def _get_confidence_threshold(self, context_quality: ContextQuality) -> float: + """Get confidence threshold based on context quality.""" + thresholds = { + ContextQuality.HIGH: 0.3, + ContextQuality.MEDIUM: 0.5, + ContextQuality.LOW: 0.7 + } + return thresholds[context_quality] + + def _count_technical_terms(self, text: str) -> int: + """Count technical terms in text.""" + technical_terms = [ + "risc-v", "riscv", "cpu", "gpu", "mcu", "interrupt", "register", + "memory", "cache", "pipeline", "instruction", "assembly", "compiler", + "embedded", "freertos", "rtos", "gpio", "uart", "spi", "i2c", + "adc", "dac", "timer", "pwm", "dma", "firmware", "bootloader", + "ai", "ml", "neural", "transformer", "attention", "embedding" + ] + + text_lower = text.lower() + count = 0 + for term in technical_terms: + count += text_lower.count(term) + + return count + + def _is_noisy_chunk(self, content: str) -> bool: + """Determine if a chunk is noisy (low quality).""" + # Check for common noise indicators + noise_indicators = [ + "table of contents", + "copyright", + "creative commons", + "license", + "all rights reserved", + "terms of use", + "privacy policy" + ] + + content_lower = content.lower() + + # Check for noise indicators + for indicator in noise_indicators: + if indicator in content_lower: + return True + + # Check for very short fragments + if len(content) < 100: + return True + + # Check for repetitive content + words = content.split() + if len(set(words)) < len(words) * 0.3: # Less than 30% unique words + return True + + return False + + def _calculate_technical_density(self, chunks: List[Dict[str, Any]]) -> float: + """Calculate technical density of chunks.""" + if not chunks: + return 0.0 + + total_terms = 0 + total_words = 0 + + for chunk in chunks: + content = chunk.get('content', chunk.get('text', '')) + words = content.split() + total_words += len(words) + total_terms += self._count_technical_terms(content) + + return (total_terms / total_words) if total_words > 0 else 0.0 + + def _meets_confidence_threshold( + self, + chunk: Dict[str, Any], + threshold: float + ) -> bool: + """Check if chunk meets confidence threshold.""" + confidence = chunk.get('confidence', chunk.get('score', 0.5)) + return confidence >= threshold + + +# Example usage +if __name__ == "__main__": + # Initialize engine + engine = AdaptivePromptEngine() + + # Example context chunks + example_chunks = [ + { + "content": "RISC-V is an open-source instruction set architecture...", + "metadata": {"page_number": 1, "source": "riscv-spec.pdf"}, + "confidence": 0.9 + }, + { + "content": "The RISC-V processor supports 32-bit and 64-bit implementations...", + "metadata": {"page_number": 2, "source": "riscv-spec.pdf"}, + "confidence": 0.8 + } + ] + + # Example queries + simple_query = "What is RISC-V?" + complex_query = "How do I implement a complete interrupt handling system in RISC-V with nested interrupts and priority management?" + + # Generate adaptive prompts + simple_config = engine.generate_adaptive_config(simple_query, example_chunks) + complex_config = engine.generate_adaptive_config(complex_query, example_chunks) + + print(f"Simple query complexity: {simple_config.query_complexity}") + print(f"Complex query complexity: {complex_config.query_complexity}") + print(f"Context quality: {simple_config.context_quality}") + print(f"Few-shot enabled: {complex_config.include_few_shot}") + print(f"Chain-of-thought enabled: {complex_config.enable_chain_of_thought}") \ No newline at end of file diff --git a/shared_utils/generation/answer_generator.py b/shared_utils/generation/answer_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9c5f4fa4d0f273d0814c143a51abf2e25ded6281 --- /dev/null +++ b/shared_utils/generation/answer_generator.py @@ -0,0 +1,703 @@ +""" +Answer generation module using Ollama for local LLM inference. + +This module provides answer generation with citation support for RAG systems, +optimized for technical documentation Q&A on Apple Silicon. +""" + +import json +import logging +from dataclasses import dataclass +from typing import List, Dict, Any, Optional, Generator, Tuple +import ollama +from datetime import datetime +import re +from pathlib import Path +import sys + +# Import calibration framework +try: + from src.confidence_calibration import ConfidenceCalibrator +except ImportError: + # Fallback - disable calibration for deployment + ConfidenceCalibrator = None + +logger = logging.getLogger(__name__) + + +@dataclass +class Citation: + """Represents a citation to a source document chunk.""" + chunk_id: str + page_number: int + source_file: str + relevance_score: float + text_snippet: str + + +@dataclass +class GeneratedAnswer: + """Represents a generated answer with citations.""" + answer: str + citations: List[Citation] + confidence_score: float + generation_time: float + model_used: str + context_used: List[Dict[str, Any]] + + +class AnswerGenerator: + """ + Generates answers using local LLMs via Ollama with citation support. + + Optimized for technical documentation Q&A with: + - Streaming response support + - Citation extraction and formatting + - Confidence scoring + - Fallback model support + """ + + def __init__( + self, + primary_model: str = "llama3.2:3b", + fallback_model: str = "mistral:latest", + temperature: float = 0.3, + max_tokens: int = 1024, + stream: bool = True, + enable_calibration: bool = True + ): + """ + Initialize the answer generator. + + Args: + primary_model: Primary Ollama model to use + fallback_model: Fallback model for complex queries + temperature: Generation temperature (0.0-1.0) + max_tokens: Maximum tokens to generate + stream: Whether to stream responses + enable_calibration: Whether to enable confidence calibration + """ + self.primary_model = primary_model + self.fallback_model = fallback_model + self.temperature = temperature + self.max_tokens = max_tokens + self.stream = stream + self.client = ollama.Client() + + # Initialize confidence calibration + self.enable_calibration = enable_calibration + self.calibrator = None + if enable_calibration and ConfidenceCalibrator is not None: + try: + self.calibrator = ConfidenceCalibrator() + logger.info("Confidence calibration enabled") + except Exception as e: + logger.warning(f"Failed to initialize calibration: {e}") + self.enable_calibration = False + elif enable_calibration and ConfidenceCalibrator is None: + logger.warning("Calibration requested but ConfidenceCalibrator not available - disabling") + self.enable_calibration = False + + # Verify models are available + self._verify_models() + + def _verify_models(self) -> None: + """Verify that required models are available.""" + try: + model_list = self.client.list() + available_models = [] + + # Handle Ollama's ListResponse object + if hasattr(model_list, 'models'): + for model in model_list.models: + if hasattr(model, 'model'): + available_models.append(model.model) + elif isinstance(model, dict) and 'model' in model: + available_models.append(model['model']) + + if self.primary_model not in available_models: + logger.warning(f"Primary model {self.primary_model} not found. Available models: {available_models}") + raise ValueError(f"Model {self.primary_model} not available. Please run: ollama pull {self.primary_model}") + + if self.fallback_model not in available_models: + logger.warning(f"Fallback model {self.fallback_model} not found in: {available_models}") + + except Exception as e: + logger.error(f"Error verifying models: {e}") + raise + + def _create_system_prompt(self) -> str: + """Create system prompt for technical documentation Q&A.""" + return """You are a technical documentation assistant that provides clear, accurate answers based on the provided context. + +CORE PRINCIPLES: +1. ANSWER DIRECTLY: If context contains the answer, provide it clearly and confidently +2. BE CONCISE: Keep responses focused and avoid unnecessary uncertainty language +3. CITE ACCURATELY: Use [chunk_X] citations for every fact from context + +RESPONSE GUIDELINES: +- If context has sufficient information → Answer directly and confidently +- If context has partial information → Answer what's available, note what's missing briefly +- If context is irrelevant → Brief refusal: "This information isn't available in the provided documents" + +CITATION FORMAT: +- Use [chunk_1], [chunk_2] etc. for all facts from context +- Example: "According to [chunk_1], RISC-V is an open-source architecture." + +WHAT TO AVOID: +- Do NOT add details not in context +- Do NOT second-guess yourself if context is clear +- Do NOT use phrases like "does not contain sufficient information" when context clearly answers the question +- Do NOT be overly cautious when context is adequate + +Be direct, confident, and accurate. If the context answers the question, provide that answer clearly.""" + + def _format_context(self, chunks: List[Dict[str, Any]]) -> str: + """ + Format retrieved chunks into context for the LLM. + + Args: + chunks: List of retrieved chunks with metadata + + Returns: + Formatted context string + """ + context_parts = [] + + for i, chunk in enumerate(chunks): + chunk_text = chunk.get('content', chunk.get('text', '')) + page_num = chunk.get('metadata', {}).get('page_number', 'unknown') + source = chunk.get('metadata', {}).get('source', 'unknown') + + context_parts.append( + f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n" + ) + + return "\n---\n".join(context_parts) + + def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]: + """ + Extract citations from the generated answer and integrate them naturally. + + Args: + answer: Generated answer with [chunk_X] citations + chunks: Original chunks used for context + + Returns: + Tuple of (natural_answer, citations) + """ + citations = [] + citation_pattern = r'\[chunk_(\d+)\]' + + cited_chunks = set() + + # Find [chunk_X] citations and collect cited chunks + matches = re.finditer(citation_pattern, answer) + for match in matches: + chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index + if 0 <= chunk_idx < len(chunks): + cited_chunks.add(chunk_idx) + + # Create Citation objects for each cited chunk + chunk_to_source = {} + for idx in cited_chunks: + chunk = chunks[idx] + citation = Citation( + chunk_id=chunk.get('id', f'chunk_{idx}'), + page_number=chunk.get('metadata', {}).get('page_number', 0), + source_file=chunk.get('metadata', {}).get('source', 'unknown'), + relevance_score=chunk.get('score', 0.0), + text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...' + ) + citations.append(citation) + + # Map chunk reference to natural source name + source_name = chunk.get('metadata', {}).get('source', 'unknown') + if source_name != 'unknown': + # Use just the filename without extension for natural reference + natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ') + chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation" + else: + chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation" + + # Replace [chunk_X] with natural references instead of removing them + natural_answer = answer + for chunk_ref, natural_ref in chunk_to_source.items(): + natural_answer = natural_answer.replace(chunk_ref, natural_ref) + + # Clean up any remaining unreferenced citations (fallback) + natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer) + + # Clean up multiple spaces and formatting + natural_answer = re.sub(r'\s+', ' ', natural_answer).strip() + + return natural_answer, citations + + def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float: + """ + Calculate confidence score for the generated answer with improved calibration. + + Args: + answer: Generated answer + citations: Extracted citations + chunks: Retrieved chunks + + Returns: + Confidence score (0.0-1.0) + """ + # Check if no chunks were provided first + if not chunks: + return 0.05 # No context = very low confidence + + # Assess context quality to determine base confidence + scores = [chunk.get('score', 0) for chunk in chunks] + max_relevance = max(scores) if scores else 0 + avg_relevance = sum(scores) / len(scores) if scores else 0 + + # Dynamic base confidence based on context quality + if max_relevance >= 0.8: + confidence = 0.6 # High-quality context starts high + elif max_relevance >= 0.6: + confidence = 0.4 # Good context starts moderately + elif max_relevance >= 0.4: + confidence = 0.2 # Fair context starts low + else: + confidence = 0.05 # Poor context starts very low + + # Strong uncertainty and explicit refusal indicators + strong_uncertainty_phrases = [ + "does not contain sufficient information", + "context does not provide", + "insufficient information", + "cannot determine", + "refuse to answer", + "cannot answer", + "does not contain relevant", + "no relevant context", + "missing from the provided context" + ] + + # Weak uncertainty phrases that might be in nuanced but correct answers + weak_uncertainty_phrases = [ + "unclear", + "conflicting", + "not specified", + "questionable", + "not contained", + "no mention", + "no relevant", + "missing", + "not explicitly" + ] + + # Check for strong uncertainty - these should drastically reduce confidence + if any(phrase in answer.lower() for phrase in strong_uncertainty_phrases): + return min(0.1, confidence * 0.2) # Max 10% for explicit refusal/uncertainty + + # Check for weak uncertainty - reduce but don't destroy confidence for good context + weak_uncertainty_count = sum(1 for phrase in weak_uncertainty_phrases if phrase in answer.lower()) + if weak_uncertainty_count > 0: + if max_relevance >= 0.7 and citations: + # Good context with citations - reduce less severely + confidence *= (0.8 ** weak_uncertainty_count) # Moderate penalty + else: + # Poor context - reduce more severely + confidence *= (0.5 ** weak_uncertainty_count) # Strong penalty + + # If all chunks have very low relevance scores, cap confidence low + if max_relevance < 0.4: + return min(0.08, confidence) # Max 8% for low relevance context + + # Factor 1: Citation quality and coverage + if citations and chunks: + citation_ratio = len(citations) / min(len(chunks), 3) + + # Strong boost for high-relevance citations + relevant_chunks = [c for c in chunks if c.get('score', 0) > 0.6] + if relevant_chunks: + # Significant boost for citing relevant chunks + confidence += 0.25 * citation_ratio + + # Extra boost if citing majority of relevant chunks + if len(citations) >= len(relevant_chunks) * 0.5: + confidence += 0.15 + else: + # Small boost for citations to lower-relevance chunks + confidence += 0.1 * citation_ratio + else: + # No citations = reduce confidence unless it's a simple factual statement + if max_relevance >= 0.8 and len(answer.split()) < 20: + confidence *= 0.8 # Gentle penalty for uncited but simple answers + else: + confidence *= 0.6 # Stronger penalty for complex uncited answers + + # Factor 2: Relevance score reinforcement + if citations: + avg_citation_relevance = sum(c.relevance_score for c in citations) / len(citations) + if avg_citation_relevance > 0.8: + confidence += 0.2 # Strong boost for highly relevant citations + elif avg_citation_relevance > 0.6: + confidence += 0.1 # Moderate boost + elif avg_citation_relevance < 0.4: + confidence *= 0.6 # Penalty for low-relevance citations + + # Factor 3: Context utilization quality + if chunks: + avg_chunk_length = sum(len(chunk.get('content', chunk.get('text', ''))) for chunk in chunks) / len(chunks) + + # Boost for substantial, high-quality context + if avg_chunk_length > 200 and max_relevance > 0.8: + confidence += 0.1 + elif avg_chunk_length < 50: # Very short chunks + confidence *= 0.8 + + # Factor 4: Answer characteristics + answer_words = len(answer.split()) + if answer_words < 10: + confidence *= 0.9 # Slight penalty for very short answers + elif answer_words > 50 and citations: + confidence += 0.05 # Small boost for detailed cited answers + + # Factor 5: High-quality scenario bonus + if (max_relevance >= 0.8 and citations and + len(citations) > 0 and + not any(phrase in answer.lower() for phrase in strong_uncertainty_phrases)): + # This is a high-quality response scenario + confidence += 0.15 + + raw_confidence = min(confidence, 0.95) # Cap at 95% to maintain some uncertainty + + # Apply temperature scaling calibration if available + if self.enable_calibration and self.calibrator and self.calibrator.is_fitted: + try: + calibrated_confidence = self.calibrator.calibrate_confidence(raw_confidence) + logger.debug(f"Confidence calibrated: {raw_confidence:.3f} -> {calibrated_confidence:.3f}") + return calibrated_confidence + except Exception as e: + logger.warning(f"Calibration failed, using raw confidence: {e}") + + return raw_confidence + + def fit_calibration(self, validation_data: List[Dict[str, Any]]) -> float: + """ + Fit temperature scaling calibration using validation data. + + Args: + validation_data: List of dicts with 'confidence' and 'correctness' keys + + Returns: + Optimal temperature parameter + """ + if not self.enable_calibration or not self.calibrator: + logger.warning("Calibration not enabled or not available") + return 1.0 + + try: + confidences = [item['confidence'] for item in validation_data] + correctness = [item['correctness'] for item in validation_data] + + optimal_temp = self.calibrator.fit_temperature_scaling(confidences, correctness) + logger.info(f"Calibration fitted with temperature: {optimal_temp:.3f}") + return optimal_temp + + except Exception as e: + logger.error(f"Failed to fit calibration: {e}") + return 1.0 + + def save_calibration(self, filepath: str) -> bool: + """Save fitted calibration to file.""" + if not self.calibrator or not self.calibrator.is_fitted: + logger.warning("No fitted calibration to save") + return False + + try: + calibration_data = { + 'temperature': self.calibrator.temperature, + 'is_fitted': self.calibrator.is_fitted, + 'model_info': { + 'primary_model': self.primary_model, + 'fallback_model': self.fallback_model + } + } + + with open(filepath, 'w') as f: + json.dump(calibration_data, f, indent=2) + + logger.info(f"Calibration saved to {filepath}") + return True + + except Exception as e: + logger.error(f"Failed to save calibration: {e}") + return False + + def load_calibration(self, filepath: str) -> bool: + """Load fitted calibration from file.""" + if not self.enable_calibration or not self.calibrator: + logger.warning("Calibration not enabled") + return False + + try: + with open(filepath, 'r') as f: + calibration_data = json.load(f) + + self.calibrator.temperature = calibration_data['temperature'] + self.calibrator.is_fitted = calibration_data['is_fitted'] + + logger.info(f"Calibration loaded from {filepath} (temp: {self.calibrator.temperature:.3f})") + return True + + except Exception as e: + logger.error(f"Failed to load calibration: {e}") + return False + + def generate( + self, + query: str, + chunks: List[Dict[str, Any]], + use_fallback: bool = False + ) -> GeneratedAnswer: + """ + Generate an answer based on the query and retrieved chunks. + + Args: + query: User's question + chunks: Retrieved document chunks + use_fallback: Whether to use fallback model + + Returns: + GeneratedAnswer object with answer, citations, and metadata + """ + start_time = datetime.now() + model = self.fallback_model if use_fallback else self.primary_model + + # Check for no-context or very poor context situation + if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks): + # Handle no-context situation with brief, professional refusal + user_prompt = f"""Context: [NO RELEVANT CONTEXT FOUND] + +Question: {query} + +INSTRUCTION: Respond with exactly this brief message: + +"This information isn't available in the provided documents." + +DO NOT elaborate, explain, or add any other information.""" + else: + # Format context from chunks + context = self._format_context(chunks) + + # Create concise prompt for faster generation + user_prompt = f"""Context: +{context} + +Question: {query} + +Instructions: Answer using only the context above. Cite with [chunk_1], [chunk_2] etc. + +Answer:""" + + try: + # Generate response + response = self.client.chat( + model=model, + messages=[ + {"role": "system", "content": self._create_system_prompt()}, + {"role": "user", "content": user_prompt} + ], + options={ + "temperature": self.temperature, + "num_predict": min(self.max_tokens, 300), # Reduce max tokens for speed + "top_k": 40, # Optimize sampling for speed + "top_p": 0.9, + "repeat_penalty": 1.1 + }, + stream=False # Get complete response for processing + ) + + # Extract answer + answer_with_citations = response['message']['content'] + + # Extract and clean citations + clean_answer, citations = self._extract_citations(answer_with_citations, chunks) + + # Calculate confidence + confidence = self._calculate_confidence(clean_answer, citations, chunks) + + # Calculate generation time + generation_time = (datetime.now() - start_time).total_seconds() + + return GeneratedAnswer( + answer=clean_answer, + citations=citations, + confidence_score=confidence, + generation_time=generation_time, + model_used=model, + context_used=chunks + ) + + except Exception as e: + logger.error(f"Error generating answer: {e}") + # Return a fallback response + return GeneratedAnswer( + answer="I apologize, but I encountered an error while generating the answer. Please try again.", + citations=[], + confidence_score=0.0, + generation_time=0.0, + model_used=model, + context_used=chunks + ) + + def generate_stream( + self, + query: str, + chunks: List[Dict[str, Any]], + use_fallback: bool = False + ) -> Generator[str, None, GeneratedAnswer]: + """ + Generate an answer with streaming support. + + Args: + query: User's question + chunks: Retrieved document chunks + use_fallback: Whether to use fallback model + + Yields: + Partial answer strings + + Returns: + Final GeneratedAnswer object + """ + start_time = datetime.now() + model = self.fallback_model if use_fallback else self.primary_model + + # Check for no-context or very poor context situation + if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks): + # Handle no-context situation with brief, professional refusal + user_prompt = f"""Context: [NO RELEVANT CONTEXT FOUND] + +Question: {query} + +INSTRUCTION: Respond with exactly this brief message: + +"This information isn't available in the provided documents." + +DO NOT elaborate, explain, or add any other information.""" + else: + # Format context from chunks + context = self._format_context(chunks) + + # Create concise prompt for faster generation + user_prompt = f"""Context: +{context} + +Question: {query} + +Instructions: Answer using only the context above. Cite with [chunk_1], [chunk_2] etc. + +Answer:""" + + try: + # Generate streaming response + stream = self.client.chat( + model=model, + messages=[ + {"role": "system", "content": self._create_system_prompt()}, + {"role": "user", "content": user_prompt} + ], + options={ + "temperature": self.temperature, + "num_predict": min(self.max_tokens, 300), # Reduce max tokens for speed + "top_k": 40, # Optimize sampling for speed + "top_p": 0.9, + "repeat_penalty": 1.1 + }, + stream=True + ) + + # Collect full answer while streaming + full_answer = "" + for chunk in stream: + if 'message' in chunk and 'content' in chunk['message']: + partial = chunk['message']['content'] + full_answer += partial + yield partial + + # Process complete answer + clean_answer, citations = self._extract_citations(full_answer, chunks) + confidence = self._calculate_confidence(clean_answer, citations, chunks) + generation_time = (datetime.now() - start_time).total_seconds() + + return GeneratedAnswer( + answer=clean_answer, + citations=citations, + confidence_score=confidence, + generation_time=generation_time, + model_used=model, + context_used=chunks + ) + + except Exception as e: + logger.error(f"Error in streaming generation: {e}") + yield "I apologize, but I encountered an error while generating the answer." + + return GeneratedAnswer( + answer="Error occurred during generation.", + citations=[], + confidence_score=0.0, + generation_time=0.0, + model_used=model, + context_used=chunks + ) + + def format_answer_with_citations(self, generated_answer: GeneratedAnswer) -> str: + """ + Format the generated answer with citations for display. + + Args: + generated_answer: GeneratedAnswer object + + Returns: + Formatted string with answer and citations + """ + formatted = f"{generated_answer.answer}\n\n" + + if generated_answer.citations: + formatted += "**Sources:**\n" + for i, citation in enumerate(generated_answer.citations, 1): + formatted += f"{i}. {citation.source_file} (Page {citation.page_number})\n" + + formatted += f"\n*Confidence: {generated_answer.confidence_score:.1%} | " + formatted += f"Model: {generated_answer.model_used} | " + formatted += f"Time: {generated_answer.generation_time:.2f}s*" + + return formatted + + +if __name__ == "__main__": + # Example usage + generator = AnswerGenerator() + + # Example chunks (would come from retrieval system) + example_chunks = [ + { + "id": "chunk_1", + "content": "RISC-V is an open-source instruction set architecture (ISA) based on reduced instruction set computer (RISC) principles.", + "metadata": {"page_number": 1, "source": "riscv-spec.pdf"}, + "score": 0.95 + }, + { + "id": "chunk_2", + "content": "The RISC-V ISA is designed to support a wide range of implementations including 32-bit, 64-bit, and 128-bit variants.", + "metadata": {"page_number": 2, "source": "riscv-spec.pdf"}, + "score": 0.89 + } + ] + + # Generate answer + result = generator.generate( + query="What is RISC-V?", + chunks=example_chunks + ) + + # Display formatted result + print(generator.format_answer_with_citations(result)) \ No newline at end of file diff --git a/shared_utils/generation/chain_of_thought_engine.py b/shared_utils/generation/chain_of_thought_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..31ff08a00312658a9e8b688af95aa37ce98ff105 --- /dev/null +++ b/shared_utils/generation/chain_of_thought_engine.py @@ -0,0 +1,565 @@ +""" +Chain-of-Thought Reasoning Engine for Complex Technical Queries. + +This module provides structured reasoning capabilities for complex technical +questions that require multi-step analysis and implementation guidance. +""" + +from typing import Dict, List, Optional, Any, Tuple +from dataclasses import dataclass +from enum import Enum +import re + +from .prompt_templates import QueryType, PromptTemplate + + +class ReasoningStep(Enum): + """Types of reasoning steps in chain-of-thought.""" + ANALYSIS = "analysis" + DECOMPOSITION = "decomposition" + SYNTHESIS = "synthesis" + VALIDATION = "validation" + IMPLEMENTATION = "implementation" + + +@dataclass +class ChainStep: + """Represents a single step in chain-of-thought reasoning.""" + step_type: ReasoningStep + description: str + prompt_addition: str + requires_context: bool = True + + +class ChainOfThoughtEngine: + """ + Engine for generating chain-of-thought reasoning prompts for complex technical queries. + + Features: + - Multi-step reasoning for complex implementations + - Context-aware step generation + - Query type specific reasoning chains + - Validation and error checking steps + """ + + def __init__(self): + """Initialize the chain-of-thought engine.""" + self.reasoning_chains = self._initialize_reasoning_chains() + + def _initialize_reasoning_chains(self) -> Dict[QueryType, List[ChainStep]]: + """Initialize reasoning chains for different query types.""" + return { + QueryType.IMPLEMENTATION: [ + ChainStep( + step_type=ReasoningStep.ANALYSIS, + description="Analyze the implementation requirements", + prompt_addition=""" +First, let me analyze what needs to be implemented: +1. What is the specific goal or functionality required? +2. What are the key components or modules involved? +3. Are there any hardware or software constraints mentioned?""" + ), + ChainStep( + step_type=ReasoningStep.DECOMPOSITION, + description="Break down into implementation steps", + prompt_addition=""" +Next, let me break this down into logical implementation steps: +1. What are the prerequisites and dependencies? +2. What is the logical sequence of implementation? +3. Which steps are critical and which are optional?""" + ), + ChainStep( + step_type=ReasoningStep.SYNTHESIS, + description="Synthesize the complete solution", + prompt_addition=""" +Now I'll synthesize the complete solution: +1. How do the individual steps connect together? +2. What code examples or configurations are needed? +3. What are the key integration points?""" + ), + ChainStep( + step_type=ReasoningStep.VALIDATION, + description="Consider validation and error handling", + prompt_addition=""" +Finally, let me consider validation and potential issues: +1. How can we verify the implementation works? +2. What are common pitfalls or error conditions? +3. What debugging or troubleshooting steps are important?""" + ) + ], + + QueryType.COMPARISON: [ + ChainStep( + step_type=ReasoningStep.ANALYSIS, + description="Analyze items being compared", + prompt_addition=""" +Let me start by analyzing what's being compared: +1. What are the specific items or concepts being compared? +2. What aspects or dimensions are relevant for comparison? +3. What context or use case should guide the comparison?""" + ), + ChainStep( + step_type=ReasoningStep.DECOMPOSITION, + description="Break down comparison criteria", + prompt_addition=""" +Next, let me identify the key comparison criteria: +1. What are the technical specifications or features to compare? +2. What are the performance characteristics? +3. What are the practical considerations (cost, complexity, etc.)?""" + ), + ChainStep( + step_type=ReasoningStep.SYNTHESIS, + description="Synthesize comparison results", + prompt_addition=""" +Now I'll synthesize the comparison: +1. How do the items compare on each criterion? +2. What are the key trade-offs and differences? +3. What recommendations can be made for different scenarios?""" + ) + ], + + QueryType.TROUBLESHOOTING: [ + ChainStep( + step_type=ReasoningStep.ANALYSIS, + description="Analyze the problem", + prompt_addition=""" +Let me start by analyzing the problem: +1. What are the specific symptoms or error conditions? +2. What system or component is affected? +3. What was the expected vs actual behavior?""" + ), + ChainStep( + step_type=ReasoningStep.DECOMPOSITION, + description="Identify potential root causes", + prompt_addition=""" +Next, let me identify potential root causes: +1. What are the most likely causes based on the symptoms? +2. What system components could be involved? +3. What external factors might contribute to the issue?""" + ), + ChainStep( + step_type=ReasoningStep.VALIDATION, + description="Develop diagnostic approach", + prompt_addition=""" +Now I'll develop a diagnostic approach: +1. What tests or checks can isolate the root cause? +2. What is the recommended sequence of diagnostic steps? +3. How can we verify the fix once implemented?""" + ) + ], + + QueryType.HARDWARE_CONSTRAINT: [ + ChainStep( + step_type=ReasoningStep.ANALYSIS, + description="Analyze hardware requirements", + prompt_addition=""" +Let me analyze the hardware requirements: +1. What are the specific hardware resources needed? +2. What are the performance requirements? +3. What are the power and size constraints?""" + ), + ChainStep( + step_type=ReasoningStep.DECOMPOSITION, + description="Break down resource utilization", + prompt_addition=""" +Next, let me break down resource utilization: +1. How much memory (RAM/Flash) is required? +2. What are the processing requirements (CPU/DSP)? +3. What I/O and peripheral requirements exist?""" + ), + ChainStep( + step_type=ReasoningStep.SYNTHESIS, + description="Evaluate feasibility and alternatives", + prompt_addition=""" +Now I'll evaluate feasibility: +1. Can the requirements be met with the available hardware? +2. What optimizations might be needed? +3. What are alternative approaches if constraints are exceeded?""" + ) + ] + } + + def generate_chain_of_thought_prompt( + self, + query: str, + query_type: QueryType, + context: str, + base_template: PromptTemplate + ) -> Dict[str, str]: + """ + Generate a chain-of-thought enhanced prompt. + + Args: + query: User's question + query_type: Type of query + context: Retrieved context + base_template: Base prompt template + + Returns: + Enhanced prompt with chain-of-thought reasoning + """ + # Get reasoning chain for query type + reasoning_chain = self.reasoning_chains.get(query_type, []) + + if not reasoning_chain: + # Fall back to generic reasoning for unsupported types + reasoning_chain = self._generate_generic_reasoning_chain(query) + + # Build chain-of-thought prompt + cot_prompt = self._build_cot_prompt(reasoning_chain, query, context) + + # Enhance system prompt + enhanced_system = base_template.system_prompt + """ + +CHAIN-OF-THOUGHT REASONING: You will approach this question using structured reasoning. +Work through each step methodically before providing your final answer. +Show your reasoning process clearly, then provide a comprehensive final answer.""" + + # Enhance user prompt + enhanced_user = f"""{base_template.context_format.format(context=context)} + +{base_template.query_format.format(query=query)} + +{cot_prompt} + +{base_template.answer_guidelines} + +After working through your reasoning, provide your final answer in the requested format.""" + + return { + "system": enhanced_system, + "user": enhanced_user + } + + def _build_cot_prompt( + self, + reasoning_chain: List[ChainStep], + query: str, + context: str + ) -> str: + """ + Build the chain-of-thought prompt section. + + Args: + reasoning_chain: List of reasoning steps + query: User's question + context: Retrieved context + + Returns: + Chain-of-thought prompt text + """ + cot_sections = [ + "REASONING PROCESS:", + "Work through this step-by-step using the following reasoning framework:", + "" + ] + + for i, step in enumerate(reasoning_chain, 1): + cot_sections.append(f"Step {i}: {step.description}") + cot_sections.append(step.prompt_addition) + cot_sections.append("") + + cot_sections.extend([ + "STRUCTURED REASONING:", + "Now work through each step above, referencing the provided context where relevant.", + "Use [chunk_X] citations for your reasoning at each step.", + "" + ]) + + return "\n".join(cot_sections) + + def _generate_generic_reasoning_chain(self, query: str) -> List[ChainStep]: + """ + Generate a generic reasoning chain for unsupported query types. + + Args: + query: User's question + + Returns: + List of generic reasoning steps + """ + # Analyze query complexity to determine appropriate steps + complexity_indicators = { + "multi_part": ["and", "also", "additionally", "furthermore"], + "causal": ["why", "because", "cause", "reason"], + "conditional": ["if", "when", "unless", "provided that"], + "comparative": ["better", "worse", "compare", "versus", "vs"] + } + + query_lower = query.lower() + steps = [] + + # Always start with analysis + steps.append(ChainStep( + step_type=ReasoningStep.ANALYSIS, + description="Analyze the question", + prompt_addition=""" +Let me start by analyzing the question: +1. What is the core question being asked? +2. What context or domain knowledge is needed? +3. Are there multiple parts to this question?""" + )) + + # Add decomposition for complex queries + if any(indicator in query_lower for indicators in complexity_indicators.values() for indicator in indicators): + steps.append(ChainStep( + step_type=ReasoningStep.DECOMPOSITION, + description="Break down the question", + prompt_addition=""" +Let me break this down into components: +1. What are the key concepts or elements involved? +2. How do these elements relate to each other? +3. What information do I need to address each part?""" + )) + + # Always end with synthesis + steps.append(ChainStep( + step_type=ReasoningStep.SYNTHESIS, + description="Synthesize the answer", + prompt_addition=""" +Now I'll synthesize a comprehensive answer: +1. How do all the pieces fit together? +2. What is the most complete and accurate response? +3. Are there any important caveats or limitations?""" + )) + + return steps + + def create_reasoning_validation_prompt( + self, + query: str, + proposed_answer: str, + context: str + ) -> str: + """ + Create a prompt for validating chain-of-thought reasoning. + + Args: + query: Original query + proposed_answer: Generated answer to validate + context: Context used for the answer + + Returns: + Validation prompt + """ + return f""" +REASONING VALIDATION TASK: + +Original Query: {query} + +Proposed Answer: {proposed_answer} + +Context Used: {context} + +Please validate the reasoning in the proposed answer by checking: + +1. LOGICAL CONSISTENCY: + - Are the reasoning steps logically connected? + - Are there any contradictions or gaps in logic? + - Does the conclusion follow from the premises? + +2. FACTUAL ACCURACY: + - Are the facts and technical details correct? + - Are the citations appropriate and accurate? + - Is the information consistent with the provided context? + +3. COMPLETENESS: + - Does the answer address all parts of the question? + - Are important considerations or caveats mentioned? + - Is the level of detail appropriate for the question? + +4. CLARITY: + - Is the reasoning easy to follow? + - Are technical terms used correctly? + - Is the structure logical and well-organized? + +Provide your validation assessment with specific feedback on any issues found. +""" + + def extract_reasoning_steps(self, cot_response: str) -> List[Dict[str, str]]: + """ + Extract reasoning steps from a chain-of-thought response. + + Args: + cot_response: Response containing chain-of-thought reasoning + + Returns: + List of extracted reasoning steps + """ + steps = [] + + # Look for step patterns + step_patterns = [ + r"Step \d+:?\s*(.+?)(?=Step \d+|$)", + r"First,?\s*(.+?)(?=Next,?|Second,?|Then,?|Finally,?|$)", + r"Next,?\s*(.+?)(?=Then,?|Finally,?|Now,?|$)", + r"Then,?\s*(.+?)(?=Finally,?|Now,?|$)", + r"Finally,?\s*(.+?)(?=\n\n|$)" + ] + + for pattern in step_patterns: + matches = re.findall(pattern, cot_response, re.DOTALL | re.IGNORECASE) + for match in matches: + if match.strip(): + steps.append({ + "step_text": match.strip(), + "pattern": pattern + }) + + return steps + + def evaluate_reasoning_quality(self, reasoning_steps: List[Dict[str, str]]) -> Dict[str, float]: + """ + Evaluate the quality of chain-of-thought reasoning. + + Args: + reasoning_steps: List of reasoning steps + + Returns: + Dictionary of quality metrics + """ + if not reasoning_steps: + return {"overall_quality": 0.0, "step_count": 0} + + # Evaluate different aspects + metrics = { + "step_count": len(reasoning_steps), + "logical_flow": self._evaluate_logical_flow(reasoning_steps), + "technical_depth": self._evaluate_technical_depth(reasoning_steps), + "citation_usage": self._evaluate_citation_usage(reasoning_steps), + "completeness": self._evaluate_completeness(reasoning_steps) + } + + # Calculate overall quality + quality_weights = { + "logical_flow": 0.3, + "technical_depth": 0.3, + "citation_usage": 0.2, + "completeness": 0.2 + } + + overall_quality = sum( + metrics[key] * quality_weights[key] + for key in quality_weights + ) + + metrics["overall_quality"] = overall_quality + + return metrics + + def _evaluate_logical_flow(self, steps: List[Dict[str, str]]) -> float: + """Evaluate logical flow between reasoning steps.""" + if len(steps) < 2: + return 0.5 + + # Check for logical connectors + connectors = ["therefore", "thus", "because", "since", "as a result", "consequently"] + connector_count = 0 + + for step in steps: + step_text = step["step_text"].lower() + if any(connector in step_text for connector in connectors): + connector_count += 1 + + return min(connector_count / len(steps), 1.0) + + def _evaluate_technical_depth(self, steps: List[Dict[str, str]]) -> float: + """Evaluate technical depth of reasoning.""" + technical_terms = [ + "implementation", "architecture", "algorithm", "protocol", "specification", + "optimization", "configuration", "register", "memory", "hardware", + "software", "system", "component", "module", "interface" + ] + + total_terms = 0 + total_words = 0 + + for step in steps: + words = step["step_text"].lower().split() + total_words += len(words) + + for term in technical_terms: + total_terms += words.count(term) + + return min(total_terms / max(total_words, 1) * 100, 1.0) + + def _evaluate_citation_usage(self, steps: List[Dict[str, str]]) -> float: + """Evaluate citation usage in reasoning.""" + citation_pattern = r'\[chunk_\d+\]' + total_citations = 0 + + for step in steps: + citations = re.findall(citation_pattern, step["step_text"]) + total_citations += len(citations) + + # Good reasoning should have at least one citation per step + return min(total_citations / len(steps), 1.0) + + def _evaluate_completeness(self, steps: List[Dict[str, str]]) -> float: + """Evaluate completeness of reasoning.""" + # Check for key reasoning elements + completeness_indicators = [ + "analysis", "consider", "examine", "evaluate", + "conclusion", "summary", "result", "therefore", + "requirement", "constraint", "limitation", "important" + ] + + indicator_count = 0 + for step in steps: + step_text = step["step_text"].lower() + for indicator in completeness_indicators: + if indicator in step_text: + indicator_count += 1 + break + + return indicator_count / len(steps) + + +# Example usage +if __name__ == "__main__": + # Initialize engine + cot_engine = ChainOfThoughtEngine() + + # Example implementation query + query = "How do I implement a real-time task scheduler in FreeRTOS with priority inheritance?" + query_type = QueryType.IMPLEMENTATION + context = "FreeRTOS supports priority-based scheduling with optional priority inheritance..." + + # Create a basic template + base_template = PromptTemplate( + system_prompt="You are a technical assistant.", + context_format="Context: {context}", + query_format="Question: {query}", + answer_guidelines="Provide a structured answer." + ) + + # Generate chain-of-thought prompt + cot_prompt = cot_engine.generate_chain_of_thought_prompt( + query=query, + query_type=query_type, + context=context, + base_template=base_template + ) + + print("Chain-of-Thought Enhanced Prompt:") + print("=" * 50) + print("System:", cot_prompt["system"][:200], "...") + print("User:", cot_prompt["user"][:300], "...") + print("=" * 50) + + # Example reasoning evaluation + example_response = """ + Step 1: Let me analyze the requirements + FreeRTOS provides priority-based scheduling [chunk_1]... + + Step 2: Breaking down the implementation + Priority inheritance requires mutex implementation [chunk_2]... + + Step 3: Synthesizing the solution + Therefore, we need to configure priority inheritance in FreeRTOS [chunk_3]... + """ + + steps = cot_engine.extract_reasoning_steps(example_response) + quality = cot_engine.evaluate_reasoning_quality(steps) + + print(f"Reasoning Quality: {quality}") \ No newline at end of file diff --git a/shared_utils/generation/hf_answer_generator.py b/shared_utils/generation/hf_answer_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..ba5720f44db57c9e4388b573177a87b879317b38 --- /dev/null +++ b/shared_utils/generation/hf_answer_generator.py @@ -0,0 +1,881 @@ +""" +HuggingFace API-based answer generation for deployment environments. + +This module provides answer generation using HuggingFace's Inference API, +optimized for cloud deployment where local LLMs aren't feasible. +""" + +import json +import logging +from dataclasses import dataclass +from typing import List, Dict, Any, Optional, Generator, Tuple +from datetime import datetime +import re +from pathlib import Path +import requests +import os +import sys + +# Import technical prompt templates +from .prompt_templates import TechnicalPromptTemplates + +# Import standard interfaces (add this for the adapter) +try: + from pathlib import Path + import sys + project_root = Path(__file__).parent.parent.parent.parent.parent + sys.path.append(str(project_root)) + from src.core.interfaces import Document, Answer, AnswerGenerator +except ImportError: + # Fallback for standalone usage + Document = None + Answer = None + AnswerGenerator = object + +logger = logging.getLogger(__name__) + + +@dataclass +class Citation: + """Represents a citation to a source document chunk.""" + chunk_id: str + page_number: int + source_file: str + relevance_score: float + text_snippet: str + + +@dataclass +class GeneratedAnswer: + """Represents a generated answer with citations.""" + answer: str + citations: List[Citation] + confidence_score: float + generation_time: float + model_used: str + context_used: List[Dict[str, Any]] + + +class HuggingFaceAnswerGenerator(AnswerGenerator if AnswerGenerator != object else object): + """ + Generates answers using HuggingFace Inference API with hybrid reliability. + + 🎯 HYBRID APPROACH - Best of Both Worlds: + - Primary: High-quality open models (Zephyr-7B, Mistral-7B-Instruct) + - Fallback: Reliable classics (DialoGPT-medium) + - Foundation: HF GPT's proven Docker + auth setup + - Pro Benefits: Better rate limits, priority processing + + Optimized for deployment environments with: + - Fast API-based inference + - No local model requirements + - Citation extraction and formatting + - Confidence scoring + - Automatic fallback for reliability + """ + + def __init__( + self, + model_name: str = "sshleifer/distilbart-cnn-12-6", + api_token: Optional[str] = None, + temperature: float = 0.3, + max_tokens: int = 512 + ): + """ + Initialize the HuggingFace answer generator. + + Args: + model_name: HuggingFace model to use + api_token: HF API token (optional, uses free tier if None) + temperature: Generation temperature (0.0-1.0) + max_tokens: Maximum tokens to generate + """ + self.model_name = model_name + # Try multiple common token environment variable names + self.api_token = (api_token or + os.getenv("HUGGINGFACE_API_TOKEN") or + os.getenv("HF_TOKEN") or + os.getenv("HF_API_TOKEN")) + self.temperature = temperature + self.max_tokens = max_tokens + + # Hybrid approach: Classic API + fallback models + self.api_url = f"https://api-inference.huggingface.co/models/{model_name}" + + # Prepare headers + self.headers = {"Content-Type": "application/json"} + self._auth_failed = False # Track if auth has failed + if self.api_token: + self.headers["Authorization"] = f"Bearer {self.api_token}" + logger.info("Using authenticated HuggingFace API") + else: + logger.info("Using free HuggingFace API (rate limited)") + + # Only include models that actually work based on tests + self.fallback_models = [ + "deepset/roberta-base-squad2", # Q&A model - perfect for RAG + "sshleifer/distilbart-cnn-12-6", # Summarization - also good + "facebook/bart-base", # Base BART - works but needs right format + ] + + def _make_api_request(self, url: str, payload: dict, timeout: int = 30) -> requests.Response: + """Make API request with automatic 401 handling.""" + # Use current headers (may have been updated if auth failed) + headers = self.headers.copy() + + # If we've already had auth failure, don't include the token + if self._auth_failed and "Authorization" in headers: + del headers["Authorization"] + + response = requests.post(url, headers=headers, json=payload, timeout=timeout) + + # Handle 401 error + if response.status_code == 401 and not self._auth_failed and self.api_token: + logger.error(f"API request failed: 401 Unauthorized") + logger.error(f"Response body: {response.text}") + logger.warning("Token appears invalid, retrying without authentication...") + self._auth_failed = True + # Remove auth header + if "Authorization" in self.headers: + del self.headers["Authorization"] + headers = self.headers.copy() + # Retry without auth + response = requests.post(url, headers=headers, json=payload, timeout=timeout) + if response.status_code == 401: + logger.error("Still getting 401 even without auth token") + logger.error(f"Response body: {response.text}") + + return response + + def _call_api_with_model(self, prompt: str, model_name: str) -> str: + """Call API with a specific model (for fallback support).""" + fallback_url = f"https://api-inference.huggingface.co/models/{model_name}" + + # SIMPLIFIED payload that works + payload = {"inputs": prompt} + + # Use helper method with 401 handling + response = self._make_api_request(fallback_url, payload) + + response.raise_for_status() + result = response.json() + + # Handle response + if isinstance(result, list) and len(result) > 0: + if isinstance(result[0], dict): + return result[0].get("generated_text", "").strip() + else: + return str(result[0]).strip() + elif isinstance(result, dict): + return result.get("generated_text", "").strip() + else: + return str(result).strip() + + def _create_system_prompt(self) -> str: + """Create system prompt optimized for the model type.""" + if "squad" in self.model_name.lower() or "roberta" in self.model_name.lower(): + # RoBERTa Squad2 uses question/context format - no system prompt needed + return "" + elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower(): + # GPT-2 style completion prompt - simpler is better + return "Based on the following context, answer the question.\n\nContext: " + elif "llama" in self.model_name.lower(): + # Llama-2 chat format + return """[INST] You are a helpful technical documentation assistant. Answer the user's question based only on the provided context. Always cite sources using [chunk_X] format. + +Context:""" + elif "flan" in self.model_name.lower() or "t5" in self.model_name.lower(): + # Flan-T5 instruction format - simple and direct + return """Answer the question based on the context below. Cite sources using [chunk_X] format. + +Context: """ + elif "falcon" in self.model_name.lower(): + # Falcon instruction format + return """### Instruction: Answer based on the context and cite sources with [chunk_X]. + +### Context: """ + elif "bart" in self.model_name.lower(): + # BART summarization format + return """Summarize the answer to the question from the context. Use [chunk_X] for citations. + +Context: """ + else: + # Default instruction prompt for other models + return """You are a technical documentation assistant that provides clear, accurate answers based on the provided context. + +CORE PRINCIPLES: +1. ANSWER DIRECTLY: If context contains the answer, provide it clearly and confidently +2. BE CONCISE: Keep responses focused and avoid unnecessary uncertainty language +3. CITE ACCURATELY: Use [chunk_X] citations for every fact from context + +RESPONSE GUIDELINES: +- If context has sufficient information → Answer directly and confidently +- If context has partial information → Answer what's available, note what's missing briefly +- If context is irrelevant → Brief refusal: "This information isn't available in the provided documents" + +CITATION FORMAT: +- Use [chunk_1], [chunk_2] etc. for all facts from context +- Example: "According to [chunk_1], RISC-V is an open-source architecture." + +Be direct, confident, and accurate. If the context answers the question, provide that answer clearly.""" + + def _format_context(self, chunks: List[Dict[str, Any]]) -> str: + """ + Format retrieved chunks into context for the LLM. + + Args: + chunks: List of retrieved chunks with metadata + + Returns: + Formatted context string + """ + context_parts = [] + + for i, chunk in enumerate(chunks): + chunk_text = chunk.get('content', chunk.get('text', '')) + page_num = chunk.get('metadata', {}).get('page_number', 'unknown') + source = chunk.get('metadata', {}).get('source', 'unknown') + + context_parts.append( + f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n" + ) + + return "\n---\n".join(context_parts) + + def _call_api(self, prompt: str) -> str: + """ + Call HuggingFace Inference API. + + Args: + prompt: Input prompt for the model + + Returns: + Generated text response + """ + # Validate prompt + if not prompt or len(prompt.strip()) < 5: + logger.warning(f"Prompt too short: '{prompt}' - padding it") + prompt = f"Please provide information about: {prompt}. Based on the context, give a detailed answer." + + # Model-specific payload formatting + if "squad" in self.model_name.lower() or "roberta" in self.model_name.lower(): + # RoBERTa Squad2 needs question and context separately + # Parse the structured prompt format we create + if "Context:" in prompt and "Question:" in prompt: + # Split by the markers we use + parts = prompt.split("Question:") + if len(parts) == 2: + context_part = parts[0].replace("Context:", "").strip() + question_part = parts[1].strip() + else: + # Fallback + question_part = "What is this about?" + context_part = prompt + else: + # Fallback for unexpected format + question_part = "What is this about?" + context_part = prompt + + # Clean up the context and question + context_part = context_part.replace("---", "").strip() + if not question_part or len(question_part.strip()) < 3: + question_part = "What is the main information?" + + # Debug output + print(f"🔍 Squad2 Question: {question_part[:100]}...") + print(f"🔍 Squad2 Context: {context_part[:200]}...") + + payload = { + "inputs": { + "question": question_part, + "context": context_part + } + } + elif "bart" in self.model_name.lower() or "distilbart" in self.model_name.lower(): + # BART/DistilBART for summarization + if len(prompt) < 50: + prompt = f"{prompt} Please provide a comprehensive answer based on the available information." + + payload = { + "inputs": prompt, + "parameters": { + "max_length": 150, + "min_length": 10, + "do_sample": False + } + } + else: + # Simple payload for other models + payload = {"inputs": prompt} + + try: + logger.info(f"Calling API URL: {self.api_url}") + logger.info(f"Headers: {self.headers}") + logger.info(f"Payload: {payload}") + + # Use helper method with 401 handling + response = self._make_api_request(self.api_url, payload) + + logger.info(f"Response status: {response.status_code}") + logger.info(f"Response headers: {response.headers}") + + if response.status_code == 503: + # Model is loading, wait and retry + logger.warning("Model loading, waiting 20 seconds...") + import time + time.sleep(20) + response = self._make_api_request(self.api_url, payload) + logger.info(f"Retry response status: {response.status_code}") + + elif response.status_code == 404: + logger.error(f"Model not found: {self.model_name}") + logger.error(f"Response text: {response.text}") + # Try fallback models + for fallback_model in self.fallback_models: + if fallback_model != self.model_name: + logger.info(f"Trying fallback model: {fallback_model}") + try: + return self._call_api_with_model(prompt, fallback_model) + except Exception as e: + logger.warning(f"Fallback model {fallback_model} failed: {e}") + continue + return "All models are currently unavailable. Please try again later." + + response.raise_for_status() + result = response.json() + + # Handle different response formats based on model type + print(f"🔍 API Response type: {type(result)}") + print(f"🔍 API Response preview: {str(result)[:300]}...") + + if isinstance(result, dict) and "answer" in result: + # RoBERTa Squad2 format: {"answer": "...", "score": ..., "start": ..., "end": ...} + answer = result["answer"].strip() + print(f"🔍 Squad2 extracted answer: {answer}") + return answer + elif isinstance(result, list) and len(result) > 0: + # Check for DistilBART format (returns dict with summary_text) + if isinstance(result[0], dict) and "summary_text" in result[0]: + return result[0]["summary_text"].strip() + # Check for nested list (BART format: [[...]]) + elif isinstance(result[0], list) and len(result[0]) > 0: + if isinstance(result[0][0], dict): + return result[0][0].get("summary_text", str(result[0][0])).strip() + else: + # BART base returns embeddings - not useful for text generation + logger.warning("BART returned embeddings instead of text") + return "Model returned embeddings instead of text. Please try a different model." + # Regular list format + elif isinstance(result[0], dict): + # Try different keys that models might use + text = (result[0].get("generated_text", "") or + result[0].get("summary_text", "") or + result[0].get("translation_text", "") or + result[0].get("answer", "") or + str(result[0])) + # Remove the input prompt from the output if present + if isinstance(prompt, str) and text.startswith(prompt): + text = text[len(prompt):].strip() + return text + else: + return str(result[0]).strip() + elif isinstance(result, dict): + # Some models return dict directly + text = (result.get("generated_text", "") or + result.get("summary_text", "") or + result.get("translation_text", "") or + result.get("answer", "") or + str(result)) + # Remove input prompt if model included it + if isinstance(prompt, str) and text.startswith(prompt): + text = text[len(prompt):].strip() + return text + elif isinstance(result, str): + return result.strip() + else: + logger.error(f"Unexpected response format: {type(result)} - {result}") + return "I apologize, but I couldn't generate a response." + + except requests.exceptions.RequestException as e: + logger.error(f"API request failed: {e}") + if hasattr(e, 'response') and e.response is not None: + logger.error(f"Response status: {e.response.status_code}") + logger.error(f"Response body: {e.response.text}") + return f"API Error: {str(e)}. Using free tier? Try adding an API token." + except Exception as e: + logger.error(f"Unexpected error: {e}") + import traceback + logger.error(f"Traceback: {traceback.format_exc()}") + return f"Error: {str(e)}. Please check logs for details." + + def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]: + """ + Extract citations from the generated answer and integrate them naturally. + + Args: + answer: Generated answer with [chunk_X] citations + chunks: Original chunks used for context + + Returns: + Tuple of (natural_answer, citations) + """ + citations = [] + citation_pattern = r'\[chunk_(\d+)\]' + + cited_chunks = set() + + # Find [chunk_X] citations and collect cited chunks + matches = re.finditer(citation_pattern, answer) + for match in matches: + chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index + if 0 <= chunk_idx < len(chunks): + cited_chunks.add(chunk_idx) + + # FALLBACK: If no explicit citations found but we have an answer and chunks, + # create citations for the top chunks that were likely used + if not cited_chunks and chunks and len(answer.strip()) > 50: + # Use the top chunks that were provided as likely sources + num_fallback_citations = min(3, len(chunks)) # Use top 3 chunks max + cited_chunks = set(range(num_fallback_citations)) + print(f"🔧 HF Fallback: Creating {num_fallback_citations} citations for answer without explicit [chunk_X] references", file=sys.stderr, flush=True) + + # Create Citation objects for each cited chunk + chunk_to_source = {} + for idx in cited_chunks: + chunk = chunks[idx] + citation = Citation( + chunk_id=chunk.get('id', f'chunk_{idx}'), + page_number=chunk.get('metadata', {}).get('page_number', 0), + source_file=chunk.get('metadata', {}).get('source', 'unknown'), + relevance_score=chunk.get('score', 0.0), + text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...' + ) + citations.append(citation) + + # Map chunk reference to natural source name + source_name = chunk.get('metadata', {}).get('source', 'unknown') + if source_name != 'unknown': + # Use just the filename without extension for natural reference + natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ') + chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation" + else: + chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation" + + # Replace [chunk_X] with natural references instead of removing them + natural_answer = answer + for chunk_ref, natural_ref in chunk_to_source.items(): + natural_answer = natural_answer.replace(chunk_ref, natural_ref) + + # Clean up any remaining unreferenced citations (fallback) + natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer) + + # Clean up multiple spaces and formatting + natural_answer = re.sub(r'\s+', ' ', natural_answer).strip() + + return natural_answer, citations + + def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float: + """ + Calculate confidence score for the generated answer. + + Args: + answer: Generated answer + citations: Extracted citations + chunks: Retrieved chunks + + Returns: + Confidence score (0.0-1.0) + """ + if not chunks: + return 0.05 # No context = very low confidence + + # Base confidence from context quality + scores = [chunk.get('score', 0) for chunk in chunks] + max_relevance = max(scores) if scores else 0 + + if max_relevance >= 0.8: + confidence = 0.7 # High-quality context + elif max_relevance >= 0.6: + confidence = 0.5 # Good context + elif max_relevance >= 0.4: + confidence = 0.3 # Fair context + else: + confidence = 0.1 # Poor context + + # Uncertainty indicators + uncertainty_phrases = [ + "does not contain sufficient information", + "context does not provide", + "insufficient information", + "cannot determine", + "not available in the provided documents" + ] + + if any(phrase in answer.lower() for phrase in uncertainty_phrases): + return min(0.15, confidence * 0.3) + + # Citation bonus + if citations and chunks: + citation_ratio = len(citations) / min(len(chunks), 3) + confidence += 0.2 * citation_ratio + + return min(confidence, 0.9) # Cap at 90% + + def generate(self, query: str, context: List[Document]) -> Answer: + """ + Generate an answer from query and context documents (standard interface). + + This is the public interface that conforms to the AnswerGenerator protocol. + It handles the conversion between standard Document objects and HuggingFace's + internal chunk format. + + Args: + query: User's question + context: List of relevant Document objects + + Returns: + Answer object conforming to standard interface + + Raises: + ValueError: If query is empty or context is None + """ + if not query.strip(): + raise ValueError("Query cannot be empty") + + if context is None: + raise ValueError("Context cannot be None") + + # Internal adapter: Convert Documents to HuggingFace chunk format + hf_chunks = self._documents_to_hf_chunks(context) + + # Use existing HuggingFace-specific generation logic + hf_result = self._generate_internal(query, hf_chunks) + + # Internal adapter: Convert HuggingFace result to standard Answer + return self._hf_result_to_answer(hf_result, context) + + def _generate_internal( + self, + query: str, + chunks: List[Dict[str, Any]] + ) -> GeneratedAnswer: + """ + Generate an answer based on the query and retrieved chunks. + + Args: + query: User's question + chunks: Retrieved document chunks + + Returns: + GeneratedAnswer object with answer, citations, and metadata + """ + start_time = datetime.now() + + # Check for no-context situation + if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks): + return GeneratedAnswer( + answer="This information isn't available in the provided documents.", + citations=[], + confidence_score=0.05, + generation_time=0.1, + model_used=self.model_name, + context_used=chunks + ) + + # Format context from chunks + context = self._format_context(chunks) + + # Create prompt using TechnicalPromptTemplates for consistency + prompt_data = TechnicalPromptTemplates.format_prompt_with_template( + query=query, + context=context + ) + + # Format for specific model types + if "squad" in self.model_name.lower() or "roberta" in self.model_name.lower(): + # Squad2 uses special question/context format - handled in _call_api + prompt = f"Context: {context}\n\nQuestion: {query}" + elif "gpt2" in self.model_name.lower() or "distilgpt2" in self.model_name.lower(): + # Simple completion style for GPT-2 + prompt = f"""{prompt_data['system']} + +{prompt_data['user']} + +MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. + +Answer:""" + elif "llama" in self.model_name.lower(): + # Llama-2 chat format with technical templates + prompt = f"""[INST] {prompt_data['system']} + +{prompt_data['user']} + +MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]""" + elif "mistral" in self.model_name.lower(): + # Mistral instruction format with technical templates + prompt = f"""[INST] {prompt_data['system']} + +{prompt_data['user']} + +MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]""" + elif "codellama" in self.model_name.lower(): + # CodeLlama instruction format with technical templates + prompt = f"""[INST] {prompt_data['system']} + +{prompt_data['user']} + +MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]""" + elif "distilbart" in self.model_name.lower(): + # DistilBART is a summarization model - simpler prompt works better + prompt = f"""Technical Documentation Context: +{context} + +Question: {query} + +Instructions: Provide a technical answer using only the context above. Include source citations.""" + else: + # Default instruction prompt with technical templates + prompt = f"""{prompt_data['system']} + +{prompt_data['user']} + +MANDATORY: Use [chunk_1], [chunk_2] etc. for all factual statements. + +Answer:""" + + # Generate response + try: + answer_with_citations = self._call_api(prompt) + + # Extract and clean citations + clean_answer, citations = self._extract_citations(answer_with_citations, chunks) + + # Calculate confidence + confidence = self._calculate_confidence(clean_answer, citations, chunks) + + # Calculate generation time + generation_time = (datetime.now() - start_time).total_seconds() + + return GeneratedAnswer( + answer=clean_answer, + citations=citations, + confidence_score=confidence, + generation_time=generation_time, + model_used=self.model_name, + context_used=chunks + ) + + except Exception as e: + logger.error(f"Error generating answer: {e}") + return GeneratedAnswer( + answer="I apologize, but I encountered an error while generating the answer. Please try again.", + citations=[], + confidence_score=0.0, + generation_time=0.0, + model_used=self.model_name, + context_used=chunks + ) + + def generate_with_custom_prompt( + self, + query: str, + chunks: List[Dict[str, Any]], + custom_prompt: Dict[str, str] + ) -> GeneratedAnswer: + """ + Generate answer using a custom prompt (for adaptive prompting). + + Args: + query: User's question + chunks: Retrieved context chunks + custom_prompt: Dict with 'system' and 'user' prompts + + Returns: + GeneratedAnswer with custom prompt enhancement + """ + start_time = datetime.now() + + # Format context + context = self._format_context(chunks) + + # Build prompt using custom format + if "llama" in self.model_name.lower(): + prompt = f"""[INST] {custom_prompt['system']} + +{custom_prompt['user']} + +MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]""" + elif "mistral" in self.model_name.lower(): + prompt = f"""[INST] {custom_prompt['system']} + +{custom_prompt['user']} + +MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]""" + elif "distilbart" in self.model_name.lower(): + # For BART, use the user prompt directly (it already contains context) + prompt = custom_prompt['user'] + else: + # Default format + prompt = f"""{custom_prompt['system']} + +{custom_prompt['user']} + +MANDATORY: Use [chunk_1], [chunk_2] etc. for all factual statements. + +Answer:""" + + # Generate response + try: + answer_with_citations = self._call_api(prompt) + + # Extract and clean citations + clean_answer, citations = self._extract_citations(answer_with_citations, chunks) + + # Calculate confidence + confidence = self._calculate_confidence(clean_answer, citations, chunks) + + # Calculate generation time + generation_time = (datetime.now() - start_time).total_seconds() + + return GeneratedAnswer( + answer=clean_answer, + citations=citations, + confidence_score=confidence, + generation_time=generation_time, + model_used=self.model_name, + context_used=chunks + ) + + except Exception as e: + logger.error(f"Error generating answer with custom prompt: {e}") + return GeneratedAnswer( + answer="I apologize, but I encountered an error while generating the answer. Please try again.", + citations=[], + confidence_score=0.0, + generation_time=0.0, + model_used=self.model_name, + context_used=chunks + ) + + def format_answer_with_citations(self, generated_answer: GeneratedAnswer) -> str: + """ + Format the generated answer with citations for display. + + Args: + generated_answer: GeneratedAnswer object + + Returns: + Formatted string with answer and citations + """ + formatted = f"{generated_answer.answer}\n\n" + + if generated_answer.citations: + formatted += "**Sources:**\n" + for i, citation in enumerate(generated_answer.citations, 1): + formatted += f"{i}. {citation.source_file} (Page {citation.page_number})\n" + + formatted += f"\n*Confidence: {generated_answer.confidence_score:.1%} | " + formatted += f"Model: {generated_answer.model_used} | " + formatted += f"Time: {generated_answer.generation_time:.2f}s*" + + return formatted + + def _documents_to_hf_chunks(self, documents: List[Document]) -> List[Dict[str, Any]]: + """ + Convert Document objects to HuggingFace's internal chunk format. + + This internal adapter ensures that Document objects are properly formatted + for HuggingFace's processing pipeline while keeping the format requirements + encapsulated within this class. + + Args: + documents: List of Document objects from the standard interface + + Returns: + List of chunk dictionaries in HuggingFace's expected format + """ + if not documents: + return [] + + chunks = [] + for i, doc in enumerate(documents): + chunk = { + "id": f"chunk_{i+1}", + "content": doc.content, # HuggingFace expects "content" field + "text": doc.content, # Alternative field for compatibility + "score": 1.0, # Default relevance score + "metadata": { + "page_number": doc.metadata.get("start_page", 1), + "source": doc.metadata.get("source", "unknown"), + **doc.metadata # Include all original metadata + } + } + chunks.append(chunk) + + return chunks + + def _hf_result_to_answer(self, hf_result: GeneratedAnswer, original_context: List[Document]) -> Answer: + """ + Convert HuggingFace's GeneratedAnswer to the standard Answer format. + + This internal adapter converts HuggingFace's result format back to the + standard interface format expected by the rest of the system. + + Args: + hf_result: Result from HuggingFace's internal generation + original_context: Original Document objects for sources + + Returns: + Answer object conforming to standard interface + """ + if Answer is None: + # Fallback if standard interface not available + return hf_result + + # Convert to standard Answer format + return Answer( + text=hf_result.answer, + sources=original_context, # Use original Document objects + confidence=hf_result.confidence_score, + metadata={ + "model_used": hf_result.model_used, + "generation_time": hf_result.generation_time, + "citations": [ + { + "chunk_id": cit.chunk_id, + "page_number": cit.page_number, + "source_file": cit.source_file, + "relevance_score": cit.relevance_score, + "text_snippet": cit.text_snippet + } + for cit in hf_result.citations + ], + "provider": "huggingface", + "api_token_used": bool(self.api_token), + "fallback_used": hasattr(self, '_auth_failed') and self._auth_failed + } + ) + + +if __name__ == "__main__": + # Example usage + generator = HuggingFaceAnswerGenerator() + + # Example chunks (would come from retrieval system) + example_chunks = [ + { + "id": "chunk_1", + "content": "RISC-V is an open-source instruction set architecture (ISA) based on reduced instruction set computer (RISC) principles.", + "metadata": {"page_number": 1, "source": "riscv-spec.pdf"}, + "score": 0.95 + } + ] + + # Generate answer + result = generator.generate( + query="What is RISC-V?", + chunks=example_chunks + ) + + # Display formatted result + print(generator.format_answer_with_citations(result)) \ No newline at end of file diff --git a/shared_utils/generation/inference_providers_generator.py b/shared_utils/generation/inference_providers_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..47f888567454b0ec7b5b23dbb2a014b51e4adf5c --- /dev/null +++ b/shared_utils/generation/inference_providers_generator.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python3 +""" +HuggingFace Inference Providers API-based answer generation. + +This module provides answer generation using HuggingFace's new Inference Providers API, +which offers OpenAI-compatible chat completion format for better reliability and consistency. +""" + +import os +import sys +import logging +import time +from datetime import datetime +from typing import List, Dict, Any, Optional, Tuple +from pathlib import Path +import re + +# Import shared components +from .hf_answer_generator import Citation, GeneratedAnswer +from .prompt_templates import TechnicalPromptTemplates + +# Check if huggingface_hub is new enough for InferenceClient chat completion +try: + from huggingface_hub import InferenceClient + from huggingface_hub import __version__ as hf_hub_version + print(f"🔍 Using huggingface_hub version: {hf_hub_version}", file=sys.stderr, flush=True) +except ImportError: + print("❌ huggingface_hub not found or outdated. Please install: pip install -U huggingface-hub", file=sys.stderr, flush=True) + raise + +logger = logging.getLogger(__name__) + + +class InferenceProvidersGenerator: + """ + Generates answers using HuggingFace Inference Providers API. + + This uses the new OpenAI-compatible chat completion format for better reliability + compared to the classic Inference API. It provides: + - Consistent response format across models + - Better error handling and retry logic + - Support for streaming responses + - Automatic provider selection and failover + """ + + # Models that work well with chat completion format + CHAT_MODELS = [ + "microsoft/DialoGPT-medium", # Proven conversational model + "google/gemma-2-2b-it", # Instruction-tuned, good for Q&A + "meta-llama/Llama-3.2-3B-Instruct", # If available with token + "Qwen/Qwen2.5-1.5B-Instruct", # Small, fast, good quality + ] + + # Fallback to classic API models if chat completion fails + CLASSIC_FALLBACK_MODELS = [ + "google/flan-t5-small", # Good for instructions + "deepset/roberta-base-squad2", # Q&A specific + "facebook/bart-base", # Summarization + ] + + def __init__( + self, + model_name: Optional[str] = None, + api_token: Optional[str] = None, + temperature: float = 0.3, + max_tokens: int = 512, + timeout: int = 30 + ): + """ + Initialize the Inference Providers answer generator. + + Args: + model_name: Model to use (defaults to first available chat model) + api_token: HF API token (uses env vars if not provided) + temperature: Generation temperature (0.0-1.0) + max_tokens: Maximum tokens to generate + timeout: Request timeout in seconds + """ + # Get API token from various sources + self.api_token = ( + api_token or + os.getenv("HUGGINGFACE_API_TOKEN") or + os.getenv("HF_TOKEN") or + os.getenv("HF_API_TOKEN") + ) + + if not self.api_token: + print("⚠️ No HF API token found. Inference Providers requires authentication.", file=sys.stderr, flush=True) + print("Set HF_TOKEN, HUGGINGFACE_API_TOKEN, or HF_API_TOKEN environment variable.", file=sys.stderr, flush=True) + raise ValueError("HuggingFace API token required for Inference Providers") + + print(f"✅ Found HF token (starts with: {self.api_token[:8]}...)", file=sys.stderr, flush=True) + + # Initialize client with token + self.client = InferenceClient(token=self.api_token) + self.temperature = temperature + self.max_tokens = max_tokens + self.timeout = timeout + + # Select model + self.model_name = model_name or self.CHAT_MODELS[0] + self.using_chat_completion = True + + print(f"🚀 Initialized Inference Providers with model: {self.model_name}", file=sys.stderr, flush=True) + + # Test the connection + self._test_connection() + + def _test_connection(self): + """Test if the API is accessible and model is available.""" + print(f"🔧 Testing Inference Providers API connection...", file=sys.stderr, flush=True) + + try: + # Try a simple test query + test_messages = [ + {"role": "user", "content": "Hello"} + ] + + # First try chat completion (preferred) + try: + response = self.client.chat_completion( + messages=test_messages, + model=self.model_name, + max_tokens=10, + temperature=0.1 + ) + print(f"✅ Chat completion API working with {self.model_name}", file=sys.stderr, flush=True) + self.using_chat_completion = True + return + except Exception as e: + print(f"⚠️ Chat completion failed for {self.model_name}: {e}", file=sys.stderr, flush=True) + + # Try other chat models + for model in self.CHAT_MODELS: + if model != self.model_name: + try: + print(f"🔄 Trying {model}...", file=sys.stderr, flush=True) + response = self.client.chat_completion( + messages=test_messages, + model=model, + max_tokens=10 + ) + print(f"✅ Found working model: {model}", file=sys.stderr, flush=True) + self.model_name = model + self.using_chat_completion = True + return + except: + continue + + # If chat completion fails, test classic text generation + print("🔄 Falling back to classic text generation API...", file=sys.stderr, flush=True) + for model in self.CLASSIC_FALLBACK_MODELS: + try: + response = self.client.text_generation( + model=model, + prompt="Hello", + max_new_tokens=10 + ) + print(f"✅ Classic API working with fallback model: {model}", file=sys.stderr, flush=True) + self.model_name = model + self.using_chat_completion = False + return + except: + continue + + raise Exception("No working models found in Inference Providers API") + + except Exception as e: + print(f"❌ Inference Providers API test failed: {e}", file=sys.stderr, flush=True) + raise + + def _format_context(self, chunks: List[Dict[str, Any]]) -> str: + """Format retrieved chunks into context string.""" + context_parts = [] + + for i, chunk in enumerate(chunks): + chunk_text = chunk.get('content', chunk.get('text', '')) + page_num = chunk.get('metadata', {}).get('page_number', 'unknown') + source = chunk.get('metadata', {}).get('source', 'unknown') + + context_parts.append( + f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n" + ) + + return "\n---\n".join(context_parts) + + def _create_messages(self, query: str, context: str) -> List[Dict[str, str]]: + """Create chat messages using TechnicalPromptTemplates.""" + # Get appropriate template based on query type + prompt_data = TechnicalPromptTemplates.format_prompt_with_template( + query=query, + context=context + ) + + # Create messages for chat completion + messages = [ + { + "role": "system", + "content": prompt_data['system'] + "\n\nMANDATORY: Use [chunk_X] citations for all facts." + }, + { + "role": "user", + "content": prompt_data['user'] + } + ] + + return messages + + def _call_chat_completion(self, messages: List[Dict[str, str]]) -> str: + """Call the chat completion API.""" + try: + print(f"🤖 Calling Inference Providers chat completion with {self.model_name}...", file=sys.stderr, flush=True) + + # Use chat completion with proper error handling + response = self.client.chat_completion( + messages=messages, + model=self.model_name, + temperature=self.temperature, + max_tokens=self.max_tokens, + stream=False + ) + + # Extract content from response + if hasattr(response, 'choices') and response.choices: + content = response.choices[0].message.content + print(f"✅ Got response: {len(content)} characters", file=sys.stderr, flush=True) + return content + else: + print(f"⚠️ Unexpected response format: {response}", file=sys.stderr, flush=True) + return str(response) + + except Exception as e: + print(f"❌ Chat completion error: {e}", file=sys.stderr, flush=True) + + # Try with a fallback model + if self.model_name != "microsoft/DialoGPT-medium": + print("🔄 Trying fallback model: microsoft/DialoGPT-medium", file=sys.stderr, flush=True) + try: + response = self.client.chat_completion( + messages=messages, + model="microsoft/DialoGPT-medium", + temperature=self.temperature, + max_tokens=self.max_tokens + ) + if hasattr(response, 'choices') and response.choices: + return response.choices[0].message.content + except: + pass + + raise Exception(f"Chat completion failed: {e}") + + def _call_classic_api(self, query: str, context: str) -> str: + """Fallback to classic text generation API.""" + print(f"🔄 Using classic text generation with {self.model_name}...", file=sys.stderr, flush=True) + + # Format prompt for classic API + if "squad" in self.model_name.lower(): + # Q&A format for squad models + prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:" + elif "flan" in self.model_name.lower(): + # Instruction format for Flan models + prompt = f"Answer the question based on the context.\n\nContext: {context}\n\nQuestion: {query}\n\nAnswer:" + else: + # Generic format + prompt = f"Based on the following context, answer the question.\n\nContext:\n{context}\n\nQuestion: {query}\n\nAnswer:" + + try: + response = self.client.text_generation( + model=self.model_name, + prompt=prompt, + max_new_tokens=self.max_tokens, + temperature=self.temperature + ) + return response + except Exception as e: + print(f"❌ Classic API error: {e}", file=sys.stderr, flush=True) + return f"Error generating response: {str(e)}" + + def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]: + """Extract citations from the answer.""" + citations = [] + citation_pattern = r'\[chunk_(\d+)\]' + + cited_chunks = set() + + # Find explicit citations + matches = re.finditer(citation_pattern, answer) + for match in matches: + chunk_idx = int(match.group(1)) - 1 + if 0 <= chunk_idx < len(chunks): + cited_chunks.add(chunk_idx) + + # Fallback: Create citations for top chunks if none found + if not cited_chunks and chunks and len(answer.strip()) > 50: + num_fallback = min(3, len(chunks)) + cited_chunks = set(range(num_fallback)) + print(f"🔧 Creating {num_fallback} fallback citations", file=sys.stderr, flush=True) + + # Create Citation objects + chunk_to_source = {} + for idx in cited_chunks: + chunk = chunks[idx] + citation = Citation( + chunk_id=chunk.get('id', f'chunk_{idx}'), + page_number=chunk.get('metadata', {}).get('page_number', 0), + source_file=chunk.get('metadata', {}).get('source', 'unknown'), + relevance_score=chunk.get('score', 0.0), + text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...' + ) + citations.append(citation) + + # Map for natural language replacement + source_name = chunk.get('metadata', {}).get('source', 'unknown') + if source_name != 'unknown': + natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ') + chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation" + else: + chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation" + + # Replace citations with natural language + natural_answer = answer + for chunk_ref, natural_ref in chunk_to_source.items(): + natural_answer = natural_answer.replace(chunk_ref, natural_ref) + + # Clean up any remaining citations + natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer) + natural_answer = re.sub(r'\s+', ' ', natural_answer).strip() + + return natural_answer, citations + + def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float: + """Calculate confidence score for the answer.""" + if not answer or len(answer.strip()) < 10: + return 0.1 + + # Base confidence from chunk quality + if len(chunks) >= 3: + confidence = 0.8 + elif len(chunks) >= 2: + confidence = 0.7 + else: + confidence = 0.6 + + # Citation bonus + if citations and chunks: + citation_ratio = len(citations) / min(len(chunks), 3) + confidence += 0.15 * citation_ratio + + # Check for uncertainty phrases + uncertainty_phrases = [ + "insufficient information", + "cannot determine", + "not available in the provided documents", + "i don't know", + "unclear" + ] + + if any(phrase in answer.lower() for phrase in uncertainty_phrases): + confidence *= 0.3 + + return min(confidence, 0.95) + + def generate(self, query: str, chunks: List[Dict[str, Any]]) -> GeneratedAnswer: + """ + Generate an answer using Inference Providers API. + + Args: + query: User's question + chunks: Retrieved document chunks + + Returns: + GeneratedAnswer with answer, citations, and metadata + """ + start_time = datetime.now() + + # Check for no-context situation + if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks): + return GeneratedAnswer( + answer="This information isn't available in the provided documents.", + citations=[], + confidence_score=0.05, + generation_time=0.1, + model_used=self.model_name, + context_used=chunks + ) + + # Format context + context = self._format_context(chunks) + + # Generate answer + try: + if self.using_chat_completion: + # Create chat messages + messages = self._create_messages(query, context) + + # Call chat completion API + answer_text = self._call_chat_completion(messages) + else: + # Fallback to classic API + answer_text = self._call_classic_api(query, context) + + # Extract citations and clean answer + natural_answer, citations = self._extract_citations(answer_text, chunks) + + # Calculate confidence + confidence = self._calculate_confidence(natural_answer, citations, chunks) + + generation_time = (datetime.now() - start_time).total_seconds() + + return GeneratedAnswer( + answer=natural_answer, + citations=citations, + confidence_score=confidence, + generation_time=generation_time, + model_used=self.model_name, + context_used=chunks + ) + + except Exception as e: + logger.error(f"Error generating answer: {e}") + print(f"❌ Generation failed: {e}", file=sys.stderr, flush=True) + + # Return error response + return GeneratedAnswer( + answer="I apologize, but I encountered an error while generating the answer. Please try again.", + citations=[], + confidence_score=0.0, + generation_time=(datetime.now() - start_time).total_seconds(), + model_used=self.model_name, + context_used=chunks + ) + + def generate_with_custom_prompt( + self, + query: str, + chunks: List[Dict[str, Any]], + custom_prompt: Dict[str, str] + ) -> GeneratedAnswer: + """ + Generate answer using a custom prompt (for adaptive prompting). + + Args: + query: User's question + chunks: Retrieved context chunks + custom_prompt: Dict with 'system' and 'user' prompts + + Returns: + GeneratedAnswer with custom prompt enhancement + """ + start_time = datetime.now() + + if not chunks: + return GeneratedAnswer( + answer="I don't have enough context to answer your question.", + citations=[], + confidence_score=0.0, + generation_time=0.1, + model_used=self.model_name, + context_used=chunks + ) + + try: + # Try chat completion with custom prompt + messages = [ + {"role": "system", "content": custom_prompt['system']}, + {"role": "user", "content": custom_prompt['user']} + ] + + answer_text = self._call_chat_completion(messages) + + # Extract citations and clean answer + natural_answer, citations = self._extract_citations(answer_text, chunks) + + # Calculate confidence + confidence = self._calculate_confidence(natural_answer, citations, chunks) + + generation_time = (datetime.now() - start_time).total_seconds() + + return GeneratedAnswer( + answer=natural_answer, + citations=citations, + confidence_score=confidence, + generation_time=generation_time, + model_used=self.model_name, + context_used=chunks + ) + + except Exception as e: + logger.error(f"Error generating answer with custom prompt: {e}") + print(f"❌ Custom prompt generation failed: {e}", file=sys.stderr, flush=True) + + # Return error response + return GeneratedAnswer( + answer="I apologize, but I encountered an error while generating the answer. Please try again.", + citations=[], + confidence_score=0.0, + generation_time=(datetime.now() - start_time).total_seconds(), + model_used=self.model_name, + context_used=chunks + ) + + +# Example usage +if __name__ == "__main__": + # Test the generator + print("Testing Inference Providers Generator...") + + try: + generator = InferenceProvidersGenerator() + + # Test chunks + test_chunks = [ + { + "content": "RISC-V is an open-source instruction set architecture (ISA) based on established reduced instruction set computer (RISC) principles.", + "metadata": {"page_number": 1, "source": "riscv-spec.pdf"}, + "score": 0.95 + }, + { + "content": "Unlike most other ISA designs, RISC-V is provided under open source licenses that do not require fees to use.", + "metadata": {"page_number": 2, "source": "riscv-spec.pdf"}, + "score": 0.85 + } + ] + + # Generate answer + result = generator.generate("What is RISC-V and why is it important?", test_chunks) + + print(f"\n📝 Answer: {result.answer}") + print(f"📊 Confidence: {result.confidence_score:.1%}") + print(f"⏱️ Generation time: {result.generation_time:.2f}s") + print(f"🤖 Model: {result.model_used}") + print(f"📚 Citations: {len(result.citations)}") + + except Exception as e: + print(f"❌ Test failed: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/shared_utils/generation/ollama_answer_generator.py b/shared_utils/generation/ollama_answer_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..553cc306e1e2ff8dc2ca6900cd0443e7fceaed76 --- /dev/null +++ b/shared_utils/generation/ollama_answer_generator.py @@ -0,0 +1,834 @@ +#!/usr/bin/env python3 +""" +Ollama-based answer generator for local inference. + +Provides the same interface as HuggingFaceAnswerGenerator but uses +local Ollama server for model inference. +""" + +import time +import requests +import json +import re +import sys +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Any, Tuple +from dataclasses import dataclass + +# Import shared components +from .hf_answer_generator import Citation, GeneratedAnswer +from .prompt_templates import TechnicalPromptTemplates + +# Import standard interfaces (add this for the adapter) +try: + from pathlib import Path + import sys + project_root = Path(__file__).parent.parent.parent.parent.parent + sys.path.append(str(project_root)) + from src.core.interfaces import Document, Answer, AnswerGenerator +except ImportError: + # Fallback for standalone usage + Document = None + Answer = None + AnswerGenerator = object + + +class OllamaAnswerGenerator(AnswerGenerator if AnswerGenerator != object else object): + """ + Generates answers using local Ollama server. + + Perfect for: + - Local development + - Privacy-sensitive applications + - No API rate limits + - Consistent performance + - Offline operation + """ + + def __init__( + self, + model_name: str = "llama3.2:3b", + base_url: str = "http://localhost:11434", + temperature: float = 0.3, + max_tokens: int = 512, + ): + """ + Initialize Ollama answer generator. + + Args: + model_name: Ollama model to use (e.g., "llama3.2:3b", "mistral") + base_url: Ollama server URL + temperature: Generation temperature + max_tokens: Maximum tokens to generate + """ + self.model_name = model_name + self.base_url = base_url.rstrip("/") + self.temperature = temperature + self.max_tokens = max_tokens + + # Test connection + self._test_connection() + + def _test_connection(self): + """Test if Ollama server is accessible.""" + # Reduce retries for faster initialization - container should be ready quickly + max_retries = 12 # Wait up to 60 seconds for Ollama to start + retry_delay = 5 + + print( + f"🔧 Testing connection to {self.base_url}/api/tags...", + file=sys.stderr, + flush=True, + ) + + for attempt in range(max_retries): + try: + response = requests.get(f"{self.base_url}/api/tags", timeout=8) + if response.status_code == 200: + print( + f"✅ Connected to Ollama at {self.base_url}", + file=sys.stderr, + flush=True, + ) + + # Check if our model is available + models = response.json().get("models", []) + model_names = [m["name"] for m in models] + + if self.model_name in model_names: + print( + f"✅ Model {self.model_name} is available", + file=sys.stderr, + flush=True, + ) + return # Success! + else: + print( + f"⚠️ Model {self.model_name} not found. Available: {model_names}", + file=sys.stderr, + flush=True, + ) + if models: # If any models are available, use the first one + fallback_model = model_names[0] + print( + f"🔄 Using fallback model: {fallback_model}", + file=sys.stderr, + flush=True, + ) + self.model_name = fallback_model + return + else: + print( + f"📥 No models found, will try to pull {self.model_name}", + file=sys.stderr, + flush=True, + ) + # Try to pull the model + self._pull_model(self.model_name) + return + else: + print(f"⚠️ Ollama server returned status {response.status_code}") + if attempt < max_retries - 1: + print( + f"🔄 Retry {attempt + 1}/{max_retries} in {retry_delay} seconds..." + ) + time.sleep(retry_delay) + continue + + except requests.exceptions.ConnectionError: + if attempt < max_retries - 1: + print( + f"⏳ Ollama not ready yet, retry {attempt + 1}/{max_retries} in {retry_delay} seconds..." + ) + time.sleep(retry_delay) + continue + else: + raise Exception( + f"Cannot connect to Ollama server at {self.base_url} after 60 seconds. Check if it's running." + ) + except requests.exceptions.Timeout: + if attempt < max_retries - 1: + print(f"⏳ Ollama timeout, retry {attempt + 1}/{max_retries}...") + time.sleep(retry_delay) + continue + else: + raise Exception("Ollama server timeout after multiple retries.") + except Exception as e: + if attempt < max_retries - 1: + print(f"⚠️ Ollama error: {e}, retry {attempt + 1}/{max_retries}...") + time.sleep(retry_delay) + continue + else: + raise Exception( + f"Ollama connection failed after {max_retries} attempts: {e}" + ) + + raise Exception("Failed to connect to Ollama after all retries") + + def _pull_model(self, model_name: str): + """Pull a model if it's not available.""" + try: + print(f"📥 Pulling model {model_name}...") + pull_response = requests.post( + f"{self.base_url}/api/pull", + json={"name": model_name}, + timeout=300, # 5 minutes for model download + ) + if pull_response.status_code == 200: + print(f"✅ Successfully pulled {model_name}") + else: + print(f"⚠️ Failed to pull {model_name}: {pull_response.status_code}") + # Try smaller models as fallback + fallback_models = ["llama3.2:1b", "llama2:latest", "mistral:latest"] + for fallback in fallback_models: + try: + print(f"🔄 Trying fallback model: {fallback}") + fallback_response = requests.post( + f"{self.base_url}/api/pull", + json={"name": fallback}, + timeout=300, + ) + if fallback_response.status_code == 200: + print(f"✅ Successfully pulled fallback {fallback}") + self.model_name = fallback + return + except: + continue + raise Exception(f"Failed to pull {model_name} or any fallback models") + except Exception as e: + print(f"❌ Model pull failed: {e}") + raise + + def _format_context(self, chunks: List[Dict[str, Any]]) -> str: + """Format retrieved chunks into context.""" + context_parts = [] + + for i, chunk in enumerate(chunks): + chunk_text = chunk.get("content", chunk.get("text", "")) + page_num = chunk.get("metadata", {}).get("page_number", "unknown") + source = chunk.get("metadata", {}).get("source", "unknown") + + context_parts.append( + f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n" + ) + + return "\n---\n".join(context_parts) + + def _create_prompt(self, query: str, context: str, chunks: List[Dict[str, Any]]) -> str: + """Create optimized prompt with dynamic length constraints and citation instructions.""" + # Get the appropriate template based on query type + prompt_data = TechnicalPromptTemplates.format_prompt_with_template( + query=query, context=context + ) + + # Create dynamic citation instructions based on available chunks + num_chunks = len(chunks) + available_chunks = ", ".join([f"[chunk_{i+1}]" for i in range(min(num_chunks, 5))]) # Show max 5 examples + + # Create appropriate example based on actual chunks + if num_chunks == 1: + citation_example = "RISC-V is an open-source ISA [chunk_1]." + elif num_chunks == 2: + citation_example = "RISC-V is an open-source ISA [chunk_1] that supports multiple data widths [chunk_2]." + else: + citation_example = "RISC-V is an open-source ISA [chunk_1] that supports multiple data widths [chunk_2] and provides extensions [chunk_3]." + + # Determine optimal answer length based on query complexity + target_length = self._determine_target_length(query, chunks) + length_instruction = self._create_length_instruction(target_length) + + # Format for different model types + if "llama" in self.model_name.lower(): + # Llama-3.2 format with technical prompt templates + return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> +{prompt_data['system']} + +MANDATORY CITATION RULES: +- ONLY use available chunks: {available_chunks} +- You have {num_chunks} chunks available - DO NOT cite chunk numbers higher than {num_chunks} +- Every technical claim MUST have a citation from available chunks +- Example: "{citation_example}" + +{length_instruction} + +<|eot_id|><|start_header_id|>user<|end_header_id|> +{prompt_data['user']} + +CRITICAL: You MUST cite sources ONLY from available chunks: {available_chunks}. DO NOT use chunk numbers > {num_chunks}. +{length_instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>""" + + elif "mistral" in self.model_name.lower(): + # Mistral format with technical templates + return f"""[INST] {prompt_data['system']} + +Context: +{context} + +Question: {query} + +MANDATORY: ONLY use available chunks: {available_chunks}. DO NOT cite chunk numbers > {num_chunks}. +{length_instruction} [/INST]""" + + else: + # Generic format with technical templates + return f"""{prompt_data['system']} + +Context: +{context} + +Question: {query} + +MANDATORY CITATIONS: ONLY use available chunks: {available_chunks}. DO NOT cite chunk numbers > {num_chunks}. +{length_instruction} + +Answer:""" + + def _determine_target_length(self, query: str, chunks: List[Dict[str, Any]]) -> int: + """ + Determine optimal answer length based on query complexity. + + Target range: 150-400 characters (down from 1000-2600) + """ + # Analyze query complexity + query_length = len(query) + query_words = len(query.split()) + + # Check for complexity indicators + complex_words = [ + "explain", "describe", "analyze", "compare", "contrast", + "evaluate", "discuss", "detail", "elaborate", "comprehensive" + ] + + simple_words = [ + "what", "define", "list", "name", "identify", "is", "are" + ] + + query_lower = query.lower() + is_complex = any(word in query_lower for word in complex_words) + is_simple = any(word in query_lower for word in simple_words) + + # Base length from query type + if is_complex: + base_length = 350 # Complex queries get longer answers + elif is_simple: + base_length = 200 # Simple queries get shorter answers + else: + base_length = 275 # Default middle ground + + # Adjust based on available context + context_factor = min(len(chunks) * 25, 75) # More context allows longer answers + + # Adjust based on query length + query_factor = min(query_words * 5, 50) # Longer queries allow longer answers + + target_length = base_length + context_factor + query_factor + + # Constrain to target range + return max(150, min(target_length, 400)) + + def _create_length_instruction(self, target_length: int) -> str: + """Create length instruction based on target length.""" + if target_length <= 200: + return f"ANSWER LENGTH: Keep your answer concise and focused, approximately {target_length} characters. Be direct and to the point." + elif target_length <= 300: + return f"ANSWER LENGTH: Provide a clear and informative answer, approximately {target_length} characters. Include key details but avoid unnecessary elaboration." + else: + return f"ANSWER LENGTH: Provide a comprehensive but concise answer, approximately {target_length} characters. Include important details while maintaining clarity." + + def _call_ollama(self, prompt: str) -> str: + """Call Ollama API for generation.""" + payload = { + "model": self.model_name, + "prompt": prompt, + "stream": False, + "options": { + "temperature": self.temperature, + "num_predict": self.max_tokens, + "top_p": 0.9, + "repeat_penalty": 1.1, + }, + } + + try: + response = requests.post( + f"{self.base_url}/api/generate", json=payload, timeout=300 + ) + + response.raise_for_status() + result = response.json() + + return result.get("response", "").strip() + + except requests.exceptions.RequestException as e: + print(f"❌ Ollama API error: {e}") + return f"Error communicating with Ollama: {str(e)}" + except Exception as e: + print(f"❌ Unexpected error: {e}") + return f"Unexpected error: {str(e)}" + + def _extract_citations( + self, answer: str, chunks: List[Dict[str, Any]] + ) -> Tuple[str, List[Citation]]: + """Extract citations from the generated answer.""" + citations = [] + citation_pattern = r"\[chunk_(\d+)\]" + + cited_chunks = set() + + # Find [chunk_X] citations + matches = re.finditer(citation_pattern, answer) + for match in matches: + chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index + if 0 <= chunk_idx < len(chunks): + cited_chunks.add(chunk_idx) + + # FALLBACK: If no explicit citations found but we have an answer and chunks, + # create citations for the top chunks that were likely used + if not cited_chunks and chunks and len(answer.strip()) > 50: + # Use the top chunks that were provided as likely sources + num_fallback_citations = min(3, len(chunks)) # Use top 3 chunks max + cited_chunks = set(range(num_fallback_citations)) + print( + f"🔧 Fallback: Creating {num_fallback_citations} citations for answer without explicit [chunk_X] references", + file=sys.stderr, + flush=True, + ) + + # Create Citation objects + chunk_to_source = {} + for idx in cited_chunks: + chunk = chunks[idx] + citation = Citation( + chunk_id=chunk.get("id", f"chunk_{idx}"), + page_number=chunk.get("metadata", {}).get("page_number", 0), + source_file=chunk.get("metadata", {}).get("source", "unknown"), + relevance_score=chunk.get("score", 0.0), + text_snippet=chunk.get("content", chunk.get("text", ""))[:200] + "...", + ) + citations.append(citation) + + # Don't replace chunk references - keep them as proper citations + # The issue was that replacing [chunk_X] with "the documentation" creates repetitive text + # Instead, we should keep the proper citation format + pass + + # Keep the answer as-is with proper [chunk_X] citations + # Don't replace citations with repetitive text + natural_answer = re.sub(r"\s+", " ", answer).strip() + + return natural_answer, citations + + def _calculate_confidence( + self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]] + ) -> float: + """ + Calculate confidence score with expanded multi-factor assessment. + + Enhanced algorithm expands range from 0.75-0.95 to 0.3-0.9 with: + - Context quality assessment + - Citation quality evaluation + - Semantic relevance scoring + - Off-topic detection + - Answer completeness analysis + """ + if not answer or len(answer.strip()) < 10: + return 0.1 + + # 1. Context Quality Assessment (0.3-0.6 base range) + context_quality = self._assess_context_quality(chunks) + + # 2. Citation Quality Evaluation (0.0-0.2 boost) + citation_quality = self._assess_citation_quality(citations, chunks) + + # 3. Semantic Relevance Scoring (0.0-0.15 boost) + semantic_relevance = self._assess_semantic_relevance(answer, chunks) + + # 4. Off-topic Detection (-0.4 penalty if off-topic) + off_topic_penalty = self._detect_off_topic(answer, chunks) + + # 5. Answer Completeness Analysis (0.0-0.1 boost) + completeness_bonus = self._assess_answer_completeness(answer, len(chunks)) + + # Combine all factors + confidence = ( + context_quality + + citation_quality + + semantic_relevance + + completeness_bonus + + off_topic_penalty + ) + + # Apply uncertainty penalty + uncertainty_phrases = [ + "insufficient information", + "cannot determine", + "not available in the provided documents", + "I don't have enough context", + "the context doesn't seem to provide" + ] + + if any(phrase in answer.lower() for phrase in uncertainty_phrases): + confidence *= 0.4 # Stronger penalty for uncertainty + + # Constrain to target range 0.3-0.9 + return max(0.3, min(confidence, 0.9)) + + def _assess_context_quality(self, chunks: List[Dict[str, Any]]) -> float: + """Assess quality of context chunks (0.3-0.6 range).""" + if not chunks: + return 0.3 + + # Base score from chunk count + if len(chunks) >= 3: + base_score = 0.6 + elif len(chunks) >= 2: + base_score = 0.5 + else: + base_score = 0.4 + + # Quality adjustments based on chunk content + avg_chunk_length = sum(len(chunk.get("content", chunk.get("text", ""))) for chunk in chunks) / len(chunks) + + if avg_chunk_length > 500: # Rich content + base_score += 0.05 + elif avg_chunk_length < 100: # Sparse content + base_score -= 0.05 + + return max(0.3, min(base_score, 0.6)) + + def _assess_citation_quality(self, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float: + """Assess citation quality (0.0-0.2 range).""" + if not citations or not chunks: + return 0.0 + + # Citation coverage bonus + citation_ratio = len(citations) / min(len(chunks), 3) + coverage_bonus = 0.1 * citation_ratio + + # Citation diversity bonus (multiple sources) + unique_sources = len(set(cit.source_file for cit in citations)) + diversity_bonus = 0.05 * min(unique_sources / max(len(chunks), 1), 1.0) + + return min(coverage_bonus + diversity_bonus, 0.2) + + def _assess_semantic_relevance(self, answer: str, chunks: List[Dict[str, Any]]) -> float: + """Assess semantic relevance between answer and context (0.0-0.15 range).""" + if not answer or not chunks: + return 0.0 + + # Simple keyword overlap assessment + answer_words = set(answer.lower().split()) + context_words = set() + + for chunk in chunks: + chunk_text = chunk.get("content", chunk.get("text", "")) + context_words.update(chunk_text.lower().split()) + + if not context_words: + return 0.0 + + # Calculate overlap ratio + overlap = len(answer_words & context_words) + total_unique = len(answer_words | context_words) + + if total_unique == 0: + return 0.0 + + overlap_ratio = overlap / total_unique + return min(0.15 * overlap_ratio, 0.15) + + def _detect_off_topic(self, answer: str, chunks: List[Dict[str, Any]]) -> float: + """Detect if answer is off-topic (-0.4 penalty if off-topic).""" + if not answer or not chunks: + return 0.0 + + # Check for off-topic indicators + off_topic_phrases = [ + "but I have to say that the context doesn't seem to provide", + "these documents appear to be focused on", + "but they don't seem to cover", + "I'd recommend consulting a different type of documentation", + "without more context or information" + ] + + answer_lower = answer.lower() + for phrase in off_topic_phrases: + if phrase in answer_lower: + return -0.4 # Strong penalty for off-topic responses + + return 0.0 + + def _assess_answer_completeness(self, answer: str, chunk_count: int) -> float: + """Assess answer completeness (0.0-0.1 range).""" + if not answer: + return 0.0 + + # Length-based completeness assessment + answer_length = len(answer) + + if answer_length > 500: # Comprehensive answer + return 0.1 + elif answer_length > 200: # Adequate answer + return 0.05 + else: # Brief answer + return 0.0 + + def generate(self, query: str, context: List[Document]) -> Answer: + """ + Generate an answer from query and context documents (standard interface). + + This is the public interface that conforms to the AnswerGenerator protocol. + It handles the conversion between standard Document objects and Ollama's + internal chunk format. + + Args: + query: User's question + context: List of relevant Document objects + + Returns: + Answer object conforming to standard interface + + Raises: + ValueError: If query is empty or context is None + """ + if not query.strip(): + raise ValueError("Query cannot be empty") + + if context is None: + raise ValueError("Context cannot be None") + + # Internal adapter: Convert Documents to Ollama chunk format + ollama_chunks = self._documents_to_ollama_chunks(context) + + # Use existing Ollama-specific generation logic + ollama_result = self._generate_internal(query, ollama_chunks) + + # Internal adapter: Convert Ollama result to standard Answer + return self._ollama_result_to_answer(ollama_result, context) + + def _generate_internal(self, query: str, chunks: List[Dict[str, Any]]) -> GeneratedAnswer: + """ + Generate an answer based on the query and retrieved chunks. + + Args: + query: User's question + chunks: Retrieved document chunks + + Returns: + GeneratedAnswer object with answer, citations, and metadata + """ + start_time = datetime.now() + + # Check for no-context situation + if not chunks or all( + len(chunk.get("content", chunk.get("text", ""))) < 20 for chunk in chunks + ): + return GeneratedAnswer( + answer="This information isn't available in the provided documents.", + citations=[], + confidence_score=0.05, + generation_time=0.1, + model_used=self.model_name, + context_used=chunks, + ) + + # Format context + context = self._format_context(chunks) + + # Create prompt with chunks parameter for dynamic citation instructions + prompt = self._create_prompt(query, context, chunks) + + # Generate answer + print( + f"🤖 Calling Ollama with {self.model_name}...", file=sys.stderr, flush=True + ) + answer_with_citations = self._call_ollama(prompt) + + generation_time = (datetime.now() - start_time).total_seconds() + + # Extract citations and create natural answer + natural_answer, citations = self._extract_citations( + answer_with_citations, chunks + ) + + # Calculate confidence + confidence = self._calculate_confidence(natural_answer, citations, chunks) + + return GeneratedAnswer( + answer=natural_answer, + citations=citations, + confidence_score=confidence, + generation_time=generation_time, + model_used=self.model_name, + context_used=chunks, + ) + + def generate_with_custom_prompt( + self, + query: str, + chunks: List[Dict[str, Any]], + custom_prompt: Dict[str, str] + ) -> GeneratedAnswer: + """ + Generate answer using a custom prompt (for adaptive prompting). + + Args: + query: User's question + chunks: Retrieved context chunks + custom_prompt: Dict with 'system' and 'user' prompts + + Returns: + GeneratedAnswer with custom prompt enhancement + """ + start_time = datetime.now() + + if not chunks: + return GeneratedAnswer( + answer="I don't have enough context to answer your question.", + citations=[], + confidence_score=0.0, + generation_time=0.1, + model_used=self.model_name, + context_used=chunks, + ) + + # Build custom prompt based on model type + if "llama" in self.model_name.lower(): + prompt = f"""[INST] {custom_prompt['system']} + +{custom_prompt['user']} + +MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]""" + elif "mistral" in self.model_name.lower(): + prompt = f"""[INST] {custom_prompt['system']} + +{custom_prompt['user']} + +MANDATORY: Use [chunk_1], [chunk_2] etc. for all facts. [/INST]""" + else: + # Generic format for other models + prompt = f"""{custom_prompt['system']} + +{custom_prompt['user']} + +MANDATORY: Use [chunk_1], [chunk_2] etc. for all factual statements. + +Answer:""" + + # Generate answer + print(f"🤖 Calling Ollama with custom prompt using {self.model_name}...", file=sys.stderr, flush=True) + answer_with_citations = self._call_ollama(prompt) + + generation_time = (datetime.now() - start_time).total_seconds() + + # Extract citations and create natural answer + natural_answer, citations = self._extract_citations(answer_with_citations, chunks) + + # Calculate confidence + confidence = self._calculate_confidence(natural_answer, citations, chunks) + + return GeneratedAnswer( + answer=natural_answer, + citations=citations, + confidence_score=confidence, + generation_time=generation_time, + model_used=self.model_name, + context_used=chunks, + ) + + def _documents_to_ollama_chunks(self, documents: List[Document]) -> List[Dict[str, Any]]: + """ + Convert Document objects to Ollama's internal chunk format. + + This internal adapter ensures that Document objects are properly formatted + for Ollama's processing pipeline while keeping the format requirements + encapsulated within this class. + + Args: + documents: List of Document objects from the standard interface + + Returns: + List of chunk dictionaries in Ollama's expected format + """ + if not documents: + return [] + + chunks = [] + for i, doc in enumerate(documents): + chunk = { + "id": f"chunk_{i+1}", + "content": doc.content, # Ollama expects "content" field + "text": doc.content, # Fallback field for compatibility + "score": 1.0, # Default relevance score + "metadata": { + "source": doc.metadata.get("source", "unknown"), + "page_number": doc.metadata.get("start_page", 1), + **doc.metadata # Include all original metadata + } + } + chunks.append(chunk) + + return chunks + + def _ollama_result_to_answer(self, ollama_result: GeneratedAnswer, original_context: List[Document]) -> Answer: + """ + Convert Ollama's GeneratedAnswer to the standard Answer format. + + This internal adapter converts Ollama's result format back to the + standard interface format expected by the rest of the system. + + Args: + ollama_result: Result from Ollama's internal generation + original_context: Original Document objects for sources + + Returns: + Answer object conforming to standard interface + """ + if not Answer: + # Fallback if standard interface not available + return ollama_result + + # Convert to standard Answer format + return Answer( + text=ollama_result.answer, + sources=original_context, # Use original Document objects + confidence=ollama_result.confidence_score, + metadata={ + "model_used": ollama_result.model_used, + "generation_time": ollama_result.generation_time, + "citations": [ + { + "chunk_id": cit.chunk_id, + "page_number": cit.page_number, + "source_file": cit.source_file, + "relevance_score": cit.relevance_score, + "text_snippet": cit.text_snippet + } + for cit in ollama_result.citations + ], + "provider": "ollama", + "temperature": self.temperature, + "max_tokens": self.max_tokens + } + ) + + +# Example usage +if __name__ == "__main__": + # Test Ollama connection + generator = OllamaAnswerGenerator(model_name="llama3.2:3b") + + # Mock chunks for testing + test_chunks = [ + { + "content": "RISC-V is a free and open-source ISA.", + "metadata": {"page_number": 1, "source": "riscv-spec.pdf"}, + "score": 0.9, + } + ] + + # Test generation + result = generator.generate("What is RISC-V?", test_chunks) + print(f"Answer: {result.answer}") + print(f"Confidence: {result.confidence_score:.2%}") diff --git a/shared_utils/generation/prompt_optimizer.py b/shared_utils/generation/prompt_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..891c90ee9de09a2be0a1188cb90a86ac0330e2c9 --- /dev/null +++ b/shared_utils/generation/prompt_optimizer.py @@ -0,0 +1,687 @@ +""" +A/B Testing Framework for Prompt Optimization. + +This module provides systematic prompt optimization through A/B testing, +performance analysis, and automated prompt variation generation. +""" + +import json +import time +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass, asdict +from enum import Enum +from pathlib import Path +import numpy as np +from collections import defaultdict +import logging + +from .prompt_templates import QueryType, PromptTemplate, TechnicalPromptTemplates + + +class OptimizationMetric(Enum): + """Metrics for evaluating prompt performance.""" + RESPONSE_TIME = "response_time" + CONFIDENCE_SCORE = "confidence_score" + CITATION_COUNT = "citation_count" + ANSWER_LENGTH = "answer_length" + TECHNICAL_ACCURACY = "technical_accuracy" + USER_SATISFACTION = "user_satisfaction" + + +@dataclass +class PromptVariation: + """Represents a prompt variation for A/B testing.""" + variation_id: str + name: str + description: str + template: PromptTemplate + query_type: QueryType + created_at: float + metadata: Dict[str, Any] + + +@dataclass +class TestResult: + """Represents a single test result.""" + variation_id: str + query: str + query_type: QueryType + response_time: float + confidence_score: float + citation_count: int + answer_length: int + technical_accuracy: Optional[float] = None + user_satisfaction: Optional[float] = None + timestamp: float = None + metadata: Dict[str, Any] = None + + def __post_init__(self): + if self.timestamp is None: + self.timestamp = time.time() + if self.metadata is None: + self.metadata = {} + + +@dataclass +class ComparisonResult: + """Results of A/B test comparison.""" + variation_a: str + variation_b: str + metric: OptimizationMetric + a_mean: float + b_mean: float + improvement_percent: float + p_value: float + confidence_interval: Tuple[float, float] + is_significant: bool + sample_size: int + recommendation: str + + +class PromptOptimizer: + """ + A/B testing framework for systematic prompt optimization. + + Features: + - Automated prompt variation generation + - Performance metric tracking + - Statistical significance testing + - Recommendation engine + - Persistence and experiment tracking + """ + + def __init__(self, experiment_dir: str = "experiments"): + """ + Initialize the prompt optimizer. + + Args: + experiment_dir: Directory to store experiment data + """ + self.experiment_dir = Path(experiment_dir) + self.experiment_dir.mkdir(exist_ok=True) + + self.variations: Dict[str, PromptVariation] = {} + self.test_results: List[TestResult] = [] + self.active_experiments: Dict[str, List[str]] = {} + + # Load existing experiments + self._load_experiments() + + # Setup logging + logging.basicConfig(level=logging.INFO) + self.logger = logging.getLogger(__name__) + + def create_variation( + self, + base_template: PromptTemplate, + query_type: QueryType, + variation_name: str, + modifications: Dict[str, str], + description: str = "" + ) -> str: + """ + Create a new prompt variation. + + Args: + base_template: Base template to modify + query_type: Type of query this variation is for + variation_name: Human-readable name + modifications: Dict of template field modifications + description: Description of the variation + + Returns: + Variation ID + """ + variation_id = f"{query_type.value}_{variation_name}_{int(time.time())}" + + # Create modified template + modified_template = PromptTemplate( + system_prompt=modifications.get("system_prompt", base_template.system_prompt), + context_format=modifications.get("context_format", base_template.context_format), + query_format=modifications.get("query_format", base_template.query_format), + answer_guidelines=modifications.get("answer_guidelines", base_template.answer_guidelines) + ) + + variation = PromptVariation( + variation_id=variation_id, + name=variation_name, + description=description, + template=modified_template, + query_type=query_type, + created_at=time.time(), + metadata=modifications + ) + + self.variations[variation_id] = variation + self._save_variation(variation) + + self.logger.info(f"Created variation: {variation_id}") + return variation_id + + def create_temperature_variations( + self, + base_query_type: QueryType, + temperatures: List[float] = [0.3, 0.5, 0.7, 0.9] + ) -> List[str]: + """ + Create variations with different temperature settings. + + Args: + base_query_type: Query type to create variations for + temperatures: List of temperature values to test + + Returns: + List of variation IDs + """ + base_template = TechnicalPromptTemplates.get_template_for_query("") + if base_query_type != QueryType.GENERAL: + template_map = { + QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template, + QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template, + QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template, + QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template, + QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template, + QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template, + QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template, + } + base_template = template_map[base_query_type]() + + variation_ids = [] + for temp in temperatures: + temp_modification = { + "system_prompt": base_template.system_prompt + f"\n\nGenerate responses with temperature={temp} (creativity level).", + "answer_guidelines": base_template.answer_guidelines + f"\n\nAdjust response creativity to temperature={temp}." + } + + variation_id = self.create_variation( + base_template=base_template, + query_type=base_query_type, + variation_name=f"temp_{temp}", + modifications=temp_modification, + description=f"Temperature variation with {temp} creativity level" + ) + variation_ids.append(variation_id) + + return variation_ids + + def create_length_variations( + self, + base_query_type: QueryType, + length_styles: List[str] = ["concise", "detailed", "comprehensive"] + ) -> List[str]: + """ + Create variations with different response length preferences. + + Args: + base_query_type: Query type to create variations for + length_styles: List of length styles to test + + Returns: + List of variation IDs + """ + base_template = TechnicalPromptTemplates.get_template_for_query("") + if base_query_type != QueryType.GENERAL: + template_map = { + QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template, + QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template, + QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template, + QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template, + QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template, + QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template, + QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template, + } + base_template = template_map[base_query_type]() + + length_prompts = { + "concise": "Be concise and focus on essential information only. Aim for 2-3 sentences per point.", + "detailed": "Provide detailed explanations with examples. Aim for comprehensive coverage.", + "comprehensive": "Provide exhaustive detail with multiple examples, edge cases, and related concepts." + } + + variation_ids = [] + for style in length_styles: + length_modification = { + "answer_guidelines": base_template.answer_guidelines + f"\n\nResponse style: {length_prompts[style]}" + } + + variation_id = self.create_variation( + base_template=base_template, + query_type=base_query_type, + variation_name=f"length_{style}", + modifications=length_modification, + description=f"Length variation with {style} response style" + ) + variation_ids.append(variation_id) + + return variation_ids + + def create_citation_variations( + self, + base_query_type: QueryType, + citation_styles: List[str] = ["minimal", "standard", "extensive"] + ) -> List[str]: + """ + Create variations with different citation requirements. + + Args: + base_query_type: Query type to create variations for + citation_styles: List of citation styles to test + + Returns: + List of variation IDs + """ + base_template = TechnicalPromptTemplates.get_template_for_query("") + if base_query_type != QueryType.GENERAL: + template_map = { + QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template, + QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template, + QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template, + QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template, + QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template, + QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template, + QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template, + } + base_template = template_map[base_query_type]() + + citation_prompts = { + "minimal": "Use [chunk_X] citations only for direct quotes or specific claims.", + "standard": "Include [chunk_X] citations for each major point or claim.", + "extensive": "Provide [chunk_X] citations for every statement. Use multiple citations per point where relevant." + } + + variation_ids = [] + for style in citation_styles: + citation_modification = { + "answer_guidelines": base_template.answer_guidelines + f"\n\nCitation style: {citation_prompts[style]}" + } + + variation_id = self.create_variation( + base_template=base_template, + query_type=base_query_type, + variation_name=f"citation_{style}", + modifications=citation_modification, + description=f"Citation variation with {style} citation requirements" + ) + variation_ids.append(variation_id) + + return variation_ids + + def setup_experiment( + self, + experiment_name: str, + variation_ids: List[str], + test_queries: List[str] + ) -> str: + """ + Set up a new A/B test experiment. + + Args: + experiment_name: Name of the experiment + variation_ids: List of variation IDs to test + test_queries: List of test queries + + Returns: + Experiment ID + """ + experiment_id = f"exp_{experiment_name}_{int(time.time())}" + + experiment_config = { + "experiment_id": experiment_id, + "name": experiment_name, + "variation_ids": variation_ids, + "test_queries": test_queries, + "created_at": time.time(), + "status": "active" + } + + self.active_experiments[experiment_id] = variation_ids + + # Save experiment config + experiment_file = self.experiment_dir / f"{experiment_id}.json" + with open(experiment_file, 'w') as f: + json.dump(experiment_config, f, indent=2) + + self.logger.info(f"Created experiment: {experiment_id}") + return experiment_id + + def record_test_result( + self, + variation_id: str, + query: str, + query_type: QueryType, + response_time: float, + confidence_score: float, + citation_count: int, + answer_length: int, + technical_accuracy: Optional[float] = None, + user_satisfaction: Optional[float] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> None: + """ + Record a test result for analysis. + + Args: + variation_id: ID of the variation tested + query: The query that was tested + query_type: Type of the query + response_time: Response time in seconds + confidence_score: Confidence score (0-1) + citation_count: Number of citations in response + answer_length: Length of answer in characters + technical_accuracy: Optional technical accuracy score (0-1) + user_satisfaction: Optional user satisfaction score (0-1) + metadata: Optional additional metadata + """ + result = TestResult( + variation_id=variation_id, + query=query, + query_type=query_type, + response_time=response_time, + confidence_score=confidence_score, + citation_count=citation_count, + answer_length=answer_length, + technical_accuracy=technical_accuracy, + user_satisfaction=user_satisfaction, + metadata=metadata or {} + ) + + self.test_results.append(result) + self._save_test_result(result) + + self.logger.info(f"Recorded test result for variation: {variation_id}") + + def analyze_variations( + self, + variation_a: str, + variation_b: str, + metric: OptimizationMetric, + min_samples: int = 10 + ) -> ComparisonResult: + """ + Analyze performance difference between two variations. + + Args: + variation_a: First variation ID + variation_b: Second variation ID + metric: Metric to compare + min_samples: Minimum samples required for analysis + + Returns: + Comparison result with statistical analysis + """ + # Filter results for each variation + results_a = [r for r in self.test_results if r.variation_id == variation_a] + results_b = [r for r in self.test_results if r.variation_id == variation_b] + + if len(results_a) < min_samples or len(results_b) < min_samples: + raise ValueError(f"Insufficient samples. Need at least {min_samples} for each variation.") + + # Extract metric values + values_a = self._extract_metric_values(results_a, metric) + values_b = self._extract_metric_values(results_b, metric) + + # Calculate statistics + mean_a = np.mean(values_a) + mean_b = np.mean(values_b) + + # Calculate improvement percentage + improvement = ((mean_b - mean_a) / mean_a) * 100 + + # Simple t-test (normally would use scipy.stats.ttest_ind) + # For now, using basic statistical comparison + std_a = np.std(values_a) + std_b = np.std(values_b) + n_a = len(values_a) + n_b = len(values_b) + + # Basic p-value estimation (simplified) + pooled_std = np.sqrt(((n_a - 1) * std_a**2 + (n_b - 1) * std_b**2) / (n_a + n_b - 2)) + t_stat = (mean_b - mean_a) / (pooled_std * np.sqrt(1/n_a + 1/n_b)) + p_value = 2 * (1 - abs(t_stat) / (abs(t_stat) + 1)) # Rough approximation + + # Confidence interval (simplified) + margin_of_error = 1.96 * pooled_std * np.sqrt(1/n_a + 1/n_b) + ci_lower = (mean_b - mean_a) - margin_of_error + ci_upper = (mean_b - mean_a) + margin_of_error + + # Determine significance + is_significant = p_value < 0.05 + + # Generate recommendation + if is_significant: + if improvement > 0: + recommendation = f"Variation B shows significant improvement ({improvement:.1f}%). Recommend adopting variation B." + else: + recommendation = f"Variation A shows significant improvement ({-improvement:.1f}%). Recommend keeping variation A." + else: + recommendation = f"No significant difference detected (p={p_value:.3f}). More data needed or variations are equivalent." + + return ComparisonResult( + variation_a=variation_a, + variation_b=variation_b, + metric=metric, + a_mean=mean_a, + b_mean=mean_b, + improvement_percent=improvement, + p_value=p_value, + confidence_interval=(ci_lower, ci_upper), + is_significant=is_significant, + sample_size=min(n_a, n_b), + recommendation=recommendation + ) + + def get_best_variation( + self, + query_type: QueryType, + metric: OptimizationMetric, + min_samples: int = 10 + ) -> Optional[str]: + """ + Get the best performing variation for a query type and metric. + + Args: + query_type: Type of query + metric: Metric to optimize for + min_samples: Minimum samples required + + Returns: + Best variation ID or None if insufficient data + """ + # Filter results by query type + relevant_results = [r for r in self.test_results if r.query_type == query_type] + + # Group by variation + variation_performance = defaultdict(list) + for result in relevant_results: + variation_performance[result.variation_id].append(result) + + # Calculate mean performance for each variation + best_variation = None + best_score = None + + for variation_id, results in variation_performance.items(): + if len(results) >= min_samples: + values = self._extract_metric_values(results, metric) + mean_score = np.mean(values) + + if best_score is None or mean_score > best_score: + best_score = mean_score + best_variation = variation_id + + return best_variation + + def generate_optimization_report( + self, + experiment_id: str, + output_file: Optional[str] = None + ) -> Dict[str, Any]: + """ + Generate a comprehensive optimization report. + + Args: + experiment_id: Experiment to analyze + output_file: Optional file to save report + + Returns: + Report dictionary + """ + if experiment_id not in self.active_experiments: + raise ValueError(f"Experiment {experiment_id} not found") + + variation_ids = self.active_experiments[experiment_id] + experiment_results = [r for r in self.test_results if r.variation_id in variation_ids] + + if not experiment_results: + raise ValueError(f"No results found for experiment {experiment_id}") + + # Analyze each metric + metrics = [ + OptimizationMetric.RESPONSE_TIME, + OptimizationMetric.CONFIDENCE_SCORE, + OptimizationMetric.CITATION_COUNT, + OptimizationMetric.ANSWER_LENGTH + ] + + report = { + "experiment_id": experiment_id, + "variations_tested": len(variation_ids), + "total_tests": len(experiment_results), + "analysis_date": time.time(), + "metric_analysis": {}, + "recommendations": [] + } + + # Analyze each metric across variations + for metric in metrics: + metric_data = {} + for variation_id in variation_ids: + var_results = [r for r in experiment_results if r.variation_id == variation_id] + if var_results: + values = self._extract_metric_values(var_results, metric) + metric_data[variation_id] = { + "mean": np.mean(values), + "std": np.std(values), + "count": len(values) + } + + report["metric_analysis"][metric.value] = metric_data + + # Generate recommendations + for metric in metrics: + best_variation = self.get_best_variation( + query_type=QueryType.GENERAL, # Could be made more specific + metric=metric, + min_samples=5 + ) + if best_variation: + report["recommendations"].append({ + "metric": metric.value, + "best_variation": best_variation, + "variation_name": self.variations[best_variation].name + }) + + # Save report if requested + if output_file: + with open(output_file, 'w') as f: + json.dump(report, f, indent=2) + + return report + + def _extract_metric_values(self, results: List[TestResult], metric: OptimizationMetric) -> List[float]: + """Extract metric values from test results.""" + values = [] + for result in results: + if metric == OptimizationMetric.RESPONSE_TIME: + values.append(result.response_time) + elif metric == OptimizationMetric.CONFIDENCE_SCORE: + values.append(result.confidence_score) + elif metric == OptimizationMetric.CITATION_COUNT: + values.append(float(result.citation_count)) + elif metric == OptimizationMetric.ANSWER_LENGTH: + values.append(float(result.answer_length)) + elif metric == OptimizationMetric.TECHNICAL_ACCURACY and result.technical_accuracy is not None: + values.append(result.technical_accuracy) + elif metric == OptimizationMetric.USER_SATISFACTION and result.user_satisfaction is not None: + values.append(result.user_satisfaction) + + return values + + def _load_experiments(self) -> None: + """Load existing experiments from disk.""" + if not self.experiment_dir.exists(): + return + + for file_path in self.experiment_dir.glob("*.json"): + if file_path.name.startswith("exp_"): + with open(file_path, 'r') as f: + config = json.load(f) + self.active_experiments[config["experiment_id"]] = config["variation_ids"] + + # Load variations and results + for file_path in self.experiment_dir.glob("variation_*.json"): + with open(file_path, 'r') as f: + var_data = json.load(f) + variation = PromptVariation(**var_data) + self.variations[variation.variation_id] = variation + + for file_path in self.experiment_dir.glob("result_*.json"): + with open(file_path, 'r') as f: + result_data = json.load(f) + result = TestResult(**result_data) + self.test_results.append(result) + + def _save_variation(self, variation: PromptVariation) -> None: + """Save variation to disk.""" + file_path = self.experiment_dir / f"variation_{variation.variation_id}.json" + var_dict = asdict(variation) + + # Convert template to dict + var_dict["template"] = asdict(variation.template) + var_dict["query_type"] = variation.query_type.value + + with open(file_path, 'w') as f: + json.dump(var_dict, f, indent=2) + + def _save_test_result(self, result: TestResult) -> None: + """Save test result to disk.""" + file_path = self.experiment_dir / f"result_{int(result.timestamp)}.json" + result_dict = asdict(result) + result_dict["query_type"] = result.query_type.value + + with open(file_path, 'w') as f: + json.dump(result_dict, f, indent=2) + + +# Example usage +if __name__ == "__main__": + # Initialize optimizer + optimizer = PromptOptimizer() + + # Create temperature variations for implementation queries + temp_variations = optimizer.create_temperature_variations( + base_query_type=QueryType.IMPLEMENTATION, + temperatures=[0.3, 0.7] + ) + + # Create length variations for definition queries + length_variations = optimizer.create_length_variations( + base_query_type=QueryType.DEFINITION, + length_styles=["concise", "detailed"] + ) + + # Setup experiment + test_queries = [ + "How do I implement a timer interrupt in RISC-V?", + "What is the difference between machine mode and user mode?", + "Configure GPIO pins for input/output operations" + ] + + experiment_id = optimizer.setup_experiment( + experiment_name="temperature_vs_length", + variation_ids=temp_variations + length_variations, + test_queries=test_queries + ) + + print(f"Created experiment: {experiment_id}") + print(f"Variations: {len(temp_variations + length_variations)}") + print(f"Test queries: {len(test_queries)}") \ No newline at end of file diff --git a/shared_utils/generation/prompt_templates.py b/shared_utils/generation/prompt_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b9db9d2d96045c97110d6a9433f991883d11d4 --- /dev/null +++ b/shared_utils/generation/prompt_templates.py @@ -0,0 +1,520 @@ +""" +Prompt templates optimized for technical documentation Q&A. + +This module provides specialized prompt templates for different types of +technical queries, with a focus on embedded systems and AI documentation. +""" + +from enum import Enum +from typing import Dict, List, Optional +from dataclasses import dataclass + + +class QueryType(Enum): + """Types of technical queries.""" + DEFINITION = "definition" + IMPLEMENTATION = "implementation" + COMPARISON = "comparison" + TROUBLESHOOTING = "troubleshooting" + SPECIFICATION = "specification" + CODE_EXAMPLE = "code_example" + HARDWARE_CONSTRAINT = "hardware_constraint" + GENERAL = "general" + + +@dataclass +class PromptTemplate: + """Represents a prompt template with its components.""" + system_prompt: str + context_format: str + query_format: str + answer_guidelines: str + few_shot_examples: Optional[List[str]] = None + + +class TechnicalPromptTemplates: + """ + Collection of prompt templates optimized for technical documentation. + + Features: + - Domain-specific templates for embedded systems and AI + - Structured output formats + - Citation requirements + - Technical accuracy emphasis + """ + + @staticmethod + def get_base_system_prompt() -> str: + """Get the base system prompt for technical documentation.""" + return """You are an expert technical documentation assistant specializing in embedded systems, +RISC-V architecture, RTOS, and embedded AI/ML. Your role is to provide accurate, detailed +technical answers based strictly on the provided context. + +Key responsibilities: +1. Answer questions using ONLY information from the provided context +2. Include precise citations using [chunk_X] notation for every claim +3. Maintain technical accuracy and use correct terminology +4. Format code snippets and technical specifications properly +5. Clearly state when information is not available in the context +6. Consider hardware constraints and embedded system limitations when relevant + +Write naturally and conversationally. Avoid repetitive phrases and numbered lists unless specifically requested. Never make up information. If the context doesn't contain the answer, say so explicitly.""" + + @staticmethod + def get_definition_template() -> PromptTemplate: + """Template for definition/explanation queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For definition queries, focus on: +- Clear, concise technical definitions +- Related concepts and terminology +- Technical context and applications +- Any acronym expansions""", + + context_format="""Technical Documentation Context: +{context}""", + + query_format="""Define or explain: {query} + +Provide a comprehensive technical definition with proper citations.""", + + answer_guidelines="""Provide a clear, comprehensive answer that directly addresses the question. Include relevant technical details and cite your sources using [chunk_X] notation. Make your response natural and conversational while maintaining technical accuracy.""", + + few_shot_examples=[ + """Q: What is RISC-V? +A: RISC-V is an open-source instruction set architecture (ISA) based on established reduced instruction set computing (RISC) principles [chunk_1]. Unlike proprietary ISAs, RISC-V is freely available under open-source licenses, allowing anyone to implement RISC-V processors without licensing fees [chunk_2]. The architecture supports 32-bit, 64-bit, and 128-bit address spaces, with a modular design that includes base integer instruction sets and optional extensions [chunk_3]. RISC-V stands for "RISC-Five" referring to the fifth generation of RISC architecture developed at UC Berkeley.""", + + """Q: What is FreeRTOS? +A: FreeRTOS is a real-time operating system kernel for embedded devices that provides multitasking capabilities for microcontrollers and small microprocessors [chunk_1]. It implements a preemptive scheduler with priority-based task scheduling, ensuring deterministic real-time behavior [chunk_2]. FreeRTOS includes core features like task management, semaphores, queues, and memory management while maintaining a small footprint typically under 10KB [chunk_3]. The "Free" in FreeRTOS refers to both its open-source license and the fact that it's free of charge for commercial use.""" + ] + ) + + @staticmethod + def get_implementation_template() -> PromptTemplate: + """Template for implementation/how-to queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For implementation queries, focus on: +- Step-by-step instructions +- Required components or dependencies +- Code examples with proper formatting +- Hardware/software requirements +- Common pitfalls or considerations""", + + context_format="""Implementation Documentation: +{context}""", + + query_format="""Implementation question: {query} + +Provide detailed implementation guidance with code examples where available.""", + + answer_guidelines="""Structure your answer as: +1. Overview of the implementation approach [chunk_X] +2. Prerequisites and requirements [chunk_Y] +3. Step-by-step implementation: + - Step 1: Description [chunk_Z] + - Step 2: Description [chunk_W] +4. Code example (if available): +```language +// Code here +``` +5. Important considerations or warnings""", + + few_shot_examples=[ + """Q: How do I configure GPIO pins for output in RISC-V? +A: GPIO configuration for output requires setting the pin direction and initial value through memory-mapped registers [chunk_1]. First, identify the GPIO base address for your specific RISC-V implementation (commonly 0x10060000 for SiFive cores) [chunk_2]. + +Steps: +1. Set pin direction to output by writing to GPIO_OUTPUT_EN register [chunk_3] +2. Configure initial output value using GPIO_OUTPUT_VAL register [chunk_4] + +```c +#define GPIO_BASE 0x10060000 +#define GPIO_OUTPUT_EN (GPIO_BASE + 0x08) +#define GPIO_OUTPUT_VAL (GPIO_BASE + 0x0C) + +// Configure pin 5 as output +volatile uint32_t *gpio_en = (uint32_t*)GPIO_OUTPUT_EN; +volatile uint32_t *gpio_val = (uint32_t*)GPIO_OUTPUT_VAL; + +*gpio_en |= (1 << 5); // Enable output on pin 5 +*gpio_val |= (1 << 5); // Set pin 5 high +``` + +Important: Always check your board's documentation for the correct GPIO base address and pin mapping [chunk_5].""", + + """Q: How to implement a basic timer interrupt in RISC-V? +A: Timer interrupts in RISC-V use the machine timer (mtime) and timer compare (mtimecmp) registers for precise timing control [chunk_1]. The implementation requires configuring the timer hardware, setting up the interrupt handler, and enabling machine timer interrupts [chunk_2]. + +Prerequisites: +- RISC-V processor with timer support +- Access to machine-level CSRs +- Understanding of memory-mapped timer registers [chunk_3] + +Implementation steps: +1. Set up timer compare value in mtimecmp register [chunk_4] +2. Enable machine timer interrupt in mie CSR [chunk_5] +3. Configure interrupt handler in mtvec CSR [chunk_6] + +```c +#define MTIME_BASE 0x0200bff8 +#define MTIMECMP_BASE 0x02004000 + +void setup_timer_interrupt(uint64_t interval) { + uint64_t *mtime = (uint64_t*)MTIME_BASE; + uint64_t *mtimecmp = (uint64_t*)MTIMECMP_BASE; + + // Set next interrupt time + *mtimecmp = *mtime + interval; + + // Enable machine timer interrupt + asm volatile ("csrs mie, %0" : : "r"(0x80)); + + // Enable global interrupts + asm volatile ("csrs mstatus, %0" : : "r"(0x8)); +} +``` + +Critical considerations: Timer registers are 64-bit and must be accessed atomically on 32-bit systems [chunk_7].""" + ] + ) + + @staticmethod + def get_comparison_template() -> PromptTemplate: + """Template for comparison queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For comparison queries, focus on: +- Clear distinction between compared items +- Technical specifications and differences +- Use cases for each option +- Performance or resource implications +- Recommendations based on context""", + + context_format="""Technical Comparison Context: +{context}""", + + query_format="""Compare: {query} + +Provide a detailed technical comparison with clear distinctions.""", + + answer_guidelines="""Structure your answer as: +1. Overview of items being compared [chunk_X] +2. Key differences: + - Feature A: Item1 vs Item2 [chunk_Y] + - Feature B: Item1 vs Item2 [chunk_Z] +3. Technical specifications comparison +4. Use case recommendations +5. Performance/resource considerations""" + ) + + @staticmethod + def get_specification_template() -> PromptTemplate: + """Template for specification/parameter queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For specification queries, focus on: +- Exact technical specifications +- Parameter ranges and limits +- Units and measurements +- Compliance with standards +- Version-specific information""", + + context_format="""Technical Specifications: +{context}""", + + query_format="""Specification query: {query} + +Provide precise technical specifications with all relevant parameters.""", + + answer_guidelines="""Structure your answer as: +1. Specification overview [chunk_X] +2. Detailed parameters: + - Parameter 1: value (unit) [chunk_Y] + - Parameter 2: value (unit) [chunk_Z] +3. Operating conditions or constraints +4. Compliance/standards information +5. Version or variant notes""" + ) + + @staticmethod + def get_code_example_template() -> PromptTemplate: + """Template for code example queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For code example queries, focus on: +- Complete, runnable code examples +- Proper syntax highlighting +- Clear comments and documentation +- Error handling +- Best practices for embedded systems""", + + context_format="""Code Examples and Documentation: +{context}""", + + query_format="""Code example request: {query} + +Provide working code examples with explanations.""", + + answer_guidelines="""Structure your answer as: +1. Purpose and overview [chunk_X] +2. Required includes/imports [chunk_Y] +3. Complete code example: +```c +// Or appropriate language +#include + +// Function or code implementation +``` +4. Key points explained [chunk_Z] +5. Common variations or modifications""" + ) + + @staticmethod + def get_hardware_constraint_template() -> PromptTemplate: + """Template for hardware constraint queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For hardware constraint queries, focus on: +- Memory requirements (RAM, Flash) +- Processing power needs (MIPS, frequency) +- Power consumption +- I/O requirements +- Real-time constraints +- Temperature/environmental limits""", + + context_format="""Hardware Specifications and Constraints: +{context}""", + + query_format="""Hardware constraint question: {query} + +Analyze feasibility and constraints for embedded deployment.""", + + answer_guidelines="""Structure your answer as: +1. Hardware requirements summary [chunk_X] +2. Detailed constraints: + - Memory: RAM/Flash requirements [chunk_Y] + - Processing: CPU/frequency needs [chunk_Z] + - Power: Consumption estimates [chunk_W] +3. Feasibility assessment +4. Optimization suggestions +5. Alternative approaches if constraints are exceeded""" + ) + + @staticmethod + def get_troubleshooting_template() -> PromptTemplate: + """Template for troubleshooting queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For troubleshooting queries, focus on: +- Common error causes +- Diagnostic steps +- Solution procedures +- Preventive measures +- Debug techniques for embedded systems""", + + context_format="""Troubleshooting Documentation: +{context}""", + + query_format="""Troubleshooting issue: {query} + +Provide diagnostic steps and solutions.""", + + answer_guidelines="""Structure your answer as: +1. Problem identification [chunk_X] +2. Common causes: + - Cause 1: Description [chunk_Y] + - Cause 2: Description [chunk_Z] +3. Diagnostic steps: + - Step 1: Check... [chunk_W] + - Step 2: Verify... [chunk_V] +4. Solutions for each cause +5. Prevention recommendations""" + ) + + @staticmethod + def get_general_template() -> PromptTemplate: + """Default template for general queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt(), + + context_format="""Technical Documentation: +{context}""", + + query_format="""Question: {query} + +Provide a comprehensive technical answer based on the documentation.""", + + answer_guidelines="""Provide a clear, comprehensive answer that directly addresses the question. Include relevant technical details and cite your sources using [chunk_X] notation. Write naturally and conversationally while maintaining technical accuracy. Acknowledge any limitations in available information.""" + ) + + @staticmethod + def detect_query_type(query: str) -> QueryType: + """ + Detect the type of query based on keywords and patterns. + + Args: + query: User's question + + Returns: + Detected QueryType + """ + query_lower = query.lower() + + # Definition keywords + if any(keyword in query_lower for keyword in [ + "what is", "what are", "define", "definition", "meaning of", "explain what" + ]): + return QueryType.DEFINITION + + # Implementation keywords + if any(keyword in query_lower for keyword in [ + "how to", "how do i", "implement", "setup", "configure", "install" + ]): + return QueryType.IMPLEMENTATION + + # Comparison keywords + if any(keyword in query_lower for keyword in [ + "difference between", "compare", "vs", "versus", "better than", "which is" + ]): + return QueryType.COMPARISON + + # Specification keywords + if any(keyword in query_lower for keyword in [ + "specification", "specs", "parameters", "limits", "range", "maximum", "minimum" + ]): + return QueryType.SPECIFICATION + + # Code example keywords + if any(keyword in query_lower for keyword in [ + "example", "code", "snippet", "sample", "demo", "show me" + ]): + return QueryType.CODE_EXAMPLE + + # Hardware constraint keywords + if any(keyword in query_lower for keyword in [ + "memory", "ram", "flash", "mcu", "constraint", "fit on", "run on", "power consumption" + ]): + return QueryType.HARDWARE_CONSTRAINT + + # Troubleshooting keywords + if any(keyword in query_lower for keyword in [ + "error", "problem", "issue", "debug", "troubleshoot", "fix", "solve", "not working" + ]): + return QueryType.TROUBLESHOOTING + + return QueryType.GENERAL + + @staticmethod + def get_template_for_query(query: str) -> PromptTemplate: + """ + Get the appropriate template based on query type. + + Args: + query: User's question + + Returns: + Appropriate PromptTemplate + """ + query_type = TechnicalPromptTemplates.detect_query_type(query) + + template_map = { + QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template, + QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template, + QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template, + QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template, + QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template, + QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template, + QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template, + QueryType.GENERAL: TechnicalPromptTemplates.get_general_template + } + + return template_map[query_type]() + + @staticmethod + def format_prompt_with_template( + query: str, + context: str, + template: Optional[PromptTemplate] = None, + include_few_shot: bool = True + ) -> Dict[str, str]: + """ + Format a complete prompt using the appropriate template. + + Args: + query: User's question + context: Retrieved context chunks + template: Optional specific template (auto-detected if None) + include_few_shot: Whether to include few-shot examples + + Returns: + Dict with 'system' and 'user' prompts + """ + if template is None: + template = TechnicalPromptTemplates.get_template_for_query(query) + + # Format the context + formatted_context = template.context_format.format(context=context) + + # Format the query + formatted_query = template.query_format.format(query=query) + + # Build user prompt with optional few-shot examples + user_prompt_parts = [] + + # Add few-shot examples if available and requested + if include_few_shot and template.few_shot_examples: + user_prompt_parts.append("Here are some examples of how to answer similar questions:") + user_prompt_parts.append("\n\n".join(template.few_shot_examples)) + user_prompt_parts.append("\nNow answer the following question using the same format:") + + user_prompt_parts.extend([ + formatted_context, + formatted_query, + template.answer_guidelines + ]) + + user_prompt = "\n\n".join(user_prompt_parts) + + return { + "system": template.system_prompt, + "user": user_prompt + } + + +# Example usage and testing +if __name__ == "__main__": + # Test query type detection + test_queries = [ + "What is RISC-V?", + "How do I implement a timer interrupt?", + "What's the difference between FreeRTOS and Zephyr?", + "What are the memory specifications for STM32F4?", + "Show me an example of GPIO configuration", + "Can this model run on an MCU with 256KB RAM?", + "Debug error: undefined reference to main" + ] + + for query in test_queries: + query_type = TechnicalPromptTemplates.detect_query_type(query) + print(f"Query: '{query}' -> Type: {query_type.value}") + + # Example prompt formatting + example_context = "RISC-V is an open instruction set architecture..." + example_query = "What is RISC-V?" + + formatted = TechnicalPromptTemplates.format_prompt_with_template( + query=example_query, + context=example_context + ) + + print("\nFormatted prompt example:") + print("System:", formatted["system"][:100], "...") + print("User:", formatted["user"][:200], "...") \ No newline at end of file diff --git a/shared_utils/query_processing/__init__.py b/shared_utils/query_processing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..482ba426819b9ee75762815a83bcd51232a7ae24 --- /dev/null +++ b/shared_utils/query_processing/__init__.py @@ -0,0 +1,8 @@ +""" +Query processing utilities for intelligent RAG systems. +Provides query enhancement, analysis, and optimization capabilities. +""" + +from .query_enhancer import QueryEnhancer + +__all__ = ['QueryEnhancer'] \ No newline at end of file diff --git a/shared_utils/query_processing/__pycache__/__init__.cpython-312.pyc b/shared_utils/query_processing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e810d05d4617c4e2d512949b9241f74321b05ed0 Binary files /dev/null and b/shared_utils/query_processing/__pycache__/__init__.cpython-312.pyc differ diff --git a/shared_utils/query_processing/__pycache__/query_enhancer.cpython-312.pyc b/shared_utils/query_processing/__pycache__/query_enhancer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9525380f13366ea46714a2aee7ebeedb125330b2 Binary files /dev/null and b/shared_utils/query_processing/__pycache__/query_enhancer.cpython-312.pyc differ diff --git a/shared_utils/query_processing/query_enhancer.py b/shared_utils/query_processing/query_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..233591b67ce213046586dea00ecd7135cda9a842 --- /dev/null +++ b/shared_utils/query_processing/query_enhancer.py @@ -0,0 +1,644 @@ +""" +Intelligent query processing for technical documentation RAG. + +Provides adaptive query enhancement through technical term expansion, +acronym handling, and intelligent hybrid weighting optimization. +""" + +from typing import Dict, List, Any, Tuple, Set, Optional +import re +from collections import defaultdict +import time + + +class QueryEnhancer: + """ + Intelligent query processing for technical documentation RAG. + + Analyzes query characteristics and enhances retrieval through: + - Technical synonym expansion + - Acronym detection and expansion + - Adaptive hybrid weighting based on query type + - Query complexity analysis for optimal retrieval strategy + + Optimized for embedded systems and technical documentation domains. + + Performance: <10ms query enhancement, improves retrieval relevance by >10% + """ + + def __init__(self): + """Initialize QueryEnhancer with technical domain knowledge.""" + + # Technical vocabulary dictionary organized by domain + self.technical_synonyms = { + # Processor terminology + 'cpu': ['processor', 'microprocessor', 'central processing unit'], + 'mcu': ['microcontroller', 'microcontroller unit', 'embedded processor'], + 'core': ['processor core', 'cpu core', 'execution unit'], + 'alu': ['arithmetic logic unit', 'arithmetic unit'], + + # Memory terminology + 'memory': ['ram', 'storage', 'buffer', 'cache'], + 'flash': ['non-volatile memory', 'program memory', 'code storage'], + 'sram': ['static ram', 'static memory', 'cache memory'], + 'dram': ['dynamic ram', 'dynamic memory'], + 'cache': ['buffer', 'temporary storage', 'fast memory'], + + # Architecture terminology + 'risc-v': ['riscv', 'risc v', 'open isa', 'open instruction set'], + 'arm': ['advanced risc machine', 'acorn risc machine'], + 'isa': ['instruction set architecture', 'instruction set'], + 'architecture': ['design', 'structure', 'organization'], + + # Embedded systems terminology + 'rtos': ['real-time operating system', 'real-time os'], + 'interrupt': ['isr', 'interrupt service routine', 'exception handler'], + 'peripheral': ['hardware peripheral', 'external device', 'io device'], + 'firmware': ['embedded software', 'system software'], + 'bootloader': ['boot code', 'initialization code'], + + # Performance terminology + 'latency': ['delay', 'response time', 'execution time'], + 'throughput': ['bandwidth', 'data rate', 'performance'], + 'power': ['power consumption', 'energy usage', 'battery life'], + 'optimization': ['improvement', 'enhancement', 'tuning'], + + # Communication protocols + 'uart': ['serial communication', 'async serial'], + 'spi': ['serial peripheral interface', 'synchronous serial'], + 'i2c': ['inter-integrated circuit', 'two-wire interface'], + 'usb': ['universal serial bus'], + + # Development terminology + 'debug': ['debugging', 'troubleshooting', 'testing'], + 'compile': ['compilation', 'build', 'assembly'], + 'programming': ['coding', 'development', 'implementation'] + } + + # Comprehensive acronym expansions for embedded/technical domains + self.acronym_expansions = { + # Processor & Architecture + 'CPU': 'Central Processing Unit', + 'MCU': 'Microcontroller Unit', + 'MPU': 'Microprocessor Unit', + 'DSP': 'Digital Signal Processor', + 'GPU': 'Graphics Processing Unit', + 'ALU': 'Arithmetic Logic Unit', + 'FPU': 'Floating Point Unit', + 'MMU': 'Memory Management Unit', + 'ISA': 'Instruction Set Architecture', + 'RISC': 'Reduced Instruction Set Computer', + 'CISC': 'Complex Instruction Set Computer', + + # Memory & Storage + 'RAM': 'Random Access Memory', + 'ROM': 'Read Only Memory', + 'EEPROM': 'Electrically Erasable Programmable ROM', + 'SRAM': 'Static Random Access Memory', + 'DRAM': 'Dynamic Random Access Memory', + 'FRAM': 'Ferroelectric Random Access Memory', + 'MRAM': 'Magnetoresistive Random Access Memory', + 'DMA': 'Direct Memory Access', + + # Operating Systems & Software + 'RTOS': 'Real-Time Operating System', + 'OS': 'Operating System', + 'API': 'Application Programming Interface', + 'SDK': 'Software Development Kit', + 'IDE': 'Integrated Development Environment', + 'HAL': 'Hardware Abstraction Layer', + 'BSP': 'Board Support Package', + + # Interrupts & Exceptions + 'ISR': 'Interrupt Service Routine', + 'IRQ': 'Interrupt Request', + 'NMI': 'Non-Maskable Interrupt', + 'NVIC': 'Nested Vectored Interrupt Controller', + + # Communication Protocols + 'UART': 'Universal Asynchronous Receiver Transmitter', + 'USART': 'Universal Synchronous Asynchronous Receiver Transmitter', + 'SPI': 'Serial Peripheral Interface', + 'I2C': 'Inter-Integrated Circuit', + 'CAN': 'Controller Area Network', + 'USB': 'Universal Serial Bus', + 'TCP': 'Transmission Control Protocol', + 'UDP': 'User Datagram Protocol', + 'HTTP': 'HyperText Transfer Protocol', + 'MQTT': 'Message Queuing Telemetry Transport', + + # Analog & Digital + 'ADC': 'Analog to Digital Converter', + 'DAC': 'Digital to Analog Converter', + 'PWM': 'Pulse Width Modulation', + 'GPIO': 'General Purpose Input Output', + 'JTAG': 'Joint Test Action Group', + 'SWD': 'Serial Wire Debug', + + # Power & Clock + 'PLL': 'Phase Locked Loop', + 'VCO': 'Voltage Controlled Oscillator', + 'LDO': 'Low Dropout Regulator', + 'PMU': 'Power Management Unit', + 'RTC': 'Real Time Clock', + + # Standards & Organizations + 'IEEE': 'Institute of Electrical and Electronics Engineers', + 'ISO': 'International Organization for Standardization', + 'ANSI': 'American National Standards Institute', + 'IEC': 'International Electrotechnical Commission' + } + + # Compile regex patterns for efficiency + self._acronym_pattern = re.compile(r'\b[A-Z]{2,}\b') + self._technical_term_pattern = re.compile(r'\b\w+(?:-\w+)*\b', re.IGNORECASE) + self._question_indicators = re.compile(r'\b(?:how|what|why|when|where|which|explain|describe|define)\b', re.IGNORECASE) + + # Question type classification keywords + self.question_type_keywords = { + 'conceptual': ['how', 'why', 'what', 'explain', 'describe', 'understand', 'concept', 'theory'], + 'technical': ['configure', 'implement', 'setup', 'install', 'code', 'program', 'register'], + 'procedural': ['steps', 'process', 'procedure', 'workflow', 'guide', 'tutorial'], + 'troubleshooting': ['error', 'problem', 'issue', 'debug', 'fix', 'solve', 'troubleshoot'] + } + + def analyze_query_characteristics(self, query: str) -> Dict[str, Any]: + """ + Analyze query to determine optimal processing strategy. + + Performs comprehensive analysis including: + - Technical term detection and counting + - Acronym presence identification + - Question type classification + - Complexity scoring based on multiple factors + - Optimal hybrid weight recommendation + + Args: + query: User input query string + + Returns: + Dictionary with comprehensive query analysis: + - technical_term_count: Number of domain-specific terms detected + - has_acronyms: Boolean indicating acronym presence + - question_type: 'conceptual', 'technical', 'procedural', 'mixed' + - complexity_score: Float 0-1 indicating query complexity + - recommended_dense_weight: Optimal weight for hybrid search + - detected_acronyms: List of acronyms found + - technical_terms: List of technical terms found + + Performance: <2ms for typical queries + """ + if not query or not query.strip(): + return { + 'technical_term_count': 0, + 'has_acronyms': False, + 'question_type': 'unknown', + 'complexity_score': 0.0, + 'recommended_dense_weight': 0.7, + 'detected_acronyms': [], + 'technical_terms': [] + } + + query_lower = query.lower() + words = query.split() + + # Detect acronyms + detected_acronyms = self._acronym_pattern.findall(query) + has_acronyms = len(detected_acronyms) > 0 + + # Detect technical terms + technical_terms = [] + technical_term_count = 0 + + for word in words: + word_clean = re.sub(r'[^\w\-]', '', word.lower()) + if word_clean in self.technical_synonyms: + technical_terms.append(word_clean) + technical_term_count += 1 + # Also check for compound technical terms like "risc-v" + elif any(term in word_clean for term in ['risc-v', 'arm', 'cpu', 'mcu']): + technical_terms.append(word_clean) + technical_term_count += 1 + + # Add acronyms to technical term count + for acronym in detected_acronyms: + if acronym in self.acronym_expansions: + technical_term_count += 1 + + # Determine question type + question_type = self._classify_question_type(query_lower) + + # Calculate complexity score (0-1) + complexity_factors = [ + len(words) / 20.0, # Word count factor (normalized to 20 words max) + technical_term_count / 5.0, # Technical density (normalized to 5 terms max) + len(detected_acronyms) / 3.0, # Acronym density (normalized to 3 acronyms max) + 1.0 if self._question_indicators.search(query) else 0.5, # Question complexity + ] + complexity_score = min(1.0, sum(complexity_factors) / len(complexity_factors)) + + # Determine recommended dense weight based on analysis + recommended_dense_weight = self._calculate_optimal_weight( + question_type, technical_term_count, has_acronyms, complexity_score + ) + + return { + 'technical_term_count': technical_term_count, + 'has_acronyms': has_acronyms, + 'question_type': question_type, + 'complexity_score': complexity_score, + 'recommended_dense_weight': recommended_dense_weight, + 'detected_acronyms': detected_acronyms, + 'technical_terms': technical_terms, + 'word_count': len(words), + 'has_question_indicators': bool(self._question_indicators.search(query)) + } + + def _classify_question_type(self, query_lower: str) -> str: + """Classify query into conceptual, technical, procedural, or mixed categories.""" + type_scores = defaultdict(int) + + for question_type, keywords in self.question_type_keywords.items(): + for keyword in keywords: + if keyword in query_lower: + type_scores[question_type] += 1 + + if not type_scores: + return 'mixed' + + # Return type with highest score, or 'mixed' if tie + max_score = max(type_scores.values()) + top_types = [t for t, s in type_scores.items() if s == max_score] + + return top_types[0] if len(top_types) == 1 else 'mixed' + + def _calculate_optimal_weight(self, question_type: str, tech_terms: int, + has_acronyms: bool, complexity: float) -> float: + """Calculate optimal dense weight based on query characteristics.""" + + # Base weights by question type + base_weights = { + 'technical': 0.3, # Favor sparse for technical precision + 'conceptual': 0.8, # Favor dense for conceptual understanding + 'procedural': 0.5, # Balanced for step-by-step queries + 'troubleshooting': 0.4, # Slight sparse favor for specific issues + 'mixed': 0.7, # Default balanced + 'unknown': 0.7 # Default balanced + } + + weight = base_weights.get(question_type, 0.7) + + # Adjust based on technical term density + if tech_terms > 2: + weight -= 0.2 # More technical → favor sparse + elif tech_terms == 0: + weight += 0.1 # Less technical → favor dense + + # Adjust based on acronym presence + if has_acronyms: + weight -= 0.1 # Acronyms → favor sparse for exact matching + + # Adjust based on complexity + if complexity > 0.8: + weight += 0.1 # High complexity → favor dense for understanding + elif complexity < 0.3: + weight -= 0.1 # Low complexity → favor sparse for precision + + # Ensure weight stays within valid bounds + return max(0.1, min(0.9, weight)) + + def expand_technical_terms(self, query: str, max_expansions: int = 1) -> str: + """ + Expand query with technical synonyms while preventing bloat. + + Conservative expansion strategy: + - Maximum 1 synonym per technical term by default + - Prioritizes most relevant/common synonyms + - Maintains semantic focus while improving recall + + Args: + query: Original user query + max_expansions: Maximum synonyms per term (default 1 for focus) + + Returns: + Conservatively enhanced query + + Example: + Input: "CPU performance optimization" + Output: "CPU processor performance optimization" + + Performance: <3ms for typical queries + """ + if not query or not query.strip(): + return query + + words = query.split() + + # Conservative expansion: only add most relevant synonym + expansion_candidates = [] + + for word in words: + word_clean = re.sub(r'[^\w\-]', '', word.lower()) + + # Check for direct synonym expansion + if word_clean in self.technical_synonyms: + synonyms = self.technical_synonyms[word_clean] + # Add only the first (most common) synonym + if synonyms and max_expansions > 0: + expansion_candidates.append(synonyms[0]) + + # Limit total expansion to prevent bloat + max_total_expansions = min(2, len(words) // 2) # At most 50% expansion + selected_expansions = expansion_candidates[:max_total_expansions] + + # Reconstruct with minimal expansion + if selected_expansions: + return ' '.join(words + selected_expansions) + else: + return query + + def detect_and_expand_acronyms(self, query: str, conservative: bool = True) -> str: + """ + Detect technical acronyms and add their expansions conservatively. + + Conservative approach to prevent query bloat: + - Limits acronym expansions to most relevant ones + - Preserves original acronyms for exact matching + - Maintains query focus and performance + + Args: + query: Query potentially containing acronyms + conservative: If True, limits expansion to prevent bloat + + Returns: + Query with selective acronym expansions + + Example: + Input: "RTOS scheduling algorithm" + Output: "RTOS Real-Time Operating System scheduling algorithm" + + Performance: <2ms for typical queries + """ + if not query or not query.strip(): + return query + + # Find all acronyms in the query + acronyms = self._acronym_pattern.findall(query) + + if not acronyms: + return query + + # Conservative mode: limit expansions + if conservative and len(acronyms) > 2: + # Only expand first 2 acronyms to prevent bloat + acronyms = acronyms[:2] + + result = query + + # Expand selected acronyms + for acronym in acronyms: + if acronym in self.acronym_expansions: + expansion = self.acronym_expansions[acronym] + # Add expansion after the acronym (preserving original) + result = result.replace(acronym, f"{acronym} {expansion}", 1) + + return result + + def adaptive_hybrid_weighting(self, query: str) -> float: + """ + Determine optimal dense_weight based on query characteristics. + + Analyzes query to automatically determine the best balance between + dense semantic search and sparse keyword matching for optimal results. + + Strategy: + - Technical/exact queries → lower dense_weight (favor sparse/BM25) + - Conceptual questions → higher dense_weight (favor semantic) + - Mixed queries → balanced weighting based on complexity + + Args: + query: User query string + + Returns: + Float between 0.1 and 0.9 representing optimal dense_weight + + Performance: <2ms analysis time + """ + analysis = self.analyze_query_characteristics(query) + return analysis['recommended_dense_weight'] + + def enhance_query(self, query: str, conservative: bool = True) -> Dict[str, Any]: + """ + Comprehensive query enhancement with performance and quality focus. + + Optimized enhancement strategy: + - Conservative expansion to maintain semantic focus + - Performance-first approach with minimal overhead + - Quality validation to ensure improvements + + Args: + query: Original user query + conservative: Use conservative expansion (recommended for production) + + Returns: + Dictionary containing: + - enhanced_query: Optimized enhanced query + - optimal_weight: Recommended dense weight + - analysis: Complete query analysis + - enhancement_metadata: Performance and quality metrics + + Performance: <5ms total enhancement time + """ + start_time = time.perf_counter() + + # Fast analysis + analysis = self.analyze_query_characteristics(query) + + # Conservative enhancement approach + if conservative: + enhanced_query = self.expand_technical_terms(query, max_expansions=1) + enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True) + else: + # Legacy aggressive expansion + enhanced_query = self.expand_technical_terms(query, max_expansions=2) + enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=False) + + # Quality validation: prevent excessive bloat + expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0 + if expansion_ratio > 2.5: # Limit to 2.5x expansion + # Fallback to minimal enhancement + enhanced_query = self.expand_technical_terms(query, max_expansions=0) + enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True) + expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0 + + # Calculate optimal weight + optimal_weight = analysis['recommended_dense_weight'] + + enhancement_time = time.perf_counter() - start_time + + return { + 'enhanced_query': enhanced_query, + 'optimal_weight': optimal_weight, + 'analysis': analysis, + 'enhancement_metadata': { + 'original_length': len(query.split()), + 'enhanced_length': len(enhanced_query.split()), + 'expansion_ratio': expansion_ratio, + 'processing_time_ms': enhancement_time * 1000, + 'techniques_applied': ['conservative_expansion', 'quality_validation', 'adaptive_weighting'], + 'conservative_mode': conservative + } + } + + def expand_technical_terms_with_vocabulary( + self, + query: str, + vocabulary_index: Optional['VocabularyIndex'] = None, + min_frequency: int = 3 + ) -> str: + """ + Expand query with vocabulary-aware synonym filtering. + + Only adds synonyms that exist in the document corpus with sufficient + frequency to ensure relevance and prevent query dilution. + + Args: + query: Original query + vocabulary_index: Optional vocabulary index for filtering + min_frequency: Minimum term frequency required + + Returns: + Enhanced query with validated synonyms + + Performance: <2ms with vocabulary validation + """ + if not query or not query.strip(): + return query + + if vocabulary_index is None: + # Fallback to standard expansion + return self.expand_technical_terms(query, max_expansions=1) + + words = query.split() + expanded_terms = [] + + for word in words: + word_clean = re.sub(r'[^\w\-]', '', word.lower()) + + # Check for synonym expansion + if word_clean in self.technical_synonyms: + synonyms = self.technical_synonyms[word_clean] + + # Filter synonyms through vocabulary + valid_synonyms = vocabulary_index.filter_synonyms( + synonyms, + min_frequency=min_frequency + ) + + # Add only the best valid synonym + if valid_synonyms: + expanded_terms.append(valid_synonyms[0]) + + # Reconstruct query with validated expansions + if expanded_terms: + return ' '.join(words + expanded_terms) + else: + return query + + def enhance_query_with_vocabulary( + self, + query: str, + vocabulary_index: Optional['VocabularyIndex'] = None, + min_frequency: int = 3, + require_technical: bool = False + ) -> Dict[str, Any]: + """ + Enhanced query processing with vocabulary validation. + + Uses corpus vocabulary to ensure all expansions are relevant + and actually present in the documents. + + Args: + query: Original query + vocabulary_index: Vocabulary index for validation + min_frequency: Minimum term frequency + require_technical: Only expand with technical terms + + Returns: + Enhanced query with vocabulary-aware expansion + """ + start_time = time.perf_counter() + + # Perform analysis + analysis = self.analyze_query_characteristics(query) + + # Vocabulary-aware enhancement + if vocabulary_index: + # Technical term expansion with validation + enhanced_query = self.expand_technical_terms_with_vocabulary( + query, vocabulary_index, min_frequency + ) + + # Acronym expansion (already conservative) + enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True) + + # Track vocabulary validation + validation_applied = True + + # Detect domain if available + detected_domain = vocabulary_index.detect_domain() + else: + # Fallback to standard enhancement + enhanced_query = self.expand_technical_terms(query, max_expansions=1) + enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True) + validation_applied = False + detected_domain = 'unknown' + + # Calculate metrics + expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0 + enhancement_time = time.perf_counter() - start_time + + return { + 'enhanced_query': enhanced_query, + 'optimal_weight': analysis['recommended_dense_weight'], + 'analysis': analysis, + 'enhancement_metadata': { + 'original_length': len(query.split()), + 'enhanced_length': len(enhanced_query.split()), + 'expansion_ratio': expansion_ratio, + 'processing_time_ms': enhancement_time * 1000, + 'techniques_applied': ['vocabulary_validation', 'conservative_expansion'], + 'vocabulary_validated': validation_applied, + 'detected_domain': detected_domain, + 'min_frequency_threshold': min_frequency + } + } + + def get_enhancement_stats(self) -> Dict[str, Any]: + """ + Get statistics about the enhancement system capabilities. + + Returns: + Dictionary with system statistics and capabilities + """ + return { + 'technical_synonyms_count': len(self.technical_synonyms), + 'acronym_expansions_count': len(self.acronym_expansions), + 'supported_domains': [ + 'embedded_systems', 'processor_architecture', 'memory_systems', + 'communication_protocols', 'real_time_systems', 'power_management' + ], + 'question_types_supported': list(self.question_type_keywords.keys()), + 'weight_range': {'min': 0.1, 'max': 0.9, 'default': 0.7}, + 'performance_targets': { + 'enhancement_time_ms': '<10', + 'accuracy_improvement': '>10%', + 'memory_overhead': '<1MB' + }, + 'vocabulary_features': { + 'vocabulary_aware_expansion': True, + 'min_frequency_filtering': True, + 'domain_detection': True, + 'technical_term_priority': True + } + } \ No newline at end of file diff --git a/shared_utils/retrieval/__init__.py b/shared_utils/retrieval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b8c0a57b3117afdb42d8f790ae79018395c7e09 --- /dev/null +++ b/shared_utils/retrieval/__init__.py @@ -0,0 +1,8 @@ +""" +Retrieval utilities for hybrid RAG systems. +Combines dense semantic search with sparse keyword matching. +""" + +from .hybrid_search import HybridRetriever + +__all__ = ['HybridRetriever'] \ No newline at end of file diff --git a/shared_utils/retrieval/__pycache__/__init__.cpython-312.pyc b/shared_utils/retrieval/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce20008140f9969e85a2b7e0cb550078c564dacc Binary files /dev/null and b/shared_utils/retrieval/__pycache__/__init__.cpython-312.pyc differ diff --git a/shared_utils/retrieval/__pycache__/hybrid_search.cpython-312.pyc b/shared_utils/retrieval/__pycache__/hybrid_search.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94632c6fd554a283fc62a3539497705d41507ee2 Binary files /dev/null and b/shared_utils/retrieval/__pycache__/hybrid_search.cpython-312.pyc differ diff --git a/shared_utils/retrieval/__pycache__/vocabulary_index.cpython-312.pyc b/shared_utils/retrieval/__pycache__/vocabulary_index.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e721138cb755c3b64e60ed7adf6fed0af1d42a1 Binary files /dev/null and b/shared_utils/retrieval/__pycache__/vocabulary_index.cpython-312.pyc differ diff --git a/shared_utils/retrieval/hybrid_search.py b/shared_utils/retrieval/hybrid_search.py new file mode 100644 index 0000000000000000000000000000000000000000..7a538e3b93475dcbdc3efd1225f66bc616c68b1c --- /dev/null +++ b/shared_utils/retrieval/hybrid_search.py @@ -0,0 +1,277 @@ +""" +Hybrid retrieval combining dense semantic search with sparse BM25 keyword matching. +Uses Reciprocal Rank Fusion (RRF) to combine results from both approaches. +""" + +from typing import List, Dict, Tuple, Optional +import numpy as np +from pathlib import Path +import sys + +# Add project root to Python path for imports +project_root = Path(__file__).parent.parent.parent / "project-1-technical-rag" +sys.path.append(str(project_root)) + +from src.sparse_retrieval import BM25SparseRetriever +from src.fusion import reciprocal_rank_fusion, adaptive_fusion +from shared_utils.embeddings.generator import generate_embeddings +import faiss + + +class HybridRetriever: + """ + Hybrid retrieval system combining dense semantic search with sparse BM25. + + Optimized for technical documentation where both semantic similarity + and exact keyword matching are important for retrieval quality. + + Performance: Sub-second search on 1000+ document corpus + """ + + def __init__( + self, + dense_weight: float = 0.7, + embedding_model: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", + use_mps: bool = True, + bm25_k1: float = 1.2, + bm25_b: float = 0.75, + rrf_k: int = 10 + ): + """ + Initialize hybrid retriever with dense and sparse components. + + Args: + dense_weight: Weight for semantic similarity in fusion (0.7 default) + embedding_model: Sentence transformer model name + use_mps: Use Apple Silicon MPS acceleration for embeddings + bm25_k1: BM25 term frequency saturation parameter + bm25_b: BM25 document length normalization parameter + rrf_k: Reciprocal Rank Fusion constant (1=strong rank preference, 2=moderate) + + Raises: + ValueError: If parameters are invalid + """ + if not 0 <= dense_weight <= 1: + raise ValueError("dense_weight must be between 0 and 1") + + self.dense_weight = dense_weight + self.embedding_model = embedding_model + self.use_mps = use_mps + self.rrf_k = rrf_k + + # Initialize sparse retriever + self.sparse_retriever = BM25SparseRetriever(k1=bm25_k1, b=bm25_b) + + # Dense retrieval components (initialized on first index) + self.dense_index: Optional[faiss.Index] = None + self.chunks: List[Dict] = [] + self.embeddings: Optional[np.ndarray] = None + + def index_documents(self, chunks: List[Dict]) -> None: + """ + Index documents for both dense and sparse retrieval. + + Args: + chunks: List of chunk dictionaries with 'text' field + + Raises: + ValueError: If chunks is empty or malformed + + Performance: ~100 chunks/second for complete indexing + """ + if not chunks: + raise ValueError("Cannot index empty chunk list") + + print(f"Indexing {len(chunks)} chunks for hybrid retrieval...") + + # Store chunks for retrieval + self.chunks = chunks + + # Index for sparse retrieval + print("Building BM25 sparse index...") + self.sparse_retriever.index_documents(chunks) + + # Index for dense retrieval + print("Building dense semantic index...") + texts = [chunk['text'] for chunk in chunks] + + # Generate embeddings + self.embeddings = generate_embeddings( + texts, + model_name=self.embedding_model, + use_mps=self.use_mps + ) + + # Create FAISS index + embedding_dim = self.embeddings.shape[1] + self.dense_index = faiss.IndexFlatIP(embedding_dim) # Inner product for cosine similarity + + # Normalize embeddings for cosine similarity + faiss.normalize_L2(self.embeddings) + self.dense_index.add(self.embeddings) + + print(f"Hybrid indexing complete: {len(chunks)} chunks ready for search") + + def search( + self, + query: str, + top_k: int = 10, + dense_top_k: Optional[int] = None, + sparse_top_k: Optional[int] = None + ) -> List[Tuple[int, float, Dict]]: + """ + Hybrid search combining dense and sparse retrieval with RRF. + + Args: + query: Search query string + top_k: Final number of results to return + dense_top_k: Results from dense search (default: 2*top_k) + sparse_top_k: Results from sparse search (default: 2*top_k) + + Returns: + List of (chunk_index, rrf_score, chunk_dict) tuples + + Raises: + ValueError: If not indexed or invalid parameters + + Performance: <200ms for 1000+ document corpus + """ + if self.dense_index is None: + raise ValueError("Must call index_documents() before searching") + + if not query.strip(): + return [] + + if top_k <= 0: + raise ValueError("top_k must be positive") + + # Set default intermediate result counts + if dense_top_k is None: + dense_top_k = min(2 * top_k, len(self.chunks)) + if sparse_top_k is None: + sparse_top_k = min(2 * top_k, len(self.chunks)) + + # Dense semantic search + dense_results = self._dense_search(query, dense_top_k) + + # Sparse BM25 search + sparse_results = self.sparse_retriever.search(query, sparse_top_k) + + # Combine using Adaptive Fusion (better for small result sets) + fused_results = adaptive_fusion( + dense_results=dense_results, + sparse_results=sparse_results, + dense_weight=self.dense_weight, + result_size=top_k + ) + + # Prepare final results with chunk content and apply source diversity + final_results = [] + for chunk_idx, rrf_score in fused_results: + chunk_dict = self.chunks[chunk_idx] + final_results.append((chunk_idx, rrf_score, chunk_dict)) + + # Apply source diversity enhancement + diverse_results = self._enhance_source_diversity(final_results, top_k) + + return diverse_results + + def _dense_search(self, query: str, top_k: int) -> List[Tuple[int, float]]: + """ + Perform dense semantic search using FAISS. + + Args: + query: Search query + top_k: Number of results to return + + Returns: + List of (chunk_index, similarity_score) tuples + """ + # Generate query embedding + query_embedding = generate_embeddings( + [query], + model_name=self.embedding_model, + use_mps=self.use_mps + ) + + # Normalize for cosine similarity + faiss.normalize_L2(query_embedding) + + # Search dense index + similarities, indices = self.dense_index.search(query_embedding, top_k) + + # Convert to required format + results = [ + (int(indices[0][i]), float(similarities[0][i])) + for i in range(len(indices[0])) + if indices[0][i] != -1 # Filter out invalid results + ] + + return results + + def _enhance_source_diversity( + self, + results: List[Tuple[int, float, Dict]], + top_k: int, + max_per_source: int = 2 + ) -> List[Tuple[int, float, Dict]]: + """ + Enhance source diversity in retrieval results to prevent over-focusing on single documents. + + Args: + results: List of (chunk_idx, score, chunk_dict) tuples sorted by relevance + top_k: Maximum number of results to return + max_per_source: Maximum chunks allowed per source document + + Returns: + Diversified results maintaining relevance while improving source coverage + """ + if not results: + return [] + + source_counts = {} + diverse_results = [] + + # First pass: Add highest scoring results respecting source limits + for chunk_idx, score, chunk_dict in results: + source = chunk_dict.get('source', 'unknown') + current_count = source_counts.get(source, 0) + + if current_count < max_per_source: + diverse_results.append((chunk_idx, score, chunk_dict)) + source_counts[source] = current_count + 1 + + if len(diverse_results) >= top_k: + break + + # Second pass: If we still need more results, relax source constraints + if len(diverse_results) < top_k: + for chunk_idx, score, chunk_dict in results: + if (chunk_idx, score, chunk_dict) not in diverse_results: + diverse_results.append((chunk_idx, score, chunk_dict)) + + if len(diverse_results) >= top_k: + break + + return diverse_results[:top_k] + + def get_retrieval_stats(self) -> Dict[str, any]: + """ + Get statistics about the indexed corpus and retrieval performance. + + Returns: + Dictionary with corpus statistics + """ + if not self.chunks: + return {"status": "not_indexed"} + + return { + "status": "indexed", + "total_chunks": len(self.chunks), + "dense_index_size": self.dense_index.ntotal if self.dense_index else 0, + "embedding_dim": self.embeddings.shape[1] if self.embeddings is not None else 0, + "sparse_indexed_chunks": len(self.sparse_retriever.chunk_mapping), + "dense_weight": self.dense_weight, + "sparse_weight": 1.0 - self.dense_weight, + "rrf_k": self.rrf_k + } \ No newline at end of file diff --git a/shared_utils/retrieval/vocabulary_index.py b/shared_utils/retrieval/vocabulary_index.py new file mode 100644 index 0000000000000000000000000000000000000000..aad0357a81988eb85dba44dd70b570fa4ecfab65 --- /dev/null +++ b/shared_utils/retrieval/vocabulary_index.py @@ -0,0 +1,260 @@ +""" +Vocabulary index for corpus-aware query enhancement. + +Tracks all unique terms in the document corpus to enable intelligent +synonym expansion that only adds terms actually present in documents. +""" + +from typing import Set, Dict, List, Optional +from collections import defaultdict +import re +from pathlib import Path +import json + + +class VocabularyIndex: + """ + Maintains vocabulary statistics for intelligent query enhancement. + + Features: + - Tracks all unique terms in document corpus + - Stores term frequencies for relevance weighting + - Identifies technical terms and domain vocabulary + - Enables vocabulary-aware synonym expansion + + Performance: + - Build time: ~1s per 1000 chunks + - Memory: ~3MB for 80K unique terms + - Lookup: O(1) set operations + """ + + def __init__(self): + """Initialize empty vocabulary index.""" + self.vocabulary: Set[str] = set() + self.term_frequencies: Dict[str, int] = defaultdict(int) + self.technical_terms: Set[str] = set() + self.document_frequencies: Dict[str, int] = defaultdict(int) + self.total_documents = 0 + self.total_terms = 0 + + # Regex for term extraction + self._term_pattern = re.compile(r'\b[a-zA-Z][a-zA-Z0-9\-_]*\b') + self._technical_pattern = re.compile(r'\b[A-Z]{2,}|[a-zA-Z]+[\-_][a-zA-Z]+|\b\d+[a-zA-Z]+\b') + + def build_from_chunks(self, chunks: List[Dict]) -> None: + """ + Build vocabulary index from document chunks. + + Args: + chunks: List of document chunks with 'text' field + + Performance: ~1s per 1000 chunks + """ + self.total_documents = len(chunks) + + for chunk in chunks: + text = chunk.get('text', '') + + # Extract and process terms + terms = self._extract_terms(text) + unique_terms = set(terms) + + # Update vocabulary + self.vocabulary.update(unique_terms) + + # Update frequencies + for term in terms: + self.term_frequencies[term] += 1 + self.total_terms += 1 + + # Update document frequencies + for term in unique_terms: + self.document_frequencies[term] += 1 + + # Identify technical terms + technical = self._extract_technical_terms(text) + self.technical_terms.update(technical) + + def _extract_terms(self, text: str) -> List[str]: + """Extract normalized terms from text.""" + # Convert to lowercase and extract words + text_lower = text.lower() + terms = self._term_pattern.findall(text_lower) + + # Filter short terms + return [term for term in terms if len(term) > 2] + + def _extract_technical_terms(self, text: str) -> Set[str]: + """Extract technical terms (acronyms, hyphenated, etc).""" + technical = set() + + # Find potential technical terms + matches = self._technical_pattern.findall(text) + + for match in matches: + # Normalize but preserve technical nature + normalized = match.lower() + if len(normalized) > 2: + technical.add(normalized) + + return technical + + def contains(self, term: str) -> bool: + """Check if term exists in vocabulary.""" + return term.lower() in self.vocabulary + + def get_frequency(self, term: str) -> int: + """Get term frequency in corpus.""" + return self.term_frequencies.get(term.lower(), 0) + + def get_document_frequency(self, term: str) -> int: + """Get number of documents containing term.""" + return self.document_frequencies.get(term.lower(), 0) + + def is_common_term(self, term: str, min_frequency: int = 5) -> bool: + """Check if term appears frequently enough.""" + return self.get_frequency(term) >= min_frequency + + def is_technical_term(self, term: str) -> bool: + """Check if term is identified as technical.""" + return term.lower() in self.technical_terms + + def filter_synonyms(self, synonyms: List[str], + min_frequency: int = 3, + require_technical: bool = False) -> List[str]: + """ + Filter synonym list to only include terms in vocabulary. + + Args: + synonyms: List of potential synonyms + min_frequency: Minimum term frequency required + require_technical: Only include technical terms + + Returns: + Filtered list of valid synonyms + """ + valid_synonyms = [] + + for synonym in synonyms: + # Check existence + if not self.contains(synonym): + continue + + # Check frequency threshold + if self.get_frequency(synonym) < min_frequency: + continue + + # Check technical requirement + if require_technical and not self.is_technical_term(synonym): + continue + + valid_synonyms.append(synonym) + + return valid_synonyms + + def get_vocabulary_stats(self) -> Dict[str, any]: + """Get comprehensive vocabulary statistics.""" + return { + 'unique_terms': len(self.vocabulary), + 'total_terms': self.total_terms, + 'technical_terms': len(self.technical_terms), + 'total_documents': self.total_documents, + 'avg_terms_per_doc': self.total_terms / self.total_documents if self.total_documents > 0 else 0, + 'vocabulary_richness': len(self.vocabulary) / self.total_terms if self.total_terms > 0 else 0, + 'technical_ratio': len(self.technical_terms) / len(self.vocabulary) if self.vocabulary else 0 + } + + def get_top_terms(self, n: int = 100, technical_only: bool = False) -> List[tuple]: + """ + Get most frequent terms in corpus. + + Args: + n: Number of top terms to return + technical_only: Only return technical terms + + Returns: + List of (term, frequency) tuples + """ + if technical_only: + term_freq = { + term: freq for term, freq in self.term_frequencies.items() + if term in self.technical_terms + } + else: + term_freq = self.term_frequencies + + return sorted(term_freq.items(), key=lambda x: x[1], reverse=True)[:n] + + def detect_domain(self) -> str: + """ + Detect document domain from vocabulary patterns. + + Returns: + Detected domain name + """ + # Domain detection heuristics + domain_indicators = { + 'embedded_systems': ['microcontroller', 'rtos', 'embedded', 'firmware', 'mcu'], + 'processor_architecture': ['risc-v', 'riscv', 'instruction', 'register', 'isa'], + 'regulatory': ['fda', 'validation', 'compliance', 'regulation', 'guidance'], + 'ai_ml': ['model', 'training', 'neural', 'algorithm', 'machine learning'], + 'software_engineering': ['software', 'development', 'testing', 'debugging', 'code'] + } + + domain_scores = {} + + for domain, indicators in domain_indicators.items(): + score = sum( + self.get_document_frequency(indicator) + for indicator in indicators + if self.contains(indicator) + ) + domain_scores[domain] = score + + # Return domain with highest score + if domain_scores: + return max(domain_scores, key=domain_scores.get) + return 'general' + + def save_to_file(self, path: Path) -> None: + """Save vocabulary index to JSON file.""" + data = { + 'vocabulary': list(self.vocabulary), + 'term_frequencies': dict(self.term_frequencies), + 'technical_terms': list(self.technical_terms), + 'document_frequencies': dict(self.document_frequencies), + 'total_documents': self.total_documents, + 'total_terms': self.total_terms + } + + with open(path, 'w') as f: + json.dump(data, f, indent=2) + + def load_from_file(self, path: Path) -> None: + """Load vocabulary index from JSON file.""" + with open(path, 'r') as f: + data = json.load(f) + + self.vocabulary = set(data['vocabulary']) + self.term_frequencies = defaultdict(int, data['term_frequencies']) + self.technical_terms = set(data['technical_terms']) + self.document_frequencies = defaultdict(int, data['document_frequencies']) + self.total_documents = data['total_documents'] + self.total_terms = data['total_terms'] + + def merge_with(self, other: 'VocabularyIndex') -> None: + """Merge another vocabulary index into this one.""" + # Merge vocabularies + self.vocabulary.update(other.vocabulary) + self.technical_terms.update(other.technical_terms) + + # Merge frequencies + for term, freq in other.term_frequencies.items(): + self.term_frequencies[term] += freq + + for term, doc_freq in other.document_frequencies.items(): + self.document_frequencies[term] += doc_freq + + # Update totals + self.total_documents += other.total_documents + self.total_terms += other.total_terms \ No newline at end of file diff --git a/shared_utils/vector_stores/__init__.py b/shared_utils/vector_stores/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shared_utils/vector_stores/__pycache__/__init__.cpython-312.pyc b/shared_utils/vector_stores/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6984611af7503df2c306e0be4df927bd4bbc15a Binary files /dev/null and b/shared_utils/vector_stores/__pycache__/__init__.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/document_processing/__init__.py b/shared_utils/vector_stores/document_processing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shared_utils/vector_stores/document_processing/__pycache__/__init__.cpython-312.pyc b/shared_utils/vector_stores/document_processing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29228d914a6b968652cef2840fe4719c1df0535b Binary files /dev/null and b/shared_utils/vector_stores/document_processing/__pycache__/__init__.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/document_processing/__pycache__/chunker.cpython-312.pyc b/shared_utils/vector_stores/document_processing/__pycache__/chunker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d888a501ced747a18b556945f57bdc6f05ab7d8d Binary files /dev/null and b/shared_utils/vector_stores/document_processing/__pycache__/chunker.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/document_processing/__pycache__/hybrid_parser.cpython-312.pyc b/shared_utils/vector_stores/document_processing/__pycache__/hybrid_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d406cb38a3d3c9290fe48479d3e219eb71668d3b Binary files /dev/null and b/shared_utils/vector_stores/document_processing/__pycache__/hybrid_parser.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/document_processing/__pycache__/paragraph_chunker.cpython-312.pyc b/shared_utils/vector_stores/document_processing/__pycache__/paragraph_chunker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd5cba549bfd974d3df4310c7f8666656951bba2 Binary files /dev/null and b/shared_utils/vector_stores/document_processing/__pycache__/paragraph_chunker.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/document_processing/__pycache__/pdf_parser.cpython-312.pyc b/shared_utils/vector_stores/document_processing/__pycache__/pdf_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84730dfbf1fe0fff183aebd7513da747a8845d8b Binary files /dev/null and b/shared_utils/vector_stores/document_processing/__pycache__/pdf_parser.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/document_processing/__pycache__/pdfplumber_parser.cpython-312.pyc b/shared_utils/vector_stores/document_processing/__pycache__/pdfplumber_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78668adde61fc2fbda29b7e085b053b0c85a55cc Binary files /dev/null and b/shared_utils/vector_stores/document_processing/__pycache__/pdfplumber_parser.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/document_processing/__pycache__/smart_chunker.cpython-312.pyc b/shared_utils/vector_stores/document_processing/__pycache__/smart_chunker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60c509102ca48c2e030d8bd16101c0bc9e88dde5 Binary files /dev/null and b/shared_utils/vector_stores/document_processing/__pycache__/smart_chunker.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/document_processing/__pycache__/structure_preserving_parser.cpython-312.pyc b/shared_utils/vector_stores/document_processing/__pycache__/structure_preserving_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69c543c3b8d838d9a5acd8b7458bdc7ded30f457 Binary files /dev/null and b/shared_utils/vector_stores/document_processing/__pycache__/structure_preserving_parser.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/document_processing/__pycache__/toc_guided_parser.cpython-312.pyc b/shared_utils/vector_stores/document_processing/__pycache__/toc_guided_parser.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b41cb453b64680b97baf7ec55f9ea920ddd52f17 Binary files /dev/null and b/shared_utils/vector_stores/document_processing/__pycache__/toc_guided_parser.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/document_processing/chunker.py b/shared_utils/vector_stores/document_processing/chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..dc65f930aee006505ed69b7dd834efe001a16e62 --- /dev/null +++ b/shared_utils/vector_stores/document_processing/chunker.py @@ -0,0 +1,243 @@ +""" +BasicRAG System - Technical Document Chunker + +This module implements intelligent text chunking specifically optimized for technical +documentation. Unlike naive chunking approaches, this implementation preserves sentence +boundaries and maintains semantic coherence, critical for accurate RAG retrieval. + +Key Features: +- Sentence-boundary aware chunking to preserve semantic units +- Configurable overlap to maintain context across chunk boundaries +- Content-based chunk IDs for reproducibility and deduplication +- Technical document optimizations (handles code blocks, lists, etc.) + +Technical Approach: +- Uses regex patterns to identify sentence boundaries +- Implements a sliding window algorithm with intelligent boundary detection +- Generates deterministic chunk IDs using MD5 hashing +- Balances chunk size consistency with semantic completeness + +Design Decisions: +- Default 512 char chunks: Optimal for transformer models (under token limits) +- 50 char overlap: Sufficient context preservation without excessive redundancy +- Sentence boundaries prioritized over exact size for better coherence +- Hash-based IDs enable chunk deduplication across documents + +Performance Characteristics: +- Time complexity: O(n) where n is text length +- Memory usage: O(n) for output chunks +- Typical throughput: 1MB text/second on modern hardware + +Author: Arthur Passuello +Date: June 2025 +Project: RAG Portfolio - Technical Documentation System +""" + +from typing import List, Dict +import re +import hashlib + + +def _is_low_quality_chunk(text: str) -> bool: + """ + Identify low-quality chunks that should be filtered out. + + @param text: Chunk text to evaluate + @return: True if chunk is low quality and should be filtered + """ + text_lower = text.lower().strip() + + # Skip if too short to be meaningful + if len(text.strip()) < 50: + return True + + # Filter out common low-value content + low_value_patterns = [ + # Acknowledgments and credits + r'^(acknowledgment|thanks|thank you)', + r'(thanks to|grateful to|acknowledge)', + + # References and citations + r'^\s*\[\d+\]', # Citation markers + r'^references?$', + r'^bibliography$', + + # Metadata and headers + r'this document is released under', + r'creative commons', + r'copyright \d{4}', + + # Table of contents + r'^\s*\d+\..*\.\.\.\.\.\d+$', # TOC entries + r'^(contents?|table of contents)$', + + # Page headers/footers + r'^\s*page \d+', + r'^\s*\d+\s*$', # Just page numbers + + # Figure/table captions that are too short + r'^(figure|table|fig\.|tab\.)\s*\d+:?\s*$', + ] + + for pattern in low_value_patterns: + if re.search(pattern, text_lower): + return True + + # Check content quality metrics + words = text.split() + if len(words) < 8: # Too few words to be meaningful + return True + + # Check for reasonable sentence structure + sentences = re.split(r'[.!?]+', text) + complete_sentences = [s.strip() for s in sentences if len(s.strip()) > 10] + + if len(complete_sentences) == 0: # No complete sentences + return True + + return False + + +def chunk_technical_text( + text: str, chunk_size: int = 1400, overlap: int = 200 +) -> List[Dict]: + """ + Phase 1: Sentence-boundary preserving chunker for technical documentation. + + ZERO MID-SENTENCE BREAKS: This implementation strictly enforces sentence + boundaries to eliminate fragmented retrieval results that break Q&A quality. + + Key Improvements: + - Never breaks chunks mid-sentence (eliminates 90% fragment rate) + - Larger target chunks (1400 chars) for complete explanations + - Extended search windows to find sentence boundaries + - Paragraph boundary preference within size constraints + + @param text: The input text to be chunked, typically from technical documentation + @type text: str + + @param chunk_size: Target size for each chunk in characters (default: 1400) + @type chunk_size: int + + @param overlap: Number of characters to overlap between consecutive chunks (default: 200) + @type overlap: int + + @return: List of chunk dictionaries containing text and metadata + @rtype: List[Dict[str, Any]] where each dictionary contains: + { + "text": str, # Complete, sentence-bounded chunk text + "start_char": int, # Starting character position in original text + "end_char": int, # Ending character position in original text + "chunk_id": str, # Unique identifier (format: "chunk_[8-char-hash]") + "word_count": int, # Number of words in the chunk + "sentence_complete": bool # Always True (guaranteed complete sentences) + } + + Algorithm Details (Phase 1): + - Expands search window up to 50% beyond target size to find sentence boundaries + - Prefers chunks within 70-150% of target size over fragmenting + - Never falls back to mid-sentence breaks + - Quality filtering removes headers, captions, and navigation elements + + Expected Results: + - Fragment rate: 90% → 0% (complete sentences only) + - Average chunk size: 1400-2100 characters (larger, complete contexts) + - All chunks end with proper sentence terminators (. ! ? : ;) + - Better retrieval context for Q&A generation + + Example Usage: + >>> text = "RISC-V defines registers. Each register has specific usage. The architecture supports..." + >>> chunks = chunk_technical_text(text, chunk_size=1400, overlap=200) + >>> # All chunks will contain complete sentences and explanations + """ + # Handle edge case: empty or whitespace-only input + if not text.strip(): + return [] + + # Clean and normalize text by removing leading/trailing whitespace + text = text.strip() + chunks = [] + start_pos = 0 + + # Main chunking loop - process text sequentially + while start_pos < len(text): + # Calculate target end position for this chunk + # Min() ensures we don't exceed text length + target_end = min(start_pos + chunk_size, len(text)) + + # Define sentence boundary pattern + # Matches: period, exclamation, question mark, colon, semicolon + # followed by whitespace or end of string + sentence_pattern = r'[.!?:;](?:\s|$)' + + # PHASE 1: Strict sentence boundary enforcement + # Expand search window significantly to ensure we find sentence boundaries + max_extension = chunk_size // 2 # Allow up to 50% larger chunks to find boundaries + search_start = max(start_pos, target_end - 200) # Look back further + search_end = min(len(text), target_end + max_extension) # Look forward much further + search_text = text[search_start:search_end] + + # Find all sentence boundaries in expanded search window + sentence_matches = list(re.finditer(sentence_pattern, search_text)) + + # STRICT: Always find a sentence boundary, never break mid-sentence + chunk_end = None + sentence_complete = False + + if sentence_matches: + # Find the best sentence boundary within reasonable range + for match in reversed(sentence_matches): # Start from last (longest chunk) + candidate_end = search_start + match.end() + candidate_size = candidate_end - start_pos + + # Accept if within reasonable size range + if candidate_size >= chunk_size * 0.7: # At least 70% of target size + chunk_end = candidate_end + sentence_complete = True + break + + # If no good boundary found, take the last boundary (avoid fragments) + if chunk_end is None and sentence_matches: + best_match = sentence_matches[-1] + chunk_end = search_start + best_match.end() + sentence_complete = True + + # Final fallback: extend to end of text if no sentences found + if chunk_end is None: + chunk_end = len(text) + sentence_complete = True # End of document is always complete + + # Extract chunk text and clean whitespace + chunk_text = text[start_pos:chunk_end].strip() + + # Only create chunk if it contains actual content AND passes quality filter + if chunk_text and not _is_low_quality_chunk(chunk_text): + # Generate deterministic chunk ID using content hash + # MD5 is sufficient for deduplication (not cryptographic use) + chunk_hash = hashlib.md5(chunk_text.encode()).hexdigest()[:8] + chunk_id = f"chunk_{chunk_hash}" + + # Calculate word count for chunk statistics + word_count = len(chunk_text.split()) + + # Assemble chunk metadata + chunks.append({ + "text": chunk_text, + "start_char": start_pos, + "end_char": chunk_end, + "chunk_id": chunk_id, + "word_count": word_count, + "sentence_complete": sentence_complete + }) + + # Calculate next chunk starting position with overlap + if chunk_end >= len(text): + # Reached end of text, exit loop + break + + # Apply overlap by moving start position back from chunk end + # Max() ensures we always move forward at least 1 character + overlap_start = max(chunk_end - overlap, start_pos + 1) + start_pos = overlap_start + + return chunks diff --git a/shared_utils/vector_stores/document_processing/hybrid_parser.py b/shared_utils/vector_stores/document_processing/hybrid_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9ebf15317ceb57ae1d7f09b8c6c2da4ffbcbfaca --- /dev/null +++ b/shared_utils/vector_stores/document_processing/hybrid_parser.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python3 +""" +Hybrid TOC + PDFPlumber Parser + +Combines the best of both approaches: +1. TOC-guided navigation for reliable chapter/section mapping +2. PDFPlumber's precise content extraction with formatting awareness +3. Aggressive trash content filtering while preserving actual content + +This hybrid approach provides: +- Reliable structure detection (TOC) +- High-quality content extraction (PDFPlumber) +- Optimal chunk sizing and quality +- Fast processing with precise results + +Author: Arthur Passuello +Date: 2025-07-01 +""" + +import re +import pdfplumber +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass + +from .toc_guided_parser import TOCGuidedParser, TOCEntry +from .pdfplumber_parser import PDFPlumberParser + + +class HybridParser: + """ + Hybrid parser combining TOC navigation with PDFPlumber extraction. + + Architecture: + 1. Use TOC to identify chapter/section boundaries and pages + 2. Use PDFPlumber to extract clean content from those specific pages + 3. Apply aggressive content filtering to remove trash + 4. Create optimal chunks with preserved structure + """ + + def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800, + max_chunk_size: int = 2000): + """Initialize hybrid parser.""" + self.target_chunk_size = target_chunk_size + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + + # Initialize component parsers + self.toc_parser = TOCGuidedParser(target_chunk_size, min_chunk_size, max_chunk_size) + self.plumber_parser = PDFPlumberParser(target_chunk_size, min_chunk_size, max_chunk_size) + + # Content filtering patterns (aggressive trash removal) + self.trash_patterns = [ + # License and legal text + r'Creative Commons.*?License', + r'International License.*?authors', + r'released under.*?license', + r'derivative of.*?License', + r'Document Version \d+', + + # Table of contents artifacts + r'\.{3,}', # Multiple dots + r'^\s*\d+\s*$', # Standalone page numbers + r'Contents\s*$', + r'Preface\s*$', + + # PDF formatting artifacts + r'Volume\s+[IVX]+:.*?V\d+', + r'^\s*[ivx]+\s*$', # Roman numerals alone + r'^\s*[\d\w\s]{1,3}\s*$', # Very short meaningless lines + + # Redundant headers and footers + r'RISC-V.*?ISA.*?V\d+', + r'Volume I:.*?Unprivileged', + + # Editor and publication info + r'Editors?:.*?[A-Z][a-z]+', + r'[A-Z][a-z]+\s+\d{1,2},\s+\d{4}', # Dates + r'@[a-z]+\.[a-z]+', # Email addresses + + # Boilerplate text + r'please contact editors to suggest corrections', + r'alphabetical order.*?corrections', + r'contributors to all versions', + ] + + # Content quality patterns (preserve these) + self.preserve_patterns = [ + r'RISC-V.*?instruction', + r'register.*?file', + r'memory.*?operation', + r'processor.*?implementation', + r'architecture.*?design', + ] + + # TOC-specific patterns to exclude from searchable content + self.toc_exclusion_patterns = [ + r'^\s*Contents\s*$', + r'^\s*Table\s+of\s+Contents\s*$', + r'^\s*\d+(?:\.\d+)*\s*$', # Standalone section numbers + r'^\s*\d+(?:\.\d+)*\s+[A-Z]', # "1.1 INTRODUCTION" style + r'\.{3,}', # Multiple dots (TOC formatting) + r'^\s*Chapter\s+\d+\s*$', # Standalone "Chapter N" + r'^\s*Section\s+\d+(?:\.\d+)*\s*$', # Standalone "Section N.M" + r'^\s*Appendix\s+[A-Z]\s*$', # Standalone "Appendix A" + r'^\s*[ivxlcdm]+\s*$', # Roman numerals alone + r'^\s*Preface\s*$', + r'^\s*Introduction\s*$', + r'^\s*Conclusion\s*$', + r'^\s*Bibliography\s*$', + r'^\s*Index\s*$', + ] + + def parse_document(self, pdf_path: Path, pdf_data: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Parse document using hybrid approach. + + Args: + pdf_path: Path to PDF file + pdf_data: PDF data from extract_text_with_metadata() + + Returns: + List of high-quality chunks with preserved structure + """ + print("🔗 Starting Hybrid TOC + PDFPlumber parsing...") + + # Step 1: Use TOC to identify structure + print("📋 Step 1: Extracting TOC structure...") + toc_entries = self.toc_parser.parse_toc(pdf_data['pages']) + print(f" Found {len(toc_entries)} TOC entries") + + # Check if TOC is reliable (multiple entries or quality single entry) + toc_is_reliable = ( + len(toc_entries) > 1 or # Multiple entries = likely real TOC + (len(toc_entries) == 1 and len(toc_entries[0].title) > 10) # Quality single entry + ) + + if not toc_entries or not toc_is_reliable: + if not toc_entries: + print(" ⚠️ No TOC found, using full page coverage parsing") + else: + print(f" ⚠️ TOC quality poor (title: '{toc_entries[0].title}'), using full page coverage") + return self.plumber_parser.parse_document(pdf_path, pdf_data) + + # Step 2: Use PDFPlumber for precise extraction + print("🔬 Step 2: PDFPlumber extraction of TOC sections...") + chunks = [] + chunk_id = 0 + + with pdfplumber.open(str(pdf_path)) as pdf: + for i, toc_entry in enumerate(toc_entries): + next_entry = toc_entries[i + 1] if i + 1 < len(toc_entries) else None + + # Extract content using PDFPlumber + section_content = self._extract_section_with_plumber( + pdf, toc_entry, next_entry + ) + + if section_content: + # Apply aggressive content filtering + cleaned_content = self._filter_trash_content(section_content) + + if cleaned_content and len(cleaned_content) >= 200: # Minimum meaningful content + # Create chunks from cleaned content + section_chunks = self._create_chunks_from_clean_content( + cleaned_content, chunk_id, toc_entry + ) + chunks.extend(section_chunks) + chunk_id += len(section_chunks) + + print(f" Created {len(chunks)} high-quality chunks") + return chunks + + def _extract_section_with_plumber(self, pdf, toc_entry: TOCEntry, + next_entry: Optional[TOCEntry]) -> str: + """ + Extract section content using PDFPlumber's precise extraction. + + Args: + pdf: PDFPlumber PDF object + toc_entry: Current TOC entry + next_entry: Next TOC entry (for boundary detection) + + Returns: + Clean extracted content for this section + """ + start_page = max(0, toc_entry.page - 1) # Convert to 0-indexed + + if next_entry: + end_page = min(len(pdf.pages), next_entry.page - 1) + else: + end_page = len(pdf.pages) + + content_parts = [] + + for page_idx in range(start_page, end_page): + if page_idx < len(pdf.pages): + page = pdf.pages[page_idx] + + # Extract text with PDFPlumber (preserves formatting) + page_text = page.extract_text() + + if page_text: + # Clean page content while preserving structure + cleaned_text = self._clean_page_content_precise(page_text) + if cleaned_text.strip(): + content_parts.append(cleaned_text) + + return ' '.join(content_parts) + + def _clean_page_content_precise(self, page_text: str) -> str: + """ + Clean page content with precision, removing artifacts but preserving content. + + Args: + page_text: Raw page text from PDFPlumber + + Returns: + Cleaned text with artifacts removed + """ + lines = page_text.split('\n') + cleaned_lines = [] + + for line in lines: + line = line.strip() + + # Skip empty lines + if not line: + continue + + # Skip obvious artifacts but be conservative + if (len(line) < 3 or # Very short lines + re.match(r'^\d+$', line) or # Standalone numbers + re.match(r'^[ivx]+$', line.lower()) or # Roman numerals alone + '.' * 5 in line): # TOC dots + continue + + # Preserve technical content even if it looks like an artifact + has_technical_content = any(term in line.lower() for term in [ + 'risc', 'register', 'instruction', 'memory', 'processor', + 'architecture', 'implementation', 'specification' + ]) + + if has_technical_content or len(line) >= 10: + cleaned_lines.append(line) + + return ' '.join(cleaned_lines) + + def _filter_trash_content(self, content: str) -> str: + """ + Apply aggressive trash filtering while preserving actual content. + + Args: + content: Raw content to filter + + Returns: + Content with trash removed but technical content preserved + """ + if not content.strip(): + return "" + + # First, identify and preserve important technical sentences + sentences = re.split(r'[.!?]+\s*', content) + preserved_sentences = [] + + for sentence in sentences: + sentence = sentence.strip() + if not sentence: + continue + + # Check if sentence contains important technical content + is_technical = any(term in sentence.lower() for term in [ + 'risc-v', 'register', 'instruction', 'memory', 'processor', + 'architecture', 'implementation', 'specification', 'encoding', + 'bit', 'byte', 'address', 'data', 'control', 'operand' + ]) + + # Check if sentence is trash (including general trash and TOC content) + is_trash = any(re.search(pattern, sentence, re.IGNORECASE) + for pattern in self.trash_patterns) + + # Check if sentence is TOC content (should be excluded) + is_toc_content = any(re.search(pattern, sentence, re.IGNORECASE) + for pattern in self.toc_exclusion_patterns) + + # Preserve if technical and not trash/TOC, or if substantial and not clearly trash/TOC + if ((is_technical and not is_trash and not is_toc_content) or + (len(sentence) > 50 and not is_trash and not is_toc_content)): + preserved_sentences.append(sentence) + + # Reconstruct content from preserved sentences + filtered_content = '. '.join(preserved_sentences) + + # Final cleanup + filtered_content = re.sub(r'\s+', ' ', filtered_content) # Normalize whitespace + filtered_content = re.sub(r'\.+', '.', filtered_content) # Remove multiple dots + + # Ensure proper sentence ending + if filtered_content and not filtered_content.rstrip().endswith(('.', '!', '?', ':', ';')): + filtered_content = filtered_content.rstrip() + '.' + + return filtered_content.strip() + + def _create_chunks_from_clean_content(self, content: str, start_chunk_id: int, + toc_entry: TOCEntry) -> List[Dict[str, Any]]: + """ + Create optimally-sized chunks from clean content. + + Args: + content: Clean, filtered content + start_chunk_id: Starting chunk ID + toc_entry: TOC entry metadata + + Returns: + List of chunk dictionaries + """ + if not content or len(content) < 100: + return [] + + chunks = [] + + # If content fits in one chunk, create single chunk + if self.min_chunk_size <= len(content) <= self.max_chunk_size: + chunk = self._create_chunk(content, start_chunk_id, toc_entry) + chunks.append(chunk) + + # If too large, split intelligently at sentence boundaries + elif len(content) > self.max_chunk_size: + sub_chunks = self._split_large_content_smart(content, start_chunk_id, toc_entry) + chunks.extend(sub_chunks) + + # If too small but substantial, keep it + elif len(content) >= 200: # Lower threshold for cleaned content + chunk = self._create_chunk(content, start_chunk_id, toc_entry) + chunks.append(chunk) + + return chunks + + def _split_large_content_smart(self, content: str, start_chunk_id: int, + toc_entry: TOCEntry) -> List[Dict[str, Any]]: + """ + Split large content intelligently at natural boundaries. + + Args: + content: Content to split + start_chunk_id: Starting chunk ID + toc_entry: TOC entry metadata + + Returns: + List of chunk dictionaries + """ + chunks = [] + + # Split at sentence boundaries + sentences = re.split(r'([.!?:;]+\s*)', content) + + current_chunk = "" + chunk_id = start_chunk_id + + for i in range(0, len(sentences), 2): + sentence = sentences[i].strip() + if not sentence: + continue + + # Add punctuation if available + punctuation = sentences[i + 1] if i + 1 < len(sentences) else '.' + full_sentence = sentence + punctuation + + # Check if adding this sentence exceeds max size + potential_chunk = current_chunk + (" " if current_chunk else "") + full_sentence + + if len(potential_chunk) <= self.max_chunk_size: + current_chunk = potential_chunk + else: + # Save current chunk if it meets minimum size + if current_chunk and len(current_chunk) >= self.min_chunk_size: + chunk = self._create_chunk(current_chunk, chunk_id, toc_entry) + chunks.append(chunk) + chunk_id += 1 + + # Start new chunk + current_chunk = full_sentence + + # Add final chunk if substantial + if current_chunk and len(current_chunk) >= 200: + chunk = self._create_chunk(current_chunk, chunk_id, toc_entry) + chunks.append(chunk) + + return chunks + + def _create_chunk(self, content: str, chunk_id: int, toc_entry: TOCEntry) -> Dict[str, Any]: + """Create a chunk dictionary with hybrid metadata.""" + return { + "text": content, + "chunk_id": chunk_id, + "title": toc_entry.title, + "parent_title": toc_entry.parent_title, + "level": toc_entry.level, + "page": toc_entry.page, + "size": len(content), + "metadata": { + "parsing_method": "hybrid_toc_pdfplumber", + "has_context": True, + "content_type": "filtered_structured_content", + "quality_score": self._calculate_quality_score(content), + "trash_filtered": True + } + } + + def _calculate_quality_score(self, content: str) -> float: + """Calculate quality score for filtered content.""" + if not content.strip(): + return 0.0 + + words = content.split() + score = 0.0 + + # Length score (25%) + if self.min_chunk_size <= len(content) <= self.max_chunk_size: + score += 0.25 + elif len(content) >= 200: # At least some content + score += 0.15 + + # Content richness (25%) + substantial_words = sum(1 for word in words if len(word) > 3) + richness_score = min(substantial_words / 30, 1.0) # Lower threshold for filtered content + score += richness_score * 0.25 + + # Technical content (30%) + technical_terms = ['risc', 'register', 'instruction', 'cpu', 'memory', 'processor', 'architecture'] + technical_count = sum(1 for word in words if any(term in word.lower() for term in technical_terms)) + technical_score = min(technical_count / 3, 1.0) # Lower threshold + score += technical_score * 0.30 + + # Completeness (20%) + completeness_score = 0.0 + if content[0].isupper() or content.startswith(('The ', 'A ', 'An ', 'RISC')): + completeness_score += 0.5 + if content.rstrip().endswith(('.', '!', '?', ':', ';')): + completeness_score += 0.5 + score += completeness_score * 0.20 + + return min(score, 1.0) + + +def parse_pdf_with_hybrid_approach(pdf_path: Path, pdf_data: Dict[str, Any], + target_chunk_size: int = 1400, min_chunk_size: int = 800, + max_chunk_size: int = 2000) -> List[Dict[str, Any]]: + """ + Parse PDF using hybrid TOC + PDFPlumber approach. + + This function combines: + 1. TOC-guided structure detection for reliable navigation + 2. PDFPlumber's precise content extraction + 3. Aggressive trash filtering while preserving technical content + + Args: + pdf_path: Path to PDF file + pdf_data: PDF data from extract_text_with_metadata() + target_chunk_size: Preferred chunk size + min_chunk_size: Minimum chunk size + max_chunk_size: Maximum chunk size + + Returns: + List of high-quality, filtered chunks ready for RAG indexing + + Example: + >>> from shared_utils.document_processing.pdf_parser import extract_text_with_metadata + >>> from shared_utils.document_processing.hybrid_parser import parse_pdf_with_hybrid_approach + >>> + >>> pdf_data = extract_text_with_metadata("document.pdf") + >>> chunks = parse_pdf_with_hybrid_approach(Path("document.pdf"), pdf_data) + >>> print(f"Created {len(chunks)} hybrid-parsed chunks") + """ + parser = HybridParser(target_chunk_size, min_chunk_size, max_chunk_size) + return parser.parse_document(pdf_path, pdf_data) + + +# Example usage +if __name__ == "__main__": + print("Hybrid TOC + PDFPlumber Parser") + print("Combines TOC navigation with PDFPlumber precision and aggressive trash filtering") \ No newline at end of file diff --git a/shared_utils/vector_stores/document_processing/pdf_parser.py b/shared_utils/vector_stores/document_processing/pdf_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad8caf8f6144373e812d9821084f7eb0085c530 --- /dev/null +++ b/shared_utils/vector_stores/document_processing/pdf_parser.py @@ -0,0 +1,137 @@ +""" +BasicRAG System - PDF Document Parser + +This module implements robust PDF text extraction functionality as part of the BasicRAG +technical documentation system. It serves as the entry point for document ingestion, +converting PDF files into structured text data suitable for chunking and embedding. + +Key Features: +- Page-by-page text extraction with metadata preservation +- Robust error handling for corrupted or malformed PDFs +- Performance timing for optimization analysis +- Memory-efficient processing for large documents + +Technical Approach: +- Uses PyMuPDF (fitz) for reliable text extraction across PDF versions +- Maintains document structure with page-level granularity +- Preserves PDF metadata (author, title, creation date, etc.) + +Dependencies: +- PyMuPDF (fitz): Chosen for superior text extraction accuracy and speed +- Standard library: pathlib for cross-platform file handling + +Performance Characteristics: +- Typical processing: 10-50 pages/second on modern hardware +- Memory usage: O(n) with document size, but processes page-by-page +- Scales linearly with document length + +Author: Arthur Passuello +Date: June 2025 +Project: RAG Portfolio - Technical Documentation System +""" + +from typing import Dict, List, Any +from pathlib import Path +import time +import fitz # PyMuPDF + + +def extract_text_with_metadata(pdf_path: Path) -> Dict[str, Any]: + """ + Extract text and metadata from technical PDF documents with production-grade reliability. + + This function serves as the primary ingestion point for the RAG system, converting + PDF documents into structured data. It's optimized for technical documentation with + emphasis on preserving structure and handling various PDF formats gracefully. + + @param pdf_path: Path to the PDF file to process + @type pdf_path: pathlib.Path + + @return: Dictionary containing extracted text and comprehensive metadata + @rtype: Dict[str, Any] with the following structure: + { + "text": str, # Complete concatenated text from all pages + "pages": List[Dict], # Per-page breakdown with text and statistics + # Each page dict contains: + # - page_number: int (1-indexed for human readability) + # - text: str (raw text from that page) + # - char_count: int (character count for that page) + "metadata": Dict, # PDF metadata (title, author, subject, etc.) + "page_count": int, # Total number of pages processed + "extraction_time": float # Processing duration in seconds + } + + @throws FileNotFoundError: If the specified PDF file doesn't exist + @throws ValueError: If the PDF is corrupted, encrypted, or otherwise unreadable + + Performance Notes: + - Processes ~10-50 pages/second depending on PDF complexity + - Memory usage is proportional to document size but page-by-page processing + prevents loading entire document into memory at once + - Extraction time is included for performance monitoring and optimization + + Usage Example: + >>> pdf_path = Path("technical_manual.pdf") + >>> result = extract_text_with_metadata(pdf_path) + >>> print(f"Extracted {result['page_count']} pages in {result['extraction_time']:.2f}s") + >>> first_page_text = result['pages'][0]['text'] + """ + # Validate input file exists before attempting to open + if not pdf_path.exists(): + raise FileNotFoundError(f"PDF file not found: {pdf_path}") + + # Start performance timer for extraction analytics + start_time = time.perf_counter() + + try: + # Open PDF with PyMuPDF - automatically handles various PDF versions + # Using string conversion for compatibility with older fitz versions + doc = fitz.open(str(pdf_path)) + + # Extract document-level metadata (may include title, author, subject, keywords) + # Default to empty dict if no metadata present (common in scanned PDFs) + metadata = doc.metadata or {} + page_count = len(doc) + + # Initialize containers for page-by-page extraction + pages = [] # Will store individual page data + all_text = [] # Will store text for concatenation + + # Process each page sequentially to maintain document order + for page_num in range(page_count): + # Load page object (0-indexed internally) + page = doc[page_num] + + # Extract text using default extraction parameters + # This preserves reading order and handles multi-column layouts + page_text = page.get_text() + + # Store page data with human-readable page numbering (1-indexed) + pages.append({ + "page_number": page_num + 1, # Convert to 1-indexed for user clarity + "text": page_text, + "char_count": len(page_text) # Useful for chunking decisions + }) + + # Accumulate text for final concatenation + all_text.append(page_text) + + # Properly close the PDF to free resources + doc.close() + + # Calculate total extraction time for performance monitoring + extraction_time = time.perf_counter() - start_time + + # Return comprehensive extraction results + return { + "text": "\n".join(all_text), # Full document text with page breaks + "pages": pages, # Detailed page-by-page breakdown + "metadata": metadata, # Original PDF metadata + "page_count": page_count, # Total pages for quick reference + "extraction_time": extraction_time # Performance metric + } + + except Exception as e: + # Wrap any extraction errors with context for debugging + # Common causes: encrypted PDFs, corrupted files, unsupported formats + raise ValueError(f"Failed to process PDF: {e}") \ No newline at end of file diff --git a/shared_utils/vector_stores/document_processing/pdfplumber_parser.py b/shared_utils/vector_stores/document_processing/pdfplumber_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..08ff8e9897c198447eb0b8d1def63e244298496c --- /dev/null +++ b/shared_utils/vector_stores/document_processing/pdfplumber_parser.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python3 +""" +PDFPlumber-based Parser + +Advanced PDF parsing using pdfplumber for better structure detection +and cleaner text extraction. + +Author: Arthur Passuello +""" + +import re +import pdfplumber +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Any + + +class PDFPlumberParser: + """Advanced PDF parser using pdfplumber for structure-aware extraction.""" + + def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800, + max_chunk_size: int = 2000): + """Initialize PDFPlumber parser.""" + self.target_chunk_size = target_chunk_size + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + + # Trash content patterns + self.trash_patterns = [ + r'Creative Commons.*?License', + r'International License.*?authors', + r'RISC-V International', + r'Visit.*?for further', + r'editors to suggest.*?corrections', + r'released under.*?license', + r'\.{5,}', # Long dots (TOC artifacts) + r'^\d+\s*$', # Page numbers alone + ] + + def extract_with_structure(self, pdf_path: Path) -> List[Dict]: + """Extract PDF content with structure awareness using pdfplumber.""" + chunks = [] + + with pdfplumber.open(pdf_path) as pdf: + current_section = None + current_text = [] + + for page_num, page in enumerate(pdf.pages): + # Extract text with formatting info + page_content = self._extract_page_content(page, page_num + 1) + + for element in page_content: + if element['type'] == 'header': + # Save previous section if exists + if current_text: + chunk_text = '\n\n'.join(current_text) + if self._is_valid_chunk(chunk_text): + chunks.extend(self._create_chunks( + chunk_text, + current_section or "Document", + page_num + )) + + # Start new section + current_section = element['text'] + current_text = [] + + elif element['type'] == 'content': + # Add to current section + if self._is_valid_content(element['text']): + current_text.append(element['text']) + + # Don't forget last section + if current_text: + chunk_text = '\n\n'.join(current_text) + if self._is_valid_chunk(chunk_text): + chunks.extend(self._create_chunks( + chunk_text, + current_section or "Document", + len(pdf.pages) + )) + + return chunks + + def _extract_page_content(self, page: Any, page_num: int) -> List[Dict]: + """Extract structured content from a page.""" + content = [] + + # Get all text with positioning + chars = page.chars + if not chars: + return content + + # Group by lines + lines = [] + current_line = [] + current_y = None + + for char in sorted(chars, key=lambda x: (x['top'], x['x0'])): + if current_y is None or abs(char['top'] - current_y) < 2: + current_line.append(char) + current_y = char['top'] + else: + if current_line: + lines.append(current_line) + current_line = [char] + current_y = char['top'] + + if current_line: + lines.append(current_line) + + # Analyze each line + for line in lines: + line_text = ''.join(char['text'] for char in line).strip() + + if not line_text: + continue + + # Detect headers by font size + avg_font_size = sum(char.get('size', 12) for char in line) / len(line) + is_bold = any(char.get('fontname', '').lower().count('bold') > 0 for char in line) + + # Classify content + if avg_font_size > 14 or is_bold: + # Likely a header + if self._is_valid_header(line_text): + content.append({ + 'type': 'header', + 'text': line_text, + 'font_size': avg_font_size, + 'page': page_num + }) + else: + # Regular content + content.append({ + 'type': 'content', + 'text': line_text, + 'font_size': avg_font_size, + 'page': page_num + }) + + return content + + def _is_valid_header(self, text: str) -> bool: + """Check if text is a valid header.""" + # Skip if too short or too long + if len(text) < 3 or len(text) > 200: + return False + + # Skip if matches trash patterns + for pattern in self.trash_patterns: + if re.search(pattern, text, re.IGNORECASE): + return False + + # Valid if starts with number or capital letter + if re.match(r'^(\d+\.?\d*\s+|[A-Z])', text): + return True + + # Valid if contains keywords + keywords = ['chapter', 'section', 'introduction', 'conclusion', 'appendix'] + return any(keyword in text.lower() for keyword in keywords) + + def _is_valid_content(self, text: str) -> bool: + """Check if text is valid content (not trash).""" + # Skip very short text + if len(text.strip()) < 10: + return False + + # Skip trash patterns + for pattern in self.trash_patterns: + if re.search(pattern, text, re.IGNORECASE): + return False + + return True + + def _is_valid_chunk(self, text: str) -> bool: + """Check if chunk text is valid.""" + # Must have minimum length + if len(text.strip()) < self.min_chunk_size // 2: + return False + + # Must have some alphabetic content + alpha_chars = sum(1 for c in text if c.isalpha()) + if alpha_chars < len(text) * 0.5: + return False + + return True + + def _create_chunks(self, text: str, title: str, page: int) -> List[Dict]: + """Create chunks from text.""" + chunks = [] + + # Clean text + text = self._clean_text(text) + + if len(text) <= self.max_chunk_size: + # Single chunk + chunks.append({ + 'text': text, + 'title': title, + 'page': page, + 'metadata': { + 'parsing_method': 'pdfplumber', + 'quality_score': self._calculate_quality_score(text) + } + }) + else: + # Split into chunks + text_chunks = self._split_text_into_chunks(text) + for i, chunk_text in enumerate(text_chunks): + chunks.append({ + 'text': chunk_text, + 'title': f"{title} (Part {i+1})", + 'page': page, + 'metadata': { + 'parsing_method': 'pdfplumber', + 'part_number': i + 1, + 'total_parts': len(text_chunks), + 'quality_score': self._calculate_quality_score(chunk_text) + } + }) + + return chunks + + def _clean_text(self, text: str) -> str: + """Clean text from artifacts.""" + # Remove volume headers (e.g., "Volume I: RISC-V Unprivileged ISA V20191213") + text = re.sub(r'Volume\s+[IVX]+:\s*RISC-V[^V]*V\d{8}\s*', '', text, flags=re.IGNORECASE) + text = re.sub(r'^\d+\s+Volume\s+[IVX]+:.*?$', '', text, flags=re.MULTILINE) + + # Remove document version artifacts + text = re.sub(r'Document Version \d{8}\s*', '', text, flags=re.IGNORECASE) + + # Remove repeated ISA headers + text = re.sub(r'RISC-V.*?ISA.*?V\d{8}\s*', '', text, flags=re.IGNORECASE) + text = re.sub(r'The RISC-V Instruction Set Manual\s*', '', text, flags=re.IGNORECASE) + + # Remove figure/table references that are standalone + text = re.sub(r'^(Figure|Table)\s+\d+\.\d+:.*?$', '', text, flags=re.MULTILINE) + + # Remove email addresses (often in contributor lists) + text = re.sub(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', '', text) + + # Remove URLs + text = re.sub(r'https?://[^\s]+', '', text) + + # Remove page numbers at start/end of lines + text = re.sub(r'^\d{1,3}\s+', '', text, flags=re.MULTILINE) + text = re.sub(r'\s+\d{1,3}$', '', text, flags=re.MULTILINE) + + # Remove excessive dots (TOC artifacts) + text = re.sub(r'\.{3,}', '', text) + + # Remove standalone numbers (often page numbers or figure numbers) + text = re.sub(r'^\s*\d+\s*$', '', text, flags=re.MULTILINE) + + # Clean up multiple spaces and newlines + text = re.sub(r'\s{3,}', ' ', text) + text = re.sub(r'\n{3,}', '\n\n', text) + text = re.sub(r'[ \t]+', ' ', text) # Normalize all whitespace + + # Remove common boilerplate phrases + text = re.sub(r'Contains Nonbinding Recommendations\s*', '', text, flags=re.IGNORECASE) + text = re.sub(r'Guidance for Industry and FDA Staff\s*', '', text, flags=re.IGNORECASE) + + return text.strip() + + def _split_text_into_chunks(self, text: str) -> List[str]: + """Split text into chunks at sentence boundaries.""" + sentences = re.split(r'(?<=[.!?])\s+', text) + chunks = [] + current_chunk = [] + current_size = 0 + + for sentence in sentences: + sentence_size = len(sentence) + + if current_size + sentence_size > self.target_chunk_size and current_chunk: + chunks.append(' '.join(current_chunk)) + current_chunk = [sentence] + current_size = sentence_size + else: + current_chunk.append(sentence) + current_size += sentence_size + 1 + + if current_chunk: + chunks.append(' '.join(current_chunk)) + + return chunks + + def _calculate_quality_score(self, text: str) -> float: + """Calculate quality score for chunk.""" + score = 1.0 + + # Penalize very short or very long + if len(text) < self.min_chunk_size: + score *= 0.8 + elif len(text) > self.max_chunk_size: + score *= 0.9 + + # Reward complete sentences + if text.strip().endswith(('.', '!', '?')): + score *= 1.1 + + # Reward technical content + technical_terms = ['risc', 'instruction', 'register', 'memory', 'processor'] + term_count = sum(1 for term in technical_terms if term in text.lower()) + score *= (1 + term_count * 0.05) + + return min(score, 1.0) + + def extract_with_page_coverage(self, pdf_path: Path, pymupdf_pages: List[Dict]) -> List[Dict]: + """ + Extract content ensuring ALL pages are covered using PyMuPDF page data. + + Args: + pdf_path: Path to PDF file + pymupdf_pages: Page data from PyMuPDF with page numbers and text + + Returns: + List of chunks covering ALL document pages + """ + chunks = [] + chunk_id = 0 + + print(f"📄 Processing {len(pymupdf_pages)} pages with PDFPlumber quality extraction...") + + with pdfplumber.open(str(pdf_path)) as pdf: + for pymupdf_page in pymupdf_pages: + page_num = pymupdf_page['page_number'] # 1-indexed from PyMuPDF + page_idx = page_num - 1 # Convert to 0-indexed for PDFPlumber + + if page_idx < len(pdf.pages): + # Extract with PDFPlumber quality from this specific page + pdfplumber_page = pdf.pages[page_idx] + page_text = pdfplumber_page.extract_text() + + if page_text and page_text.strip(): + # Clean and chunk the page text + cleaned_text = self._clean_text(page_text) + + if len(cleaned_text) >= 100: # Minimum meaningful content + # Create chunks from this page + page_chunks = self._create_page_chunks( + cleaned_text, page_num, chunk_id + ) + chunks.extend(page_chunks) + chunk_id += len(page_chunks) + + if len(chunks) % 50 == 0: # Progress indicator + print(f" Processed {page_num} pages, created {len(chunks)} chunks") + + print(f"✅ Full coverage: {len(chunks)} chunks from {len(pymupdf_pages)} pages") + return chunks + + def _create_page_chunks(self, page_text: str, page_num: int, start_chunk_id: int) -> List[Dict]: + """Create properly sized chunks from a single page's content.""" + # Clean and validate page text first + cleaned_text = self._ensure_complete_sentences(page_text) + + if not cleaned_text or len(cleaned_text) < 50: + # Skip pages with insufficient content + return [] + + if len(cleaned_text) <= self.max_chunk_size: + # Single chunk for small pages + return [{ + 'text': cleaned_text, + 'title': f"Page {page_num}", + 'page': page_num, + 'metadata': { + 'parsing_method': 'pdfplumber_page_coverage', + 'quality_score': self._calculate_quality_score(cleaned_text), + 'full_page_coverage': True + } + }] + else: + # Split large pages into chunks with sentence boundaries + text_chunks = self._split_text_into_chunks(cleaned_text) + page_chunks = [] + + for i, chunk_text in enumerate(text_chunks): + # Ensure each chunk is complete + complete_chunk = self._ensure_complete_sentences(chunk_text) + + if complete_chunk and len(complete_chunk) >= 100: + page_chunks.append({ + 'text': complete_chunk, + 'title': f"Page {page_num} (Part {i+1})", + 'page': page_num, + 'metadata': { + 'parsing_method': 'pdfplumber_page_coverage', + 'part_number': i + 1, + 'total_parts': len(text_chunks), + 'quality_score': self._calculate_quality_score(complete_chunk), + 'full_page_coverage': True + } + }) + + return page_chunks + + def _ensure_complete_sentences(self, text: str) -> str: + """Ensure text contains only complete sentences.""" + text = text.strip() + if not text: + return "" + + # Find last complete sentence + last_sentence_end = -1 + for i, char in enumerate(reversed(text)): + if char in '.!?:': + last_sentence_end = len(text) - i + break + + if last_sentence_end > 0: + # Return text up to last complete sentence + complete_text = text[:last_sentence_end].strip() + + # Ensure it starts properly (capital letter or common starters) + if complete_text and (complete_text[0].isupper() or + complete_text.startswith(('The ', 'A ', 'An ', 'This ', 'RISC'))): + return complete_text + + # If no complete sentences found, return empty + return "" + + def parse_document(self, pdf_path: Path, pdf_data: Dict[str, Any] = None) -> List[Dict]: + """ + Parse document using PDFPlumber (required by HybridParser). + + Args: + pdf_path: Path to PDF file + pdf_data: PyMuPDF page data to ensure full page coverage + + Returns: + List of chunks with structure preservation across ALL pages + """ + if pdf_data and 'pages' in pdf_data: + # Use PyMuPDF page data to ensure full coverage + return self.extract_with_page_coverage(pdf_path, pdf_data['pages']) + else: + # Fallback to structure-based extraction + return self.extract_with_structure(pdf_path) + + +def parse_pdf_with_pdfplumber(pdf_path: Path, **kwargs) -> List[Dict]: + """Main entry point for PDFPlumber parsing.""" + parser = PDFPlumberParser(**kwargs) + chunks = parser.extract_with_structure(pdf_path) + + print(f"PDFPlumber extracted {len(chunks)} chunks") + + return chunks \ No newline at end of file diff --git a/shared_utils/vector_stores/document_processing/toc_guided_parser.py b/shared_utils/vector_stores/document_processing/toc_guided_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..613820600207c0994ac745d29e24491fee38ec69 --- /dev/null +++ b/shared_utils/vector_stores/document_processing/toc_guided_parser.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +""" +TOC-Guided PDF Parser + +Uses the Table of Contents to guide intelligent chunking that respects +document structure and hierarchy. + +Author: Arthur Passuello +""" + +import re +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass + + +@dataclass +class TOCEntry: + """Represents a table of contents entry.""" + title: str + page: int + level: int # 0 for chapters, 1 for sections, 2 for subsections + parent: Optional[str] = None + parent_title: Optional[str] = None # Added for hybrid parser compatibility + + +class TOCGuidedParser: + """Parser that uses TOC to create structure-aware chunks.""" + + def __init__(self, target_chunk_size: int = 1400, min_chunk_size: int = 800, + max_chunk_size: int = 2000): + """Initialize TOC-guided parser.""" + self.target_chunk_size = target_chunk_size + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + + def parse_toc(self, pages: List[Dict]) -> List[TOCEntry]: + """Parse table of contents from pages.""" + toc_entries = [] + + # Find TOC pages (usually early in document) + toc_pages = [] + for i, page in enumerate(pages[:20]): # Check first 20 pages + page_text = page.get('text', '').lower() + if 'contents' in page_text or 'table of contents' in page_text: + toc_pages.append((i, page)) + + if not toc_pages: + print("No TOC found, using fallback structure detection") + return self._detect_structure_without_toc(pages) + + # Parse TOC entries + for page_idx, page in toc_pages: + text = page.get('text', '') + lines = text.split('\n') + + i = 0 + while i < len(lines): + line = lines[i].strip() + + # Skip empty lines and TOC header + if not line or 'contents' in line.lower(): + i += 1 + continue + + # Pattern 1: "1.1 Title .... 23" + match1 = re.match(r'^(\d+(?:\.\d+)*)\s+(.+?)\s*\.{2,}\s*(\d+)$', line) + if match1: + number, title, page_num = match1.groups() + level = len(number.split('.')) - 1 + toc_entries.append(TOCEntry( + title=title.strip(), + page=int(page_num), + level=level + )) + i += 1 + continue + + # Pattern 2: Multi-line format + # "1.1" + # "Title" + # ". . . . 23" + if re.match(r'^(\d+(?:\.\d+)*)$', line): + number = line + if i + 1 < len(lines): + title_line = lines[i + 1].strip() + if i + 2 < len(lines): + dots_line = lines[i + 2].strip() + page_match = re.search(r'(\d+)\s*$', dots_line) + if page_match and '.' in dots_line: + title = title_line + page_num = int(page_match.group(1)) + level = len(number.split('.')) - 1 + toc_entries.append(TOCEntry( + title=title, + page=page_num, + level=level + )) + i += 3 + continue + + # Pattern 3: "Chapter 1: Title ... 23" + match3 = re.match(r'^(Chapter|Section|Part)\s+(\d+):?\s+(.+?)\s*\.{2,}\s*(\d+)$', line, re.IGNORECASE) + if match3: + prefix, number, title, page_num = match3.groups() + level = 0 if prefix.lower() == 'chapter' else 1 + toc_entries.append(TOCEntry( + title=f"{prefix} {number}: {title}", + page=int(page_num), + level=level + )) + i += 1 + continue + + i += 1 + + # Add parent relationships + for i, entry in enumerate(toc_entries): + if entry.level > 0: + # Find parent (previous entry with lower level) + for j in range(i - 1, -1, -1): + if toc_entries[j].level < entry.level: + entry.parent = toc_entries[j].title + entry.parent_title = toc_entries[j].title # Set both for compatibility + break + + return toc_entries + + def _detect_structure_without_toc(self, pages: List[Dict]) -> List[TOCEntry]: + """Fallback: detect structure from content patterns across ALL pages.""" + entries = [] + + # Expanded patterns for better structure detection + chapter_patterns = [ + re.compile(r'^(Chapter|CHAPTER)\s+(\d+|[IVX]+)(?:\s*[:\-]\s*(.+))?', re.MULTILINE), + re.compile(r'^(\d+)\s+([A-Z][^.]*?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1 Introduction" + re.compile(r'^([A-Z][A-Z\s]{10,})$', re.MULTILINE), # ALL CAPS titles + ] + + section_patterns = [ + re.compile(r'^(\d+\.\d+)\s+(.+?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1.1 Section" + re.compile(r'^(\d+\.\d+\.\d+)\s+(.+?)(?:\s*\.{2,}\s*\d+)?$', re.MULTILINE), # "1.1.1 Subsection" + ] + + # Process ALL pages, not just first 20 + for i, page in enumerate(pages): + text = page.get('text', '') + if not text.strip(): + continue + + # Find chapters with various patterns + for pattern in chapter_patterns: + for match in pattern.finditer(text): + if len(match.groups()) >= 2: + if len(match.groups()) >= 3 and match.group(3): + title = match.group(3).strip() + else: + title = match.group(2).strip() if match.group(2) else f"Section {match.group(1)}" + + # Skip very short or likely false positives + if len(title) >= 3 and not re.match(r'^\d+$', title): + entries.append(TOCEntry( + title=title, + page=i + 1, + level=0 + )) + + # Find sections + for pattern in section_patterns: + for match in pattern.finditer(text): + section_num = match.group(1) + title = match.group(2).strip() if len(match.groups()) >= 2 else f"Section {section_num}" + + # Determine level by number of dots + level = section_num.count('.') + + # Skip very short titles or obvious artifacts + if len(title) >= 3 and not re.match(r'^\d+$', title): + entries.append(TOCEntry( + title=title, + page=i + 1, + level=level + )) + + # If still no entries found, create page-based entries for full coverage + if not entries: + print("No structure patterns found, creating page-based sections for full coverage") + # Create sections every 10 pages to ensure full document coverage + for i in range(0, len(pages), 10): + start_page = i + 1 + end_page = min(i + 10, len(pages)) + title = f"Pages {start_page}-{end_page}" + entries.append(TOCEntry( + title=title, + page=start_page, + level=0 + )) + + return entries + + def create_chunks_from_toc(self, pdf_data: Dict, toc_entries: List[TOCEntry]) -> List[Dict]: + """Create chunks based on TOC structure.""" + chunks = [] + pages = pdf_data.get('pages', []) + + for i, entry in enumerate(toc_entries): + # Determine page range for this entry + start_page = entry.page - 1 # Convert to 0-indexed + + # Find end page (start of next entry at same or higher level) + end_page = len(pages) + for j in range(i + 1, len(toc_entries)): + if toc_entries[j].level <= entry.level: + end_page = toc_entries[j].page - 1 + break + + # Extract text for this section + section_text = [] + for page_idx in range(max(0, start_page), min(end_page, len(pages))): + page_text = pages[page_idx].get('text', '') + if page_text.strip(): + section_text.append(page_text) + + if not section_text: + continue + + full_text = '\n\n'.join(section_text) + + # Create chunks from section text + if len(full_text) <= self.max_chunk_size: + # Single chunk for small sections + chunks.append({ + 'text': full_text.strip(), + 'title': entry.title, + 'parent_title': entry.parent_title or entry.parent or '', + 'level': entry.level, + 'page': entry.page, + 'context': f"From {entry.title}", + 'metadata': { + 'parsing_method': 'toc_guided', + 'section_title': entry.title, + 'hierarchy_level': entry.level + } + }) + else: + # Split large sections into chunks + section_chunks = self._split_text_into_chunks(full_text) + for j, chunk_text in enumerate(section_chunks): + chunks.append({ + 'text': chunk_text.strip(), + 'title': f"{entry.title} (Part {j+1})", + 'parent_title': entry.parent_title or entry.parent or '', + 'level': entry.level, + 'page': entry.page, + 'context': f"Part {j+1} of {entry.title}", + 'metadata': { + 'parsing_method': 'toc_guided', + 'section_title': entry.title, + 'hierarchy_level': entry.level, + 'part_number': j + 1, + 'total_parts': len(section_chunks) + } + }) + + return chunks + + def _split_text_into_chunks(self, text: str) -> List[str]: + """Split text into chunks while preserving sentence boundaries.""" + sentences = re.split(r'(?<=[.!?])\s+', text) + chunks = [] + current_chunk = [] + current_size = 0 + + for sentence in sentences: + sentence_size = len(sentence) + + if current_size + sentence_size > self.target_chunk_size and current_chunk: + # Save current chunk + chunks.append(' '.join(current_chunk)) + current_chunk = [sentence] + current_size = sentence_size + else: + current_chunk.append(sentence) + current_size += sentence_size + 1 # +1 for space + + if current_chunk: + chunks.append(' '.join(current_chunk)) + + return chunks + + +def parse_pdf_with_toc_guidance(pdf_data: Dict, **kwargs) -> List[Dict]: + """Main entry point for TOC-guided parsing.""" + parser = TOCGuidedParser(**kwargs) + + # Parse TOC + pages = pdf_data.get('pages', []) + toc_entries = parser.parse_toc(pages) + + print(f"Found {len(toc_entries)} TOC entries") + + if not toc_entries: + print("No TOC entries found, falling back to basic chunking") + from .chunker import chunk_technical_text + return chunk_technical_text(pdf_data.get('text', '')) + + # Create chunks based on TOC + chunks = parser.create_chunks_from_toc(pdf_data, toc_entries) + + print(f"Created {len(chunks)} chunks from TOC structure") + + return chunks \ No newline at end of file diff --git a/shared_utils/vector_stores/embeddings/__init__.py b/shared_utils/vector_stores/embeddings/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a56edba636fcf28480a085092019159318cee266 --- /dev/null +++ b/shared_utils/vector_stores/embeddings/__init__.py @@ -0,0 +1 @@ +# Embeddings module \ No newline at end of file diff --git a/shared_utils/vector_stores/embeddings/__pycache__/__init__.cpython-312.pyc b/shared_utils/vector_stores/embeddings/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4ec2d7d00bae47cd42db0335f28226e9521959c Binary files /dev/null and b/shared_utils/vector_stores/embeddings/__pycache__/__init__.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/embeddings/__pycache__/generator.cpython-312.pyc b/shared_utils/vector_stores/embeddings/__pycache__/generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb3aa25f845ca25f73d8e77f92568c5e8da3e912 Binary files /dev/null and b/shared_utils/vector_stores/embeddings/__pycache__/generator.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/embeddings/generator.py b/shared_utils/vector_stores/embeddings/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b0dba2656760b6b8bf6f64e0fa2d3e66e7bc0e4b --- /dev/null +++ b/shared_utils/vector_stores/embeddings/generator.py @@ -0,0 +1,84 @@ +import numpy as np +import torch +from typing import List, Optional +from sentence_transformers import SentenceTransformer + +# Global cache for embeddings +_embedding_cache = {} +_model_cache = {} + + +def generate_embeddings( + texts: List[str], + model_name: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", + batch_size: int = 32, + use_mps: bool = True, +) -> np.ndarray: + """ + Generate embeddings for text chunks with caching. + + Args: + texts: List of text chunks to embed + model_name: SentenceTransformer model identifier + batch_size: Processing batch size + use_mps: Use Apple Silicon acceleration + + Returns: + numpy array of shape (len(texts), embedding_dim) + + Performance Target: + - 100 texts/second on M4-Pro + - 384-dimensional embeddings + - Memory usage <500MB + """ + # Check cache for all texts + cache_keys = [f"{model_name}:{text}" for text in texts] + cached_embeddings = [] + texts_to_compute = [] + compute_indices = [] + + for i, key in enumerate(cache_keys): + if key in _embedding_cache: + cached_embeddings.append((i, _embedding_cache[key])) + else: + texts_to_compute.append(texts[i]) + compute_indices.append(i) + + # Load model if needed + if model_name not in _model_cache: + model = SentenceTransformer(model_name) + device = 'mps' if use_mps and torch.backends.mps.is_available() else 'cpu' + model = model.to(device) + model.eval() + _model_cache[model_name] = model + else: + model = _model_cache[model_name] + + # Compute new embeddings + if texts_to_compute: + with torch.no_grad(): + new_embeddings = model.encode( + texts_to_compute, + batch_size=batch_size, + convert_to_numpy=True, + normalize_embeddings=False + ).astype(np.float32) + + # Cache new embeddings + for i, text in enumerate(texts_to_compute): + key = f"{model_name}:{text}" + _embedding_cache[key] = new_embeddings[i] + + # Reconstruct full embedding array + result = np.zeros((len(texts), 384), dtype=np.float32) + + # Fill cached embeddings + for idx, embedding in cached_embeddings: + result[idx] = embedding + + # Fill newly computed embeddings + if texts_to_compute: + for i, original_idx in enumerate(compute_indices): + result[original_idx] = new_embeddings[i] + + return result diff --git a/shared_utils/vector_stores/generation/__pycache__/answer_generator.cpython-312.pyc b/shared_utils/vector_stores/generation/__pycache__/answer_generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4f219a09f72d62b17da1ed85bc6a3d23f8f83e0 Binary files /dev/null and b/shared_utils/vector_stores/generation/__pycache__/answer_generator.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/generation/__pycache__/prompt_templates.cpython-312.pyc b/shared_utils/vector_stores/generation/__pycache__/prompt_templates.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f21016f0b34dd53e605320c1602aaa914e364616 Binary files /dev/null and b/shared_utils/vector_stores/generation/__pycache__/prompt_templates.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/generation/answer_generator.py b/shared_utils/vector_stores/generation/answer_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..af02878e747c06f1695758f325ddfc778dbaf1db --- /dev/null +++ b/shared_utils/vector_stores/generation/answer_generator.py @@ -0,0 +1,703 @@ +""" +Answer generation module using Ollama for local LLM inference. + +This module provides answer generation with citation support for RAG systems, +optimized for technical documentation Q&A on Apple Silicon. +""" + +import json +import logging +from dataclasses import dataclass +from typing import List, Dict, Any, Optional, Generator, Tuple +import ollama +from datetime import datetime +import re +from pathlib import Path +import sys + +# Import calibration framework +try: + # Try relative import first + from ...project_1_technical_rag.src.confidence_calibration import ConfidenceCalibrator +except ImportError: + # Fallback to absolute import + project_root = Path(__file__).parent.parent.parent / "project-1-technical-rag" + sys.path.insert(0, str(project_root / "src")) + from confidence_calibration import ConfidenceCalibrator + +logger = logging.getLogger(__name__) + + +@dataclass +class Citation: + """Represents a citation to a source document chunk.""" + chunk_id: str + page_number: int + source_file: str + relevance_score: float + text_snippet: str + + +@dataclass +class GeneratedAnswer: + """Represents a generated answer with citations.""" + answer: str + citations: List[Citation] + confidence_score: float + generation_time: float + model_used: str + context_used: List[Dict[str, Any]] + + +class AnswerGenerator: + """ + Generates answers using local LLMs via Ollama with citation support. + + Optimized for technical documentation Q&A with: + - Streaming response support + - Citation extraction and formatting + - Confidence scoring + - Fallback model support + """ + + def __init__( + self, + primary_model: str = "llama3.2:3b", + fallback_model: str = "mistral:latest", + temperature: float = 0.3, + max_tokens: int = 1024, + stream: bool = True, + enable_calibration: bool = True + ): + """ + Initialize the answer generator. + + Args: + primary_model: Primary Ollama model to use + fallback_model: Fallback model for complex queries + temperature: Generation temperature (0.0-1.0) + max_tokens: Maximum tokens to generate + stream: Whether to stream responses + enable_calibration: Whether to enable confidence calibration + """ + self.primary_model = primary_model + self.fallback_model = fallback_model + self.temperature = temperature + self.max_tokens = max_tokens + self.stream = stream + self.client = ollama.Client() + + # Initialize confidence calibration + self.enable_calibration = enable_calibration + self.calibrator = None + if enable_calibration: + try: + self.calibrator = ConfidenceCalibrator() + logger.info("Confidence calibration enabled") + except Exception as e: + logger.warning(f"Failed to initialize calibration: {e}") + self.enable_calibration = False + + # Verify models are available + self._verify_models() + + def _verify_models(self) -> None: + """Verify that required models are available.""" + try: + model_list = self.client.list() + available_models = [] + + # Handle Ollama's ListResponse object + if hasattr(model_list, 'models'): + for model in model_list.models: + if hasattr(model, 'model'): + available_models.append(model.model) + elif isinstance(model, dict) and 'model' in model: + available_models.append(model['model']) + + if self.primary_model not in available_models: + logger.warning(f"Primary model {self.primary_model} not found. Available models: {available_models}") + raise ValueError(f"Model {self.primary_model} not available. Please run: ollama pull {self.primary_model}") + + if self.fallback_model not in available_models: + logger.warning(f"Fallback model {self.fallback_model} not found in: {available_models}") + + except Exception as e: + logger.error(f"Error verifying models: {e}") + raise + + def _create_system_prompt(self) -> str: + """Create system prompt for technical documentation Q&A.""" + return """You are a technical documentation assistant that provides clear, accurate answers based on the provided context. + +CORE PRINCIPLES: +1. ANSWER DIRECTLY: If context contains the answer, provide it clearly and confidently +2. BE CONCISE: Keep responses focused and avoid unnecessary uncertainty language +3. CITE ACCURATELY: Use [chunk_X] citations for every fact from context + +RESPONSE GUIDELINES: +- If context has sufficient information → Answer directly and confidently +- If context has partial information → Answer what's available, note what's missing briefly +- If context is irrelevant → Brief refusal: "This information isn't available in the provided documents" + +CITATION FORMAT: +- Use [chunk_1], [chunk_2] etc. for all facts from context +- Example: "According to [chunk_1], RISC-V is an open-source architecture." + +WHAT TO AVOID: +- Do NOT add details not in context +- Do NOT second-guess yourself if context is clear +- Do NOT use phrases like "does not contain sufficient information" when context clearly answers the question +- Do NOT be overly cautious when context is adequate + +Be direct, confident, and accurate. If the context answers the question, provide that answer clearly.""" + + def _format_context(self, chunks: List[Dict[str, Any]]) -> str: + """ + Format retrieved chunks into context for the LLM. + + Args: + chunks: List of retrieved chunks with metadata + + Returns: + Formatted context string + """ + context_parts = [] + + for i, chunk in enumerate(chunks): + chunk_text = chunk.get('content', chunk.get('text', '')) + page_num = chunk.get('metadata', {}).get('page_number', 'unknown') + source = chunk.get('metadata', {}).get('source', 'unknown') + + context_parts.append( + f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n" + ) + + return "\n---\n".join(context_parts) + + def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]: + """ + Extract citations from the generated answer and integrate them naturally. + + Args: + answer: Generated answer with [chunk_X] citations + chunks: Original chunks used for context + + Returns: + Tuple of (natural_answer, citations) + """ + citations = [] + citation_pattern = r'\[chunk_(\d+)\]' + + cited_chunks = set() + + # Find [chunk_X] citations and collect cited chunks + matches = re.finditer(citation_pattern, answer) + for match in matches: + chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index + if 0 <= chunk_idx < len(chunks): + cited_chunks.add(chunk_idx) + + # Create Citation objects for each cited chunk + chunk_to_source = {} + for idx in cited_chunks: + chunk = chunks[idx] + citation = Citation( + chunk_id=chunk.get('id', f'chunk_{idx}'), + page_number=chunk.get('metadata', {}).get('page_number', 0), + source_file=chunk.get('metadata', {}).get('source', 'unknown'), + relevance_score=chunk.get('score', 0.0), + text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...' + ) + citations.append(citation) + + # Map chunk reference to natural source name + source_name = chunk.get('metadata', {}).get('source', 'unknown') + if source_name != 'unknown': + # Use just the filename without extension for natural reference + natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ') + chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation" + else: + chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation" + + # Replace [chunk_X] with natural references instead of removing them + natural_answer = answer + for chunk_ref, natural_ref in chunk_to_source.items(): + natural_answer = natural_answer.replace(chunk_ref, natural_ref) + + # Clean up any remaining unreferenced citations (fallback) + natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer) + + # Clean up multiple spaces and formatting + natural_answer = re.sub(r'\s+', ' ', natural_answer).strip() + + return natural_answer, citations + + def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float: + """ + Calculate confidence score for the generated answer with improved calibration. + + Args: + answer: Generated answer + citations: Extracted citations + chunks: Retrieved chunks + + Returns: + Confidence score (0.0-1.0) + """ + # Check if no chunks were provided first + if not chunks: + return 0.05 # No context = very low confidence + + # Assess context quality to determine base confidence + scores = [chunk.get('score', 0) for chunk in chunks] + max_relevance = max(scores) if scores else 0 + avg_relevance = sum(scores) / len(scores) if scores else 0 + + # Dynamic base confidence based on context quality + if max_relevance >= 0.8: + confidence = 0.6 # High-quality context starts high + elif max_relevance >= 0.6: + confidence = 0.4 # Good context starts moderately + elif max_relevance >= 0.4: + confidence = 0.2 # Fair context starts low + else: + confidence = 0.05 # Poor context starts very low + + # Strong uncertainty and explicit refusal indicators + strong_uncertainty_phrases = [ + "does not contain sufficient information", + "context does not provide", + "insufficient information", + "cannot determine", + "refuse to answer", + "cannot answer", + "does not contain relevant", + "no relevant context", + "missing from the provided context" + ] + + # Weak uncertainty phrases that might be in nuanced but correct answers + weak_uncertainty_phrases = [ + "unclear", + "conflicting", + "not specified", + "questionable", + "not contained", + "no mention", + "no relevant", + "missing", + "not explicitly" + ] + + # Check for strong uncertainty - these should drastically reduce confidence + if any(phrase in answer.lower() for phrase in strong_uncertainty_phrases): + return min(0.1, confidence * 0.2) # Max 10% for explicit refusal/uncertainty + + # Check for weak uncertainty - reduce but don't destroy confidence for good context + weak_uncertainty_count = sum(1 for phrase in weak_uncertainty_phrases if phrase in answer.lower()) + if weak_uncertainty_count > 0: + if max_relevance >= 0.7 and citations: + # Good context with citations - reduce less severely + confidence *= (0.8 ** weak_uncertainty_count) # Moderate penalty + else: + # Poor context - reduce more severely + confidence *= (0.5 ** weak_uncertainty_count) # Strong penalty + + # If all chunks have very low relevance scores, cap confidence low + if max_relevance < 0.4: + return min(0.08, confidence) # Max 8% for low relevance context + + # Factor 1: Citation quality and coverage + if citations and chunks: + citation_ratio = len(citations) / min(len(chunks), 3) + + # Strong boost for high-relevance citations + relevant_chunks = [c for c in chunks if c.get('score', 0) > 0.6] + if relevant_chunks: + # Significant boost for citing relevant chunks + confidence += 0.25 * citation_ratio + + # Extra boost if citing majority of relevant chunks + if len(citations) >= len(relevant_chunks) * 0.5: + confidence += 0.15 + else: + # Small boost for citations to lower-relevance chunks + confidence += 0.1 * citation_ratio + else: + # No citations = reduce confidence unless it's a simple factual statement + if max_relevance >= 0.8 and len(answer.split()) < 20: + confidence *= 0.8 # Gentle penalty for uncited but simple answers + else: + confidence *= 0.6 # Stronger penalty for complex uncited answers + + # Factor 2: Relevance score reinforcement + if citations: + avg_citation_relevance = sum(c.relevance_score for c in citations) / len(citations) + if avg_citation_relevance > 0.8: + confidence += 0.2 # Strong boost for highly relevant citations + elif avg_citation_relevance > 0.6: + confidence += 0.1 # Moderate boost + elif avg_citation_relevance < 0.4: + confidence *= 0.6 # Penalty for low-relevance citations + + # Factor 3: Context utilization quality + if chunks: + avg_chunk_length = sum(len(chunk.get('content', chunk.get('text', ''))) for chunk in chunks) / len(chunks) + + # Boost for substantial, high-quality context + if avg_chunk_length > 200 and max_relevance > 0.8: + confidence += 0.1 + elif avg_chunk_length < 50: # Very short chunks + confidence *= 0.8 + + # Factor 4: Answer characteristics + answer_words = len(answer.split()) + if answer_words < 10: + confidence *= 0.9 # Slight penalty for very short answers + elif answer_words > 50 and citations: + confidence += 0.05 # Small boost for detailed cited answers + + # Factor 5: High-quality scenario bonus + if (max_relevance >= 0.8 and citations and + len(citations) > 0 and + not any(phrase in answer.lower() for phrase in strong_uncertainty_phrases)): + # This is a high-quality response scenario + confidence += 0.15 + + raw_confidence = min(confidence, 0.95) # Cap at 95% to maintain some uncertainty + + # Apply temperature scaling calibration if available + if self.enable_calibration and self.calibrator and self.calibrator.is_fitted: + try: + calibrated_confidence = self.calibrator.calibrate_confidence(raw_confidence) + logger.debug(f"Confidence calibrated: {raw_confidence:.3f} -> {calibrated_confidence:.3f}") + return calibrated_confidence + except Exception as e: + logger.warning(f"Calibration failed, using raw confidence: {e}") + + return raw_confidence + + def fit_calibration(self, validation_data: List[Dict[str, Any]]) -> float: + """ + Fit temperature scaling calibration using validation data. + + Args: + validation_data: List of dicts with 'confidence' and 'correctness' keys + + Returns: + Optimal temperature parameter + """ + if not self.enable_calibration or not self.calibrator: + logger.warning("Calibration not enabled or not available") + return 1.0 + + try: + confidences = [item['confidence'] for item in validation_data] + correctness = [item['correctness'] for item in validation_data] + + optimal_temp = self.calibrator.fit_temperature_scaling(confidences, correctness) + logger.info(f"Calibration fitted with temperature: {optimal_temp:.3f}") + return optimal_temp + + except Exception as e: + logger.error(f"Failed to fit calibration: {e}") + return 1.0 + + def save_calibration(self, filepath: str) -> bool: + """Save fitted calibration to file.""" + if not self.calibrator or not self.calibrator.is_fitted: + logger.warning("No fitted calibration to save") + return False + + try: + calibration_data = { + 'temperature': self.calibrator.temperature, + 'is_fitted': self.calibrator.is_fitted, + 'model_info': { + 'primary_model': self.primary_model, + 'fallback_model': self.fallback_model + } + } + + with open(filepath, 'w') as f: + json.dump(calibration_data, f, indent=2) + + logger.info(f"Calibration saved to {filepath}") + return True + + except Exception as e: + logger.error(f"Failed to save calibration: {e}") + return False + + def load_calibration(self, filepath: str) -> bool: + """Load fitted calibration from file.""" + if not self.enable_calibration or not self.calibrator: + logger.warning("Calibration not enabled") + return False + + try: + with open(filepath, 'r') as f: + calibration_data = json.load(f) + + self.calibrator.temperature = calibration_data['temperature'] + self.calibrator.is_fitted = calibration_data['is_fitted'] + + logger.info(f"Calibration loaded from {filepath} (temp: {self.calibrator.temperature:.3f})") + return True + + except Exception as e: + logger.error(f"Failed to load calibration: {e}") + return False + + def generate( + self, + query: str, + chunks: List[Dict[str, Any]], + use_fallback: bool = False + ) -> GeneratedAnswer: + """ + Generate an answer based on the query and retrieved chunks. + + Args: + query: User's question + chunks: Retrieved document chunks + use_fallback: Whether to use fallback model + + Returns: + GeneratedAnswer object with answer, citations, and metadata + """ + start_time = datetime.now() + model = self.fallback_model if use_fallback else self.primary_model + + # Check for no-context or very poor context situation + if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks): + # Handle no-context situation with brief, professional refusal + user_prompt = f"""Context: [NO RELEVANT CONTEXT FOUND] + +Question: {query} + +INSTRUCTION: Respond with exactly this brief message: + +"This information isn't available in the provided documents." + +DO NOT elaborate, explain, or add any other information.""" + else: + # Format context from chunks + context = self._format_context(chunks) + + # Create concise prompt for faster generation + user_prompt = f"""Context: +{context} + +Question: {query} + +Instructions: Answer using only the context above. Cite with [chunk_1], [chunk_2] etc. + +Answer:""" + + try: + # Generate response + response = self.client.chat( + model=model, + messages=[ + {"role": "system", "content": self._create_system_prompt()}, + {"role": "user", "content": user_prompt} + ], + options={ + "temperature": self.temperature, + "num_predict": min(self.max_tokens, 300), # Reduce max tokens for speed + "top_k": 40, # Optimize sampling for speed + "top_p": 0.9, + "repeat_penalty": 1.1 + }, + stream=False # Get complete response for processing + ) + + # Extract answer + answer_with_citations = response['message']['content'] + + # Extract and clean citations + clean_answer, citations = self._extract_citations(answer_with_citations, chunks) + + # Calculate confidence + confidence = self._calculate_confidence(clean_answer, citations, chunks) + + # Calculate generation time + generation_time = (datetime.now() - start_time).total_seconds() + + return GeneratedAnswer( + answer=clean_answer, + citations=citations, + confidence_score=confidence, + generation_time=generation_time, + model_used=model, + context_used=chunks + ) + + except Exception as e: + logger.error(f"Error generating answer: {e}") + # Return a fallback response + return GeneratedAnswer( + answer="I apologize, but I encountered an error while generating the answer. Please try again.", + citations=[], + confidence_score=0.0, + generation_time=0.0, + model_used=model, + context_used=chunks + ) + + def generate_stream( + self, + query: str, + chunks: List[Dict[str, Any]], + use_fallback: bool = False + ) -> Generator[str, None, GeneratedAnswer]: + """ + Generate an answer with streaming support. + + Args: + query: User's question + chunks: Retrieved document chunks + use_fallback: Whether to use fallback model + + Yields: + Partial answer strings + + Returns: + Final GeneratedAnswer object + """ + start_time = datetime.now() + model = self.fallback_model if use_fallback else self.primary_model + + # Check for no-context or very poor context situation + if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks): + # Handle no-context situation with brief, professional refusal + user_prompt = f"""Context: [NO RELEVANT CONTEXT FOUND] + +Question: {query} + +INSTRUCTION: Respond with exactly this brief message: + +"This information isn't available in the provided documents." + +DO NOT elaborate, explain, or add any other information.""" + else: + # Format context from chunks + context = self._format_context(chunks) + + # Create concise prompt for faster generation + user_prompt = f"""Context: +{context} + +Question: {query} + +Instructions: Answer using only the context above. Cite with [chunk_1], [chunk_2] etc. + +Answer:""" + + try: + # Generate streaming response + stream = self.client.chat( + model=model, + messages=[ + {"role": "system", "content": self._create_system_prompt()}, + {"role": "user", "content": user_prompt} + ], + options={ + "temperature": self.temperature, + "num_predict": min(self.max_tokens, 300), # Reduce max tokens for speed + "top_k": 40, # Optimize sampling for speed + "top_p": 0.9, + "repeat_penalty": 1.1 + }, + stream=True + ) + + # Collect full answer while streaming + full_answer = "" + for chunk in stream: + if 'message' in chunk and 'content' in chunk['message']: + partial = chunk['message']['content'] + full_answer += partial + yield partial + + # Process complete answer + clean_answer, citations = self._extract_citations(full_answer, chunks) + confidence = self._calculate_confidence(clean_answer, citations, chunks) + generation_time = (datetime.now() - start_time).total_seconds() + + return GeneratedAnswer( + answer=clean_answer, + citations=citations, + confidence_score=confidence, + generation_time=generation_time, + model_used=model, + context_used=chunks + ) + + except Exception as e: + logger.error(f"Error in streaming generation: {e}") + yield "I apologize, but I encountered an error while generating the answer." + + return GeneratedAnswer( + answer="Error occurred during generation.", + citations=[], + confidence_score=0.0, + generation_time=0.0, + model_used=model, + context_used=chunks + ) + + def format_answer_with_citations(self, generated_answer: GeneratedAnswer) -> str: + """ + Format the generated answer with citations for display. + + Args: + generated_answer: GeneratedAnswer object + + Returns: + Formatted string with answer and citations + """ + formatted = f"{generated_answer.answer}\n\n" + + if generated_answer.citations: + formatted += "**Sources:**\n" + for i, citation in enumerate(generated_answer.citations, 1): + formatted += f"{i}. {citation.source_file} (Page {citation.page_number})\n" + + formatted += f"\n*Confidence: {generated_answer.confidence_score:.1%} | " + formatted += f"Model: {generated_answer.model_used} | " + formatted += f"Time: {generated_answer.generation_time:.2f}s*" + + return formatted + + +if __name__ == "__main__": + # Example usage + generator = AnswerGenerator() + + # Example chunks (would come from retrieval system) + example_chunks = [ + { + "id": "chunk_1", + "content": "RISC-V is an open-source instruction set architecture (ISA) based on reduced instruction set computer (RISC) principles.", + "metadata": {"page_number": 1, "source": "riscv-spec.pdf"}, + "score": 0.95 + }, + { + "id": "chunk_2", + "content": "The RISC-V ISA is designed to support a wide range of implementations including 32-bit, 64-bit, and 128-bit variants.", + "metadata": {"page_number": 2, "source": "riscv-spec.pdf"}, + "score": 0.89 + } + ] + + # Generate answer + result = generator.generate( + query="What is RISC-V?", + chunks=example_chunks + ) + + # Display formatted result + print(generator.format_answer_with_citations(result)) \ No newline at end of file diff --git a/shared_utils/vector_stores/generation/prompt_templates.py b/shared_utils/vector_stores/generation/prompt_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..b74bfbf0e7d66e9f37a96a31cad20fef4f1f3f24 --- /dev/null +++ b/shared_utils/vector_stores/generation/prompt_templates.py @@ -0,0 +1,450 @@ +""" +Prompt templates optimized for technical documentation Q&A. + +This module provides specialized prompt templates for different types of +technical queries, with a focus on embedded systems and AI documentation. +""" + +from enum import Enum +from typing import Dict, List, Optional +from dataclasses import dataclass + + +class QueryType(Enum): + """Types of technical queries.""" + DEFINITION = "definition" + IMPLEMENTATION = "implementation" + COMPARISON = "comparison" + TROUBLESHOOTING = "troubleshooting" + SPECIFICATION = "specification" + CODE_EXAMPLE = "code_example" + HARDWARE_CONSTRAINT = "hardware_constraint" + GENERAL = "general" + + +@dataclass +class PromptTemplate: + """Represents a prompt template with its components.""" + system_prompt: str + context_format: str + query_format: str + answer_guidelines: str + + +class TechnicalPromptTemplates: + """ + Collection of prompt templates optimized for technical documentation. + + Features: + - Domain-specific templates for embedded systems and AI + - Structured output formats + - Citation requirements + - Technical accuracy emphasis + """ + + @staticmethod + def get_base_system_prompt() -> str: + """Get the base system prompt for technical documentation.""" + return """You are an expert technical documentation assistant specializing in embedded systems, +RISC-V architecture, RTOS, and embedded AI/ML. Your role is to provide accurate, detailed +technical answers based strictly on the provided documentation context. + +Key responsibilities: +1. Answer questions using ONLY information from the provided context +2. Include precise citations using [chunk_X] notation for every claim +3. Maintain technical accuracy and use correct terminology +4. Format code snippets and technical specifications properly +5. Clearly state when information is not available in the context +6. Consider hardware constraints and embedded system limitations when relevant + +Never make up information. If the context doesn't contain the answer, say so explicitly.""" + + @staticmethod + def get_definition_template() -> PromptTemplate: + """Template for definition/explanation queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For definition queries, focus on: +- Clear, concise technical definitions +- Related concepts and terminology +- Technical context and applications +- Any acronym expansions""", + + context_format="""Technical Documentation Context: +{context}""", + + query_format="""Define or explain: {query} + +Provide a comprehensive technical definition with proper citations.""", + + answer_guidelines="""Structure your answer as: +1. Primary definition [chunk_X] +2. Technical details and specifications [chunk_Y] +3. Related concepts or applications [chunk_Z] +4. Any relevant acronyms or abbreviations""" + ) + + @staticmethod + def get_implementation_template() -> PromptTemplate: + """Template for implementation/how-to queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For implementation queries, focus on: +- Step-by-step instructions +- Required components or dependencies +- Code examples with proper formatting +- Hardware/software requirements +- Common pitfalls or considerations""", + + context_format="""Implementation Documentation: +{context}""", + + query_format="""Implementation question: {query} + +Provide detailed implementation guidance with code examples where available.""", + + answer_guidelines="""Structure your answer as: +1. Overview of the implementation approach [chunk_X] +2. Prerequisites and requirements [chunk_Y] +3. Step-by-step implementation: + - Step 1: Description [chunk_Z] + - Step 2: Description [chunk_W] +4. Code example (if available): +```language +// Code here +``` +5. Important considerations or warnings""" + ) + + @staticmethod + def get_comparison_template() -> PromptTemplate: + """Template for comparison queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For comparison queries, focus on: +- Clear distinction between compared items +- Technical specifications and differences +- Use cases for each option +- Performance or resource implications +- Recommendations based on context""", + + context_format="""Technical Comparison Context: +{context}""", + + query_format="""Compare: {query} + +Provide a detailed technical comparison with clear distinctions.""", + + answer_guidelines="""Structure your answer as: +1. Overview of items being compared [chunk_X] +2. Key differences: + - Feature A: Item1 vs Item2 [chunk_Y] + - Feature B: Item1 vs Item2 [chunk_Z] +3. Technical specifications comparison +4. Use case recommendations +5. Performance/resource considerations""" + ) + + @staticmethod + def get_specification_template() -> PromptTemplate: + """Template for specification/parameter queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For specification queries, focus on: +- Exact technical specifications +- Parameter ranges and limits +- Units and measurements +- Compliance with standards +- Version-specific information""", + + context_format="""Technical Specifications: +{context}""", + + query_format="""Specification query: {query} + +Provide precise technical specifications with all relevant parameters.""", + + answer_guidelines="""Structure your answer as: +1. Specification overview [chunk_X] +2. Detailed parameters: + - Parameter 1: value (unit) [chunk_Y] + - Parameter 2: value (unit) [chunk_Z] +3. Operating conditions or constraints +4. Compliance/standards information +5. Version or variant notes""" + ) + + @staticmethod + def get_code_example_template() -> PromptTemplate: + """Template for code example queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For code example queries, focus on: +- Complete, runnable code examples +- Proper syntax highlighting +- Clear comments and documentation +- Error handling +- Best practices for embedded systems""", + + context_format="""Code Examples and Documentation: +{context}""", + + query_format="""Code example request: {query} + +Provide working code examples with explanations.""", + + answer_guidelines="""Structure your answer as: +1. Purpose and overview [chunk_X] +2. Required includes/imports [chunk_Y] +3. Complete code example: +```c +// Or appropriate language +#include + +// Function or code implementation +``` +4. Key points explained [chunk_Z] +5. Common variations or modifications""" + ) + + @staticmethod + def get_hardware_constraint_template() -> PromptTemplate: + """Template for hardware constraint queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For hardware constraint queries, focus on: +- Memory requirements (RAM, Flash) +- Processing power needs (MIPS, frequency) +- Power consumption +- I/O requirements +- Real-time constraints +- Temperature/environmental limits""", + + context_format="""Hardware Specifications and Constraints: +{context}""", + + query_format="""Hardware constraint question: {query} + +Analyze feasibility and constraints for embedded deployment.""", + + answer_guidelines="""Structure your answer as: +1. Hardware requirements summary [chunk_X] +2. Detailed constraints: + - Memory: RAM/Flash requirements [chunk_Y] + - Processing: CPU/frequency needs [chunk_Z] + - Power: Consumption estimates [chunk_W] +3. Feasibility assessment +4. Optimization suggestions +5. Alternative approaches if constraints are exceeded""" + ) + + @staticmethod + def get_troubleshooting_template() -> PromptTemplate: + """Template for troubleshooting queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt() + """ + +For troubleshooting queries, focus on: +- Common error causes +- Diagnostic steps +- Solution procedures +- Preventive measures +- Debug techniques for embedded systems""", + + context_format="""Troubleshooting Documentation: +{context}""", + + query_format="""Troubleshooting issue: {query} + +Provide diagnostic steps and solutions.""", + + answer_guidelines="""Structure your answer as: +1. Problem identification [chunk_X] +2. Common causes: + - Cause 1: Description [chunk_Y] + - Cause 2: Description [chunk_Z] +3. Diagnostic steps: + - Step 1: Check... [chunk_W] + - Step 2: Verify... [chunk_V] +4. Solutions for each cause +5. Prevention recommendations""" + ) + + @staticmethod + def get_general_template() -> PromptTemplate: + """Default template for general queries.""" + return PromptTemplate( + system_prompt=TechnicalPromptTemplates.get_base_system_prompt(), + + context_format="""Technical Documentation: +{context}""", + + query_format="""Question: {query} + +Provide a comprehensive technical answer based on the documentation.""", + + answer_guidelines="""Provide a clear, well-structured answer that: +1. Directly addresses the question +2. Includes all relevant technical details +3. Cites sources using [chunk_X] notation +4. Maintains technical accuracy +5. Acknowledges any limitations in available information""" + ) + + @staticmethod + def detect_query_type(query: str) -> QueryType: + """ + Detect the type of query based on keywords and patterns. + + Args: + query: User's question + + Returns: + Detected QueryType + """ + query_lower = query.lower() + + # Definition keywords + if any(keyword in query_lower for keyword in [ + "what is", "what are", "define", "definition", "meaning of", "explain what" + ]): + return QueryType.DEFINITION + + # Implementation keywords + if any(keyword in query_lower for keyword in [ + "how to", "how do i", "implement", "setup", "configure", "install" + ]): + return QueryType.IMPLEMENTATION + + # Comparison keywords + if any(keyword in query_lower for keyword in [ + "difference between", "compare", "vs", "versus", "better than", "which is" + ]): + return QueryType.COMPARISON + + # Specification keywords + if any(keyword in query_lower for keyword in [ + "specification", "specs", "parameters", "limits", "range", "maximum", "minimum" + ]): + return QueryType.SPECIFICATION + + # Code example keywords + if any(keyword in query_lower for keyword in [ + "example", "code", "snippet", "sample", "demo", "show me" + ]): + return QueryType.CODE_EXAMPLE + + # Hardware constraint keywords + if any(keyword in query_lower for keyword in [ + "memory", "ram", "flash", "mcu", "constraint", "fit on", "run on", "power consumption" + ]): + return QueryType.HARDWARE_CONSTRAINT + + # Troubleshooting keywords + if any(keyword in query_lower for keyword in [ + "error", "problem", "issue", "debug", "troubleshoot", "fix", "solve", "not working" + ]): + return QueryType.TROUBLESHOOTING + + return QueryType.GENERAL + + @staticmethod + def get_template_for_query(query: str) -> PromptTemplate: + """ + Get the appropriate template based on query type. + + Args: + query: User's question + + Returns: + Appropriate PromptTemplate + """ + query_type = TechnicalPromptTemplates.detect_query_type(query) + + template_map = { + QueryType.DEFINITION: TechnicalPromptTemplates.get_definition_template, + QueryType.IMPLEMENTATION: TechnicalPromptTemplates.get_implementation_template, + QueryType.COMPARISON: TechnicalPromptTemplates.get_comparison_template, + QueryType.SPECIFICATION: TechnicalPromptTemplates.get_specification_template, + QueryType.CODE_EXAMPLE: TechnicalPromptTemplates.get_code_example_template, + QueryType.HARDWARE_CONSTRAINT: TechnicalPromptTemplates.get_hardware_constraint_template, + QueryType.TROUBLESHOOTING: TechnicalPromptTemplates.get_troubleshooting_template, + QueryType.GENERAL: TechnicalPromptTemplates.get_general_template + } + + return template_map[query_type]() + + @staticmethod + def format_prompt_with_template( + query: str, + context: str, + template: Optional[PromptTemplate] = None + ) -> Dict[str, str]: + """ + Format a complete prompt using the appropriate template. + + Args: + query: User's question + context: Retrieved context chunks + template: Optional specific template (auto-detected if None) + + Returns: + Dict with 'system' and 'user' prompts + """ + if template is None: + template = TechnicalPromptTemplates.get_template_for_query(query) + + # Format the context + formatted_context = template.context_format.format(context=context) + + # Format the query + formatted_query = template.query_format.format(query=query) + + # Combine into user prompt + user_prompt = f"""{formatted_context} + +{formatted_query} + +{template.answer_guidelines}""" + + return { + "system": template.system_prompt, + "user": user_prompt + } + + +# Example usage and testing +if __name__ == "__main__": + # Test query type detection + test_queries = [ + "What is RISC-V?", + "How do I implement a timer interrupt?", + "What's the difference between FreeRTOS and Zephyr?", + "What are the memory specifications for STM32F4?", + "Show me an example of GPIO configuration", + "Can this model run on an MCU with 256KB RAM?", + "Debug error: undefined reference to main" + ] + + for query in test_queries: + query_type = TechnicalPromptTemplates.detect_query_type(query) + print(f"Query: '{query}' -> Type: {query_type.value}") + + # Example prompt formatting + example_context = "RISC-V is an open instruction set architecture..." + example_query = "What is RISC-V?" + + formatted = TechnicalPromptTemplates.format_prompt_with_template( + query=example_query, + context=example_context + ) + + print("\nFormatted prompt example:") + print("System:", formatted["system"][:100], "...") + print("User:", formatted["user"][:200], "...") \ No newline at end of file diff --git a/shared_utils/vector_stores/query_processing/__init__.py b/shared_utils/vector_stores/query_processing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..482ba426819b9ee75762815a83bcd51232a7ae24 --- /dev/null +++ b/shared_utils/vector_stores/query_processing/__init__.py @@ -0,0 +1,8 @@ +""" +Query processing utilities for intelligent RAG systems. +Provides query enhancement, analysis, and optimization capabilities. +""" + +from .query_enhancer import QueryEnhancer + +__all__ = ['QueryEnhancer'] \ No newline at end of file diff --git a/shared_utils/vector_stores/query_processing/__pycache__/__init__.cpython-312.pyc b/shared_utils/vector_stores/query_processing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e810d05d4617c4e2d512949b9241f74321b05ed0 Binary files /dev/null and b/shared_utils/vector_stores/query_processing/__pycache__/__init__.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/query_processing/__pycache__/query_enhancer.cpython-312.pyc b/shared_utils/vector_stores/query_processing/__pycache__/query_enhancer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9525380f13366ea46714a2aee7ebeedb125330b2 Binary files /dev/null and b/shared_utils/vector_stores/query_processing/__pycache__/query_enhancer.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/query_processing/query_enhancer.py b/shared_utils/vector_stores/query_processing/query_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..233591b67ce213046586dea00ecd7135cda9a842 --- /dev/null +++ b/shared_utils/vector_stores/query_processing/query_enhancer.py @@ -0,0 +1,644 @@ +""" +Intelligent query processing for technical documentation RAG. + +Provides adaptive query enhancement through technical term expansion, +acronym handling, and intelligent hybrid weighting optimization. +""" + +from typing import Dict, List, Any, Tuple, Set, Optional +import re +from collections import defaultdict +import time + + +class QueryEnhancer: + """ + Intelligent query processing for technical documentation RAG. + + Analyzes query characteristics and enhances retrieval through: + - Technical synonym expansion + - Acronym detection and expansion + - Adaptive hybrid weighting based on query type + - Query complexity analysis for optimal retrieval strategy + + Optimized for embedded systems and technical documentation domains. + + Performance: <10ms query enhancement, improves retrieval relevance by >10% + """ + + def __init__(self): + """Initialize QueryEnhancer with technical domain knowledge.""" + + # Technical vocabulary dictionary organized by domain + self.technical_synonyms = { + # Processor terminology + 'cpu': ['processor', 'microprocessor', 'central processing unit'], + 'mcu': ['microcontroller', 'microcontroller unit', 'embedded processor'], + 'core': ['processor core', 'cpu core', 'execution unit'], + 'alu': ['arithmetic logic unit', 'arithmetic unit'], + + # Memory terminology + 'memory': ['ram', 'storage', 'buffer', 'cache'], + 'flash': ['non-volatile memory', 'program memory', 'code storage'], + 'sram': ['static ram', 'static memory', 'cache memory'], + 'dram': ['dynamic ram', 'dynamic memory'], + 'cache': ['buffer', 'temporary storage', 'fast memory'], + + # Architecture terminology + 'risc-v': ['riscv', 'risc v', 'open isa', 'open instruction set'], + 'arm': ['advanced risc machine', 'acorn risc machine'], + 'isa': ['instruction set architecture', 'instruction set'], + 'architecture': ['design', 'structure', 'organization'], + + # Embedded systems terminology + 'rtos': ['real-time operating system', 'real-time os'], + 'interrupt': ['isr', 'interrupt service routine', 'exception handler'], + 'peripheral': ['hardware peripheral', 'external device', 'io device'], + 'firmware': ['embedded software', 'system software'], + 'bootloader': ['boot code', 'initialization code'], + + # Performance terminology + 'latency': ['delay', 'response time', 'execution time'], + 'throughput': ['bandwidth', 'data rate', 'performance'], + 'power': ['power consumption', 'energy usage', 'battery life'], + 'optimization': ['improvement', 'enhancement', 'tuning'], + + # Communication protocols + 'uart': ['serial communication', 'async serial'], + 'spi': ['serial peripheral interface', 'synchronous serial'], + 'i2c': ['inter-integrated circuit', 'two-wire interface'], + 'usb': ['universal serial bus'], + + # Development terminology + 'debug': ['debugging', 'troubleshooting', 'testing'], + 'compile': ['compilation', 'build', 'assembly'], + 'programming': ['coding', 'development', 'implementation'] + } + + # Comprehensive acronym expansions for embedded/technical domains + self.acronym_expansions = { + # Processor & Architecture + 'CPU': 'Central Processing Unit', + 'MCU': 'Microcontroller Unit', + 'MPU': 'Microprocessor Unit', + 'DSP': 'Digital Signal Processor', + 'GPU': 'Graphics Processing Unit', + 'ALU': 'Arithmetic Logic Unit', + 'FPU': 'Floating Point Unit', + 'MMU': 'Memory Management Unit', + 'ISA': 'Instruction Set Architecture', + 'RISC': 'Reduced Instruction Set Computer', + 'CISC': 'Complex Instruction Set Computer', + + # Memory & Storage + 'RAM': 'Random Access Memory', + 'ROM': 'Read Only Memory', + 'EEPROM': 'Electrically Erasable Programmable ROM', + 'SRAM': 'Static Random Access Memory', + 'DRAM': 'Dynamic Random Access Memory', + 'FRAM': 'Ferroelectric Random Access Memory', + 'MRAM': 'Magnetoresistive Random Access Memory', + 'DMA': 'Direct Memory Access', + + # Operating Systems & Software + 'RTOS': 'Real-Time Operating System', + 'OS': 'Operating System', + 'API': 'Application Programming Interface', + 'SDK': 'Software Development Kit', + 'IDE': 'Integrated Development Environment', + 'HAL': 'Hardware Abstraction Layer', + 'BSP': 'Board Support Package', + + # Interrupts & Exceptions + 'ISR': 'Interrupt Service Routine', + 'IRQ': 'Interrupt Request', + 'NMI': 'Non-Maskable Interrupt', + 'NVIC': 'Nested Vectored Interrupt Controller', + + # Communication Protocols + 'UART': 'Universal Asynchronous Receiver Transmitter', + 'USART': 'Universal Synchronous Asynchronous Receiver Transmitter', + 'SPI': 'Serial Peripheral Interface', + 'I2C': 'Inter-Integrated Circuit', + 'CAN': 'Controller Area Network', + 'USB': 'Universal Serial Bus', + 'TCP': 'Transmission Control Protocol', + 'UDP': 'User Datagram Protocol', + 'HTTP': 'HyperText Transfer Protocol', + 'MQTT': 'Message Queuing Telemetry Transport', + + # Analog & Digital + 'ADC': 'Analog to Digital Converter', + 'DAC': 'Digital to Analog Converter', + 'PWM': 'Pulse Width Modulation', + 'GPIO': 'General Purpose Input Output', + 'JTAG': 'Joint Test Action Group', + 'SWD': 'Serial Wire Debug', + + # Power & Clock + 'PLL': 'Phase Locked Loop', + 'VCO': 'Voltage Controlled Oscillator', + 'LDO': 'Low Dropout Regulator', + 'PMU': 'Power Management Unit', + 'RTC': 'Real Time Clock', + + # Standards & Organizations + 'IEEE': 'Institute of Electrical and Electronics Engineers', + 'ISO': 'International Organization for Standardization', + 'ANSI': 'American National Standards Institute', + 'IEC': 'International Electrotechnical Commission' + } + + # Compile regex patterns for efficiency + self._acronym_pattern = re.compile(r'\b[A-Z]{2,}\b') + self._technical_term_pattern = re.compile(r'\b\w+(?:-\w+)*\b', re.IGNORECASE) + self._question_indicators = re.compile(r'\b(?:how|what|why|when|where|which|explain|describe|define)\b', re.IGNORECASE) + + # Question type classification keywords + self.question_type_keywords = { + 'conceptual': ['how', 'why', 'what', 'explain', 'describe', 'understand', 'concept', 'theory'], + 'technical': ['configure', 'implement', 'setup', 'install', 'code', 'program', 'register'], + 'procedural': ['steps', 'process', 'procedure', 'workflow', 'guide', 'tutorial'], + 'troubleshooting': ['error', 'problem', 'issue', 'debug', 'fix', 'solve', 'troubleshoot'] + } + + def analyze_query_characteristics(self, query: str) -> Dict[str, Any]: + """ + Analyze query to determine optimal processing strategy. + + Performs comprehensive analysis including: + - Technical term detection and counting + - Acronym presence identification + - Question type classification + - Complexity scoring based on multiple factors + - Optimal hybrid weight recommendation + + Args: + query: User input query string + + Returns: + Dictionary with comprehensive query analysis: + - technical_term_count: Number of domain-specific terms detected + - has_acronyms: Boolean indicating acronym presence + - question_type: 'conceptual', 'technical', 'procedural', 'mixed' + - complexity_score: Float 0-1 indicating query complexity + - recommended_dense_weight: Optimal weight for hybrid search + - detected_acronyms: List of acronyms found + - technical_terms: List of technical terms found + + Performance: <2ms for typical queries + """ + if not query or not query.strip(): + return { + 'technical_term_count': 0, + 'has_acronyms': False, + 'question_type': 'unknown', + 'complexity_score': 0.0, + 'recommended_dense_weight': 0.7, + 'detected_acronyms': [], + 'technical_terms': [] + } + + query_lower = query.lower() + words = query.split() + + # Detect acronyms + detected_acronyms = self._acronym_pattern.findall(query) + has_acronyms = len(detected_acronyms) > 0 + + # Detect technical terms + technical_terms = [] + technical_term_count = 0 + + for word in words: + word_clean = re.sub(r'[^\w\-]', '', word.lower()) + if word_clean in self.technical_synonyms: + technical_terms.append(word_clean) + technical_term_count += 1 + # Also check for compound technical terms like "risc-v" + elif any(term in word_clean for term in ['risc-v', 'arm', 'cpu', 'mcu']): + technical_terms.append(word_clean) + technical_term_count += 1 + + # Add acronyms to technical term count + for acronym in detected_acronyms: + if acronym in self.acronym_expansions: + technical_term_count += 1 + + # Determine question type + question_type = self._classify_question_type(query_lower) + + # Calculate complexity score (0-1) + complexity_factors = [ + len(words) / 20.0, # Word count factor (normalized to 20 words max) + technical_term_count / 5.0, # Technical density (normalized to 5 terms max) + len(detected_acronyms) / 3.0, # Acronym density (normalized to 3 acronyms max) + 1.0 if self._question_indicators.search(query) else 0.5, # Question complexity + ] + complexity_score = min(1.0, sum(complexity_factors) / len(complexity_factors)) + + # Determine recommended dense weight based on analysis + recommended_dense_weight = self._calculate_optimal_weight( + question_type, technical_term_count, has_acronyms, complexity_score + ) + + return { + 'technical_term_count': technical_term_count, + 'has_acronyms': has_acronyms, + 'question_type': question_type, + 'complexity_score': complexity_score, + 'recommended_dense_weight': recommended_dense_weight, + 'detected_acronyms': detected_acronyms, + 'technical_terms': technical_terms, + 'word_count': len(words), + 'has_question_indicators': bool(self._question_indicators.search(query)) + } + + def _classify_question_type(self, query_lower: str) -> str: + """Classify query into conceptual, technical, procedural, or mixed categories.""" + type_scores = defaultdict(int) + + for question_type, keywords in self.question_type_keywords.items(): + for keyword in keywords: + if keyword in query_lower: + type_scores[question_type] += 1 + + if not type_scores: + return 'mixed' + + # Return type with highest score, or 'mixed' if tie + max_score = max(type_scores.values()) + top_types = [t for t, s in type_scores.items() if s == max_score] + + return top_types[0] if len(top_types) == 1 else 'mixed' + + def _calculate_optimal_weight(self, question_type: str, tech_terms: int, + has_acronyms: bool, complexity: float) -> float: + """Calculate optimal dense weight based on query characteristics.""" + + # Base weights by question type + base_weights = { + 'technical': 0.3, # Favor sparse for technical precision + 'conceptual': 0.8, # Favor dense for conceptual understanding + 'procedural': 0.5, # Balanced for step-by-step queries + 'troubleshooting': 0.4, # Slight sparse favor for specific issues + 'mixed': 0.7, # Default balanced + 'unknown': 0.7 # Default balanced + } + + weight = base_weights.get(question_type, 0.7) + + # Adjust based on technical term density + if tech_terms > 2: + weight -= 0.2 # More technical → favor sparse + elif tech_terms == 0: + weight += 0.1 # Less technical → favor dense + + # Adjust based on acronym presence + if has_acronyms: + weight -= 0.1 # Acronyms → favor sparse for exact matching + + # Adjust based on complexity + if complexity > 0.8: + weight += 0.1 # High complexity → favor dense for understanding + elif complexity < 0.3: + weight -= 0.1 # Low complexity → favor sparse for precision + + # Ensure weight stays within valid bounds + return max(0.1, min(0.9, weight)) + + def expand_technical_terms(self, query: str, max_expansions: int = 1) -> str: + """ + Expand query with technical synonyms while preventing bloat. + + Conservative expansion strategy: + - Maximum 1 synonym per technical term by default + - Prioritizes most relevant/common synonyms + - Maintains semantic focus while improving recall + + Args: + query: Original user query + max_expansions: Maximum synonyms per term (default 1 for focus) + + Returns: + Conservatively enhanced query + + Example: + Input: "CPU performance optimization" + Output: "CPU processor performance optimization" + + Performance: <3ms for typical queries + """ + if not query or not query.strip(): + return query + + words = query.split() + + # Conservative expansion: only add most relevant synonym + expansion_candidates = [] + + for word in words: + word_clean = re.sub(r'[^\w\-]', '', word.lower()) + + # Check for direct synonym expansion + if word_clean in self.technical_synonyms: + synonyms = self.technical_synonyms[word_clean] + # Add only the first (most common) synonym + if synonyms and max_expansions > 0: + expansion_candidates.append(synonyms[0]) + + # Limit total expansion to prevent bloat + max_total_expansions = min(2, len(words) // 2) # At most 50% expansion + selected_expansions = expansion_candidates[:max_total_expansions] + + # Reconstruct with minimal expansion + if selected_expansions: + return ' '.join(words + selected_expansions) + else: + return query + + def detect_and_expand_acronyms(self, query: str, conservative: bool = True) -> str: + """ + Detect technical acronyms and add their expansions conservatively. + + Conservative approach to prevent query bloat: + - Limits acronym expansions to most relevant ones + - Preserves original acronyms for exact matching + - Maintains query focus and performance + + Args: + query: Query potentially containing acronyms + conservative: If True, limits expansion to prevent bloat + + Returns: + Query with selective acronym expansions + + Example: + Input: "RTOS scheduling algorithm" + Output: "RTOS Real-Time Operating System scheduling algorithm" + + Performance: <2ms for typical queries + """ + if not query or not query.strip(): + return query + + # Find all acronyms in the query + acronyms = self._acronym_pattern.findall(query) + + if not acronyms: + return query + + # Conservative mode: limit expansions + if conservative and len(acronyms) > 2: + # Only expand first 2 acronyms to prevent bloat + acronyms = acronyms[:2] + + result = query + + # Expand selected acronyms + for acronym in acronyms: + if acronym in self.acronym_expansions: + expansion = self.acronym_expansions[acronym] + # Add expansion after the acronym (preserving original) + result = result.replace(acronym, f"{acronym} {expansion}", 1) + + return result + + def adaptive_hybrid_weighting(self, query: str) -> float: + """ + Determine optimal dense_weight based on query characteristics. + + Analyzes query to automatically determine the best balance between + dense semantic search and sparse keyword matching for optimal results. + + Strategy: + - Technical/exact queries → lower dense_weight (favor sparse/BM25) + - Conceptual questions → higher dense_weight (favor semantic) + - Mixed queries → balanced weighting based on complexity + + Args: + query: User query string + + Returns: + Float between 0.1 and 0.9 representing optimal dense_weight + + Performance: <2ms analysis time + """ + analysis = self.analyze_query_characteristics(query) + return analysis['recommended_dense_weight'] + + def enhance_query(self, query: str, conservative: bool = True) -> Dict[str, Any]: + """ + Comprehensive query enhancement with performance and quality focus. + + Optimized enhancement strategy: + - Conservative expansion to maintain semantic focus + - Performance-first approach with minimal overhead + - Quality validation to ensure improvements + + Args: + query: Original user query + conservative: Use conservative expansion (recommended for production) + + Returns: + Dictionary containing: + - enhanced_query: Optimized enhanced query + - optimal_weight: Recommended dense weight + - analysis: Complete query analysis + - enhancement_metadata: Performance and quality metrics + + Performance: <5ms total enhancement time + """ + start_time = time.perf_counter() + + # Fast analysis + analysis = self.analyze_query_characteristics(query) + + # Conservative enhancement approach + if conservative: + enhanced_query = self.expand_technical_terms(query, max_expansions=1) + enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True) + else: + # Legacy aggressive expansion + enhanced_query = self.expand_technical_terms(query, max_expansions=2) + enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=False) + + # Quality validation: prevent excessive bloat + expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0 + if expansion_ratio > 2.5: # Limit to 2.5x expansion + # Fallback to minimal enhancement + enhanced_query = self.expand_technical_terms(query, max_expansions=0) + enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True) + expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0 + + # Calculate optimal weight + optimal_weight = analysis['recommended_dense_weight'] + + enhancement_time = time.perf_counter() - start_time + + return { + 'enhanced_query': enhanced_query, + 'optimal_weight': optimal_weight, + 'analysis': analysis, + 'enhancement_metadata': { + 'original_length': len(query.split()), + 'enhanced_length': len(enhanced_query.split()), + 'expansion_ratio': expansion_ratio, + 'processing_time_ms': enhancement_time * 1000, + 'techniques_applied': ['conservative_expansion', 'quality_validation', 'adaptive_weighting'], + 'conservative_mode': conservative + } + } + + def expand_technical_terms_with_vocabulary( + self, + query: str, + vocabulary_index: Optional['VocabularyIndex'] = None, + min_frequency: int = 3 + ) -> str: + """ + Expand query with vocabulary-aware synonym filtering. + + Only adds synonyms that exist in the document corpus with sufficient + frequency to ensure relevance and prevent query dilution. + + Args: + query: Original query + vocabulary_index: Optional vocabulary index for filtering + min_frequency: Minimum term frequency required + + Returns: + Enhanced query with validated synonyms + + Performance: <2ms with vocabulary validation + """ + if not query or not query.strip(): + return query + + if vocabulary_index is None: + # Fallback to standard expansion + return self.expand_technical_terms(query, max_expansions=1) + + words = query.split() + expanded_terms = [] + + for word in words: + word_clean = re.sub(r'[^\w\-]', '', word.lower()) + + # Check for synonym expansion + if word_clean in self.technical_synonyms: + synonyms = self.technical_synonyms[word_clean] + + # Filter synonyms through vocabulary + valid_synonyms = vocabulary_index.filter_synonyms( + synonyms, + min_frequency=min_frequency + ) + + # Add only the best valid synonym + if valid_synonyms: + expanded_terms.append(valid_synonyms[0]) + + # Reconstruct query with validated expansions + if expanded_terms: + return ' '.join(words + expanded_terms) + else: + return query + + def enhance_query_with_vocabulary( + self, + query: str, + vocabulary_index: Optional['VocabularyIndex'] = None, + min_frequency: int = 3, + require_technical: bool = False + ) -> Dict[str, Any]: + """ + Enhanced query processing with vocabulary validation. + + Uses corpus vocabulary to ensure all expansions are relevant + and actually present in the documents. + + Args: + query: Original query + vocabulary_index: Vocabulary index for validation + min_frequency: Minimum term frequency + require_technical: Only expand with technical terms + + Returns: + Enhanced query with vocabulary-aware expansion + """ + start_time = time.perf_counter() + + # Perform analysis + analysis = self.analyze_query_characteristics(query) + + # Vocabulary-aware enhancement + if vocabulary_index: + # Technical term expansion with validation + enhanced_query = self.expand_technical_terms_with_vocabulary( + query, vocabulary_index, min_frequency + ) + + # Acronym expansion (already conservative) + enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True) + + # Track vocabulary validation + validation_applied = True + + # Detect domain if available + detected_domain = vocabulary_index.detect_domain() + else: + # Fallback to standard enhancement + enhanced_query = self.expand_technical_terms(query, max_expansions=1) + enhanced_query = self.detect_and_expand_acronyms(enhanced_query, conservative=True) + validation_applied = False + detected_domain = 'unknown' + + # Calculate metrics + expansion_ratio = len(enhanced_query.split()) / len(query.split()) if query.split() else 1.0 + enhancement_time = time.perf_counter() - start_time + + return { + 'enhanced_query': enhanced_query, + 'optimal_weight': analysis['recommended_dense_weight'], + 'analysis': analysis, + 'enhancement_metadata': { + 'original_length': len(query.split()), + 'enhanced_length': len(enhanced_query.split()), + 'expansion_ratio': expansion_ratio, + 'processing_time_ms': enhancement_time * 1000, + 'techniques_applied': ['vocabulary_validation', 'conservative_expansion'], + 'vocabulary_validated': validation_applied, + 'detected_domain': detected_domain, + 'min_frequency_threshold': min_frequency + } + } + + def get_enhancement_stats(self) -> Dict[str, Any]: + """ + Get statistics about the enhancement system capabilities. + + Returns: + Dictionary with system statistics and capabilities + """ + return { + 'technical_synonyms_count': len(self.technical_synonyms), + 'acronym_expansions_count': len(self.acronym_expansions), + 'supported_domains': [ + 'embedded_systems', 'processor_architecture', 'memory_systems', + 'communication_protocols', 'real_time_systems', 'power_management' + ], + 'question_types_supported': list(self.question_type_keywords.keys()), + 'weight_range': {'min': 0.1, 'max': 0.9, 'default': 0.7}, + 'performance_targets': { + 'enhancement_time_ms': '<10', + 'accuracy_improvement': '>10%', + 'memory_overhead': '<1MB' + }, + 'vocabulary_features': { + 'vocabulary_aware_expansion': True, + 'min_frequency_filtering': True, + 'domain_detection': True, + 'technical_term_priority': True + } + } \ No newline at end of file diff --git a/shared_utils/vector_stores/retrieval/__init__.py b/shared_utils/vector_stores/retrieval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b8c0a57b3117afdb42d8f790ae79018395c7e09 --- /dev/null +++ b/shared_utils/vector_stores/retrieval/__init__.py @@ -0,0 +1,8 @@ +""" +Retrieval utilities for hybrid RAG systems. +Combines dense semantic search with sparse keyword matching. +""" + +from .hybrid_search import HybridRetriever + +__all__ = ['HybridRetriever'] \ No newline at end of file diff --git a/shared_utils/vector_stores/retrieval/__pycache__/__init__.cpython-312.pyc b/shared_utils/vector_stores/retrieval/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce20008140f9969e85a2b7e0cb550078c564dacc Binary files /dev/null and b/shared_utils/vector_stores/retrieval/__pycache__/__init__.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/retrieval/__pycache__/hybrid_search.cpython-312.pyc b/shared_utils/vector_stores/retrieval/__pycache__/hybrid_search.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94632c6fd554a283fc62a3539497705d41507ee2 Binary files /dev/null and b/shared_utils/vector_stores/retrieval/__pycache__/hybrid_search.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/retrieval/__pycache__/vocabulary_index.cpython-312.pyc b/shared_utils/vector_stores/retrieval/__pycache__/vocabulary_index.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e721138cb755c3b64e60ed7adf6fed0af1d42a1 Binary files /dev/null and b/shared_utils/vector_stores/retrieval/__pycache__/vocabulary_index.cpython-312.pyc differ diff --git a/shared_utils/vector_stores/retrieval/hybrid_search.py b/shared_utils/vector_stores/retrieval/hybrid_search.py new file mode 100644 index 0000000000000000000000000000000000000000..7a538e3b93475dcbdc3efd1225f66bc616c68b1c --- /dev/null +++ b/shared_utils/vector_stores/retrieval/hybrid_search.py @@ -0,0 +1,277 @@ +""" +Hybrid retrieval combining dense semantic search with sparse BM25 keyword matching. +Uses Reciprocal Rank Fusion (RRF) to combine results from both approaches. +""" + +from typing import List, Dict, Tuple, Optional +import numpy as np +from pathlib import Path +import sys + +# Add project root to Python path for imports +project_root = Path(__file__).parent.parent.parent / "project-1-technical-rag" +sys.path.append(str(project_root)) + +from src.sparse_retrieval import BM25SparseRetriever +from src.fusion import reciprocal_rank_fusion, adaptive_fusion +from shared_utils.embeddings.generator import generate_embeddings +import faiss + + +class HybridRetriever: + """ + Hybrid retrieval system combining dense semantic search with sparse BM25. + + Optimized for technical documentation where both semantic similarity + and exact keyword matching are important for retrieval quality. + + Performance: Sub-second search on 1000+ document corpus + """ + + def __init__( + self, + dense_weight: float = 0.7, + embedding_model: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", + use_mps: bool = True, + bm25_k1: float = 1.2, + bm25_b: float = 0.75, + rrf_k: int = 10 + ): + """ + Initialize hybrid retriever with dense and sparse components. + + Args: + dense_weight: Weight for semantic similarity in fusion (0.7 default) + embedding_model: Sentence transformer model name + use_mps: Use Apple Silicon MPS acceleration for embeddings + bm25_k1: BM25 term frequency saturation parameter + bm25_b: BM25 document length normalization parameter + rrf_k: Reciprocal Rank Fusion constant (1=strong rank preference, 2=moderate) + + Raises: + ValueError: If parameters are invalid + """ + if not 0 <= dense_weight <= 1: + raise ValueError("dense_weight must be between 0 and 1") + + self.dense_weight = dense_weight + self.embedding_model = embedding_model + self.use_mps = use_mps + self.rrf_k = rrf_k + + # Initialize sparse retriever + self.sparse_retriever = BM25SparseRetriever(k1=bm25_k1, b=bm25_b) + + # Dense retrieval components (initialized on first index) + self.dense_index: Optional[faiss.Index] = None + self.chunks: List[Dict] = [] + self.embeddings: Optional[np.ndarray] = None + + def index_documents(self, chunks: List[Dict]) -> None: + """ + Index documents for both dense and sparse retrieval. + + Args: + chunks: List of chunk dictionaries with 'text' field + + Raises: + ValueError: If chunks is empty or malformed + + Performance: ~100 chunks/second for complete indexing + """ + if not chunks: + raise ValueError("Cannot index empty chunk list") + + print(f"Indexing {len(chunks)} chunks for hybrid retrieval...") + + # Store chunks for retrieval + self.chunks = chunks + + # Index for sparse retrieval + print("Building BM25 sparse index...") + self.sparse_retriever.index_documents(chunks) + + # Index for dense retrieval + print("Building dense semantic index...") + texts = [chunk['text'] for chunk in chunks] + + # Generate embeddings + self.embeddings = generate_embeddings( + texts, + model_name=self.embedding_model, + use_mps=self.use_mps + ) + + # Create FAISS index + embedding_dim = self.embeddings.shape[1] + self.dense_index = faiss.IndexFlatIP(embedding_dim) # Inner product for cosine similarity + + # Normalize embeddings for cosine similarity + faiss.normalize_L2(self.embeddings) + self.dense_index.add(self.embeddings) + + print(f"Hybrid indexing complete: {len(chunks)} chunks ready for search") + + def search( + self, + query: str, + top_k: int = 10, + dense_top_k: Optional[int] = None, + sparse_top_k: Optional[int] = None + ) -> List[Tuple[int, float, Dict]]: + """ + Hybrid search combining dense and sparse retrieval with RRF. + + Args: + query: Search query string + top_k: Final number of results to return + dense_top_k: Results from dense search (default: 2*top_k) + sparse_top_k: Results from sparse search (default: 2*top_k) + + Returns: + List of (chunk_index, rrf_score, chunk_dict) tuples + + Raises: + ValueError: If not indexed or invalid parameters + + Performance: <200ms for 1000+ document corpus + """ + if self.dense_index is None: + raise ValueError("Must call index_documents() before searching") + + if not query.strip(): + return [] + + if top_k <= 0: + raise ValueError("top_k must be positive") + + # Set default intermediate result counts + if dense_top_k is None: + dense_top_k = min(2 * top_k, len(self.chunks)) + if sparse_top_k is None: + sparse_top_k = min(2 * top_k, len(self.chunks)) + + # Dense semantic search + dense_results = self._dense_search(query, dense_top_k) + + # Sparse BM25 search + sparse_results = self.sparse_retriever.search(query, sparse_top_k) + + # Combine using Adaptive Fusion (better for small result sets) + fused_results = adaptive_fusion( + dense_results=dense_results, + sparse_results=sparse_results, + dense_weight=self.dense_weight, + result_size=top_k + ) + + # Prepare final results with chunk content and apply source diversity + final_results = [] + for chunk_idx, rrf_score in fused_results: + chunk_dict = self.chunks[chunk_idx] + final_results.append((chunk_idx, rrf_score, chunk_dict)) + + # Apply source diversity enhancement + diverse_results = self._enhance_source_diversity(final_results, top_k) + + return diverse_results + + def _dense_search(self, query: str, top_k: int) -> List[Tuple[int, float]]: + """ + Perform dense semantic search using FAISS. + + Args: + query: Search query + top_k: Number of results to return + + Returns: + List of (chunk_index, similarity_score) tuples + """ + # Generate query embedding + query_embedding = generate_embeddings( + [query], + model_name=self.embedding_model, + use_mps=self.use_mps + ) + + # Normalize for cosine similarity + faiss.normalize_L2(query_embedding) + + # Search dense index + similarities, indices = self.dense_index.search(query_embedding, top_k) + + # Convert to required format + results = [ + (int(indices[0][i]), float(similarities[0][i])) + for i in range(len(indices[0])) + if indices[0][i] != -1 # Filter out invalid results + ] + + return results + + def _enhance_source_diversity( + self, + results: List[Tuple[int, float, Dict]], + top_k: int, + max_per_source: int = 2 + ) -> List[Tuple[int, float, Dict]]: + """ + Enhance source diversity in retrieval results to prevent over-focusing on single documents. + + Args: + results: List of (chunk_idx, score, chunk_dict) tuples sorted by relevance + top_k: Maximum number of results to return + max_per_source: Maximum chunks allowed per source document + + Returns: + Diversified results maintaining relevance while improving source coverage + """ + if not results: + return [] + + source_counts = {} + diverse_results = [] + + # First pass: Add highest scoring results respecting source limits + for chunk_idx, score, chunk_dict in results: + source = chunk_dict.get('source', 'unknown') + current_count = source_counts.get(source, 0) + + if current_count < max_per_source: + diverse_results.append((chunk_idx, score, chunk_dict)) + source_counts[source] = current_count + 1 + + if len(diverse_results) >= top_k: + break + + # Second pass: If we still need more results, relax source constraints + if len(diverse_results) < top_k: + for chunk_idx, score, chunk_dict in results: + if (chunk_idx, score, chunk_dict) not in diverse_results: + diverse_results.append((chunk_idx, score, chunk_dict)) + + if len(diverse_results) >= top_k: + break + + return diverse_results[:top_k] + + def get_retrieval_stats(self) -> Dict[str, any]: + """ + Get statistics about the indexed corpus and retrieval performance. + + Returns: + Dictionary with corpus statistics + """ + if not self.chunks: + return {"status": "not_indexed"} + + return { + "status": "indexed", + "total_chunks": len(self.chunks), + "dense_index_size": self.dense_index.ntotal if self.dense_index else 0, + "embedding_dim": self.embeddings.shape[1] if self.embeddings is not None else 0, + "sparse_indexed_chunks": len(self.sparse_retriever.chunk_mapping), + "dense_weight": self.dense_weight, + "sparse_weight": 1.0 - self.dense_weight, + "rrf_k": self.rrf_k + } \ No newline at end of file diff --git a/shared_utils/vector_stores/retrieval/vocabulary_index.py b/shared_utils/vector_stores/retrieval/vocabulary_index.py new file mode 100644 index 0000000000000000000000000000000000000000..aad0357a81988eb85dba44dd70b570fa4ecfab65 --- /dev/null +++ b/shared_utils/vector_stores/retrieval/vocabulary_index.py @@ -0,0 +1,260 @@ +""" +Vocabulary index for corpus-aware query enhancement. + +Tracks all unique terms in the document corpus to enable intelligent +synonym expansion that only adds terms actually present in documents. +""" + +from typing import Set, Dict, List, Optional +from collections import defaultdict +import re +from pathlib import Path +import json + + +class VocabularyIndex: + """ + Maintains vocabulary statistics for intelligent query enhancement. + + Features: + - Tracks all unique terms in document corpus + - Stores term frequencies for relevance weighting + - Identifies technical terms and domain vocabulary + - Enables vocabulary-aware synonym expansion + + Performance: + - Build time: ~1s per 1000 chunks + - Memory: ~3MB for 80K unique terms + - Lookup: O(1) set operations + """ + + def __init__(self): + """Initialize empty vocabulary index.""" + self.vocabulary: Set[str] = set() + self.term_frequencies: Dict[str, int] = defaultdict(int) + self.technical_terms: Set[str] = set() + self.document_frequencies: Dict[str, int] = defaultdict(int) + self.total_documents = 0 + self.total_terms = 0 + + # Regex for term extraction + self._term_pattern = re.compile(r'\b[a-zA-Z][a-zA-Z0-9\-_]*\b') + self._technical_pattern = re.compile(r'\b[A-Z]{2,}|[a-zA-Z]+[\-_][a-zA-Z]+|\b\d+[a-zA-Z]+\b') + + def build_from_chunks(self, chunks: List[Dict]) -> None: + """ + Build vocabulary index from document chunks. + + Args: + chunks: List of document chunks with 'text' field + + Performance: ~1s per 1000 chunks + """ + self.total_documents = len(chunks) + + for chunk in chunks: + text = chunk.get('text', '') + + # Extract and process terms + terms = self._extract_terms(text) + unique_terms = set(terms) + + # Update vocabulary + self.vocabulary.update(unique_terms) + + # Update frequencies + for term in terms: + self.term_frequencies[term] += 1 + self.total_terms += 1 + + # Update document frequencies + for term in unique_terms: + self.document_frequencies[term] += 1 + + # Identify technical terms + technical = self._extract_technical_terms(text) + self.technical_terms.update(technical) + + def _extract_terms(self, text: str) -> List[str]: + """Extract normalized terms from text.""" + # Convert to lowercase and extract words + text_lower = text.lower() + terms = self._term_pattern.findall(text_lower) + + # Filter short terms + return [term for term in terms if len(term) > 2] + + def _extract_technical_terms(self, text: str) -> Set[str]: + """Extract technical terms (acronyms, hyphenated, etc).""" + technical = set() + + # Find potential technical terms + matches = self._technical_pattern.findall(text) + + for match in matches: + # Normalize but preserve technical nature + normalized = match.lower() + if len(normalized) > 2: + technical.add(normalized) + + return technical + + def contains(self, term: str) -> bool: + """Check if term exists in vocabulary.""" + return term.lower() in self.vocabulary + + def get_frequency(self, term: str) -> int: + """Get term frequency in corpus.""" + return self.term_frequencies.get(term.lower(), 0) + + def get_document_frequency(self, term: str) -> int: + """Get number of documents containing term.""" + return self.document_frequencies.get(term.lower(), 0) + + def is_common_term(self, term: str, min_frequency: int = 5) -> bool: + """Check if term appears frequently enough.""" + return self.get_frequency(term) >= min_frequency + + def is_technical_term(self, term: str) -> bool: + """Check if term is identified as technical.""" + return term.lower() in self.technical_terms + + def filter_synonyms(self, synonyms: List[str], + min_frequency: int = 3, + require_technical: bool = False) -> List[str]: + """ + Filter synonym list to only include terms in vocabulary. + + Args: + synonyms: List of potential synonyms + min_frequency: Minimum term frequency required + require_technical: Only include technical terms + + Returns: + Filtered list of valid synonyms + """ + valid_synonyms = [] + + for synonym in synonyms: + # Check existence + if not self.contains(synonym): + continue + + # Check frequency threshold + if self.get_frequency(synonym) < min_frequency: + continue + + # Check technical requirement + if require_technical and not self.is_technical_term(synonym): + continue + + valid_synonyms.append(synonym) + + return valid_synonyms + + def get_vocabulary_stats(self) -> Dict[str, any]: + """Get comprehensive vocabulary statistics.""" + return { + 'unique_terms': len(self.vocabulary), + 'total_terms': self.total_terms, + 'technical_terms': len(self.technical_terms), + 'total_documents': self.total_documents, + 'avg_terms_per_doc': self.total_terms / self.total_documents if self.total_documents > 0 else 0, + 'vocabulary_richness': len(self.vocabulary) / self.total_terms if self.total_terms > 0 else 0, + 'technical_ratio': len(self.technical_terms) / len(self.vocabulary) if self.vocabulary else 0 + } + + def get_top_terms(self, n: int = 100, technical_only: bool = False) -> List[tuple]: + """ + Get most frequent terms in corpus. + + Args: + n: Number of top terms to return + technical_only: Only return technical terms + + Returns: + List of (term, frequency) tuples + """ + if technical_only: + term_freq = { + term: freq for term, freq in self.term_frequencies.items() + if term in self.technical_terms + } + else: + term_freq = self.term_frequencies + + return sorted(term_freq.items(), key=lambda x: x[1], reverse=True)[:n] + + def detect_domain(self) -> str: + """ + Detect document domain from vocabulary patterns. + + Returns: + Detected domain name + """ + # Domain detection heuristics + domain_indicators = { + 'embedded_systems': ['microcontroller', 'rtos', 'embedded', 'firmware', 'mcu'], + 'processor_architecture': ['risc-v', 'riscv', 'instruction', 'register', 'isa'], + 'regulatory': ['fda', 'validation', 'compliance', 'regulation', 'guidance'], + 'ai_ml': ['model', 'training', 'neural', 'algorithm', 'machine learning'], + 'software_engineering': ['software', 'development', 'testing', 'debugging', 'code'] + } + + domain_scores = {} + + for domain, indicators in domain_indicators.items(): + score = sum( + self.get_document_frequency(indicator) + for indicator in indicators + if self.contains(indicator) + ) + domain_scores[domain] = score + + # Return domain with highest score + if domain_scores: + return max(domain_scores, key=domain_scores.get) + return 'general' + + def save_to_file(self, path: Path) -> None: + """Save vocabulary index to JSON file.""" + data = { + 'vocabulary': list(self.vocabulary), + 'term_frequencies': dict(self.term_frequencies), + 'technical_terms': list(self.technical_terms), + 'document_frequencies': dict(self.document_frequencies), + 'total_documents': self.total_documents, + 'total_terms': self.total_terms + } + + with open(path, 'w') as f: + json.dump(data, f, indent=2) + + def load_from_file(self, path: Path) -> None: + """Load vocabulary index from JSON file.""" + with open(path, 'r') as f: + data = json.load(f) + + self.vocabulary = set(data['vocabulary']) + self.term_frequencies = defaultdict(int, data['term_frequencies']) + self.technical_terms = set(data['technical_terms']) + self.document_frequencies = defaultdict(int, data['document_frequencies']) + self.total_documents = data['total_documents'] + self.total_terms = data['total_terms'] + + def merge_with(self, other: 'VocabularyIndex') -> None: + """Merge another vocabulary index into this one.""" + # Merge vocabularies + self.vocabulary.update(other.vocabulary) + self.technical_terms.update(other.technical_terms) + + # Merge frequencies + for term, freq in other.term_frequencies.items(): + self.term_frequencies[term] += freq + + for term, doc_freq in other.document_frequencies.items(): + self.document_frequencies[term] += doc_freq + + # Update totals + self.total_documents += other.total_documents + self.total_terms += other.total_terms \ No newline at end of file