|
|
""" |
|
|
Gradio Chatbot Interface for CGT-LLM-Beta RAG System |
|
|
|
|
|
This application provides a web interface for the RAG chatbot with OAuth authentication. |
|
|
It uses Hugging Face Inference API with OAuth tokens for authentication. |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import argparse |
|
|
import sys |
|
|
import os |
|
|
from typing import Tuple, Optional, List |
|
|
import logging |
|
|
import textstat |
|
|
import torch |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
try: |
|
|
from bot import RAGBot, parse_args, Chunk |
|
|
BOT_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
logger.error(f"Failed to import bot module: {e}") |
|
|
BOT_AVAILABLE = False |
|
|
|
|
|
class RAGBot: |
|
|
pass |
|
|
class Chunk: |
|
|
pass |
|
|
def parse_args(): |
|
|
return None |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import InferenceClient |
|
|
HF_INFERENCE_AVAILABLE = True |
|
|
except ImportError: |
|
|
HF_INFERENCE_AVAILABLE = False |
|
|
logger.warning("huggingface_hub not available, InferenceClient will not work") |
|
|
|
|
|
|
|
|
MODEL_MAP = { |
|
|
"Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", |
|
|
"Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2", |
|
|
"Llama-4-Scout-17B-16E-Instruct": "meta-llama/Llama-4-Scout-17B-16E-Instruct", |
|
|
"MediPhi-Instruct": "microsoft/MediPhi-Instruct", |
|
|
"MediPhi": "microsoft/MediPhi", |
|
|
"Phi-4-reasoning": "microsoft/Phi-4-reasoning", |
|
|
} |
|
|
|
|
|
|
|
|
EDUCATION_LEVELS = { |
|
|
"Middle School": "middle_school", |
|
|
"High School": "high_school", |
|
|
"College": "college", |
|
|
"Doctoral": "doctoral" |
|
|
} |
|
|
|
|
|
|
|
|
EXAMPLE_QUESTIONS = [ |
|
|
"Can a BRCA2 variant skip a generation?", |
|
|
"Can a PMS2 variant skip a generation?", |
|
|
"Can an EPCAM/MSH2 variant skip a generation?", |
|
|
"Can an MLH1 variant skip a generation?", |
|
|
"Can an MSH2 variant skip a generation?", |
|
|
"Can an MSH6 variant skip a generation?", |
|
|
"Can I pass this MSH2 variant to my kids?", |
|
|
"Can only women carry a BRCA inherited mutation?", |
|
|
"Does GINA cover life or disability insurance?", |
|
|
"Does having a BRCA1 mutation mean I will definitely have cancer?", |
|
|
"Does having a BRCA2 mutation mean I will definitely have cancer?", |
|
|
"Does having a PMS2 mutation mean I will definitely have cancer?", |
|
|
"Does having an EPCAM/MSH2 mutation mean I will definitely have cancer?", |
|
|
"Does having an MLH1 mutation mean I will definitely have cancer?", |
|
|
"Does having an MSH2 mutation mean I will definitely have cancer?", |
|
|
"Does having an MSH6 mutation mean I will definitely have cancer?", |
|
|
"Does this BRCA1 genetic variant affect my cancer treatment?", |
|
|
"Does this BRCA2 genetic variant affect my cancer treatment?", |
|
|
"Does this EPCAM/MSH2 genetic variant affect my cancer treatment?", |
|
|
"Does this MLH1 genetic variant affect my cancer treatment?", |
|
|
"Does this MSH2 genetic variant affect my cancer treatment?", |
|
|
"Does this MSH6 genetic variant affect my cancer treatment?", |
|
|
"Does this PMS2 genetic variant affect my cancer treatment?", |
|
|
"How can I cope with this diagnosis?", |
|
|
"How can I get my kids tested?", |
|
|
"How can I help others with my condition?", |
|
|
"How might my genetic test results change over time?", |
|
|
"I don't talk to my family/parents/sister/brother. How can I share this with them?", |
|
|
"I have a BRCA pathogenic variant and I want to have children, what are my options?", |
|
|
"Is genetic testing for my family members covered by insurance?", |
|
|
"Is new research being done on my condition?", |
|
|
"Is this BRCA1 variant something I inherited?", |
|
|
"Is this BRCA2 variant something I inherited?", |
|
|
"Is this EPCAM/MSH2 variant something I inherited?", |
|
|
"Is this MLH1 variant something I inherited?", |
|
|
"Is this MSH2 variant something I inherited?", |
|
|
"Is this MSH6 variant something I inherited?", |
|
|
"Is this PMS2 variant something I inherited?", |
|
|
"My relative doesn't have insurance. What should they do?", |
|
|
"People who test positive for a genetic mutation are they at risk of losing their health insurance?", |
|
|
"Should I contact my male and female relatives?", |
|
|
"Should my family members get tested?", |
|
|
"What are the Risks and Benefits of Risk-Reducing Surgeries for Lynch Syndrome?", |
|
|
"What are the recommendations for my family members if I have a BRCA1 mutation?", |
|
|
"What are the recommendations for my family members if I have a BRCA2 mutation?", |
|
|
"What are the recommendations for my family members if I have a PMS2 mutation?", |
|
|
"What are the recommendations for my family members if I have an EPCAM/MSH2 mutation?", |
|
|
"What are the recommendations for my family members if I have an MLH1 mutation?", |
|
|
"What are the recommendations for my family members if I have an MSH2 mutation?", |
|
|
"What are the recommendations for my family members if I have an MSH6 mutation?", |
|
|
"What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have a BRCA mutation?", |
|
|
"What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an EPCAM/MSH2 mutation?", |
|
|
"What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an MSH2 mutation?", |
|
|
"What does a BRCA1 genetic variant mean for me?", |
|
|
"What does a BRCA2 genetic variant mean for me?", |
|
|
"What does a PMS2 genetic variant mean for me?", |
|
|
"What does an EPCAM/MSH2 genetic variant mean for me?", |
|
|
"What does an MLH1 genetic variant mean for me?", |
|
|
"What does an MSH2 genetic variant mean for me?", |
|
|
"What does an MSH6 genetic variant mean for me?", |
|
|
"What if I feel overwhelmed?", |
|
|
"What if I want to have children and have a hereditary cancer gene? What are my reproductive options?", |
|
|
"What if a family member doesn't want to get tested?", |
|
|
"What is Lynch Syndrome?", |
|
|
"What is my cancer risk if I have BRCA1 Hereditary Breast and Ovarian Cancer syndrome?", |
|
|
"What is my cancer risk if I have BRCA2 Hereditary Breast and Ovarian Cancer syndrome?", |
|
|
"What is my cancer risk if I have MLH1 Lynch syndrome?", |
|
|
"What is my cancer risk if I have MSH2 or EPCAM-associated Lynch syndrome?", |
|
|
"What is my cancer risk if I have MSH6 Lynch syndrome?", |
|
|
"What is my cancer risk if I have PMS2 Lynch syndrome?", |
|
|
"What other resources are available to help me?", |
|
|
"What screening tests do you recommend for BRCA1 carriers?", |
|
|
"What screening tests do you recommend for BRCA2 carriers?", |
|
|
"What screening tests do you recommend for EPCAM/MSH2 carriers?", |
|
|
"What screening tests do you recommend for MLH1 carriers?", |
|
|
"What screening tests do you recommend for MSH2 carriers?", |
|
|
"What screening tests do you recommend for MSH6 carriers?", |
|
|
"What screening tests do you recommend for PMS2 carriers?", |
|
|
"What steps can I take to manage my cancer risk if I have Lynch syndrome?", |
|
|
"What types of cancers am I at risk for with a BRCA1 mutation?", |
|
|
"What types of cancers am I at risk for with a BRCA2 mutation?", |
|
|
"What types of cancers am I at risk for with a PMS2 mutation?", |
|
|
"What types of cancers am I at risk for with an EPCAM/MSH2 mutation?", |
|
|
"What types of cancers am I at risk for with an MLH1 mutation?", |
|
|
"What types of cancers am I at risk for with an MSH2 mutation?", |
|
|
"What types of cancers am I at risk for with an MSH6 mutation?", |
|
|
"Where can I find a genetic counselor?", |
|
|
"Which of my relatives are at risk?", |
|
|
"Who are my first-degree relatives?", |
|
|
"Who do my family members call to have genetic testing?", |
|
|
"Why do some families with Lynch syndrome have more cases of cancer than others?", |
|
|
"Why should I share my BRCA1 genetic results with family?", |
|
|
"Why should I share my BRCA2 genetic results with family?", |
|
|
"Why should I share my EPCAM/MSH2 genetic results with family?", |
|
|
"Why should I share my MLH1 genetic results with family?", |
|
|
"Why should I share my MSH2 genetic results with family?", |
|
|
"Why should I share my MSH6 genetic results with family?", |
|
|
"Why should I share my PMS2 genetic results with family?", |
|
|
"Why would my relatives want to know if they have this? What can they do about it?", |
|
|
"Will my insurance cover testing for my parents/brother/sister?", |
|
|
"Will this affect my health insurance?", |
|
|
] |
|
|
|
|
|
|
|
|
class InferenceAPIBot: |
|
|
"""Wrapper that uses Hugging Face Inference API with OAuth token""" |
|
|
|
|
|
def __init__(self, bot: RAGBot): |
|
|
"""Initialize with a RAGBot (for vector DB)""" |
|
|
self.bot = bot |
|
|
self.current_model = bot.args.model |
|
|
logger.info(f"InferenceAPIBot initialized with model: {self.current_model}") |
|
|
|
|
|
def _get_client(self, hf_token: Optional[str] = None) -> InferenceClient: |
|
|
"""Create InferenceClient with token (can be None for public models)""" |
|
|
if hf_token: |
|
|
return InferenceClient(token=hf_token) |
|
|
else: |
|
|
|
|
|
return InferenceClient() |
|
|
|
|
|
@property |
|
|
def args(self): |
|
|
"""Access args from the wrapped bot""" |
|
|
return self.bot.args |
|
|
|
|
|
def generate_answer(self, prompt: str, hf_token: Optional[str] = None, **kwargs) -> str: |
|
|
"""Generate answer using Inference API""" |
|
|
try: |
|
|
max_tokens = kwargs.get('max_new_tokens', 512) |
|
|
temperature = kwargs.get('temperature', 0.2) |
|
|
top_p = kwargs.get('top_p', 0.9) |
|
|
|
|
|
|
|
|
client = self._get_client(hf_token) |
|
|
|
|
|
|
|
|
logger.info(f"Calling Inference API for model: {self.current_model} (token: {'provided' if hf_token else 'not provided'})") |
|
|
response = client.text_generation( |
|
|
prompt, |
|
|
model=self.current_model, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
return_full_text=False, |
|
|
) |
|
|
logger.info(f"Inference API response received (length: {len(response) if response else 0})") |
|
|
return response |
|
|
except Exception as e: |
|
|
error_msg = str(e).lower() |
|
|
logger.error(f"Error calling Inference API: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
if 'authentication' in error_msg or 'token' in error_msg or '401' in error_msg or '403' in error_msg: |
|
|
return "⚠️ **Authentication Required**\n\nThis model requires authentication. Please log in using the Hugging Face login button in the sidebar, or ensure you have a valid HF_TOKEN set." |
|
|
elif 'not found' in error_msg or '404' in error_msg: |
|
|
return f"⚠️ **Model Not Found**\n\nThe model '{self.current_model}' could not be found. Please check the model name or try a different model." |
|
|
else: |
|
|
return f"⚠️ **Error generating answer**\n\n{str(e)}\n\nPlease check the logs for more details or try again." |
|
|
|
|
|
def enhance_readability(self, answer: str, target_level: str = "middle_school", hf_token: Optional[str] = None) -> Tuple[str, float]: |
|
|
"""Enhance readability using Inference API""" |
|
|
try: |
|
|
|
|
|
if target_level == "middle_school": |
|
|
level_description = "middle school reading level (ages 12-14, 6th-8th grade)" |
|
|
instructions = """ |
|
|
- Use simpler medical terms or explain them |
|
|
- Medium-length sentences |
|
|
- Clear, structured explanations |
|
|
- Keep important medical information accessible""" |
|
|
elif target_level == "high_school": |
|
|
level_description = "high school reading level (ages 15-18, 9th-12th grade)" |
|
|
instructions = """ |
|
|
- Use appropriate medical terminology with context |
|
|
- Varied sentence length |
|
|
- Comprehensive yet accessible explanations |
|
|
- Maintain technical accuracy while ensuring clarity""" |
|
|
elif target_level == "college": |
|
|
level_description = "college reading level (undergraduate level, ages 18-22)" |
|
|
instructions = """ |
|
|
- Use standard medical terminology with brief explanations |
|
|
- Professional and clear writing style |
|
|
- Include relevant clinical context |
|
|
- Maintain scientific accuracy and precision |
|
|
- Appropriate for undergraduate students in health sciences""" |
|
|
elif target_level == "doctoral": |
|
|
level_description = "doctoral/professional reading level (graduate level, medical professionals)" |
|
|
instructions = """ |
|
|
- Use advanced medical and scientific terminology |
|
|
- Include detailed clinical and research context |
|
|
- Reference specific mechanisms, pathways, and evidence |
|
|
- Provide comprehensive technical explanations |
|
|
- Appropriate for medical professionals, researchers, and graduate students |
|
|
- Include nuanced discussions of clinical implications and research findings""" |
|
|
else: |
|
|
raise ValueError(f"Unknown target_level: {target_level}") |
|
|
|
|
|
system_message = f"""You are a helpful medical assistant who specializes in explaining complex medical information at appropriate reading levels. Rewrite the following medical answer for {level_description}: |
|
|
{instructions} |
|
|
- Keep the same important information but adapt the complexity |
|
|
- Provide context for technical terms |
|
|
- Ensure the answer is informative yet understandable""" |
|
|
|
|
|
user_message = f"Please rewrite this medical answer for {level_description}:\n\n{answer}" |
|
|
|
|
|
|
|
|
combined_prompt = f"{system_message}\n\n{user_message}" |
|
|
logger.info(f"Enhancing readability for {target_level} level") |
|
|
|
|
|
|
|
|
client = self._get_client(hf_token) |
|
|
|
|
|
max_tokens = 512 if target_level in ["college", "doctoral"] else 384 |
|
|
temperature = 0.4 if target_level in ["college", "doctoral"] else 0.3 |
|
|
|
|
|
enhanced_answer = client.text_generation( |
|
|
combined_prompt, |
|
|
model=self.current_model, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
return_full_text=False, |
|
|
) |
|
|
|
|
|
cleaned = self.bot._clean_readability_answer(enhanced_answer, target_level) |
|
|
|
|
|
|
|
|
try: |
|
|
flesch_score = textstat.flesch_kincaid_grade(cleaned) |
|
|
except: |
|
|
flesch_score = 0.0 |
|
|
|
|
|
return cleaned, flesch_score |
|
|
except Exception as e: |
|
|
logger.error(f"Error enhancing readability: {e}", exc_info=True) |
|
|
return answer, 0.0 |
|
|
|
|
|
|
|
|
def format_prompt(self, context_chunks: List[Chunk], question: str) -> str: |
|
|
return self.bot.format_prompt(context_chunks, question) |
|
|
|
|
|
def retrieve_with_scores(self, query: str, k: int) -> Tuple[List[Chunk], List[float]]: |
|
|
return self.bot.retrieve_with_scores(query, k) |
|
|
|
|
|
def _categorize_question(self, question: str) -> str: |
|
|
return self.bot._categorize_question(question) |
|
|
|
|
|
@property |
|
|
def vector_retriever(self): |
|
|
return self.bot.vector_retriever |
|
|
|
|
|
|
|
|
class GradioRAGInterface: |
|
|
"""Wrapper class to integrate RAGBot with Gradio using OAuth""" |
|
|
|
|
|
def __init__(self, initial_bot: RAGBot): |
|
|
|
|
|
if HF_INFERENCE_AVAILABLE: |
|
|
self.bot = InferenceAPIBot(initial_bot) |
|
|
self.use_inference_api = True |
|
|
logger.info("Using Hugging Face Inference API with OAuth") |
|
|
else: |
|
|
self.bot = initial_bot |
|
|
self.use_inference_api = False |
|
|
logger.warning("Inference API not available, falling back to local model") |
|
|
|
|
|
|
|
|
self.current_model = self.bot.args.model if hasattr(self.bot, 'args') else getattr(self.bot, 'current_model', None) |
|
|
if self.current_model is None and hasattr(self.bot, 'bot'): |
|
|
self.current_model = self.bot.bot.args.model |
|
|
self.data_dir = initial_bot.args.data_dir |
|
|
logger.info("GradioRAGInterface initialized") |
|
|
|
|
|
def _find_file_path(self, filename: str) -> str: |
|
|
"""Find the full file path for a given filename""" |
|
|
from pathlib import Path |
|
|
data_path = Path(self.data_dir) |
|
|
|
|
|
if not data_path.exists(): |
|
|
return "" |
|
|
|
|
|
|
|
|
for file_path in data_path.rglob(filename): |
|
|
return str(file_path) |
|
|
|
|
|
return "" |
|
|
|
|
|
def reload_model(self, model_short_name: str) -> str: |
|
|
"""Reload the model when user selects a different one""" |
|
|
if model_short_name not in MODEL_MAP: |
|
|
return f"Error: Unknown model '{model_short_name}'" |
|
|
|
|
|
new_model_path = MODEL_MAP[model_short_name] |
|
|
|
|
|
|
|
|
if new_model_path == self.current_model: |
|
|
return f"Model already loaded: {model_short_name}" |
|
|
|
|
|
try: |
|
|
logger.info(f"Switching model from {self.current_model} to {new_model_path}") |
|
|
|
|
|
if self.use_inference_api: |
|
|
|
|
|
self.bot.current_model = new_model_path |
|
|
self.current_model = new_model_path |
|
|
return f"✓ Model switched to: {model_short_name} (using Inference API)" |
|
|
else: |
|
|
|
|
|
self.bot.args.model = new_model_path |
|
|
|
|
|
|
|
|
if hasattr(self.bot, 'model') and self.bot.model is not None: |
|
|
del self.bot.model |
|
|
del self.bot.tokenizer |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
self.bot._load_model() |
|
|
self.current_model = new_model_path |
|
|
|
|
|
return f"✓ Model loaded: {model_short_name}" |
|
|
except Exception as e: |
|
|
logger.error(f"Error reloading model: {e}", exc_info=True) |
|
|
return f"✗ Error loading model: {str(e)}" |
|
|
|
|
|
def process_question( |
|
|
self, |
|
|
question: str, |
|
|
model_name: str, |
|
|
education_level: str, |
|
|
k: int, |
|
|
temperature: float, |
|
|
max_tokens: int, |
|
|
hf_token: Optional[str] = None |
|
|
) -> Tuple[str, str, str, str, str]: |
|
|
""" |
|
|
Process a single question and return formatted results |
|
|
|
|
|
Returns: |
|
|
Tuple of (answer, flesch_score, sources, similarity_scores, question_category) |
|
|
""" |
|
|
import time |
|
|
|
|
|
if not question or not question.strip(): |
|
|
return "Please enter a question.", "N/A", "", "", "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
logger.info(f"Processing question: {question[:50]}...") |
|
|
|
|
|
|
|
|
if model_name in MODEL_MAP: |
|
|
model_path = MODEL_MAP[model_name] |
|
|
if model_path != self.current_model: |
|
|
logger.info(f"Model changed, reloading from {self.current_model} to {model_path}") |
|
|
reload_status = self.reload_model(model_name) |
|
|
if reload_status.startswith("✗"): |
|
|
return f"Error: {reload_status}", "N/A", "", "", "" |
|
|
logger.info(f"Model reloaded in {time.time() - start_time:.1f}s") |
|
|
|
|
|
|
|
|
self.bot.args.k = k |
|
|
self.bot.args.temperature = temperature |
|
|
self.bot.args.max_new_tokens = min(max_tokens, 512) |
|
|
|
|
|
|
|
|
logger.info("Categorizing question...") |
|
|
question_group = self.bot._categorize_question(question) |
|
|
|
|
|
|
|
|
logger.info("Retrieving relevant documents...") |
|
|
retrieve_start = time.time() |
|
|
context_chunks, similarity_scores = self.bot.retrieve_with_scores(question, k) |
|
|
logger.info(f"Retrieved {len(context_chunks)} chunks in {time.time() - retrieve_start:.2f}s") |
|
|
|
|
|
if not context_chunks: |
|
|
return ( |
|
|
"I don't have enough information to answer this question. Please try rephrasing or asking about a different topic.", |
|
|
"N/A", |
|
|
"No sources found", |
|
|
"No matches found", |
|
|
question_group |
|
|
) |
|
|
|
|
|
|
|
|
similarity_scores_str = ", ".join([f"{score:.3f}" for score in similarity_scores]) |
|
|
|
|
|
|
|
|
sources_list = [] |
|
|
for i, (chunk, score) in enumerate(zip(context_chunks, similarity_scores)): |
|
|
file_path = self._find_file_path(chunk.filename) |
|
|
|
|
|
source_info = f""" |
|
|
{'='*80} |
|
|
SOURCE {i+1} | Similarity: {score:.3f} |
|
|
{'='*80} |
|
|
📄 File: {chunk.filename} |
|
|
📍 Path: {file_path if file_path else 'File path not found (search in Data Resources directory)'} |
|
|
📊 Chunk: {chunk.chunk_id + 1}/{chunk.total_chunks} (Position: {chunk.start_pos}-{chunk.end_pos}) |
|
|
|
|
|
📝 Full Chunk Text: |
|
|
{chunk.text} |
|
|
|
|
|
""" |
|
|
sources_list.append(source_info) |
|
|
|
|
|
sources = "\n".join(sources_list) |
|
|
|
|
|
|
|
|
gen_kwargs = { |
|
|
'max_new_tokens': min(max_tokens, 512), |
|
|
'temperature': temperature, |
|
|
'top_p': self.bot.args.top_p, |
|
|
'repetition_penalty': self.bot.args.repetition_penalty |
|
|
} |
|
|
|
|
|
|
|
|
answer = "" |
|
|
flesch_score = 0.0 |
|
|
|
|
|
|
|
|
logger.info("Generating original answer...") |
|
|
gen_start = time.time() |
|
|
prompt = self.bot.format_prompt(context_chunks, question) |
|
|
original_answer = self.bot.generate_answer(prompt, hf_token=hf_token, **gen_kwargs) |
|
|
logger.info(f"Original answer generated in {time.time() - gen_start:.1f}s") |
|
|
|
|
|
|
|
|
logger.info(f"Enhancing answer for {education_level} level...") |
|
|
enhance_start = time.time() |
|
|
if education_level == "middle_school": |
|
|
answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="middle_school", hf_token=hf_token) |
|
|
elif education_level == "high_school": |
|
|
answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="high_school", hf_token=hf_token) |
|
|
elif education_level == "college": |
|
|
answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="college", hf_token=hf_token) |
|
|
elif education_level == "doctoral": |
|
|
answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="doctoral", hf_token=hf_token) |
|
|
else: |
|
|
answer = "Invalid education level selected." |
|
|
flesch_score = 0.0 |
|
|
|
|
|
logger.info(f"Answer enhanced in {time.time() - enhance_start:.1f}s") |
|
|
total_time = time.time() - start_time |
|
|
logger.info(f"Total processing time: {total_time:.1f}s") |
|
|
|
|
|
|
|
|
import re |
|
|
cleaned_answer = answer |
|
|
|
|
|
|
|
|
special_tokens = [ |
|
|
"<|end|>", |
|
|
"<|endoftext|>", |
|
|
"<|end_of_text|>", |
|
|
"<|eot_id|>", |
|
|
"<|start_header_id|>", |
|
|
"<|end_header_id|>", |
|
|
"<|assistant|>", |
|
|
"<|endoftext|>", |
|
|
"<|end_of_text|>", |
|
|
] |
|
|
for token in special_tokens: |
|
|
cleaned_answer = re.sub(re.escape(token), '', cleaned_answer, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
cleaned_answer = re.sub(r'<\|[^|]+\|>', '', cleaned_answer) |
|
|
cleaned_answer = re.sub(r'^\*\*.*?\*\*.*?\n', '', cleaned_answer, flags=re.MULTILINE) |
|
|
cleaned_answer = re.sub(r'\n\s*\n\s*\n+', '\n\n', cleaned_answer) |
|
|
cleaned_answer = re.sub(r'^\s+|\s+$', '', cleaned_answer, flags=re.MULTILINE) |
|
|
cleaned_answer = cleaned_answer.strip() |
|
|
|
|
|
return ( |
|
|
cleaned_answer, |
|
|
f"{flesch_score:.1f}", |
|
|
sources, |
|
|
similarity_scores_str, |
|
|
question_group |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing question: {e}", exc_info=True) |
|
|
return ( |
|
|
f"An error occurred while processing your question: {str(e)}", |
|
|
"N/A", |
|
|
"", |
|
|
"", |
|
|
"Error" |
|
|
) |
|
|
|
|
|
|
|
|
def create_interface(initial_bot: RAGBot) -> gr.Blocks: |
|
|
"""Create and configure the Gradio interface with OAuth""" |
|
|
|
|
|
try: |
|
|
interface = GradioRAGInterface(initial_bot) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create GradioRAGInterface: {e}") |
|
|
with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: |
|
|
gr.Markdown(f""" |
|
|
# ⚠️ Initialization Error |
|
|
|
|
|
Failed to initialize the chatbot interface. |
|
|
|
|
|
**Error:** {str(e)} |
|
|
|
|
|
Please check the logs for more details. |
|
|
""") |
|
|
return demo |
|
|
|
|
|
|
|
|
initial_model_short = None |
|
|
for short_name, full_path in MODEL_MAP.items(): |
|
|
if full_path == initial_bot.args.model: |
|
|
initial_model_short = short_name |
|
|
break |
|
|
if initial_model_short is None: |
|
|
initial_model_short = list(MODEL_MAP.keys())[0] |
|
|
|
|
|
|
|
|
try: |
|
|
with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: |
|
|
with gr.Sidebar(): |
|
|
gr.LoginButton() |
|
|
gr.Markdown("### 🔐 Authentication") |
|
|
gr.Markdown("Please log in with your Hugging Face account to use the Inference API.") |
|
|
|
|
|
gr.Markdown(""" |
|
|
# 🧬 CGT-LLM-Beta: Genetic Counseling RAG Chatbot |
|
|
|
|
|
Ask questions about genetic counseling, cascade genetic testing, hereditary cancer syndromes, and related topics. |
|
|
|
|
|
The chatbot uses a Retrieval-Augmented Generation (RAG) system to provide evidence-based answers from medical literature. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
question_input = gr.Textbox( |
|
|
label="Your Question", |
|
|
placeholder="e.g., What is Lynch Syndrome? What screening is recommended for BRCA1 carriers?", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=list(MODEL_MAP.keys()), |
|
|
value=initial_model_short, |
|
|
label="Select Model", |
|
|
info="Choose which LLM model to use for generating answers" |
|
|
) |
|
|
|
|
|
education_dropdown = gr.Dropdown( |
|
|
choices=list(EDUCATION_LEVELS.keys()), |
|
|
value=list(EDUCATION_LEVELS.keys())[0], |
|
|
label="Education Level", |
|
|
info="Select your education level for personalized answers" |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
k_slider = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
value=5, |
|
|
step=1, |
|
|
label="Number of document chunks to retrieve (k)" |
|
|
) |
|
|
temperature_slider = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.2, |
|
|
step=0.1, |
|
|
label="Temperature (lower = more focused)" |
|
|
) |
|
|
max_tokens_slider = gr.Slider( |
|
|
minimum=128, |
|
|
maximum=1024, |
|
|
value=512, |
|
|
step=128, |
|
|
label="Max Tokens (lower = faster responses)" |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("Ask Question", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
answer_output = gr.Textbox( |
|
|
label="Answer", |
|
|
lines=20, |
|
|
interactive=False, |
|
|
elem_classes=["answer-box"] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
flesch_output = gr.Textbox( |
|
|
label="Flesch-Kincaid Grade Level", |
|
|
value="N/A", |
|
|
interactive=False, |
|
|
scale=1 |
|
|
) |
|
|
|
|
|
similarity_output = gr.Textbox( |
|
|
label="Similarity Scores", |
|
|
value="", |
|
|
interactive=False, |
|
|
scale=1 |
|
|
) |
|
|
|
|
|
category_output = gr.Textbox( |
|
|
label="Question Category", |
|
|
value="", |
|
|
interactive=False, |
|
|
scale=1 |
|
|
) |
|
|
|
|
|
sources_output = gr.Textbox( |
|
|
label="Source Documents (with Chunk Text)", |
|
|
lines=15, |
|
|
interactive=False, |
|
|
info="Shows the retrieved document chunks with full text. File paths are shown for easy access." |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### 💡 Example Questions") |
|
|
gr.Markdown(f"Select a question below to use it in the chatbot ({len(EXAMPLE_QUESTIONS)} questions - scrollable dropdown):") |
|
|
|
|
|
example_questions_dropdown = gr.Dropdown( |
|
|
choices=EXAMPLE_QUESTIONS, |
|
|
label="Example Questions", |
|
|
value=None, |
|
|
info="Open the dropdown and scroll through all questions. Select one to use it.", |
|
|
interactive=True, |
|
|
container=True, |
|
|
scale=1 |
|
|
) |
|
|
|
|
|
def update_question_from_dropdown(selected_question): |
|
|
return selected_question if selected_question else "" |
|
|
|
|
|
example_questions_dropdown.change( |
|
|
fn=update_question_from_dropdown, |
|
|
inputs=example_questions_dropdown, |
|
|
outputs=question_input |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
**Note:** This chatbot provides informational answers based on medical literature. |
|
|
It is not a substitute for professional medical advice, diagnosis, or treatment. |
|
|
Always consult with qualified healthcare providers for medical decisions. |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
def process_with_education_level(question, model, education, k, temp, max_tok, request: gr.Request = None): |
|
|
|
|
|
|
|
|
token = None |
|
|
|
|
|
|
|
|
try: |
|
|
if request is not None: |
|
|
|
|
|
if hasattr(request, 'client') and request.client is not None: |
|
|
if hasattr(request.client, 'hf_token') and request.client.hf_token: |
|
|
token = request.client.hf_token |
|
|
elif hasattr(request.client, 'token') and request.client.token: |
|
|
token = request.client.token |
|
|
|
|
|
if not token and hasattr(request, 'headers') and request.headers: |
|
|
auth_header = request.headers.get('authorization', '') or request.headers.get('Authorization', '') |
|
|
if auth_header and auth_header.startswith('Bearer '): |
|
|
token = auth_header[7:] |
|
|
except Exception as e: |
|
|
logger.debug(f"Could not get token from request: {e}") |
|
|
|
|
|
|
|
|
|
|
|
if not token: |
|
|
token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") |
|
|
|
|
|
education_key = EDUCATION_LEVELS[education] |
|
|
return interface.process_question(question, model, education_key, k, temp, max_tok, hf_token=token) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=process_with_education_level, |
|
|
inputs=[ |
|
|
question_input, |
|
|
model_dropdown, |
|
|
education_dropdown, |
|
|
k_slider, |
|
|
temperature_slider, |
|
|
max_tokens_slider |
|
|
], |
|
|
outputs=[ |
|
|
answer_output, |
|
|
flesch_output, |
|
|
sources_output, |
|
|
similarity_output, |
|
|
category_output |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
question_input.submit( |
|
|
fn=process_with_education_level, |
|
|
inputs=[ |
|
|
question_input, |
|
|
model_dropdown, |
|
|
education_dropdown, |
|
|
k_slider, |
|
|
temperature_slider, |
|
|
max_tokens_slider |
|
|
], |
|
|
outputs=[ |
|
|
answer_output, |
|
|
flesch_output, |
|
|
sources_output, |
|
|
similarity_output, |
|
|
category_output |
|
|
] |
|
|
) |
|
|
except Exception as interface_error: |
|
|
logger.error(f"Error setting up Gradio interface components: {interface_error}", exc_info=True) |
|
|
import traceback |
|
|
error_trace = traceback.format_exc() |
|
|
with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: |
|
|
gr.Markdown(f""" |
|
|
# ⚠️ Interface Setup Error |
|
|
|
|
|
An error occurred while setting up the interface components. |
|
|
|
|
|
**Error:** {str(interface_error)} |
|
|
|
|
|
**Traceback:** |
|
|
``` |
|
|
{error_trace[:1000]}... |
|
|
``` |
|
|
|
|
|
Please check the logs for more details. |
|
|
""") |
|
|
return demo |
|
|
|
|
|
logger.info("Gradio interface created successfully") |
|
|
logger.info(f"Demo type: {type(demo)}, Demo ID: {id(demo)}") |
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
IS_SPACES = ( |
|
|
os.getenv("SPACE_ID") is not None or |
|
|
os.getenv("SYSTEM") == "spaces" or |
|
|
os.getenv("HF_SPACE_ID") is not None |
|
|
) |
|
|
|
|
|
|
|
|
demo = None |
|
|
_demo_created = False |
|
|
|
|
|
def _create_demo(): |
|
|
"""Create the demo - separated into function for better error handling""" |
|
|
global _demo_created, demo |
|
|
if _demo_created and demo is not None and isinstance(demo, (gr.Blocks, gr.Interface)): |
|
|
logger.warning("Demo already created, skipping...") |
|
|
return demo |
|
|
|
|
|
_demo_created = True |
|
|
try: |
|
|
logger.info("=" * 80) |
|
|
logger.info("Starting demo creation...") |
|
|
logger.info(f"IS_SPACES: {IS_SPACES}") |
|
|
logger.info(f"BOT_AVAILABLE: {BOT_AVAILABLE}") |
|
|
|
|
|
if not BOT_AVAILABLE: |
|
|
raise ImportError("bot module is not available - cannot create demo") |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct') |
|
|
parser.add_argument('--vector-db-dir', default='./chroma_db') |
|
|
parser.add_argument('--data-dir', default='./Data Resources') |
|
|
parser.add_argument('--max-new-tokens', type=int, default=1024) |
|
|
parser.add_argument('--temperature', type=float, default=0.2) |
|
|
parser.add_argument('--top-p', type=float, default=0.9) |
|
|
parser.add_argument('--repetition-penalty', type=float, default=1.1) |
|
|
parser.add_argument('--k', type=int, default=5) |
|
|
parser.add_argument('--skip-indexing', action='store_true', default=True) |
|
|
parser.add_argument('--verbose', action='store_true', default=False) |
|
|
parser.add_argument('--seed', type=int, default=42) |
|
|
|
|
|
args = parser.parse_args([]) |
|
|
args.skip_model_loading = IS_SPACES |
|
|
|
|
|
logger.info("Creating RAGBot...") |
|
|
bot = RAGBot(args) |
|
|
|
|
|
if bot.vector_retriever is None: |
|
|
raise Exception("Vector database not available") |
|
|
|
|
|
|
|
|
collection_stats = bot.vector_retriever.get_collection_stats() |
|
|
if collection_stats.get('total_chunks', 0) == 0: |
|
|
logger.warning("Vector database is empty. The chatbot may not find relevant documents.") |
|
|
|
|
|
logger.info("Creating interface...") |
|
|
created_demo = create_interface(bot) |
|
|
logger.info(f"Demo created successfully: {type(created_demo)}") |
|
|
return created_demo |
|
|
|
|
|
except Exception as bot_error: |
|
|
logger.error(f"Error initializing: {bot_error}", exc_info=True) |
|
|
import traceback |
|
|
error_trace = traceback.format_exc() |
|
|
logger.error(f"Full traceback: {error_trace}") |
|
|
|
|
|
with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as error_demo: |
|
|
gr.Markdown(f""" |
|
|
# ⚠️ Initialization Error |
|
|
|
|
|
The chatbot encountered an error during initialization: |
|
|
|
|
|
**Error:** {str(bot_error)} |
|
|
|
|
|
**Possible causes:** |
|
|
- Missing vector database (chroma_db directory) |
|
|
- Missing dependencies |
|
|
- Configuration issues |
|
|
|
|
|
**Error Details:** |
|
|
``` |
|
|
{error_trace[:1000]}... |
|
|
``` |
|
|
""") |
|
|
logger.info(f"Error demo created: {type(error_demo)}") |
|
|
return error_demo |
|
|
|
|
|
|
|
|
|
|
|
if demo is None or not isinstance(demo, (gr.Blocks, gr.Interface)): |
|
|
try: |
|
|
if IS_SPACES: |
|
|
logger.info("Creating demo directly at module level for Spaces...") |
|
|
else: |
|
|
logger.info("Creating demo for local execution...") |
|
|
|
|
|
demo = _create_demo() |
|
|
|
|
|
if demo is None or not isinstance(demo, (gr.Blocks, gr.Interface)): |
|
|
raise ValueError(f"Demo creation returned invalid result: {type(demo)}") |
|
|
|
|
|
logger.info("Demo creation completed successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"CRITICAL: Error creating demo: {e}", exc_info=True) |
|
|
import traceback |
|
|
error_trace = traceback.format_exc() |
|
|
logger.error(f"Full traceback: {error_trace}") |
|
|
with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: |
|
|
gr.Markdown(f""" |
|
|
# Error Initializing Chatbot |
|
|
|
|
|
A critical error occurred while initializing the chatbot. |
|
|
|
|
|
**Error:** {str(e)} |
|
|
|
|
|
**Traceback:** |
|
|
``` |
|
|
{error_trace[:1500]}... |
|
|
``` |
|
|
|
|
|
Please check the logs for more details. |
|
|
""") |
|
|
logger.info(f"Fallback error demo created: {type(demo)}") |
|
|
else: |
|
|
logger.info("Demo already exists, skipping creation") |
|
|
|
|
|
|
|
|
if demo is None: |
|
|
logger.error("CRITICAL: Demo variable is None! Creating fallback demo.") |
|
|
with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: |
|
|
gr.Markdown("# Error: Demo was not created properly\n\nPlease check the logs for details.") |
|
|
elif not isinstance(demo, (gr.Blocks, gr.Interface)): |
|
|
logger.error(f"CRITICAL: Demo is not a valid Gradio object: {type(demo)}") |
|
|
with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: |
|
|
gr.Markdown(f"# Error: Invalid demo type\n\nDemo type: {type(demo)}\n\nPlease check the logs for details.") |
|
|
else: |
|
|
logger.info(f"✅ Final demo check passed: demo type={type(demo)}") |
|
|
|
|
|
if IS_SPACES: |
|
|
logger.info(f"Spaces mode: Demo is ready and accessible") |
|
|
|
|
|
print(f"DEMO_READY: {type(demo)}") |
|
|
print(f"DEMO_VALID: {isinstance(demo, (gr.Blocks, gr.Interface))}") |
|
|
|
|
|
import sys |
|
|
current_module = sys.modules[__name__] |
|
|
current_module.demo = demo |
|
|
current_module.__dict__['demo'] = demo |
|
|
logger.info("Demo explicitly set in module namespace") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if not IS_SPACES: |
|
|
|
|
|
demo.launch() |
|
|
|