from __future__ import annotations import argparse import csv from pathlib import Path import matplotlib.pyplot as plt import numpy as np from matplotlib import font_manager from matplotlib.lines import Line2D OVERSEAS_CITIES = {"迪拜", "法兰克福", "雅加达", "开普敦"} OVERSEAS_COLOR = "#4285F4" DOMESTIC_COLOR = "#EA4335" BASE_FONT_SIZE = 20 CELL_FONT_SIZE = BASE_FONT_SIZE - 5 def leading_group_size(labels: list[str], group: set[str]) -> int: size = 0 for label in labels: if label not in group: break size += 1 return size def configure_fonts() -> None: """Prefer common CJK fonts so Chinese city names render correctly.""" preferred_fonts = [ "Microsoft YaHei", "SimHei", "Noto Sans CJK SC", "Source Han Sans SC", "PingFang SC", "WenQuanYi Micro Hei", "Arial Unicode MS", ] installed_fonts = {font.name for font in font_manager.fontManager.ttflist} for font in preferred_fonts: if font in installed_fonts: plt.rcParams["font.sans-serif"] = [font] break plt.rcParams["axes.unicode_minus"] = False plt.rcParams.update( { "font.size": BASE_FONT_SIZE, "axes.titlesize": BASE_FONT_SIZE + 4, "axes.labelsize": BASE_FONT_SIZE + 2, "xtick.labelsize": BASE_FONT_SIZE, "ytick.labelsize": BASE_FONT_SIZE, "figure.dpi": 150, } ) def read_loss_matrix(csv_path: Path) -> tuple[list[str], list[str], np.ma.MaskedArray]: with csv_path.open("r", encoding="utf-8-sig", newline="") as file: reader = csv.reader(file) header = next(reader) receivers = header[1:] senders: list[str] = [] values: list[list[float]] = [] mask: list[list[bool]] = [] for row in reader: senders.append(row[0]) value_row: list[float] = [] mask_row: list[bool] = [] for item in row[1:]: if item.strip().lower() in {"n/a", "na", ""}: value_row.append(0.0) mask_row.append(True) else: value_row.append(float(item)) mask_row.append(False) values.append(value_row) mask.append(mask_row) return senders, receivers, np.ma.array(values, mask=mask) def plot_heatmap(csv_path: Path, output_path: Path) -> None: configure_fonts() senders, receivers, matrix = read_loss_matrix(csv_path) fig_width = max(8.2, 0.78 * len(receivers) + 2.6) fig_height = max(6.2, 0.68 * len(senders) + 2.4) fig, ax = plt.subplots(figsize=(fig_width, fig_height), constrained_layout=True) cmap = plt.get_cmap("YlOrRd").copy() cmap.set_bad(color="#eeeeee") vmax = max(0.1, float(matrix.max())) image = ax.imshow(matrix, cmap=cmap, vmin=0, vmax=vmax) ax.set_xticks(np.arange(len(receivers)), labels=receivers) ax.set_yticks(np.arange(len(senders)), labels=senders) ax.set_xlabel("接收端") ax.set_ylabel("发送端") # ax.set_title("公网链路平均丢包率") ax.xaxis.tick_top() ax.xaxis.set_label_position("top") ax.tick_params(axis="x", top=True, labeltop=True, bottom=False, labelbottom=False, pad=6) plt.setp(ax.get_xticklabels(), rotation=-35, ha="right", rotation_mode="anchor") for label in ax.get_xticklabels(): label.set_color(OVERSEAS_COLOR if label.get_text() in OVERSEAS_CITIES else DOMESTIC_COLOR) for label in ax.get_yticklabels(): label.set_color(OVERSEAS_COLOR if label.get_text() in OVERSEAS_CITIES else DOMESTIC_COLOR) legend_handles = [ Line2D([0], [0], marker="s", color="none", markerfacecolor=OVERSEAS_COLOR, markeredgecolor=OVERSEAS_COLOR, markersize=9, label="海外"), Line2D([0], [0], marker="s", color="none", markerfacecolor=DOMESTIC_COLOR, markeredgecolor=DOMESTIC_COLOR, markersize=9, label="国内"), ] ax.legend( handles=legend_handles, loc="lower center", bbox_to_anchor=(0.5, 1.2), ncol=2, frameon=False, columnspacing=1.4, handletextpad=0.4, ) for row_index in range(len(senders)): for col_index in range(len(receivers)): if matrix.mask[row_index, col_index]: text = "-" color = "#777777" else: value = float(matrix[row_index, col_index]) text = "0" if value == 0 else f"{value:.2f}".rstrip("0").rstrip(".") color = "white" if value > 0.55 * vmax else "#222222" ax.text(col_index, row_index, text, ha="center", va="center", color=color, fontsize=CELL_FONT_SIZE) sender_split = leading_group_size(senders, OVERSEAS_CITIES) receiver_split = leading_group_size(receivers, OVERSEAS_CITIES) if 0 < sender_split < len(senders) and 0 < receiver_split < len(receivers): ax.axhline(sender_split - 0.5, color="#303030", linewidth=2.2) ax.axvline(receiver_split - 0.5, color="#303030", linewidth=2.2) ax.set_xticks(np.arange(len(receivers) + 1) - 0.5, minor=True) ax.set_yticks(np.arange(len(senders) + 1) - 0.5, minor=True) ax.grid(which="minor", color="white", linewidth=1.2) ax.tick_params(which="minor", bottom=False, left=False) colorbar = fig.colorbar(image, ax=ax, shrink=0.88) colorbar.set_label("平均丢包率", fontsize=BASE_FONT_SIZE) colorbar.ax.tick_params(labelsize=BASE_FONT_SIZE) output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path, bbox_inches="tight") plt.close(fig) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Plot average public-link loss as a heatmap.") parser.add_argument("--input", type=Path, default=Path("scripts/loss_avg.csv"), help="Path to loss CSV.") parser.add_argument( "--output", type=Path, default=Path("figures/loss_avg_heatmap.pdf"), help="Output figure path, such as figures/loss_avg_heatmap.pdf or .png.", ) return parser.parse_args() def main() -> None: args = parse_args() plot_heatmap(args.input, args.output) if __name__ == "__main__": main()