basiliskan commited on
Commit
7f44b94
·
verified ·
1 Parent(s): 2a30887

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +94 -0
handler.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from PIL import Image
3
+ import base64
4
+ import io
5
+ import os
6
+ import torch
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path=""):
10
+ from doclayout_yolo import YOLOv10
11
+
12
+ # Load model from repo path
13
+ model_path = os.path.join(path, "doclayout_yolo_docstructbench_imgsz1024.pt")
14
+ self.model = YOLOv10(model_path)
15
+
16
+ # Label mapping
17
+ self.id_to_names = {
18
+ 0: 'title',
19
+ 1: 'plain_text',
20
+ 2: 'abandon',
21
+ 3: 'figure',
22
+ 4: 'figure_caption',
23
+ 5: 'table',
24
+ 6: 'table_caption',
25
+ 7: 'table_footnote',
26
+ 8: 'isolate_formula',
27
+ 9: 'formula_caption'
28
+ }
29
+
30
+ # Set device
31
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
+
33
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
34
+ """
35
+ Process image and return layout detections.
36
+
37
+ Args:
38
+ data: Dictionary with:
39
+ - "inputs": base64 encoded image string or PIL Image
40
+ - "parameters" (optional): {
41
+ "confidence": float (default 0.2),
42
+ "iou_threshold": float (default 0.45)
43
+ }
44
+
45
+ Returns:
46
+ List of detections with label, score, and bounding box
47
+ """
48
+ # Get image from request
49
+ image = data.get("inputs")
50
+
51
+ # Get optional parameters
52
+ params = data.get("parameters", {})
53
+ conf_threshold = params.get("confidence", 0.2)
54
+ iou_threshold = params.get("iou_threshold", 0.45)
55
+
56
+ # Handle base64 encoded image
57
+ if isinstance(image, str):
58
+ # Remove data URL prefix if present
59
+ if "base64," in image:
60
+ image = image.split("base64,")[1]
61
+ image = Image.open(io.BytesIO(base64.b64decode(image)))
62
+
63
+ # Run inference
64
+ results = self.model.predict(
65
+ image,
66
+ imgsz=1024,
67
+ conf=conf_threshold,
68
+ iou=iou_threshold,
69
+ device=self.device
70
+ )[0]
71
+
72
+ # Format output
73
+ detections = []
74
+ boxes = results.boxes
75
+
76
+ for i in range(len(boxes)):
77
+ box = boxes[i]
78
+ cls_id = int(box.cls.item())
79
+
80
+ detections.append({
81
+ "label": self.id_to_names.get(cls_id, f"class_{cls_id}"),
82
+ "score": round(float(box.conf.item()), 4),
83
+ "box": {
84
+ "x1": round(float(box.xyxy[0][0].item()), 2),
85
+ "y1": round(float(box.xyxy[0][1].item()), 2),
86
+ "x2": round(float(box.xyxy[0][2].item()), 2),
87
+ "y2": round(float(box.xyxy[0][3].item()), 2)
88
+ }
89
+ })
90
+
91
+ # Sort by confidence score
92
+ detections.sort(key=lambda x: x["score"], reverse=True)
93
+
94
+ return detections