import hashlib
import shutil
from tkinter import Image
import numpy as np
import requests
from bs4 import BeautifulSoup
import os
import csv
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, io
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter
import time
import copy
import torchvision.models
from pathlib import Path
import random
import logging
from typing import List, Tuple, Dict, Optional
from tqdm import tqdm
from flower import csv_path
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("../flower_classifier.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
# 定义花卉品种(可根据需要扩展)
flower_species = ["玫瑰", "郁金香", "百合", "康乃馨", "向日葵", "牡丹","菊花","薰衣草"]
# 创建保存目录
desktop_path = Path.home() / "Desktop"
log_dir = desktop_path / "flower_logs"
image_folder = desktop_path / "raw_flowers"
processed_folder = desktop_path / "processed_flowers"
class_folders = {species: desktop_path / f"{species}_flowers" for species in flower_species}
# 确保所有目录存在
for dir_path in [log_dir, image_folder, processed_folder] + list(class_folders.values()):
dir_path.mkdir(parents=True, exist_ok=True)
# 花卉百科网页URL(每个品种一个URL,示例中使用百度图片搜索结果)
# 注意:实际使用时需要为每个品种准备专门的图片源
species_urls = {
"玫瑰": "https://cn.bing.com/search?q=%e7%8e%ab%e7%91%b0%e5%9b%be%e7%89%87&qs=LT&pq=%e7%8e%ab%e7%91%b0%e5%9b%be%e7%89%87&sc=12-4&cvid=C98F6F4569DF4AFBBA0461E14F01169D&FORM=QBRE&sp=1&lq=玫瑰",
"郁金香": "https://cn.bing.com/search?q=%E9%83%81%E9%87%91%E9%A6%99%E5%9B%BE%E7%89%87&qs=n&form=QBRE&sp=-1&lq=0&pq=%E9%83%81%E9%87%91%E9%A6%99%E5%9B%BE%E7%89%87&sc=12-5&sk=&cvid=A617B61EB9C24713B03FB189188D3B6D=郁金香",
"百合": "https://cn.bing.com/search?q=%E7%99%BE%E5%90%88%E5%9B%BE%E7%89%87&qs=n&form=QBRE&sp=-1&lq=0&pq=%E7%99%BE%E5%90%88%E5%9B%BE%E7%89%87&sc=12-4&sk=&cvid=65856B267CDA40EC96C4FF83D83E560E=百合",
"康乃馨": "https://cn.bing.com/images/search?q=%e5%ba%b7%e4%b9%83%e9%a6%a8%e5%9b%be%e7%89%87&form=IACFSM&first=1=康乃馨",
"向日葵": "https://cn.bing.com/images/search?q=%E5%90%91%E6%97%A5%E8%91%B5%E5%9B%BE%E7%89%87&qs=n&form=QBIR&sp=-1&lq=0&pq=%E5%90%91%E6%97%A5%E8%91%B5%E5%9B%BE%E7%89%87&sc=10-5&cvid=1C1D49325CD84F8788FB6B314F6E01A6&first=1=向日葵",
"牡丹": "https://cn.bing.com/images/search?q=%E7%89%A1%E4%B8%B9%E5%9B%BE%E7%89%87&qs=n&form=QBIR&sp=-1&lq=0&pq=%E7%89%A1%E4%B8%B9%E5%9B%BE%E7%89%87&sc=10-4&cvid=AC8E61E440BB4133AD716C7D56657A18&first=1=牡丹",
"菊花": "https://cn.bing.com/images/search?q=%E8%8F%8A%E8%8A%B1%E5%9B%BE%E7%89%87&qs=n&form=QBIR&sp=-1&lq=0&pq=%E8%8F%8A%E8%8A%B1%E5%9B%BE%E7%89%87&sc=10-4&cvid=7BFACEDD29434E70A021E052A2F4FD17&first=1=菊花",
"薰衣草": "https://cn.bing.com/images/search?q=%E8%96%B0%E8%A1%A3%E8%8D%89%E5%9B%BE%E7%89%87&qs=n&form=QBIR&sp=-1&lq=0&pq=%E8%96%B0%E8%A1%A3%E8%8D%89%E5%9B%BE%E7%89%87&sc=10-5&cvid=B0C56AC233904D848003886AC009A43E&first=1=薰衣草"
}
# 每个品种下载的最大图片数
images_per_species = 100
# 1. 图片下载模块 - 支持多品种图片下载
def download_species_images(species_urls: Dict[str, str], save_dir: Path,
max_images_per_species: int = 100) -> Dict[str, List[str]]:
"""
下载多品种花卉图片并返回各品种的图片路径
"""
all_downloaded = {species: [] for species in species_urls}
for species, url in species_urls.items():
logger.info(f"开始下载 {species} 的图片...")
image_links = []
try:
# 模拟浏览器请求头
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
"Referer": url
}
response = requests.get(url, headers=headers, timeout=15)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
# 针对百度图片的特定选择器
image_tags = soup.find_all('img', class_='main_img')
image_links = [tag.get('src') or tag.get('data-src') for tag in image_tags if
tag.get('src') or tag.get('data-src')]
image_links = [link for link in image_links if link.startswith(('http:', 'https:'))][
:max_images_per_species]
except Exception as e:
logger.error(f"获取 {species} 图片链接失败: {str(e)}")
continue
downloaded = []
for idx, link in enumerate(tqdm(image_links, desc=f"下载 {species} 图片")):
try:
# 处理可能的相对路径
if not link.startswith(('http:', 'https:')):
link = 'https:' + link if link.startswith('//') else url + link
img_response = requests.get(link, headers=headers, stream=True, timeout=15)
img_response.raise_for_status()
# 保存到原始图片文件夹并记录
img_ext = os.path.splitext(link)[1] or ".jpg"
img_name = save_dir / f"{species}_{idx:04d}{img_ext}"
with open(img_name, 'wb') as f:
for chunk in img_response.iter_content(chunk_size=1024 * 1024):
f.write(chunk)
downloaded.append(str(img_name))
except Exception as e:
logger.warning(f"下载 {species} 图片 {link} 失败: {str(e)}")
continue
all_downloaded[species] = downloaded
logger.info(f"成功下载 {species} 图片: {len(downloaded)} 张")
return all_downloaded
# 2. 图片清洗模块 - 去重和分类保存
def clean_and_classify_images(downloaded_images: Dict[str, List[str]],
min_size: Tuple[int, int] = (200, 200),
similarity_threshold: int = 5) -> Dict[str, List[str]]:
"""清洗图片并按品种分类保存"""
cleaned_images = {species: [] for species in downloaded_images}
all_images = []
# 收集所有图片用于全局去重
for species, images in downloaded_images.items():
all_images.extend([(img, species) for img in images])
# 全局去重
phash_dict = {}
total_duplicates = 0
for img_path, species in tqdm(all_images, desc="清洗和分类图片"):
if not is_valid_image(img_path, min_size):
shutil.move(img_path, processed_folder / "invalid" / os.path.basename(img_path))
continue
phash = calculate_phash(img_path)
if not phash:
shutil.move(img_path, processed_folder / "invalid" / os.path.basename(img_path))
continue
is_duplicate = False
for existing_phash, existing_path in phash_dict.items():
hamming_distance = sum(a != b for a, b in zip(phash, existing_phash))
if hamming_distance <= similarity_threshold:
shutil.move(img_path, processed_folder / "duplicates" / os.path.basename(img_path))
total_duplicates += 1
is_duplicate = True
break
if not is_duplicate:
# 保存到对应品种文件夹
species_folder = class_folders[species]
species_folder.mkdir(exist_ok=True)
shutil.move(img_path, species_folder / os.path.basename(img_path))
cleaned_images[species].append(str(species_folder / os.path.basename(img_path)))
phash_dict[phash] = img_path
logger.info(f"清洗完成: 移除 {total_duplicates} 张重复图片和无效图片")
return cleaned_images
def is_valid_image(image_path: str, min_size: Tuple[int, int]) -> bool:
"""检查图片有效性"""
try:
img = Image.open(image_path)
width, height = img.size
if width < min_size[0] or height < min_size[1]:
return False
if img.format not in ["JPEG", "PNG", "JPG"]:
return False
img = img.convert("RGB")
img.verify()
img.close()
return True
except:
return False
def calculate_phash(image_path: str, size: Tuple[int, int] = (32, 32)) -> Optional[str]:
"""计算图片感知哈希"""
try:
img = Image.open(image_path).convert('L').resize(size, Image.LANCZOS)
img_array = np.array(img)
dct = np.fft.irfft2(np.fft.rfft2(img_array))
dct_small = dct[:8, :8]
mean = np.mean(dct_small)
phash = ''.join(['1' if pixel > mean else '0' for pixel in dct_small.flatten()])
return hashlib.md5(phash.encode()).hexdigest()
except:
return None
# 3. 生成多类别标注CSV
def generate_multiclass_annotations(cleaned_images: Dict[str, List[str]]) -> Optional[Path]:
"""生成多类别标注CSV"""
if not any(cleaned_images.values()):
logger.error("没有有效图片用于标注")
return None
annotation_data = []
for species, images in cleaned_images.items():
for img_path in images:
img_name = os.path.basename(img_path)
annotation_data.append([img_name, species])
csv_path = desktop_path / "flowers_multiclass_dataset.csv"
try:
with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["image_path", "label"])
writer.writerows(annotation_data)
logger.info(f"多类别标注CSV已保存至: {csv_path}")
return csv_path
except Exception as e:
logger.error(f"生成标注CSV失败: {str(e)}")
return None
# 4. 多类别数据集类
class MulticlassFlowerDataset(Dataset):
"""多类别花卉数据集"""
def __init__(self, csv_path: Path, class_folders: Dict[str, Path], transform=None):
self.data_info = []
self.class_folders = class_folders
self.transform = transform
if not os.path.exists(csv_path):
logger.error(f"CSV文件不存在: {csv_path}")
return
with open(csv_path, 'r', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile)
next(reader) # 跳过标题行
for row in reader:
if len(row) >= 2:
img_name, label = row[0], row[1]
if label in class_folders:
full_path = class_folders[label] / img_name
if os.path.exists(full_path):
self.data_info.append((str(full_path), label))
else:
logger.warning(f"图片不存在: {full_path}")
logger.info(f"加载 {len(self.data_info)} 个多类别样本")
def __len__(self):
return len(self.data_info)
def __getitem__(self, index):
img_path, label = self.data_info[index]
try:
img = io.read_image(img_path)
if img.shape[0] == 1:
img = img.repeat(3, 1, 1)
img = img.float() / 255.0
if self.transform:
img = self.transform(img)
# 构建标签映射
unique_labels = sorted(list(self.class_folders.keys()))
label_to_idx = {label: i for i, label in enumerate(unique_labels)}
label_idx = label_to_idx[label]
return img, torch.tensor(label_idx)
except Exception as e:
logger.error(f"处理图片 {img_path} 失败: {e}")
return torch.zeros(3, 224, 224), torch.tensor(0)
# 5. 多类别分类模型
class FlowerClassifier(nn.Module):
"""支持多类别的花卉分类模型"""
def __init__(self, num_classes: int, model_name: str = "resnet50", pretrained: bool = True):
super(FlowerClassifier, self).__init__()
if model_name == "resnet50":
self.model = torchvision.models.resnet50(
weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None
)
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Linear(num_ftrs, num_classes)
elif model_name == "resnet18":
self.model = torchvision.models.resnet18(
weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
)
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Linear(num_ftrs, num_classes)
else:
raise ValueError(f"不支持的模型: {model_name}")
def forward(self, x):
return self.model(x)
# 6. 训练多类别模型
def train_multiclass_model(model, dataloaders, criterion, optimizer, scheduler,
num_epochs=20, device='cpu'):
"""训练多类别分类模型"""
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
writer = SummaryWriter(log_dir=str(log_dir))
for epoch in range(num_epochs):
print(f'Epoch {epoch + 1}/{num_epochs}')
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
writer.add_scalar(f'Loss/{phase}', epoch_loss, epoch)
writer.add_scalar(f'Accuracy/{phase}', epoch_acc, epoch)
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
if scheduler:
scheduler.step()
writer.close()
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:4f}')
model.load_state_dict(best_model_wts)
return model
# 主函数:多类别花卉分类全流程
def main():
logger.info("开始多类别花卉分类模型训练流程...")
# 1. 下载多品种花卉图片
logger.info("步骤1: 下载各品种花卉图片...")
downloaded_images = download_species_images(species_urls, image_folder, images_per_species)
# 统计下载结果
total_downloaded = sum(len(imgs) for imgs in downloaded_images.values())
if total_downloaded == 0:
logger.error("没有下载到任何图片,程序退出")
return
logger.info(f"总共下载 {total_downloaded} 张图片,开始清洗...")
# 2. 清洗图片并按品种分类
cleaned_images = clean_and_classify_images(downloaded_images)
# 统计清洗结果
total_cleaned = sum(len(imgs) for imgs in cleaned_images.values())
if total_cleaned == 0:
logger.error("没有有效图片,程序退出")
return
logger.info(f"清洗后保留 {total_cleaned} 张有效图片,开始生成标注...")
# 3. 生成多类别标注CSV
csv_path = generate_multiclass_annotations(cleaned_images)
if not csv_path:
logger.error("标注生成失败,程序退出")
return
# 4. 数据预处理
logger.info("步骤4: 数据预处理...")
# 定义数据增强和归一化转换
train_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 5. 创建数据集
logger.info("步骤5: 创建多类别数据集...")
dataset = MulticlassFlowerDataset(csv_path, class_folders, transform=train_transform)
# if len(dataset) < 20: # 确保数据集至少有20个样本用于划分
# logger.error(f"数据集样本不足,仅有 {len(dataset)} 个样本")
# return
import sys
# 主程序逻辑
if len(dataset) < 20:
logger.error(f"数据集样本不足,仅有 {len(dataset)} 个样本")
sys.exit(1) # 终止程序并返回错误码1
# 后续代码(样本不足时不会执行到这里)
print("开始训练模型...")
# 6. 划分训练集和验证集
logger.info("步骤6: 划分训练集和验证集...")
# 按8:2比例划分
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# 为验证集设置验证转换
val_dataset.dataset.transform = val_transform
# 7. 创建数据加载器
logger.info("步骤7: 创建数据加载器...")
batch_size = 32 if torch.cuda.is_available() else 8 # 根据是否有GPU调整批量大小
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
dataloaders = {
'train': train_loader,
'val': val_loader
}
# 8. 初始化模型
logger.info("步骤8: 初始化多类别分类模型...")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"使用设备: {device}")
# 获取类别数
num_classes = len(flower_species)
model = FlowerClassifier(num_classes=num_classes, model_name="resnet50", pretrained=True).to(device)
# 9. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
# 10. 训练模型
logger.info("步骤9: 开始训练多类别分类模型...")
model = train_multiclass_model(
model, dataloaders, criterion, optimizer, scheduler,
num_epochs=25, device=device
)
# 11. 保存模型
logger.info("步骤10: 保存训练好的模型...")
model_save_path = desktop_path / "multiclass_flower_classifier.pth"
torch.save(model.state_dict(), str(model_save_path))
logger.info(f"多类别花卉分类模型已保存至: {model_save_path}")
# 12. 评估模型
logger.info("步骤11: 评估模型性能...")
def evaluate_model(model, dataloader, device, class_names):
"""评估模型在验证集上的性能"""
model.eval()
criterion = nn.CrossEntropyLoss()
running_loss = 0.0
running_corrects = 0
confusion_matrix = torch.zeros(len(class_names), len(class_names), dtype=torch.int64)
with torch.no_grad():
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
_, preds = torch.max(outputs, 1)
running_corrects += torch.sum(preds == labels.data)
# 累计混淆矩阵
for t, p in zip(labels.view(-1), preds.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
epoch_loss = running_loss / len(dataloader.dataset)
epoch_acc = running_corrects.double() / len(dataloader.dataset)
# 打印分类报告
print(f'\n测试 Loss: {epoch_loss:.4f}')
print(f'测试 Accuracy: {epoch_acc:.4f}')
print("\n分类详细结果:")
for i, class_name in enumerate(class_names):
correct = confusion_matrix[i, i].item()
total = confusion_matrix[i, :].sum().item()
if total > 0:
accuracy = correct / total
print(f"{class_name}: 正确 {correct}/{total} ({accuracy:.2%})")
# 计算总体准确率和各分类准确率
return {
'loss': epoch_loss,
'accuracy': epoch_acc,
'confusion_matrix': confusion_matrix,
'class_names': class_names
}
# 执行评估
evaluation = evaluate_model(model, val_loader, device, flower_species)
# 13. 可视化混淆矩阵
logger.info("步骤12: 可视化混淆矩阵...")
def plot_confusion_matrix(cm, class_names, title='混淆矩阵', cmap='Blues'):
"""绘制混淆矩阵"""
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap=cmap, ax=ax)
ax.set_title(title)
ax.set_xlabel('预测标签')
ax.set_ylabel('真实标签')
# 设置x轴标签旋转
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.set_yticklabels(class_names)
plt.tight_layout()
return fig
# 绘制并保存混淆矩阵
fig = plot_confusion_matrix(
evaluation['confusion_matrix'].numpy(),
evaluation['class_names']
)
confusion_matrix_path = desktop_path / "confusion_matrix.png"
fig.savefig(str(confusion_matrix_path))
logger.info(f"混淆矩阵已保存至: {confusion_matrix_path}")
# 14. 模型预测功能
logger.info("步骤13: 实现模型预测功能...")
def predict_flower(model, image_path, class_names, device='cpu'):
"""使用训练好的模型预测花卉类别"""
model = model.to(device)
model.eval()
# 定义预测时的预处理转换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
try:
# 加载并预处理图片
img = Image.open(image_path).convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(device)
# 预测
with torch.no_grad():
outputs = model(img_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)
conf, pred = torch.max(probs, 1)
# 构建预测结果
result = {
"predicted_class": class_names[pred.item()],
"confidence": conf.item(),
"probabilities": {class_names[i]: probs[0, i].item() for i in range(len(class_names))}
}
return result
except Exception as e:
logger.error(f"预测失败: {str(e)}")
return None
# 15. 测试预测功能
logger.info("步骤14: 测试模型预测功能...")
# 随机选择一个验证集样本进行预测
if val_dataset and len(val_dataset) > 0:
# 从验证集中获取一个样本
index = random.randint(0, len(val_dataset) - 1)
img_path, true_label = val_dataset.dataset.data_info[index]
# 进行预测
logger.info(f"测试预测: 使用图片 {img_path}")
prediction = predict_flower(model, img_path, flower_species, device)
if prediction:
logger.info(f"真实标签: {true_label}")
logger.info(f"预测结果: {prediction['predicted_class']}, 置信度: {prediction['confidence']:.4f}")
logger.info("各类别概率:")
for class_name, prob in prediction['probabilities'].items():
logger.info(f" {class_name}: {prob:.4f}")
else:
logger.warning("验证集为空,无法测试预测功能")
logger.info("多类别花卉分类模型训练和评估流程完成!") 这个代码的标签是什么