canerkonuk's picture
Create app.py
233c6c0 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Model configuration
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
# Initialize model and tokenizer
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print("Model loaded successfully!")
def chat(message, history, system_prompt, max_tokens, temperature, top_p):
"""
Generate a response using the BitNet model.
Args:
message: User's current message
history: List of previous messages in the conversation
system_prompt: System instruction for the model
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
"""
# Build conversation history
messages = [{"role": "system", "content": system_prompt}]
# Add conversation history
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Apply chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True if temperature > 0 else False,
pad_token_id=tokenizer.eos_token_id
)
# Decode only the generated tokens (not the input)
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[-1]:],
skip_special_tokens=True
)
return response
# Create Gradio interface
with gr.Blocks(title="BitNet b1.58 2B Chat Demo") as demo:
gr.Markdown("""
# 🚀 BitNet b1.58 2B-4T Chat Demo
This is a demo of Microsoft's BitNet b1.58 2B model - a 1.58-bit Large Language Model trained on 4 trillion tokens.
**Key Features:**
- 1.58-bit weights (ternary: {-1, 0, 1})
- Significantly reduced memory footprint
- Faster inference and lower energy consumption
- Performance comparable to full-precision LLMs of similar size
⚠️ **Note:** This model is for research purposes. Responses may be unexpected, biased, or inaccurate.
""")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Conversation",
height=500,
show_copy_button=True
)
msg = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
lines=2
)
with gr.Row():
submit = gr.Button("Send", variant="primary")
clear = gr.Button("Clear")
with gr.Column(scale=1):
system_prompt = gr.Textbox(
label="System Prompt",
value="You are a helpful AI assistant.",
lines=3
)
max_tokens = gr.Slider(
minimum=50,
maximum=512,
value=256,
step=1,
label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05,
label="Top P"
)
gr.Markdown("""
### Parameters Guide:
- **Max Tokens:** Maximum length of response
- **Temperature:** Higher = more creative, Lower = more focused
- **Top P:** Nucleus sampling threshold
""")
# Event handlers
def user_message(user_input, history):
return "", history + [[user_input, None]]
def bot_response(history, system_prompt, max_tokens, temperature, top_p):
user_input = history[-1][0]
bot_message = chat(
user_input,
history[:-1],
system_prompt,
max_tokens,
temperature,
top_p
)
history[-1][1] = bot_message
return history
msg.submit(
user_message,
[msg, chatbot],
[msg, chatbot],
queue=False
).then(
bot_response,
[chatbot, system_prompt, max_tokens, temperature, top_p],
chatbot
)
submit.click(
user_message,
[msg, chatbot],
[msg, chatbot],
queue=False
).then(
bot_response,
[chatbot, system_prompt, max_tokens, temperature, top_p],
chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
gr.Markdown("""
---
**Resources:**
- [Model Card](https://huggingface.co/microsoft/bitnet-b1.58-2B-4T)
- [Technical Paper](https://huggingface.co/papers/2504.12285)
- [BitNet GitHub](https://github.com/microsoft/BitNet)
""")
# Launch the app
if __name__ == "__main__":
demo.queue()
demo.launch()