Spaces:
Build error
Build error
| import io | |
| import gdown | |
| import base64 | |
| from typing import Optional | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2.projects.point_rend import add_pointrend_config | |
| # ------------------------------- | |
| # FastAPI setup | |
| # ------------------------------- | |
| app = FastAPI(title="Rooftop Segmentation API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ------------------------------- | |
| # Available epsilons | |
| # ------------------------------- | |
| EPSILONS = [0.01, 0.005, 0.004, 0.003, 0.001] | |
| def get_epsilons(): | |
| return {"epsilons": EPSILONS} | |
| # ------------------------------- | |
| # Detectron2 model setup | |
| # ------------------------------- | |
| def setup_model_rect(weights_path: str): | |
| cfg = get_cfg() | |
| add_pointrend_config(cfg) | |
| cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml" | |
| cfg.merge_from_file(cfg_path) | |
| cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 | |
| cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES | |
| cfg.MODEL.WEIGHTS = weights_path | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| cfg.MODEL.DEVICE = "cpu" | |
| return DefaultPredictor(cfg) | |
| def setup_model_irregular(weights_path: str): | |
| cfg = get_cfg() | |
| add_pointrend_config(cfg) | |
| cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml" | |
| cfg.merge_from_file(cfg_path) | |
| cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # Only irregular-flat | |
| cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES | |
| cfg.MODEL.WEIGHTS = weights_path | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| cfg.MODEL.DEVICE = "cpu" | |
| return DefaultPredictor(cfg) | |
| # Load models | |
| predictor_rect = setup_model_rect("/app/model_rect_final.pth") | |
| predictor_irregular_flat = setup_model_irregular("/app/model_irregular_flat.pth") | |
| # ------------------------------- | |
| # Post-processing functions | |
| # ------------------------------- | |
| def postprocess_rect(mask: np.ndarray, epsilon: float) -> Optional[np.ndarray]: | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None | |
| c = max(contours, key=cv2.contourArea) | |
| eps = epsilon * cv2.arcLength(c, True) | |
| approx = cv2.approxPolyDP(c, eps, True) | |
| simp = np.zeros_like(mask_uint8) | |
| cv2.fillPoly(simp, [approx], 255) | |
| return simp | |
| def postprocess_irregular(mask: np.ndarray, epsilon: float) -> Optional[np.ndarray]: | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None | |
| c = max(contours, key=cv2.contourArea) | |
| eps = epsilon * cv2.arcLength(c, True) | |
| polygon = cv2.approxPolyDP(c, eps, True) | |
| return polygon.reshape(-1, 2) | |
| def mask_to_polygon(mask: np.ndarray) -> Optional[np.ndarray]: | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None | |
| largest = max(contours, key=cv2.contourArea) | |
| return largest.reshape(-1, 2) | |
| def im_to_b64_png(im: np.ndarray) -> str: | |
| _, buffer = cv2.imencode(".png", im) | |
| return base64.b64encode(buffer).decode() | |
| def overlay_polygon(im: np.ndarray, polygon: Optional[np.ndarray]) -> np.ndarray: | |
| overlay = im.copy() | |
| if polygon is not None: | |
| cv2.polylines(overlay, [polygon.astype(np.int32)], True, (0,0,255), 2) | |
| return overlay | |
| # ------------------------------- | |
| # API endpoints | |
| # ------------------------------- | |
| def root(): | |
| return {"message": "Rooftop Segmentation API is running!"} | |
| async def predict( | |
| file: UploadFile = File(...), | |
| rooftop_type: str = Form(...), # "rectangular" or "irregular" | |
| epsilon: float = Form(0.004) | |
| ): | |
| contents = await file.read() | |
| try: | |
| im_pil = Image.open(io.BytesIO(contents)).convert("RGB") | |
| except Exception as e: | |
| return JSONResponse(status_code=400, content={"error": "Invalid image", "detail": str(e)}) | |
| im = np.array(im_pil)[:, :, ::-1].copy() # RGB -> BGR | |
| # Choose predictor and post-processing based on rooftop type | |
| if rooftop_type.lower() == "rectangular": | |
| predictor = predictor_rect | |
| post_fn = lambda mask: postprocess_rect(mask, epsilon) | |
| model_used = "model_rect_final.pth" | |
| elif rooftop_type.lower() == "irregular": | |
| predictor = predictor_irregular_flat | |
| post_fn = lambda mask: postprocess_irregular(mask, epsilon) | |
| model_used = "model_irregular_flat.pth" | |
| else: | |
| return JSONResponse(status_code=400, content={"error": "Invalid rooftop_type. Choose 'rectangular' or 'irregular'."}) | |
| # Run prediction | |
| outputs = predictor(im) | |
| instances = outputs["instances"].to("cpu") | |
| if len(instances) == 0: | |
| return {"polygon": None, "image": None, "model_used": model_used, "rooftop_type": rooftop_type, "epsilon": epsilon} | |
| idx = int(instances.scores.argmax().item()) | |
| raw_mask = instances.pred_masks[idx].numpy().astype(np.uint8) | |
| # Post-process | |
| result_mask = post_fn(raw_mask) | |
| polygon = mask_to_polygon(result_mask) if rooftop_type.lower() == "rectangular" else result_mask | |
| # Overlay | |
| overlay = overlay_polygon(im, polygon) | |
| img_b64 = im_to_b64_png(overlay) | |
| return { | |
| "polygon": polygon.tolist() if polygon is not None else None, | |
| "image": img_b64, | |
| "model_used": model_used, | |
| "rooftop_type": rooftop_type, | |
| "epsilon": epsilon | |
| } |