Add Homework2.
This commit is contained in:
329
hw2/code/visualize.py
Normal file
329
hw2/code/visualize.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# ========================================================
|
||||
# Media and Cognition
|
||||
# Homework 2 Convolutional Neural Network
|
||||
# visual.py - Visualization
|
||||
# Student ID:
|
||||
# Name:
|
||||
# Tsinghua University
|
||||
# (C) Copyright 2024
|
||||
# ========================================================
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
import string
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from sklearn.manifold import TSNE
|
||||
from torch.autograd import Variable
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
from datasets import get_data_loader
|
||||
from networks import Classifier, ConvBlock
|
||||
|
||||
|
||||
class ConvFilterVisualization:
|
||||
def __init__(self, model, save_dir):
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.save_dir = save_dir
|
||||
self.conv_output = None
|
||||
|
||||
def hook_layer(self, layer_idx, filter_idx):
|
||||
def hook_function(module, input, output):
|
||||
# Gets the conv output of the selected filter (from selected layer)
|
||||
self.conv_output = output[0, filter_idx]
|
||||
|
||||
# Hook the selected layer
|
||||
self.hook = self.model[layer_idx].conv.register_forward_hook(hook_function)
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
conv_layer_indices,
|
||||
layer_idx,
|
||||
filter_idx,
|
||||
opt_steps,
|
||||
upscaling_steps=4,
|
||||
upscaling_factor=1.2,
|
||||
blur=None,
|
||||
):
|
||||
# Hook the selected layer
|
||||
self.hook_layer(conv_layer_indices[layer_idx], filter_idx)
|
||||
im_size = 32
|
||||
x = torch.rand(1, 3, im_size, im_size, requires_grad=True) * 2 - 1
|
||||
for _ in range(upscaling_steps):
|
||||
x = Variable(x, requires_grad=True)
|
||||
|
||||
optimizer = torch.optim.Adam([x], lr=0.1, weight_decay=1e-6)
|
||||
for n in range(opt_steps):
|
||||
optimizer.zero_grad()
|
||||
self.model(x)
|
||||
loss = -self.conv_output.mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
image = 255 * (x * 0.5 + 0.5).squeeze(0).permute(1, 2, 0).detach().numpy()
|
||||
im_size = int(upscaling_factor * im_size) # calculate new image size
|
||||
x = cv2.resize(
|
||||
image, (im_size, im_size), interpolation=cv2.INTER_CUBIC
|
||||
) # scale image up
|
||||
x = np.clip((x / 255 - 0.5) * 2, -1, 1)
|
||||
x = torch.from_numpy(x)
|
||||
x.requires_grad = True
|
||||
x = x.view(1, 3, im_size, im_size)
|
||||
if blur is not None:
|
||||
image = cv2.blur(image, (blur, blur))
|
||||
save_dir = os.path.join(self.save_dir, "layer_%d" % layer_idx)
|
||||
if not os.path.exists(save_dir):
|
||||
os.mkdir(save_dir)
|
||||
cv2.imwrite(
|
||||
os.path.join(save_dir, "filter_%d.jpg" % filter_idx), np.clip(image, 0, 255)
|
||||
)
|
||||
self.hook.remove()
|
||||
return image / 255
|
||||
|
||||
|
||||
class ConvFeatureVisualization:
|
||||
def __init__(self, model, save_dir):
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.save_dir = save_dir
|
||||
self.conv_output = None
|
||||
|
||||
def hook_layer(self, layer_idx):
|
||||
def hook_function(module, input, output):
|
||||
# Gets the conv output of the selected filter (from selected layer)
|
||||
self.conv_output = output[0]
|
||||
|
||||
# Hook the selected layer
|
||||
self.hook = self.model[layer_idx].relu.register_forward_hook(hook_function)
|
||||
|
||||
def visualize(self, conv_layer_indices, layer_idx, image):
|
||||
self.hook_layer(conv_layer_indices[layer_idx])
|
||||
self.model(image)
|
||||
save_dir = os.path.join(self.save_dir, "layer_%d" % layer_idx)
|
||||
w = 16
|
||||
h = int(self.conv_output.shape[0] / w)
|
||||
fig, axes = plt.subplots(h, w, figsize=(w / 1.6, h))
|
||||
plt.suptitle("output feature map of layer %d" % layer_idx)
|
||||
if not os.path.exists(save_dir):
|
||||
os.mkdir(save_dir)
|
||||
for i in range(self.conv_output.shape[0]):
|
||||
x = self.conv_output[i].detach().numpy()
|
||||
x = (
|
||||
((x - x.min()) / (x.max() - x.min()))
|
||||
if x.max() > x.min()
|
||||
else (x - x.min())
|
||||
)
|
||||
x = cv2.resize(x, (32, 32), interpolation=cv2.INTER_CUBIC)
|
||||
axes[i // w, i % w].imshow(x, cmap="rainbow")
|
||||
axes[i // w, i % w].set_title(str(i), fontsize="small")
|
||||
axes[i // w, i % w].axis("off")
|
||||
cv2.imwrite(os.path.join(save_dir, "channel_%d.jpg" % i), 255 * x)
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(save_dir, "feature_map.jpg"), dpi=200)
|
||||
plt.show()
|
||||
print(
|
||||
"Results are saved as {}".format(os.path.join(save_dir, "feature_map.jpg"))
|
||||
)
|
||||
self.hook.remove()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# set configurations of the visualization process
|
||||
parser.add_argument("--path", type=str, default="data", help="path to data file")
|
||||
parser.add_argument(
|
||||
"--epoch", type=int, default=15, help="epoch of checkpoint you want to load"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt_path", type=str, default="ckpt", help="path to load checkpoints"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--type",
|
||||
type=str,
|
||||
default="filter",
|
||||
choices=["filter", "feature", "tsne", "stn"],
|
||||
help="type of visualized data, can be filter, feature and tsne",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layer_idx",
|
||||
type=int,
|
||||
default=0,
|
||||
help="index of convolutional layer for visualizing filter and feature",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_idx",
|
||||
type=int,
|
||||
default=128,
|
||||
help="index of images for visualizing feature",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="visualized/",
|
||||
help="directory to save visualization results",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
if not os.path.exists(opt.save_dir):
|
||||
os.mkdir(opt.save_dir)
|
||||
print(
|
||||
"[Info] loading checkpoint from %s ..."
|
||||
% os.path.join(opt.ckpt_path, "ckpt_epoch_%d.pth" % opt.epoch)
|
||||
)
|
||||
checkpoint = torch.load(
|
||||
os.path.join(opt.ckpt_path, "ckpt_epoch_%d.pth" % opt.epoch)
|
||||
)
|
||||
configs = checkpoint["configs"]
|
||||
model = Classifier(
|
||||
configs["in_channels"],
|
||||
configs["num_classes"],
|
||||
configs["use_batch_norm"],
|
||||
configs["use_stn"],
|
||||
configs["dropout_prob"],
|
||||
)
|
||||
model.load_state_dict(checkpoint["model_state"])
|
||||
model.eval()
|
||||
stn = model.stn
|
||||
conv_net = model.conv_net
|
||||
fc_net = model.fc_net
|
||||
|
||||
if opt.type == "filter":
|
||||
filter_dir = os.path.join(opt.save_dir, "filter")
|
||||
if not os.path.exists(filter_dir):
|
||||
os.mkdir(filter_dir)
|
||||
|
||||
conv_layer_indices = []
|
||||
filter_nums = []
|
||||
for i, m in enumerate(conv_net.children()):
|
||||
if isinstance(m, ConvBlock):
|
||||
conv_layer_indices.append(i)
|
||||
filter_nums.append(m.conv.out_channels)
|
||||
|
||||
visual = ConvFilterVisualization(conv_net, filter_dir)
|
||||
w = 16
|
||||
h = int(filter_nums[opt.layer_idx] / w)
|
||||
fig, axes = plt.subplots(h, w, figsize=(w / 1.6, h))
|
||||
plt.suptitle("conv filters of layer %d" % opt.layer_idx)
|
||||
for i in range(filter_nums[opt.layer_idx]):
|
||||
x = visual.visualize(conv_layer_indices, opt.layer_idx, i, 30, blur=None)
|
||||
axes[i // w, i % w].imshow(x[:, :, 0], cmap="rainbow")
|
||||
axes[i // w, i % w].set_title(str(i), fontsize="small")
|
||||
axes[i // w, i % w].axis("off")
|
||||
plt.tight_layout()
|
||||
plt.savefig(
|
||||
os.path.join(opt.save_dir, "filter", "filter_layer_%d.jpg" % opt.layer_idx),
|
||||
dpi=200,
|
||||
)
|
||||
plt.show()
|
||||
print(
|
||||
"Results are saved as {}".format(
|
||||
os.path.join(
|
||||
opt.save_dir, "filter", "filter_layer_%d.jpg" % opt.layer_idx
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
elif opt.type == "feature":
|
||||
feature_dir = os.path.join(opt.save_dir, "feature")
|
||||
if not os.path.exists(feature_dir):
|
||||
os.mkdir(feature_dir)
|
||||
|
||||
conv_layer_indices = []
|
||||
for i, m in enumerate(conv_net.children()):
|
||||
if isinstance(m, ConvBlock):
|
||||
conv_layer_indices.append(i)
|
||||
|
||||
visual = ConvFeatureVisualization(conv_net, feature_dir)
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(0.5, 0.5),
|
||||
]
|
||||
)
|
||||
|
||||
dataset = ImageFolder(os.path.join(opt.path, "train"), transform=transform)
|
||||
img_idx = opt.image_idx
|
||||
image, _ = dataset[img_idx]
|
||||
image_out = 255 * (image / 2 + 0.5).permute(1, 2, 0).detach().numpy()
|
||||
image_out = cv2.cvtColor(image_out, cv2.COLOR_BGR2RGB)
|
||||
cv2.imwrite(os.path.join(feature_dir, "image.jpg"), image_out)
|
||||
# print(image.shape)
|
||||
visual.visualize(conv_layer_indices, opt.layer_idx, image.unsqueeze(0))
|
||||
|
||||
elif opt.type == "tsne":
|
||||
tsne_dir = os.path.join(opt.save_dir, "tsne")
|
||||
if not os.path.exists(tsne_dir):
|
||||
os.mkdir(tsne_dir)
|
||||
|
||||
data_loader = get_data_loader(
|
||||
opt.path, "train", image_size=(32, 32), batch_size=8
|
||||
)
|
||||
labels = []
|
||||
features = []
|
||||
with torch.no_grad():
|
||||
for x, y in data_loader:
|
||||
x, y = x.float(), y.long()
|
||||
x = stn(x)
|
||||
x = conv_net(x)
|
||||
x = x.contiguous().view(x.shape[0], -1)
|
||||
x = fc_net[0](x)
|
||||
x = fc_net[1](x)
|
||||
features.append(copy.deepcopy(x.detach()))
|
||||
labels.append(copy.deepcopy(y))
|
||||
features = torch.cat(features, dim=0)
|
||||
labels = torch.cat(labels, dim=0)
|
||||
Y = TSNE(
|
||||
n_components=2, init="pca", random_state=0, learning_rate="auto"
|
||||
).fit_transform(features[:800].numpy())
|
||||
labels = labels[:800].numpy()
|
||||
|
||||
letters = list(string.ascii_letters[-26:])
|
||||
Y = (Y - Y.min(0)) / (Y.max(0) - Y.min(0))
|
||||
for i in range(len(labels)):
|
||||
c = plt.cm.rainbow(float(labels[i]) / 26)
|
||||
plt.text(Y[i, 0], Y[i, 1], s=letters[labels[i]], color=c)
|
||||
plt.savefig(os.path.join(tsne_dir, "tsne.jpg"), dpi=300)
|
||||
plt.show()
|
||||
print("Results are saved as {}".format(os.path.join(tsne_dir, "tsne.jpg")))
|
||||
|
||||
else:
|
||||
stn_dir = os.path.join(opt.save_dir, "stn")
|
||||
if not os.path.exists(stn_dir):
|
||||
os.mkdir(stn_dir)
|
||||
|
||||
data_loader = get_data_loader(
|
||||
opt.path, "train", image_size=(32, 32), batch_size=16
|
||||
)
|
||||
labels = []
|
||||
features = []
|
||||
with torch.no_grad():
|
||||
x, y = next(iter(data_loader))
|
||||
x_transformed = stn(x)
|
||||
|
||||
img_original = make_grid((x + 1) / 2).cpu().numpy().transpose(1, 2, 0)
|
||||
img_transformed = (
|
||||
make_grid((x_transformed + 1) / 2).cpu().numpy().transpose(1, 2, 0)
|
||||
)
|
||||
fig, axes = plt.subplots(2, 1, figsize=(6, 4))
|
||||
plt.suptitle("The Effect of the Spatial Transformer Network")
|
||||
axes[0].imshow(img_original)
|
||||
axes[0].set_title("original")
|
||||
axes[0].axis("off")
|
||||
axes[1].imshow(img_transformed)
|
||||
axes[1].set_title("transformed")
|
||||
axes[1].axis("off")
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(opt.save_dir, "stn", "stn.jpg"), dpi=200)
|
||||
plt.show()
|
||||
print(
|
||||
"Results are saved as {}".format(
|
||||
os.path.join(opt.save_dir, "stn", "stn.jpg")
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user