import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 设定使用的 GPU
import tensorflow as tf
from dataset import generate_data
import numpy as np
from model import enhancednet
# 检查 TensorFlow 是否可以识别 GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
# 限制 TensorFlow 只使用第0号 GPU
tf.config.set_visible_devices(gpus[0], 'GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)
print(f"Using GPU: {gpus[0].name}")
except RuntimeError as e:
print(e)
else:
print("No GPU available, using CPU.")
# 列出所有物理设备
print("All physical devices:")
for device in tf.config.list_physical_devices():
print(device)
print("\nAvailable GPU devices:")
for gpu in tf.config.list_physical_devices('GPU'):
print(gpu)
# 设置参数
image_rows = 128
image_cols = 256
filename = 'detached_data.mat'
# 生成和准备数据
train_data, train1_data, label_data = generate_data(filename)
train_data = np.array(train_data, dtype=float).reshape(-1, image_rows, image_cols, 1)
train1_data = np.array(train1_data, dtype=float).reshape(-1, image_rows, image_cols, 1)
# 确保数据类型为 float32
train_data = train_data.astype('float32')
train1_data = train1_data.astype('float32')
# 打印数据形状以确认
print("Train data shape:", train_data.shape)
print("Train1 data shape:", train1_data.shape)
# 创建模型
model = enhancednet()
# 编译模型(假设 enhancednet 已经包含编译逻辑,可根据需要调整)
model.compile(optimizer='adam',
loss='mean_squared_error', # 示例损失函数,依据具体任务调整
metrics=['accuracy'])
# 训练模型,并指定设备
print("\nStarting training on GPU...")
history = model.fit(train_data, train1_data,
batch_size=32,
epochs=100,
verbose=2,
shuffle=True,
validation_split=0.1)
# 保存模型
model.save('enhanced_model.h5')
print("\nModel saved to 'enhanced_model.h5'")
102005
Python结合TensorFlow进行模型训练
最新推荐文章于 2025-12-04 22:56:41 发布
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
TensorFlow-v2.15
TensorFlow
TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型
275

被折叠的 条评论
为什么被折叠?



