unified-analysis-for-legal-docs / document_processor.py
sagar008's picture
Update document_processor.py
2bdc136 verified
import time
import asyncio
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Any, Tuple
from chunker import DocumentChunker
from summarizer import DocumentSummarizer
from risk_detector import RiskDetector
from clause_tagger import ClauseTagger
from models import *
def clean_numpy(obj):
"""Recursively convert NumPy types to native Python types"""
if isinstance(obj, np.generic):
return obj.item()
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, dict):
return {k: clean_numpy(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [clean_numpy(v) for v in obj]
else:
return obj
class DocumentProcessor:
def __init__(self):
self.chunker = None
self.summarizer = None
self.risk_detector = None
self.clause_tagger = None
self.cache = {}
self.executor = ThreadPoolExecutor(max_workers=3)
async def initialize(self):
"""Initialize all components"""
print(" Initializing Document Processor...")
self.chunker = DocumentChunker()
self.summarizer = DocumentSummarizer()
self.risk_detector = RiskDetector()
self.clause_tagger = ClauseTagger()
# Initialize models in parallel for faster startup
init_tasks = [
self.summarizer.initialize(),
self.clause_tagger.initialize()
]
await asyncio.gather(*init_tasks)
print(" Document Processor initialized")
async def process_document(self, text: str, doc_id: str) -> Tuple[Dict[str, Any], List[Dict]]:
"""Process document with optimized single embedding generation"""
# Check cache first
if doc_id in self.cache:
print(f" Using cached result for doc_id: {doc_id}")
return self.cache[doc_id]
print(f" Processing new document: {doc_id}")
start_time = time.time()
try:
# Step 1: Chunk the document
chunks = self.chunker.chunk_by_tokens(text, max_tokens=1600, stride=50)
print(f" Created {len(chunks)} chunks in {time.time() - start_time:.2f}s")
# Step 2: Generate embeddings
print(f" Generating embeddings for {len(chunks)} chunks...")
embedding_start = time.time()
if self.clause_tagger.embedding_model:
chunk_embeddings = self.clause_tagger.embedding_model.encode(chunks)
embedding_time = time.time() - embedding_start
print(f" Generated embeddings in {embedding_time:.2f}s")
# Convert embeddings to lists to avoid NumPy serialization issues
chunk_data = [
{"text": chunk, "embedding": embedding.tolist()}
for chunk, embedding in zip(chunks, chunk_embeddings)
]
else:
chunk_data = [{"text": chunk, "embedding": None} for chunk in chunks]
embedding_time = 0
print(" No embedding model available")
# Step 3: Run analysis tasks in parallel
tasks = []
# Task 1: Summarization (async)
summary_task = asyncio.create_task(
self.summarizer.batch_summarize(chunks)
)
tasks.append(('summary', summary_task))
# Task 2: Risk detection (CPU-bound)
risk_task = asyncio.get_event_loop().run_in_executor(
self.executor,
self.risk_detector.detect_risks,
chunks
)
tasks.append(('risks', risk_task))
# Task 3: Clause tagging (async, uses embeddings)
if self.clause_tagger.embedding_model and chunk_data[0]["embedding"] is not None:
clause_task = asyncio.create_task(
self.clause_tagger.tag_clauses_with_embeddings(chunk_data)
)
tasks.append(('clauses', clause_task))
print(f" Starting {len(tasks)} parallel analysis tasks...")
# Wait for all tasks
results = {}
for task_name, task in tasks:
try:
print(f" Waiting for {task_name} analysis...")
results[task_name] = await task
print(f" {task_name} completed")
except Exception as e:
print(f" {task_name} analysis failed: {e}")
# Fallback results
if task_name == 'summary':
results[task_name] = {"actual_summary": "Summary generation failed", "short_summary": "Summary failed"}
elif task_name == 'risks':
results[task_name] = []
elif task_name == 'clauses':
results[task_name] = []
# Step 4: Combine results
processing_time = time.time() - start_time
result = {
"summary": results.get('summary', {"actual_summary": "Summary not available", "short_summary": "Summary not available"}),
"risky_terms": results.get('risks', []),
"key_clauses": results.get('clauses', []),
"chunk_count": len(chunks),
"processing_time": f"{processing_time:.2f}s",
"embedding_time": f"{embedding_time:.2f}s",
"embeddings_generated": len(chunk_embeddings) if 'chunk_embeddings' in locals() else 0,
"doc_id": doc_id,
"parallel_tasks_completed": len([r for r in results.values() if r])
}
# Step 5: Clean NumPy data before caching/returning
cleaned_result = clean_numpy(result)
cleaned_chunk_data = clean_numpy(chunk_data)
# Cache results
self.cache[doc_id] = (cleaned_result, cleaned_chunk_data)
print(f"πŸŽ‰ Document processing completed in {processing_time:.2f}s")
return cleaned_result, cleaned_chunk_data
except Exception as e:
error_time = time.time() - start_time
print(f"❌ Document processing failed after {error_time:.2f}s: {e}")
error_result = {
"error": str(e),
"summary": {"actual_summary": "Processing failed", "short_summary": "Processing failed"},
"risky_terms": [],
"key_clauses": [],
"chunk_count": 0,
"processing_time": f"{error_time:.2f}s",
"doc_id": doc_id
}
return clean_numpy(error_result), []
def chunk_text(self, data: ChunkInput) -> Dict[str, Any]:
"""Standalone chunking endpoint"""
start = time.time()
try:
chunks = self.chunker.chunk_by_tokens(data.text, data.max_tokens, data.stride)
return {
"chunks": chunks,
"chunk_count": len(chunks),
"time_taken": f"{time.time() - start:.2f}s",
"status": "success"
}
except Exception as e:
return {
"error": str(e),
"chunks": [],
"chunk_count": 0,
"time_taken": f"{time.time() - start:.2f}s",
"status": "failed"
}
def summarize_batch(self, data: SummarizeBatchInput) -> Dict[str, Any]:
"""Standalone batch summarization endpoint"""
start = time.time()
try:
result = self.summarizer.summarize_texts_sync(data.texts, data.max_length, data.min_length)
result["time_taken"] = f"{time.time() - start:.2f}s"
result["status"] = "success"
return clean_numpy(result)
except Exception as e:
return {
"error": str(e),
"summary": "Summarization failed",
"time_taken": f"{time.time() - start:.2f}s",
"status": "failed"
}
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics for monitoring"""
return {
"cached_documents": len(self.cache),
"cache_keys": list(self.cache.keys())
}
def clear_cache(self) -> Dict[str, str]:
"""Clear the document cache"""
cleared_count = len(self.cache)
self.cache.clear()
return {
"message": f"Cleared {cleared_count} cached documents",
"status": "success"
}
def __del__(self):
"""Cleanup thread pool on destruction"""
if hasattr(self, 'executor'):
self.executor.shutdown(wait=True)