文章目录
了解FastAPI程序结构
编写一个简单的FastAPI程序需要五个小步骤,先看一个完整例子
from fastapi import FastAPI
app = FastAPI()
@app.get("/")
def root():
return {"message": "Hello World"}
第一步,导入FastAPI
from fastapi import FastAPI
第二步,创建一个app实例
app = FastAPI()
第三步,编写一个 路径操作装饰器
@app.get("/")
需要注意的两点是:
● 你可以将get操作方法更改成@app.post()、@app.put()、@app.delete()等方法
● 你可以更改相应的路径(“/”)为自己想要的,例如我更改为(“/hello_word/”)
第四步,编写一个路径操作函数,例如下面代码中的root函数。它位于路径操作装饰器下方(见上方例子)
def root():
return {“message”: “Hello World”}
这个函数的返回值可以是
dict,list,单独的值,比如str,int,或者是Pydantic模型
第五步、运行开发服务器uvicorn main:app --reload即可访问api链接。
也可以安装uvicorn,启动
pip install uvicorn
import uvicorn
from fastapi import FastAPI
app=FastAPI()
if __name__ == '__main__':
uvicorn.run(app)
- 例如我在终端运行uvicorn main:app --reload之后,在浏览器输入127.0.0.1:8000,出现"message": "Hello World"这句话。
- 在这里可以自己指定要运行的服务器ip和端口号。
- 例如:uvicorn main:app --host 127.0.0.1 --port 8001 --reload表示指定本地电脑为服务器,端口号为8001。下面所有的代码演示都默认这个本机ip地址和8001端口号。
符案例
from fastapi import FastAPI
import uvicorn
from service.knowledge_service import router as kno_router
from service.session_service import router as session_router
app = FastAPI()
# 声明多应用
app.include_router(kno_router)
app.include_router(session_router)
@app.get("/")
def root():
return {"message": "Hello World"}
if __name__ == '__main__':
uvicorn.run(app="main:app", host="127.0.0.1", port=8000, reload=True)
声明路径参数
from fastapi import FastAPI
app = FastAPI()
@app.get("/items/{item_id}")
def read_item(item_id):
return {"item_id": item_id}
声明路径参数的类型
from fastapi import FastAPI
app = FastAPI()
@app.get1("/items/{item_id}")
async def read_item1(item_id: int):
return {"item_id": item_id}
@app.get2("/items/{item_name}")
def read_item2(item_name: str):
return {"item_id": item_name}
get请求查询参数
from fastapi import FastAPI
app = FastAPI()
@app.get("/files/")
def add(num1: int=2, num2: int=8):
return {"num1 + num2 = ": num1 + num2}
请求体
请求体是客户端发送到您的API的数据。 响应体是您的API发送给客户端的数据。
API几乎总是必须发送一个响应体,但是客户端并不需要一直发送请求体。
定义请求体,需要使用 Pydantic 模型。注意以下几点
不能通过GET请求发送请求体
发送请求体数据,必须使用以下几种方法之一:POST(最常见)、PUT、DELETE、PATCH
如何实现请求体
第一步,从pydantic中导入BaseModel
from pydantic import BaseModel
第二步,创建请求体数据模型
声明请求体数据模型为一个类,且该类继承 BaseModel。所有的属性都用标准Python类。和查询参数一样:数据类型的属性如果不是必须的话,可以拥有一个默认值或者是可选None。否则,该属性就是必须的。
from pydantic import BaseModel
class Item(BaseModel):
name: str
description: str = None
price: float
tax: float = None
所以访问链接的时候传入的请求体可以是下面两种
第一种
{
"name": "Foo",
"description": "An optional description",
"price": 45.2,
"tax": 3.5
}
第二种
{
"name": "Foo",
"price": 45.2
}
第三步、将模型定义为参数
将上面定义的模型添加到你的路径操作中,就和定义Path和Query参数一样的方式:
from fastapi import FastAPI
from pydantic import BaseModel
class Item(BaseModel):
name: str
description: str = None
price: float
tax: float = None
app = FastAPI()
@app.post("/items/")
async def create_item(item: Item):
return item
【注】
(1)
class Item(BaseModel):
name: str
description: str | None = None
price: float
tax: float | None = None
@app.put("/items/{item_id}")
async def update_item(item_id: int, item: Annotated[Item, Body(embed=True)]):
results = {"item_id": item_id, "item": item}
return results
async def update_item(item_id: int, item: Annotated[Item, Body(embed=True)]):
item: Annotated[Item, Body(embed=True)
item: Annotated[Item, Body(embed=True)]
● 这部分定义了请求体中的数据。它使用了 Annotated 和 Body 来为 item 参数指定详细的信息。
● Annotated[Item, Body(embed=True)]:
○ Item 是一个 Pydantic 模型类,表示请求体的数据结构。Pydantic 会验证请求体中的数据是否符合 Item 类的要求,自动进行类型检查。
○ Item 类的字段包括:
■ name(字符串类型,必填)
■ description(可选的字符串)
■ price(浮动类型,必填)
■ tax(可选的浮动类型)
○ Body(embed=True) 是 FastAPI 提供的一个功能,表示请求体数据应该“嵌套”在一个对象内,而不是直接以属性的形式存在。这意味着,客户端请求的 JSON 数据结构应该像这样:
{
“item”: {
“name”: “item_name”,
“description”: “item_description”,
“price”: 99.99,
“tax”: 5.0
}
}
(2)
from pydantic import BaseModel, Field
description: str | None = Field(
default=None, title="The description of the item", max_length=300
)
= Field(…)
● 这里使用了 Pydantic 的 Field 函数来为 description 字段提供更多的元数据和验证规则。
● Field 是 Pydantic 用来定义字段的函数,可以用来设置字段的默认值、验证条件、描述信息等。
跨域配置
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# 在这里换上前端的源请求
origins = [
"http://localhost.tiangolo.com",
"https://localhost.tiangolo.com",
"http://localhost",
"http://localhost:8080",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def main():
return {"message": "Hello World"}
数据库连接
1.配置
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager
# MySQL所在主机名
HOSTNAME = "127.0.0.1"
# MySQL监听的端口号,默认3306
PORT = 3306
# 连接MySQL的用户名,自己设置
USERNAME = "root"
# 连接MySQL的密码,自己设置
PASSWORD = "root"
# MySQL上创建的数据库名称
DATABASE = "fadiantest"
SQLALCHEMY_DATABASE_URL = f"mysql+pymysql://{USERNAME}:{PASSWORD}@{HOSTNAME}:{PORT}/{DATABASE}?charset=utf8mb4"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@contextmanager
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
Base = declarative_base()
2.定义实体
from sqlalchemy import Column, String, Integer, Float ,DateTime
from config.mysql_config import Base
class PowerPlant(Base):
'''发电量'''
__tablename__ = 'power_plant'
id = Column(Integer, primary_key=True, autoincrement=True)
power_plant_date = Column(String(128))
month_power_generation = Column(Float)
unit = Column(String(128))
3.应用
from config.mysql_config import get_db
from pojo.entities import PowerPlant
@app.get("/hi")
async def say_hello():
with get_db() as db: # 确保会话被正确管理
powerPlant = PowerPlant(power_plant_date="2024-11-11", month_power_generation=12.21, unit="机组")
db.add(powerPlant)
db.commit()
return {"message": f"添加成功"}
CRUD
from fastapi import FastAPI,Query
import uvicorn
from sqlalchemy.orm import Session
from sqlalchemy import and_
app = FastAPI()
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/hello/{name}")
async def say_hello(name: str):
return {"message": f"Hello {name}"}
from config.mysql_config import get_db
from models.entities import PowerPlant
from pojo.entities import PowerPlantPo
@app.get("/hi")
async def say_hello():
with get_db() as db: # 确保会话被正确管理
powerPlant = PowerPlant(power_plant_date="2024-11-11", month_power_generation=12.21, unit="机组")
db.add(powerPlant)
db.commit()
return {"message": f"添加成功"}
@app.get("/getAll")
async def get_all_power_plants():
with get_db() as db:
powerPlants = db.query(PowerPlant).all() # 查询所有记录
return [{"id": plant.id,
"power_plant_date": plant.power_plant_date,
"month_power_generation": plant.month_power_generation,
"unit": plant.unit} for plant in powerPlants]
@app.put("/update_power_plant")
async def update_power_plant(powerPlant: PowerPlantPo):
with get_db() as db:
db_power_plant = db.query(PowerPlant).filter(PowerPlant.id == powerPlant.id).first()
if db_power_plant:
db_power_plant.power_plant_date = powerPlant.power_plant_date
db_power_plant.month_power_generation = powerPlant.month_power_generation
db_power_plant.unit = powerPlant.unit
db.commit() # 提交事务
db.refresh(db_power_plant) # 确保数据已经被提交并刷新到数据库
print(f"Updated power plant: {db_power_plant}") # 打印确认是否更新
return {"message": "更新成功"}
return {"message": "未找到该电厂信息"}
@app.post("/addpow")
async def add_power_plant(powerPlant: PowerPlantPo):
with get_db() as db:
powerPlant = PowerPlant(power_plant_date=powerPlant.power_plant_date,month_power_generation=powerPlant.month_power_generation,unit=powerPlant.unit)
db.add(powerPlant) # 将实体添加到会话
db.commit() # 提交事务
return {"message": "添加成功"}
@app.delete("/deletepower/{plant_id}")
async def delete_power_plant(plant_id: int):
with get_db() as db:
powerPlant = db.query(PowerPlant).filter(PowerPlant.id == plant_id).first()
if powerPlant:
db.delete(powerPlant) # 删除数据
db.commit() # 提交事务
return {"message": "删除成功"}
return {"message": "未找到该电厂信息"}
@app.get("/powerplants")
async def get_power_plants(
page: int = 1, # 默认第一页
size: int = 10, # 默认每页10条
name: str = Query(None), # 过滤条件:电厂名称
location: str = Query(None) # 过滤条件:电厂位置
):
offset = (page - 1) * size # 计算偏移量
with get_db() as db:
# 构造查询过滤条件
query_filters = []
if name:
query_filters.append(PowerPlant.unit.like(f"%{name}%"))
if location:
query_filters.append(PowerPlant.power_plant_date.like(f"%{location}%"))
# 使用and_将多个过滤条件组合
query = db.query(PowerPlant).filter(and_(*query_filters)) if query_filters else db.query(PowerPlant)
# 获取总数
total_count = query.count()
# 获取分页后的数据
power_plants = query.offset(offset).limit(size).all()
# 计算总页数
total_pages = (total_count + size - 1) // size
return {
"total_count": total_count,
"total_pages": total_pages,
"page": page,
"size": size,
"power_plants": power_plants
}
if __name__ == '__main__':
uvicorn.run(app)
``
## 多个文件对应flask的蓝图
### 注册多个文件
```python
from fastapi import APIRouter
from config.mysql_config import get_db
from models.entities import PowerPlant
router = APIRouter(
prefix="/user",
tags=["user"],
responses={404: {"description": "Not found"}}
)
@router.get('/getAll',tags=["users"])
def read_pow():
with get_db() as db:
powerPlants = db.query(PowerPlant).all() # 查询所有记录
return [{"id": plant.id,
"power_plant_date": plant.power_plant_date,
"month_power_generation": plant.month_power_generation,
"unit": plant.unit} for plant in powerPlants]
在main.py中声明
from service.user_service import router as user_router
app = FastAPI()
app.include_router(user_router)