import gradio as gr from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer from transformers.image_utils import load_image from threading import Thread import time import torch import spaces import cv2 import numpy as np from PIL import Image from models import ( get_model_list, get_model_info, DEFAULT_GENERATION_PARAMS, get_preset_list, get_preset_params, get_preset_description ) def progress_bar_html(label: str) -> str: """ Returns an HTML snippet for a thin progress bar with a label. The progress bar is styled as a dark animated bar. """ return f'''
{label}
''' def downsample_video(video_path): """ Downsamples the video to 10 evenly spaced frames. Each frame is converted to a PIL Image along with its timestamp. """ vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = vidcap.get(cv2.CAP_PROP_FPS) frames = [] if total_frames <= 0 or fps <= 0: vidcap.release() return frames # Sample 10 evenly spaced frames. frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames # Initial model will be loaded when the first request comes in processor = None model = None current_model_name = None def load_model(model_name): """ Loads the model and processor based on the model name. Returns the model and processor. """ global processor, model, current_model_name # If the model is already loaded, return it if model is not None and current_model_name == model_name: return model, processor # Get model info model_info = get_model_info(model_name) MODEL_ID = model_info["id"] # Set dtype based on model info dtype = getattr(torch, model_info["dtype"]) # Load processor and model processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=dtype ).to(model_info["device"]).eval() # Update current model name current_model_name = model_name return model, processor @spaces.GPU def model_inference(input_dict, history, model_name, temperature=DEFAULT_GENERATION_PARAMS["temperature"], top_p=DEFAULT_GENERATION_PARAMS["top_p"], top_k=DEFAULT_GENERATION_PARAMS["top_k"], max_new_tokens=DEFAULT_GENERATION_PARAMS["max_new_tokens"], do_sample=DEFAULT_GENERATION_PARAMS["do_sample"], num_beams=DEFAULT_GENERATION_PARAMS["num_beams"], early_stopping=DEFAULT_GENERATION_PARAMS["early_stopping"], length_penalty=DEFAULT_GENERATION_PARAMS["length_penalty"], no_repeat_ngram_size=DEFAULT_GENERATION_PARAMS["no_repeat_ngram_size"], repetition_penalty=DEFAULT_GENERATION_PARAMS["repetition_penalty"]): # Load the selected model model, processor = load_model(model_name) text = input_dict["text"] files = input_dict["files"] if text.strip().lower().startswith("@video-infer"): # Remove the tag from the query. text = text[len("@video-infer"):].strip() if not files: gr.Error("Please upload a video file along with your @video-infer query.") return # Assume the first file is a video. video_path = files[0] frames = downsample_video(video_path) if not frames: gr.Error("Could not process video.") return # Build messages: start with the text prompt. messages = [ { "role": "user", "content": [{"type": "text", "text": text}] } ] # Append each frame with a timestamp label. for image, timestamp in frames: messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"}) messages[0]["content"].append({"type": "image", "image": image}) # Collect only the images from the frames. video_images = [image for image, _ in frames] # Prepare the prompt. prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[prompt], images=video_images, return_tensors="pt", padding=True, ).to("cuda") # Set up streaming generation. streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=do_sample, num_beams=num_beams, early_stopping=early_stopping, length_penalty=length_penalty, no_repeat_ngram_size=no_repeat_ngram_size, repetition_penalty=repetition_penalty ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" yield progress_bar_html(f"Processing video with {model_name}") for new_text in streamer: buffer += new_text time.sleep(0.01) yield buffer return if len(files) > 1: images = [load_image(image) for image in files] elif len(files) == 1: images = [load_image(files[0])] else: images = [] if text == "" and not images: gr.Error("Please input a query and optionally image(s).") return if text == "" and images: gr.Error("Please input a text query along with the image(s).") return messages = [ { "role": "user", "content": [ *[{"type": "image", "image": image} for image in images], {"type": "text", "text": text}, ], } ] prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[prompt], images=images if images else None, return_tensors="pt", padding=True, ).to("cuda") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=do_sample, num_beams=num_beams, early_stopping=early_stopping, length_penalty=length_penalty, no_repeat_ngram_size=no_repeat_ngram_size, repetition_penalty=repetition_penalty ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" yield progress_bar_html(f"Processing with {model_name}") for new_text in streamer: buffer += new_text time.sleep(0.01) yield buffer examples = [ [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}], [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}], [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}], [{"text": "@video-infer Explain the content of the video.", "files": ["example_images/sky.mp4"]}], ] def create_interface(): # Get the list of available models and presets model_options = get_model_list() preset_options = get_preset_list() def apply_preset(preset_name): """Helper function to apply parameter presets""" params = get_preset_params(preset_name) return [ params["temperature"], params["top_p"], params["top_k"], params["max_new_tokens"], params["do_sample"], params["num_beams"], params["early_stopping"], params["length_penalty"], params["no_repeat_ngram_size"], params["repetition_penalty"], get_preset_description(preset_name) ] with gr.Blocks() as demo: gr.Markdown("# **Qwen2.5 Series (add `@video-infer` for video understanding)**") with gr.Accordion("Model Settings", open=True): with gr.Row(): model_dropdown = gr.Dropdown( choices=model_options, value=model_options[0], label="Select Model" ) with gr.Row(): preset_dropdown = gr.Dropdown( choices=preset_options, value="Default", label="Parameter Preset" ) preset_description = gr.Textbox( value=get_preset_description("Default"), label="Preset Description", interactive=False ) # Button to apply the selected preset preset_button = gr.Button("Apply Preset") with gr.Row(): temperature = gr.Slider( minimum=0.0, maximum=2.0, value=DEFAULT_GENERATION_PARAMS["temperature"], step=0.1, label="Temperature", info="Higher values produce more diverse outputs" ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=DEFAULT_GENERATION_PARAMS["top_p"], step=0.05, label="Top P", info="Nucleus sampling: limit sampling to top P% of probability mass" ) with gr.Row(): top_k = gr.Slider( minimum=1, maximum=100, value=DEFAULT_GENERATION_PARAMS["top_k"], step=1, label="Top K", info="Limit sampling to top K most likely tokens" ) max_tokens = gr.Slider( minimum=64, maximum=2048, value=DEFAULT_GENERATION_PARAMS["max_new_tokens"], step=64, label="Max New Tokens", info="Maximum number of tokens to generate" ) with gr.Row(): do_sample = gr.Checkbox( value=DEFAULT_GENERATION_PARAMS["do_sample"], label="Do Sample", info="When enabled, uses sampling; when disabled, uses greedy decoding" ) num_beams = gr.Slider( minimum=1, maximum=10, value=DEFAULT_GENERATION_PARAMS["num_beams"], step=1, label="Beam Size", info="Number of beams for beam search (1 = no beam search)" ) with gr.Accordion("Advanced Parameters", open=False): with gr.Row(): repetition_penalty = gr.Slider( minimum=0.1, maximum=2.0, value=DEFAULT_GENERATION_PARAMS["repetition_penalty"], step=0.1, label="Repetition Penalty", info="Penalize repetition (1.0 = no penalty, > 1.0 = penalty)" ) length_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=DEFAULT_GENERATION_PARAMS["length_penalty"], step=0.1, label="Length Penalty", info="<1 favors shorter, >1 favors longer generations" ) with gr.Row(): no_repeat_ngram_size = gr.Slider( minimum=0, maximum=10, value=DEFAULT_GENERATION_PARAMS["no_repeat_ngram_size"], step=1, label="No Repeat NGram Size", info="Size of ngrams that can't be repeated (0 = no constraint)" ) early_stopping = gr.Checkbox( value=DEFAULT_GENERATION_PARAMS["early_stopping"], label="Early Stopping", info="Stop beam search when best beam is found" ) # Connect preset button with parameter controls preset_button.click( fn=apply_preset, inputs=[preset_dropdown], outputs=[ temperature, top_p, top_k, max_tokens, do_sample, num_beams, early_stopping, length_penalty, no_repeat_ngram_size, repetition_penalty, preset_description ] ) # Update description when preset is selected preset_dropdown.change( fn=lambda x: get_preset_description(x), inputs=[preset_dropdown], outputs=[preset_description] ) chatbot = gr.ChatInterface( fn=model_inference, additional_inputs=[ model_dropdown, temperature, top_p, top_k, max_tokens, do_sample, num_beams, early_stopping, length_penalty, no_repeat_ngram_size, repetition_penalty ], examples=examples, fill_height=True, textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"), stop_btn="Stop", multimodal=True, cache_examples=False, type="messages", ) return demo demo = create_interface() demo.launch(debug=True, mcp=True)