Commit
·
705fc3e
1
Parent(s):
938934c
upgrade graph
Browse files
app.py
CHANGED
|
@@ -45,7 +45,7 @@ def generate_chart(data: Union[Dict, List[Dict], pd.DataFrame],
|
|
| 45 |
x_label: str = None,
|
| 46 |
y_label: str = None) -> str:
|
| 47 |
"""
|
| 48 |
-
Generate a chart from data and return it as
|
| 49 |
|
| 50 |
Args:
|
| 51 |
data: The data to plot (can be a list of dicts or a pandas DataFrame)
|
|
@@ -57,7 +57,7 @@ def generate_chart(data: Union[Dict, List[Dict], pd.DataFrame],
|
|
| 57 |
y_label: Label for y-axis
|
| 58 |
|
| 59 |
Returns:
|
| 60 |
-
|
| 61 |
"""
|
| 62 |
try:
|
| 63 |
# Convert data to DataFrame if it's a list of dicts
|
|
@@ -72,6 +72,7 @@ def generate_chart(data: Union[Dict, List[Dict], pd.DataFrame],
|
|
| 72 |
return "Error: Data must be a dictionary, list of dictionaries, or pandas DataFrame"
|
| 73 |
|
| 74 |
# Generate the appropriate chart type
|
|
|
|
| 75 |
if chart_type == 'bar':
|
| 76 |
fig = px.bar(df, x=x, y=y, title=title)
|
| 77 |
elif chart_type == 'line':
|
|
@@ -90,11 +91,24 @@ def generate_chart(data: Union[Dict, List[Dict], pd.DataFrame],
|
|
| 90 |
xaxis_title=x_label or x,
|
| 91 |
yaxis_title=y_label or (y if y != x else ''),
|
| 92 |
title=title or f"{chart_type.capitalize()} Chart of {x} vs {y}" if y else f"{chart_type.capitalize()} Chart of {x}",
|
| 93 |
-
template="plotly_white"
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
|
| 96 |
-
#
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
except Exception as e:
|
| 100 |
error_msg = f"Error generating chart: {str(e)}"
|
|
@@ -664,15 +678,22 @@ def create_ui():
|
|
| 664 |
if not env_ok:
|
| 665 |
gr.Warning("⚠️ " + env_message)
|
| 666 |
|
| 667 |
-
#
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
|
| 677 |
# Input area
|
| 678 |
with gr.Row():
|
|
|
|
| 45 |
x_label: str = None,
|
| 46 |
y_label: str = None) -> str:
|
| 47 |
"""
|
| 48 |
+
Generate a chart from data and return it as a base64 encoded image.
|
| 49 |
|
| 50 |
Args:
|
| 51 |
data: The data to plot (can be a list of dicts or a pandas DataFrame)
|
|
|
|
| 57 |
y_label: Label for y-axis
|
| 58 |
|
| 59 |
Returns:
|
| 60 |
+
Markdown string with embedded image
|
| 61 |
"""
|
| 62 |
try:
|
| 63 |
# Convert data to DataFrame if it's a list of dicts
|
|
|
|
| 72 |
return "Error: Data must be a dictionary, list of dictionaries, or pandas DataFrame"
|
| 73 |
|
| 74 |
# Generate the appropriate chart type
|
| 75 |
+
fig = None
|
| 76 |
if chart_type == 'bar':
|
| 77 |
fig = px.bar(df, x=x, y=y, title=title)
|
| 78 |
elif chart_type == 'line':
|
|
|
|
| 91 |
xaxis_title=x_label or x,
|
| 92 |
yaxis_title=y_label or (y if y != x else ''),
|
| 93 |
title=title or f"{chart_type.capitalize()} Chart of {x} vs {y}" if y else f"{chart_type.capitalize()} Chart of {x}",
|
| 94 |
+
template="plotly_white",
|
| 95 |
+
margin=dict(l=20, r=20, t=40, b=20),
|
| 96 |
+
height=400
|
| 97 |
)
|
| 98 |
|
| 99 |
+
# Save the figure to a temporary file
|
| 100 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
| 101 |
+
fig.write_image(temp_file.name, format='png', engine='kaleido')
|
| 102 |
+
|
| 103 |
+
# Read the image file and encode as base64
|
| 104 |
+
with open(temp_file.name, 'rb') as img_file:
|
| 105 |
+
img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
|
| 106 |
+
|
| 107 |
+
# Clean up the temporary file
|
| 108 |
+
os.unlink(temp_file.name)
|
| 109 |
+
|
| 110 |
+
# Return as markdown image
|
| 111 |
+
return f'<img src="data:image/png;base64,{img_base64}" style="max-width:100%;"/>'
|
| 112 |
|
| 113 |
except Exception as e:
|
| 114 |
error_msg = f"Error generating chart: {str(e)}"
|
|
|
|
| 678 |
if not env_ok:
|
| 679 |
gr.Warning("⚠️ " + env_message)
|
| 680 |
|
| 681 |
+
# Create the chat interface
|
| 682 |
+
with gr.Row():
|
| 683 |
+
chatbot = gr.Chatbot(
|
| 684 |
+
[],
|
| 685 |
+
elem_id="chatbot",
|
| 686 |
+
bubble_full_width=False,
|
| 687 |
+
avatar_images=(
|
| 688 |
+
None,
|
| 689 |
+
(os.path.join(os.path.dirname(__file__), "logo.png")),
|
| 690 |
+
),
|
| 691 |
+
height=600,
|
| 692 |
+
render_markdown=True, # Enable markdown rendering
|
| 693 |
+
show_label=False,
|
| 694 |
+
show_share_button=False,
|
| 695 |
+
likeable=False
|
| 696 |
+
)
|
| 697 |
|
| 698 |
# Input area
|
| 699 |
with gr.Row():
|