반응형
아래 코드 스니펫은 인턴 생활을 하면서 새로 개발하거나 이전 논문에서 있었던 모듈을 재구현 한 것이다.
CGN에의 training scheme 구현. 아래에 원본 github과 paper가 있습니다.
필자의 training scheme과 구조가 달라서 따로 함수로 적출한 형태입니다.. 원본에서는 Model Class 내에서 dataloading과 train을 하도록 구현되어있었습니다.
https://github.com/KAIST-vilab/OC-CSE
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import os
from util import pyutils
from module.dataloader import get_dataloader
from module.model import get_model
from module.optimizer import get_optimizer, get_scheduler
def model_freezing(model):
for param in model.parameters():
param.requires_grad = False
return model
def model_defreezing(model):
for param in model.parameters():
param.requires_grad = True
return model
def _work(model1, model2, args):
'''
model1, model2 must be pretrained
'''
# model to be train
model1.to(args.device)
model1.train()
model1 = model_freezing(model1)
# model make guidance and fixed
model2.to(args.device)
model2.eval()
model2 = model_freezing(model2)
train_loader = get_dataloader(args) # load dataset
args.max_iters = args.epochs * len(train_loader) ## max_iter adjust before optimizer definition
optimizer = get_optimizer(args, model1, args.max_iters)
scheduler = get_scheduler(args, optimizer)
## printing utility
avg_meter = pyutils.AverageMeter('loss', 'cls_loss1', 'cls_loss2')
timer = pyutils.Timer("Session started: ")
for ep in range(args.epochs):
loss_weight = args.cc_init + args.cc_slope * ep
for iteration in range(len(train_loader)):
pack = next(iter(train_loader))
img = pack['img'].to(args.device)
label = pack['label'].to(args.device)
label_all = label.clone().to(args.device)
label_remain = label.clone().to(args.device)
label_mask = torch.zeros((label.shape)).to(args.device)
for i in range(label.shape[0]):
label_idx = torch.nonzero(label[i], as_tuple=False)
rand_idx = torch.randint(0, len(label_idx), (1,))
target = label_idx[rand_idx][0]
label_remain[i, target] = 0
label_mask[i, target] = 1
pred1 = model1(img)
cam1 = model1.forward_cam(img).to(args.device) # (1,c,h,w)
mask = cam1[label_mask == 1, :, :].unsqueeze(1) # (1,c',h,w)
# mask = F.interpolate(mask, size=img.size()[2:], mode='bilinear', align_corners=False)
mask = F.interpolate(mask, size=(args.model2_crop_size, args.model2_crop_size),
mode='bilinear', align_corners=False)
mask = F.relu(mask)
mask = mask / (torch.max(mask) + 1e-5)
img = F.interpolate(img, size=(args.model2_crop_size, args.model2_crop_size),
mode='bilinear', align_corners=False)
# import pdb
# pdb.set_trace()
masked_img = img * (1 - mask)
pred2 = model2(masked_img)
cls_loss1 = F.binary_cross_entropy_with_logits(pred1, label_all)
cls_loss2 = loss_weight * F.binary_cross_entropy_with_logits(pred2, label_remain)
# cls_loss1 = F.multilabel_soft_margin_loss(pred1, label_all)
# cls_loss2 = loss_weight * F.multilabel_soft_margin_loss(pred2, label_remain)
total_loss = cls_loss1 + cls_loss2
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
avg_meter.add({'loss': total_loss.item(), 'cls_loss1': cls_loss1.item(), 'cls_loss2': cls_loss2.item()})
if iteration % 100 == 0: # optimizer.global_step-1
timer.update_progress((ep * len(train_loader) + iteration + 1) / args.max_iters)
print('Iter:%5d/%5d' % (ep * len(train_loader) + iteration, args.max_iters),
'Total Loss:%.4f' % (avg_meter.pop('loss')),
'Loss1:%.4f' % (avg_meter.pop('cls_loss1')),
'Loss2:%.4f' % (avg_meter.pop('cls_loss2')),
'Rem:%s' % (timer.get_est_remain()),
'lr: %.4f' % (optimizer.param_groups[0]['lr']),
flush=True)
timer.reset_stage()
if scheduler is not None: scheduler.step()
def run_cgensemble(args):
model1, model2 = get_model(args)
_work(model1, model2, args)
torch.cuda.empty_cache()
Random Activation Cropout
Resnet과 VIT 모두에서 동작 가능하게 만들었고, 중간중간의 activation map과 attention map을 잘라낼 수 있도록 하는 layer입니다. 생각보다 성능이 별로라서 폐기했습니다
import torch
import torch.nn as nn
import random
import math
class ActivationCropOut(nn.Module):
def __init__(self, prob=0.5, training=True, attention_model=False):
super(ActivationCropOut, self).__init__()
# self.do = random.random()
self.prob = prob
self.training = training
self.attention_model = attention_model
def forward(self, x):
do = random.random()
if do > self.prob or self.training is False:
return x
if self.attention_model:
class_token = x[:,0]
x = x[:,1:]
b,pos,hid = x.shape
width = int(math.sqrt(pos))
x = x.reshape(b, width, width, hid)
x = x.permute(0,3,1,2) # (b, hid, width, width)
b,c,h,w = x.shape
## ACL
rand_h = torch.randint(0, h, (1,))
rand_w = torch.randint(0, w, (1,))
min_h, max_h = rand_h, min(rand_h + 64, h)
min_w, max_w = rand_w, min(rand_w + 64, w)
mask = torch.ones((h,w), dtype=torch.float).cuda()
mask[min_h:max_h,min_w:max_w] = 0.
x = x * mask
if self.attention_model:
x = x.reshape(x.shape[0], x.shape[1], x.shape[2]**2)
x = x.permute(0,2,1)
x = torch.cat([class_token.unsqueeze(dim=1), x], dim=1)
return x
ADL : Attention Based Dropout Layer
Resnet과 VIT 모두에서 동작 가능하게 수정했었다.
원본 논문 및 github는 아래와 같다.
https://github.com/junsukchoe/ADL
https://arxiv.org/abs/1908.10028
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import math
class AttentionbasedDropout(nn.Module):
'''
extension version of ADL for attention based model
'''
def __init__(self, thres=0.90, training=True, attention_model=False):
super(AttentionbasedDropout, self).__init__()
# self.drop_or_importance = (random.random() > 0.5)
self.training = training
self.attention_model = attention_model
self.thres = thres
def forward(self, x):
if self.training is False:
return x
if self.attention_model: # reshape to 3D tensor
class_token = x[:,0]
x = x[:,1:]
b,pos,hid = x.shape
width = int(math.sqrt(pos))
x = x.reshape(b, width, width, hid)
x = x.permute(0,3,1,2) # (b, hid, width, width)
drop_or_importance = random.random()
b,c,h,w = x.shape
# Importance Map
attention = torch.mean(x, dim=1, keepdim=True) # [b,1,h,w]
importance_map = torch.sigmoid(attention) # [b,1,h,w]
# drop mask
max_val = torch.amax(attention, dim=(1,2,3), keepdim=True) # [b,1]
thres_val = max_val * self.thres # [b,1]
drop_mask = (attention < thres_val).to(torch.float32) # [b,1,h,w]
x = x * drop_mask if drop_or_importance < 0.75 else x * importance_map # (b,c,h,w)
if self.attention_model:
x = x.reshape(x.shape[0], x.shape[1], x.shape[2]**2) # (b, hid, width**2)
x = x.permute(0,2,1) # (b, width**2, hid)
x = torch.cat([class_token.unsqueeze(dim=1), x], dim=1) # (b, width**2 + 1, hid)
return x
object 영역에 대해서만 가중치를 주기 위해서 masking 작업을 수행하는 CrossEntropy Loss 입니다.
class MaskingCELoss(nn.Module):
def __init__(self, ignore_index=255, mask_val=0.001):
super(MaskingCELoss, self).__init__()
self.ignore_index = ignore_index
self.mask_val = mask_val
def forward(self, inputs, targets):
device = inputs.device
b,c,h,w = inputs.shape
inputs = inputs.permute(0,2,3,1).reshape(-1, c)
targets = targets.reshape(-1)
rows = torch.arange(0,len(targets)).to(device, non_blocking=True)
logs = F.log_softmax(inputs, dim=-1)
# clearing ignore target
logs_mask = torch.ones_like(logs)
logs_mask[targets==self.ignore_index, :] = 0
targets_mask = torch.ones_like(targets)
targets_mask[targets==self.ignore_index] = 0
# clearing ignore target
logs = logs * logs_mask
targets = targets * targets_mask
# getting log likelihood
out = logs[rows, targets]
mask = torch.zeros_like(out)
mask[targets!=0] = self.mask_val
out = out * mask
return -out.sum()/len(out)
ViT를 수정할 필요가 있어서, torchvision을 기반으로 수정한 내역입니다.
중간에 gradCAM을 구할 수 있도록 함수가 추가되어있고, ADL및 ACL을 시도한 흔적이 있습니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vit_b_16, ViT_B_16_Weights
import numpy as np
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import os
from module.vit_cam import reshape_transform
from network.attentionbased_dropout import AttentionbasedDropout
class MLP(torch.nn.Sequential):
"""This block implements the multi-layer perceptron (MLP) module.
Args:
in_channels (int): Number of channels of the input
hidden_channels (List[int]): List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool): Whether to use bias in the linear layer. Default ``True``
dropout (float): The probability for the dropout layer. Default: 0.0
"""
def __init__(
self,
in_channels: int,
hidden_channels: List[int],
norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
inplace: Optional[bool] = True,
bias: bool = True,
dropout: float = 0.0,
):
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
params = {} if inplace is None else {"inplace": inplace}
layers = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
if norm_layer is not None:
layers.append(norm_layer(hidden_dim))
layers.append(activation_layer(**params))
layers.append(torch.nn.Dropout(dropout, **params))
in_dim = hidden_dim
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
layers.append(torch.nn.Dropout(dropout, **params))
super().__init__(*layers)
# _log_api_usage_once(self)
class MLPBlock(MLP):
"""Transformer MLP block."""
_version = 2
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
for i in range(2):
for type in ["weight", "bias"]:
old_key = f"{prefix}linear_{i+1}.{type}"
new_key = f"{prefix}{3*i}.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
return x + y
class EncoderBlock_ADL(nn.Module):
"""Transformer encoder block."""
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
# activation crop
self.attbased_drop = AttentionbasedDropout(training=True, attention_model=True)
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
out = self.attbased_drop(x + y)
return out
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
def __init__(
self,
seq_length: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
network_type: str = 'vit'
):
super().__init__()
# Note that batch_size is on the first dim because
# we have batch_first=True in nn.MultiAttention() by default
self.network_type = network_type
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
self.dropout = nn.Dropout(dropout)
layers: OrderedDict[str, nn.Module] = OrderedDict()
if 'vit_adl' in self.network_type:
for i in range(num_layers):
if i == 8 or i == 9 or i == 10 or i == 11:
layers[f"encoder_layer_{i}"] = EncoderBlock_ADL(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
else:
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
else:
for i in range(num_layers):
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.layers = nn.Sequential(layers)
self.ln = norm_layer(hidden_dim)
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
input = input + self.pos_embedding
return self.ln(self.layers(self.dropout(input)))
class VisionTransformer(nn.Module):
"""Vision Transformer as per https://arxiv.org/abs/2010.11929."""
def __init__(
self,
image_size: int,
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float = 0.0,
attention_dropout: float = 0.0,
num_classes: int = 1000,
representation_size: Optional[int] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
use_cuda: bool =True,
network_type: str = 'vit'
):
super().__init__()
# _log_api_usage_once(self)
torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
self.image_size = image_size
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.attention_dropout = attention_dropout
self.dropout = dropout
self.num_classes = num_classes
self.representation_size = representation_size
self.norm_layer = norm_layer
self.use_cuda = use_cuda
self.network_type = network_type
self.conv_proj = nn.Conv2d(
in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
)
seq_length = (image_size // patch_size) ** 2
# Add a class token
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
seq_length += 1
self.encoder = Encoder(
seq_length,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.seq_length = seq_length
heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
if representation_size is None:
heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
nn.init.xavier_uniform_(heads_layers["head"].weight)
else:
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
heads_layers["act"] = nn.Tanh()
heads_layers["head"] = nn.Linear(representation_size, num_classes)
self.heads = nn.Sequential(heads_layers)
if isinstance(self.conv_proj, nn.Conv2d):
# Init the patchify stem
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
if self.conv_proj.bias is not None:
nn.init.zeros_(self.conv_proj.bias)
elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
# Init the last 1x1 conv of the conv stem
nn.init.normal_(
self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
)
if self.conv_proj.conv_last.bias is not None:
nn.init.zeros_(self.conv_proj.conv_last.bias)
if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
fan_in = self.heads.pre_logits.in_features
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.heads.pre_logits.bias)
if isinstance(self.heads.head, nn.Linear):
torch.nn.init.xavier_uniform_(self.heads.head.weight)
nn.init.zeros_(self.heads.head.bias)
# nn.init.zeros_(self.heads.head.weight)
# nn.init.zeros_(self.heads.head.bias)
self.not_training = [self.conv_proj]
def _process_input(self, x: torch.Tensor) -> torch.Tensor:
n, c, h, w = x.shape
p = self.patch_size
torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
n_h = h // p
n_w = w // p
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = self.conv_proj(x)
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
x = x.reshape(n, self.hidden_dim, n_h * n_w)
# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
# The self attention layer expects inputs in the format (N, S, E)
# where S is the source sequence length, N is the batch size, E is the
# embedding dimension
x = x.permute(0, 2, 1)
return x
def forward(self, x: torch.Tensor):
# Reshape and permute the input tensor
x = self._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0]
x = self.heads(x)
return x
def forward_cam(self, img):
target_layers = [self.encoder.layers.encoder_layer_11.ln_1]
# Construct the CAM object once, and then re-use it on many images:
cam = GradCAM(model=self, target_layers=target_layers, use_cuda=self.use_cuda,
reshape_transform=reshape_transform)
grayscale_cam_list_allbatch = []
for i in range(img.shape[0]): # for batch size
grayscale_cam_list = []
for idx in range(self.num_classes):
targets = [ClassifierOutputTarget(idx)]
# You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
grayscale_cam = cam(input_tensor=img[i].unsqueeze(dim=0), targets=targets,\
eigen_smooth=True, aug_smooth=True) # numpy.ndarray (1, 384, 384)
# In this example grayscale_cam has only one image in the batch:
grayscale_cam = grayscale_cam[0, :] # (384, 384)
grayscale_cam_list.append(grayscale_cam)
grayscale_cam_stacked = np.stack(grayscale_cam_list, axis=0) # (C, 384, 384)
grayscale_cam_stacked = torch.from_numpy(grayscale_cam_stacked).unsqueeze(dim=0) # (1, C, 384, 384)
grayscale_cam_list_allbatch.append(grayscale_cam_stacked)
grayscale_cam_stacked_allbatch = torch.cat(grayscale_cam_list_allbatch, dim=0) # (b,c,384,384)
return grayscale_cam_stacked_allbatch
def train(self, mode=True):
super().train(mode)
for layer in self.not_training:
if isinstance(layer, torch.nn.Conv2d):
layer.weight.requires_grad = False
elif isinstance(layer, torch.nn.Module):
try:
for c in layer.children():
c.weight.requires_grad = False
if c.bias is not None:
c.bias.requires_grad = False
except:
pass
def trainable_parameters(self):
self.backbone_param = list(self.parameters())[:-2]
self.newly_added_param = list(self.parameters())[-2:]
return (self.backbone_param, self.newly_added_param)
def vit(args):
if args.network_type == 'vit' or args.network_type == 'vit_cam'\
or args.network_type=='cgn' or args.network_type=='cgn2' or args.network_type=='cgn3' \
or args.network_type=='vit_cam_cgn' or args.network_type=='vit_cam_cgn2' or args.network_type=='vit_cam_cgn3'\
or args.network_type == 'vit_ivr' or args.network_type == 'vit_cam_ivr':
model = VisionTransformer(
image_size=384,
patch_size=16,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
num_classes=args.num_classes,
use_cuda=args.use_cuda,
network_type = 'vit'
)
vit_statedict = torch.load(os.path.join(args.weight_root, 'vit_b_16_swag-9ac1b537.pth'))
vit_statedict.pop('heads.head.weight')
vit_statedict.pop('heads.head.bias')
model.load_state_dict(vit_statedict, strict=False)
elif args.network_type == 'vit_adl' or args.network_type == 'vit_cam_adl' \
or args.network_type=='cgn_adl' or args.network_type=='cgn2_adl' or args.network_type=='cgn3_adl' \
or args.network_type=='vit_cam_cgn_adl' or args.network_type=='vit_cam_cgn2_adl' or args.network_type=='vit_cam_cgn3_adl':
model = VisionTransformer(
image_size=384,
patch_size=16,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
num_classes=args.num_classes,
use_cuda=args.use_cuda,
network_type = 'vit_adl'
)
vit_statedict = torch.load(os.path.join(args.weight_root, 'vit_b_16_swag-9ac1b537.pth'))
vit_statedict.pop('heads.head.weight')
vit_statedict.pop('heads.head.bias')
model.load_state_dict(vit_statedict, strict=False)
else:
raise NotImplementedError("No model in ViT!!")
return model
import numpy as np
import torch
import torch.nn.functional as F
import os
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.ablation_layer import AblationLayerVit
import time
from tqdm import tqdm
from module.model import get_model
from module.dataloader import get_dataloader
from util import torchutils, imutils
def reshape_transform(tensor, crop_size=384):
height=crop_size // 16
width=crop_size // 16
# [1, 577, 768] -> [1, 24, 24, 768]
result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
# Bring the channels to the first dimension,
# like in CNNs.
result = result.transpose(2, 3).transpose(1, 2)
return result
def vit_grad_cam(model, data_loader, args):
# with torch.no_grad():
# target_layers = [model.encoder.layers.encoder_layer_11.ln_1] # for vit
target_layers = [model.blocks[11].norm1] # for vit
st = time.time()
for iteration, pack in tqdm(enumerate(data_loader)):
img_name = pack['name'][0]
label = pack['label'][0] # one hot encoded
valid_cat = torch.nonzero(label)[:, 0] # nonzero label index for all batch
input_tensor = pack['img'] # Create an input tensor image for your model..
size = pack['size']
strided_size = imutils.get_strided_size(size, 4)
grayscale_cam_low_res_list = []
grayscale_cam_high_res_list = []
for idx in range(len(valid_cat)):
targets = [ClassifierOutputTarget(valid_cat[idx])]
# Construct the CAM object once, and then re-use it on many images:
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=args.use_cuda,
reshape_transform=reshape_transform)
# You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
grayscale_cam = cam(input_tensor=input_tensor, targets=targets,\
eigen_smooth=args.eigen_smooth, aug_smooth=args.aug_smooth) # numpy.ndarray (1, 384 + 1, 384)
# In this example grayscale_cam has only one image in the batch:
grayscale_cam = grayscale_cam[0, :] # (1, 384, 384)
grayscale_cam_low_res = np.asarray(F.interpolate(
torch.from_numpy(grayscale_cam).unsqueeze(dim=0).unsqueeze(dim=0), size=strided_size,
mode='bilinear').squeeze(dim=0)) # strided size
grayscale_cam_low_res_list.append(grayscale_cam_low_res)
grayscale_cam_high_res = np.asarray(F.interpolate(
torch.from_numpy(grayscale_cam).unsqueeze(dim=0).unsqueeze(dim=0), size=(size[0], size[1]),
mode='bilinear').squeeze(dim=0)) # to original size
grayscale_cam_high_res_list.append(grayscale_cam_high_res)
grayscale_cam_low_res_stacked = np.concatenate(grayscale_cam_low_res_list, axis=0)
grayscale_cam_high_res_stacked = np.concatenate(grayscale_cam_high_res_list, axis=0)
# print("vitCAM size")
# print(grayscale_cam_low_res_stacked.shape)
# print(grayscale_cam_high_res_stacked.shape)
np.save(os.path.join(args.pred_dir, img_name + '.npy'),
{"keys": valid_cat, "cam": grayscale_cam_low_res_stacked, "high_res": grayscale_cam_high_res_stacked})
반응형