tensorflow 中batch normalize 的使用

最近在学习slim,slim有个很好的地方就是:搭建网络方便,也有很多预训练模型下载。

但是最近在调slim中的resnet的时候,发现训练集有很高的accuracy(如90%),但是测试集的accuracy还是很低(如0%, 1%),这肯定不是由于欠拟合或者过拟合导致的。

后来发现是在做batch normalize的时候出了问题。
slim的使用batch normalize的时候很方便,不需要在每个卷积层后面显示地加一个batch normalize.只需要在slim里面的arg_scope中加入slim.batch_norm就可以。
如下操作就可以:

batch_norm_params = {
      'decay': batch_norm_decay,
      'epsilon': batch_norm_epsilon,
      'scale': batch_norm_scale,
      'updates_collections': tf.GraphKeys.UPDATE_OPS,
      'is_training': is_training
  }

  with slim.arg_scope(
      [slim.conv2d],
      weights_regularizer=slim.l2_regularizer(weight_decay),
      weights_initializer=slim.variance_scaling_initializer(),
      activation_fn=tf.nn.relu,
      normalizer_fn=slim.batch_norm,
      normalizer_params=batch_norm_params):
    with slim.arg_scope([slim.batch_norm], **batch_norm_params):
      ...
      ...

言归正转,要注意的地方是,在做测试的时候,如果将is_training改为 False,就会出现测试accuracy很低的现象,需要将is_training改成True。虽然这样能得到高的accuracy,但是明显不合理!!
解决方法是:
因为batch_norm 在test的时候,用的是固定的mean和var, 而这个固定的mean和var是通过训练过程中对mean和var进行移动平均得到的。而直接使用train_op会使得模型没有计算mean和var,因此正确的方式是:
每次训练时应当更新一下moving_mean和moving_var

optimizer = tf.train.MomentumOptimizer(lr,momentum=FLAGS.momentum,
                                       name='MOMENTUM')
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies([tf.group(*update_ops)]):
    # train_op = slim.learning.create_train_op(total_loss, optimizer, global_step)
    train_op = optimizer.minimize(total_loss, global_step=global_step)

这样在测试的时候即使将is_training改成False也能得到正常的test accuracy了。

当然如果你还是没看懂,就戳这个链接:https://github.com/soloice/mnist-bn/blob/master/mnist_bn.py,里面有完整的代码。
或者这个:https://github.com/tensorflow/models/blob/master/slim/train_image_classifier.py

当然,其他的用法,例如tf.contrib.layers.batch_norm里面的batch normalize应该差不多,但是我没有用过,如果你用起来出了问题,可以戳下面两个链接看看能否找到答案。
1.http://ruishu.io/2016/12/27/batchnorm/
2.https://github.com/tensorflow/tensorflow/issues/1122#issuecomment-280325584

参考文献:
[1]https://github.com/tensorflow/tensorflow/issues/1122#issuecomment-280325584
[2].http://ruishu.io/2016/12/27/batchnorm/
[3]https://github.com/soloice/mnist-bn/blob/master/mnist_bn.py
[4]https://github.com/tensorflow/models/blob/master/slim/train_image_classifier.py

在Python中使用TensorFlow或PyTorch实现CIFAR-100动物识别是一个常见的计算机视觉任务。以下是使用这两种框架的基本步骤: ### 使用TensorFlow实现CIFAR-100动物识别 1. **安装必要的库**: ```bash pip install tensorflow tensorflow-datasets ``` 2. **导入库**: ```python import tensorflow as tf from tensorflow.keras import layers, models import tensorflow_datasets as tfds ``` 3. **加载数据集**: ```python (ds_train, ds_test), ds_info = tfds.load( 'cifar100', split=['train', 'test'], shuffle_files=True, as_supervised=True, with_info=True, ) ``` 4. **数据预处理**: ```python def normalize_img(image, label): return tf.cast(image, tf.float32) / 255.0, label ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE) ds_train = ds_train.cache() ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples) ds_train = ds_train.batch(128) ds_train = ds_train.prefetch(tf.data.AUTOTUNE) ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE) ds_test = ds_test.batch(128) ds_test = ds_test.cache() ds_test = ds_test.prefetch(tf.data.AUTOTUNE) ``` 5. **构建模型**: ```python model = models.Sequential([ layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), layers.MaxPooling2D((2, 2)), layers.Conv2D(64, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Flatten(), layers.Dense(128, activation='relu'), layers.Dense(100, activation='softmax') ]) ``` 6. **编译和训练模型**: ```python model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit( ds_train, epochs=10, validation_data=ds_test ) ``` 7. **评估模型**: ```python model.evaluate(ds_test) ``` ### 使用PyTorch实现CIFAR-100动物识别 1. **安装必要的库**: ```bash pip install torch torchvision ``` 2. **导入库**: ```python import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms ``` 3. **数据预处理**: ```python transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) ``` 4. **构建模型**: ```python class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(32, 64, 3) self.fc1 = nn.Linear(64 * 6 * 6, 128) self.fc2 = nn.Linear(128, 100) def forward(self, x): x = self.pool(nn.functional.relu(self.conv1(x))) x = self.pool(nn.functional.relu(self.conv2(x))) x = x.view(-1, 64 * 6 * 6) x = nn.functional.relu(self.fc1(x)) x = self.fc2(x) return x net = Net() ``` 5. **定义损失函数和优化器**: ```python criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=0.001) ``` 6. **训练模型**: ```python for epoch in range(10): # 训练10个epoch running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 100 == 99: # 每100个小批量打印一次损失 print(f'Epoch [{epoch + 1}], Batch [{i + 1}], Loss: {running_loss / 100:.4f}') running_loss = 0.0 print('训练完成') ``` 7. **评估模型**: ```python correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f'准确率: {100 * correct / total:.2f}%') ``` 通过以上步骤,你可以在Python中使用TensorFlow或PyTorch实现CIFAR-100动物识别任务。
评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值