Files
bachelor-thesis/scripts/plot_loss_heatmap.py
2026-05-20 14:33:40 +08:00

174 lines
6.0 KiB
Python

from __future__ import annotations
import argparse
import csv
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import font_manager
from matplotlib.lines import Line2D
OVERSEAS_CITIES = {"迪拜", "法兰克福", "雅加达", "开普敦"}
OVERSEAS_COLOR = "#4285F4"
DOMESTIC_COLOR = "#EA4335"
def leading_group_size(labels: list[str], group: set[str]) -> int:
size = 0
for label in labels:
if label not in group:
break
size += 1
return size
def configure_fonts() -> None:
"""Prefer common CJK fonts so Chinese city names render correctly."""
preferred_fonts = [
"Microsoft YaHei",
"SimHei",
"Noto Sans CJK SC",
"Source Han Sans SC",
"PingFang SC",
"WenQuanYi Micro Hei",
"Arial Unicode MS",
]
installed_fonts = {font.name for font in font_manager.fontManager.ttflist}
for font in preferred_fonts:
if font in installed_fonts:
plt.rcParams["font.sans-serif"] = [font]
break
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams.update(
{
"font.size": 15,
"axes.titlesize": 19,
"axes.labelsize": 17,
"xtick.labelsize": 15,
"ytick.labelsize": 15,
"figure.dpi": 150,
}
)
def read_loss_matrix(csv_path: Path) -> tuple[list[str], list[str], np.ma.MaskedArray]:
with csv_path.open("r", encoding="utf-8-sig", newline="") as file:
reader = csv.reader(file)
header = next(reader)
receivers = header[1:]
senders: list[str] = []
values: list[list[float]] = []
mask: list[list[bool]] = []
for row in reader:
senders.append(row[0])
value_row: list[float] = []
mask_row: list[bool] = []
for item in row[1:]:
if item.strip().lower() in {"n/a", "na", ""}:
value_row.append(0.0)
mask_row.append(True)
else:
value_row.append(float(item))
mask_row.append(False)
values.append(value_row)
mask.append(mask_row)
return senders, receivers, np.ma.array(values, mask=mask)
def plot_heatmap(csv_path: Path, output_path: Path) -> None:
configure_fonts()
senders, receivers, matrix = read_loss_matrix(csv_path)
fig_width = max(8.2, 0.78 * len(receivers) + 2.6)
fig_height = max(6.2, 0.68 * len(senders) + 2.4)
fig, ax = plt.subplots(figsize=(fig_width, fig_height), constrained_layout=True)
cmap = plt.get_cmap("YlOrRd").copy()
cmap.set_bad(color="#eeeeee")
vmax = max(0.1, float(matrix.max()))
image = ax.imshow(matrix, cmap=cmap, vmin=0, vmax=vmax)
ax.set_xticks(np.arange(len(receivers)), labels=receivers)
ax.set_yticks(np.arange(len(senders)), labels=senders)
ax.set_xlabel("接收端")
ax.set_ylabel("发送端")
# ax.set_title("公网链路平均丢包率")
ax.xaxis.tick_top()
ax.xaxis.set_label_position("top")
ax.tick_params(axis="x", top=True, labeltop=True, bottom=False, labelbottom=False, pad=6)
plt.setp(ax.get_xticklabels(), rotation=-35, ha="right", rotation_mode="anchor")
for label in ax.get_xticklabels():
label.set_color(OVERSEAS_COLOR if label.get_text() in OVERSEAS_CITIES else DOMESTIC_COLOR)
for label in ax.get_yticklabels():
label.set_color(OVERSEAS_COLOR if label.get_text() in OVERSEAS_CITIES else DOMESTIC_COLOR)
legend_handles = [
Line2D([0], [0], marker="s", color="none", markerfacecolor=OVERSEAS_COLOR, markeredgecolor=OVERSEAS_COLOR, markersize=9, label="海外"),
Line2D([0], [0], marker="s", color="none", markerfacecolor=DOMESTIC_COLOR, markeredgecolor=DOMESTIC_COLOR, markersize=9, label="国内"),
]
ax.legend(
handles=legend_handles,
loc="lower center",
bbox_to_anchor=(0.5, 1.15),
ncol=2,
frameon=False,
columnspacing=1.4,
handletextpad=0.4,
)
for row_index in range(len(senders)):
for col_index in range(len(receivers)):
if matrix.mask[row_index, col_index]:
text = "-"
color = "#777777"
else:
value = float(matrix[row_index, col_index])
text = "0" if value == 0 else f"{value:.2f}".rstrip("0").rstrip(".")
color = "white" if value > 0.55 * vmax else "#222222"
ax.text(col_index, row_index, text, ha="center", va="center", color=color, fontsize=11)
sender_split = leading_group_size(senders, OVERSEAS_CITIES)
receiver_split = leading_group_size(receivers, OVERSEAS_CITIES)
if 0 < sender_split < len(senders) and 0 < receiver_split < len(receivers):
ax.axhline(sender_split - 0.5, color="#303030", linewidth=2.2)
ax.axvline(receiver_split - 0.5, color="#303030", linewidth=2.2)
ax.set_xticks(np.arange(len(receivers) + 1) - 0.5, minor=True)
ax.set_yticks(np.arange(len(senders) + 1) - 0.5, minor=True)
ax.grid(which="minor", color="white", linewidth=1.2)
ax.tick_params(which="minor", bottom=False, left=False)
colorbar = fig.colorbar(image, ax=ax, shrink=0.88)
colorbar.set_label("平均丢包率", fontsize=14)
colorbar.ax.tick_params(labelsize=12)
output_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output_path, bbox_inches="tight")
plt.close(fig)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Plot average public-link loss as a heatmap.")
parser.add_argument("--input", type=Path, default=Path("scripts/loss_avg.csv"), help="Path to loss CSV.")
parser.add_argument(
"--output",
type=Path,
default=Path("figures/loss_avg_heatmap.pdf"),
help="Output figure path, such as figures/loss_avg_heatmap.pdf or .png.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
plot_heatmap(args.input, args.output)
if __name__ == "__main__":
main()