Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Model configuration | |
| MODEL_ID = "microsoft/bitnet-b1.58-2B-4T" | |
| # Initialize model and tokenizer | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| print("Model loaded successfully!") | |
| def chat(message, history, system_prompt, max_tokens, temperature, top_p): | |
| """ | |
| Generate a response using the BitNet model. | |
| Args: | |
| message: User's current message | |
| history: List of previous messages in the conversation | |
| system_prompt: System instruction for the model | |
| max_tokens: Maximum number of tokens to generate | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling parameter | |
| """ | |
| # Build conversation history | |
| messages = [{"role": "system", "content": system_prompt}] | |
| # Add conversation history | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| # Apply chat template | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True if temperature > 0 else False, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode only the generated tokens (not the input) | |
| response = tokenizer.decode( | |
| outputs[0][inputs['input_ids'].shape[-1]:], | |
| skip_special_tokens=True | |
| ) | |
| return response | |
| # Create Gradio interface | |
| with gr.Blocks(title="BitNet b1.58 2B Chat Demo") as demo: | |
| gr.Markdown(""" | |
| # 🚀 BitNet b1.58 2B-4T Chat Demo | |
| This is a demo of Microsoft's BitNet b1.58 2B model - a 1.58-bit Large Language Model trained on 4 trillion tokens. | |
| **Key Features:** | |
| - 1.58-bit weights (ternary: {-1, 0, 1}) | |
| - Significantly reduced memory footprint | |
| - Faster inference and lower energy consumption | |
| - Performance comparable to full-precision LLMs of similar size | |
| ⚠️ **Note:** This model is for research purposes. Responses may be unexpected, biased, or inaccurate. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| height=500, | |
| show_copy_button=True | |
| ) | |
| msg = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button("Send", variant="primary") | |
| clear = gr.Button("Clear") | |
| with gr.Column(scale=1): | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| value="You are a helpful AI assistant.", | |
| lines=3 | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=512, | |
| value=256, | |
| step=1, | |
| label="Max Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top P" | |
| ) | |
| gr.Markdown(""" | |
| ### Parameters Guide: | |
| - **Max Tokens:** Maximum length of response | |
| - **Temperature:** Higher = more creative, Lower = more focused | |
| - **Top P:** Nucleus sampling threshold | |
| """) | |
| # Event handlers | |
| def user_message(user_input, history): | |
| return "", history + [[user_input, None]] | |
| def bot_response(history, system_prompt, max_tokens, temperature, top_p): | |
| user_input = history[-1][0] | |
| bot_message = chat( | |
| user_input, | |
| history[:-1], | |
| system_prompt, | |
| max_tokens, | |
| temperature, | |
| top_p | |
| ) | |
| history[-1][1] = bot_message | |
| return history | |
| msg.submit( | |
| user_message, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=False | |
| ).then( | |
| bot_response, | |
| [chatbot, system_prompt, max_tokens, temperature, top_p], | |
| chatbot | |
| ) | |
| submit.click( | |
| user_message, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=False | |
| ).then( | |
| bot_response, | |
| [chatbot, system_prompt, max_tokens, temperature, top_p], | |
| chatbot | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| gr.Markdown(""" | |
| --- | |
| **Resources:** | |
| - [Model Card](https://huggingface.co/microsoft/bitnet-b1.58-2B-4T) | |
| - [Technical Paper](https://huggingface.co/papers/2504.12285) | |
| - [BitNet GitHub](https://github.com/microsoft/BitNet) | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() |