python调用本地大模型与向量模型

基于python调用ollma的api模型接口.streamlit构建网页端,向量模型对相关文档进行处理。

import os
import asyncio
import chromadb
import aiohttp
import requests
import re
import yaml
from typing import List, Optional, Dict, Generator, Tuple
from datetime import datetime
from uuid import uuid4
from io import BytesIO
import openai
import streamlit as st
from streamlit.runtime.uploaded_file_manager import UploadedFile
from PyPDF2 import PdfReader
from docx import Document
import chardet


# --------------------------
# 配置管理
# --------------------------
class AppConfig:
    def __init__(self):
        self._config = None
        self.load_time = None

    @property
    def config(self) -> dict:
        """带缓存机制的配置加载"""
        if not self._config or (datetime.now() - self.load_time).seconds > 300:
            self._load_config()
        return self._config

    def _load_config(self):
        """安全加载配置文件"""
        try:
            config_path = os.path.join(os.path.dirname(__file__), "config.yaml")
            if not os.path.exists(config_path):
                raise FileNotFoundError(f"配置文件不存在: {config_path}")

            with open(config_path) as f:
                config = yaml.safe_load(f)

            required_keys = {"ollama", "embed_model", "vector_db_path"}
            if missing := required_keys - config.keys():
                raise KeyError(f"缺少必要配置项: {missing}")

            self._config = config
            self.load_time = datetime.now()
        except Exception as e:
            st.error(f"配置加载失败: {str(e)}")
            st.stop()


config = AppConfig()


# --------------------------
# 向量数据库管理
# --------------------------
class VectorDBManager:
    def __init__(self, path: str):
        self.path = path
        self.client = None
        self.collection = None

    def initialize(self):
        """带重试机制的数据库初始化"""
        for retry in range(3):
            try:
                self.client = chromadb.PersistentClient(path=self.path)
                self.collection = self.client.get_or_create_collection(
                    name="chat_history",
                    metadata={"hnsw:space": "cosine"}
                )
                return
            except Exception as e:
                if retry == 2:
                    raise RuntimeError(f"向量数据库初始化失败: {str(e)}")
                time.sleep(2 ** retry)


try:
    db_manager = VectorDBManager(config.config["vector_db_path"])
    db_manager.initialize()
except Exception as e:
    st.error(str(e))
    st.stop()


# --------------------------
# 嵌入服务
# --------------------------
class EmbeddingService:
    def __init__(self, base_url: str, model: str):
        self.base_url = base_url
        self.model = model
        self.cache = {}
        self.lock = asyncio.Lock()

    async def get_embedding(self, text: str) -> Optional[List[float]]:
        """带缓存和重试机制的嵌入获取"""
        sanitized_text = self._sanitize(text)

        # 缓存检查
        if sanitized_text in self.cache:
            return self.cache[sanitized_text]

        async with self.lock:
            for attempt in range(3):
                try:
                    async with aiohttp.ClientSession() as session:
                        async with session.post(
                                f"{self.base_url}/api/embeddings",
                                json={"model": self.model, "prompt": sanitized_text},
                                timeout=aiohttp.ClientTimeout(total=15)
                        ) as resp:
                            resp.raise_for_status()
                            data = await resp.json()
                            embedding = data.get("embedding")
                            if embedding:
                                self.cache[sanitized_text] = embedding
                                return embedding
                except Exception as e:
                    if attempt == 2:
                        st.error(f"嵌入获取失败: {str(e)}")
                    await asyncio.sleep(1.5 ** attempt)
        return None

    @staticmethod
    def _sanitize(text: str) -> str:
        """输入文本安全处理"""
        return re.sub(r'[<>{}`]', '', text).strip()[:1000]


# --------------------------
# 文档处理管道
# --------------------------
class DocumentProcessor:
    @staticmethod
    async def process_files(files: List[UploadedFile], embed_service: EmbeddingService):
        """并行处理多个文件"""
        semaphore = asyncio.Semaphore(3)  # 控制并发数

        async def process_file(file: UploadedFile):
            async with semaphore:
                try:
                    content = await DocumentParser.parse(file)
                    for chunk in TextSplitter.split(content):
                        if embedding := await embed_service.get_embedding(chunk):
                            db_manager.collection.add(
                                ids=[f"doc_{uuid4()}"],
                                embeddings=[embedding],
                                documents=[chunk],
                                metadatas={
                                    "source": file.name,
                                    "type": "uploaded_doc",
                                    "timestamp": datetime.now().isoformat()
                                }
                            )
                    st.toast(f"✅ 成功处理: {file.name}")
                except Exception as e:
                    st.error(f"❌ 处理失败 {file.name}: {str(e)}")

        await asyncio.gather(*[process_file(f) for f in files])


class DocumentParser:
    @staticmethod
    async def parse(file: UploadedFile) -> str:
        """异步解析文档内容"""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, DocumentParser._sync_parse, file)

    @staticmethod
    def _sync_parse(file: UploadedFile) -> str:
        """同步解析实现"""
        content_bytes = file.getvalue()
        encodings = ['utf-8', 'gbk', 'iso-8859-1']

        try:
            if file.type == "application/pdf":
                return "\n".join(p.extract_text() for p in PdfReader(BytesIO(content_bytes)).pages)
            if file.type == "text/plain":
                for enc in encodings:
                    try:
                        return content_bytes.decode(enc)
                    except UnicodeDecodeError:
                        continue
            if file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
                return "\n".join(p.text for p in Document(BytesIO(content_bytes)).paragraphs)
        except Exception as e:
            raise ValueError(f"文档解析失败: {str(e)}")
        raise ValueError("不支持的文档格式")


class TextSplitter:
    @staticmethod
    def split(text: str, chunk_size=500, overlap=100) -> List[str]:
        """智能文本分块"""
        sentences = re.split(r'(?<=[。!?])', text)
        chunks = []
        current_chunk = []
        current_length = 0

        for sentence in sentences:
            sentence = sentence.strip()
            if not sentence:
                continue
            slen = len(sentence)

            if current_length + slen > chunk_size:
                chunks.append("".join(current_chunk))
                current_chunk = current_chunk[-int(overlap / 20):]  # 保留部分上文
                current_length = sum(len(s) for s in current_chunk)

            current_chunk.append(sentence)
            current_length += slen

        if current_chunk:
            chunks.append("".join(current_chunk))

        return chunks


# --------------------------
# LLM交互模块
# --------------------------
class ChatService:
    def __init__(self, base_url: str):
        self.base_url = base_url
        self.client = openai.AsyncOpenAI(base_url=f"{base_url}/v1", api_key="no-key-required")

    async def stream_response(self, messages: List[Dict], model: str, temperature: float) -> Generator[str, None, None]:
        """流式响应生成"""
        try:
            stream = await self.client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                stream=True
            )

            full_response = ""
            async for chunk in stream:
                if content := chunk.choices[0].delta.content:
                    full_response += content
                    yield content

            # 保存对话历史
            st.session_state.history.extend([
                {"role": "user", "content": messages[-1]["content"]},
                {"role": "assistant", "content": full_response}
            ])

        except Exception as e:
            yield f"⚠️ 生成错误: {str(e)}"


# --------------------------
# 用户界面组件
# --------------------------
class UIComponents:
    @staticmethod
    def setup_page():
        st.set_page_config(
            page_title="智能研究助手",
            page_icon="🧠",
            layout="wide",
            initial_sidebar_state="expanded"
        )
        st.markdown("""
        <style>
        [data-testid="stSidebar"] {
            background: #f5f7fb !important;
        }
        .stChatFloatingInputContainer {
            bottom: 20px;
            padding: 1rem;
            background: white;
            box-shadow: 0 4px 12px rgba(0,0,0,0.1);
            border-radius: 12px;
        }
        </style>
        """, unsafe_allow_html=True)

    @staticmethod
    def model_status(service_ok: bool):
        status_color = "#4CAF50" if service_ok else "#FF5252"
        st.sidebar.markdown(
            f'<div style="padding: 8px; background: {status_color}; color: white; border-radius: 4px;">'
            f'服务状态: {"正常" if service_ok else "异常"}'
            '</div>',
            unsafe_allow_html=True
        )

    @staticmethod
    def chat_input_area():
        with st.container():
            cols = st.columns([0.85, 0.15])
            with cols[0]:
                prompt = st.text_input(label="", placeholder="输入您的问题...", key="input",
                                       label_visibility="collapsed")
            with cols[1]:
                if st.button("发送", use_container_width=True):
                    return prompt
        return None


# --------------------------
# 主应用逻辑
# --------------------------
async def fetch_models_from_ollama(base_url: str) -> Tuple[Dict[str, str], Dict[str, str]]:
    """从Ollama API获取模型列表,并区分嵌入模型和聊天模型"""
    try:
        response = requests.get(f"{base_url}/v1/models")
        response.raise_for_status()
        data = response.json()
        # 假设响应结构为 {"data": [{"id": "model_id", "object": "model", "created": ..., "owned_by": ...}, ...]}
        models = data.get('data', [])
        embed_models = {model['id']: model.get('description', model['id']) for model in models if
                        "embed" in model['id']}
        chat_models = {model['id']: model.get('description', model['id']) for model in models if
                       "embed" not in model['id']}
        if not embed_models:
            st.warning("API 响应中未找到任何嵌入模型,请检查 API 文档或服务器配置。")
        if not chat_models:
            st.warning("API 响应中未找到任何聊天模型,请检查 API 文档或服务器配置。")
        return embed_models, chat_models
    except requests.exceptions.HTTPError as http_err:
        st.error(f"HTTP 错误: {http_err}")
    except requests.exceptions.RequestException as req_err:
        st.error(f"请求错误: {req_err}")
    except Exception as e:
        st.error(f"无法获取模型列表: {str(e)}")
    return {}, {}


async def main():
    # 初始化UI
    UIComponents.setup_page()

    # 初始化服务组件
    ollama_config = config.config["ollama"]
    embed_base_url = ollama_config["base_url"]
    embed_model = config.config["embed_model"]
    chat_base_url = ollama_config["base_url"]  # 假设聊天模型使用相同的base_url

    # 获取模型列表
    embed_models, chat_models = await fetch_models_from_ollama(embed_base_url)

    # 初始化会话状态
    if "history" not in st.session_state:
        st.session_state.update({
            "history": [],
            "current_embed_model": embed_model if embed_model in embed_models else next(
                iter(embed_models)) if embed_models else None,
            "current_chat_model": next(iter(chat_models)) if chat_models else None,
            "temperature": 0.7,
            "processing": False,
            "uploaded_files": []
        })

    # 侧边栏组件
    with st.sidebar:
        st.title("设置")
        UIComponents.model_status(True)

        # 文件上传
        uploaded_files = st.file_uploader(
            "上传研究文档",
            type=["pdf", "docx", "txt"],
            accept_multiple_files=True,
            key="file_uploader"
        )
        if uploaded_files and st.button("开始处理"):
            st.session_state.uploaded_files = uploaded_files
            with st.spinner('正在处理文档...'):
                embed_service = EmbeddingService(embed_base_url, st.session_state.current_embed_model)
                await DocumentProcessor.process_files(list(uploaded_files), embed_service)

        # 模型参数
        st.slider("温度参数", 0.0, 1.0, st.session_state.temperature, key="temp_slider",
                  on_change=lambda: st.session_state.__setitem__("temperature", st.session_state.temp_slider))

        # 模型选择
        if chat_models:
            selected_chat_model = st.selectbox("选择聊天模型", options=list(chat_models.keys()), index=0,
                                               key="chat_model_selector")
            if selected_chat_model != st.session_state.current_chat_model:
                st.session_state.current_chat_model = selected_chat_model
        else:
            st.warning("没有可用的聊天模型,请检查Ollama服务器配置。")

    # 主界面
    st.title("🧠 智能研究助手")

    # 显示历史对话
    for msg in st.session_state.history:
        with st.chat_message(msg["role"], avatar=(
        "https://img.alicdn.com/tfs/TB1oYRYwUT1gK0jSZFhXXaAtVXa-16-16.png" if msg[
                                                                                  "role"] == "user" else "https://img.alicdn.com/tfs/TB1ZLrwuET1gK0jSZSyXXXtlpXa-16-16.png")):
            st.markdown(msg["content"])

    # 处理用户输入
    if prompt := UIComponents.chat_input_area():
        if st.session_state.current_chat_model:
            st.session_state.history.append({"role": "user", "content": prompt})

            # 显示用户消息
            with st.chat_message("user", avatar="https://img.alicdn.com/tfs/TB1oYRYwUT1gK0jSZFhXXaAtVXa-16-16.png"):
                st.markdown(prompt)

            # 生成响应
            async def generate_and_show():
                chat_service = ChatService(chat_base_url)
                with st.chat_message("assistant",
                                     avatar="https://img.alicdn.com/tfs/TB1ZLrwuET1gK0jSZSyXXXtlpXa-16-16.png") as assistant_msg:
                    placeholder = st.empty()
                    full_response = ""

                    try:
                        async for chunk in chat_service.stream_response(
                                messages=st.session_state.history,
                                model=st.session_state.current_chat_model,
                                temperature=st.session_state.temperature
                        ):
                            full_response += chunk
                            placeholder.markdown(full_response + "▌")

                        placeholder.markdown(full_response)
                    except Exception as e:
                        placeholder.error(f"生成失败: {str(e)}")

            await generate_and_show()
        else:
            st.warning("请选择一个有效的聊天模型。")


if __name__ == "__main__":
    asyncio.run(main())

### 如何在本地环境中部署Ollma模型 #### 使用Docker部署Ollma 对于希望快速启动并运行大模型的服务,Ollma提供了一种简便的方法来实现在本地环境中的部署。特别是针对Windows系统的用户,推荐采用Docker容器化的方式来完成这一过程。 #### 准备工作 确保已经安装好Docker以及配置好了适用于Linux的子系统(WSL),这对于Windows平台上的操作至关重要[^2]。 #### 获取镜像 为了适应不同硬件条件下的需求,提供了多种版本的选择。如果目标机器不具备GPU支持,则应选择CPU版的官方镜像: ```bash docker pull ollama/ollama:0.3.7-rc6 ``` 这条命令会下载指定标签(`0.3.7-rc6`)下对应于CPU架构的最新稳定版本到用户的计算机上。 #### 启动服务 成功获取所需资源之后,下一步就是创建一个新的容器实例并将该镜像作为基础映射至主机端口之上以便访问API接口: ```bash docker run -d --name my_ollama_service -p 8080:8080 ollama/ollama:0.3.7-rc6 ``` 上述指令中`-d`参数表示以后台模式执行;而`--name`用于定义新建立起来的那个进程的名字;最后部分则是指明要基于哪个具体图像文件去构建新的虚拟空间,并将其内部监听着Web服务器的那一侧开放给外部网络连接请求到达宿主操作系统里所设定好的相应位置(-p选项)。 此时,只要浏览器能够正常解析IP地址并且没有防火墙阻止的话,那么就可以直接通过http://localhost:8080的方式轻松调用已上线的大规模预训练语言理解能力了!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值