摘要
在人工智能技术快速发展的今天,将AI模型转化为用户友好的应用程序是开发者面临的重要挑战。Streamlit作为一个专为机器学习和数据科学团队设计的开源Python库,让开发者能够快速创建美观、功能强大的数据应用界面,而无需深入学习前端技术。本文将详细介绍如何使用Streamlit构建AI应用界面,涵盖从基础组件使用到高级交互设计的完整流程。通过丰富的实践示例、架构图、流程图等可视化内容,帮助中国开发者快速掌握Streamlit在AI应用开发中的应用,提升开发效率和用户体验。
思维导图:Streamlit AI应用开发知识点全景

mindmap
root((Streamlit AI应用))
基础概念
Streamlit特性
简单易用
实时交互
自动重载
响应式设计
核心组件
文本组件
数据组件
图表组件
媒体组件
核心功能
用户界面
输入控件
布局管理
主题定制
数据可视化
图表展示
地图集成
实时更新
AI模型集成
模型加载
预测展示
结果解释
高级特性
缓存机制
数据缓存
资源缓存
自定义缓存
性能优化
异步处理
批量操作
进度显示
安全特性
认证集成
权限控制
输入验证
部署发布
本地部署
开发环境
测试环境
云端部署
Streamlit Cloud
Heroku
Docker部署
实践案例
文本分析应用
图像识别应用
数据可视化应用
推荐系统应用
最佳实践
代码组织
错误处理
用户体验
性能调优
1. Streamlit与AI应用概述
1.1 为什么选择Streamlit
Streamlit是专为数据科学和机器学习团队设计的现代应用框架,具有以下显著优势:
- 简单易学:使用纯Python编写,无需HTML、CSS或JavaScript知识
- 快速开发:几行代码即可创建交互式应用界面
- 实时交互:变量自动更新,无需手动刷新
- 丰富的组件:内置文本、数据、图表、媒体等多种组件
- 自动重载:代码保存时自动刷新应用
- 响应式设计:自动适配不同设备屏幕
1.2 Streamlit应用架构
2. 环境搭建与基础应用
2.1 环境准备
首先,我们需要安装Streamlit及相关依赖:
# 创建虚拟环境
python -m venv streamlit-ai-env
source streamlit-ai-env/bin/activate # Linux/Mac
# streamlit-ai-env\Scripts\activate # Windows
# 安装核心依赖
pip install streamlit pandas numpy matplotlib seaborn scikit-learn
# 安装额外工具
pip install plotly altair pillow opencv-python
2.2 基础Streamlit应用
让我们创建一个基础的Streamlit应用来了解其工作原理:
# basic_app.py
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import time
# 设置页面配置
st.set_page_config(
page_title="AI应用基础示例",
page_icon="🤖",
layout="wide",
initial_sidebar_state="expanded"
)
# 页面标题
st.title("🤖 Streamlit AI应用基础示例")
st.markdown("---")
# 侧边栏
st.sidebar.header("应用设置")
app_mode = st.sidebar.selectbox(
"选择应用模式",
["主页", "数据探索", "模型演示", "关于"]
)
# 主页内容
if app_mode == "主页":
st.header("欢迎使用Streamlit AI应用")
# 显示应用介绍
st.markdown("""
这是一个使用Streamlit构建的AI应用示例,展示了以下功能:
- 📊 数据可视化和探索
- 🤖 机器学习模型演示
- 🎛️ 交互式用户界面
- 📈 实时结果展示
请在左侧选择不同的应用模式来体验各项功能。
""")
# 显示应用截图或示例图片
st.image(
"https://streamlit.io/images/brand/streamlit-logo-secondary-colormark-darktext.png",
caption="Streamlit Logo",
width=300
)
# 显示一些统计信息
col1, col2, col3 = st.columns(3)
col1.metric("应用版本", "1.0.0")
col2.metric("Python版本", "3.9+")
col3.metric("最后更新", datetime.now().strftime("%Y-%m-%d"))
# 数据探索模式
elif app_mode == "数据探索":
st.header("📊 数据探索")
# 生成示例数据
@st.cache_data
def load_data():
"""加载示例数据"""
data = pd.DataFrame({
'日期': pd.date_range(start='2023-01-01', periods=100, freq='D'),
'销售额': np.random.randn(100).cumsum() + 100,
'访问量': np.random.randint(50, 200, 100),
'转化率': np.random.uniform(0.01, 0.1, 100),
'产品类别': np.random.choice(['A', 'B', 'C', 'D'], 100)
})
return data
# 加载数据
with st.spinner("正在加载数据..."):
df = load_data()
# 显示数据基本信息
st.subheader("数据概览")
st.write(f"数据形状: {df.shape}")
st.dataframe(df.head(10))
# 数据统计
st.subheader("数据统计")
st.write(df.describe())
# 数据可视化
st.subheader("数据可视化")
# 创建多个图表
tab1, tab2, tab3 = st.tabs(["时间序列", "分布图", "相关性"])
with tab1:
# 时间序列图
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(df['日期'], df['销售额'])
ax.set_title("销售额时间序列")
ax.set_xlabel("日期")
ax.set_ylabel("销售额")
st.pyplot(fig)
with tab2:
# 分布图
fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(df['转化率'], bins=20, alpha=0.7)
ax.set_title("转化率分布")
ax.set_xlabel("转化率")
ax.set_ylabel("频次")
st.pyplot(fig)
with tab3:
# 相关性热力图
corr_matrix = df[['销售额', '访问量', '转化率']].corr()
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', ax=ax)
ax.set_title("变量相关性热力图")
st.pyplot(fig)
# 模型演示模式
elif app_mode == "模型演示":
st.header("🤖 机器学习模型演示")
# 选择模型类型
model_type = st.selectbox(
"选择模型类型",
["线性回归", "决策树", "随机森林", "支持向量机"]
)
# 模型参数设置
st.subheader("模型参数")
col1, col2 = st.columns(2)
with col1:
train_size = st.slider("训练集比例", 0.1, 0.9, 0.8, 0.1)
max_depth = st.slider("最大深度", 1, 20, 5)
with col2:
n_estimators = st.slider("估计器数量", 10, 200, 100, 10)
random_state = st.number_input("随机种子", 0, 1000, 42)
# 生成示例数据
@st.cache_data
def generate_regression_data(n_samples=1000):
"""生成回归示例数据"""
np.random.seed(42)
X = np.random.randn(n_samples, 3)
y = 2*X[:, 0] + 3*X[:, 1] - X[:, 2] + np.random.randn(n_samples)*0.1
return X, y
# 训练模型
if st.button("训练模型"):
with st.spinner("正在训练模型..."):
# 模拟训练过程
progress_bar = st.progress(0)
for i in range(100):
time.sleep(0.01) # 模拟训练时间
progress_bar.progress(i + 1)
# 生成示例结果
X, y = generate_regression_data()
train_size_int = int(len(X) * train_size)
st.success("模型训练完成!")
# 显示训练结果
st.subheader("训练结果")
col1, col2, col3 = st.columns(3)
col1.metric("训练样本数", train_size_int)
col2.metric("测试样本数", len(X) - train_size_int)
col3.metric("模型准确率", f"{np.random.uniform(0.85, 0.95):.2%}")
# 显示预测结果图表
st.subheader("预测结果对比")
fig, ax = plt.subplots(figsize=(10, 6))
# 生成预测值(模拟)
y_pred = y + np.random.randn(len(y)) * 0.1
ax.scatter(y[:100], y_pred[:100], alpha=0.6)
ax.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
ax.set_xlabel("真实值")
ax.set_ylabel("预测值")
ax.set_title("预测值 vs 真实值")
st.pyplot(fig)
# 关于页面
else:
st.header("ℹ️ 关于")
st.markdown("""
### 应用信息
- **开发者**: AI应用开发团队
- **版本**: 1.0.0
- **技术栈**:
- Python 3.9+
- Streamlit
- Pandas, NumPy
- Scikit-learn
- Matplotlib, Seaborn
### 功能特性
1. **交互式界面**: 提供友好的用户交互体验
2. **数据可视化**: 支持多种图表展示方式
3. **模型演示**: 展示机器学习模型的工作原理
4. **实时更新**: 数据和结果实时更新
### 联系方式
如有任何问题或建议,请联系:
- 邮箱: ai-app@example.com
- GitHub: github.com/ai-app-demo
""")
# 显示技术栈图标
st.image(
"https://upload.wikimedia.org/wikipedia/commons/c/c3/Python-logo-notext.svg",
caption="Python",
width=100
)
# 页脚
st.markdown("---")
st.markdown("© 2025 AI应用开发团队. 保留所有权利.")
3. AI模型集成实践
3.1 文本分类应用
# text_classifier_app.py
import streamlit as st
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from wordcloud import WordCloud
import io
import base64
# 设置页面配置
st.set_page_config(
page_title="文本分类应用",
page_icon="📝",
layout="wide"
)
# 页面标题
st.title("📝 文本情感分类应用")
st.markdown("---")
# 侧边栏配置
st.sidebar.header("应用配置")
# 数据源选择
data_source = st.sidebar.radio(
"数据源",
["示例数据", "上传数据"]
)
# 模型参数
st.sidebar.subheader("模型参数")
test_size = st.sidebar.slider("测试集比例", 0.1, 0.5, 0.2, 0.1)
max_features = st.sidebar.slider("最大特征数", 1000, 10000, 5000, 1000)
random_state = st.sidebar.number_input("随机种子", 0, 100, 42)
# 加载数据函数
@st.cache_data
def load_sample_data():
"""加载示例数据"""
# 创建示例数据
texts = [
"这部电影真是太棒了,演员表演出色,剧情引人入胜",
"我觉得这部电影很无聊,浪费时间",
"画面精美,特效震撼,值得一看",
"剧情拖沓,演员表演生硬",
"这是我看过最好的电影之一,强烈推荐",
"故事情节老套,没有新意",
"音乐优美,情感丰富,触动人心",
"节奏太慢,看得昏昏欲睡",
"导演功力深厚,作品质量很高",
"台词尴尬,演技浮夸,难以入戏"
] * 100 # 扩展数据集
labels = [1, 0, 1, 0, 1, 0, 1, 0, 1, 0] * 100 # 1表示正面,0表示负面
return pd.DataFrame({
'text': texts,
'label': labels
})
@st.cache_data
def load_uploaded_data(uploaded_file):
"""加载上传的数据"""
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
return df
return None
# 训练模型函数
@st.cache_resource
def train_model(df, test_size, max_features, random_state):
"""训练文本分类模型"""
# 准备数据
X = df['text']
y = df['label']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=random_state
)
# 文本向量化
vectorizer = TfidfVectorizer(max_features=max_features, stop_words='english')
X_train_vec = vectorizer.fit_transform(X_train)
X_test_vec = vectorizer.transform(X_test)
# 训练模型
model = MultinomialNB()
model.fit(X_train_vec, y_train)
# 预测
y_pred = model.predict(X_test_vec)
accuracy = accuracy_score(y_test, y_pred)
return model, vectorizer, X_test, y_test, y_pred, accuracy
# 主要应用逻辑
if data_source == "示例数据":
df = load_sample_data()
st.info("正在使用示例数据集")
else:
uploaded_file = st.sidebar.file_uploader("上传CSV文件", type=["csv"])
if uploaded_file is not None:
df = load_uploaded_data(uploaded_file)
st.success("数据上传成功")
else:
st.warning("请上传CSV文件,文件应包含'text'和'label'列")
st.stop()
# 显示数据信息
st.subheader("数据概览")
st.write(f"数据集大小: {len(df)} 条记录")
st.write(f"正样本数: {sum(df['label'])}")
st.write(f"负样本数: {len(df) - sum(df['label'])}")
# 显示前几条数据
st.dataframe(df.head(10))
# 训练模型
if st.button("训练模型"):
with st.spinner("正在训练模型..."):
model, vectorizer, X_test, y_test, y_pred, accuracy = train_model(
df, test_size, max_features, random_state
)
st.success("模型训练完成!")
# 显示模型性能
st.subheader("模型性能")
col1, col2, col3 = st.columns(3)
col1.metric("准确率", f"{accuracy:.2%}")
col2.metric("测试集大小", len(X_test))
col3.metric("特征数量", max_features)
# 分类报告
st.subheader("详细分类报告")
report = classification_report(y_test, y_pred, output_dict=True)
report_df = pd.DataFrame(report).transpose()
st.dataframe(report_df)
# 混淆矩阵
st.subheader("混淆矩阵")
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
ax.set_xlabel('预测标签')
ax.set_ylabel('真实标签')
ax.set_title('混淆矩阵')
st.pyplot(fig)
# 词云图
st.subheader("词云图分析")
tab1, tab2 = st.tabs(["正面情感", "负面情感"])
with tab1:
positive_texts = df[df['label'] == 1]['text'].tolist()
if positive_texts:
positive_text = ' '.join(positive_texts)
wordcloud = WordCloud(width=800, height=400, background_color='white').generate(positive_text)
fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(wordcloud, interpolation='bilinear')
ax.axis('off')
ax.set_title('正面情感词云')
st.pyplot(fig)
with tab2:
negative_texts = df[df['label'] == 0]['text'].tolist()
if negative_texts:
negative_text = ' '.join(negative_texts)
wordcloud = WordCloud(width=800, height=400, background_color='white').generate(negative_text)
fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(wordcloud, interpolation='bilinear')
ax.axis('off')
ax.set_title('负面情感词云')
st.pyplot(fig)
# 实时预测功能
st.subheader("实时预测")
user_input = st.text_area("输入要分类的文本:", height=100)
if st.button("预测情感"):
if user_input:
# 向量化用户输入
user_vec = vectorizer.transform([user_input])
# 预测
prediction = model.predict(user_vec)[0]
probability = model.predict_proba(user_vec)[0]
# 显示结果
if prediction == 1:
st.success(f"预测结果: 正面情感 (置信度: {probability[1]:.2%})")
else:
st.error(f"预测结果: 负面情感 (置信度: {probability[0]:.2%})")
else:
st.warning("请输入要分类的文本")
# 使用说明
with st.expander("应用查看说明"):
st.markdown("""
### 应用功能说明
1. **数据源选择**:
- 示例数据: 使用内置的示例数据集
- 上传数据: 上传自己的CSV文件(需包含text和label列)
2. **模型参数**:
- 测试集比例: 控制训练集和测试集的划分比例
- 最大特征数: TF-IDF向量化的最大特征数量
- 随机种子: 确保实验结果可重现
3. **模型训练**:
- 点击"训练模型"按钮开始训练
- 训练完成后会显示性能指标和可视化结果
4. **实时预测**:
- 在训练完成后,可以输入文本进行实时情感分类
### 数据格式要求
上传的CSV文件应包含以下列:
- `text`: 文本内容
- `label`: 标签(1表示正面,0表示负面)
""")
3.2 图像识别应用
# image_classifier_app.py
import streamlit as st
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
import base64
import time
# 设置页面配置
st.set_page_config(
page_title="图像识别应用",
page_icon="🖼️",
layout="wide"
)
# 页面标题
st.title("🖼️ 图像识别应用")
st.markdown("---")
# 侧边栏配置
st.sidebar.header("识别设置")
# 模型选择
model_type = st.sidebar.selectbox(
"选择模型",
["CNN基础模型", "预训练ResNet", "预训练VGG", "自定义模型"]
)
# 识别参数
confidence_threshold = st.sidebar.slider("置信度阈值", 0.0, 1.0, 0.5, 0.1)
top_k = st.sidebar.slider("显示前K个结果", 1, 10, 5)
# 模拟图像分类函数
def classify_image(image, model_type):
"""模拟图像分类"""
# 模拟分类过程
time.sleep(2)
# 模拟分类结果
classes = ["猫", "狗", "鸟", "汽车", "飞机", "船只", "花卉", "水果", "人物", "建筑"]
probabilities = np.random.dirichlet(np.ones(len(classes)), size=1)[0]
# 按概率排序
sorted_indices = np.argsort(probabilities)[::-1]
results = [(classes[i], probabilities[i]) for i in sorted_indices[:top_k]]
return results
# 主要应用逻辑
tab1, tab2 = st.tabs(["图像上传", "实时摄像头"])
with tab1:
st.header("图像上传识别")
# 图像上传
uploaded_file = st.file_uploader(
"选择图像文件",
type=["jpg", "jpeg", "png", "bmp", "tiff"]
)
if uploaded_file is not None:
# 显示上传的图像
image = Image.open(uploaded_file)
st.image(image, caption="上传的图像", use_column_width=True)
# 显示图像信息
st.subheader("图像信息")
col1, col2, col3 = st.columns(3)
col1.metric("宽度", f"{image.width}px")
col2.metric("高度", f"{image.height}px")
col3.metric("格式", image.format)
# 识别按钮
if st.button("开始识别"):
with st.spinner("正在识别图像..."):
# 模拟识别过程
progress_bar = st.progress(0)
for i in range(100):
time.sleep(0.02)
progress_bar.progress(i + 1)
# 获取识别结果
results = classify_image(image, model_type)
st.success("识别完成!")
# 显示识别结果
st.subheader("识别结果")
for i, (class_name, probability) in enumerate(results, 1):
if probability >= confidence_threshold:
st.markdown(f"**{i}. {class_name}**: {probability:.2%}")
# 进度条显示置信度
st.progress(float(probability))
else:
st.markdown(f"*{i}. {class_name}*: {probability:.2%} (低于阈值)")
# 可视化结果
st.subheader("结果可视化")
fig, ax = plt.subplots(figsize=(10, 6))
class_names = [r[0] for r in results]
probabilities = [r[1] for r in results]
bars = ax.bar(range(len(class_names)), probabilities)
ax.set_xlabel("类别")
ax.set_ylabel("置信度")
ax.set_title("图像识别结果")
ax.set_xticks(range(len(class_names)))
ax.set_xticklabels(class_names, rotation=45)
# 标注置信度数值
for i, (bar, prob) in enumerate(zip(bars, probabilities)):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{prob:.2%}', ha='center', va='bottom')
st.pyplot(fig)
with tab2:
st.header("实时摄像头识别")
st.info("注意:出于安全考虑,Streamlit Web应用通常无法直接访问本地摄像头。此功能需要在本地运行环境中使用。")
st.markdown("""
### 本地运行说明
要使用摄像头功能,请在本地环境中运行此应用:
```bash
# 安装额外依赖
pip install opencv-python
# 运行应用
streamlit run image_classifier_app.py --server.enableCORS=false --server.enableXsrfProtection=false
```
然后在应用中启用摄像头访问权限。
""")
# 模型信息
st.subheader("模型信息")
model_info = {
"CNN基础模型": "简单的卷积神经网络,适合基础图像分类任务",
"预训练ResNet": "基于ImageNet预训练的ResNet模型,具有优秀的特征提取能力",
"预训练VGG": "经典的VGG网络结构,适合细粒度图像分类",
"自定义模型": "用户自定义的模型架构"
}
st.info(model_info[model_type])
# 性能比较
st.subheader("模型性能比较")
performance_data = pd.DataFrame({
'模型': ['CNN基础模型', '预训练ResNet', '预训练VGG', '自定义模型'],
'准确率': [0.75, 0.92, 0.88, 0.85],
'推理时间(ms)': [50, 120, 150, 80],
'模型大小(MB)': [15, 90, 500, 30]
})
st.dataframe(performance_data)
# 可视化性能比较
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# 准确率比较
ax1.bar(performance_data['模型'], performance_data['准确率'])
ax1.set_title('模型准确率比较')
ax1.set_ylabel('准确率')
ax1.tick_params(axis='x', rotation=45)
# 推理时间比较
ax2.bar(performance_data['模型'], performance_data['推理时间(ms)'])
ax2.set_title('模型推理时间比较')
ax2.set_ylabel('时间(ms)')
ax2.tick_params(axis='x', rotation=45)
plt.tight_layout()
st.pyplot(fig)
4. 高级功能实现
4.1 自定义组件和主题
# advanced_features_app.py
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from datetime import datetime, timedelta
import json
import time
# 设置自定义主题
st.set_page_config(
page_title="高级功能演示",
page_icon="✨",
layout="wide",
initial_sidebar_state="expanded"
)
# 自定义CSS样式
st.markdown("""
<style>
.stApp {
background-color: #f0f2f6;
}
.css-1d391kg {
background-color: #ffffff;
}
.stButton>button {
background-color: #4CAF50;
color: white;
border-radius: 10px;
border: none;
padding: 10px 20px;
font-size: 16px;
transition: all 0.3s;
}
.stButton>button:hover {
background-color: #45a049;
transform: scale(1.05);
}
.success-box {
padding: 10px;
border-radius: 5px;
background-color: #d4edda;
border: 1px solid #c3e6cb;
color: #155724;
}
.warning-box {
padding: 10px;
border-radius: 5px;
background-color: #fff3cd;
border: 1px solid #ffeaa7;
color: #856404;
}
</style>
""", unsafe_allow_html=True)
# 页面标题
st.title("✨ Streamlit高级功能演示")
st.markdown("---")
# 自定义组件演示
st.header("🎨 自定义组件演示")
# 使用columns创建布局
col1, col2, col3 = st.columns(3)
with col1:
st.subheader("卡片1")
st.markdown('<div class="success-box">这是一个自定义样式的成功消息框</div>', unsafe_allow_html=True)
if st.button("按钮1"):
st.info("按钮1被点击了")
with col2:
st.subheader("卡片2")
st.markdown('<div class="warning-box">这是一个自定义样式的警告消息框</div>', unsafe_allow_html=True)
if st.button("按钮2"):
st.success("按钮2被点击了")
with col3:
st.subheader("卡片3")
st.error("这是一个错误消息")
if st.button("按钮3"):
st.warning("按钮3被点击了")
# 交互式图表
st.header("📊 交互式图表")
# 生成示例数据
@st.cache_data
def generate_time_series_data():
"""生成时间序列数据"""
dates = pd.date_range(start='2023-01-01', end='2023-12-31', freq='D')
values = np.cumsum(np.random.randn(len(dates))) + 100
return pd.DataFrame({'date': dates, 'value': values})
df = generate_time_series_data()
# 使用Plotly创建交互式图表
fig = px.line(df, x='date', y='value', title='交互式时间序列图')
fig.update_layout(
xaxis_title="日期",
yaxis_title="数值",
hovermode='x unified'
)
st.plotly_chart(fig, use_container_width=True)
# 3D图表
st.subheader("3D散点图")
np.random.seed(42)
n_points = 1000
x = np.random.randn(n_points)
y = np.random.randn(n_points)
z = np.random.randn(n_points)
colors = np.random.randn(n_points)
fig_3d = go.Figure(data=[go.Scatter3d(
x=x, y=y, z=z,
mode='markers',
marker=dict(
size=5,
color=colors,
colorscale='Viridis',
showscale=True
)
)])
fig_3d.update_layout(
title="3D散点图演示",
scene=dict(
xaxis_title="X轴",
yaxis_title="Y轴",
zaxis_title="Z轴"
)
)
st.plotly_chart(fig_3d, use_container_width=True)
# 进度和状态管理
st.header("🔄 进度和状态管理")
# 进度条演示
if st.button("开始长时间任务"):
progress_text = st.empty()
progress_bar = st.progress(0)
for i in range(100):
time.sleep(0.05) # 模拟任务执行
progress_bar.progress(i + 1)
progress_text.text(f"任务进度: {i+1}%")
st.success("任务完成!")
# 状态管理演示
st.subheader("状态管理")
if 'counter' not in st.session_state:
st.session_state.counter = 0
col1, col2, col3 = st.columns(3)
with col1:
if st.button("增加计数"):
st.session_state.counter += 1
with col2:
if st.button("减少计数"):
st.session_state.counter -= 1
with col3:
if st.button("重置计数"):
st.session_state.counter = 0
st.metric("当前计数", st.session_state.counter)
# 缓存机制演示
st.header("⚡ 缓存机制")
@st.cache_data(ttl=3600) # 缓存1小时
def expensive_computation(n):
"""模拟耗时计算"""
time.sleep(2) # 模拟耗时操作
return sum(i**2 for i in range(n))
computation_input = st.number_input("输入计算参数", 1, 10000, 1000)
if st.button("执行耗时计算"):
with st.spinner("正在执行计算..."):
result = expensive_computation(computation_input)
st.success(f"计算结果: {result}")
# 多页面导航
st.header("🧭 多页面导航")
# 使用session_state管理页面状态
if 'current_page' not in st.session_state:
st.session_state.current_page = "首页"
# 页面导航
page = st.radio(
"选择页面",
["首页", "数据分析", "模型展示", "设置"],
horizontal=True
)
st.session_state.current_page = page
# 根据选择的页面显示内容
if st.session_state.current_page == "首页":
st.subheader("欢迎来到首页")
st.write("这是应用的主页,展示核心功能和最新动态。")
elif st.session_state.current_page == "数据分析":
st.subheader("数据分析页面")
st.write("在这里可以进行各种数据分析操作。")
# 生成示例图表
data = pd.DataFrame({
'category': ['A', 'B', 'C', 'D', 'E'],
'values': np.random.randint(10, 100, 5)
})
fig, ax = plt.subplots()
ax.bar(data['category'], data['values'])
ax.set_title("示例柱状图")
st.pyplot(fig)
elif st.session_state.current_page == "模型展示":
st.subheader("模型展示页面")
st.write("展示各种机器学习模型的性能和应用。")
# 模型性能对比
models = ['逻辑回归', '决策树', '随机森林', '支持向量机']
accuracy = [0.85, 0.82, 0.88, 0.86]
fig, ax = plt.subplots()
ax.bar(models, accuracy)
ax.set_title("模型准确率对比")
ax.set_ylabel("准确率")
plt.xticks(rotation=45)
st.pyplot(fig)
else: # 设置页面
st.subheader("设置页面")
st.write("在这里可以配置应用的各种参数。")
# 应用设置
theme = st.selectbox("选择主题", ["默认", "暗色", "明亮"])
language = st.selectbox("选择语言", ["中文", "英文"])
if st.button("保存设置"):
st.success("设置已保存")
# 文件上传和下载
st.header("📁 文件操作")
# 文件上传
uploaded_file = st.file_uploader("上传文件", type=["txt", "csv", "json"])
if uploaded_file is not None:
st.success(f"文件 {uploaded_file.name} 上传成功")
# 根据文件类型处理
if uploaded_file.name.endswith('.csv'):
df = pd.read_csv(uploaded_file)
st.dataframe(df.head())
elif uploaded_file.name.endswith('.json'):
data = json.load(uploaded_file)
st.json(data)
else:
content = uploaded_file.read().decode('utf-8')
st.text_area("文件内容", content, height=200)
# 文件下载
st.subheader("文件下载")
download_data = pd.DataFrame({
'name': ['Alice', 'Bob', 'Charlie'],
'age': [25, 30, 35],
'city': ['北京', '上海', '广州']
})
@st.cache_data
def convert_df(df):
return df.to_csv(index=False).encode('utf-8')
csv = convert_df(download_data)
st.download_button(
label="下载示例CSV文件",
data=csv,
file_name='sample_data.csv',
mime='text/csv'
)
# 实时数据更新
st.header("⏰ 实时数据更新")
# 使用empty创建动态更新区域
live_data_placeholder = st.empty()
# 模拟实时数据更新
if st.button("开始实时更新"):
for i in range(10):
# 生成新数据
new_data = pd.DataFrame({
'时间': [datetime.now().strftime("%H:%M:%S")],
'数值': [np.random.randint(1, 100)]
})
# 更新显示
live_data_placeholder.dataframe(new_data)
time.sleep(1)
st.success("实时更新完成")
# 帮助和文档
st.header("📘 帮助文档")
with st.expander("应用使用说明"):
st.markdown("""
### 应用功能说明
1. **自定义组件**: 展示如何使用CSS自定义组件样式
2. **交互式图表**: 使用Plotly创建交互式数据可视化
3. **进度管理**: 显示长时间任务的进度状态
4. **缓存机制**: 利用缓存提高应用性能
5. **多页面导航**: 实现应用内的页面切换
6. **文件操作**: 支持文件上传和下载
7. **实时更新**: 演示动态数据更新功能
### 最佳实践
- 合理使用缓存避免重复计算
- 使用session_state管理用户状态
- 为长时间任务提供进度反馈
- 优化图表渲染性能
- 提供清晰的用户指引
""")
# 页脚
st.markdown("---")
st.markdown("© 2025 Streamlit高级功能演示. 保留所有权利.")
4.2 数据缓存和性能优化
# performance_optimization_app.py
import streamlit as st
import pandas as pd
import numpy as np
import time
import functools
from datetime import datetime
import hashlib
# 设置页面配置
st.set_page_config(
page_title="性能优化演示",
page_icon="⚡",
layout="wide"
)
st.title("⚡ Streamlit性能优化演示")
st.markdown("---")
# 性能测试函数
def time_it(func):
"""性能测试装饰器"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
return result, execution_time
return wrapper
# 数据生成函数
@st.cache_data
def generate_large_dataset(size=100000):
"""生成大型数据集"""
np.random.seed(42)
data = {
'id': range(size),
'value1': np.random.randn(size),
'value2': np.random.randn(size),
'category': np.random.choice(['A', 'B', 'C', 'D'], size),
'date': pd.date_range('2020-01-01', periods=size, freq='1min')
}
return pd.DataFrame(data)
@st.cache_data
def expensive_calculation(data, operation='sum'):
"""耗时计算"""
time.sleep(2) # 模拟耗时操作
if operation == 'sum':
return data['value1'].sum()
elif operation == 'mean':
return data['value1'].mean()
elif operation == 'std':
return data['value1'].std()
# 资源缓存示例
@st.cache_resource
def load_model():
"""模拟模型加载"""
time.sleep(3) # 模拟模型加载时间
return {"model": "complex_model", "version": "1.0"}
# 主要内容
tab1, tab2, tab3 = st.tabs(["缓存演示", "性能对比", "最佳实践"])
with tab1:
st.header("キャッシング演示")
# 数据集大小选择
dataset_size = st.slider("选择数据集大小", 1000, 100000, 10000, 1000)
# 生成数据(带时间测量)
st.subheader("数据生成性能")
if st.button("生成数据"):
with st.spinner("正在生成数据..."):
data, gen_time = time_it(generate_large_dataset)(dataset_size)
st.success(f"数据生成完成,耗时: {gen_time:.4f}秒")
st.info(f"数据集大小: {len(data)} 行")
# 显示数据统计
st.subheader("数据统计")
st.write(data.describe())
# 计算操作
st.subheader("计算操作性能")
operation = st.selectbox("选择计算操作", ["sum", "mean", "std"])
if st.button("执行计算"):
with st.spinner("正在执行计算..."):
data = generate_large_dataset(dataset_size)
result, calc_time = time_it(expensive_calculation)(data, operation)
st.success(f"计算完成,耗时: {calc_time:.4f}秒")
st.metric(f"{operation.upper()}结果", f"{result:.4f}")
with tab2:
st.header("性能对比测试")
# 缓存vs无缓存对比
st.subheader("缓存性能对比")
test_size = st.slider("测试数据大小", 1000, 50000, 10000, 1000)
# 无缓存版本
def no_cache_generate(size):
np.random.seed(42)
return pd.DataFrame({
'x': np.random.randn(size),
'y': np.random.randn(size)
})
# 缓存版本
@st.cache_data
def cache_generate(size):
np.random.seed(42)
return pd.DataFrame({
'x': np.random.randn(size),
'y': np.random.randn(size)
})
if st.button("运行性能对比"):
# 测试无缓存版本
with st.spinner("测试无缓存版本..."):
_, no_cache_time = time_it(no_cache_generate)(test_size)
# 测试缓存版本(第一次)
with st.spinner("测试缓存版本(首次)..."):
_, first_cache_time = time_it(cache_generate)(test_size)
# 测试缓存版本(第二次)
with st.spinner("测试缓存版本(缓存命中)..."):
_, second_cache_time = time_it(cache_generate)(test_size)
# 显示结果
st.subheader("性能测试结果")
performance_data = pd.DataFrame({
'方法': ['无缓存', '缓存(首次)', '缓存(命中)'],
'耗时(秒)': [no_cache_time, first_cache_time, second_cache_time]
})
st.dataframe(performance_data)
# 可视化对比
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(performance_data['方法'], performance_data['耗时(秒)'])
ax.set_ylabel('耗时(秒)')
ax.set_title('性能对比')
# 添加数值标签
for bar in bars:
height = bar.get_height()
ax.annotate(f'{height:.4f}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3),
textcoords="offset points",
ha='center', va='bottom')
st.pyplot(fig)
# 计算性能提升
if second_cache_time > 0:
speedup = no_cache_time / second_cache_time
st.metric("缓存命中性能提升", f"{speedup:.1f}倍")
with tab3:
st.header("性能优化最佳实践")
st.markdown("""
### 缓存使用最佳实践
1. **@st.cache_data** - 用于缓存数据处理结果
```python
@st.cache_data
def load_and_process_data():
# 数据加载和处理逻辑
return processed_data
```
2. **@st.cache_resource** - 用于缓存全局资源
```python
@st.cache_resource
def load_model():
# 模型加载逻辑
return model
```
3. **合理设置缓存参数**
```python
@st.cache_data(ttl=3600, max_entries=100)
def expensive_function():
pass
```
### 性能优化技巧
1. **延迟加载** - 只在需要时加载数据
2. **分批处理** - 大数据集分批处理
3. **异步操作** - 使用st.spinner提供反馈
4. **内存管理** - 及时释放不需要的资源
5. **CDN使用** - 静态资源使用CDN加速
### 常见性能问题
1. **重复计算** - 未使用缓存导致重复执行
2. **大数据传输** - 传输过大的数据给前端
3. **阻塞操作** - 长时间运行的同步操作
4. **内存泄漏** - 未正确释放资源
""")
# 性能监控工具
st.subheader("性能监控")
if st.button("检查应用性能"):
import psutil
import os
# 获取当前进程信息
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
col1, col2, col3 = st.columns(3)
col1.metric("内存使用", f"{memory_info.rss / 1024 / 1024:.1f} MB")
col2.metric("CPU使用率", f"{process.cpu_percent()}%")
col3.metric("线程数", process.num_threads())
st.info("定期监控应用性能有助于及时发现和解决性能瓶颈。")
# 缓存管理
st.sidebar.header("缓存管理")
if st.sidebar.button("清除所有缓存"):
st.cache_data.clear()
st.cache_resource.clear()
st.sidebar.success("缓存已清除")
st.sidebar.markdown("---")
st.sidebar.info("使用缓存可以显著提升应用性能,但也要注意内存使用情况。")
5. 部署与扩展
5.1 本地部署
# 本地部署脚本
#!/bin/bash
# streamlit_app.sh
echo "🚀 启动Streamlit应用"
# 检查虚拟环境
if [ -d "venv" ]; then
echo "激活虚拟环境"
source venv/bin/activate
else
echo "创建虚拟环境"
python -m venv venv
source venv/bin/activate
echo "安装依赖"
pip install -r requirements.txt
fi
# 启动应用
echo "启动Streamlit应用"
streamlit run app.py --server.port 8501 --server.address 0.0.0.0
echo "应用查看地址: http://localhost:8501"
5.2 云端部署配置
# deployment_config.py
import streamlit as st
import os
# Streamlit配置
STREAMLIT_CONFIG = {
# 服务器配置
"server.port": int(os.getenv("PORT", 8501)),
"server.address": "0.0.0.0",
"server.enableCORS": False,
"server.enableXsrfProtection": False,
# 浏览器配置
"browser.serverAddress": "0.0.0.0",
"browser.serverPort": int(os.getenv("PORT", 8501)),
# 主题配置
"theme.primaryColor": "#4CAF50",
"theme.backgroundColor": "#FFFFFF",
"theme.secondaryBackgroundColor": "#F0F2F6",
"theme.textColor": "#262730",
"theme.font": "sans serif"
}
# 应用配置
APP_CONFIG = {
"title": "AI应用平台",
"icon": "🤖",
"layout": "wide",
"initial_sidebar_state": "expanded"
}
# 数据库配置
DATABASE_CONFIG = {
"host": os.getenv("DB_HOST", "localhost"),
"port": int(os.getenv("DB_PORT", 5432)),
"database": os.getenv("DB_NAME", "ai_app"),
"user": os.getenv("DB_USER", "ai_user"),
"password": os.getenv("DB_PASSWORD", "ai_password")
}
# Redis配置(用于缓存)
REDIS_CONFIG = {
"host": os.getenv("REDIS_HOST", "localhost"),
"port": int(os.getenv("REDIS_PORT", 6379)),
"db": int(os.getenv("REDIS_DB", 0))
}
# 日志配置
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"standard": {
"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
},
},
"handlers": {
"default": {
"level": "INFO",
"formatter": "standard",
"class": "logging.StreamHandler",
},
},
"loggers": {
"": {
"handlers": ["default"],
"level": "INFO",
"propagate": False
}
}
}
def apply_config():
"""应用配置"""
st.set_page_config(**APP_CONFIG)
# 部署检查
def deployment_check():
"""部署环境检查"""
checks = []
# 检查环境变量
required_env_vars = ["PORT"]
for var in required_env_vars:
if not os.getenv(var):
checks.append(f"⚠️ 缺少环境变量: {var}")
# 检查依赖
try:
import streamlit
checks.append("✅ Streamlit已安装")
except ImportError:
checks.append("❌ Streamlit未安装")
try:
import pandas
checks.append("✅ Pandas已安装")
except ImportError:
checks.append("❌ Pandas未安装")
return checks
# 在应用中使用配置
if __name__ == "__main__":
# 应用配置
apply_config()
st.title("部署配置检查")
# 执行部署检查
checks = deployment_check()
for check in checks:
if check.startswith("✅"):
st.success(check)
elif check.startswith("⚠️"):
st.warning(check)
else:
st.error(check)
# 显示配置信息
st.subheader("当前配置")
st.json({
"Streamlit配置": STREAMLIT_CONFIG,
"应用配置": APP_CONFIG,
"数据库配置": DATABASE_CONFIG,
"Redis配置": REDIS_CONFIG
})
5.3 Docker部署
# Dockerfile
FROM python:3.9-slim
# 设置工作目录
WORKDIR /app
# 设置环境变量
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
# 安装系统依赖
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
gcc \
g++ \
curl \
git \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 创建非root用户
RUN adduser --disabled-password --gecos '' appuser \
&& chown -R appuser:appuser /app
USER appuser
# 暴露端口
EXPOSE 8501
# 健康检查
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8501/healthz || exit 1
# 启动命令
CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
# docker-compose.yml
version: '3.8'
services:
streamlit-app:
build: .
ports:
- "8501:8501"
environment:
- PORT=8501
- DB_HOST=postgres
- DB_PORT=5432
- DB_NAME=ai_app
- DB_USER=ai_user
- DB_PASSWORD=ai_password
- REDIS_HOST=redis
- REDIS_PORT=6379
volumes:
- ./data:/app/data
- ./models:/app/models
depends_on:
- postgres
- redis
restart: unless-stopped
postgres:
image: postgres:13-alpine
environment:
POSTGRES_DB: ai_app
POSTGRES_USER: ai_user
POSTGRES_PASSWORD: ai_password
volumes:
- postgres_data:/var/lib/postgresql/data
restart: unless-stopped
redis:
image: redis:6-alpine
volumes:
- redis_data:/data
restart: unless-stopped
nginx:
image: nginx:alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
- ./ssl:/etc/nginx/ssl
depends_on:
- streamlit-app
restart: unless-stopped
volumes:
postgres_data:
redis_data:
6. 最佳实践与注意事项
6.1 代码组织结构

6.2 性能优化建议
# performance_tips.py
import streamlit as st
import pandas as pd
import numpy as np
import time
import functools
from typing import Any, Callable
# 1. 智能缓存装饰器
def smart_cache(ttl: int = 3600, max_entries: int = 128):
"""
智能缓存装饰器,支持TTL和最大条目数
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
@st.cache_data(ttl=ttl, max_entries=max_entries)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorator
# 2. 懒加载数据
class LazyDataLoader:
"""
懒加载数据管理器
"""
def __init__(self):
self._data = {}
@smart_cache(ttl=1800)
def load_data(self, data_type: str, **kwargs) -> pd.DataFrame:
"""
懒加载数据
"""
if data_type not in self._data:
with st.spinner(f"正在加载{data_type}数据..."):
# 模拟数据加载
time.sleep(1)
if data_type == "sales":
self._data[data_type] = self._generate_sales_data(**kwargs)
elif data_type == "user":
self._data[data_type] = self._generate_user_data(**kwargs)
return self._data[data_type]
def _generate_sales_data(self, days: int = 30) -> pd.DataFrame:
"""生成销售数据"""
dates = pd.date_range(end=pd.Timestamp.now(), periods=days)
data = {
'date': dates,
'sales': np.random.randint(1000, 10000, days),
'visitors': np.random.randint(100, 1000, days)
}
return pd.DataFrame(data)
def _generate_user_data(self, count: int = 1000) -> pd.DataFrame:
"""生成用户数据"""
data = {
'user_id': range(1, count + 1),
'age': np.random.randint(18, 80, count),
'gender': np.random.choice(['M', 'F'], count),
'region': np.random.choice(['北京', '上海', '广州', '深圳', '其他'], count)
}
return pd.DataFrame(data)
# 3. 异步任务处理器
class AsyncTaskHandler:
"""
异步任务处理器
"""
def __init__(self):
self.tasks = []
def add_task(self, name: str, func: Callable, *args, **kwargs):
"""
添加异步任务
"""
task = {
'name': name,
'func': func,
'args': args,
'kwargs': kwargs,
'status': 'pending',
'result': None
}
self.tasks.append(task)
return len(self.tasks) - 1
def run_task(self, task_id: int):
"""
运行任务
"""
if 0 <= task_id < len(self.tasks):
task = self.tasks[task_id]
try:
task['status'] = 'running'
task['result'] = task['func'](*task['args'], **task['kwargs'])
task['status'] = 'completed'
except Exception as e:
task['status'] = 'failed'
task['result'] = str(e)
def get_task_status(self, task_id: int) -> dict:
"""
获取任务状态
"""
if 0 <= task_id < len(self.tasks):
return self.tasks[task_id]
return None
# 4. 内存优化工具
class MemoryOptimizer:
"""
内存优化工具
"""
@staticmethod
def optimize_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
优化DataFrame内存使用
"""
optimized_df = df.copy()
for col in optimized_df.columns:
col_type = optimized_df[col].dtype
if col_type != 'object':
c_min = optimized_df[col].min()
c_max = optimized_df[col].max()
if str(col_type)[:3] == 'int':
if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
optimized_df[col] = optimized_df[col].astype(np.int8)
elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
optimized_df[col] = optimized_df[col].astype(np.int16)
elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
optimized_df[col] = optimized_df[col].astype(np.int32)
else:
if c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:
optimized_df[col] = optimized_df[col].astype(np.float32)
return optimized_df
@staticmethod
def clear_unused_data():
"""
清理未使用的数据
"""
# 清理Streamlit缓存
st.cache_data.clear()
# 在实际应用中,可以添加更多清理逻辑
# 使用示例
def performance_demo():
"""
性能优化演示
"""
st.title("性能优化演示")
# 懒加载数据演示
loader = LazyDataLoader()
if st.button("加载销售数据"):
with st.spinner("正在加载数据..."):
sales_data = loader.load_data("sales", days=30)
st.success("数据加载完成")
st.dataframe(sales_data.head())
# 异步任务演示
task_handler = AsyncTaskHandler()
if st.button("运行长时间任务"):
task_id = task_handler.add_task(
"数据处理",
lambda: time.sleep(3) or "任务完成"
)
with st.spinner("任务运行中..."):
task_handler.run_task(task_id)
task_status = task_handler.get_task_status(task_id)
if task_status['status'] == 'completed':
st.success(f"任务完成: {task_status['result']}")
else:
st.error(f"任务失败: {task_status['result']}")
if __name__ == "__main__":
performance_demo()
7. 常见问题解答
7.1 部署相关问题
Q: 如何解决部署后访问速度慢的问题?
A: 部署后访问速度慢的解决方案:
- 使用CDN加速:将静态资源部署到CDN
- 启用Gzip压缩:减少数据传输量
- 优化图片资源:压缩和格式优化
- 使用缓存策略:合理配置浏览器缓存
# nginx.conf示例
server {
listen 80;
server_name your-domain.com;
# 启用Gzip压缩
gzip on;
gzip_types text/plain text/css application/json application/javascript text/xml application/xml;
# 静态资源缓存
location ~* \.(jpg|jpeg|png|gif|ico|css|js)$ {
expires 1y;
add_header Cache-Control "public, immutable";
}
# 代理到Streamlit应用
location / {
proxy_pass http://localhost:8501;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
}
}
Q: 如何实现多用户访问控制?
A: 多用户访问控制实现方案:
- 集成认证服务:使用OAuth、JWT等认证机制
- 会话管理:使用session_state管理用户状态
- 权限控制:基于角色的访问控制(RBAC)
# auth_example.py
import streamlit as st
from datetime import datetime, timedelta
import jwt
# 模拟用户数据库
USERS = {
"admin": {"password": "admin123", "role": "admin"},
"user": {"password": "user123", "role": "user"}
}
# JWT配置
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
def create_token(username: str, role: str) -> str:
"""创建JWT令牌"""
payload = {
"sub": username,
"role": role,
"exp": datetime.utcnow() + timedelta(hours=24)
}
return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
def verify_token(token: str) -> dict:
"""验证JWT令牌"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except jwt.PyJWTError:
return None
def login_page():
"""登录页面"""
st.title("用户登录")
username = st.text_input("用户名")
password = st.text_input("密码", type="password")
if st.button("登录"):
if username in USERS and USERS[username]["password"] == password:
token = create_token(username, USERS[username]["role"])
st.session_state.token = token
st.session_state.user = username
st.session_state.role = USERS[username]["role"]
st.success("登录成功")
st.experimental_rerun()
else:
st.error("用户名或密码错误")
def protected_page():
"""受保护的页面"""
if 'token' not in st.session_state:
login_page()
return False
# 验证令牌
payload = verify_token(st.session_state.token)
if not payload:
st.error("会话已过期,请重新登录")
if st.button("重新登录"):
del st.session_state.token
st.experimental_rerun()
return False
# 显示用户信息
st.sidebar.success(f"欢迎, {payload['sub']}!")
if st.sidebar.button("退出登录"):
del st.session_state.token
st.experimental_rerun()
return True
# 在主应用中使用
def main():
if protected_page():
st.title("受保护的应用内容")
st.write(f"当前用户角色: {st.session_state.role}")
# 根据角色显示不同内容
if st.session_state.role == "admin":
st.subheader("管理员功能")
st.write("管理员专有功能")
else:
st.subheader("用户功能")
st.write("普通用户功能")
if __name__ == "__main__":
main()
7.2 性能相关问题
Q: 如何处理大数据集的展示问题?
A: 大数据集展示优化方案:
- 分页展示:使用st.dataframe的分页功能
- 虚拟滚动:只渲染可见区域的数据
- 数据采样:对大数据集进行采样展示
- 懒加载:按需加载数据
# large_data_display.py
import streamlit as st
import pandas as pd
import numpy as np
def display_large_dataset(df: pd.DataFrame, page_size: int = 100):
"""
分页展示大数据集
"""
total_rows = len(df)
total_pages = (total_rows - 1) // page_size + 1
# 页面选择
page = st.sidebar.number_input(
"选择页面",
min_value=1,
max_value=total_pages,
value=1
)
# 计算显示范围
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_rows)
# 显示数据
st.write(f"显示第 {start_idx+1} - {end_idx} 行,共 {total_rows} 行")
st.dataframe(df.iloc[start_idx:end_idx])
# 显示分页信息
st.sidebar.write(f"页面 {page}/{total_pages}")
# 使用示例
if __name__ == "__main__":
# 生成大型数据集
large_df = pd.DataFrame({
'id': range(10000),
'value': np.random.randn(10000),
'category': np.random.choice(['A', 'B', 'C'], 10000)
})
st.title("大数据集展示示例")
display_large_dataset(large_df, page_size=50)
8. 扩展阅读
8.1 学习资源推荐
8.2 相关工具和库
-
可视化库:
- Plotly: 交互式图表
- Altair: 声明式统计可视化
- Bokeh: 交互式Web可视化
-
数据处理库:
- Pandas: 数据分析和处理
- NumPy: 数值计算
- Dask: 并行计算
-
机器学习库:
- Scikit-learn: 传统机器学习
- TensorFlow/PyTorch: 深度学习
- XGBoost/LightGBM: 梯度提升
-
部署工具:
- Docker: 容器化部署
- Kubernetes: 容器编排
- Nginx: 反向代理
实施计划甘特图
交互流程时序图
总结
本文全面介绍了使用Streamlit构建AI应用界面的完整流程,从基础概念到高级应用,涵盖了实际开发中的各个方面。
关键要点回顾
-
Streamlit优势:简单易学、快速开发、实时交互、丰富的组件库
-
核心功能:文本组件、数据可视化、模型集成、用户交互
-
高级特性:缓存机制、性能优化、自定义组件、多页面导航
-
部署方案:本地部署、Docker容器化、云平台部署
-
最佳实践:代码组织、性能优化、用户体验、安全考虑
实践建议
-
从简单开始:先实现基础功能,再逐步添加复杂特性
-
重视用户体验:设计直观的界面和流畅的交互流程
-
合理使用缓存:利用缓存机制提升应用性能
-
做好错误处理:为用户提供清晰的错误提示和解决方案
-
持续优化迭代:根据用户反馈和使用数据持续改进应用
未来展望
随着AI技术的不断发展,Streamlit在AI应用开发领域将发挥更大作用:
- 组件生态:更丰富的第三方组件库
- 性能提升:更快的渲染速度和更低的资源消耗
- AI集成:更深度的AI框架集成
- 协作功能:支持团队协作开发的特性
- 移动端适配:更好的移动设备支持
通过正确应用Streamlit,AI开发者可以快速将模型转化为用户友好的应用界面,大大缩短从原型到产品的时间,提升开发效率和用户满意度。
1105

被折叠的 条评论
为什么被折叠?



