使用paddleocr2.7识别表格

python3.8   -使用canda

pip install paddleocr==2.7.0.3
pip install paddleocr opencv-python numpy matplotlib pandas pillow openpyxl

pip install -r requirements.txt 

paddleocr>=2.6.0
paddlepaddle>=2.4.0
opencv-python>=4.5.0
numpy>=1.21.0
matplotlib>=3.5.0
pandas>=1.3.0
Pillow>=8.3.0  
pip install fastapi uvicorn python-multipart
pip install pillow pandas openpyxl

=======================

代码如下:

import cv2
import numpy as np
import pandas as pd
from paddleocr import PaddleOCR
import os
from collections import defaultdict

class SmartTableOCR:
    def __init__(self, use_gpu=False):
        # 指定模型路径
        det_model_dir = r"D:\aAi\ocr38\ch_PP-OCRv4_det_infer"
        rec_model_dir = r"D:\aAi\ocr38\ch_PP-OCRv4_rec_infer"  # 如果需要也指定识别模型
        cls_model_dir = r"D:\aAi\ocr38\ch_ppocr_mobile_v2.0_cls_infer"  # 如果需要也指定分类模型
        
        # 检查模型路径是否存在
        if not os.path.exists(det_model_dir):
            print(f"警告: 检测模型路径不存在: {det_model_dir}")
            print("将使用在线下载的模型")
            det_model_dir = None
        
        self.ocr = PaddleOCR(
            use_angle_cls=True, 
            lang='ch', 
            use_gpu=use_gpu, 
            show_log=False,
            det_model_dir=det_model_dir,
            rec_model_dir=rec_model_dir,  # 可以设置为None让自动下载
            cls_model_dir=cls_model_dir   # 可以设置为None让自动下载
        )
        print("PaddleOCR 初始化完成")
    
    def smart_table_recognition(self, image_path):
        """智能表格识别"""
        print("正在进行智能表格识别...")
        
        # 检查图片是否存在
        if not os.path.exists(image_path):
            print(f"错误: 图片不存在 - {image_path}")
            return None
        
        # OCR识别
        try:
            result = self.ocr.ocr(image_path, cls=True)
        except Exception as e:
            print(f"OCR识别失败: {e}")
            return None
        
        if not result or not result[0]:
            print("未识别到文字")
            return None
        
        # 提取文字框信息
        text_boxes = self._extract_text_boxes(result[0])
        
        if not text_boxes:
            print("未提取到文字框信息")
            return None
        
        # 检测行列结构
        rows, cols = self._detect_table_structure(text_boxes)
        
        # 构建表格数据
        table_data = self._build_table_data(text_boxes, rows, cols)
        
        return table_data
    
    def _extract_text_boxes(self, ocr_result):
        """提取文字框信息"""
        text_boxes = []
        for line in ocr_result:
            points = line[0]
            text = line[1][0]
            confidence = line[1][1]
            
            x_coords = [p[0] for p in points]
            y_coords = [p[1] for p in points]
            
            text_boxes.append({
                'text': text,
                'confidence': confidence,
                'bbox': {
                    'x_min': min(x_coords), 'x_max': max(x_coords),
                    'y_min': min(y_coords), 'y_max': max(y_coords)
                },
                'center_x': sum(x_coords) / 4,
                'center_y': sum(y_coords) / 4,
                'width': max(x_coords) - min(x_coords),
                'height': max(y_coords) - min(y_coords)
            })
        
        print(f"提取到 {len(text_boxes)} 个文字框")
        return text_boxes
    
    def _detect_table_structure(self, text_boxes, row_threshold=0.03, col_threshold=0.05):
        """检测表格的行列结构"""
        if not text_boxes:
            return [], []
        
        # 获取图片尺寸(通过文字框的最大坐标)
        img_height = max(box['bbox']['y_max'] for box in text_boxes)
        img_width = max(box['bbox']['x_max'] for box in text_boxes)
        
        print(f"图片尺寸估计: {img_width} x {img_height}")
        
        # 动态阈值
        row_threshold_px = img_height * row_threshold
        col_threshold_px = img_width * col_threshold
        
        print(f"行阈值: {row_threshold_px:.1f}px, 列阈值: {col_threshold_px:.1f}px")
        
        # 按行分组
        text_boxes.sort(key=lambda x: x['center_y'])
        rows = []
        current_row = [text_boxes[0]]
        
        for i in range(1, len(text_boxes)):
            current_box = text_boxes[i]
            last_box = current_row[-1]
            
            y_diff = abs(current_box['center_y'] - last_box['center_y'])
            
            if y_diff < row_threshold_px:
                current_row.append(current_box)
            else:
                # 对当前行按x坐标排序
                current_row.sort(key=lambda x: x['center_x'])
                rows.append(current_row)
                current_row = [current_box]
        
        if current_row:
            current_row.sort(key=lambda x: x['center_x'])
            rows.append(current_row)
        
        print(f"检测到 {len(rows)} 行")
        
        # 检测列位置
        all_centers_x = [box['center_x'] for box in text_boxes]
        if all_centers_x:
            cols = self._cluster_values(all_centers_x, col_threshold_px)
            print(f"检测到 {len(cols)} 列")
        else:
            cols = []
            print("未检测到列结构")
        
        return rows, cols
    
    def _cluster_values(self, values, threshold):
        """对值进行聚类"""
        if not values:
            return []
        
        values.sort()
        clusters = []
        current_cluster = [values[0]]
        
        for i in range(1, len(values)):
            if values[i] - current_cluster[-1] <= threshold:
                current_cluster.append(values[i])
            else:
                clusters.append(sum(current_cluster) / len(current_cluster))
                current_cluster = [values[i]]
        
        if current_cluster:
            clusters.append(sum(current_cluster) / len(current_cluster))
        
        return sorted(clusters)
    
    def _build_table_data(self, text_boxes, rows, cols):
        """构建表格数据"""
        if not rows:
            return []
        
        # 创建空的表格结构
        table = [['' for _ in range(len(cols))] for _ in range(len(rows))]
        
        # 统计每个单元格的文字数量
        cell_counts = [[0 for _ in range(len(cols))] for _ in range(len(rows))]
        
        # 将文字填入对应的单元格
        for i, row_boxes in enumerate(rows):
            for box in row_boxes:
                col_idx = self._find_column_index(box['center_x'], cols)
                if col_idx is not None and col_idx < len(cols):
                    # 如果单元格已有内容,用空格分隔
                    if table[i][col_idx]:
                        table[i][col_idx] += ' ' + box['text']
                    else:
                        table[i][col_idx] = box['text']
                    cell_counts[i][col_idx] += 1
        
        # 处理可能的重叠问题(一个单元格有多个文字框)
        self._refine_table_data(table, cell_counts)
        
        return table
    
    def _find_column_index(self, x_pos, cols):
        """找到文字对应的列索引"""
        if not cols:
            return 0
        
        min_distance = float('inf')
        best_index = 0
        
        for i, col_x in enumerate(cols):
            distance = abs(x_pos - col_x)
            if distance < min_distance:
                min_distance = distance
                best_index = i
        
        return best_index
    
    def _refine_table_data(self, table, cell_counts):
        """精炼表格数据,处理重叠的文字"""
        for i in range(len(table)):
            for j in range(len(table[i])):
                if cell_counts[i][j] > 1:
                    # 如果一个单元格有多个文字,可能需要特殊处理
                    # 这里只是简单记录,实际可以根据需要调整
                    pass
    
    def analyze_table_structure(self, image_path):
        """分析表格结构并返回详细信息"""
        result = self.ocr.ocr(image_path, cls=True)
        
        if not result or not result[0]:
            return None
        
        text_boxes = self._extract_text_boxes(result[0])
        rows, cols = self._detect_table_structure(text_boxes)
        
        analysis = {
            'text_boxes_count': len(text_boxes),
            'rows_count': len(rows),
            'columns_count': len(cols),
            'rows': [
                {
                    'row_index': i,
                    'boxes_count': len(row),
                    'boxes': [box['text'] for box in row]
                }
                for i, row in enumerate(rows)
            ],
            'columns': cols
        }
        
        return analysis
    
    def export_table_results(self, table_data, image_path):
        """导出表格结果"""
        if not table_data:
            print("无表格数据可导出")
            return
        
        output_dir = "output"
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        base_name = os.path.splitext(os.path.basename(image_path))[0]
        
        # 保存为Excel
        excel_path = os.path.join(output_dir, f"{base_name}_smart_table.xlsx")
        df = pd.DataFrame(table_data)
        df.to_excel(excel_path, index=False, header=False)
        print(f"智能表格已保存到: {excel_path}")
        
        # 保存为CSV
        csv_path = os.path.join(output_dir, f"{base_name}_smart_table.csv")
        df.to_csv(csv_path, index=False, header=False, encoding='utf-8-sig')
        print(f"CSV表格已保存到: {csv_path}")
        
        # 打印表格预览
        print("\n表格预览:")
        print("-" * 50)
        for i, row in enumerate(table_data):
            print(f"行 {i+1}: {row}")
        
        return excel_path, csv_path
    
    def visualize_detection(self, image_path, output_dir="output"):
        """可视化检测结果"""
        if not os.path.exists(image_path):
            print(f"图片不存在: {image_path}")
            return
        
        # OCR识别
        result = self.ocr.ocr(image_path, cls=True)
        
        if not result or not result[0]:
            print("未识别到文字")
            return
        
        # 读取图片
        img = cv2.imread(image_path)
        if img is None:
            print("无法读取图片")
            return
        
        # 绘制文字框
        for line in result[0]:
            points = np.array(line[0], dtype=np.int32)
            text = line[1][0]
            confidence = line[1][1]
            
            # 绘制边界框
            cv2.polylines(img, [points], True, (0, 255, 0), 2)
            
            # 绘制文字
            center_x = int(np.mean(points[:, 0]))
            center_y = int(np.mean(points[:, 1]))
            cv2.putText(img, f"{text}({confidence:.2f})", (center_x-50, center_y), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
        
        # 保存结果
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        base_name = os.path.splitext(os.path.basename(image_path))[0]
        output_path = os.path.join(output_dir, f"{base_name}_detection.jpg")
        cv2.imwrite(output_path, img)
        print(f"检测可视化已保存到: {output_path}")

def main():
    """主函数"""
    image_path = r"d:\Desktop\tinified\test.png"
    
    if not os.path.exists(image_path):
        print(f"图片不存在: {image_path}")
        return
    
    # 使用智能表格识别
    print("初始化智能表格识别器...")
    smart_ocr = SmartTableOCR(use_gpu=False)
    
    print("开始表格识别...")
    table_data = smart_ocr.smart_table_recognition(image_path)
    
    if table_data:
        print(f"\n识别成功! 表格尺寸: {len(table_data)}行 x {len(table_data[0])}列")
        smart_ocr.export_table_results(table_data, image_path)
        
        # 可视化检测结果
        print("\n生成可视化结果...")
        smart_ocr.visualize_detection(image_path)
        
        # 分析表格结构
        print("\n表格结构分析:")
        analysis = smart_ocr.analyze_table_structure(image_path)
        if analysis:
            print(f"文字框数量: {analysis['text_boxes_count']}")
            print(f"行数: {analysis['rows_count']}")
            print(f"列数: {analysis['columns_count']}")
            
    else:
        print("表格识别失败")

if __name__ == "__main__":
    main()

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值