pytorchOCR之DBnet(多类别文本检测1)
pytorchOCR之DBnet(多类别文本检测1)
摘要
该博客介绍了如何使用PyTorch实现DBnet模型进行多类别文本检测,包括中英文和手写印刷文字的分类。模型结构通过在原始DBnet基础上增加一个分类支路来实现。此外,还详细讲解了分类损失函数的计算方法。实验结果显示,这种方法在文本分类.
pytorchOCR之DBnet(多类别文本检测1)
代码
以上链接无法访问,请访问
https://download.youkuaiyun.com/download/weixin_54626591/92423947
动机
主要是为了对不同文本类型,既要检测文本位置,又要分类文本类型,比如中英文检测分类,手写印刷检测分类。
实现方式
1 . 模型结构
这里对DBnet结构多增加了一个分类的支路,如下:
self.classhead = nn.Sequential(
nn.Conv2d(inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
nn.BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, n_classes, 2, 2)
)
当然也可以直接在文本分割支路最后直接加上n_classes,这样可以共用更多卷积,但是实验下来没有分开效果好。
2. loss
这里多增加一个分类的loss如下:
class MulClassLoss(nn.Module):
def __init__(self, ):
super(MulClassLoss, self).__init__()
def forward(self,pre_score,gt_score,n_class):
gt_score = gt_score.reshape(-1)
index = gt_score>0
if index.sum()>0:
pre_score = pre_score.permute(0,2,3,1).reshape(-1,n_class)
gt_score = gt_score[index]
gt_score = gt_score - 1
pre_score = pre_score[index]
class_loss = F.cross_entropy(pre_score,gt_score.long(), ignore_index=-1)
else:
class_loss = torch.tensor(0.0).cuda()
return class_loss
这里我们选取文本像素进行分类loss计算。
效果



底下评论
问题一:
这个有训练好的模型文件可以给我试下吗?
作者回答:
没有模型了
问题二:
gt_score = gt_score - 1
这里的-1是什么意思,另外为什么要设置ignore_index=-1
作者回答:
GitHub里回复了,值得一提的是,这种方式只能解决文字不重叠的情况下的分类,第二种 网页链接 可以解决文字重叠的情况
其他人回答:
感谢您的分享!求教如何改善文字重叠的情况?
问题三:
请问类别是在哪里设置的呢?我训练出来只有一个类别,但标注是含有类别的
作者回答:
网页链接 里n_class
问题四:
计算分类这里有点问题,水平方向没问题,如果是斜体文字分类取值就错了,非常容易误判为第0分类
classes = classes[ymin:ymax + 1, xmin:xmax + 1]
np.argmax(np.bincount(classes.reshape(-1).astype(np.int32)))
作者回答:
是的,倾斜的这样做是有问题的 还是得取文本像素
其他人提问:
请问斜体文字的分类如何改善呀?
提问者回答
看我博文
问题五:
大佬,我想问一下,分类不均衡的问题有什么好的解决办法吗,我尝试了几种数据增强的方法还是不能提升预测效果。
作者回答:
没实际试过具体的项目 照理来说copy paste应该还可以吧
问题六:
大佬,问下为什么你的 loss 里面 gt_score > 0 才参与多类别交叉熵运算的目的是什么呢
作者回答:
因为文本的标签是从1开始的 0是背景像素的标签
其他人提问:
我跑出来的多分类的mask 的分割效果不太好,符合预期吗?如果进入 classhead 的 fuse detach 一下会不会好一些
作者回答:
效果不好有多种情况吧 ,不太好说,但是这两个分支尽可能解耦效果照理会好一些
问题七:
大佬,文本检测+分类,主干网络用mobilenet_v3_small,ACC一直是0.5左右,上不去怎么回事啊
作者回答:
你先用大一点的,比如resnet看看是不是会高一些, 排除代码问题,有可能小网络拟合能力比较差
其他人提问:
请问解决这个问题了吗,我用resnet 100个epoch了也一直是0.5左右
问题八:
您好,我用自己的数据集进行训练,骨干网络调用的是resnet50,mobilnet也用了,但是出现三个问题,1不保存训练后参数,2,loss在1至3左右不收敛3,acc一直在0.5附近……请问这个怎么调整,怎么调用预训练模型?希望您能解答。非常感谢
5358

被折叠的 条评论
为什么被折叠?



