记一次使用Java调用本地BERT模型,做文本内容实体提取,运行环境中不需要Python

背景

为什么使用Java加载模型?

在生产环境中没有任何必要使用Python在单独起一个服务提供服务接口,那会增加每次服务调用的时间,造成用户不好的体验。

同时为了减少部署的工作量,与其他业务功能都使用Java提供统一的服务接口,会减少很多的工作量,维护成本也相对减少。

环境说明

Java版本:11

操作系统:Windows

利用Python将模型本地化

下载模型

前往https://huggingface.co/中查找自己需要的开源模型,复制模型标识,比如如下图所示:

img

将模型标识替换掉下方代码中的mode_id内容,target_dir是输出目录,这里指定一个目录即可,目录不存在的话会自动创建目录。

import os, sys

model_id = "uer/roberta-base-finetuned-cluener2020-chinese"
target_dir = r"E:\Work\BERT\models\roberta-base-finetuned-cluener2020-chinese"
os.makedirs(target_dir, exist_ok=True)

# 常见需要的文件(可能有些模型文件名不同)
files = ["config.json", "vocab.txt", "tokenizer.json", "special_tokens_map.json", "pytorch_model.bin"]

api = HfApi()
for fname in files:
    try:
        print("Downloading", fname)
        path = hf_hub_download(repo_id=model_id, filename=fname, cache_dir=target_dir, local_dir=target_dir)
        print("Saved:", path)
    except Exception as e:
        print("Failed to download", fname, ":", e)
print("done")

验证是否下载成功

打开输出目录,查看文件是否有下载完成。至少需要包含pytorch_model.binconfig.jsonvocab.txt等以下文件。

img

生成tokenizer.json文件

有些模型是没有tokenizer.json文件的,就像我们现在所用的这个模型。但我们后续使用Java去加载这个模型时是需要用到tokenizer.json文件。下面是使用Python去根据下载的模型生成tokenizer.json文件代码:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese")
tokenizer.save_pretrained("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese", legacy_format=False)

Python加载模型测试

通过以下代码加载刚刚下载到本地的模型。指定目录即可,保证目录中模型存在会自动读取模型文件的。

from transformers import BertTokenizerFast, BertForTokenClassification
import torch

model_dir = "E:\\Work\BERT\\models\\roberta-base-finetuned-cluener2020-chinese"
tokenizer = BertTokenizerFast.from_pretrained(model_dir)
model = BertForTokenClassification.from_pretrained(model_dir)

text = "程序员范宁在北京大学的燕园看了中国男篮的一场比赛。"
tokens = list(text)
inputs = tokenizer(tokens, return_tensors="pt", is_split_into_words=True)

with torch.no_grad():
    outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=2)

id2label = model.config.id2label
print([id2label[i.item()] for i in predictions[0]])

输出内容

上面的示例代码输出结果如下所示:

['O', 'B-position', 'I-position', 'I-position', 'B-name', 'I-name', 'O', 'B-organization', 'I-organization', 'I-organization', 'I-address', 'O', 'I-address', 'I-address', 'O', 'O', 'O', 'O', 'O', 'I-organization', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

这个输出是一个典型的 命名实体识别(NER)任务的标签序列,使用的是 BIO 标注格式(有时也叫 IOB 格式)。

B-XXX:表示一个实体的开始(Begin),XXX 是实体类型(如 name、position、organization、address 等)。

I-XXX:表示该词属于 XXX 类型实体的中间或结尾部分(Inside),且前面已经有同类型的 B 或 I。

O:表示“Outside”,即不属于任何命名实体。

导出pt模型

常见的模型文件:

格式主要框架是否含结构是否跨语言是否可读典型文件名
.ptPyTorch✅(TorchScript) ❌(state_dict)✅(TorchScript) ❌(pickle)model.pt, traced_model.pt
.binPyTorch (HF)pytorch_model.bin
.h5TensorFlow/Keras❌(限 TF)⚠️(需工具)tf_model.h5
.msgpackFlax/JAX✅(数据)⚠️(二进制)flax_model.msgpack

如果你需要将模型转换为.pt标准格式模型用于Java服务,下面是用来将.bin模型导出为.pt模型的Python代码:

from transformers import BertTokenizerFast, BertForTokenClassification
import torch

model_dir = "E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese"
tokenizer = BertTokenizerFast.from_pretrained(model_dir)
model = BertForTokenClassification.from_pretrained(model_dir)
model.eval()

# 示例输入(必须和实际输入格式一致)
text = "程序员范宁在北京大学的燕园看了中国男篮的一场比赛。"
tokens = list(text)
inputs = tokenizer(tokens, return_tensors="pt", is_split_into_words=True)

# 导出为 TorchScript
traced_model = torch.jit.trace(
    model,
    (inputs["input_ids"], inputs["attention_mask"]),
    strict=False
)
traced_model.save("roberta-cluener-traced.pt")

Java加载模型

引入工程依赖

以下为pom.xml文件的核心片段内容:

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>0.34.0</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <dependencies>
        <!-- 系统依赖-->
        <dependency>
            <groupId>cn.tworice</groupId>
            <artifactId>tworice-system</artifactId>
        </dependency>
             
        <!-- 单元测试 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
        </dependency>

        <!-- DJL API -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>

        <!-- DJL PyTorch engine -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
        </dependency>

        <!-- PyTorch native CPU binding -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-cpu</artifactId>
            <version>2.7.1</version> <!-- 与 BOM 配合的 native 版本(可从 BOM 确认) -->
            <classifier>win-x86_64</classifier>
        </dependency>

        <!-- HuggingFace tokenizers helper(用于加载 tokenizer) -->
        <dependency>
            <groupId>ai.djl.huggingface</groupId>
            <artifactId>tokenizers</artifactId>
        </dependency>
    </dependencies>

加载模型

初始化标签映射

在实例初始化块中对标签映射关系进行初始化,这些标签对应关系一般在模型文件夹下的config.json文件中。config.json文件示例如下图所示:

img

将该内容转换成Java中的Map存储,可以编写一个自动化内容,也可以手动转一下,我这里手动转了一下,核心代码:

private final Map<Integer, String> ID2LABEL = new HashMap<>();

{
    // 标签映射
    ID2LABEL.put(0, "O");
    ID2LABEL.put(1, "B-address");
    ID2LABEL.put(2, "I-address");
    ID2LABEL.put(3, "B-book");
    ID2LABEL.put(4, "I-book");
    ID2LABEL.put(5, "B-company");
    // 这里其他类似内容省略.......
}

加载模型

先加载模型配置文件,这里就用到了上文中生成的tokenizer.json文件,将tokenizer.json文件所在目录替换掉代码中的目录,之后替换掉代码中的pt文件绝对路径。

@PostConstruct
public void init() throws IOException, MalformedModelException {
    tokenizer = HuggingFaceTokenizer.builder()
            .optTokenizerPath(Paths.get("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese"))
            .optAddSpecialTokens(true)
            .build();

    model = Model.newInstance("ner");
    model.load(Paths.get("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese\\roberta-cluener-traced.pt"));
}

完整代码

下面是Java加载模型服务提供类的完整代码:

package cn.tworice.djl;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.huggingface.tokenizers.Encoding;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.*;

@Service
public class NerService {

    private Model model;
    private HuggingFaceTokenizer tokenizer;
    private final Map<Integer, String> ID2LABEL = new HashMap<>();

    {
        // CLUEner2020 标签映射(请根据你的模型 config 确认)
        ID2LABEL.put(0, "O");
        ID2LABEL.put(1, "B-address");
        ID2LABEL.put(2, "I-address");
        ID2LABEL.put(3, "B-book");
        ID2LABEL.put(4, "I-book");
        ID2LABEL.put(5, "B-company");
        // 这里其他类似内容省略.......
    }

    @PostConstruct
    public void init() throws IOException, MalformedModelException {
        tokenizer = HuggingFaceTokenizer.builder()
                .optTokenizerPath(Paths.get("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese"))
                .optAddSpecialTokens(true)
                .build();

        model = Model.newInstance("ner");
        model.load(Paths.get("E:\\Work\\BERT\\models\\roberta-base-finetuned-cluener2020-chinese\\roberta-cluener-traced.pt"));
    }

    /**
     * 预测并打印每个 token 及其 label,同时返回实体列表
     */
    public List<NerEntity> predict(String text) throws TranslateException {
        try (NDManager manager = NDManager.newBaseManager()) {
            Encoding encoding = tokenizer.encode(text);
            long[] inputIds = encoding.getIds();
            long[] attentionMask = encoding.getAttentionMask();
            String[] tokens = encoding.getTokens(); // 实际分词结果

            // === 打印输入 ===
            System.out.println(">>> 输入文本: " + text);
            System.out.println(">>> 分词结果 (含 [CLS]/[SEP]): " + Arrays.toString(tokens));

            // 转为 NDArray
            NDArray inputIdsArr = manager.create(new Shape(1, inputIds.length), DataType.INT64);
            inputIdsArr.set(inputIds);
            NDArray attentionMaskArr = manager.create(new Shape(1, attentionMask.length), DataType.INT64);
            attentionMaskArr.set(attentionMask);

            // 推理
            try (Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator())) {
                NDList inputs = new NDList(inputIdsArr, attentionMaskArr);
                NDList outputs = predictor.predict(inputs);
                NDArray logits = outputs.singletonOrThrow(); // [1, seq_len, num_labels]
                NDArray predictions = logits.argMax(2);      // [1, seq_len]
                long[] predIds = predictions.toLongArray();  // length = seq_len

                // === 构建 token -> label 映射(跳过 [CLS] 和 [SEP])===
                System.out.println("\n>>> Token 与 Label 对应关系:");
                List<String> tokenLabels = new ArrayList<>();
                // tokens[0] = [CLS], tokens[tokens.length-1] = [SEP]
                for (int i = 1; i < tokens.length - 1; i++) {
                    String token = tokens[i];
                    String label = ID2LABEL.getOrDefault((int) predIds[i], "O");
                    tokenLabels.add(label);
                    System.out.printf("  %-12s -> %s%n", token, label);
                }
                System.out.println(); // 空行分隔

                // === 提取实体(使用 tokenLabels 转为 long[])===
                long[] labelIds = tokenLabels.stream()
                        .mapToLong(label -> {
                            for (Map.Entry<Integer, String> entry : ID2LABEL.entrySet()) {
                                if (entry.getValue().equals(label)) return entry.getKey();
                            }
                            return 0L;
                        })
                        .toArray();

                // 注意:这里 tokens[1:-1] 对应原始分词,但中文通常按字,可直接用于 decode
                String[] contentTokens = Arrays.copyOfRange(tokens, 1, tokens.length - 1);
                List<NerEntity> entities = decodeEntities(contentTokens, labelIds);

                return entities;
            }
        }
    }

    private List<NerEntity> decodeEntities(String[] tokens, long[] labels) {
        List<NerEntity> entities = new ArrayList<>();
        StringBuilder currentEntity = new StringBuilder();
        String currentType = null;
        int start = -1;

        for (int i = 0; i < tokens.length && i < labels.length; i++) {
            String token = tokens[i];
            String label = ID2LABEL.getOrDefault((int) labels[i], "O");

            if (label.startsWith("B-")) {
                if (currentType != null) {
                    entities.add(new NerEntity(currentEntity.toString(), currentType, start, i));
                }
                currentEntity = new StringBuilder(token);
                currentType = label.substring(2);
                start = i;
            } else if (label.startsWith("I-") && currentType != null && label.substring(2).equals(currentType)) {
                currentEntity.append(token);
            } else {
                if (currentType != null) {
                    entities.add(new NerEntity(currentEntity.toString(), currentType, start, i));
                    currentType = null;
                    currentEntity.setLength(0);
                }
            }
        }

        if (currentType != null) {
            entities.add(new NerEntity(currentEntity.toString(), currentType, start, tokens.length));
        }

        return entities;
    }

    public static class NerEntity {
        public String entity;
        public String type;
        public int start;
        public int end;

        public NerEntity(String entity, String type, int start, int end) {
            this.entity = entity;
            this.type = type;
            this.start = start;
            this.end = end;
        }

        @Override
        public String toString() {
            return String.format("{'entity': '%s', 'type': '%s', 'start': %d, 'end': %d}", entity, type, start, end);
        }
    }
}

测试使用

利用单元测试,传入一段文字查看输出结果。

@SpringBootTest
public class BertTest {

    @Autowired
    private NerService nerService;

    @Test
    void testNerPrediction() throws Exception {
        String text = "2025年10月1日,程序员范宁在中国北京看了中国男篮的一场比赛。";
        List<NerService.NerEntity> entities = nerService.predict(text);

        System.out.println("输入文本: " + text);
        System.out.println("识别实体:");
        for (var entity : entities) {
            System.out.println(entity);
        }
    }
}

结果输出:

>>> 输入文本: 程序员范宁在中国北京看了中国男篮的一场比赛。
>>> 分词结果 ([CLS]/[SEP]): [[CLS],,,,,,,,,,,,,,,,,,,,,,, [SEP]]
[W1022 18:57:51.000000000 LegacyTypeDispatch.h:79] Warning: AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, If you are looking for a user facing API to enable running your inference-only workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code is under risk of producing silent wrong result in some edge cases. See Note [AutoDispatchBelowAutograd] for more details. (function operator ())

>>> TokenLabel 对应关系:-> B-position
  序            -> I-position
  员            -> I-position
  范            -> B-name
  宁            -> I-name
  在            -> O-> B-address
  国            -> I-address
  北            -> I-address
  京            -> I-address
  看            -> O-> O-> O-> O-> O-> I-organization
  的            -> O-> O-> O-> O-> O-> O

输入文本: 程序员范宁在中国北京看了中国男篮的一场比赛。
识别实体:
{'entity': '程序员', 'type': 'position', 'start': 0, 'end': 3}
{'entity': '范宁', 'type': 'name', 'start': 3, 'end': 5}
{'entity': '中国北京', 'type': 'address', 'start': 6, 'end': 10}
<think>好的,用户现在希望将模型下载到本地使用,会一点Python代码,但模型训练,只关心如何调用接口。我需要推荐适合的开源模型,并指导他们如何使用这些模型。 首先,用户之前已经询问过现成的无需编码的解决方案,现在他们更进了一步,想要自己下载模型本地,并且愿意写一些Python代码,但需要涉及训练部分。这说明用户可能有一定的技术背景,但可能对深度学习框架或模型部署太熟悉。需要确保推荐的模型使用简单,文档齐全,并且有清晰的API接口。 接下来,我需要考虑哪些开源模型适合这样的场景。Hugging Face的Transformers库是一个很好的选择,因为它提供了大量预训练模型,并且有简单的pipelines接口,用户只需几行代码就能调用。例如,sentence-transformers库特别适合计算文本相似度,这对于简历和岗位描述的匹配非常有用。 另外,SpaCy也是一个用户友好的库,虽然需要先下载模型,但它的API非常直观,处理NLP任务也很方便。用户可能需要处理简历中的实体识别,比如提取技能、经验等,SpaCy的预训练模型可以胜任。 对于需要本地部署的中文模型,可能要考虑像BERT这样的模型,但Hugging Face已经支持了许多中文模型,比如bert-base-chinese。需要确保推荐的模型支持中文,并且有足够的使用示例。 用户提到的模型下载到本地,可能需要指导他们使用from_pretrained方法,并指定本地路径。此外,如果用户担心网络问题,可以提示他们先下载模型文件,然后离线加载。 另外,需要注意用户的数据隐私问题,尤其是处理简历这种敏感信息。本地部署的模型可以避免数据上传到外部服务器,这点可能需要强调,以符合用户的安全需求。 用户可能需要示例代码来快速上手,所以每个推荐的模型都应该附带一个简单的代码示例,展示如何加载模型并进行推理。例如,使用Hugging Face的pipeline来生成文本相似度分数,或者用sentence-transformers计算嵌入向量之间的余弦相似度。 还要考虑模型的大小和硬件需求。如果用户的本地机器没有GPU,可能需要推荐一些轻量级的模型,比如蒸馏版的模型(如DistilBERT),这样在CPU上也能较快运行。 最后,应该提醒用户注意模型输入输出的格式,例如是否需要提前将PDF简历转换为文本,以及如何处理长文本的分块输入。此外,可能还需要建议他们进行一些后处理,比如如何根据相似度分数制定阈值来判断匹配程度。 总结下来,推荐的模型需要满足以下几点:易于安装和调用、支持中文、有良好的文档和社区支持、适合本地部署,并且提供清晰的代码示例。同时,要考虑到用户的技术水平,避免涉及复杂的模型训练或调整步骤,保持接口调用的简洁性。</think>针对您希望**本地部署模型**并**通过Python调用接口**的需求,以下是经过筛选的推荐方案,所有模型均支持中文且无需训练即可直接使用: --- ### 一、首选推荐:Hugging Face模型库 #### 1. **Sentence-BERT(文本相似度计算)** - **功能**:直接计算简历文本与岗位描述的相似度 - **安装**: ```bash pip install sentence-transformers ``` - **使用代码**: ```python from sentence_transformers import SentenceTransformer, util # 下载模型本地(首次运行需联网) model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2') # 本地加载(后续使用) # model = SentenceTransformer('./local_model_path/') resume_text = "精通Python,5年数据分析经验..." job_desc = "招聘数据分析师,要求3年以上Python经验..." # 生成向量 embeddings = model.encode([resume_text, job_desc]) # 计算相似度 similarity = util.cos_sim(embeddings[0], embeddings[1]) print(f"匹配度:{similarity.item():.2f}") # 输出示例:匹配度:0.86 ``` #### 2. **BERT关键词提取** - **功能**:从简历中提取技能/经验关键词 - **安装**: ```bash pip install transformers ``` - **使用代码**: ```python from transformers import pipeline # 自动下载模型本地 ner_pipeline = pipeline("ner", model="bert-base-chinese") text = "熟练掌握JavaPython,有AWS云平台开发经验" results = ner_pipeline(text) # 提取技能关键词 skills = [result['word'] for result in results if result['entity'] == 'SKILL'] print(skills) # 输出示例:['Java', 'Python', 'AWS'] ``` --- ### 二、轻量化方案:SpaCy工业级模型 #### 1. **中文文本处理** - **安装**: ```bash pip install spacy python -m spacy download zh_core_web_sm # 下载中文模型 ``` - **使用代码**: ```python import spacy # 加载本地模型 nlp = spacy.load("zh_core_web_sm") doc = nlp("负责过千万级用户系统的架构设计,精通微服务") # 提取关键信息 print([ent.text for ent in doc.ents]) # 输出示例:['千万级用户系统', '微服务'] ``` --- ### 三、专业推荐:FastText分类模型 #### 1. **岗位类型匹配** - **预训练模型**:[中文词向量](https://fasttext.cc/docs/en/crawl-vectors.html) - **使用代码**: ```python import fasttext # 下载模型本地(需手动下载cc.zh.300.bin) model = fasttext.load_model("cc.zh.300.bin") # 计算岗位匹配度 def match_job(resume, job_title): return model.get_sentence_vector(resume).dot(model.get_sentence_vector(job_title)) print(match_job("大数据开发经验", "数据工程师")) # 输出示例:0.78 ``` --- ### 四、模型本地化部署技巧 1. **永久保存模型**: ```python from transformers import AutoModel, AutoTokenizer model_name = "bert-base-chinese" model = AutoModel.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # 保存到本地文件夹 model.save_pretrained("./my_local_model/") tokenizer.save_pretrained("./my_local_model/") ``` 2. **离线加载**: ```python model = AutoModel.from_pretrained("./my_local_model/") ``` --- ### 五、避坑指南 1. **模型选择原则**: - 短文本匹配:优先选`sentence-transformers` - 长文本分析:用`BERT`系列 - 实时性要求高:用`FastText` 2. **硬件建议**: - CPU运行:选择带"mini","tiny","distil"前缀的模型(如`distilbert-base-chinese`) - GPU加速:选择`bert-large`系列 如果需要具体场景的完整代码模板(例如构建简历匹配系统),可以告诉我您的: 1. 输入数据格式(PDF/文本) 2. 期望输出形式(匹配分数/详细报告) 3. 硬件配置(有无GPU) 我将为您定制专属代码框架。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

二饭

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值