TA update.

This commit is contained in:
unlockable
2024-04-09 00:34:23 +08:00
parent 1a180f3c89
commit bdb985ddb3

View File

@@ -22,6 +22,9 @@ def test_data_loader():
from datasets import get_data_loader
train_loader = get_data_loader("data", "train", (32, 32), 8, 0, True)
num_classes = len(train_loader.dataset.classes)
assert num_classes == 26, f"Expected 26 classes, got {num_classes}."
images, labels = next(iter(train_loader))
# print labels
print(" ".join(chr(65 + x) for x in labels))
@@ -70,5 +73,3 @@ if __name__ == "__main__":
args = parser.parse_args()
main(args.unit)
main(args.unit)