【深度学习可视化系列】—— 特征图可视化(支持Vit系列模型的特征图可视化,包含使用Tensorboard对可视化结果进行保存)

本文介绍了如何在PyTorch中使用Vit系列模型进行特征图可视化,并演示了如何通过Tensorboard保存这些可视化结果,包括使用`FeatureExtractor`类提取特征并进行预处理,最后展示了地表裂缝图像的可视化实例。
Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

本镜像基于 RTX 4090D 24GB 显存 + CUDA 12.4 + 驱动 550.90.07 深度优化,内置完整运行环境与 Qwen3-32B 模型依赖,开箱即用。

【深度学习可视化系列】—— 特征图可视化(支持Vit系列模型的特征图可视化,包含使用Tensorboard对可视化结果进行保存)

import sys
import os
import torch
import cv2
import timm
import numpy as np 
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model.MitUnet import  MitUnet
from collections import OrderedDict
from typing import Dict, Iterable, Callable
from torch import nn, Tensor
from PIL import Image
from pprint import pprint


# --------------------------------------------------------------------------------------------------------------------------
# 构建模型特征图提取模型,输入参数为模型、以及需提取特征图层的key名称,该名称可通过model.named_modules()或model.named_children()获取
# --------------------------------------------------------------------------------------------------------------------------
class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        # assert layers is not None
        self.model = model
        self.layers = layers
        self._features = OrderedDict({layer: torch.empty(0) for layer in layers})
        self.hook = []

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            self.hook = layer.register_forward_hook(self.hook_func(layer_id))
            # self.hook.append(self.layer_id)

    def hook_func(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            # print("_____{}".format(output.dim()))   
            if output.dim() == 3:
                output = self.reshape_transform(in_tensor=output) 
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        self.remove()
        return self._features
    
    def remove(self):
        # for hook in self.hook:
        self.hook.remove()

    def reshape_transform(self, in_tensor):
        result = in_tensor.reshape(in_tensor.size(0),
            int(np.sqrt(in_tensor.size(1))), int(np.sqrt(in_tensor.size(1))), in_tensor.size(2))

        result = result.transpose(2, 3).transpose(1, 2)
        return result
    
    
# --------------------------------------------------------------------------------------------------------------------------
# 构建模型,并进行特征提取
# --------------------------------------------------------------------------------------------------------------------------
img_mask_size = 256
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model = UNet(....)
# map_location={'cuda:0': 'cpu'}
state_dict = torch.load('./state_dict/model.pth')
model.load_state_dict(state_dict['model'])
print('网络设置完毕 :成功载入了训练完毕的权重。')
model.to(device=device)
transformer = A.Compose([
    A.Resize(img_mask_size, img_mask_size),
    A.Normalize(
        mean=(0.5835, 0.5820, 0.5841),
        std=(0.1149, 0.1111, 0.1064),
        max_pixel_value=255.0
    ),
    ToTensorV2()
])
return_layers = ["encoder.norm1"]
e_model = FeatureExtractor(model=model, layers=return_layers)
image_file = ".\images"
image_file_path = os.path.join(image_file, str("15") + (".jpg"))
img = Image.open(image_file_path)
img_width, img_height = img.size
image_np = np.array(img)
augmented = transformer(image=image_np)
augmented_img = augmented['image'].to(device)  
# 由于模型中存在BN层,其不允许推理的batchsize小于2,所以生成一个和原始影像相同大小尺度的虚拟图像使得batchsize=2。
virual_image = torch.randn(size=(3, img_mask_size, img_mask_size), dtype=torch.float32).to(device=device)
augmented_img = torch.stack([augmented_img, virual_image], dim=0)
print(augmented_img.shape)
output = e_model(augmented_img)
for keys, values in output.items():
    output[keys] = values[0].unsqueeze(0) 
pprint({keys : torch.sigmoid(values[0]).detach().shape for keys, values in output.items()})


# --------------------------------------------------------------------------------------------------------------------------
# 使用tensorboard保存特征图可视化结果
# --------------------------------------------------------------------------------------------------------------------------
from torchvision.utils import make_grid
from torch.utils.tensorboard.writer import SummaryWriter

writer = SummaryWriter("runs/test")
for keys, values in output.items():
    values = torch.sigmoid(values[0]).cpu().detach().numpy()
    imgs_ = np.empty(shape=(values.shape[0], 3, values.shape[1], values.shape[2])) 
    for index, batch_img in enumerate(values):
        imgs_[index] =  cv2.applyColorMap(np.uint8(batch_img * 255), cv2.COLORMAP_JET).transpose(2, 0, 1)
    imgs_grid = make_grid(torch.from_numpy(imgs_), nrow=5, padding=2, pad_value=0)
    cv2.namedWindow("imgs_grid", cv2.WINDOW_FULLSCREEN)
    cv2.imshow("imgs_grid", imgs_grid.permute(1, 2, 0).numpy())
    cv2.waitKey()
	cv2.destroyAllWindows()
    
    writer.add_images(keys + "_TEST", imgs_, 0, dataformats="NCHW")
writer.close()

可视化结果如下(以地表裂缝图像为例):
请添加图片描述
​ 地裂缝图像以及分割结果
请添加图片描述

​ 裂缝提取模型部分特征图可视化结果

您可能感兴趣的与本文相关的镜像

Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版

Qwen
文本生成
Qwen3

本镜像基于 RTX 4090D 24GB 显存 + CUDA 12.4 + 驱动 550.90.07 深度优化,内置完整运行环境与 Qwen3-32B 模型依赖,开箱即用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

卖报的大地主

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值