FastAPI 中的 JWT 认证与访问限流实现

在现代 Web 应用程序中,保护资源和控制访问是至关重要的。本文将深入探讨如何在 FastAPI 中实现 JWT 认证,并介绍如何根据用户和 IP 实现访问限流。我们将从基础概念开始,逐步深入到实际代码实现,并提供详细的解释和示例。

背景故事

想象一下,你正在开发一个在线购物平台,用户可以浏览商品、添加到购物车并完成支付。为了保护用户的敏感信息和平台的资源,你需要实现一个安全的认证机制和访问控制。JWT(JSON Web Token)是一种广泛使用的认证方式,而 FastAPI 提供了强大的工具来实现这一目标。此外,为了防止滥用 API 并保护服务器资源,我们还需要实现访问限流机制。

FastAPI 中的 JWT 认证

1. JWT 基础概念

JWT 的定义和用途

JSON Web Token (JWT) 是一种开放标准 (RFC 7519),用于在网络应用环境间安全地将信息作为 JSON 对象传输。JWT 通常用于身份验证和信息交换,可以确保数据在传输过程中的完整性和安全性。

JWT 的结构

JWT 由三部分组成:Header(头部)、Payload(载荷)和 Signature(签名),它们由点 (.) 分隔。

  • Header:包含令牌的类型(通常是 JWT)和所使用的签名算法(如 HMAC SHA256 或 RSA)。
  • Payload:包含声明(claims),即关于实体(通常是用户)和其他数据的信息。声明分为三种类型:注册声明(Registered Claims)、公共声明(Public Claims)和私有声明(Private Claims)。
  • Signature:对前两部分进行签名,确保数据未被篡改。签名的生成方式如下:
    HMACSHA256(
      base64UrlEncode(header) + "." +
      base64UrlEncode(payload),
      secret)
    
JWT 的优势
  • 无状态认证:服务器不需要存储会话信息,减轻了服务器负载。
  • 跨域:由于 JWT 是基于令牌的认证方式,可以在不同域之间传递。
  • 信息丰富:可以在 Payload 中携带丰富的用户信息,减少数据库查询次数。

2. FastAPI 环境设置

安装 FastAPI 和 Uvicorn

首先,确保你的 Python 环境中安装了 FastAPI 和 Uvicorn。Uvicorn 是一个 ASGI 服务器,用于运行 FastAPI 应用。

pip install fastapi uvicorn
安装 JWT 相关库

我们将使用 python-jose 库来生成和验证 JWT,使用 passlib 来加密用户密码。

pip install python-jose passlib[bcrypt]

3. 生成和验证 JWT

创建用户登录的 API 端点

实现用户登录的逻辑。我们需要一个简单的用户模型和密码加密功能。

from pydantic import BaseModel
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from datetime import datetime, timedelta

app = FastAPI()

# 密钥和算法
SECRET_KEY = "your_secret_key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

# 密码加密
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# OAuth2
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

# 示例用户数据库
fake_users_db = {
    "johndoe": {
        "username": "johndoe",
        "full_name": "John Doe",
        "email": "johndoe@example.com",
        "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW",  # 密码: secret
        "disabled": False,
    }
}

# 模型类
class User(BaseModel):
    username: str
    email: str | None = None
    full_name: str | None = None
    disabled: bool | None = None

class UserInDB(User):
    hashed_password: str

def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

def get_password_hash(password):
    return pwd_context.hash(password)

def get_user(db, username: str):
    if username in db:
        user_dict = db[username]
        return UserInDB(**user_dict)

def authenticate_user(fake_db, username: str, password: str):
    user = get_user(fake_db, username)
    if not user:
        return False
    if not verify_password(password, user.hashed_password):
        return False
    return user

def create_access_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

async def get_current_user(token: str = Depends(oauth2_scheme)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=username)
    if user is None:
        raise credentials_exception
    return user

async def get_current_active_user(current_user: User = Depends(get_current_user)):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user

@app.post("/token")
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
    user = authenticate_user(fake_users_db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}
生成 JWT

在用户登录成功后,使用 python-jose 生成 JWT。注意选择和管理 Secret Key 的安全,不要将其硬编码在代码中,可以使用环境变量或配置文件。

def create_access_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt
验证 JWT

在每个请求中解析和验证 JWT 的合法性。如果验证失败,返回特定的 HTTP 状态码和错误信息。

async def get_current_user(token: str = Depends(oauth2_scheme)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=username)
    if user is None:
        raise credentials_exception
    return user

4. 保护路由

依赖注入机制

利用 FastAPI 的依赖注入来保护路由。创建一个函数用于解析和验证 JWT,然后将其应用到需要保护的路由上。

async def get_current_active_user(current_user: User = Depends(get_current_user)):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user

@app.get("/users/me")
async def read_users_me(current_user: User = Depends(get_current_active_user)):
    return current_user
验证 JWT 的依赖函数

创建一个依赖函数来验证 JWT,并根据需要处理不同角色的用户访问控制。

async def get_current_active_user(current_user: User = Depends(get_current_user)):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user

@app.get("/admin")
async def read_admin_info(current_user: User = Depends(get_current_active_user)):
    if current_user.username != "admin":
        raise HTTPException(status_code=403, detail="Forbidden")
    return {"message": "Admin information"}

5. JWT 过期和刷新机制

过期时间设置

定义 JWT 的有效期以确保安全。可以通过设置 expires_delta 参数来控制 Token 的过期时间。

access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
    data={"sub": user.username}, expires_delta=access_token_expires
)
刷新 token 机制

实现 token 刷新机制,以保持用户会话。通常需要一个单独的刷新 Token 和相应的逻辑。

REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24  # 24 hours

@app.post("/refresh_token")
async def refresh_access_token(refresh_token: str = Depends(oauth2_scheme)):
    try:
        payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise HTTPException(status_code=401, detail="Invalid refresh token")
    except JWTError:
        raise HTTPException(status_code=401, detail="Invalid refresh token")
    user = get_user(fake_users_db, username=username)
    if user is None:
        raise HTTPException(status_code=401, detail="Invalid refresh token")
    
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}

FastAPI 中的访问限流

1. 访问限流的概念

目的

访问限流的主要目的是防止滥用 API,保护服务器资源。通过限制请求速率,可以防止恶意用户或爬虫对系统的攻击。

常见策略

常见的限流策略包括:

  • 固定窗口计数:在固定的时间窗口内统计请求数量。
  • 滑动窗口计数:在滑动的时间窗口内统计请求数量。
  • 漏桶算法:以恒定速率处理请求,超出速率的请求会被丢弃或排队等待。

2. 根据 IP 的限流实现

使用 fastapi-limiter

fastapi-limiter 是一个用于 FastAPI 的限流库,可以方便地实现 IP 限流。首先安装该库。

pip install fastapi-limiter
配置 Redis

Redis 作为限流数据的存储后端。确保 Redis 已经安装并运行。

docker run --name redis -d -p 6379:6379 redis
应用限流中间件

在应用中配置限流规则(如每分钟请求数)。

from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
import redis.asyncio as redis

@app.on_event("startup")
async def startup():
    redis_client = redis.from_url("redis://localhost")
    await FastAPILimiter.init(redis_client)

@app.get("/limited-endpoint", dependencies=[Depends(RateLimiter(times=10, seconds=60))])
async def limited_endpoint():
    return {"message": "This endpoint is rate-limited"}

3. 根据用户的限流实现

用户信息限流

通过 JWT 中的用户信息进行限流。读取 JWT,提取用户信息并应用限流策略。

@app.get("/user-limited-endpoint", dependencies=[Depends(RateLimiter(times=5, seconds=60))])
async def user_limited_endpoint(current_user: User = Depends(get_current_active_user)):
    return {"message": f"User {current_user.username} is rate-limited"}

4. 限流策略的配置

不同策略的配置方法

可以为不同的用户或 IP 配置不同的限流策略。

@app.get("/admin-endpoint", dependencies=[Depends(RateLimiter(times=20, seconds=60))])
async def admin_endpoint(current_user: User = Depends(get_current_active_user)):
    if current_user.username != "admin":
        raise HTTPException(status_code=403, detail="Forbidden")
    return {"message": "Admin endpoint"}
动态调整和优先级

在运行时调整限流规则。可以通过修改配置文件或数据库中的设置来动态调整限流策略。

@app.get("/dynamic-rate-limit", dependencies=[Depends(RateLimiter(times=10, seconds=60))])
async def dynamic_rate_limit():
    return {"message": "Dynamic rate limit endpoint"}

5. 限流的监控和报警

监控限流状态

记录限流事件以便于监控。可以使用工具(如 Grafana)进行可视化监控。

from fastapi_limiter.depends import RateLimiter
from fastapi_limiter import FastAPILimiter
import redis.asyncio as redis
from prometheus_client import Counter

requests_total = Counter('http_requests_total', 'Total HTTP Requests', ['method', 'endpoint'])

@app.on_event("startup")
async def startup():
    redis_client = redis.from_url("redis://localhost")
    await FastAPILimiter.init(redis_client)

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = str(process_time)
    requests_total.labels(method=request.method, endpoint=request.url.path).inc()
    return response

@app.get("/limited-endpoint", dependencies=[Depends(RateLimiter(times=10, seconds=60))])
async def limited_endpoint():
    return {"message": "This endpoint is rate-limited"}
报警机制

设置报警机制以监控异常限流事件。可以结合日志系统(如 ELK Stack)实现报警。

from prometheus_client import start_http_server, Summary

# 启动 Prometheus HTTP 服务器
start_http_server(8001)

@app.get("/limited-endpoint", dependencies=[Depends(RateLimiter(times=10, seconds=60))])
async def limited_endpoint():
    return {"message": "This endpoint is rate-limited"}

6. 最佳实践和常见问题

边缘案例处理

处理限流中的边缘案例。例如,当多个请求几乎同时到达时,如何确保限流的准确性。

@app.get("/edge-case-endpoint", dependencies=[Depends(RateLimiter(times=10, seconds=60))])
async def edge_case_endpoint():
    return {"message": "Edge case endpoint"}
性能考虑

评估限流策略对应用性能的影响。选择合适的限流算法和存储后端,确保限流不会成为性能瓶颈。

@app.get("/performance-sensitive-endpoint", dependencies=[Depends(RateLimiter(times=50, seconds=60))])
async def performance_sensitive_endpoint():
    return {"message": "Performance sensitive endpoint"}
用户反馈

提供友好的限流反馈信息。通过 HTTP 响应头传递限流信息给客户端,帮助用户理解限流情况。

@app.get("/limited-endpoint", dependencies=[Depends(RateLimiter(times=10, seconds=60))])
async def limited_endpoint():
    return {"message": "This endpoint is rate-limited"}

实际案例

为了更好地理解上述内容,我们来看一个完整的示例项目。假设我们正在开发一个简单的博客系统,需要实现用户认证和访问限流。

目录树

blog_app/
├── main.py
├── models.py
├── schemas.py
├── utils.py
└── requirements.txt

main.py

from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from datetime import datetime, timedelta
from pydantic import BaseModel
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
import redis.asyncio as redis
from prometheus_client import Counter, start_http_server

app = FastAPI()

# 密钥和算法
SECRET_KEY = "your_secret_key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24  # 24 hours

# 密码加密
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# OAuth2
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

# 示例用户数据库
fake_users_db = {
    "johndoe": {
        "username": "johndoe",
        "full_name": "John Doe",
        "email": "johndoe@example.com",
        "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW",  # 密码: secret
        "disabled": False,
    },
    "admin": {
        "username": "admin",
        "full_name": "Admin User",
        "email": "admin@example.com",
        "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW",  # 密码: secret
        "disabled": False,
    }
}

# 模型类
class User(BaseModel):
    username: str
    email: str | None = None
    full_name: str | None = None
    disabled: bool | None = None

class UserInDB(User):
    hashed_password: str

# 模拟文章模型
class Article(BaseModel):
    id: int
    title: str
    content: str

# 示例文章数据库
fake_articles_db = [
    {"id": 1, "title": "First Article", "content": "This is the first article."},
    {"id": 2, "title": "Second Article", "content": "This is the second article."},
]

def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

def get_password_hash(password):
    return pwd_context.hash(password)

def get_user(db, username: str):
    if username in db:
        user_dict = db[username]
        return UserInDB(**user_dict)

def authenticate_user(fake_db, username: str, password: str):
    user = get_user(fake_db, username)
    if not user:
        return False
    if not verify_password(password, user.hashed_password):
        return False
    return user

def create_access_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

async def get_current_user(token: str = Depends(oauth2_scheme)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=username)
    if user is None:
        raise credentials_exception
    return user

async def get_current_active_user(current_user: User = Depends(get_current_user)):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user

@app.on_event("startup")
async def startup():
    redis_client = redis.from_url("redis://localhost")
    await FastAPILimiter.init(redis_client)
    start_http_server(8001)  # 启动 Prometheus HTTP 服务器

@app.post("/token")
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
    user = authenticate_user(fake_users_db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}

@app.post("/refresh_token")
async def refresh_access_token(refresh_token: str = Depends(oauth2_scheme)):
    try:
        payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise HTTPException(status_code=401, detail="Invalid refresh token")
    except JWTError:
        raise HTTPException(status_code=401, detail="Invalid refresh token")
    user = get_user(fake_users_db, username=username)
    if user is None:
        raise HTTPException(status_code=401, detail="Invalid refresh token")
    
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}

@app.get("/users/me")
async def read_users_me(current_user: User = Depends(get_current_active_user)):
    return current_user

@app.get("/admin")
async def read_admin_info(current_user: User = Depends(get_current_active_user)):
    if current_user.username != "admin":
        raise HTTPException(status_code=403, detail="Forbidden")
    return {"message": "Admin information"}

@app.get("/articles/{article_id}", dependencies=[Depends(RateLimiter(times=10, seconds=60))])
async def get_article(article_id: int):
    for article in fake_articles_db:
        if article["id"] == article_id:
            return article
    raise HTTPException(status_code=404, detail="Article not found")

@app.get("/user-limited-endpoint", dependencies=[Depends(RateLimiter(times=5, seconds=60))])
async def user_limited_endpoint(current_user: User = Depends(get_current_active_user)):
    return {"message": f"User {current_user.username} is rate-limited"}

@app.get("/admin-endpoint", dependencies=[Depends(RateLimiter(times=20, seconds=60))])
async def admin_endpoint(current_user: User = Depends(get_current_active_user)):
    if current_user.username != "admin":
        raise HTTPException(status_code=403, detail="Forbidden")
    return {"message": "Admin endpoint"}

@app.get("/dynamic-rate-limit", dependencies=[Depends(RateLimiter(times=10, seconds=60))])
async def dynamic_rate_limit():
    return {"message": "Dynamic rate limit endpoint"}

models.py

from pydantic import BaseModel

class User(BaseModel):
    username: str
    email: str | None = None
    full_name: str | None = None
    disabled: bool | None = None

class UserInDB(User):
    hashed_password: str

class Article(BaseModel):
    id: int
    title: str
    content: str

schemas.py

from pydantic import BaseModel

class Token(BaseModel):
    access_token: str
    token_type: str

utils.py

from passlib.context import CryptContext

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

def get_password_hash(password):
    return pwd_context.hash(password)

requirements.txt

fastapi
uvicorn
python-jose
passlib[bcrypt]
fastapi-limiter
redis
prometheus_client

总结

通过本文,我们深入探讨了如何在 FastAPI 中实现 JWT 认证和访问限流。从 JWT 的基础概念到实际代码实现,我们逐步介绍了每个步骤,并提供了具体的示例和代码。希望这些内容能够帮助你在自己的项目中实现安全且高效的 API 认证和限流机制。

我创建了一个站点用于分享大模型和多智能体相关的内容,欢迎访问:https://llmmultiagents.com/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值