import json import os import argparse import numpy as np import matplotlib.pyplot as plt import seaborn as sns import pandas as pd # <--- 必须引入 pandas def main(): parser = argparse.ArgumentParser() # 这里的默认路径对应你刚才运行的输出路径 parser.add_argument("--output_dir", type=str, default="analysis_output_parallel", help="Directory with partial .json results") args = parser.parse_args() print(f"📂 Reading results from: {args.output_dir}") # 1. 合并结果 final_results = {"Gzip": [], "Tokenizer": [], "AC_M1": []} files_found = 0 if not os.path.exists(args.output_dir): print(f"❌ Error: Directory {args.output_dir} does not exist.") return for filename in os.listdir(args.output_dir): if filename.startswith("partial_result_") and filename.endswith(".json"): files_found += 1 file_path = os.path.join(args.output_dir, filename) try: with open(file_path, 'r') as f: data = json.load(f) for k in final_results: if k in data: final_results[k].extend(data[k]) except Exception as e: print(f"⚠️ Error reading {filename}: {e}") print(f"✅ Merged data from {files_found} files.") # 2. 准备绘图数据 plot_records = [] stats_summary = {} for algo, vals in final_results.items(): if not vals: continue # 过滤异常值 (大于 2.0 的通常是极少数的离群点) cleaned = [v for v in vals if v < 2.0] # 记录统计信息 stats_summary[algo] = { "mean": float(np.mean(vals)), "median": float(np.median(vals)), "count": len(vals) } # 构建用于 DataFrame 的列表 for v in cleaned: plot_records.append({"Algorithm": algo, "Normalized Edit Distance": v}) if not plot_records: print("❌ No valid data collected to plot.") return # === 关键修正:转换为 Pandas DataFrame === df = pd.DataFrame(plot_records) print(f"📊 Plotting {len(df)} data points...") # 3. 绘图 plt.figure(figsize=(12, 7)) sns.set_style("whitegrid") # 使用 DataFrame 进行绘图 sns.kdeplot( data=df, x="Normalized Edit Distance", hue="Algorithm", fill=True, common_norm=False, palette="tab10", alpha=0.5, linewidth=2 ) plt.title("Compression Stability Analysis (Impact of 10% Perturbation)") plt.xlabel("Normalized Levenshtein Distance (Lower = More Stable)") plt.ylabel("Density") plt.xlim(0, 1.2) # 聚焦在 0~1.2 范围内 output_img = os.path.join(args.output_dir, "stability_parallel_fixed.png") plt.savefig(output_img, dpi=300) print(f"🖼️ Plot saved to: {output_img}") # 4. 保存统计结果 stats_file = os.path.join(args.output_dir, "final_stats_summary.json") with open(stats_file, 'w') as f: json.dump(stats_summary, f, indent=2) print(f"📄 Stats saved to: {stats_file}") # 打印简要统计 print("\n=== Summary Stats ===") for algo, stat in stats_summary.items(): print(f"{algo}: Mean={stat['mean']:.4f}, Count={stat['count']}") if __name__ == "__main__": main()