TA update.
This commit is contained in:
@@ -22,6 +22,9 @@ def test_data_loader():
|
||||
from datasets import get_data_loader
|
||||
|
||||
train_loader = get_data_loader("data", "train", (32, 32), 8, 0, True)
|
||||
num_classes = len(train_loader.dataset.classes)
|
||||
assert num_classes == 26, f"Expected 26 classes, got {num_classes}."
|
||||
|
||||
images, labels = next(iter(train_loader))
|
||||
# print labels
|
||||
print(" ".join(chr(65 + x) for x in labels))
|
||||
@@ -70,5 +73,3 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.unit)
|
||||
|
||||
main(args.unit)
|
||||
|
||||
Reference in New Issue
Block a user