Spaces:
Sleeping
Sleeping
| """ | |
| Answer generation module using Ollama for local LLM inference. | |
| This module provides answer generation with citation support for RAG systems, | |
| optimized for technical documentation Q&A on Apple Silicon. | |
| """ | |
| import json | |
| import logging | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Any, Optional, Generator, Tuple | |
| import ollama | |
| from datetime import datetime | |
| import re | |
| from pathlib import Path | |
| import sys | |
| # Import calibration framework | |
| try: | |
| from src.confidence_calibration import ConfidenceCalibrator | |
| except ImportError: | |
| # Fallback - disable calibration for deployment | |
| ConfidenceCalibrator = None | |
| logger = logging.getLogger(__name__) | |
| class Citation: | |
| """Represents a citation to a source document chunk.""" | |
| chunk_id: str | |
| page_number: int | |
| source_file: str | |
| relevance_score: float | |
| text_snippet: str | |
| class GeneratedAnswer: | |
| """Represents a generated answer with citations.""" | |
| answer: str | |
| citations: List[Citation] | |
| confidence_score: float | |
| generation_time: float | |
| model_used: str | |
| context_used: List[Dict[str, Any]] | |
| class AnswerGenerator: | |
| """ | |
| Generates answers using local LLMs via Ollama with citation support. | |
| Optimized for technical documentation Q&A with: | |
| - Streaming response support | |
| - Citation extraction and formatting | |
| - Confidence scoring | |
| - Fallback model support | |
| """ | |
| def __init__( | |
| self, | |
| primary_model: str = "llama3.2:3b", | |
| fallback_model: str = "mistral:latest", | |
| temperature: float = 0.3, | |
| max_tokens: int = 1024, | |
| stream: bool = True, | |
| enable_calibration: bool = True | |
| ): | |
| """ | |
| Initialize the answer generator. | |
| Args: | |
| primary_model: Primary Ollama model to use | |
| fallback_model: Fallback model for complex queries | |
| temperature: Generation temperature (0.0-1.0) | |
| max_tokens: Maximum tokens to generate | |
| stream: Whether to stream responses | |
| enable_calibration: Whether to enable confidence calibration | |
| """ | |
| self.primary_model = primary_model | |
| self.fallback_model = fallback_model | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| self.stream = stream | |
| self.client = ollama.Client() | |
| # Initialize confidence calibration | |
| self.enable_calibration = enable_calibration | |
| self.calibrator = None | |
| if enable_calibration and ConfidenceCalibrator is not None: | |
| try: | |
| self.calibrator = ConfidenceCalibrator() | |
| logger.info("Confidence calibration enabled") | |
| except Exception as e: | |
| logger.warning(f"Failed to initialize calibration: {e}") | |
| self.enable_calibration = False | |
| elif enable_calibration and ConfidenceCalibrator is None: | |
| logger.warning("Calibration requested but ConfidenceCalibrator not available - disabling") | |
| self.enable_calibration = False | |
| # Verify models are available | |
| self._verify_models() | |
| def _verify_models(self) -> None: | |
| """Verify that required models are available.""" | |
| try: | |
| model_list = self.client.list() | |
| available_models = [] | |
| # Handle Ollama's ListResponse object | |
| if hasattr(model_list, 'models'): | |
| for model in model_list.models: | |
| if hasattr(model, 'model'): | |
| available_models.append(model.model) | |
| elif isinstance(model, dict) and 'model' in model: | |
| available_models.append(model['model']) | |
| if self.primary_model not in available_models: | |
| logger.warning(f"Primary model {self.primary_model} not found. Available models: {available_models}") | |
| raise ValueError(f"Model {self.primary_model} not available. Please run: ollama pull {self.primary_model}") | |
| if self.fallback_model not in available_models: | |
| logger.warning(f"Fallback model {self.fallback_model} not found in: {available_models}") | |
| except Exception as e: | |
| logger.error(f"Error verifying models: {e}") | |
| raise | |
| def _create_system_prompt(self) -> str: | |
| """Create system prompt for technical documentation Q&A.""" | |
| return """You are a technical documentation assistant that provides clear, accurate answers based on the provided context. | |
| CORE PRINCIPLES: | |
| 1. ANSWER DIRECTLY: If context contains the answer, provide it clearly and confidently | |
| 2. BE CONCISE: Keep responses focused and avoid unnecessary uncertainty language | |
| 3. CITE ACCURATELY: Use [chunk_X] citations for every fact from context | |
| RESPONSE GUIDELINES: | |
| - If context has sufficient information → Answer directly and confidently | |
| - If context has partial information → Answer what's available, note what's missing briefly | |
| - If context is irrelevant → Brief refusal: "This information isn't available in the provided documents" | |
| CITATION FORMAT: | |
| - Use [chunk_1], [chunk_2] etc. for all facts from context | |
| - Example: "According to [chunk_1], RISC-V is an open-source architecture." | |
| WHAT TO AVOID: | |
| - Do NOT add details not in context | |
| - Do NOT second-guess yourself if context is clear | |
| - Do NOT use phrases like "does not contain sufficient information" when context clearly answers the question | |
| - Do NOT be overly cautious when context is adequate | |
| Be direct, confident, and accurate. If the context answers the question, provide that answer clearly.""" | |
| def _format_context(self, chunks: List[Dict[str, Any]]) -> str: | |
| """ | |
| Format retrieved chunks into context for the LLM. | |
| Args: | |
| chunks: List of retrieved chunks with metadata | |
| Returns: | |
| Formatted context string | |
| """ | |
| context_parts = [] | |
| for i, chunk in enumerate(chunks): | |
| chunk_text = chunk.get('content', chunk.get('text', '')) | |
| page_num = chunk.get('metadata', {}).get('page_number', 'unknown') | |
| source = chunk.get('metadata', {}).get('source', 'unknown') | |
| context_parts.append( | |
| f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n" | |
| ) | |
| return "\n---\n".join(context_parts) | |
| def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]: | |
| """ | |
| Extract citations from the generated answer and integrate them naturally. | |
| Args: | |
| answer: Generated answer with [chunk_X] citations | |
| chunks: Original chunks used for context | |
| Returns: | |
| Tuple of (natural_answer, citations) | |
| """ | |
| citations = [] | |
| citation_pattern = r'\[chunk_(\d+)\]' | |
| cited_chunks = set() | |
| # Find [chunk_X] citations and collect cited chunks | |
| matches = re.finditer(citation_pattern, answer) | |
| for match in matches: | |
| chunk_idx = int(match.group(1)) - 1 # Convert to 0-based index | |
| if 0 <= chunk_idx < len(chunks): | |
| cited_chunks.add(chunk_idx) | |
| # Create Citation objects for each cited chunk | |
| chunk_to_source = {} | |
| for idx in cited_chunks: | |
| chunk = chunks[idx] | |
| citation = Citation( | |
| chunk_id=chunk.get('id', f'chunk_{idx}'), | |
| page_number=chunk.get('metadata', {}).get('page_number', 0), | |
| source_file=chunk.get('metadata', {}).get('source', 'unknown'), | |
| relevance_score=chunk.get('score', 0.0), | |
| text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...' | |
| ) | |
| citations.append(citation) | |
| # Map chunk reference to natural source name | |
| source_name = chunk.get('metadata', {}).get('source', 'unknown') | |
| if source_name != 'unknown': | |
| # Use just the filename without extension for natural reference | |
| natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ') | |
| chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation" | |
| else: | |
| chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation" | |
| # Replace [chunk_X] with natural references instead of removing them | |
| natural_answer = answer | |
| for chunk_ref, natural_ref in chunk_to_source.items(): | |
| natural_answer = natural_answer.replace(chunk_ref, natural_ref) | |
| # Clean up any remaining unreferenced citations (fallback) | |
| natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer) | |
| # Clean up multiple spaces and formatting | |
| natural_answer = re.sub(r'\s+', ' ', natural_answer).strip() | |
| return natural_answer, citations | |
| def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float: | |
| """ | |
| Calculate confidence score for the generated answer with improved calibration. | |
| Args: | |
| answer: Generated answer | |
| citations: Extracted citations | |
| chunks: Retrieved chunks | |
| Returns: | |
| Confidence score (0.0-1.0) | |
| """ | |
| # Check if no chunks were provided first | |
| if not chunks: | |
| return 0.05 # No context = very low confidence | |
| # Assess context quality to determine base confidence | |
| scores = [chunk.get('score', 0) for chunk in chunks] | |
| max_relevance = max(scores) if scores else 0 | |
| avg_relevance = sum(scores) / len(scores) if scores else 0 | |
| # Dynamic base confidence based on context quality | |
| if max_relevance >= 0.8: | |
| confidence = 0.6 # High-quality context starts high | |
| elif max_relevance >= 0.6: | |
| confidence = 0.4 # Good context starts moderately | |
| elif max_relevance >= 0.4: | |
| confidence = 0.2 # Fair context starts low | |
| else: | |
| confidence = 0.05 # Poor context starts very low | |
| # Strong uncertainty and explicit refusal indicators | |
| strong_uncertainty_phrases = [ | |
| "does not contain sufficient information", | |
| "context does not provide", | |
| "insufficient information", | |
| "cannot determine", | |
| "refuse to answer", | |
| "cannot answer", | |
| "does not contain relevant", | |
| "no relevant context", | |
| "missing from the provided context" | |
| ] | |
| # Weak uncertainty phrases that might be in nuanced but correct answers | |
| weak_uncertainty_phrases = [ | |
| "unclear", | |
| "conflicting", | |
| "not specified", | |
| "questionable", | |
| "not contained", | |
| "no mention", | |
| "no relevant", | |
| "missing", | |
| "not explicitly" | |
| ] | |
| # Check for strong uncertainty - these should drastically reduce confidence | |
| if any(phrase in answer.lower() for phrase in strong_uncertainty_phrases): | |
| return min(0.1, confidence * 0.2) # Max 10% for explicit refusal/uncertainty | |
| # Check for weak uncertainty - reduce but don't destroy confidence for good context | |
| weak_uncertainty_count = sum(1 for phrase in weak_uncertainty_phrases if phrase in answer.lower()) | |
| if weak_uncertainty_count > 0: | |
| if max_relevance >= 0.7 and citations: | |
| # Good context with citations - reduce less severely | |
| confidence *= (0.8 ** weak_uncertainty_count) # Moderate penalty | |
| else: | |
| # Poor context - reduce more severely | |
| confidence *= (0.5 ** weak_uncertainty_count) # Strong penalty | |
| # If all chunks have very low relevance scores, cap confidence low | |
| if max_relevance < 0.4: | |
| return min(0.08, confidence) # Max 8% for low relevance context | |
| # Factor 1: Citation quality and coverage | |
| if citations and chunks: | |
| citation_ratio = len(citations) / min(len(chunks), 3) | |
| # Strong boost for high-relevance citations | |
| relevant_chunks = [c for c in chunks if c.get('score', 0) > 0.6] | |
| if relevant_chunks: | |
| # Significant boost for citing relevant chunks | |
| confidence += 0.25 * citation_ratio | |
| # Extra boost if citing majority of relevant chunks | |
| if len(citations) >= len(relevant_chunks) * 0.5: | |
| confidence += 0.15 | |
| else: | |
| # Small boost for citations to lower-relevance chunks | |
| confidence += 0.1 * citation_ratio | |
| else: | |
| # No citations = reduce confidence unless it's a simple factual statement | |
| if max_relevance >= 0.8 and len(answer.split()) < 20: | |
| confidence *= 0.8 # Gentle penalty for uncited but simple answers | |
| else: | |
| confidence *= 0.6 # Stronger penalty for complex uncited answers | |
| # Factor 2: Relevance score reinforcement | |
| if citations: | |
| avg_citation_relevance = sum(c.relevance_score for c in citations) / len(citations) | |
| if avg_citation_relevance > 0.8: | |
| confidence += 0.2 # Strong boost for highly relevant citations | |
| elif avg_citation_relevance > 0.6: | |
| confidence += 0.1 # Moderate boost | |
| elif avg_citation_relevance < 0.4: | |
| confidence *= 0.6 # Penalty for low-relevance citations | |
| # Factor 3: Context utilization quality | |
| if chunks: | |
| avg_chunk_length = sum(len(chunk.get('content', chunk.get('text', ''))) for chunk in chunks) / len(chunks) | |
| # Boost for substantial, high-quality context | |
| if avg_chunk_length > 200 and max_relevance > 0.8: | |
| confidence += 0.1 | |
| elif avg_chunk_length < 50: # Very short chunks | |
| confidence *= 0.8 | |
| # Factor 4: Answer characteristics | |
| answer_words = len(answer.split()) | |
| if answer_words < 10: | |
| confidence *= 0.9 # Slight penalty for very short answers | |
| elif answer_words > 50 and citations: | |
| confidence += 0.05 # Small boost for detailed cited answers | |
| # Factor 5: High-quality scenario bonus | |
| if (max_relevance >= 0.8 and citations and | |
| len(citations) > 0 and | |
| not any(phrase in answer.lower() for phrase in strong_uncertainty_phrases)): | |
| # This is a high-quality response scenario | |
| confidence += 0.15 | |
| raw_confidence = min(confidence, 0.95) # Cap at 95% to maintain some uncertainty | |
| # Apply temperature scaling calibration if available | |
| if self.enable_calibration and self.calibrator and self.calibrator.is_fitted: | |
| try: | |
| calibrated_confidence = self.calibrator.calibrate_confidence(raw_confidence) | |
| logger.debug(f"Confidence calibrated: {raw_confidence:.3f} -> {calibrated_confidence:.3f}") | |
| return calibrated_confidence | |
| except Exception as e: | |
| logger.warning(f"Calibration failed, using raw confidence: {e}") | |
| return raw_confidence | |
| def fit_calibration(self, validation_data: List[Dict[str, Any]]) -> float: | |
| """ | |
| Fit temperature scaling calibration using validation data. | |
| Args: | |
| validation_data: List of dicts with 'confidence' and 'correctness' keys | |
| Returns: | |
| Optimal temperature parameter | |
| """ | |
| if not self.enable_calibration or not self.calibrator: | |
| logger.warning("Calibration not enabled or not available") | |
| return 1.0 | |
| try: | |
| confidences = [item['confidence'] for item in validation_data] | |
| correctness = [item['correctness'] for item in validation_data] | |
| optimal_temp = self.calibrator.fit_temperature_scaling(confidences, correctness) | |
| logger.info(f"Calibration fitted with temperature: {optimal_temp:.3f}") | |
| return optimal_temp | |
| except Exception as e: | |
| logger.error(f"Failed to fit calibration: {e}") | |
| return 1.0 | |
| def save_calibration(self, filepath: str) -> bool: | |
| """Save fitted calibration to file.""" | |
| if not self.calibrator or not self.calibrator.is_fitted: | |
| logger.warning("No fitted calibration to save") | |
| return False | |
| try: | |
| calibration_data = { | |
| 'temperature': self.calibrator.temperature, | |
| 'is_fitted': self.calibrator.is_fitted, | |
| 'model_info': { | |
| 'primary_model': self.primary_model, | |
| 'fallback_model': self.fallback_model | |
| } | |
| } | |
| with open(filepath, 'w') as f: | |
| json.dump(calibration_data, f, indent=2) | |
| logger.info(f"Calibration saved to {filepath}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to save calibration: {e}") | |
| return False | |
| def load_calibration(self, filepath: str) -> bool: | |
| """Load fitted calibration from file.""" | |
| if not self.enable_calibration or not self.calibrator: | |
| logger.warning("Calibration not enabled") | |
| return False | |
| try: | |
| with open(filepath, 'r') as f: | |
| calibration_data = json.load(f) | |
| self.calibrator.temperature = calibration_data['temperature'] | |
| self.calibrator.is_fitted = calibration_data['is_fitted'] | |
| logger.info(f"Calibration loaded from {filepath} (temp: {self.calibrator.temperature:.3f})") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load calibration: {e}") | |
| return False | |
| def generate( | |
| self, | |
| query: str, | |
| chunks: List[Dict[str, Any]], | |
| use_fallback: bool = False | |
| ) -> GeneratedAnswer: | |
| """ | |
| Generate an answer based on the query and retrieved chunks. | |
| Args: | |
| query: User's question | |
| chunks: Retrieved document chunks | |
| use_fallback: Whether to use fallback model | |
| Returns: | |
| GeneratedAnswer object with answer, citations, and metadata | |
| """ | |
| start_time = datetime.now() | |
| model = self.fallback_model if use_fallback else self.primary_model | |
| # Check for no-context or very poor context situation | |
| if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks): | |
| # Handle no-context situation with brief, professional refusal | |
| user_prompt = f"""Context: [NO RELEVANT CONTEXT FOUND] | |
| Question: {query} | |
| INSTRUCTION: Respond with exactly this brief message: | |
| "This information isn't available in the provided documents." | |
| DO NOT elaborate, explain, or add any other information.""" | |
| else: | |
| # Format context from chunks | |
| context = self._format_context(chunks) | |
| # Create concise prompt for faster generation | |
| user_prompt = f"""Context: | |
| {context} | |
| Question: {query} | |
| Instructions: Answer using only the context above. Cite with [chunk_1], [chunk_2] etc. | |
| Answer:""" | |
| try: | |
| # Generate response | |
| response = self.client.chat( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": self._create_system_prompt()}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| options={ | |
| "temperature": self.temperature, | |
| "num_predict": min(self.max_tokens, 300), # Reduce max tokens for speed | |
| "top_k": 40, # Optimize sampling for speed | |
| "top_p": 0.9, | |
| "repeat_penalty": 1.1 | |
| }, | |
| stream=False # Get complete response for processing | |
| ) | |
| # Extract answer | |
| answer_with_citations = response['message']['content'] | |
| # Extract and clean citations | |
| clean_answer, citations = self._extract_citations(answer_with_citations, chunks) | |
| # Calculate confidence | |
| confidence = self._calculate_confidence(clean_answer, citations, chunks) | |
| # Calculate generation time | |
| generation_time = (datetime.now() - start_time).total_seconds() | |
| return GeneratedAnswer( | |
| answer=clean_answer, | |
| citations=citations, | |
| confidence_score=confidence, | |
| generation_time=generation_time, | |
| model_used=model, | |
| context_used=chunks | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error generating answer: {e}") | |
| # Return a fallback response | |
| return GeneratedAnswer( | |
| answer="I apologize, but I encountered an error while generating the answer. Please try again.", | |
| citations=[], | |
| confidence_score=0.0, | |
| generation_time=0.0, | |
| model_used=model, | |
| context_used=chunks | |
| ) | |
| def generate_stream( | |
| self, | |
| query: str, | |
| chunks: List[Dict[str, Any]], | |
| use_fallback: bool = False | |
| ) -> Generator[str, None, GeneratedAnswer]: | |
| """ | |
| Generate an answer with streaming support. | |
| Args: | |
| query: User's question | |
| chunks: Retrieved document chunks | |
| use_fallback: Whether to use fallback model | |
| Yields: | |
| Partial answer strings | |
| Returns: | |
| Final GeneratedAnswer object | |
| """ | |
| start_time = datetime.now() | |
| model = self.fallback_model if use_fallback else self.primary_model | |
| # Check for no-context or very poor context situation | |
| if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks): | |
| # Handle no-context situation with brief, professional refusal | |
| user_prompt = f"""Context: [NO RELEVANT CONTEXT FOUND] | |
| Question: {query} | |
| INSTRUCTION: Respond with exactly this brief message: | |
| "This information isn't available in the provided documents." | |
| DO NOT elaborate, explain, or add any other information.""" | |
| else: | |
| # Format context from chunks | |
| context = self._format_context(chunks) | |
| # Create concise prompt for faster generation | |
| user_prompt = f"""Context: | |
| {context} | |
| Question: {query} | |
| Instructions: Answer using only the context above. Cite with [chunk_1], [chunk_2] etc. | |
| Answer:""" | |
| try: | |
| # Generate streaming response | |
| stream = self.client.chat( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": self._create_system_prompt()}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| options={ | |
| "temperature": self.temperature, | |
| "num_predict": min(self.max_tokens, 300), # Reduce max tokens for speed | |
| "top_k": 40, # Optimize sampling for speed | |
| "top_p": 0.9, | |
| "repeat_penalty": 1.1 | |
| }, | |
| stream=True | |
| ) | |
| # Collect full answer while streaming | |
| full_answer = "" | |
| for chunk in stream: | |
| if 'message' in chunk and 'content' in chunk['message']: | |
| partial = chunk['message']['content'] | |
| full_answer += partial | |
| yield partial | |
| # Process complete answer | |
| clean_answer, citations = self._extract_citations(full_answer, chunks) | |
| confidence = self._calculate_confidence(clean_answer, citations, chunks) | |
| generation_time = (datetime.now() - start_time).total_seconds() | |
| return GeneratedAnswer( | |
| answer=clean_answer, | |
| citations=citations, | |
| confidence_score=confidence, | |
| generation_time=generation_time, | |
| model_used=model, | |
| context_used=chunks | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in streaming generation: {e}") | |
| yield "I apologize, but I encountered an error while generating the answer." | |
| return GeneratedAnswer( | |
| answer="Error occurred during generation.", | |
| citations=[], | |
| confidence_score=0.0, | |
| generation_time=0.0, | |
| model_used=model, | |
| context_used=chunks | |
| ) | |
| def format_answer_with_citations(self, generated_answer: GeneratedAnswer) -> str: | |
| """ | |
| Format the generated answer with citations for display. | |
| Args: | |
| generated_answer: GeneratedAnswer object | |
| Returns: | |
| Formatted string with answer and citations | |
| """ | |
| formatted = f"{generated_answer.answer}\n\n" | |
| if generated_answer.citations: | |
| formatted += "**Sources:**\n" | |
| for i, citation in enumerate(generated_answer.citations, 1): | |
| formatted += f"{i}. {citation.source_file} (Page {citation.page_number})\n" | |
| formatted += f"\n*Confidence: {generated_answer.confidence_score:.1%} | " | |
| formatted += f"Model: {generated_answer.model_used} | " | |
| formatted += f"Time: {generated_answer.generation_time:.2f}s*" | |
| return formatted | |
| if __name__ == "__main__": | |
| # Example usage | |
| generator = AnswerGenerator() | |
| # Example chunks (would come from retrieval system) | |
| example_chunks = [ | |
| { | |
| "id": "chunk_1", | |
| "content": "RISC-V is an open-source instruction set architecture (ISA) based on reduced instruction set computer (RISC) principles.", | |
| "metadata": {"page_number": 1, "source": "riscv-spec.pdf"}, | |
| "score": 0.95 | |
| }, | |
| { | |
| "id": "chunk_2", | |
| "content": "The RISC-V ISA is designed to support a wide range of implementations including 32-bit, 64-bit, and 128-bit variants.", | |
| "metadata": {"page_number": 2, "source": "riscv-spec.pdf"}, | |
| "score": 0.89 | |
| } | |
| ] | |
| # Generate answer | |
| result = generator.generate( | |
| query="What is RISC-V?", | |
| chunks=example_chunks | |
| ) | |
| # Display formatted result | |
| print(generator.format_answer_with_citations(result)) |