# clause_tagger.py - Enhanced clause tagging with InLegalBERT from typing import List, Dict, Any from sentence_transformers import SentenceTransformer import numpy as np from sklearn.metrics.pairwise import cosine_similarity import csv import os class ClauseTagger: def __init__(self): self.embedding_model = None self.clause_reference = [] async def initialize(self): """Initialize embedding model and load clause references""" if self.embedding_model is None: print("🧠 Loading embedding model for clause tagging...") # Set cache directory explicitly for HF Spaces cache_folder = "/tmp/sentence_transformers_cache" os.makedirs(cache_folder, exist_ok=True) try: # Use a legal-domain model with explicit cache directory self.embedding_model = SentenceTransformer( 'law-ai/InLegalBERT', cache_folder=cache_folder ) print("✅ InLegalBERT embedding model loaded") except Exception as e: print(f"⚠️ Failed to load InLegalBERT: {e}") # Fallback to a general model try: self.embedding_model = SentenceTransformer( 'all-MiniLM-L6-v2', cache_folder=cache_folder ) print("✅ Fallback embedding model loaded (all-MiniLM-L6-v2)") except Exception as fallback_error: print(f"❌ Failed to load fallback model: {fallback_error}") self.embedding_model = None return # Load clause references self.clause_reference = self._load_clause_reference() if self.clause_reference: # Pre-embed clause references clause_texts = [clause['text'] for clause in self.clause_reference] try: clause_embeddings = self.embedding_model.encode(clause_texts) for i, clause in enumerate(self.clause_reference): clause['embedding'] = clause_embeddings[i] print(f"📋 Loaded and embedded {len(self.clause_reference)} clause references") except Exception as e: print(f"⚠️ Failed to embed clause references: {e}") self.clause_reference = [] def _load_clause_reference(self) -> List[Dict[str, Any]]: """Load clause reference data""" clause_file = "clause_refrence.csv" # Your existing file if not os.path.exists(clause_file): print(f"⚠️ Clause reference file not found: {clause_file}") # Create minimal clause references if file doesn't exist return self._create_default_clauses() clauses = [] try: with open(clause_file, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: clauses.append({ 'id': row.get('id', ''), 'type': row.get('type', ''), 'text': row.get('text', ''), 'category': row.get('category', 'general') }) except Exception as e: print(f"❌ Error loading clause reference: {e}") return self._create_default_clauses() return clauses def _create_default_clauses(self) -> List[Dict[str, Any]]: """Create default clause references if CSV file is not available""" return [ { 'id': 'penalty_1', 'type': 'penalty_clause', 'text': 'penalty for breach of contract terms and conditions', 'category': 'penalty' }, { 'id': 'termination_1', 'type': 'termination_clause', 'text': 'termination of agreement upon breach or default', 'category': 'termination' }, { 'id': 'liability_1', 'type': 'liability_clause', 'text': 'liability for damages and compensation obligations', 'category': 'liability' }, { 'id': 'payment_1', 'type': 'payment_clause', 'text': 'payment terms conditions and default provisions', 'category': 'payment' }, { 'id': 'confidentiality_1', 'type': 'confidentiality_clause', 'text': 'confidentiality and non-disclosure obligations', 'category': 'confidentiality' } ] async def tag_clauses(self, chunks: List[str]) -> List[Dict[str, Any]]: """Tag clauses in document chunks - GENERATES NEW EMBEDDINGS""" if not self.clause_reference or not self.embedding_model: print("⚠️ No clause references or embedding model available") return [] print(f"🏷️ Tagging clauses in {len(chunks)} chunks...") try: # Embed all chunks chunk_embeddings = self.embedding_model.encode(chunks) tagged_clauses = [] for chunk_idx, chunk in enumerate(chunks): chunk_embedding = chunk_embeddings[chunk_idx] # Find best matching clauses for this chunk for clause in self.clause_reference: if 'embedding' not in clause: continue similarity = cosine_similarity( [chunk_embedding], [clause['embedding']] )[0][0] # Only include matches above threshold if similarity > 0.65: # Slightly lower threshold for better recall tagged_clauses.append({ 'clause_id': clause['id'], 'clause_type': clause['type'], 'clause_category': clause['category'], 'matched_text': chunk[:200] + '...' if len(chunk) > 200 else chunk, 'similarity_score': float(similarity), 'chunk_index': chunk_idx, 'reference_text': clause['text'], 'confidence': self._calculate_clause_confidence(similarity, chunk) }) # Sort by similarity score and return top matches tagged_clauses.sort(key=lambda x: x['similarity_score'], reverse=True) return tagged_clauses[:20] # Top 20 matches except Exception as e: print(f"❌ Clause tagging failed: {e}") return [] async def tag_clauses_with_embeddings(self, chunk_data: List[Dict]) -> List[Dict[str, Any]]: """Tag clauses using pre-computed embeddings - OPTIMIZED VERSION""" if not self.clause_reference or not self.embedding_model: print("⚠️ No clause references or embedding model available") return [] print(f"🏷️ Tagging clauses using pre-computed embeddings for {len(chunk_data)} chunks...") try: tagged_clauses = [] for chunk_idx, chunk_info in enumerate(chunk_data): chunk_embedding = chunk_info["embedding"] if chunk_embedding is None: continue # Find best matching clauses using pre-computed embedding for clause in self.clause_reference: if 'embedding' not in clause: continue similarity = cosine_similarity( [chunk_embedding], [clause['embedding']] )[0][0] if similarity > 0.65: # Threshold for relevance confidence = self._calculate_clause_confidence(similarity, chunk_info["text"]) tagged_clauses.append({ 'clause_id': clause['id'], 'clause_type': clause['type'], 'clause_category': clause['category'], 'matched_text': chunk_info["text"][:200] + '...' if len(chunk_info["text"]) > 200 else chunk_info["text"], 'similarity_score': float(similarity), 'chunk_index': chunk_idx, 'reference_text': clause['text'], 'confidence': confidence }) # Sort by similarity score and confidence tagged_clauses.sort(key=lambda x: (x['similarity_score'], x['confidence']), reverse=True) return tagged_clauses[:15] # Top 15 matches except Exception as e: print(f"❌ Optimized clause tagging failed: {e}") return [] def _calculate_clause_confidence(self, similarity: float, text: str) -> float: """Calculate confidence score for clause matching""" confidence = similarity # Base confidence from similarity text_lower = text.lower() # Boost confidence for legal keywords legal_keywords = [ 'shall', 'hereby', 'whereas', 'therefore', 'notwithstanding', 'party', 'agreement', 'contract', 'clause', 'provision' ] for keyword in legal_keywords: if keyword in text_lower: confidence += 0.05 # Boost for specific clause indicators clause_indicators = [ 'penalty', 'fine', 'terminate', 'liable', 'payment', 'confidential', 'breach', 'default', 'obligation' ] for indicator in clause_indicators: if indicator in text_lower: confidence += 0.08 return min(confidence, 1.0)