import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer # Model configuration MODEL_ID = "microsoft/bitnet-b1.58-2B-4T" # Initialize model and tokenizer print("Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto" ) print("Model loaded successfully!") def chat(message, history, system_prompt, max_tokens, temperature, top_p): """ Generate a response using the BitNet model. Args: message: User's current message history: List of previous messages in the conversation system_prompt: System instruction for the model max_tokens: Maximum number of tokens to generate temperature: Sampling temperature top_p: Nucleus sampling parameter """ # Build conversation history messages = [{"role": "system", "content": system_prompt}] # Add conversation history for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) # Add current message messages.append({"role": "user", "content": message}) # Apply chat template prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize input inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate response with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True if temperature > 0 else False, pad_token_id=tokenizer.eos_token_id ) # Decode only the generated tokens (not the input) response = tokenizer.decode( outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True ) return response # Create Gradio interface with gr.Blocks(title="BitNet b1.58 2B Chat Demo") as demo: gr.Markdown(""" # 🚀 BitNet b1.58 2B-4T Chat Demo This is a demo of Microsoft's BitNet b1.58 2B model - a 1.58-bit Large Language Model trained on 4 trillion tokens. **Key Features:** - 1.58-bit weights (ternary: {-1, 0, 1}) - Significantly reduced memory footprint - Faster inference and lower energy consumption - Performance comparable to full-precision LLMs of similar size ⚠️ **Note:** This model is for research purposes. Responses may be unexpected, biased, or inaccurate. """) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( label="Conversation", height=500, show_copy_button=True ) msg = gr.Textbox( label="Your Message", placeholder="Type your message here...", lines=2 ) with gr.Row(): submit = gr.Button("Send", variant="primary") clear = gr.Button("Clear") with gr.Column(scale=1): system_prompt = gr.Textbox( label="System Prompt", value="You are a helpful AI assistant.", lines=3 ) max_tokens = gr.Slider( minimum=50, maximum=512, value=256, step=1, label="Max Tokens" ) temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P" ) gr.Markdown(""" ### Parameters Guide: - **Max Tokens:** Maximum length of response - **Temperature:** Higher = more creative, Lower = more focused - **Top P:** Nucleus sampling threshold """) # Event handlers def user_message(user_input, history): return "", history + [[user_input, None]] def bot_response(history, system_prompt, max_tokens, temperature, top_p): user_input = history[-1][0] bot_message = chat( user_input, history[:-1], system_prompt, max_tokens, temperature, top_p ) history[-1][1] = bot_message return history msg.submit( user_message, [msg, chatbot], [msg, chatbot], queue=False ).then( bot_response, [chatbot, system_prompt, max_tokens, temperature, top_p], chatbot ) submit.click( user_message, [msg, chatbot], [msg, chatbot], queue=False ).then( bot_response, [chatbot, system_prompt, max_tokens, temperature, top_p], chatbot ) clear.click(lambda: None, None, chatbot, queue=False) gr.Markdown(""" --- **Resources:** - [Model Card](https://huggingface.co/microsoft/bitnet-b1.58-2B-4T) - [Technical Paper](https://huggingface.co/papers/2504.12285) - [BitNet GitHub](https://github.com/microsoft/BitNet) """) # Launch the app if __name__ == "__main__": demo.queue() demo.launch()