Files
bachelor-thesis/scripts/plot_thpt.py
2026-05-21 01:56:53 +08:00

164 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import argparse
import csv
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import font_manager
BASE_FONT_SIZE = 25
TITLE_FONT_SIZE = BASE_FONT_SIZE + 1
AXIS_LABEL_FONT_SIZE = BASE_FONT_SIZE
TICK_FONT_SIZE = BASE_FONT_SIZE - 2
LEGEND_FONT_SIZE = BASE_FONT_SIZE - 2
BAR_LABEL_FONT_SIZE = BASE_FONT_SIZE - 4
MAX_LOSS = 0.02
def configure_fonts() -> None:
preferred_fonts = [
"Microsoft YaHei",
"SimHei",
"Noto Sans CJK SC",
"Source Han Sans SC",
"PingFang SC",
"WenQuanYi Micro Hei",
"Arial Unicode MS",
]
installed_fonts = {font.name for font in font_manager.fontManager.ttflist}
for font in preferred_fonts:
if font in installed_fonts:
plt.rcParams["font.sans-serif"] = [font]
break
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams.update(
{
"font.size": BASE_FONT_SIZE,
"axes.titlesize": TITLE_FONT_SIZE,
"axes.labelsize": AXIS_LABEL_FONT_SIZE,
"xtick.labelsize": TICK_FONT_SIZE,
"ytick.labelsize": TICK_FONT_SIZE,
"legend.fontsize": LEGEND_FONT_SIZE,
"figure.dpi": 150,
}
)
def read_throughput(csv_path: Path) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
losses: list[float] = []
cubic: list[float] = []
fec: list[float] = []
with csv_path.open("r", encoding="utf-8-sig", newline="") as file:
reader = csv.DictReader(file)
for row in reader:
losses.append(float(row["loss"]))
cubic.append(float(row["cubic_thpt_mbps"]))
fec.append(float(row["fec_thpt_mbps"]))
return np.array(losses), np.array(cubic), np.array(fec)
def format_loss_label(loss: float) -> str:
if loss == 0:
return "0"
percent = loss * 100
return f"{percent:g}%"
def filter_results(
losses: np.ndarray,
cubic: np.ndarray,
fec: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
keep = losses <= MAX_LOSS
return losses[keep], cubic[keep], fec[keep]
def plot_absolute_throughput(losses: np.ndarray, cubic: np.ndarray, fec: np.ndarray, output_path: Path) -> None:
labels = [format_loss_label(loss) for loss in losses]
fig, ax = plt.subplots(figsize=(5.6, 5.0), constrained_layout=True)
ax.plot(losses, cubic, marker="o", linewidth=2.4, markersize=7, label="直接转发", color='#4285F4')
ax.plot(losses, fec, marker="s", linewidth=2.4, markersize=7, label="本文方法", color='#EA4335')
ax.set_xlabel("丢包率")
ax.set_ylabel("吞吐量Mbps")
ax.set_xticks(losses, labels)
plt.setp(ax.get_xticklabels(), rotation=-35, ha="left", rotation_mode="anchor")
ax.grid(axis="y", linestyle="--", alpha=0.35)
ax.grid(axis="x", linestyle="--", alpha=0.22)
ax.legend(frameon=False)
output_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output_path, bbox_inches="tight")
plt.close(fig)
def plot_speedup(losses: np.ndarray, cubic: np.ndarray, fec: np.ndarray, output_path: Path) -> None:
speedup = fec / cubic
labels = [format_loss_label(loss) for loss in losses]
category_x = np.arange(len(labels))
fig, ax = plt.subplots(figsize=(5.6, 5.0), constrained_layout=True)
bars = ax.bar(category_x, speedup, color="#4285F4", width=0.62)
ax.axhline(1.0, color="#444444", linewidth=1.2, linestyle="--")
ax.set_xlabel("丢包率")
ax.set_ylabel("相对吞吐提升")
ax.set_xticks(category_x, labels)
ax.grid(axis="y", linestyle="--", alpha=0.35)
ax.set_ymargin(0.1)
for bar, value in zip(bars, speedup):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height(),
f"{value:.1f}x",
ha="center",
va="bottom",
fontsize=BAR_LABEL_FONT_SIZE,
)
output_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output_path, bbox_inches="tight")
plt.close(fig)
def plot_throughput(csv_path: Path, absolute_output_path: Path, speedup_output_path: Path) -> None:
configure_fonts()
losses, cubic, fec = read_throughput(csv_path)
losses, cubic, fec = filter_results(losses, cubic, fec)
plot_absolute_throughput(losses, cubic, fec, absolute_output_path)
plot_speedup(losses, cubic, fec, speedup_output_path)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Plot throughput comparison under different loss rates.")
parser.add_argument("--input", type=Path, default=Path("scripts/thpt.csv"), help="Path to throughput CSV.")
parser.add_argument(
"--absolute-output",
type=Path,
default=Path("figures/thpt_absolute.pdf"),
help="Output path for the absolute throughput figure.",
)
parser.add_argument(
"--speedup-output",
type=Path,
default=Path("figures/thpt_speedup.pdf"),
help="Output path for the relative speedup figure.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
plot_throughput(args.input, args.absolute_output, args.speedup_output)
if __name__ == "__main__":
main()