python3.8 -使用canda
pip install paddleocr==2.7.0.3
pip install paddleocr opencv-python numpy matplotlib pandas pillow openpyxl
pip install -r requirements.txt
paddleocr>=2.6.0
paddlepaddle>=2.4.0
opencv-python>=4.5.0
numpy>=1.21.0
matplotlib>=3.5.0
pandas>=1.3.0
Pillow>=8.3.0
pip install fastapi uvicorn python-multipart
pip install pillow pandas openpyxl
=======================
代码如下:
import cv2
import numpy as np
import pandas as pd
from paddleocr import PaddleOCR
import os
from collections import defaultdict
class SmartTableOCR:
def __init__(self, use_gpu=False):
# 指定模型路径
det_model_dir = r"D:\aAi\ocr38\ch_PP-OCRv4_det_infer"
rec_model_dir = r"D:\aAi\ocr38\ch_PP-OCRv4_rec_infer" # 如果需要也指定识别模型
cls_model_dir = r"D:\aAi\ocr38\ch_ppocr_mobile_v2.0_cls_infer" # 如果需要也指定分类模型
# 检查模型路径是否存在
if not os.path.exists(det_model_dir):
print(f"警告: 检测模型路径不存在: {det_model_dir}")
print("将使用在线下载的模型")
det_model_dir = None
self.ocr = PaddleOCR(
use_angle_cls=True,
lang='ch',
use_gpu=use_gpu,
show_log=False,
det_model_dir=det_model_dir,
rec_model_dir=rec_model_dir, # 可以设置为None让自动下载
cls_model_dir=cls_model_dir # 可以设置为None让自动下载
)
print("PaddleOCR 初始化完成")
def smart_table_recognition(self, image_path):
"""智能表格识别"""
print("正在进行智能表格识别...")
# 检查图片是否存在
if not os.path.exists(image_path):
print(f"错误: 图片不存在 - {image_path}")
return None
# OCR识别
try:
result = self.ocr.ocr(image_path, cls=True)
except Exception as e:
print(f"OCR识别失败: {e}")
return None
if not result or not result[0]:
print("未识别到文字")
return None
# 提取文字框信息
text_boxes = self._extract_text_boxes(result[0])
if not text_boxes:
print("未提取到文字框信息")
return None
# 检测行列结构
rows, cols = self._detect_table_structure(text_boxes)
# 构建表格数据
table_data = self._build_table_data(text_boxes, rows, cols)
return table_data
def _extract_text_boxes(self, ocr_result):
"""提取文字框信息"""
text_boxes = []
for line in ocr_result:
points = line[0]
text = line[1][0]
confidence = line[1][1]
x_coords = [p[0] for p in points]
y_coords = [p[1] for p in points]
text_boxes.append({
'text': text,
'confidence': confidence,
'bbox': {
'x_min': min(x_coords), 'x_max': max(x_coords),
'y_min': min(y_coords), 'y_max': max(y_coords)
},
'center_x': sum(x_coords) / 4,
'center_y': sum(y_coords) / 4,
'width': max(x_coords) - min(x_coords),
'height': max(y_coords) - min(y_coords)
})
print(f"提取到 {len(text_boxes)} 个文字框")
return text_boxes
def _detect_table_structure(self, text_boxes, row_threshold=0.03, col_threshold=0.05):
"""检测表格的行列结构"""
if not text_boxes:
return [], []
# 获取图片尺寸(通过文字框的最大坐标)
img_height = max(box['bbox']['y_max'] for box in text_boxes)
img_width = max(box['bbox']['x_max'] for box in text_boxes)
print(f"图片尺寸估计: {img_width} x {img_height}")
# 动态阈值
row_threshold_px = img_height * row_threshold
col_threshold_px = img_width * col_threshold
print(f"行阈值: {row_threshold_px:.1f}px, 列阈值: {col_threshold_px:.1f}px")
# 按行分组
text_boxes.sort(key=lambda x: x['center_y'])
rows = []
current_row = [text_boxes[0]]
for i in range(1, len(text_boxes)):
current_box = text_boxes[i]
last_box = current_row[-1]
y_diff = abs(current_box['center_y'] - last_box['center_y'])
if y_diff < row_threshold_px:
current_row.append(current_box)
else:
# 对当前行按x坐标排序
current_row.sort(key=lambda x: x['center_x'])
rows.append(current_row)
current_row = [current_box]
if current_row:
current_row.sort(key=lambda x: x['center_x'])
rows.append(current_row)
print(f"检测到 {len(rows)} 行")
# 检测列位置
all_centers_x = [box['center_x'] for box in text_boxes]
if all_centers_x:
cols = self._cluster_values(all_centers_x, col_threshold_px)
print(f"检测到 {len(cols)} 列")
else:
cols = []
print("未检测到列结构")
return rows, cols
def _cluster_values(self, values, threshold):
"""对值进行聚类"""
if not values:
return []
values.sort()
clusters = []
current_cluster = [values[0]]
for i in range(1, len(values)):
if values[i] - current_cluster[-1] <= threshold:
current_cluster.append(values[i])
else:
clusters.append(sum(current_cluster) / len(current_cluster))
current_cluster = [values[i]]
if current_cluster:
clusters.append(sum(current_cluster) / len(current_cluster))
return sorted(clusters)
def _build_table_data(self, text_boxes, rows, cols):
"""构建表格数据"""
if not rows:
return []
# 创建空的表格结构
table = [['' for _ in range(len(cols))] for _ in range(len(rows))]
# 统计每个单元格的文字数量
cell_counts = [[0 for _ in range(len(cols))] for _ in range(len(rows))]
# 将文字填入对应的单元格
for i, row_boxes in enumerate(rows):
for box in row_boxes:
col_idx = self._find_column_index(box['center_x'], cols)
if col_idx is not None and col_idx < len(cols):
# 如果单元格已有内容,用空格分隔
if table[i][col_idx]:
table[i][col_idx] += ' ' + box['text']
else:
table[i][col_idx] = box['text']
cell_counts[i][col_idx] += 1
# 处理可能的重叠问题(一个单元格有多个文字框)
self._refine_table_data(table, cell_counts)
return table
def _find_column_index(self, x_pos, cols):
"""找到文字对应的列索引"""
if not cols:
return 0
min_distance = float('inf')
best_index = 0
for i, col_x in enumerate(cols):
distance = abs(x_pos - col_x)
if distance < min_distance:
min_distance = distance
best_index = i
return best_index
def _refine_table_data(self, table, cell_counts):
"""精炼表格数据,处理重叠的文字"""
for i in range(len(table)):
for j in range(len(table[i])):
if cell_counts[i][j] > 1:
# 如果一个单元格有多个文字,可能需要特殊处理
# 这里只是简单记录,实际可以根据需要调整
pass
def analyze_table_structure(self, image_path):
"""分析表格结构并返回详细信息"""
result = self.ocr.ocr(image_path, cls=True)
if not result or not result[0]:
return None
text_boxes = self._extract_text_boxes(result[0])
rows, cols = self._detect_table_structure(text_boxes)
analysis = {
'text_boxes_count': len(text_boxes),
'rows_count': len(rows),
'columns_count': len(cols),
'rows': [
{
'row_index': i,
'boxes_count': len(row),
'boxes': [box['text'] for box in row]
}
for i, row in enumerate(rows)
],
'columns': cols
}
return analysis
def export_table_results(self, table_data, image_path):
"""导出表格结果"""
if not table_data:
print("无表格数据可导出")
return
output_dir = "output"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
base_name = os.path.splitext(os.path.basename(image_path))[0]
# 保存为Excel
excel_path = os.path.join(output_dir, f"{base_name}_smart_table.xlsx")
df = pd.DataFrame(table_data)
df.to_excel(excel_path, index=False, header=False)
print(f"智能表格已保存到: {excel_path}")
# 保存为CSV
csv_path = os.path.join(output_dir, f"{base_name}_smart_table.csv")
df.to_csv(csv_path, index=False, header=False, encoding='utf-8-sig')
print(f"CSV表格已保存到: {csv_path}")
# 打印表格预览
print("\n表格预览:")
print("-" * 50)
for i, row in enumerate(table_data):
print(f"行 {i+1}: {row}")
return excel_path, csv_path
def visualize_detection(self, image_path, output_dir="output"):
"""可视化检测结果"""
if not os.path.exists(image_path):
print(f"图片不存在: {image_path}")
return
# OCR识别
result = self.ocr.ocr(image_path, cls=True)
if not result or not result[0]:
print("未识别到文字")
return
# 读取图片
img = cv2.imread(image_path)
if img is None:
print("无法读取图片")
return
# 绘制文字框
for line in result[0]:
points = np.array(line[0], dtype=np.int32)
text = line[1][0]
confidence = line[1][1]
# 绘制边界框
cv2.polylines(img, [points], True, (0, 255, 0), 2)
# 绘制文字
center_x = int(np.mean(points[:, 0]))
center_y = int(np.mean(points[:, 1]))
cv2.putText(img, f"{text}({confidence:.2f})", (center_x-50, center_y),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
# 保存结果
if not os.path.exists(output_dir):
os.makedirs(output_dir)
base_name = os.path.splitext(os.path.basename(image_path))[0]
output_path = os.path.join(output_dir, f"{base_name}_detection.jpg")
cv2.imwrite(output_path, img)
print(f"检测可视化已保存到: {output_path}")
def main():
"""主函数"""
image_path = r"d:\Desktop\tinified\test.png"
if not os.path.exists(image_path):
print(f"图片不存在: {image_path}")
return
# 使用智能表格识别
print("初始化智能表格识别器...")
smart_ocr = SmartTableOCR(use_gpu=False)
print("开始表格识别...")
table_data = smart_ocr.smart_table_recognition(image_path)
if table_data:
print(f"\n识别成功! 表格尺寸: {len(table_data)}行 x {len(table_data[0])}列")
smart_ocr.export_table_results(table_data, image_path)
# 可视化检测结果
print("\n生成可视化结果...")
smart_ocr.visualize_detection(image_path)
# 分析表格结构
print("\n表格结构分析:")
analysis = smart_ocr.analyze_table_structure(image_path)
if analysis:
print(f"文字框数量: {analysis['text_boxes_count']}")
print(f"行数: {analysis['rows_count']}")
print(f"列数: {analysis['columns_count']}")
else:
print("表格识别失败")
if __name__ == "__main__":
main()
539

被折叠的 条评论
为什么被折叠?



