这个是我现在的代码,我应该怎么修改?我传入的本来就是灰度图,以.tiff结尾import os
import re
import glob
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.model_selection import train_test_split
import imageio
import sys
from skimage.transform import resize
from skimage.filters import gaussian, threshold_otsu
from skimage.feature import canny
from skimage.measure import regionprops, label
import traceback
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adam
from pathlib import Path
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.metrics import MeanAbsoluteError
# =============== 配置参数=====================================
BASE_DIR = "F:/2025.7.2wavelengthtiff" # 根目录路径
START_WAVELENGTH = 788.55500 # 起始波长
END_WAVELENGTH = 788.55600 # 结束波长
STEP = 0.00005 # 波长步长
BATCH_SIZE = 8 # 批处理大小
IMAGE_SIZE = (256, 256) # 图像尺寸
TEST_SIZE = 0.2 # 测试集比例
RANDOM_SEED = 42 # 随机种子
MODEL_SAVE_PATH = Path.home() / "Documents" / "wavelength_model.h5" # 修改为.h5格式以提高兼容性
# ================================================================
# 设置中文字体支持
try:
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体
mpl.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
print("已设置中文字体支持")
except:
print("警告:无法设置中文字体,图表可能无法正确显示中文")
def generate_folder_names(start, end, step):
"""生成波长文件夹名称列表"""
num_folders = int(((end - start) / step)) + 1
folder_names = []
for i in range(num_folders):
wavelength = start + i * step
folder_name = f"{wavelength:.5f}"
folder_names.append(folder_name)
return folder_names
def find_best_match_file(folder_path, target_wavelength):
"""在文件夹中找到波长最接近目标值的TIFF文件"""
tiff_files = glob.glob(os.path.join(folder_path, "*.tiff")) + glob.glob(os.path.join(folder_path, "*.tif"))
if not tiff_files:
return None
best_match = None
min_diff = float('inf')
for file_path in tiff_files:
filename = os.path.basename(file_path)
match = re.search(r'\s*([\d.]+)_', filename)
if not match:
continue
try:
file_wavelength = float(match.group(1))
diff = abs(file_wavelength - target_wavelength)
if diff < min_diff:
min_diff = diff
best_match = file_path
except ValueError:
continue
return best_match
def extract_shape_features(binary_image):
"""提取形状特征:面积、周长、圆度"""
labeled = label(binary_image)
regions = regionprops(labeled)
if not regions:
# 如果无轮廓,返回零特征
return np.zeros(3)
features = []
for region in regions:
features.append([
region.area, # 面积
region.perimeter, # 周长
4 * np.pi * (region.area / (region.perimeter ** 2)) if region.perimeter > 0 else 0 # 圆度
])
features = np.array(features).mean(axis=0) # 取平均值
return features
def load_and_preprocess_image(file_path):
"""加载并预处理TIFF图像 - 针对光场强度分布图优化"""
try:
# 使用imageio读取图像
image = imageio.imread(file_path, as_gray=True)
# 转换为浮点数并归一化
image = image.astype(np.float32) / 255.0
# 图像尺寸调整
image = resize(image, (IMAGE_SIZE[0], IMAGE_SIZE[1]), anti_aliasing=True)
# 增强光点特征 - 应用高斯模糊和阈值处理
blurred = gaussian(image, sigma=1)
thresh = threshold_otsu(blurred)
binary = blurred > thresh * 0.8 # 降低阈值以保留更多光点信息
# 边缘检测
edges = canny(blurred, sigma=1)
# 形状特征提取
shape_features = extract_shape_features(binary)
# 组合原始图像、增强图像和边缘图像
processed = np.stack([image, binary, edges], axis=-1)
return processed, shape_features
except Exception as e:
print(f"图像加载失败: {e}, 使用空白图像代替")
return np.zeros((IMAGE_SIZE[0], IMAGE_SIZE[1], 3), dtype=np.float32), np.zeros(3, dtype=np.float32)
def create_tiff_dataset(file_paths):
"""从文件路径列表创建TensorFlow数据集"""
# 创建数据集
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
# 使用tf.py_function包装图像加载函数
def load_wrapper(file_path):
file_path_str = file_path.numpy().decode('utf-8')
image, features = load_and_preprocess_image(file_path_str)
return image, features
# 定义TensorFlow兼容的映射函数
def tf_load_wrapper(file_path):
image, features = tf.py_function(
func=load_wrapper,
inp=[file_path],
Tout=[tf.float32, tf.float32]
)
# 明确设置输出形状
image.set_shape((IMAGE_SIZE[0], IMAGE_SIZE[1], 3)) # 三个通道
features.set_shape((3,)) # 形状特征
return image, features
dataset = dataset.map(
tf_load_wrapper,
num_parallel_calls=tf.data.AUTOTUNE
)
dataset = dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
return dataset
def load_and_prepare_data():
"""加载所有数据并准备训练/测试集"""
# 生成所有文件夹名称
folder_names = generate_folder_names(START_WAVELENGTH, END_WAVELENGTH, STEP)
print(f"\n生成的文件夹数量: {len(folder_names)}")
print(f"起始文件夹: {folder_names[0]}")
print(f"结束文件夹: {folder_names[-1]}")
# 收集所有有效文件路径
valid_files = []
wavelengths = []
print("\n扫描文件夹并匹配文件...")
for folder_name in tqdm(folder_names, desc="处理文件夹"):
folder_path = os.path.join(BASE_DIR, folder_name)
if not os.path.isdir(folder_path):
continue
try:
target_wavelength = float(folder_name)
file_path = find_best_match_file(folder_path, target_wavelength)
if file_path:
valid_files.append(file_path)
wavelengths.append(target_wavelength)
except ValueError:
continue
print(f"\n找到的有效文件: {len(valid_files)}/{len(folder_names)}")
if not valid_files:
raise ValueError("未找到任何有效文件,请检查路径和文件夹名称")
# 转换为NumPy数组
wavelengths = np.array(wavelengths)
# 归一化波长标签
min_wavelength = np.min(wavelengths)
max_wavelength = np.max(wavelengths)
wavelength_range = max_wavelength - min_wavelength
wavelengths_normalized = (wavelengths - min_wavelength) / wavelength_range
print(f"波长范围: {min_wavelength:.6f} 到 {max_wavelength:.6f}, 范围大小: {wavelength_range:.6f}")
# 分割训练集和测试集
train_files, test_files, train_wavelengths, test_wavelengths = train_test_split(
valid_files, wavelengths_normalized, test_size=TEST_SIZE, random_state=RANDOM_SEED
)
print(f"训练集大小: {len(train_files)}")
print(f"测试集大小: {len(test_files)}")
# 创建数据集
train_dataset = create_tiff_dataset(train_files)
test_dataset = create_tiff_dataset(test_files)
# 创建波长标签数据集
train_labels = tf.data.Dataset.from_tensor_slices(train_wavelengths)
test_labels = tf.data.Dataset.from_tensor_slices(test_wavelengths)
# 合并图像和标签
train_dataset = tf.data.Dataset.zip((train_dataset, train_labels))
test_dataset = tf.data.Dataset.zip((test_dataset, test_labels))
return train_dataset, test_dataset, valid_files, min_wavelength, wavelength_range
def build_spot_detection_model(input_shape, feature_shape):
"""构建针对光点图像的专用模型"""
inputs = tf.keras.Input(shape=input_shape, name='input_image')
features_input = tf.keras.Input(shape=feature_shape, name='input_features')
# 使用Lambda层替代切片操作
channel1 = layers.Lambda(lambda x: x[..., 0:1])(inputs)
channel2 = layers.Lambda(lambda x: x[..., 1:2])(inputs)
channel3 = layers.Lambda(lambda x: x[..., 2:3])(inputs)
# 通道1: 原始图像处理
x1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(channel1)
x1 = layers.BatchNormalization()(x1)
x1 = layers.MaxPooling2D((2, 2))(x1)
# 通道2: 二值化图像处理
x2 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(channel2)
x2 = layers.BatchNormalization()(x2)
x2 = layers.MaxPooling2D((2, 2))(x2)
# 通道3: 边缘图像处理
x3 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(channel3)
x3 = layers.BatchNormalization()(x3)
x3 = layers.MaxPooling2D((2, 2))(x3)
# 合并三个通道
x = layers.concatenate([x1, x2, x3])
# 特征提取
x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.GlobalAveragePooling2D()(x)
# 形状特征处理
features_x = layers.Dense(64, activation='relu')(features_input)
features_x = layers.Dropout(0.5)(features_x)
# 合并图像特征和形状特征
x = layers.Concatenate()([x, features_x])
# 回归头
x = layers.Dense(512, activation='relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(1, activation='sigmoid')(x)
model = tf.keras.Model(inputs=[inputs, features_input], outputs=outputs)
optimizer = Adam(learning_rate=0.0001)
model.compile(
optimizer=optimizer,
loss='mean_squared_error', # 使用字符串
metrics=['mae'] # 使用字符串
)
return model
def train_and_evaluate_model(train_dataset, test_dataset, input_shape, feature_shape, wavelength_range):
"""训练和评估模型"""
model = build_spot_detection_model(input_shape, feature_shape)
model.summary()
# 回调函数
callbacks = [
tf.keras.callbacks.EarlyStopping(
patience=20,
restore_best_weights=True,
monitor='val_loss',
min_delta=1e-6
),
tf.keras.callbacks.ModelCheckpoint(
str(MODEL_SAVE_PATH), # 注意确保是 str 类型
save_best_only=True,
monitor='val_loss'
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5,
min_lr=1e-7
)
]
# 训练模型
history = model.fit(
train_dataset,
epochs=200, # 增加训练轮数
validation_data=test_dataset,
callbacks=callbacks,
verbose=2
)
# 评估模型
print("\n评估测试集性能...")
test_loss, test_mae_normalized = model.evaluate(test_dataset, verbose=0)
# 将MAE转换回原始波长单位
test_mae = test_mae_normalized * wavelength_range
print(f"测试集MAE (归一化值): {test_mae_normalized:.6f}")
print(f"测试集MAE (原始波长单位): {test_mae:.8f} 纳米")
# 绘制训练历史
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('损失变化')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.subplot(1, 2, 2)
# 修改这里:使用正确的键名
plt.plot(history.history['mae'], label='训练MAE')
plt.plot(history.history['val_mae'], label='验证MAE')
plt.title('MAE变化')
plt.xlabel('Epoch')
plt.ylabel('MAE')
plt.legend()
plt.tight_layout()
plt.savefig('f:/phD/代码/training_history.png')
print("训练历史图已保存为 'training_history.png'")
# 显式保存最终模型(已移除 save_format 参数)
model.save(MODEL_SAVE_PATH)
return model
def predict_test_image(model, test_image_path, min_wavelength, wavelength_range):
"""预测单个测试图片的波长"""
# 加载并预处理图像
image, features = load_and_preprocess_image(test_image_path)
# 添加批次维度
image = np.expand_dims(image, axis=0)
features = np.expand_dims(features, axis=0)
# 预测
predicted_normalized = model.predict([image, features], verbose=0)[0][0]
# 反归一化
predicted_wavelength = predicted_normalized * wavelength_range + min_wavelength
# 显示结果
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image[0, :, :, 0], cmap='gray') # 原始图像通道
plt.title(f"原始光场强度分布")
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image[0, :, :, 1], cmap='gray') # 增强图像通道
plt.title(f"增强光点特征")
plt.axis('off')
plt.suptitle(f"预测波长: {predicted_wavelength:.6f} 纳米", fontsize=16)
# 保存结果
result_path = "f:/phD/代码/prediction_result.png"
plt.savefig(result_path)
print(f"\n预测结果已保存为 '{result_path}'")
return predicted_wavelength
def validate_data_loading(file_paths, num_samples=3):
"""验证数据加载是否正确 - 针对光点图像优化"""
print("\n验证数据加载...")
plt.figure(figsize=(15, 10))
for i in range(min(num_samples, len(file_paths))):
file_path = file_paths[i]
image, features = load_and_preprocess_image(file_path)
# 原始图像
plt.subplot(num_samples, 3, i*3+1)
plt.imshow(image[..., 0], cmap='gray')
plt.title(f"原始图像 {i+1}")
plt.axis('off')
# 增强图像
plt.subplot(num_samples, 3, i*3+2)
plt.imshow(image[..., 1], cmap='gray')
plt.title(f"增强光点特征 {i+1}")
plt.axis('off')
# 边缘图像
plt.subplot(num_samples, 3, i*3+3)
plt.imshow(image[..., 2], cmap='gray')
plt.title(f"边缘检测 {i+1}")
plt.axis('off')
print(f"图像 {i+1}: {file_path}")
print(f"形状: {image.shape}, 原始值范围: {np.min(image[...,0]):.2f}-{np.max(image[...,0]):.2f}")
print(f"增强值范围: {np.min(image[...,1]):.2f}-{np.max(image[...,1]):.2f}")
plt.tight_layout()
plt.savefig('f:/phD/代码/data_validation.png')
print("数据验证图已保存为 'data_validation.png'")
def main():
"""主函数"""
print(f"TensorFlow 版本: {tf.__version__}")
# 1. 加载数据
try:
train_dataset, test_dataset, all_files, min_wavelength, wavelength_range = load_and_prepare_data()
print(f"最小波长: {min_wavelength:.6f}, 波长范围: {wavelength_range:.6f}")
except Exception as e:
print(f"数据加载失败: {str(e)}")
return
# 验证数据加载
validate_data_loading(all_files[:3])
# 获取输入形状和特征形状
try:
for images, features in train_dataset.take(1):
input_shape = images.shape[1:]
feature_shape = features.shape[1:]
print(f"输入形状: {input_shape}")
print(f"特征形状: {feature_shape}")
except Exception as e:
print(f"获取输入形状失败: {str(e)}")
input_shape = (IMAGE_SIZE[0], IMAGE_SIZE[1], 3) # 三个通道
feature_shape = (3,) # 形状特征
print(f"使用默认形状: {input_shape}, {feature_shape}")
# 2. 训练模型
print("\n开始训练模型...")
try:
model = train_and_evaluate_model(train_dataset, test_dataset, input_shape, feature_shape, wavelength_range)
except Exception as e:
print(f"模型训练失败: {str(e)}")
traceback.print_exc()
return
# 3. 测试模型 - 从测试集中随机选择一张图片
print("\n从测试集中随机选择一张图片进行预测...")
try:
# 获取整个测试集的一个批次
for test_images, test_features, test_labels in test_dataset.take(1):
# 确保有样本可用
if test_images.shape[0] > 0:
# 选择第一个样本
test_image = test_images[0].numpy()
test_feature = test_features[0].numpy()
# 安全提取第一个标签值
labels_np = test_labels.numpy()
if labels_np.ndim == 0: # 标量情况
true_wavelength_normalized = labels_np.item()
else: # 数组情况
true_wavelength_normalized = labels_np[0]
# 反归一化真实值
true_wavelength = true_wavelength_normalized * wavelength_range + min_wavelength
# 保存测试图片
test_image_path = "f:/phD/代码/test_image.tiff"
imageio.imwrite(test_image_path, (test_image[..., 0] * 255).astype(np.uint8))
# 预测
predicted_wavelength = predict_test_image(model, test_image_path, min_wavelength, wavelength_range)
print(f"真实波长: {true_wavelength:.6f} 纳米")
print(f"预测波长: {predicted_wavelength:.6f} 纳米")
print(f"绝对误差: {abs(predicted_wavelength-true_wavelength):.8f} 纳米")
print(f"相对误差: {abs(predicted_wavelength-true_wavelength)/wavelength_range*100:.4f}%")
else:
print("错误:测试批次中没有样本")
except Exception as e:
print(f"测试失败: {str(e)}")
traceback.print_exc()
# 4. 用户自定义测试图片
print("\n您可以使用自己的图片进行测试:")
# 加载模型
model = tf.keras.models.load_model(MODEL_SAVE_PATH)
# 从之前的输出中获取这些值
#wavelength_range = ... # 请替换为实际值
# 提示用户输入图片路径
image_path = input("请输入您要测试的图片路径(例如:'test_image.tiff'):")
# 进行预测
#predicted = predict_test_image(model, image_path, min_wavelength, wavelength_range)
predicted = predict_test_image(model, image_path)
print(f"预测波长: {predicted:.6f} 纳米")
print("\n程序执行完成。")
if __name__ == "__main__":
# 设置TensorFlow日志级别
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 确保必要的库已安装
try:
import imageio
from skimage.transform import resize
from skimage.filters import gaussian, threshold_otsu
from skimage.feature import canny
from skimage.measure import regionprops, label
except ImportError:
print("安装必要的库...")
import subprocess
subprocess.run([sys.executable, "-m", "pip", "install", "imageio", "scikit-image"])
import imageio
from skimage.transform import resize
from skimage.filters import gaussian, threshold_otsu
from skimage.feature import canny
from skimage.measure import regionprops, label
# 执行主函数
main()
最新发布