import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.datasets import cifar100
import matplotlib.pyplot as plt
# 设置字体和负号显示
plt.rcParams["font.sans-serif"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
plt.rcParams['axes.unicode_minus'] = False
# ResNet 块定义
def resnet_block(input_tensor, filters, stride=1, conv_shortcut=True, attention=None):
bn_axis = 3 if tf.keras.backend.image_data_format() == 'channels_last' else 1
if conv_shortcut:
shortcut = layers.Conv2D(4 * filters, 1, strides=stride, kernel_regularizer=regularizers.l2(1e-4))(input_tensor)
shortcut = layers.BatchNormalization(axis=bn_axis)(shortcut)
else:
shortcut = input_tensor
x = layers.Conv2D(filters, 1, strides=stride, kernel_regularizer=regularizers.l2(1e-4))(input_tensor)
x = layers.BatchNormalization(axis=bn_axis)(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters, 3, padding='same', kernel_regularizer=regularizers.l2(1e-4))(x)
x = layers.BatchNormalization(axis=bn_axis)(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(4 * filters, 1, kernel_regularizer=regularizers.l2(1e-4))(x)
x = layers.BatchNormalization(axis=bn_axis)(x)
x = layers.Add()([shortcut, x])
x = layers.Activation('relu')(x)
return x
# 构建 ResNet34 模型
def build_resnet34(input_shape=(32, 32, 3), classes=100, attention=None):
inputs = layers.Input(shape=input_shape)
x = layers.ZeroPadding2D(padding=(3, 3))(inputs)
x = layers.Conv2D(64, 7, strides=2, use_bias=False, kernel_regularizer=regularizers.l2(1e-4))(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
def make_layer(x, filters, blocks, stride=1):
x = resnet_block(x, filters, stride=stride, conv_shortcut=True, attention=attention)
for _ in range(1, blocks):
x = resnet_block(x, filters, conv_shortcut=False, attention=attention)
return x
x = make_layer(x, 64, 3, stride=1)
x = make_layer(x, 128, 4, stride=2)
x = make_layer(x, 256, 6, stride=2)
x = make_layer(x, 512, 3, stride=2)
x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(classes, activation='softmax', kernel_regularizer=regularizers.l2(1e-4))(x)
return models.Model(inputs, outputs)
# 加载数据集
(x_train, y_train), (x_test, y_test) = cifar100.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 构建基础模型并加载权重
model = build_resnet34()
model.load_weights('base_resnet34_cifar100.h5')
print("模型权重加载成功")
# 剪枝函数
def prune_conv_layer(conv_layer, prune_ratio):
weights = conv_layer.get_weights()
if len(weights) == 0:
return weights
num_channels = weights[0].shape[-1]
num_channels_to_keep = int(num_channels * (1 - prune_ratio))
channel_norms = np.linalg.norm(weights[0], axis=(0, 1, 2))
sorted_indices = np.argsort(channel_norms)[-num_channels_to_keep:]
sorted_indices = np.sort(sorted_indices)
pruned_weights = [w[..., sorted_indices] for w in weights]
return pruned_weights
# 更新模型结构
def prune_model(model, prune_ratio):
new_model = models.Sequential()
for layer in model.layers:
if isinstance(layer, layers.Conv2D):
pruned_weights = prune_conv_layer(layer, prune_ratio)
new_conv_layer = layers.Conv2D(
filters=pruned_weights[0].shape[-1],
kernel_size=layer.kernel_size,
strides=layer.strides,
padding=layer.padding,
use_bias=layer.use_bias,
kernel_regularizer=layer.kernel_regularizer
)
new_conv_layer.build(layer.input_shape)
new_conv_layer.set_weights(pruned_weights)
new_model.add(new_conv_layer)
elif isinstance(layer, layers.BatchNormalization):
new_bn_layer = layers.BatchNormalization()
new_bn_layer.build(layer.input_shape)
new_bn_layer.set_weights(layer.get_weights())
new_model.add(new_bn_layer)
else:
new_model.add(layer)
return new_model
# 测试模型精度
def evaluate_model(model, x_test, y_test):
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
return accuracy
# 计算加速比
def calculate_acceleration_ratio(original_model, pruned_model):
original_params = original_model.count_params()
pruned_params = pruned_model.count_params()
return original_params / pruned_params
# 原始模型精度
original_accuracy = evaluate_model(model, x_test, y_test)
print(f"原始模型精度: {original_accuracy:.4f}")
# 剪枝比例
prune_ratios = [0.1, 0.2]
for ratio in prune_ratios:
print(f"\n剪枝比例: {ratio}")
pruned_model = prune_model(model, ratio)
pruned_accuracy = evaluate_model(pruned_model, x_test, y_test)
acceleration_ratio = calculate_acceleration_ratio(model, pruned_model)
print(f"剪枝后精度: {pruned_accuracy:.4f}")
print(f"精度损失: {original_accuracy - pruned_accuracy:.4f}")
print(f"加速比: {acceleration_ratio:.2f}")//帮我看看这个代码
最新发布