Spaces:
Runtime error
Runtime error
| # -*- encoding: utf-8 -*- | |
| # @Author: SWHL | |
| # @Contact: [email protected] | |
| from enum import Enum | |
| from pathlib import Path | |
| from typing import List, Union | |
| import gradio as gr | |
| import numpy as np | |
| from rapidocr import RapidOCR | |
| class InferEngine(Enum): | |
| ort = "ONNXRuntime" | |
| vino = "OpenVino" | |
| paddle = "PaddlePaddle" | |
| torch = "PyTorch" | |
| def get_ocr_engine(infer_engine: str, lang_det: str, lang_rec: str) -> RapidOCR: | |
| engine_mapping = { | |
| InferEngine.vino.value: "with_openvino", | |
| InferEngine.paddle.value: "with_paddle", | |
| InferEngine.torch.value: "with_torch", | |
| } | |
| param_key = engine_mapping.get(infer_engine, "with_onnx") | |
| return RapidOCR( | |
| params={ | |
| f"Global.{param_key}": True, | |
| "Global.lang_det": lang_det, | |
| "Global.lang_rec": lang_rec, | |
| } | |
| ) | |
| def get_ocr_result( | |
| img: np.ndarray, | |
| text_score, | |
| box_thresh, | |
| unclip_ratio, | |
| lang_det, | |
| lang_rec, | |
| infer_engine, | |
| is_word: str, | |
| use_module: List[str], | |
| ): | |
| return_word_box = "Yes" in is_word | |
| use_det = "use_det" in use_module | |
| use_cls = "use_cls" in use_module | |
| use_rec = "use_rec" in use_module | |
| ocr_engine = get_ocr_engine(infer_engine, lang_det=lang_det, lang_rec=lang_rec) | |
| ocr_result = ocr_engine( | |
| img, | |
| use_det=use_det, | |
| use_cls=use_cls, | |
| use_rec=use_rec, | |
| text_score=text_score, | |
| box_thresh=box_thresh, | |
| unclip_ratio=unclip_ratio, | |
| return_word_box=return_word_box, | |
| ) | |
| vis_img = ocr_result.vis() | |
| if return_word_box: | |
| txts, scores, _ = list(zip(*ocr_result.word_results)) | |
| ocr_txts = [[i, txt, score] for i, (txt, score) in enumerate(zip(txts, scores))] | |
| return vis_img, ocr_txts, ocr_result.elapse | |
| if use_rec: | |
| ocr_txts = [ | |
| [i, txt, score] | |
| for i, (txt, score) in enumerate(zip(ocr_result.txts, ocr_result.scores)) | |
| ] | |
| else: | |
| ocr_txts = [] | |
| return vis_img, ocr_txts, ocr_result.elapse | |
| def create_examples() -> List[List[Union[str, float]]]: | |
| DEFAULT_VALUES = [0.5, 0.5, 1.6, "ch_mobile", "ch_mobile", "ONNXRuntime", "No"] | |
| image_specs = [ | |
| ("images/ch_en_num.jpg", {}), | |
| ("images/japan.jpg", {3: "multi_mobile", 4: "japan_mobile"}), | |
| ("images/korean.jpg", {3: "multi_mobile", 4: "korean_mobile"}), | |
| ("images/air_ticket.jpg", {}), | |
| ("images/car_plate.jpeg", {}), | |
| ("images/train_ticket.jpeg", {}), | |
| ] | |
| examples = [] | |
| for image_path, overrides in image_specs: | |
| example = DEFAULT_VALUES.copy() | |
| example.insert(0, image_path) | |
| for index, value in overrides.items(): | |
| example[index + 1] = value | |
| examples.append(example) | |
| return examples | |
| infer_engine_list = [InferEngine[v].value for v in InferEngine.__members__] | |
| lang_det_list = ["ch_mobile", "ch_server", "en_mobile", "en_server", "multi_mobile"] | |
| lang_rec_list = [ | |
| "ch_mobile", | |
| "ch_server", | |
| "ch_doc_server", | |
| "chinese_cht", | |
| "en_mobile", | |
| "ar_mobile", | |
| "cyrillic_mobile", | |
| "devanagari_mobile", | |
| "japan_mobile", | |
| "ka_mobile", | |
| "korean_mobile", | |
| "latin_mobile", | |
| "ta_mobile", | |
| "te_mobile", | |
| ] | |
| custom_css = """ | |
| body {font-family: body {font-family: 'Helvetica Neue', Helvetica;} | |
| .gr-button {background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 5px;} | |
| .gr-button:hover {background-color: #45a049;} | |
| .gr-textbox {margin-bottom: 15px;} | |
| .example-button {background-color: #1E90FF; color: white; border: none; padding: 8px 15px; border-radius: 5px; margin: 5px;} | |
| .example-button:hover {background-color: #FF4500;} | |
| .tall-radio .gr-radio-item {padding: 15px 0; min-height: 50px; display: flex; align-items: center;} | |
| .tall-radio label {font-size: 16px;} | |
| .output-image, .input-image, .image-preview {height: 300px !important} | |
| """ | |
| with gr.Blocks( | |
| title="Rapid⚡OCR Demo", css="custom_css", theme=gr.themes.Soft() | |
| ) as demo: | |
| gr.HTML( | |
| """ | |
| <h1 style='text-align: center;font-size:40px'>Rapid⚡OCRv2.1.0</h1> | |
| <div style="display: flex; justify-content: center; gap: 10px;"> | |
| <a href=""><img src="https://img.shields.io/badge/Python->=3.6-aff.svg"></a> | |
| <a href="https://rapidai.github.io/RapidOCRDocs"><img src="https://img.shields.io/badge/Docs-link-aff.svg"></a> | |
| <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a> | |
| <a href="https://pepy.tech/project/rapidocr"><img src="https://static.pepy.tech/personalized-badge/rapidocr?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20rapidocr"></a> | |
| <a href="https://pypi.org/project/rapidocr/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapidocr"></a> | |
| <a href="https://github.com/RapidAI/RapidOCR"><img src="https://img.shields.io/github/stars/RapidAI/RapidOCR?color=ccf"></a> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| text_score = gr.Slider( | |
| label="text_score", | |
| minimum=0, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.1, | |
| info="文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", | |
| ) | |
| box_thresh = gr.Slider( | |
| label="box_thresh", | |
| minimum=0, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.1, | |
| info="检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", | |
| ) | |
| unclip_ratio = gr.Slider( | |
| label="unclip_ratio", | |
| minimum=1.5, | |
| maximum=2.0, | |
| value=1.6, | |
| step=0.1, | |
| info="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6", | |
| ) | |
| with gr.Row(): | |
| use_module = gr.CheckboxGroup( | |
| ["use_det", "use_cls", "use_rec"], | |
| label="Use module (使用哪些模块)", | |
| value=["use_det", "use_cls", "use_rec"], | |
| interactive=True, | |
| ) | |
| select_infer_engine = gr.Dropdown( | |
| choices=infer_engine_list, | |
| label="Infer Engine (推理引擎)", | |
| value="ONNXRuntime", | |
| interactive=True, | |
| ) | |
| lang_det = gr.Dropdown( | |
| choices=lang_det_list, | |
| label="Det model (文本检测模型)", | |
| value=lang_det_list[0], | |
| interactive=True, | |
| ) | |
| lang_rec = gr.Dropdown( | |
| choices=lang_rec_list, | |
| label="Rec model (文本识别模型)", | |
| value=lang_rec_list[0], | |
| interactive=True, | |
| ) | |
| is_word = gr.Radio( | |
| ["Yes", "No"], label="Return word box (返回单字符)", value="No" | |
| ) | |
| img_input = gr.Image(label="Upload or Select Image", sources="upload") | |
| run_btn = gr.Button("Run") | |
| img_output = gr.Image(label="Output Image") | |
| elapse = gr.Textbox(label="Elapse(s)") | |
| ocr_results = gr.Dataframe( | |
| label="OCR Txts", | |
| headers=["Index", "Txt", "Score"], | |
| datatype=["number", "str", "number"], | |
| show_copy_button=True, | |
| ) | |
| ocr_inputs = [ | |
| img_input, | |
| text_score, | |
| box_thresh, | |
| unclip_ratio, | |
| lang_det, | |
| lang_rec, | |
| select_infer_engine, | |
| is_word, | |
| use_module, | |
| ] | |
| run_btn.click( | |
| get_ocr_result, inputs=ocr_inputs, outputs=[img_output, ocr_results, elapse] | |
| ) | |
| examples = gr.Examples( | |
| examples=create_examples(), | |
| examples_per_page=5, | |
| inputs=ocr_inputs, | |
| fn=get_ocr_result, | |
| outputs=[img_output, ocr_results, elapse], | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |