Spaces:
Build error
Build error
| import numpy.typing as npt | |
| import time | |
| # from .sam import build_sam, SamPredictor | |
| # from .sam_hq import build_sam as build_sam_hq, SamPredictor as SamHqPredictor | |
| # from .mobile_sam import ( | |
| # build_sam_vit_t as build_mobile_sam, | |
| # SamPredictor as MobileSamPredictor, | |
| # ) | |
| # from .per_sam import train, PerSAM | |
| # from .configs import DEVICE | |
| from app.sam import build_sam, SamPredictor | |
| from app.sam_hq import build_sam as build_sam_hq, SamPredictor as SamHqPredictor | |
| from app.mobile_sam import ( | |
| build_sam_vit_t as build_mobile_sam, | |
| SamPredictor as MobileSamPredictor, | |
| ) | |
| from app.per_sam import train, PerSAM | |
| from app.configs import DEVICE | |
| def build_sam_predictor(checkpoint: str | None = None): | |
| sam = build_sam(checkpoint) | |
| sam = sam.to(DEVICE) | |
| return SamPredictor(sam) | |
| def build_sam_hq_predictor(checkpoint: str | None = None): | |
| sam = build_sam_hq(checkpoint) | |
| sam = sam.to(DEVICE) | |
| return SamHqPredictor(sam) | |
| def build_mobile_sam_predictor(checkpoint: str | None = None): | |
| sam = build_mobile_sam(checkpoint) | |
| sam = sam.to(DEVICE) | |
| return MobileSamPredictor(sam) | |
| def get_multi_label_predictor( | |
| sam: MobileSamPredictor, image: npt.NDArray, mask: npt.NDArray, | |
| ) -> PerSAM: | |
| start = time.perf_counter() | |
| weights, target_feat = train(sam, [image], [mask]) | |
| print(f"training time {time.perf_counter() - start}") | |
| per_sam_model = PerSAM(sam, target_feat, 10, 0.4, 0.2, weights) | |
| return per_sam_model | |
| if __name__ == "__main__": | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision.transforms.functional import resize | |
| from app.transforms import ResizeLongestSide | |
| T = ResizeLongestSide(1024) | |
| image = Image.open("/Users/dillonlaird/code/instance_labeler/seals-labeled/img2.png").convert("RGB") | |
| target_size = T.get_preprocess_shape(image.size[1], image.size[0], T.target_length) | |
| image_np = np.array(resize(image, target_size)) | |
| mask = Image.open("/Users/dillonlaird/code/instance_labeler/seals-labeled/img2.seal.10.png") | |
| target_size = T.get_preprocess_shape(mask.size[1], mask.size[0], T.target_length) | |
| mask_np = np.array(resize(mask, target_size).convert("L")) | |
| model = build_mobile_sam_predictor("/Users/dillonlaird/code/instance_labeler/mobile_sam.pth") | |
| start = time.perf_counter() | |
| per_sam_model = get_multi_label_predictor(model, image_np, mask_np) | |
| print(f"training time {time.perf_counter() - start}") | |
| start = time.perf_counter() | |
| masks, bboxes, _ = per_sam_model(image_np) | |
| print(f"prediction time {time.perf_counter() - start}") | |