java做AI和大语言模型开的三大框架和两大技能

三大框架:1.springAI 2.LangChain4j 3Jlama

两大技能:RAG(检索增强生成)和 Fine-tuning(微调)

LangChain4j应用场景

在java项目中使用LangChain4j

添加依赖

如果你使用的是 Maven 项目,在 pom.xml 里添加以下依赖:

<dependencies>
    <dependency>
        <groupId>dev.langchain4j</groupId>
        <artifactId>langchain4j-openai</artifactId>
        <version>0.23.0</version>
    </dependency>
</dependencies>

配置 OpenAI API 密钥

你需要设置 OpenAI 的 API 密钥,可以通过环境变量或者代码中直接设置,这里采用环境变量的方式。在运行程序前设置环境变量:

编写示例代码

import dev.langchain4j.chain.ConversationalRetrievalChain;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentLoader;
import dev.langchain4j.data.document.splitter.CharacterTextSplitter;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.OpenAiEmbeddingModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;

import java.io.File;
import java.util.List;

import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;

public class LangChain4jExample {
    public static void main(String[] args) {
        // 创建 OpenAI 聊天模型
        OpenAiChatModel chatModel = OpenAiChatModel.withApiKey(System.getenv("OPENAI_API_KEY"));

        // 创建 OpenAI 嵌入模型
        EmbeddingModel embeddingModel = OpenAiEmbeddingModel.withApiKey(System.getenv("OPENAI_API_KEY"));

        // 创建嵌入存储
        EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();

        // 加载文档
        Document document = DocumentLoader.load(new File("path/to/your/document.txt"));

        // 分割文档为文本片段
        List<TextSegment> segments = new CharacterTextSplitter().split(document);

        // 为每个文本片段生成嵌入向量并存储
        for (TextSegment segment : segments) {
            embeddingStore.add(embeddingModel.embed(segment).content(), segment);
        }

        // 创建检索器
        EmbeddingStoreRetriever<TextSegment> retriever = EmbeddingStoreRetriever.from(embeddingStore, embeddingModel);

        // 创建对话式检索链
        ConversationalRetrievalChain chain = ConversationalRetrievalChain.builder()
               .chatLanguageModel(chatModel)
               .retriever(retriever)
               .promptTemplate(PromptTemplate.from("根据以下文档回答问题:{documents}\n问题:{question}"))
               .build();

        // 定义问题
        String question = "文档里提到了哪些关键信息?";

        // 执行链并获取答案
        String answer = chain.execute(question);

        // 输出答案
        System.out.println("问题:" + question);
        System.out.println("答案:" + answer);
    }
}

代码解释

  • 模型创建:通过 OpenAiChatModel 和 OpenAiEmbeddingModel 创建与 OpenAI 交互的聊天模型和嵌入模型。
  • 文档处理:加载本地文档,将其分割成文本片段,并为每个片段生成嵌入向量存储到 InMemoryEmbeddingStore 中。
  • 检索器创建:使用嵌入存储和嵌入模型创建检索器,用于根据问题检索相关的文本片段。
  • 对话式检索链:构建 ConversationalRetrievalChain,结合聊天模型、检索器和提示模板来处理问题。
  • 问题解答:执行链并传入问题,获取答案并输出。

Jlama的应用场景


Jlama为Java开发者打开了一扇通往AI世界的大门,其应用场景非常广泛:

智能客服系统:利用Jlama可以快速构建基于LLM的智能客服系统,提供24/7的自动化客户支持。

内容生成:在新闻、广告、社交媒体等领域,Jlama可以协助生成高质量的文本内容。

代码辅助:集成到IDE中,为程序员提供智能代码补全、注释生成等功能。

数据分析:在大数据处理中,利用LLM进行自然语言查询和数据解释。

教育科技:开发智能辅导系统,为学生提供个性化的学习体验。

金融科技:用于风险评估、市场分析、智能投顾等场景。
 

在Java项目中使用Jlama


将Jlama集成到Java项目中非常简单。首先,需要在Maven配置文件中添加以下依赖:

<dependency>
  <groupId>com.github.tjake</groupId>
  <artifactId>jlama-core</artifactId>
  <version>${jlama.version}</version>
</dependency>
 
<dependency>
  <groupId>com.github.tjake</groupId>
  <artifactId>jlama-native</artifactId>
  <classifier>${os.detected.name}-${os.detected.arch}</classifier>
  <version>${jlama.version}</version>
</dependency>


然后,可以使用以下代码示例来加载模型并生成文本:

import com.github.tjake.jlama.model.LlamaModel;
import com.github.tjake.jlama.sampling.SamplingParameters;
import com.github.tjake.jlama.sampling.SamplingResult;
import com.github.tjake.jlama.tokenizer.Tokenizer;

import java.io.IOException;
import java.nio.file.Paths;

public class JLamaExample {
    public static void main(String[] args) {
        try {
            // 加载模型和分词器
            String modelPath = "path/to/your/llama/model.bin"; // 替换为实际的模型文件路径
            LlamaModel model = LlamaModel.load(Paths.get(modelPath));
            Tokenizer tokenizer = model.getTokenizer();

            // 定义输入文本
            String inputText = "介绍一下中国的长城";

            // 对输入文本进行分词
            int[] inputTokens = tokenizer.encode(inputText);

            // 配置采样参数
            SamplingParameters samplingParameters = new SamplingParameters();
            samplingParameters.setTemperature(0.8f);
            samplingParameters.setTopK(40);
            samplingParameters.setTopP(0.95);
            samplingParameters.setMaxTokens(128);

            // 生成文本
            SamplingResult result = model.sample(inputTokens, samplingParameters);

            // 解码生成的文本
            String outputText = tokenizer.decode(result.getTokens());

            // 输出结果
            System.out.println("生成的文本: " + outputText);

            // 关闭模型
            model.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

代码解释

  • 模型加载:使用 LlamaModel.load 方法加载 Llama 模型文件,需要将 modelPath 替换为实际的模型文件路径。
  • 分词:通过 model.getTokenizer() 获取分词器,使用 tokenizer.encode 方法将输入文本转换为令牌数组。
  • 采样参数配置SamplingParameters 类用于配置文本生成的参数,如温度(temperature)、topKtopP 和最大生成令牌数(maxTokens)。
  • 文本生成:调用 model.sample 方法根据输入令牌和采样参数生成文本。
  • 解码:使用 tokenizer.decode 方法将生成的令牌数组转换为文本。

另一种写法 

public void sample() throws IOException {
    String model = "tjake/TinyLlama-1.1B-Chat-v1.0-Jlama-Q4";
    String workingDirectory = "./models";
    String prompt = "What is the best season to plant avocados?";
 
    File localModelPath = SafeTensorSupport.maybeDownloadModel(workingDirectory, model);
    AbstractModel m = ModelSupport.loadModel(localModelPath, DType.F32, DType.I8);
 
    if (m.promptSupport().isPresent()) {
        prompt = m.promptSupport().get().newBuilder()
                .addSystemMessage("You are a helpful chatbot who writes short responses.")
                .addUserMessage(prompt)
                .build();
    }
 
    System.out.println("Prompt: " + prompt + "\n");
    GenerateResponse r = m.generate(UUID.randomUUID(), prompt, 0.7f, 256, false, (s, f) -> System.out.print(s));
    System.out.println(r.toString());
}

 Deep Java Library (DJL)应用场景

进行图像分类 

在Java项目中使用DJL

添加 DJL 依赖
在 pom.xml 文件中添加 DJL 的依赖项。

<dependencies>
     
 
    <!-- DJL API -->
    <dependency>
        <groupId>ai.djl</groupId>
        <artifactId>api</artifactId>
        <version>0.19.0</version>
    </dependency>
 
    <!-- PyTorch Engine -->
    <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-engine</artifactId>
        <version>0.19.0</version>
    </dependency>
 
    <!-- NDArray for tensor operations -->
    <dependency>
        <groupId>ai.djl.ndarray</groupId>
        <artifactId>ndarray</artifactId>
        <version>0.19.0</version>
    </dependency>
 
    <!-- Image processing -->
    <dependency>
        <groupId>ai.djl.basicmodelzoo</groupId>
        <artifactId>basic-model-zoo</artifactId>
        <version>0.19.0</version>
    </dependency>
 
    <!-- Native libraries for PyTorch -->
    <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-native-auto</artifactId>
        <version>1.10.0</version>
    </dependency>
</dependencies>


3. 创建图像分类服务
创建一个服务类来处理图像分类逻辑

import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
 
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
 
@Service
public class ImageClassificationService {
 
    private final Model model;
 
    public ImageClassificationService() throws IOException {
        // 定义模型的准则
        Criteria<Image, String> criteria = Criteria.builder()
                .setTypes(Image.class, String.class)
                .optApplication(ai.djl.modality.cv.Application.IMAGE_CLASSIFICATION)
                .optModelName("resnet18_v1")
                .optEngine("PyTorch")
                .optProgress(new ProgressBar())
                .build();
 
        // 加载模型
        model = criteria.loadModel();
    }
 
    public String classifyImage(MultipartFile file) throws IOException, TranslateException {
        // 将文件保存到临时位置
        Path tempFile = Files.createTempFile("image", ".jpg");
        Files.copy(file.getInputStream(), tempFile, StandardCopyOption.REPLACE_EXISTING);
 
        // 加载图像
        Image img = ImageFactory.getInstance().fromFile(tempFile);
 
        // 预处理图像
        Pipeline pipeline = new Pipeline();
        pipeline.add(new Resize(224, 224));
        pipeline.add(new ToTensor());
 
        // 创建预测器
        try (Predictor<Image, String> predictor = model.newPredictor()) {
            // 进行预测
            String result = predictor.predict(img);
 
            // 删除临时文件
            Files.delete(tempFile);
 
            return result;
        }
    }
}


解释代码:
Criteria: 用于定义模型的准则,包括模型的类型、应用领域、模型名称、引擎等。
ZooModel: 从模型库中加载模型。
Predictor: 用于进行预测。
ImageFactory: 用于加载图像。
Pipeline: 用于定义图像预处理步骤,如调整大小和转换为张量。
Translator: 用于定义输入和输出的转换逻辑

4. 创建控制器
创建一个控制器类来处理 HTTP 请求。

import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
 
import java.io.IOException;
 
@RestController
@RequestMapping("/api/classify")
public class ImageClassificationController {
 
    private final ImageClassificationService imageClassificationService;
 
    public ImageClassificationController(ImageClassificationService imageClassificationService) {
        this.imageClassificationService = imageClassificationService;
    }
 
    @PostMapping
    public ResponseEntity<String> classifyImage(@RequestParam("file") MultipartFile file) {
        try {
            String result = imageClassificationService.classifyImage(file);
            return ResponseEntity.ok(result);
        } catch (IOException | TranslateException e) {
            return ResponseEntity.badRequest().body("Error classifying image: " + e.getMessage());
        }
    }
}


5. 启动应用程序
创建一个主类来启动 Spring Boot 应用程序。

POST /api/classify
Content-Type: multipart/form-data
 
Form Data:
file: (选择一个图像文件)
如果一切正常,你应该会收到图像的分类结果。

其他值得关注的Java AI框架
训练模型框架 deeplearning4j

处理步骤

  • 数据处理
    • 定义数据模式 Schema,用于描述数据的结构。
    • 使用 TransformProcess 对数据进行转换,例如将分类标签转换为整数。
    • 使用 CSVRecordReader 读取 CSV 文件,并使用 RecordReaderDataSetIterator 创建数据集迭代器。
  • 神经网络配置
    • 使用 NeuralNetConfiguration 配置神经网络的基本参数,如随机种子、权重初始化方法和优化器。
    • 添加隐藏层和输出层,分别使用 DenseLayer 和 OutputLayer
  • 模型训练
    • 创建 MultiLayerNetwork 实例并初始化。
    • 使用 fit 方法对模型进行训练,训练多个轮次(numEpochs)。
  • 模型预测
    • 从数据集中获取测试数据,调用 output 方法进行预测,并输出预测结果。

RAG和微调的优点:

  • 检索器能够快速访问广泛和最新的外部数据;
  • 微调能够深度定制模型以适应专业领域;
  • 生成器则结合外部上下文和微调后的领域知识来生成响应。

检索增强生成通过整合信息检索系统,能够根据上下文从外部资源中检索相关信息,这些信息补充LLM的输入,使其输出基于可靠来源的事实。RAG在开放领域问答和知识密集型任务中表现良好。(Retrieval-Augmented Generation, RAG)

具体参考https://zhuanlan.zhihu.com/p/688138789

特定领域微调则通过在特定领域数据上调整LLM参数来实现模型专门化,使模型学习到相关的知识和语言模式,显著提升相关任务表现。

特定领域微调的应用场景

  • 生物医学:在科学和医学数据上进行微调,使LLM能够帮助完成回答问题、文献分析,甚至医学编码等任务。
  • 计算机科学:在代码库和文档上进行微调,帮助LLM提升代码生成、文档理解,以及开发工具支持。
  • 金融:在金融新闻、报告和分析上进行微调,提高了LLM在股票预测、风险评估和财务报告生成等方面的能力。
  • 法律:在法律文件和案例法上进行微调,使LLM能够协助进行法律研究、合同分析,甚至起草法律文件。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值