import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import cv2
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from tqdm import tqdm
import warnings
# 设置中文字体显示
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
warnings.filterwarnings('ignore')
# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')
# ==================== 高效特征提取器 ====================
class AdvancedFeatureExtractor:
def __init__(self):
# 使用EfficientNetV2预训练模型
self.model = models.efficientnet_v2_s(pretrained=True)
# 修改分类层
self.model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(1280, 512),
nn.ReLU(),
nn.Linear(512, 256) # 输出256维特征向量
)
self.model.eval()
self.model.to(device)
# 增强的图像预处理
self.transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def extract_features(self, image_path):
"""提取深度特征和病害比例"""
try:
if not os.path.exists(image_path):
return np.zeros(257) # 256维特征 + 1维病害比例
# 读取并预处理图像
img = Image.open(image_path).convert('RGB')
img_tensor = self.transform(img).unsqueeze(0).to(device)
# 提取深度特征
with torch.no_grad():
features = self.model(img_tensor).cpu().numpy().flatten()
# 计算病害比例
img_array = np.array(img)
disease_ratio = self.calculate_disease_ratio(img_array)
# 组合特征
return np.append(features, disease_ratio)
except Exception as e:
print(f"特征提取错误: {image_path} - {e}")
return np.zeros(257)
def calculate_disease_ratio(self, image):
"""精确病害比例计算"""
try:
# 转换为HSV颜色空间
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
# 病害颜色范围定义
disease_mask = np.zeros(image.shape[:2], dtype=np.uint8)
# 褐色病斑
brown_lower = np.array([5, 30, 20])
brown_upper = np.array([20, 200, 180])
brown_mask = cv2.inRange(hsv, brown_lower, brown_upper)
# 黑色病斑
black_lower = np.array([0, 0, 0])
black_upper = np.array([180, 100, 80])
black_mask = cv2.inRange(hsv, black_lower, black_upper)
# 霉变区域
mold_lower = np.array([0, 0, 30])
mold_upper = np.array([180, 50, 150])
mold_mask = cv2.inRange(hsv, mold_lower, mold_upper)
# 合并病斑区域
disease_mask = cv2.bitwise_or(brown_mask, black_mask)
disease_mask = cv2.bitwise_or(disease_mask, mold_mask)
# 形态学处理去除噪声
kernel = np.ones((5, 5), np.uint8)
disease_mask = cv2.morphologyEx(disease_mask, cv2.MORPH_OPEN, kernel)
disease_mask = cv2.dilate(disease_mask, kernel, iterations=1)
# 计算病害比例
total_pixels = image.shape[0] * image.shape[1]
disease_pixels = np.count_nonzero(disease_mask)
return disease_pixels / total_pixels
except Exception as e:
print(f"病害比例计算错误: {e}")
return 0.0
# ==================== 深度分类模型 ====================
class DiseaseClassifier(nn.Module):
def __init__(self, input_dim=257, output_dim=3):
super(DiseaseClassifier, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(input_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Linear(128, output_dim)
)
def forward(self, x):
return self.classifier(x)
# ==================== 石榴病害分析系统 ====================
class PomegranateDiseaseSystem:
def __init__(self, data_path):
self.data_path = data_path
self.feature_extractor = AdvancedFeatureExtractor()
self.classifier = DiseaseClassifier().to(device)
self.class_names = ['健康期', '初发期', '发病期']
self.image_paths = []
self.weeks = []
self.fruit_ids = []
self.tree_ids = []
self.features = []
self.labels = []
self.disease_ratios = []
# 训练配置
self.criterion = nn.CrossEntropyLoss()
self.optimizer = torch.optim.AdamW(
self.classifier.parameters(),
lr=0.001,
weight_decay=1e-4
)
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode='max',
factor=0.5,
patience=5,
verbose=True
)
def load_data(self):
"""加载图像路径和元数据"""
print("扫描数据目录...")
self.image_paths = []
self.weeks = []
self.fruit_ids = []
self.tree_ids = []
week_folders = sorted([f for f in os.listdir(self.data_path)
if f.startswith('week_') and os.path.isdir(os.path.join(self.data_path, f))])
if not week_folders:
print(f"错误: 在 {self.data_path} 中未找到week_开头的文件夹")
return False
for week_folder in tqdm(week_folders, desc="处理周数文件夹"):
week_path = os.path.join(self.data_path, week_folder)
tree_folders = [f for f in os.listdir(week_path)
if os.path.isdir(os.path.join(week_path, f))]
for tree_folder in tree_folders:
tree_path = os.path.join(week_path, tree_folder)
# 获取所有图像文件
image_files = [f for f in os.listdir(tree_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
for image_file in image_files:
image_path = os.path.join(tree_path, image_file)
if os.path.exists(image_path):
# 解析果实编号:树编号_果实编号
fruit_num = image_file.split('.')[0].split('_')[-1]
fruit_id = f"{tree_folder}_{fruit_num}"
self.image_paths.append(image_path)
self.weeks.append(week_folder)
self.fruit_ids.append(fruit_id)
self.tree_ids.append(tree_folder)
print(f"找到 {len(self.image_paths)} 张图像")
return True
def encode_labels(self):
"""基于周数和病害比例生成伪标签"""
print("生成初始伪标签...")
self.labels = []
# 收集每周的平均病害比例
week_ratios = {}
for i, week in enumerate(self.weeks):
if week not in week_ratios:
week_ratios[week] = []
week_ratios[week].append(self.disease_ratios[i])
# 计算每周的中值病害比例
week_median = {}
for week, ratios in week_ratios.items():
week_median[week] = np.median(ratios)
# 基于周数和病害比例分配标签
for i in range(len(self.image_paths)):
week = self.weeks[i]
week_num = int(week.split('_')[1])
ratio = self.disease_ratios[i]
median_ratio = week_median[week]
# 分类规则
if week_num <= 4: # 前4周
if ratio < 0.02:
self.labels.append(0) # 健康期
elif ratio < 0.10:
self.labels.append(1) # 初发期
else:
self.labels.append(2) # 发病期
elif week_num <= 8: # 中期
if ratio < max(0.03, median_ratio * 0.8):
self.labels.append(0)
elif ratio < max(0.15, median_ratio * 1.2):
self.labels.append(1)
else:
self.labels.append(2)
else: # 后期
if ratio < max(0.05, median_ratio * 0.7):
self.labels.append(0)
elif ratio < max(0.20, median_ratio * 1.3):
self.labels.append(1)
else:
self.labels.append(2)
# 统计标签分布
class_counts = np.bincount(self.labels)
print("初始伪标签分布:")
for i, name in enumerate(self.class_names):
print(f" {name}: {class_counts[i]} 张图像 ({class_counts[i]/len(self.labels)*100:.1f}%)")
def extract_features(self):
"""提取所有图像特征"""
print("提取图像特征...")
self.features = []
self.disease_ratios = []
for path in tqdm(self.image_paths, desc="处理图像"):
features = self.feature_extractor.extract_features(path)
self.features.append(features)
self.disease_ratios.append(features[-1]) # 保存病害比例
self.features = np.array(self.features)
print(f"特征矩阵形状: {self.features.shape}")
def train_classifier(self, num_epochs=50):
"""训练深度分类器"""
if not self.features.size or not self.labels:
print("缺少特征或标签,无法训练")
return
print("训练深度分类器...")
# 转换为PyTorch张量
features_tensor = torch.tensor(self.features, dtype=torch.float32).to(device)
labels_tensor = torch.tensor(self.labels, dtype=torch.long).to(device)
# 创建数据集
dataset = torch.utils.data.TensorDataset(features_tensor, labels_tensor)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
best_loss = float('inf')
best_acc = 0.0
# 训练循环
for epoch in range(num_epochs):
self.classifier.train()
epoch_loss = 0.0
correct = 0
total = 0
for inputs, targets in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
self.optimizer.zero_grad()
outputs = self.classifier(inputs)
loss = self.criterion(outputs, targets)
loss.backward()
self.optimizer.step()
epoch_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
epoch_loss /= len(dataloader)
epoch_acc = 100. * correct / total
print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f} Acc: {epoch_acc:.2f}%")
# 更新学习率
self.scheduler.step(epoch_acc)
# 保存最佳模型
if epoch_acc > best_acc:
best_acc = epoch_acc
best_loss = epoch_loss
torch.save(self.classifier.state_dict(), 'best_classifier.pth')
print(f"训练完成 - 最佳准确率: {best_acc:.2f}%")
self.classifier.load_state_dict(torch.load('best_classifier.pth'))
def predict(self):
"""预测病害阶段"""
if not self.features.size:
print("缺少特征,无法预测")
return
print("预测病害阶段...")
self.classifier.eval()
features_tensor = torch.tensor(self.features, dtype=torch.float32).to(device)
with torch.no_grad():
outputs = self.classifier(features_tensor)
_, predicted = outputs.max(1)
self.predicted_labels = predicted.cpu().numpy()
# 统计预测结果
class_counts = np.bincount(self.predicted_labels)
print("预测结果分布:")
for i, name in enumerate(self.class_names):
print(f" {name}: {class_counts[i]} 张图像 ({class_counts[i]/len(self.predicted_labels)*100:.1f}%)")
def apply_temporal_constraints(self):
"""应用时序一致性约束"""
print("应用时序约束...")
# 按果实分组
fruit_data = {}
for i in range(len(self.image_paths)):
fruit_key = (self.tree_ids[i], self.fruit_ids[i])
week_num = int(self.weeks[i].split('_')[1])
if fruit_key not in fruit_data:
fruit_data[fruit_key] = []
fruit_data[fruit_key].append({
'index': i,
'week': week_num,
'label': self.predicted_labels[i],
'ratio': self.disease_ratios[i]
})
# 对每个果实的数据按周数排序
for fruit_key, records in fruit_data.items():
records.sort(key=lambda x: x['week'])
max_label = -1
for record in records:
# 病害不应逆转
if record['label'] < max_label:
record['label'] = max_label
self.predicted_labels[record['index']] = max_label
else:
max_label = record['label']
print("时序约束应用完成")
def generate_results(self):
"""生成结果文件"""
print("生成结果文件...")
results = []
for i in range(len(self.image_paths)):
results.append({
'周数': self.weeks[i],
'果实编号': self.fruit_ids[i],
'果实阶段': self.class_names[self.predicted_labels[i]]
})
df = pd.DataFrame(results)
# 按周数和果实编号排序
df['周数_num'] = df['周数'].str.extract('(\d+)').astype(int)
df['树编号_num'] = df['果实编号'].str.split('_').str[0].astype(int)
df['果实编号_num'] = df['果实编号'].str.split('_').str[1].astype(int)
df = df.sort_values(['周数_num', '树编号_num', '果实编号_num'])
df = df.drop(['周数_num', '树编号_num', '果实编号_num'], axis=1)
# 保存结果
df.to_excel('result1.xlsx', index=False)
print(f"结果已保存至 result1.xlsx (共 {len(df)} 条记录)")
# 打印统计信息
class_stats = df['果实阶段'].value_counts()
print("\n病害阶段分布:")
for stage, count in class_stats.items():
percentage = count / len(df) * 100
print(f"{stage}: {count} 张图像 ({percentage:.1f}%)")
def visualize_results(self):
"""可视化分析结果"""
print("生成可视化图表...")
# 病害阶段随时间变化
plt.figure(figsize=(12, 6))
week_stats = {}
for i, week in enumerate(self.weeks):
if week not in week_stats:
week_stats[week] = {'健康期': 0, '初发期': 0, '发病期': 0}
week_stats[week][self.class_names[self.predicted_labels[i]]] += 1
# 按周数排序
weeks_sorted = sorted(week_stats.keys(), key=lambda x: int(x.split('_')[1]))
health_rates = []
early_rates = []
late_rates = []
for week in weeks_sorted:
total = sum(week_stats[week].values())
health_rates.append(week_stats[week]['健康期'] / total * 100)
early_rates.append(
思维导图
参考来源
优快云
优快云
优快云
请继续给出完整代码
回答
我们正在处理一个关于石榴果实病害阶段划分的问题。之前已经给出了一个完整的代码框架,但用户现在要求获取完整代码。根据引用[1],我们可能需要模拟浏览器行为来获取动态加载的内容,但在当前任务中,我们的数据是本地图像文件,不需要网络爬取。因此,我们只需提供完整的本地图像处理代码。
然而,用户提到“获取完整代码”,结合之前的上下文,我们应当提供可直接运行的完整代码。注意:代码中已经包含了数据加载、特征提取、模型训练、预测和结果生成的全过程。
但需要注意的是:
代码假设数据存储在特定目录结构下(按周数、树编号组织的文件夹)。
代码使用EfficientNetV2作为特征提取器,并训练一个全连接分类器。
由于数据量可能较大,特征提取步骤会比较耗时。
我们将提供完整的代码,包括所有必要的导入、类定义和主流程。
此外,根据用户要求,我们需要生成相关问题。结合当前任务,相关问题应围绕石榴病害阶段划分、模型改进、实际应用等方面。
完整代码如下:
以下是完整的石榴果实病害阶段划分系统代码,整合了深度学习模型、图像处理和时序分析功能:
Python
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import cv2
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from tqdm import tqdm
import warnings
# 设置中文字体显示
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
warnings.filterwarnings('ignore')
# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')
# ==================== 高效特征提取器 ====================
class AdvancedFeatureExtractor:
def __init__(self):
# 使用EfficientNetV2预训练模型
self.model = models.efficientnet_v2_s(pretrained=True)
# 修改分类层
self.model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(1280, 512),
nn.ReLU(),
nn.Linear(512, 256) # 输出256维特征向量
)
self.model.eval()
self.model.to(device)
# 增强的图像预处理
self.transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def extract_features(self, image_path):
"""提取深度特征和病害比例"""
try:
if not os.path.exists(image_path):
return np.zeros(257) # 256维特征 + 1维病害比例
# 读取并预处理图像
img = Image.open(image_path).convert('RGB')
img_tensor = self.transform(img).unsqueeze(0).to(device)
# 提取深度特征
with torch.no_grad():
features = self.model(img_tensor).cpu().numpy().flatten()
# 计算病害比例
img_array = np.array(img)
disease_ratio = self.calculate_disease_ratio(img_array)
# 组合特征
return np.append(features, disease_ratio)
except Exception as e:
print(f"特征提取错误: {image_path} - {e}")
return np.zeros(257)
def calculate_disease_ratio(self, image):
"""精确病害比例计算"""
try:
# 转换为HSV颜色空间
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
# 病害颜色范围定义
disease_mask = np.zeros(image.shape[:2], dtype=np.uint8)
# 褐色病斑
brown_lower = np.array([5, 30, 20])
brown_upper = np.array([20, 200, 180])
brown_mask = cv2.inRange(hsv, brown_lower, brown_upper)
# 黑色病斑
black_lower = np.array([0, 0, 0])
black_upper = np.array([180, 100, 80])
black_mask = cv2.inRange(hsv, black_lower, black_upper)
# 霉变区域
mold_lower = np.array([0, 0, 30])
mold_upper = np.array([180, 50, 150])
mold_mask = cv2.inRange(hsv, mold_lower, mold_upper)
# 合并病斑区域
disease_mask = cv2.bitwise_or(brown_mask, black_mask)
disease_mask = cv2.bitwise_or(disease_mask, mold_mask)
# 形态学处理去除噪声
kernel = np.ones((5, 5), np.uint8)
disease_mask = cv2.morphologyEx(disease_mask, cv2.MORPH_OPEN, kernel)
disease_mask = cv2.dilate(disease_mask, kernel, iterations=1)
# 计算病害比例
total_pixels = image.shape[0] * image.shape[1]
disease_pixels = np.count_nonzero(disease_mask)
return disease_pixels / total_pixels
except Exception as e:
print(f"病害比例计算错误: {e}")
return 0.0
# ==================== 深度分类模型 ====================
class DiseaseClassifier(nn.Module):
def __init__(self, input_dim=257, output_dim=3):
super(DiseaseClassifier, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(input_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Linear(128, output_dim)
)
def forward(self, x):
return self.classifier(x)
# ==================== 石榴病害分析系统 ====================
class PomegranateDiseaseSystem:
def __init__(self, data_path):
self.data_path = data_path
self.feature_extractor = AdvancedFeatureExtractor()
self.classifier = DiseaseClassifier().to(device)
self.class_names = ['健康期', '初发期', '发病期']
self.image_paths = []
self.weeks = []
self.fruit_ids = []
self.tree_ids = []
self.features = []
self.labels = []
self.disease_ratios = []
self.predicted_labels = []
# 训练配置
self.criterion = nn.CrossEntropyLoss()
self.optimizer = torch.optim.AdamW(
self.classifier.parameters(),
lr=0.001,
weight_decay=1e-4
)
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode='max',
factor=0.5,
patience=5,
verbose=True
)
def load_data(self):
"""加载图像路径和元数据"""
print("扫描数据目录...")
self.image_paths = []
self.weeks = []
self.fruit_ids = []
self.tree_ids = []
week_folders = sorted([f for f in os.listdir(self.data_path)
if f.startswith('week_') and os.path.isdir(os.path.join(self.data_path, f))])
if not week_folders:
print(f"错误: 在 {self.data_path} 中未找到week_开头的文件夹")
return False
for week_folder in tqdm(week_folders, desc="处理周数文件夹"):
week_path = os.path.join(self.data_path, week_folder)
tree_folders = [f for f in os.listdir(week_path)
if os.path.isdir(os.path.join(week_path, f))]
for tree_folder in tree_folders:
tree_path = os.path.join(week_path, tree_folder)
# 获取所有图像文件
image_files = [f for f in os.listdir(tree_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
for image_file in image_files:
image_path = os.path.join(tree_path, image_file)
if os.path.exists(image_path):
# 解析果实编号:树编号_果实编号
fruit_num = image_file.split('.')[0].split('_')[-1]
fruit_id = f"{tree_folder}_{fruit_num}"
self.image_paths.append(image_path)
self.weeks.append(week_folder)
self.fruit_ids.append(fruit_id)
self.tree_ids.append(tree_folder)
print(f"找到 {len(self.image_paths)} 张图像")
return True
def encode_labels(self):
"""基于周数和病害比例生成伪标签"""
print("生成初始伪标签...")
self.labels = []
# 收集每周的平均病害比例
week_ratios = {}
for i, week in enumerate(self.weeks):
if week not in week_ratios:
week_ratios[week] = []
week_ratios[week].append(self.disease_ratios[i])
# 计算每周的中值病害比例
week_median = {}
for week, ratios in week_ratios.items():
week_median[week] = np.median(ratios)
# 基于周数和病害比例分配标签
for i in range(len(self.image_paths)):
week = self.weeks[i]
week_num = int(week.split('_')[1])
ratio = self.disease_ratios[i]
median_ratio = week_median[week]
# 分类规则
if week_num <= 4: # 前4周
if ratio < 0.02:
self.labels.append(0) # 健康期
elif ratio < 0.10:
self.labels.append(1) # 初发期
else:
self.labels.append(2) # 发病期
elif week_num <= 8: # 中期
if ratio < max(0.03, median_ratio * 0.8):
self.labels.append(0)
elif ratio < max(0.15, median_ratio * 1.2):
self.labels.append(1)
else:
self.labels.append(2)
else: # 后期
if ratio < max(0.05, median_ratio * 0.7):
self.labels.append(0)
elif ratio < max(0.20, median_ratio * 1.3):
self.labels.append(1)
else:
self.labels.append(2)
# 统计标签分布
class_counts = np.bincount(self.labels)
print("初始伪标签分布:")
for i, name in enumerate(self.class_names):
print(f" {name}: {class_counts[i]} 张图像 ({class_counts[i]/len(self.labels)*100:.1f}%)")
def extract_features(self):
"""提取所有图像特征"""
print("提取图像特征...")
self.features = []
self.disease_ratios = []
for path in tqdm(self.image_paths, desc="处理图像"):
features = self.feature_extractor.extract_features(path)
self.features.append(features)
self.disease_ratios.append(features[-1]) # 保存病害比例
self.features = np.array(self.features)
print(f"特征矩阵形状: {self.features.shape}")
def train_classifier(self, num_epochs=50):
"""训练深度分类器"""
if not self.features.size or not self.labels:
print("缺少特征或标签,无法训练")
return
print("训练深度分类器...")
# 转换为PyTorch张量
features_tensor = torch.tensor(self.features, dtype=torch.float32).to(device)
labels_tensor = torch.tensor(self.labels, dtype=torch.long).to(device)
# 创建数据集
dataset = torch.utils.data.TensorDataset(features_tensor, labels_tensor)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
best_loss = float('inf')
best_acc = 0.0
# 训练循环
for epoch in range(num_epochs):
self.classifier.train()
epoch_loss = 0.0
correct = 0
total = 0
for inputs, targets in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
self.optimizer.zero_grad()
outputs = self.classifier(inputs)
loss = self.criterion(outputs, targets)
loss.backward()
self.optimizer.step()
epoch_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
epoch_loss /= len(dataloader)
epoch_acc = 100. * correct / total
print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f} Acc: {epoch_acc:.2f}%")
# 更新学习率
self.scheduler.step(epoch_acc)
# 保存最佳模型
if epoch_acc > best_acc:
best_acc = epoch_acc
best_loss = epoch_loss
torch.save(self.classifier.state_dict(), 'best_classifier.pth')
print(f"训练完成 - 最佳准确率: {best_acc:.2f}%")
self.classifier.load_state_dict(torch.load('best_classifier.pth'))
def predict(self):
"""预测病害阶段"""
if not self.features.size:
print("缺少特征,无法预测")
return
print("预测病害阶段...")
self.classifier.eval()
features_tensor = torch.tensor(self.features, dtype=torch.float32).to(device)
with torch.no_grad():
outputs = self.classifier(features_tensor)
_, predicted = outputs.max(1)
self.predicted_labels = predicted.cpu().numpy()
# 统计预测结果
class_counts = np.bincount(self.predicted_labels)
print("预测结果分布:")
for i, name in enumerate(self.class_names):
print(f" {name}: {class_counts[i]} 张图像 ({class_counts[i]/len(self.predicted_labels)*100:.1f}%)")
def apply_temporal_constraints(self):
"""应用时序一致性约束"""
print("应用时序约束...")
# 按果实分组
fruit_data = {}
for i in range(len(self.image_paths)):
fruit_key = (self.tree_ids[i], self.fruit_ids[i])
week_num = int(self.weeks[i].split('_')[1])
if fruit_key not in fruit_data:
fruit_data[fruit_key] = []
fruit_data[fruit_key].append({
'index': i,
'week': week_num,
'label': self.predicted_labels[i],
'ratio': self.disease_ratios[i]
})
# 对每个果实的数据按周数排序
for fruit_key, records in fruit_data.items():
records.sort(key=lambda x: x['week'])
max_label = -1
for record in records:
# 病害不应逆转
if record['label'] < max_label:
record['label'] = max_label
self.predicted_labels[record['index']] = max_label
else:
max_label = record['label']
print("时序约束应用完成")
def generate_results(self):
"""生成结果文件"""
print("生成结果文件...")
results = []
for i in range(len(self.image_paths)):
results.append({
'周数': self.weeks[i],
'果实编号': self.fruit_ids[i],
'果实阶段': self.class_names[self.predicted_labels[i]]
})
df = pd.DataFrame(results)
# 按周数和果实编号排序
df['周数_num'] = df['周数'].str.extract('(\d+)').astype(int)
df['树编号_num'] = df['果实编号'].str.split('_').str[0].astype(int)
df['果实编号_num'] = df['果实编号'].str.split('_').str[1].astype(int)
df = df.sort_values(['周数_num', '树编号_num', '果实编号_num'])
df = df.drop(['周数_num', '树编号_num', '果实编号_num'], axis=1)
# 保存结果
df.to_excel('result1.xlsx', index=False)
print(f"结果已保存至 result1.xlsx (共 {len(df)} 条记录)")
# 打印统计信息
class_stats = df['果实阶段'].value_counts()
print("\n病害阶段分布:")
for stage, count in class_stats.items():
percentage = count / len(df) * 100
print(f"{stage}: {count} 张图像 ({percentage:.1f}%)")
def visualize_results(self):
"""可视化分析结果"""
print("生成可视化图表...")
# 病害阶段随时间变化
plt.figure(figsize=(12, 6))
week_stats = {}
for i, week in enumerate(self.weeks):
if week not in week_stats:
week_stats[week] = {'健康期': 0, '初发期': 0, '发病期': 0}
week_stats[week][self.class_names[self.predicted_labels[i]]] += 1
# 按周数排序
weeks_sorted = sorted(week_stats.keys(), key=lambda x: int(x.split('_')[1]))
health_rates = []
early_rates = []
late_rates = []
for week in weeks_sorted:
total = sum(week_stats[week].values())
health_rates.append(week_stats[week]['健康期'] / total * 100)
early_rates.append(week_stats[week]['初发期'] / total * 100)
late_rates.append(week_stats[week]['发病期'] / total * 100)请继续这个生成完整代码
最新发布