AI/PyTorch

[PyTorch] PyTorch 새로운 Module 개발 & 구현 기록

LiDARian 2023. 2. 15. 16:22
반응형

아래 코드 스니펫은 인턴 생활을 하면서 새로 개발하거나 이전 논문에서 있었던 모듈을 재구현 한 것이다.


CGN에의 training scheme 구현. 아래에 원본 github과 paper가 있습니다.

필자의 training scheme과 구조가 달라서 따로 함수로 적출한 형태입니다.. 원본에서는 Model Class 내에서 dataloading과 train을 하도록 구현되어있었습니다.

https://github.com/KAIST-vilab/OC-CSE

 

GitHub - KAIST-vilab/OC-CSE: Unlocking the Potential of Ordinary Classifier: Class-specific Adversarial Erasing Framework for We

Unlocking the Potential of Ordinary Classifier: Class-specific Adversarial Erasing Framework for Weakly Supervised Semantic Segmentation, ICCV 2021 - GitHub - KAIST-vilab/OC-CSE: Unlocking the Pote...

github.com

 

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import os

from util import pyutils
from module.dataloader import get_dataloader
from module.model import get_model
from module.optimizer import get_optimizer, get_scheduler

def model_freezing(model):
    for param in model.parameters():
        param.requires_grad = False
    return model


def model_defreezing(model):
    for param in model.parameters():
        param.requires_grad = True
    return model


def _work(model1, model2, args):
    '''
    model1, model2 must be pretrained
    '''
    # model to be train
    model1.to(args.device)
    model1.train()
    model1 = model_freezing(model1)
    
    # model make guidance and fixed
    model2.to(args.device)
    model2.eval()
    model2 = model_freezing(model2)

    train_loader = get_dataloader(args) # load dataset
    args.max_iters = args.epochs * len(train_loader) ## max_iter adjust before optimizer definition
    optimizer = get_optimizer(args, model1, args.max_iters)
    scheduler = get_scheduler(args, optimizer)

    ## printing utility
    avg_meter = pyutils.AverageMeter('loss', 'cls_loss1', 'cls_loss2')
    timer = pyutils.Timer("Session started: ")
    
    for ep in range(args.epochs):
        
        loss_weight = args.cc_init + args.cc_slope * ep
        
        for iteration in range(len(train_loader)):
            
            pack = next(iter(train_loader))
            img = pack['img'].to(args.device)
            label = pack['label'].to(args.device)
            
            label_all = label.clone().to(args.device)
            label_remain = label.clone().to(args.device)
            label_mask = torch.zeros((label.shape)).to(args.device)
            for i in range(label.shape[0]):
                label_idx = torch.nonzero(label[i], as_tuple=False)
                rand_idx = torch.randint(0, len(label_idx), (1,))
                target = label_idx[rand_idx][0]
                label_remain[i, target] = 0
                label_mask[i, target] = 1

            
            pred1 = model1(img)
            cam1 = model1.forward_cam(img).to(args.device) # (1,c,h,w)
            
            mask = cam1[label_mask == 1, :, :].unsqueeze(1) # (1,c',h,w)
            # mask = F.interpolate(mask, size=img.size()[2:], mode='bilinear', align_corners=False)
 
            mask = F.interpolate(mask, size=(args.model2_crop_size, args.model2_crop_size),
                                 mode='bilinear', align_corners=False)
            mask = F.relu(mask)
            mask = mask / (torch.max(mask) + 1e-5)
            img = F.interpolate(img, size=(args.model2_crop_size, args.model2_crop_size),
                                 mode='bilinear', align_corners=False)
            
            # import pdb
            # pdb.set_trace()
            masked_img = img * (1 - mask)
            
            pred2 = model2(masked_img)
            
            cls_loss1 = F.binary_cross_entropy_with_logits(pred1, label_all)
            cls_loss2 = loss_weight * F.binary_cross_entropy_with_logits(pred2, label_remain)
            # cls_loss1 = F.multilabel_soft_margin_loss(pred1, label_all)
            # cls_loss2 = loss_weight * F.multilabel_soft_margin_loss(pred2, label_remain)
            
            total_loss = cls_loss1 + cls_loss2
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            avg_meter.add({'loss': total_loss.item(), 'cls_loss1': cls_loss1.item(), 'cls_loss2': cls_loss2.item()})
            if iteration % 100 == 0: # optimizer.global_step-1
                timer.update_progress((ep * len(train_loader) + iteration + 1) / args.max_iters)

                print('Iter:%5d/%5d' % (ep * len(train_loader) + iteration, args.max_iters),
                    'Total Loss:%.4f' % (avg_meter.pop('loss')),
                    'Loss1:%.4f' % (avg_meter.pop('cls_loss1')),
                    'Loss2:%.4f' % (avg_meter.pop('cls_loss2')),
                    'Rem:%s' % (timer.get_est_remain()),
                    'lr: %.4f' % (optimizer.param_groups[0]['lr']), 
                    flush=True)
            timer.reset_stage()
        
        if scheduler is not None: scheduler.step()
    

def run_cgensemble(args):
    model1, model2 = get_model(args)
    _work(model1, model2, args)
    torch.cuda.empty_cache()

Random Activation Cropout

Resnet과 VIT 모두에서 동작 가능하게 만들었고, 중간중간의 activation map과 attention map을 잘라낼 수 있도록 하는 layer입니다. 생각보다 성능이 별로라서 폐기했습니다

import torch
import torch.nn as nn
import random
import math

class ActivationCropOut(nn.Module):
    def __init__(self, prob=0.5, training=True, attention_model=False):
        super(ActivationCropOut, self).__init__()
        # self.do = random.random()
        self.prob = prob
        self.training = training
        self.attention_model = attention_model

    def forward(self, x):
        do = random.random()
        if do > self.prob or self.training is False:
            return x
        if self.attention_model:
            class_token = x[:,0]
            x = x[:,1:]
            b,pos,hid = x.shape
            width = int(math.sqrt(pos))
            x = x.reshape(b, width, width, hid)
            x = x.permute(0,3,1,2) # (b, hid, width, width)
            
        b,c,h,w = x.shape
        ## ACL
        rand_h = torch.randint(0, h, (1,))
        rand_w = torch.randint(0, w, (1,))
        min_h, max_h = rand_h, min(rand_h + 64, h)
        min_w, max_w = rand_w, min(rand_w + 64, w)
        mask = torch.ones((h,w), dtype=torch.float).cuda()
        mask[min_h:max_h,min_w:max_w] = 0.
        
        x =  x * mask

        if self.attention_model:
            x = x.reshape(x.shape[0], x.shape[1], x.shape[2]**2)
            x = x.permute(0,2,1)
            x = torch.cat([class_token.unsqueeze(dim=1), x], dim=1)
            
        return x

ADL : Attention Based Dropout Layer

Resnet과 VIT 모두에서 동작 가능하게 수정했었다.

원본 논문 및 github는 아래와 같다.

https://github.com/junsukchoe/ADL

 

GitHub - junsukchoe/ADL: Attention-based Dropout Layer for Weakly Supervised Object Localization (CVPR 2019 Oral)

Attention-based Dropout Layer for Weakly Supervised Object Localization (CVPR 2019 Oral) - GitHub - junsukchoe/ADL: Attention-based Dropout Layer for Weakly Supervised Object Localization (CVPR 201...

github.com

https://arxiv.org/abs/1908.10028

 

Attention-based Dropout Layer for Weakly Supervised Object Localization

Weakly Supervised Object Localization (WSOL) techniques learn the object location only using image-level labels, without location annotations. A common limitation for these techniques is that they cover only the most discriminative part of the object, not

arxiv.org

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import math

class AttentionbasedDropout(nn.Module):
    '''
    extension version of ADL for attention based model
    '''
    def __init__(self, thres=0.90, training=True, attention_model=False):
        super(AttentionbasedDropout, self).__init__()
        # self.drop_or_importance = (random.random() > 0.5)
        self.training = training
        self.attention_model = attention_model
        self.thres = thres

    def forward(self, x):
        if self.training is False:
            return x
        
        if self.attention_model: # reshape to 3D tensor
            class_token = x[:,0]
            x = x[:,1:]
            b,pos,hid = x.shape
            width = int(math.sqrt(pos))
            x = x.reshape(b, width, width, hid)
            x = x.permute(0,3,1,2) # (b, hid, width, width)
            
        drop_or_importance = random.random()
        b,c,h,w = x.shape
        
        # Importance Map
        attention = torch.mean(x, dim=1, keepdim=True) # [b,1,h,w]
        importance_map = torch.sigmoid(attention) # [b,1,h,w]
        
        # drop mask
        max_val = torch.amax(attention, dim=(1,2,3), keepdim=True) # [b,1]
        thres_val = max_val * self.thres # [b,1]
        drop_mask = (attention < thres_val).to(torch.float32) # [b,1,h,w]
        
        x =  x * drop_mask if drop_or_importance < 0.75 else x * importance_map # (b,c,h,w)

        if self.attention_model:
            x = x.reshape(x.shape[0], x.shape[1], x.shape[2]**2) # (b, hid, width**2)
            x = x.permute(0,2,1) # (b, width**2, hid)
            x = torch.cat([class_token.unsqueeze(dim=1), x], dim=1) # (b, width**2 + 1, hid)
            
        return x

 

object 영역에 대해서만 가중치를 주기 위해서 masking 작업을 수행하는 CrossEntropy Loss 입니다.

class MaskingCELoss(nn.Module):
    def __init__(self, ignore_index=255, mask_val=0.001):
        super(MaskingCELoss, self).__init__()
        self.ignore_index = ignore_index
        self.mask_val = mask_val
        
    def forward(self, inputs, targets):
        device = inputs.device
        b,c,h,w = inputs.shape
                
        inputs = inputs.permute(0,2,3,1).reshape(-1, c)
        targets = targets.reshape(-1)
        rows = torch.arange(0,len(targets)).to(device, non_blocking=True)
        logs = F.log_softmax(inputs, dim=-1)
        
        # clearing ignore target
        logs_mask = torch.ones_like(logs)
        logs_mask[targets==self.ignore_index, :] = 0
        targets_mask = torch.ones_like(targets)
        targets_mask[targets==self.ignore_index] = 0
                
        # clearing ignore target
        logs = logs * logs_mask
        targets = targets * targets_mask

        # getting log likelihood
        out = logs[rows, targets]
        
        mask = torch.zeros_like(out)
        mask[targets!=0] = self.mask_val
        
        out = out * mask
        
        return -out.sum()/len(out)

 

ViT를 수정할 필요가 있어서, torchvision을 기반으로 수정한 내역입니다.

중간에 gradCAM을 구할 수 있도록 함수가 추가되어있고, ADL및 ACL을 시도한 흔적이 있습니다.

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import vit_b_16, ViT_B_16_Weights
import numpy as np
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import os

from module.vit_cam import reshape_transform

from network.attentionbased_dropout import AttentionbasedDropout

class MLP(torch.nn.Sequential):
    """This block implements the multi-layer perceptron (MLP) module.
    Args:
        in_channels (int): Number of channels of the input
        hidden_channels (List[int]): List of the hidden channel dimensions
        norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
        activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
        inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
        bias (bool): Whether to use bias in the linear layer. Default ``True``
        dropout (float): The probability for the dropout layer. Default: 0.0
    """

    def __init__(
        self,
        in_channels: int,
        hidden_channels: List[int],
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        inplace: Optional[bool] = True,
        bias: bool = True,
        dropout: float = 0.0,
    ):
        # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
        # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
        params = {} if inplace is None else {"inplace": inplace}

        layers = []
        in_dim = in_channels
        for hidden_dim in hidden_channels[:-1]:
            layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
            if norm_layer is not None:
                layers.append(norm_layer(hidden_dim))
            layers.append(activation_layer(**params))
            layers.append(torch.nn.Dropout(dropout, **params))
            in_dim = hidden_dim

        layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
        layers.append(torch.nn.Dropout(dropout, **params))

        super().__init__(*layers)
        # _log_api_usage_once(self)


class MLPBlock(MLP):
    """Transformer MLP block."""

    _version = 2

    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
            for i in range(2):
                for type in ["weight", "bias"]:
                    old_key = f"{prefix}linear_{i+1}.{type}"
                    new_key = f"{prefix}{3*i}.{type}"
                    if old_key in state_dict:
                        state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )


class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

    
class EncoderBlock_ADL(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
        
        # activation crop
        self.attbased_drop = AttentionbasedDropout(training=True, attention_model=True)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        
        out = self.attbased_drop(x + y)
        return out
    

class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        network_type: str = 'vit'
    ):
        super().__init__()
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default
        self.network_type = network_type
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT
        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        

        if 'vit_adl' in self.network_type:
            for i in range(num_layers):
                if i == 8 or i == 9 or i == 10 or i == 11:
                    layers[f"encoder_layer_{i}"] = EncoderBlock_ADL(
                        num_heads,
                        hidden_dim,
                        mlp_dim,
                        dropout,
                        attention_dropout,
                        norm_layer,
                    )
                else:
                    layers[f"encoder_layer_{i}"] = EncoderBlock(
                        num_heads,
                        hidden_dim,
                        mlp_dim,
                        dropout,
                        attention_dropout,
                        norm_layer,
                    )
        else:
            for i in range(num_layers):
                layers[f"encoder_layer_{i}"] = EncoderBlock(
                    num_heads,
                    hidden_dim,
                    mlp_dim,
                    dropout,
                    attention_dropout,
                    norm_layer,
                )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        input = input + self.pos_embedding
        return self.ln(self.layers(self.dropout(input)))


class VisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        use_cuda: bool =True,
        network_type: str = 'vit'
    ):
        super().__init__()
        # _log_api_usage_once(self)
        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer
        self.use_cuda = use_cuda
        self.network_type = network_type

        self.conv_proj = nn.Conv2d(
            in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
        )

        seq_length = (image_size // patch_size) ** 2

        # Add a class token
        self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        seq_length += 1

        self.encoder = Encoder(
            seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            norm_layer,
        )
        self.seq_length = seq_length

        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
            nn.init.xavier_uniform_(heads_layers["head"].weight)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        if isinstance(self.conv_proj, nn.Conv2d):
            # Init the patchify stem
            fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
            nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
            if self.conv_proj.bias is not None:
                nn.init.zeros_(self.conv_proj.bias)
        elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
            # Init the last 1x1 conv of the conv stem
            nn.init.normal_(
                self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
            )
            if self.conv_proj.conv_last.bias is not None:
                nn.init.zeros_(self.conv_proj.conv_last.bias)

        if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
            fan_in = self.heads.pre_logits.in_features
            nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.heads.pre_logits.bias)

        if isinstance(self.heads.head, nn.Linear):
            torch.nn.init.xavier_uniform_(self.heads.head.weight)
            nn.init.zeros_(self.heads.head.bias)
            # nn.init.zeros_(self.heads.head.weight)
            # nn.init.zeros_(self.heads.head.bias)
            
        self.not_training = [self.conv_proj]

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

    def forward(self, x: torch.Tensor):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]

        x = self.heads(x)

        return x

    def forward_cam(self, img):
        
        target_layers = [self.encoder.layers.encoder_layer_11.ln_1]
        # Construct the CAM object once, and then re-use it on many images:
        cam = GradCAM(model=self, target_layers=target_layers, use_cuda=self.use_cuda, 
                    reshape_transform=reshape_transform)
        
        grayscale_cam_list_allbatch = []
        for i in range(img.shape[0]): # for batch size
            grayscale_cam_list = []
            for idx in range(self.num_classes):
                targets = [ClassifierOutputTarget(idx)]
                    
                # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
                grayscale_cam = cam(input_tensor=img[i].unsqueeze(dim=0), targets=targets,\
                    eigen_smooth=True, aug_smooth=True) # numpy.ndarray (1, 384, 384)

                # In this example grayscale_cam has only one image in the batch:
                grayscale_cam = grayscale_cam[0, :] # (384, 384)
                grayscale_cam_list.append(grayscale_cam)
            
            grayscale_cam_stacked = np.stack(grayscale_cam_list, axis=0) # (C, 384, 384)
            grayscale_cam_stacked = torch.from_numpy(grayscale_cam_stacked).unsqueeze(dim=0) # (1, C, 384, 384)

            grayscale_cam_list_allbatch.append(grayscale_cam_stacked)
        grayscale_cam_stacked_allbatch = torch.cat(grayscale_cam_list_allbatch, dim=0) # (b,c,384,384)
        return grayscale_cam_stacked_allbatch
    
    
    def train(self, mode=True):
        super().train(mode)

        for layer in self.not_training:
            if isinstance(layer, torch.nn.Conv2d):
                layer.weight.requires_grad = False

            elif isinstance(layer, torch.nn.Module):
                try:
                    for c in layer.children():
                        c.weight.requires_grad = False
                        if c.bias is not None:
                            c.bias.requires_grad = False
                except:
                    pass
    
    def trainable_parameters(self):
        
        self.backbone_param = list(self.parameters())[:-2]
        self.newly_added_param = list(self.parameters())[-2:]

        return (self.backbone_param, self.newly_added_param)
    

def vit(args):
    if args.network_type == 'vit' or args.network_type == 'vit_cam'\
        or args.network_type=='cgn' or args.network_type=='cgn2' or args.network_type=='cgn3' \
            or args.network_type=='vit_cam_cgn' or args.network_type=='vit_cam_cgn2' or args.network_type=='vit_cam_cgn3'\
                or args.network_type == 'vit_ivr' or args.network_type == 'vit_cam_ivr':
        model = VisionTransformer(
            image_size=384,
            patch_size=16,
            num_layers=12,
            num_heads=12,
            hidden_dim=768,
            mlp_dim=3072,
            num_classes=args.num_classes,
            use_cuda=args.use_cuda,
            network_type = 'vit'
        )
        vit_statedict = torch.load(os.path.join(args.weight_root, 'vit_b_16_swag-9ac1b537.pth'))
        vit_statedict.pop('heads.head.weight')
        vit_statedict.pop('heads.head.bias')
        model.load_state_dict(vit_statedict, strict=False)
        
    elif args.network_type == 'vit_adl' or args.network_type == 'vit_cam_adl' \
        or args.network_type=='cgn_adl' or args.network_type=='cgn2_adl' or args.network_type=='cgn3_adl' \
            or args.network_type=='vit_cam_cgn_adl' or args.network_type=='vit_cam_cgn2_adl' or args.network_type=='vit_cam_cgn3_adl':
        model = VisionTransformer(
            image_size=384,
            patch_size=16,
            num_layers=12,
            num_heads=12,
            hidden_dim=768,
            mlp_dim=3072,
            num_classes=args.num_classes,
            use_cuda=args.use_cuda,
            network_type = 'vit_adl'
        )
        vit_statedict = torch.load(os.path.join(args.weight_root, 'vit_b_16_swag-9ac1b537.pth'))
        vit_statedict.pop('heads.head.weight')
        vit_statedict.pop('heads.head.bias')
        model.load_state_dict(vit_statedict, strict=False)
    else:
        raise NotImplementedError("No model in ViT!!")
    return model


import numpy as np
import torch
import torch.nn.functional as F
import os
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.ablation_layer import AblationLayerVit
import time
from tqdm import tqdm

from module.model import get_model
from module.dataloader import get_dataloader
from util import torchutils, imutils


def reshape_transform(tensor, crop_size=384):
    height=crop_size // 16
    width=crop_size // 16
    # [1, 577, 768] -> [1, 24, 24, 768]
    result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))

    # Bring the channels to the first dimension,
    # like in CNNs.
    result = result.transpose(2, 3).transpose(1, 2)
    return result


def vit_grad_cam(model, data_loader, args):
    # with torch.no_grad():
    # target_layers = [model.encoder.layers.encoder_layer_11.ln_1] # for vit
    target_layers = [model.blocks[11].norm1] # for vit
    
    
    st = time.time()
    for iteration, pack in tqdm(enumerate(data_loader)):
        img_name = pack['name'][0]
        label = pack['label'][0] # one hot encoded
        valid_cat = torch.nonzero(label)[:, 0] # nonzero label index for all batch
        
        input_tensor = pack['img']  # Create an input tensor image for your model..
        size = pack['size']
        
        strided_size = imutils.get_strided_size(size, 4)
        
        grayscale_cam_low_res_list = []
        grayscale_cam_high_res_list = []
        for idx in range(len(valid_cat)):
            targets = [ClassifierOutputTarget(valid_cat[idx])]
            
            # Construct the CAM object once, and then re-use it on many images:
            cam = GradCAM(model=model, target_layers=target_layers, use_cuda=args.use_cuda, 
                        reshape_transform=reshape_transform)
        
            # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
            grayscale_cam = cam(input_tensor=input_tensor, targets=targets,\
                                eigen_smooth=args.eigen_smooth, aug_smooth=args.aug_smooth) # numpy.ndarray (1, 384 + 1, 384)

            # In this example grayscale_cam has only one image in the batch:
            grayscale_cam = grayscale_cam[0, :] # (1, 384, 384)
            
            grayscale_cam_low_res = np.asarray(F.interpolate(
                torch.from_numpy(grayscale_cam).unsqueeze(dim=0).unsqueeze(dim=0), size=strided_size,
            mode='bilinear').squeeze(dim=0)) # strided size
            grayscale_cam_low_res_list.append(grayscale_cam_low_res)
            
            grayscale_cam_high_res = np.asarray(F.interpolate(
                torch.from_numpy(grayscale_cam).unsqueeze(dim=0).unsqueeze(dim=0), size=(size[0], size[1]),
            mode='bilinear').squeeze(dim=0)) # to original size
            grayscale_cam_high_res_list.append(grayscale_cam_high_res)
        
        grayscale_cam_low_res_stacked = np.concatenate(grayscale_cam_low_res_list, axis=0)
        grayscale_cam_high_res_stacked = np.concatenate(grayscale_cam_high_res_list, axis=0)
        
        # print("vitCAM size")
        # print(grayscale_cam_low_res_stacked.shape)
        # print(grayscale_cam_high_res_stacked.shape)
        
        np.save(os.path.join(args.pred_dir, img_name + '.npy'), 
                {"keys": valid_cat, "cam": grayscale_cam_low_res_stacked, "high_res": grayscale_cam_high_res_stacked})
반응형