1. 代码整体结构
这段代码实现了一个图像分类的任务,主要包含以下几个模块:
- 数据处理:加载和预处理图像数据。
- 半监督学习:在训练过程中加入无标签数据,并根据模型预测的置信度来选择哪些无标签样本加入训练。
- 模型定义:使用卷积神经网络(CNN)模型进行图像分类。
- 训练与评估:训练模型并在验证集上评估其性能。
2. 数据预处理与加载
2.1 seed_everything:设置随机种子
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
设置随机种子是为了确保结果的可重现性。此函数会为torch、numpy和python的随机数生成器设置种子。
2.2 food_Dataset类:自定义数据集类
该类继承自torch.utils.data.Dataset,用于处理和加载图像数据。
2.2.1 __init__方法
def __init__(self, path, mode="train"):
self.mode = mode
if mode == "semi":
self.X = self.read_file(path)
else:
self.X, self.Y = self.read_file(path)
self.Y = torch.LongTensor(self.Y) #标签转为长整形\
path:数据所在路径。mode:数据集模式(训练集train,验证集val,半监督学习semi)。- 如果是半监督模式,加载无标签数据
X;否则,加载有标签数据X和Y(标签)。
2.2.2 read_file方法
def read_file(self, path):
if self.mode == "semi":
file_list = os.listdir(path)
xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
# 列出文件夹下所有文件名字
for j, img_name in enumerate(file_list):
img_path = os.path.join(path, img_name)
img = Image.open(img_path)
img = img.resize((HW, HW))
xi[j, ...] = img
print("读到了%d个数据" % len(xi))
return xi
else:
for i in tqdm(range(11)):
file_dir = path + "/%02d" % i
file_list = os.listdir(file_dir)
xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
yi = np.zeros(len(file_list), dtype=np.uint8)
# 列出文件夹下所有文件名字
for j, img_name in enumerate(file_list):
img_path = os.path.join(file_dir, img_name)
img = Image.open(img_path)
img = img.resize((HW, HW))
xi[j, ...] = img
yi[j] = i
if i == 0:
X = xi
Y = yi
else:
X = np.concatenate((X, xi), axis=0)
Y = np.concatenate((Y, yi), axis=0)
print("读到了%d个数据" % len(Y))
return X, Y
-
1.
if self.mode == "semi":这是判断数据集模式的条件语句。根据
self.mode的值,它决定是否执行半监督学习模式的处理。mode可以是train、val或semi,具体取决于在创建数据集实例时的设置。 self.mode == "semi":表示使用半监督学习模式,这时只需要加载未标记的数据。else:表示加载有标记的数据,通常是训练集或验证集。os.listdir(path):返回指定路径path中所有文件的列表。path是未标记图像的文件夹路径。- 创建一个大小为
(len(file_list), HW, HW, 3)的零矩阵,其中HW是图像的宽和高(224),3是RGB三个通道,dtype=np.uint8表示图像的像素值为 0-255 的整数。 enumerate(file_list):对文件列表中的每个图像文件名进行迭代,img_name是文件名,j是文件的索引。img_path = os.path.join(path, img_name):根据文件名img_name和文件夹路径path拼接出图像的完整路径。img = Image.open(img_path):用PIL.Image.open()打开图像文件。img = img.resize((HW, HW)):调整图像的大小为(HW, HW),即 224x224。xi[j, ...] = img:将处理后的图像数据赋值给xi数组的第j行。- 输出已读取的图像数量。
- 返回读取并调整尺寸后的图像数据
xi。这是半监督模式下的数据,不包括标签。 range(11):假设数据集有 11 个类别(0 到 10)。tqdm(range(11)):tqdm是一个用于显示进度条的库,这样可以在加载数据时实时显示进度。file_dir = path + "/%02d" % i:根据i构造出每个类别对应的文件夹路径,路径格式为"path/00","path/01", ...,"path/10"。- 列出每个类别文件夹中的所有图像文件。
xi:用来存储当前类别的所有图像数据,形状为(len(file_list), HW, HW, 3),表示每张图像的尺寸为224x224,3 通道的图像。yi:用来存储当前类别的标签,形状为(len(file_list),),每个标签都对应一个类别索引i。enumerate(file_list):对每个图像文件进行遍历。img_path = os.path.join(file_dir, img_name):构造图像文件的完整路径。img = Image.open(img_path):用PIL.Image.open()打开图像文件。img = img.resize((HW, HW)):调整图像大小为224x224。xi[j, ...] = img:将处理后的图像数据存入xi数组的第j行。yi[j] = i:为每个图像分配标签i,表示该图像属于第i类。- 如果
i == 0,说明是第一个类别,直接将xi和yi赋给X和Y。 - 如果
i != 0,说明是后续类别,将当前类别的xi和yi与前面类别的数据进行拼接,np.concatenate()用于按axis=0方向拼接数据,即将图像数据和标签数据扩展到全体数据。 - 输出已经读取的所有数据的数量(包括所有类别的数据)。
- 返回拼接后的图像数据
X和标签数据Y。这是训练模式或验证模式下的数据。 - 如果是半监督模式,代码读取无标签图像数据。
- 如果是有标签模式,代码按类别读取图片,并返回图像数据
X和标签Y。
2.2.3 __getitem__方法
def __getitem__(self, item):
if self.mode == "semi":
return self.transform(self.X[item]), self.X[item]
else:
return self.transform(self.X[item]), self.Y[item]
返回处理后的图像数据(通过transform)以及相应的标签或图像数据。
2.3 数据增强与转换
使用torchvision.transforms进行数据增强和转换:
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomResizedCrop(224),
transforms.RandomRotation(50),
transforms.ToTensor()
])
RandomResizedCrop(224):随机裁剪并缩放为224x224。RandomRotation(50):随机旋转图像,最大旋转角度为50度。
验证集的转换仅将图像转换为Tensor。
3. 模型定义
3.1 自定义CNN模型(myModel)
class myModel(nn.Module):
def __init__(self, num_class):
super(myModel, self).__init__()
#3 *224 *224 -> 512*7*7 -> 拉直 -》全连接分类
self.conv1 = nn.Conv2d(3, 64, 3, 1, 1) # 64*224*224
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.pool1 = nn.MaxPool2d(2) #64*112*112
self.layer1 = nn.Sequential(
nn.Conv2d(64, 128,
3, 1, 1), # 128*112*112
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2) #128*56*56
)
self.layer2 = nn.Sequential(
nn.Conv2d(128, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(2) #256*28*28
)
self.layer3 = nn.Sequential(
nn.Conv2d(256, 512, 3, 1, 1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.MaxPool2d(2) #512*14*14
)
self.pool2 = nn.MaxPool2d(2) #512*7*7
self.fc1 = nn.Linear(25088, 1000) #25088->1000
self.relu2 = nn.ReLU()
self.fc2 = nn.Linear(1000, num_class) #1000-11
该模型由多个卷积层、池化层和全连接层组成。最后通过一个fc2层输出类别概率。
3.2 initialize_model函数
model, _ = initialize_model("vgg", 11, use_pretrained=True)
这里使用initialize_model函数加载VGG模型,并将其初始化为11个类别的输出。use_pretrained=True表示使用预训练权重。
4. 半监督学习
4.1 semiDataset类:半监督数据集
class semiDataset(Dataset):
def __init__(self, no_label_loder, model, device, thres=0.99):
x, y = self.get_label(no_label_loder, model, device, thres)
if x == []:
self.flag = False
else:
self.flag = True
self.X = np.array(x)
self.Y = torch.LongTensor(y)
self.transform = train_transform
semiDataset类用于处理无标签数据,并根据模型的预测置信度来筛选出可用于训练的数据。thres是一个阈值,用于决定哪些无标签样本是“可靠的”:
- 在训练过程中,模型会根据其对无标签数据的预测置信度(通过Softmax输出概率)选择高于阈值的样本作为伪标签数据进行训练。
4.2 get_label函数
def get_label(self, no_label_loder, model, device, thres):
model = model.to(device)
pred_prob = []
labels = []
x = []
y = []
soft = nn.Softmax()
with torch.no_grad():
for bat_x, _ in no_label_loder:
bat_x = bat_x.to(device)
pred = model(bat_x)
pred_soft = soft(pred)
pred_max, pred_value = pred_soft.max(1)
pred_prob.extend(pred_max.cpu().numpy().tolist())
labels.extend(pred_value.cpu().numpy().tolist())
for index, prob in enumerate(pred_prob):
if prob > thres:
x.append(no_label_loder.dataset[index][1]) #调用到原始的getitem
y.append(labels[index])
return x, y
- 计算每个无标签样本的预测置信度(通过Softmax得到概率)。
- 如果预测概率大于阈值
thres,则该样本被认为是“可靠的”,并添加到训练集中。
5. 训练与评估
5.1 train_val函数
def train_val(model, train_loader, val_loader, no_label_loader,
device, epochs, optimizer, loss, thres, save_path):
该函数用于训练模型,并根据验证集的准确率保存最优模型。
- 每训练一个epoch,都会计算训练集和验证集上的损失和准确率。
- 每当验证集的准确率提高时,保存模型。
- 每隔3个epoch,如果验证集的准确率超过阈值,则开始使用无标签数据进行半监督训练。
6. 总结
通过以上代码实现,我们构建了一个半监督学习框架,能够利用无标签数据提高图像分类模型的性能。主要思想是:利用有标签数据进行初步训练,然后通过预测无标签数据并选取置信度较高的数据来扩展训练集,从而提高模型的泛化能力

672

被折叠的 条评论
为什么被折叠?



