【翻译】TextClassification介绍(三)

TextClassification介绍(三)

一、说明

这是一个关于介绍 TextClassification API 的系列文章,总共分三篇,本文是最后一篇。上一篇在此:【翻译】TextClassification介绍(二)

原文作者:Mark Allison
阅读时间: 5 分钟
原文链接:https://blog.stylingandroid.com/textclassification-part-3/

二、正文

在 API 26 (奥利奥)中安卓引入了一个新的文字功能系统: TextClassification 。这个系统将会在 API 28 ( π )中进一步改进完善。在本次简短的系列中,我们主要会探讨它是一个什么样的系统,如何使用它,以及如何为它添加一些自定义行为。

上一篇文章中我们开始研究自定义 TextClassifier 的实现,并研究了如何实现我们自己的“文本选择建议”。在本系列的最后一篇文章中,我们将会实现相应的 classifyText() 方法,并应用到我们自定义的 TextClassifier 文本分类器的实现中。

这里的 classifyText() 方法的实现实际上非常简单直接。我们检查选择的文本是否与请求中所选择的相匹配,如果相匹配则返回一个 TextClassification 的实例,这个我们稍后再详细了解。如果选择的文本与我们的正则表达式不匹配,那么我们会将结果进行回滚(也就是 TextClassifier 的默认实现):

override fun classifyText(request: TextClassification.Request): TextClassification {
    //line 36
    return if (regex.matches(request.subSequence())) {
        factory.buildTextClassification(
                //line 38
                request.subSequence().toString(),
                listOf(TextClassifier.TYPE_URL to 1.0f),
                listOf(factory.buildRemoteAction(
                        context,
                        R.drawable.ic_stylingandroid,
                        stylingAndroid,
                        contentDescription,
                        stylingAndroidUri
                ))
        )
    } else {
        fallback.classifyText(request)
    }
}

private fun TextClassification.Request.subSequence() =
        text.subSequence(startIndex, endIndex)

这里有一个 TextClassification.Request 的扩展函数,它返回文本中已选择的子文本 subSequence ,我们正是使用它来匹配正则表达式(第 36 行)。然后,我们构建一个 TextClassification 对象,它需要一些参数,包括匹配到的字符串(第 38 行),分类类型与其各自的可信度分数列表(第 39 行),以及一列 RemoteAction 实例对象,每个 RemoteAction 对象对应一个分类类型条目(第 39-45 行)。

这里的 buildRemoteAction() 方法返回每个 RemoteAction 实例:

override fun buildRemoteAction(
        context: Context,
        drawableId: Int,
        title: String,
        contentDescription: String,
        uri: String
): RemoteAction {
    return RemoteAction(
            Icon.createWithResource(context, drawableId),
            title,
            contentDescription,
            PendingIntent.getActivity(
                    context,
                    0,
                    Intent(Intent.ACTION_VIEW, Uri.parse(uri)),
                    0
            )
    )
}

这里 RemoteAction 的构造函数需要四个参数:一个 Icon 对象,它将显示为操作按钮中的一部分;一个要显示的文字;一个表示可访问的内容描述;最后还需一个 PendingIntent 对象,它表示用户在点击操作按钮时将会执行的操作。在我们的示例代码中,我们使用了一个 Styling Android 标志图形作为图标,使用了 “Styling Android” 作为标题,以及简单的一些内容描述,再加一个能登录浏览器打开 URL 链接 “https://blog.stylingandroid.com” 的 PendingIntent 对象。

buildTextClassification() 函数使用 TextClassification.Builder 实例来创建 TextClassification 对象:

override fun buildTextClassification(
        text: String,
        entityTypes: List<Pair<String, Float>>,
        actions: List<RemoteAction>
): TextClassification {
    return TextClassification.Builder()
            .run {
                setText(text)
                entityTypes.forEach { setEntityType(it.first, it.second) }
                actions.forEach { addAction(it) }
                build()
            }
}

这里需要传递三个参数:匹配到的文本,一系列以类型及其对应的可信度范围作为 Pair 的列表,以及与每个类型/置信度组合所对应的 RemoteActions 列表。在这里示例中,这些参数分别是被选定的子字符串,包含一对 TextClassifier.TYPE_URL 及其可信度为 1.0f 组成的 Pair 的列表,以及一列包含我们刚刚说过的 RemoteAction 类单例的列表。

TextClassifier 已经完成了,剩下的事情就是将它连接到我们的 TextView 文本控件中:

class MainActivity : AppCompatActivity() {

    private val emailText = "dummy@email.com"
    private val urlText = "https://blog.stylingandroid.com"
    private val hybridText = "Email: $emailText"

    private lateinit var textClassificationManager: TextClassificationManager

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        textClassificationManager = getSystemService(Context.TEXT_CLASSIFICATION_SERVICE) as TextClassificationManager

        classifier()

        text1.textClassifier = StylingAndroidTextClassifier(this, textClassificationManager.textClassifier)
    }

    private fun classifier() = async(CommonPool) {
        val textClassifier = textClassificationManager.textClassifier
        val emailClassification = textClassifier.classifyText(emailText, 0, emailText.length, LocaleList.getDefault())
        println(emailClassification)
        val urlClassification = textClassifier.classifyText(urlText, 0, urlText.length, LocaleList.getDefault())
        println(urlClassification)
        val suggestions = textClassifier.suggestSelection(hybridText, 10, 11, LocaleList.getDefault())
        println(suggestions)
    }
}

一行代码就可以实现我们要做的一切。我们创建了一个自定义 TextClassifier 对象,并将其作为参数传递给默认的 TextClassifier 对象的构造函数,并将 TextViewtextClassifier 属性设置为我们的自定义实例。

我们目前得到的文本处理行为是:如果用户长时间按下的文字是无法识别的类型,那就只有简单的复制,粘贴和选择所有的选项功能;如果他们长时间按下系统默认的文本分类器所支持的任何类型,那么他们会得到我们在第一篇文章中所看到的所有相同的行为和动作;但是,如果他们长按 “Styling Android” 字符串(或者与正则表达式相匹配的类似字符串),那么将获得我们所自定义的“流行广告”操作,这个操作将启动浏览器并登录加载链接: https://blog.stylingandroid.com

custom.gif

TextClassifier 文本分类器中还有一个机制,用来识别具体的文本类型,然后使用文本生成链接,不过在本系列文章中我们不会去研究它,因为我们这里所看到的技术已经涵盖了 TextViewWebView 内置的一些主要功能。如果看到这里有人对此感兴趣,那么请告诉我,我将会单独写一篇文章作详细说明。

三、总结

这篇文章的源代码可以在这里找到: https://github.com/StylingAndroid/TextClassification/tree/Part3

© 2018 , Mark Allison 。保留所有版权。

我的博客地址: http://liuqingwen.me ,欢迎关注我的微信公众号:
IT自学不成才

<think>好的,用户之前询问了如何部署RNN-T模型,现在又问“介绍一下rnn t”,看起来他们可能想更深入地了解RNN-T的基础知识。首先,我需要确认用户的需求是否已经从部署转向了模型本身的介绍。用户可能对RNN-T的结构、工作原理还不清楚,所以需要补充这方面的知识。 接下来,我要回忆RNN-T的核心组成,比如编码、预测和联合网络,以及它们如何协作处理序列任务。同时,用户可能对RNN-T的优势和应用场景感兴趣,比如在语音识别中的流式处理能力。另外,需要注意区分RNN-T与其他模型如CTC和Seq2Seq的不同点,以突出其特点。 用户之前的问题涉及部署,所以在介绍模型时可能需要简要提及与部署相关的内容,比如实时性和计算需求,但重点应放在模型结构上。要确保语言通俗易懂,结构清晰,分点说明,可能还需要举例说明应用场景,帮助用户更好地理解。 需要检查是否有技术细节需要展开,比如对齐机制、训练过程等,但不宜过于深入,以免让用户感到困惑。同时,注意使用中文术语,避免过多专业术语,必要时进行解释。最后,确认回答是否符合之前的系统级指令,如数学表达式的格式和结构要求。</think>以下是关于 **RNN-T(RNN Transducer,循环神经网络传感)**的详细介绍,涵盖其核心原理、结构特点和应用场景: --- ### **一、RNN-T 的基本概念** RNN-T 是一种专为序列到序列任务(如语音识别、手写识别)设计的端到端模型,最早由 Alex Graves 在 2012 年提出。其核心目标是解决传统 CTC(Connectionist Temporal Classification)模型的局限性,**允许输出序列与输入序列长度动态对齐**,并支持流式处理。 #### **关键特性** 1. **无需强制对齐**:输入(如音频帧)和输出(如文字)无需严格对齐标签。 2. **流式友好**:可逐帧处理输入,适合实时语音识别。 3. **联合建模**:通过联合网络(Joint Network)融合声学与语言模型信息。 --- ### **二、RNN-T 的核心结构** RNN-T 由部分组成:**编码(Encoder)、预测(Predictor)和联合网络(Joint Network)**。 #### **1. 编码(Encoder)** - **功能**:将输入序列(如音频特征)编码为高层表示。 - **常见结构**:LSTM、GRU 或 Transformer 的变体(如 Conformer)。 - **输入**:序列 $X = (x_1, x_2, ..., x_T)$(如梅尔频谱帧)。 - **输出**:隐状态序列 $H = (h_1, h_2, ..., h_T)$。 #### **2. 预测(Predictor)** - **功能**:建模输出序列的历史依赖(类似语言模型)。 - **输入**:已生成的历史标签 $y_{<u}$(如已识别的文字)。 - **输出**:隐状态 $p_u$(表示当前上下文)。 - **结构**:通常为单层 LSTM 或 RNN。 #### **3. 联合网络(Joint Network)** - **功能**:融合编码和预测的输出,生成当前时刻的标签概率分布。 - **数学形式**: $$ z_{t,u} = \text{ReLU}(W_h h_t + W_p p_u + b) \\ P(y_{t,u} | X, y_{<u}) = \text{Softmax}(W_z z_{t,u}) $$ - $h_t$:编码在时间步 $t$ 的输出。 - $p_u$:预测在输出步 $u$ 的输出。 - $y_{t,u}$:在输入时间步 $t$ 和输出步 $u$ 的候选标签。 --- ### **、RNN-T 的工作原理** #### **动态对齐机制** - 输入与输出通过一个**二维网格**对齐,允许模型自由选择对齐路径(见下图)。 - **路径搜索**:通常使用 Beam Search 或贪心算法选择最优路径。 <div align=center> <img src="https://miro.medium.com/v2/resize:fit:720/format:webp/1*Y8p9aYqgXZzqyN7N8t2kOg.png" width="400" alt="RNN-T 对齐网格"> </div> #### **训练目标** - **损失函数**:负对数似然,最大化正确路径的概率: $$ \mathcal{L} = -\log P(Y^* | X) $$ 其中 $Y^*$ 是真实标签序列。 --- ### **四、RNN-T 与其他模型的对比** | **模型** | **对齐方式** | **流式支持** | **输出依赖** | **典型应用** | |----------------|---------------------|--------------|-----------------------|---------------------| | **CTC** | 单调对齐(强制) | 是 | 输出间独立 | 简单语音识别 | | **Seq2Seq** | 注意力机制 | 否 | 自回归依赖 | 机翻译 | | **RNN-T** | 动态对齐(自由路径) | 是 | 依赖历史输出 | 实时语音识别 | --- ### **五、RNN-T 的优势与挑战** #### **优势** 1. **流式处理能力**:适合实时场景(如语音助手)。 2. **端到端训练**:无需对齐数据,简化流程。 3. **语言模型融合**:预测隐式建模语言信息。 #### **挑战** 1. **训练复杂度高**:需处理二维动态规划(计算成本高)。 2. **内存消耗大**:联合网络的矩阵运算对设备资源要求高。 3. **长序列建模**:预测的历史依赖可能限制长文本生成能力。 --- ### **六、实际应用场景** 1. **语音识别(ASR)** - 如 Google 的 Live Transcribe、Apple Siri 的实时转录。 - 支持低延迟流式输出(<300ms)。 2. **手写识别** - 将笔触坐标序列转换为文字。 3. **视频动作标注** - 对视频帧序列生成动作描述文本。 --- ### **七、改进与变体** 1. **Hybrid RNN-T**:结合 CTC 损失加速训练收敛。 2. **Transformer Transducer**:用 Transformer 替换 RNN,提升长序列建模能力。 3. **Emformer RNN-T**:基于内存的高效 Transformer,优化流式处理。 --- ### **八、代码示例(简化版 PyTorch 实现)** ```python class RNNT(nn.Module): def __init__(self, input_dim, vocab_size): super().__init__() # 编码:BiLSTM self.encoder = nn.LSTM(input_dim, 256, bidirectional=True) # 预测:LSTM self.predictor = nn.LSTM(128, 128) # 联合网络 self.joint_net = nn.Sequential( nn.Linear(512, 256), # 编码输出(256*2) + 预测输出(128) nn.ReLU(), nn.Linear(256, vocab_size) ) def forward(self, x, y): # 编码处理输入序列 h_enc, _ = self.encoder(x) # h_enc: [T, B, 512] # 预测处理历史标签 p, _ = self.predictor(y) # p: [U, B, 128] # 联合网络计算所有(t,u)组合的logits logits = [] for t in range(h_enc.shape[0]): for u in range(p.shape[0]): z = torch.cat([h_enc[t], p[u]], dim=-1) logits.append(self.joint_net(z)) return torch.stack(logits) ``` --- ### **九、学习资源推荐** 1. **论文**: - [《Sequence Transduction with Recurrent Neural Networks》](https://arxiv.org/abs/1211.3711)(Alex Graves, 2012) - [《Transformer Transducer》](https://arxiv.org/abs/1910.12977)(Google, 2019) 2. **工具库**: - NVIDIA NeMo:提供预训练 RNN-T 模型和训练工具。 - ESPnet:支持 RNN-T 的语音识别全流程框架。 --- 通过上述设计,RNN-T 在流式语音识别任务中实现了高精度与低延迟的平衡,成为工业界主流的端到端语音识别方案之一。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值