帮我修正一下这个代码使他能精确识别出车牌import matplotlib.pyplot as plt
import os
import cv2
import numpy as np
import pandas as pd
from typing import List, Tuple, Dict
from collections import Counter
import warnings
warnings.filterwarnings('ignore')
# 设置中文字体
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
class LicensePlateRecognizer:
"""车牌号识别系统主类"""
def __init__(self, data_dir: str, csv_path: str):
self.data_dir = data_dir
self.csv_path = csv_path
self.char_templates = {} # 字符模板库
self.knn_k = 5 # 增加K值提高鲁棒性
def load_data(self) -> Dict:
"""
加载数据集 - 修正版
问题:之前错误地尝试将文件路径转换为整数
修正:直接使用文件路径,不进行int转换
"""
print("正在加载数据集...")
if not os.path.exists(self.csv_path):
raise FileNotFoundError(f"CSV文件不存在: {self.csv_path}")
# 读取CSV
try:
df = pd.read_csv(self.csv_path, encoding='gbk')
except:
try:
df = pd.read_csv(self.csv_path, encoding='utf-8')
except:
df = pd.read_csv(self.csv_path, encoding='utf-8-sig')
print(f"CSV文件列名: {list(df.columns)}")
print(f"加载了 {len(df)} 条数据记录")
data_dict = {}
for idx, row in df.iterrows():
try:
# 关键修正:不尝试将path转换为int
# 直接提取path值,处理可能的字符串格式
path_str = str(row['path']).strip()
# 如果path是纯数字,尝试构建图片路径
if path_str.isdigit():
# 假设图片是数字.jpg格式
img_filename = f"{path_str}.jpg"
elif '/' in path_str or '\\' in path_str:
# 如果是路径格式,提取文件名
img_filename = os.path.basename(path_str)
else:
# 其他情况,直接使用
img_filename = path_str
# 构建完整图片路径
img_path = os.path.join(self.data_dir, img_filename)
# 如果图片不存在,尝试其他扩展名
if not os.path.exists(img_path):
# 尝试.jpg扩展名
img_path = os.path.join(self.data_dir, f"{path_str}.jpg")
if not os.path.exists(img_path):
# 尝试.png扩展名
img_path = os.path.join(self.data_dir, f"{path_str}.png")
# 检查图片是否存在
if not os.path.exists(img_path):
print(f"警告: 图片不存在: {img_path}")
continue
# 读取边界框坐标
bbox = []
for coord in ['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4']:
try:
bbox.append(float(row[coord]))
except:
bbox.append(0.0)
# 读取车牌标签
label = str(row.get('label', '')).strip()
data_dict[idx] = {
'img_path': img_path,
'bbox': bbox,
'label': label
}
except Exception as e:
print(f"处理第{idx}行数据时出错: {e}")
continue
print(f"成功加载 {len(data_dict)} 张有效图片")
# 显示前3个样本信息用于验证
print("\n前3个样本信息:")
for i, (key, val) in enumerate(list(data_dict.items())[:3]):
print(f" 样本{key}: 图片={os.path.basename(val['img_path'])}, 车牌={val['label']}, 坐标={val['bbox'][:4]}")
return data_dict
def extract_license_plate(self, img: np.ndarray, bbox: List) -> np.ndarray:
"""提取车牌区域(透视变换)"""
try:
h, w = img.shape[:2]
# 确保坐标在合理范围内
points = []
for i in range(0, 8, 2):
x = max(0, min(w - 1, bbox[i]))
y = max(0, min(h - 1, bbox[i + 1]))
points.append([x, y])
points = np.array(points, dtype=np.float32)
# 计算目标尺寸
width = int(max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])
))
height = int(max(
np.linalg.norm(points[1] - points[2]),
np.linalg.norm(points[3] - points[0])
))
# 防止尺寸为0
width = max(width, 10)
height = max(height, 10)
dst_points = np.array([
[0, 0],
[width - 1, 0],
[width - 1, height - 1],
[0, height - 1]
], dtype=np.float32)
# 透视变换
M = cv2.getPerspectiveTransform(points, dst_points)
warped = cv2.warpPerspective(img, M, (width, height))
return warped
except Exception as e:
print(f"提取车牌区域出错: {e}")
return None
def preprocess_plate(self, plate_img: np.ndarray) -> np.ndarray:
"""预处理车牌图像"""
if plate_img is None or plate_img.size == 0:
return None
try:
# 转换为灰度图
if len(plate_img.shape) == 3:
gray = cv2.cvtColor(plate_img, cv2.COLOR_BGR2GRAY)
else:
gray = plate_img
# 自适应阈值(更好的光照处理)
binary = cv2.adaptiveThreshold(
gray, 255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 15, 5
)
# 形态学操作
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
return binary
except Exception as e:
print(f"预处理出错: {e}")
return None
def segment_characters(self, plate_binary: np.ndarray) -> List[np.ndarray]:
"""分割车牌字符"""
if plate_binary is None:
return []
try:
height, width = plate_binary.shape
# 垂直投影
vertical_proj = np.sum(plate_binary > 0, axis=0)
# 找到字符区域
in_char = False
char_start = 0
segments = []
threshold = height * 0.05 # 降低阈值
for i in range(width):
if vertical_proj[i] > threshold and not in_char:
in_char = True
char_start = i
elif vertical_proj[i] <= threshold and in_char:
in_char = False
char_end = i
if char_end - char_start > 3: # 最小宽度
char_img = plate_binary[:, char_start:char_end]
# 去除上下空白
horizontal_proj = np.sum(char_img > 0, axis=1)
h_threshold = width * 0.01
rows = np.where(horizontal_proj > h_threshold)[0]
if len(rows) > 0:
char_img = char_img[rows[0]:rows[-1] + 1, :]
segments.append(char_img)
return segments
except Exception as e:
print(f"字符分割出错: {e}")
return []
def extract_features(self, char_img: np.ndarray) -> np.ndarray:
"""提取字符特征"""
try:
# 调整大小
char_resized = cv2.resize(char_img, (20, 40))
# 网格特征
features = []
grid_h, grid_w = 4, 4
cell_h = char_resized.shape[0] // grid_h
cell_w = char_resized.shape[1] // grid_w
for i in range(grid_h):
for j in range(grid_w):
cell = char_resized[i * cell_h:(i + 1) * cell_h, j * cell_w:(j + 1) * cell_w]
white_ratio = np.sum(cell > 0) / (cell_h * cell_w)
features.append(white_ratio)
# 添加形状特征
h, w = char_resized.shape
features.append(w / h if h > 0 else 1.0)
return np.array(features)
except Exception as e:
print(f"特征提取出错: {e}")
return np.zeros(17) # 4x4网格 + 1个形状特征
def build_template_library(self, data_dict: Dict, debug: bool = False):
"""构建字符模板库"""
print("正在构建字符模板库...")
char_samples = {}
char_count = 0
for idx, data in data_dict.items():
try:
img = cv2.imread(data['img_path'])
if img is None:
if debug:
print(f" 无法读取图片: {data['img_path']}")
continue
# 提取车牌
plate = self.extract_license_plate(img, data['bbox'])
if plate is None:
if debug:
print(f" 无法提取车牌区域")
continue
# 预处理
plate_binary = self.preprocess_plate(plate)
if plate_binary is None:
if debug:
print(f" 预处理失败")
continue
# 分割字符
char_segments = self.segment_characters(plate_binary)
if debug and idx < 3:
print(f"样本{idx}: 车牌={data['label']}, 分割字符数={len(char_segments)}")
# 获取标签
label = data['label']
if len(char_segments) > 0 and len(label) > 0:
# 取字符数较少的那个
min_len = min(len(char_segments), len(label))
for i in range(min_len):
char_img = char_segments[i]
char_label = label[i] if i < len(label) else "?"
features = self.extract_features(char_img)
if char_label not in char_samples:
char_samples[char_label] = []
char_samples[char_label].append(features)
char_count += 1
except Exception as e:
if debug:
print(f"处理样本{idx}时出错: {e}")
continue
# 构建模板
for char, feat_list in char_samples.items():
if feat_list:
# 使用中位数,对异常值更鲁棒
median_features = np.median(feat_list, axis=0)
self.char_templates[char] = {
'features': median_features,
'samples': len(feat_list)
}
print(f"模板库构建完成,共 {len(self.char_templates)} 个字符类别")
print(f"总字符样本数: {char_count}")
if len(self.char_templates) > 0:
print("字符类别:", sorted(self.char_templates.keys()))
else:
print("警告: 模板库为空!")
def knn_predict(self, features: np.ndarray, k: int = None) -> str:
"""KNN分类"""
if k is None:
k = self.knn_k
if not self.char_templates:
return "?"
# 计算距离
distances = []
for char, template in self.char_templates.items():
dist = np.linalg.norm(features - template['features'])
distances.append((dist, char))
# 排序
distances.sort(key=lambda x: x[0])
# 取前k个
k_nearest = distances[:k]
# 加权投票(距离越近权重越大)
votes = {}
for dist, char in k_nearest:
weight = 1.0 / (dist + 1e-10) # 防止除零
votes[char] = votes.get(char, 0) + weight
# 返回得票最多的
if votes:
return max(votes.items(), key=lambda x: x[1])[0]
else:
return "?"
def recognize_plate(self, img_path: str, bbox: List) -> str:
"""识别车牌"""
try:
img = cv2.imread(img_path)
if img is None:
return ""
plate = self.extract_license_plate(img, bbox)
if plate is None:
return ""
plate_binary = self.preprocess_plate(plate)
if plate_binary is None:
return ""
char_segments = self.segment_characters(plate_binary)
result = ""
for char_img in char_segments:
features = self.extract_features(char_img)
predicted_char = self.knn_predict(features)
result += predicted_char
return result
except Exception as e:
print(f"识别时出错: {e}")
return ""
def train_test_split(self, data_dict: Dict, test_ratio: float = 0.2, seed: int = 42):
"""划分数据集"""
indices = list(data_dict.keys())
np.random.seed(seed)
np.random.shuffle(indices)
split_idx = int(len(indices) * (1 - test_ratio))
return indices[:split_idx], indices[split_idx:]
def evaluate(self, data_dict: Dict, test_indices: List, debug: bool = False) -> Dict:
"""评估模型"""
print("正在评估模型性能...")
results = {
'total': 0,
'correct': 0,
'per_char_total': 0,
'per_char_correct': 0,
'details': []
}
for i, idx in enumerate(test_indices):
if idx not in data_dict:
continue
data = data_dict[idx]
true_label = data['label']
predicted_label = self.recognize_plate(data['img_path'], data['bbox'])
# 统计
results['total'] += 1
if predicted_label == true_label:
results['correct'] += 1
# 字符级统计
min_len = min(len(predicted_label), len(true_label))
for j in range(min_len):
results['per_char_total'] += 1
if predicted_label[j] == true_label[j]:
results['per_char_correct'] += 1
results['details'].append({
'idx': idx,
'true': true_label,
'predicted': predicted_label,
'correct': predicted_label == true_label
})
if debug and i < 5:
print(f"样本{idx}: 真实={true_label}, 预测={predicted_label}, 正确={predicted_label == true_label}")
# 计算准确率
if results['total'] > 0:
results['plate_accuracy'] = results['correct'] / results['total']
else:
results['plate_accuracy'] = 0
if results['per_char_total'] > 0:
results['char_accuracy'] = results['per_char_correct'] / results['per_char_total']
else:
results['char_accuracy'] = 0
return results
def visualize_sample(self, data_dict: Dict, sample_idx: int, save_path: str = None):
"""可视化样本识别过程"""
if sample_idx not in data_dict:
print(f"样本 {sample_idx} 不存在")
return
data = data_dict[sample_idx]
img = cv2.imread(data['img_path'])
if img is None:
print(f"无法读取图片: {data['img_path']}")
return
# 创建子图
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
# 原图
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
axes[0, 0].imshow(img_rgb)
axes[0, 0].set_title("原始图像")
axes[0, 0].axis('off')
# 车牌区域
plate = self.extract_license_plate(img, data['bbox'])
if plate is not None:
plate_rgb = cv2.cvtColor(plate, cv2.COLOR_BGR2RGB)
axes[0, 1].imshow(plate_rgb)
axes[0, 1].set_title("提取的车牌")
else:
axes[0, 1].text(0.5, 0.5, "提取失败", ha='center', va='center')
axes[0, 1].axis('off')
# 二值化车牌
plate_binary = self.preprocess_plate(plate)
if plate_binary is not None:
axes[0, 2].imshow(plate_binary, cmap='gray')
axes[0, 2].set_title("二值化处理")
else:
axes[0, 2].text(0.5, 0.5, "二值化失败", ha='center', va='center')
axes[0, 2].axis('off')
# 字符分割
char_segments = []
if plate_binary is not None:
char_segments = self.segment_characters(plate_binary)
axes[1, 0].imshow(plate_binary, cmap='gray')
axes[1, 0].set_title(f"字符分割 ({len(char_segments)}个字符)")
else:
axes[1, 0].text(0.5, 0.5, "无二值化图像", ha='center', va='center')
axes[1, 0].axis('off')
# 识别结果
result = self.recognize_plate(data['img_path'], data['bbox'])
axes[1, 1].text(0.1, 0.5,
f"真实车牌: {data['label']}\n识别结果: {result}\n是否正确: {result == data['label']}",
fontsize=12, transform=axes[1, 1].transAxes)
axes[1, 1].set_title("识别结果")
axes[1, 1].axis('off')
# 字符特征示例
if char_segments:
sample_char = char_segments[0]
features = self.extract_features(sample_char)
axes[1, 2].bar(range(len(features)), features)
axes[1, 2].set_title("字符特征示例")
axes[1, 2].set_xlabel("特征维度")
axes[1, 2].set_ylabel("特征值")
else:
axes[1, 2].text(0.5, 0.5, "无字符特征", ha='center', va='center')
plt.suptitle(f"车牌识别可视化 - 样本 {sample_idx}")
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"可视化结果已保存到: {save_path}")
plt.show()
def main():
"""主函数"""
# 设置路径
DATA_DIR = r"D:\CLPD\CLPD_1200"
CSV_PATH = r"D:\CLPD\CLPD.csv"
# 创建识别器
recognizer = LicensePlateRecognizer(DATA_DIR, CSV_PATH)
print("=" * 50)
print("车牌识别系统 - 手动实现版")
print("=" * 50)
# 1. 加载数据
print("\n[步骤1] 加载数据...")
data_dict = recognizer.load_data()
if not data_dict:
print("错误: 没有加载到数据!")
return
# 2. 划分数据集
print("\n[步骤2] 划分数据集...")
train_indices, test_indices = recognizer.train_test_split(data_dict, test_ratio=0.2)
print(f"训练集: {len(train_indices)} 个样本")
print(f"测试集: {len(test_indices)} 个样本")
# 3. 构建模板库(开启调试模式)
print("\n[步骤3] 构建模板库...")
train_data = {idx: data_dict[idx] for idx in train_indices}
recognizer.build_template_library(train_data, debug=True)
# 4. 评估模型
print("\n[步骤4] 评估模型...")
results = recognizer.evaluate(data_dict, test_indices, debug=True)
print("\n" + "=" * 50)
print("评估结果:")
print(f"车牌整体准确率: {results['plate_accuracy']:.2%} ({results['correct']}/{results['total']})")
print(f"字符级准确率: {results['char_accuracy']:.2%} ({results['per_char_correct']}/{results['per_char_total']})")
print("=" * 50)
# 5. 显示错误样本
errors = [d for d in results['details'] if not d['correct']]
if errors:
print(f"\n错误样本 (前10个):")
for i, err in enumerate(errors[:10]):
print(f" 样本{err['idx']}: 真实={err['true']}, 预测={err['predicted']}")
# 6. 可视化示例
print("\n[步骤5] 可视化示例...")
if test_indices:
sample_idx = test_indices[0]
recognizer.visualize_sample(data_dict, sample_idx)
# 7. 交互式测试
print("\n[步骤6] 交互式测试")
while True:
choice = input("\n输入样本序号测试 (0-1199) 或 'q' 退出: ").strip()
if choice.lower() == 'q':
break
try:
idx = int(choice)
if idx in data_dict:
data = data_dict[idx]
result = recognizer.recognize_plate(data['img_path'], data['bbox'])
print(f"\n样本 {idx}:")
print(f" 图片路径: {os.path.basename(data['img_path'])}")
print(f" 真实车牌: {data['label']}")
print(f" 识别结果: {result}")
print(f" 是否正确: {result == data['label']}")
viz = input("\n 是否可视化处理过程? (y/n): ").strip().lower()
if viz == 'y':
recognizer.visualize_sample(data_dict, idx)
else:
print("序号无效!")
except ValueError:
print("请输入有效数字!")
print("\n程序结束!")
if __name__ == "__main__":
np.random.seed(42)
main()
最新发布