174 lines
6.0 KiB
Python
174 lines
6.0 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from matplotlib import font_manager
|
|
from matplotlib.lines import Line2D
|
|
|
|
|
|
OVERSEAS_CITIES = {"迪拜", "法兰克福", "雅加达", "开普敦"}
|
|
OVERSEAS_COLOR = "#4285F4"
|
|
DOMESTIC_COLOR = "#EA4335"
|
|
|
|
|
|
def leading_group_size(labels: list[str], group: set[str]) -> int:
|
|
size = 0
|
|
for label in labels:
|
|
if label not in group:
|
|
break
|
|
size += 1
|
|
return size
|
|
|
|
|
|
def configure_fonts() -> None:
|
|
"""Prefer common CJK fonts so Chinese city names render correctly."""
|
|
preferred_fonts = [
|
|
"Microsoft YaHei",
|
|
"SimHei",
|
|
"Noto Sans CJK SC",
|
|
"Source Han Sans SC",
|
|
"PingFang SC",
|
|
"WenQuanYi Micro Hei",
|
|
"Arial Unicode MS",
|
|
]
|
|
installed_fonts = {font.name for font in font_manager.fontManager.ttflist}
|
|
for font in preferred_fonts:
|
|
if font in installed_fonts:
|
|
plt.rcParams["font.sans-serif"] = [font]
|
|
break
|
|
plt.rcParams["axes.unicode_minus"] = False
|
|
plt.rcParams.update(
|
|
{
|
|
"font.size": 15,
|
|
"axes.titlesize": 19,
|
|
"axes.labelsize": 17,
|
|
"xtick.labelsize": 15,
|
|
"ytick.labelsize": 15,
|
|
"figure.dpi": 150,
|
|
}
|
|
)
|
|
|
|
|
|
def read_loss_matrix(csv_path: Path) -> tuple[list[str], list[str], np.ma.MaskedArray]:
|
|
with csv_path.open("r", encoding="utf-8-sig", newline="") as file:
|
|
reader = csv.reader(file)
|
|
header = next(reader)
|
|
receivers = header[1:]
|
|
|
|
senders: list[str] = []
|
|
values: list[list[float]] = []
|
|
mask: list[list[bool]] = []
|
|
for row in reader:
|
|
senders.append(row[0])
|
|
value_row: list[float] = []
|
|
mask_row: list[bool] = []
|
|
for item in row[1:]:
|
|
if item.strip().lower() in {"n/a", "na", ""}:
|
|
value_row.append(0.0)
|
|
mask_row.append(True)
|
|
else:
|
|
value_row.append(float(item))
|
|
mask_row.append(False)
|
|
values.append(value_row)
|
|
mask.append(mask_row)
|
|
|
|
return senders, receivers, np.ma.array(values, mask=mask)
|
|
|
|
|
|
def plot_heatmap(csv_path: Path, output_path: Path) -> None:
|
|
configure_fonts()
|
|
senders, receivers, matrix = read_loss_matrix(csv_path)
|
|
|
|
fig_width = max(8.2, 0.78 * len(receivers) + 2.6)
|
|
fig_height = max(6.2, 0.68 * len(senders) + 2.4)
|
|
fig, ax = plt.subplots(figsize=(fig_width, fig_height), constrained_layout=True)
|
|
|
|
cmap = plt.get_cmap("YlOrRd").copy()
|
|
cmap.set_bad(color="#eeeeee")
|
|
vmax = max(0.1, float(matrix.max()))
|
|
image = ax.imshow(matrix, cmap=cmap, vmin=0, vmax=vmax)
|
|
|
|
ax.set_xticks(np.arange(len(receivers)), labels=receivers)
|
|
ax.set_yticks(np.arange(len(senders)), labels=senders)
|
|
ax.set_xlabel("接收端")
|
|
ax.set_ylabel("发送端")
|
|
# ax.set_title("公网链路平均丢包率")
|
|
ax.xaxis.tick_top()
|
|
ax.xaxis.set_label_position("top")
|
|
ax.tick_params(axis="x", top=True, labeltop=True, bottom=False, labelbottom=False, pad=6)
|
|
|
|
plt.setp(ax.get_xticklabels(), rotation=-35, ha="right", rotation_mode="anchor")
|
|
for label in ax.get_xticklabels():
|
|
label.set_color(OVERSEAS_COLOR if label.get_text() in OVERSEAS_CITIES else DOMESTIC_COLOR)
|
|
for label in ax.get_yticklabels():
|
|
label.set_color(OVERSEAS_COLOR if label.get_text() in OVERSEAS_CITIES else DOMESTIC_COLOR)
|
|
|
|
legend_handles = [
|
|
Line2D([0], [0], marker="s", color="none", markerfacecolor=OVERSEAS_COLOR, markeredgecolor=OVERSEAS_COLOR, markersize=9, label="海外"),
|
|
Line2D([0], [0], marker="s", color="none", markerfacecolor=DOMESTIC_COLOR, markeredgecolor=DOMESTIC_COLOR, markersize=9, label="国内"),
|
|
]
|
|
ax.legend(
|
|
handles=legend_handles,
|
|
loc="lower center",
|
|
bbox_to_anchor=(0.5, 1.15),
|
|
ncol=2,
|
|
frameon=False,
|
|
columnspacing=1.4,
|
|
handletextpad=0.4,
|
|
)
|
|
|
|
for row_index in range(len(senders)):
|
|
for col_index in range(len(receivers)):
|
|
if matrix.mask[row_index, col_index]:
|
|
text = "-"
|
|
color = "#777777"
|
|
else:
|
|
value = float(matrix[row_index, col_index])
|
|
text = "0" if value == 0 else f"{value:.2f}".rstrip("0").rstrip(".")
|
|
color = "white" if value > 0.55 * vmax else "#222222"
|
|
ax.text(col_index, row_index, text, ha="center", va="center", color=color, fontsize=11)
|
|
|
|
sender_split = leading_group_size(senders, OVERSEAS_CITIES)
|
|
receiver_split = leading_group_size(receivers, OVERSEAS_CITIES)
|
|
if 0 < sender_split < len(senders) and 0 < receiver_split < len(receivers):
|
|
ax.axhline(sender_split - 0.5, color="#303030", linewidth=2.2)
|
|
ax.axvline(receiver_split - 0.5, color="#303030", linewidth=2.2)
|
|
|
|
ax.set_xticks(np.arange(len(receivers) + 1) - 0.5, minor=True)
|
|
ax.set_yticks(np.arange(len(senders) + 1) - 0.5, minor=True)
|
|
ax.grid(which="minor", color="white", linewidth=1.2)
|
|
ax.tick_params(which="minor", bottom=False, left=False)
|
|
|
|
colorbar = fig.colorbar(image, ax=ax, shrink=0.88)
|
|
colorbar.set_label("平均丢包率", fontsize=14)
|
|
colorbar.ax.tick_params(labelsize=12)
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
fig.savefig(output_path, bbox_inches="tight")
|
|
plt.close(fig)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Plot average public-link loss as a heatmap.")
|
|
parser.add_argument("--input", type=Path, default=Path("scripts/loss_avg.csv"), help="Path to loss CSV.")
|
|
parser.add_argument(
|
|
"--output",
|
|
type=Path,
|
|
default=Path("figures/loss_avg_heatmap.pdf"),
|
|
help="Output figure path, such as figures/loss_avg_heatmap.pdf or .png.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
plot_heatmap(args.input, args.output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|