半监督学习图像分类任务

1. 代码整体结构

这段代码实现了一个图像分类的任务,主要包含以下几个模块:

  • 数据处理:加载和预处理图像数据。
  • 半监督学习:在训练过程中加入无标签数据,并根据模型预测的置信度来选择哪些无标签样本加入训练。
  • 模型定义:使用卷积神经网络(CNN)模型进行图像分类。
  • 训练与评估:训练模型并在验证集上评估其性能。

2. 数据预处理与加载

2.1 seed_everything:设置随机种子

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

设置随机种子是为了确保结果的可重现性。此函数会为torchnumpypython的随机数生成器设置种子。

2.2 food_Dataset类:自定义数据集类

该类继承自torch.utils.data.Dataset,用于处理和加载图像数据。

2.2.1 __init__方法
    def __init__(self, path, mode="train"):
        self.mode = mode
        if mode == "semi":
            self.X = self.read_file(path)
        else:
            self.X, self.Y = self.read_file(path)
            self.Y = torch.LongTensor(self.Y)  #标签转为长整形\
  • path:数据所在路径。
  • mode:数据集模式(训练集train,验证集val,半监督学习semi)。
  • 如果是半监督模式,加载无标签数据X;否则,加载有标签数据XY(标签)。
2.2.2 read_file方法
    def read_file(self, path):
        if self.mode == "semi":
            file_list = os.listdir(path)
            xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
            # 列出文件夹下所有文件名字
            for j, img_name in enumerate(file_list):
                img_path = os.path.join(path, img_name)
                img = Image.open(img_path)
                img = img.resize((HW, HW))
                xi[j, ...] = img
            print("读到了%d个数据" % len(xi))
            return xi
        else:
            for i in tqdm(range(11)):
                file_dir = path + "/%02d" % i
                file_list = os.listdir(file_dir)

                xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
                yi = np.zeros(len(file_list), dtype=np.uint8)

                # 列出文件夹下所有文件名字
                for j, img_name in enumerate(file_list):
                    img_path = os.path.join(file_dir, img_name)
                    img = Image.open(img_path)
                    img = img.resize((HW, HW))
                    xi[j, ...] = img
                    yi[j] = i

                if i == 0:
                    X = xi
                    Y = yi
                else:
                    X = np.concatenate((X, xi), axis=0)
                    Y = np.concatenate((Y, yi), axis=0)
            print("读到了%d个数据" % len(Y))
            return X, Y

  • 1. if self.mode == "semi":

    这是判断数据集模式的条件语句。根据 self.mode 的值,它决定是否执行半监督学习模式的处理。mode 可以是 trainvalsemi,具体取决于在创建数据集实例时的设置。

  • self.mode == "semi":表示使用半监督学习模式,这时只需要加载未标记的数据。
  • else:表示加载有标记的数据,通常是训练集或验证集。
  • os.listdir(path):返回指定路径 path 中所有文件的列表。path 是未标记图像的文件夹路径。
  • 创建一个大小为 (len(file_list), HW, HW, 3) 的零矩阵,其中 HW 是图像的宽和高(224),3 是RGB三个通道,dtype=np.uint8 表示图像的像素值为 0-255 的整数。
  • enumerate(file_list):对文件列表中的每个图像文件名进行迭代,img_name 是文件名,j 是文件的索引。
  • img_path = os.path.join(path, img_name):根据文件名 img_name 和文件夹路径 path 拼接出图像的完整路径。
  • img = Image.open(img_path):用 PIL.Image.open() 打开图像文件。
  • img = img.resize((HW, HW)):调整图像的大小为 (HW, HW),即 224x224。
  • xi[j, ...] = img:将处理后的图像数据赋值给 xi 数组的第 j 行。
  • 输出已读取的图像数量。
  • 返回读取并调整尺寸后的图像数据 xi。这是半监督模式下的数据,不包括标签。
  • range(11):假设数据集有 11 个类别(0 到 10)。
  • tqdm(range(11))tqdm 是一个用于显示进度条的库,这样可以在加载数据时实时显示进度。
  • file_dir = path + "/%02d" % i:根据 i 构造出每个类别对应的文件夹路径,路径格式为 "path/00", "path/01", ..., "path/10"
  • 列出每个类别文件夹中的所有图像文件。
  • xi:用来存储当前类别的所有图像数据,形状为 (len(file_list), HW, HW, 3),表示每张图像的尺寸为 224x224,3 通道的图像。
  • yi:用来存储当前类别的标签,形状为 (len(file_list),),每个标签都对应一个类别索引 i
  • enumerate(file_list):对每个图像文件进行遍历。
  • img_path = os.path.join(file_dir, img_name):构造图像文件的完整路径。
  • img = Image.open(img_path):用 PIL.Image.open() 打开图像文件。
  • img = img.resize((HW, HW)):调整图像大小为 224x224
  • xi[j, ...] = img:将处理后的图像数据存入 xi 数组的第 j 行。
  • yi[j] = i:为每个图像分配标签 i,表示该图像属于第 i 类。
  • 如果 i == 0,说明是第一个类别,直接将 xiyi 赋给 XY
  • 如果 i != 0,说明是后续类别,将当前类别的 xiyi 与前面类别的数据进行拼接,np.concatenate() 用于按 axis=0 方向拼接数据,即将图像数据和标签数据扩展到全体数据。
  • 输出已经读取的所有数据的数量(包括所有类别的数据)。
  • 返回拼接后的图像数据 X 和标签数据 Y。这是训练模式或验证模式下的数据。
  • 如果是半监督模式,代码读取无标签图像数据。
  • 如果是有标签模式,代码按类别读取图片,并返回图像数据X和标签Y
2.2.3 __getitem__方法
    def __getitem__(self, item):
        if self.mode == "semi":
            return self.transform(self.X[item]), self.X[item]
        else:
            return self.transform(self.X[item]), self.Y[item]

返回处理后的图像数据(通过transform)以及相应的标签或图像数据。


2.3 数据增强与转换

使用torchvision.transforms进行数据增强和转换:

train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(224),
    transforms.RandomRotation(50),
    transforms.ToTensor()
])
  • RandomResizedCrop(224):随机裁剪并缩放为224x224。
  • RandomRotation(50):随机旋转图像,最大旋转角度为50度。

验证集的转换仅将图像转换为Tensor。


3. 模型定义

3.1 自定义CNN模型(myModel

class myModel(nn.Module):
    def __init__(self, num_class):
        super(myModel, self).__init__()
        #3 *224 *224  -> 512*7*7 -> 拉直 -》全连接分类
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)    # 64*224*224
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)   #64*112*112


        self.layer1 = nn.Sequential(
            nn.Conv2d(64, 128,
                      3, 1, 1),    # 128*112*112
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)   #128*56*56
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2)   #256*28*28
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2)   #512*14*14
        )

        self.pool2 = nn.MaxPool2d(2)    #512*7*7
        self.fc1 = nn.Linear(25088, 1000)   #25088->1000
        self.relu2 = nn.ReLU()
        self.fc2 = nn.Linear(1000, num_class)  #1000-11

该模型由多个卷积层、池化层和全连接层组成。最后通过一个fc2层输出类别概率。

3.2 initialize_model函数

model, _ = initialize_model("vgg", 11, use_pretrained=True)

这里使用initialize_model函数加载VGG模型,并将其初始化为11个类别的输出。use_pretrained=True表示使用预训练权重。


4. 半监督学习

4.1 semiDataset类:半监督数据集

class semiDataset(Dataset):
    def __init__(self, no_label_loder, model, device, thres=0.99):
        x, y = self.get_label(no_label_loder, model, device, thres)
        if x == []:
            self.flag = False

        else:
            self.flag = True
            self.X = np.array(x)
            self.Y = torch.LongTensor(y)
            self.transform = train_transform

semiDataset类用于处理无标签数据,并根据模型的预测置信度来筛选出可用于训练的数据。thres是一个阈值,用于决定哪些无标签样本是“可靠的”:

  • 在训练过程中,模型会根据其对无标签数据的预测置信度(通过Softmax输出概率)选择高于阈值的样本作为伪标签数据进行训练。

4.2 get_label函数

    def get_label(self, no_label_loder, model, device, thres):
        model = model.to(device)
        pred_prob = []
        labels = []
        x = []
        y = []
        soft = nn.Softmax()
        with torch.no_grad():
            for bat_x, _ in no_label_loder:
                bat_x = bat_x.to(device)
                pred = model(bat_x)
                pred_soft = soft(pred)
                pred_max, pred_value = pred_soft.max(1)
                pred_prob.extend(pred_max.cpu().numpy().tolist())
                labels.extend(pred_value.cpu().numpy().tolist())

        for index, prob in enumerate(pred_prob):
            if prob > thres:
                x.append(no_label_loder.dataset[index][1])   #调用到原始的getitem
                y.append(labels[index])
        return x, y
  • 计算每个无标签样本的预测置信度(通过Softmax得到概率)。
  • 如果预测概率大于阈值thres,则该样本被认为是“可靠的”,并添加到训练集中。

5. 训练与评估

5.1 train_val函数

def train_val(model, train_loader, val_loader, no_label_loader,
 device, epochs, optimizer, loss, thres, save_path):

该函数用于训练模型,并根据验证集的准确率保存最优模型。

  • 每训练一个epoch,都会计算训练集和验证集上的损失和准确率。
  • 每当验证集的准确率提高时,保存模型。
  • 每隔3个epoch,如果验证集的准确率超过阈值,则开始使用无标签数据进行半监督训练。

6. 总结

通过以上代码实现,我们构建了一个半监督学习框架,能够利用无标签数据提高图像分类模型的性能。主要思想是:利用有标签数据进行初步训练,然后通过预测无标签数据并选取置信度较高的数据来扩展训练集,从而提高模型的泛化能力

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值