【2023.07.24-7.25】pytorch basics:autograd, optimization, save and load model

文章介绍了PyTorch中的autograd机制,包括正向传播和反向传播过程,以及backward函数的使用。同时,讨论了优化过程,如超参数、优化器(如SGD)的使用,以及训练和验证循环的实现。此外,还提到了模型权重的保存和加载方法,以及model.train()和model.eval()的作用。
部署运行你感兴趣的模型镜像

1.autograd

1)计算图是一个DAG有向无环图。

正向传播:autograd包含两个步骤,首先计算tensor结果,维护操作的.grad_fn。

反向传播:显示的调用backward函数后,包含三个步骤:首先用.grad_fn计算梯度值;然后所有requires_grad属性为True的结点的梯度(即结点的.grad属性)被累积更新;最终利用链式法则,反向传播。

2)backward(gradient,retain_graph)

(1)gradient:

对于标量结点,参数可选,此时参数被默认为torch.Tensor([1])。

对于非标量结点,则必需要gradient。且gradient尺寸需要和输出尺寸相同,

import torch
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()

(2)retain_graph:一个计算图只能调用一次backward函数。如果需要对同一张图多次调用,则需要backward(retain_graph=True).并且多次调用的梯度会累计。

3)取消gradient tracking

为了固定网络参数或加速计算省去后向传播,所以会不记录部分tensor梯度。

通过以下两种方式设置。用z_det=z.detach()定义的tensor或with torch.no_grad() 语句被包裹的结点。

z=torch.matmul(x,a)+b
z_det=z.detach()
with torch.no_grad():
    z_det2=torch.matmul(x,a)+b

4)查看梯度和grad_fn

(1)结点梯度:print(x.grad)

(2)梯度函数:如果是直接创建tensor,则.grad_fn=None.如果是y=x+x,则.grad_fn=AddBackward

2.optimization

1) hyperparam

number of epochs:迭代模型的次数

batch size:每次迭代需要的样本数

learning rate:每个epoch更新模型的程度。值越小,模型学习速度越慢。值越大,模型容易出现不可预测的行为。

2)optimization loop

training loop:在训练集上迭代,目标是使模型收敛

validation loop:在验证集上迭代,验证模型效果。

3)优化器

优化算法被封装在优化器对象中。在初始化优化器时,我们把需要优化的模型参数、超参数作为初始化时的参数传给优化器。

optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)

在每个批次之间,为了防止模型梯度叠加,使用optimizer.zero_grad()清零梯度

在使用loss.backward()更新梯度后,使用optim.step()使用梯度更新模型参数

4)training loop和validation loop的实现

def train_loop(dataloader,model,optimizer,loss_fn)
    model.train()
    size=len(dataloader.dataset)
    for batch,(X,y) in enumerate(dataloader):
        pred=model(X)
        loss=loss_fn(pred,y)
        loss.backward()
        optim.step()
        optim.zero_grad()

    if batch%100==0:
        loss,current=loss.item(),(batch+1)*len(X)
        print(f"train loss:{loss:>7f} [{current position:>5d}/{size:5d}]"))
#每个loss是在每个batch上计算,test_loop中loss需要求和因为需要求多个loss的平均,但train_loop是每100个batch才打印
#因为loss_fn的返回值是一个标量,pred才是和输出相同维度的。


def test_loop(dataloader,model,loss_fn)
    model.eval()
    test_loss,correct=0,0 #因为test_loop函数被多次调用,且两个变量累加。所以需要清零。
    size=len(dataloader.dataset)
    num_batches=len(dataloader)
    with torch.no_grad(): #节约资源不用再计算梯度
        for (X,y) in dataloader:
            pred=model(X)
            loss=loss_fn(pred,y)
            test_loss+=loss.item()
            correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
            #一个样本的预测标签为每一行中的最大值,即pred.argmax(1),列为种类、行为为一个batch中的样本。
    correct/=size
    test_loss/=num_batches #在多个batch上的平均
    print(f"correctness:"crrect/batch_size,"train loss:"loss)

epochs=10
loss_fn=nn.CrossEntropyLoss()
for i in range(epochs):
    print(f"Epoch"i+1"-------------")
    train_loop(training_dataloader,model,optimizer,loss_fn)
    test_loop(testing_dataloader,model,loss_fn)
print("done")

3. save and load model

1)只是用于推理:只保存模型权重

#save model weights
torch.save(model.state_dict(),'model_weight.pth')
#load model weights
model.load_state_dict(torch.load('model_weights.pth'))
#pytorch load save函数可以保存任何文件,保存的时候只保存了state_dict,则取出的时候也调用load_state_dict


2)需要继续训练模型:保存整个模型

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
#save model including optimizer,epoch,score
torch.save(state,filepath)

#load model including optimizer,epoch,score
model=torch.load(state,filepath)

------------

1.model.tran() model.eval()作用

由于之前model = NeuralNetwork(),因此model是自定义类NeuralNetwork的实例。而NeuralNetwork是nn.Module的子类。nn.Module中包含属性.train();.eval()

表示模型在不同的状态下。在model.train()下模型会根据不同批的结果,更新梯度。但在model.eval()下不会,会关闭BN,dropout。具体解释dropout参照deep learning 7.12:deeplearningbook.org/contents/regularization.html, 解释BN参照deep learning 8.1.3:deeplearningbook.org/contents/optimization.html

2.model(X)

等价于调用__call__(),该函数调用model.forward(X)+hooks。

3.optimizer.zero_grad()

为什么清零模型梯度需要通过调用optimizer的属性呢?当把模型所有参数都在定义优化器时,给了同一个优化器,则optimizer.zero_grad()效果等于model.zero_grad().除非一个模型使用了多个优化器优化。

4.dataloader类

将dataset和决定如何从数据集中获取样本的sampler结合在一起。

.num_workers定义cpu加载数据需要子进程的数量。会尽可能均匀的将batch_size个样本平均分配给num_workers。num_workers数量越大,对内存消耗越大,模型能使用的内存越少;num_workers数量太小时,Gpu必须等待cpu加载数据。结合IO、内存使用情况。所以num_worker数量需要结合batch_size数量选择,最好和训练速度匹配。当num_workers数量为0时,只由主进程加载数据。

.dataset:获取定义dataloader的数据集。

5.为什么要设计batch和epoch

epoch数量是神经网络在整个数据集上训练的次数。

batch是把整个数据集分块的块数。batch的存在是由于一次把整个数据集太大,不能整个输入给模型。

7 tensor.type(dtype)

tensor的数据类型的强制转换。torch.float,torch.double等

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

<think> 我们面对的是WebSocket连接时出现的400 Bad Request错误。结合引用内容和之前的分析,我们可以从以下几个方面解决: 1. **请求头格式问题**:引用[1]和[2]都提到了请求格式不正确可能导致400错误。在WebSocket连接中,请求头必须符合HTTP协议规范,且每个头部字段的格式应为`Key: Value`,并以列表形式传递(而不是字典)。此外,需要检查是否有非法字符或格式错误。 2. **URL编码问题**:引用[4]和[5]指出,URL中如果包含非ASCII字符(如中文)或特殊字符,需要进行URL编码。虽然我们的URL中没有明显的中文,但可能包含其他需要编码的字符。我们可以检查URL是否需要编码。 3. **参数传递方式**:引用[3]提到,当使用POST请求时,将`data`参数换成`json`参数可以解决400错误。在WebSocket连接中,我们通过请求头传递参数,因此需要确保参数传递的方式正确。 4. **认证信息有效性**:检查JWT token是否过期(根据日志,token的有效期在2025年,所以暂时未过期)以及格式是否正确。 5. **服务器配置问题**:引用[2]提到服务器返回400错误可能是服务器配置问题,比如服务器不支持WebSocket协议或路由配置错误。 ### 解决方案步骤: #### 1. 修正请求头格式 在`websocket.WebSocketApp`中,`header`参数需要是字符串列表,每个字符串是`"Key: Value"`格式。确保每个头部字段的键和值之间有一个空格(例如`"client-id: SN90230307-1764038084155"`)。 #### 2. 对URL进行编码 虽然URL中可能没有非ASCII字符,但为了安全,我们可以对URL路径部分进行编码(使用`urllib.parse.quote`,但注意不要对协议、主机和端口部分编码)。 #### 3. 检查认证信息 确保token没有换行符或其他非法字符,并且符合JWT格式(三个点分隔的部分)。 #### 4. 使用wscat工具测试 使用wscat命令行工具测试连接,以确定是客户端代码问题还是服务器问题: ```bash wscat -c "ws://81.71.143.13:8083/mqtt" -H "username: test" -H "password: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJoZy11YXYtY2xvdWQtYXBpIiwibmJmIjoxNzY0MDM4MDg3LCJpc3MiOiJESkkiLCJhY2wiOnsic3ViIjpbInJvYm90L1NOOTAyMzAzMDcvZHJjL2Rvd24iXSwicHViIjpbInJvYm90L1NOOTAyMzAzMDcvZHJjL3VwIl19LCJleHAiOjE3NjQwNDE2ODcsImlhdCI6MTc2NDAzODA4N30.ZVrpEySAJz18mqIktGN3ilZS8-lVAzdNm64uZLa1z6w" -H "client-id: SN90230307-1764038084155" ``` #### 5. 修改Python代码 根据以上分析,修改Python代码: ```python #!/usr/bin/env python3 import websocket import json import logging import time import urllib.parse # 配置日志 logging.basicConfig( level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("WebSocketTest") websocket.enableTrace(True) # 启用详细追踪 def test_websocket_connection(): # 构建URL,确保路径部分正确编码(如果包含特殊字符) base_url = 'ws://81.71.143.13:8083/mqtt' # 对路径部分编码(如果路径中有特殊字符) # 由于路径是/mqtt,没有特殊字符,所以不需要编码,但如果有参数则不同 # 如果有查询参数,则需要对查询参数进行编码,但这里没有查询参数 config = { 'address': base_url, 'username': 'test', 'password': 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJoZy11YXYtY2xvdWQtYXBpIiwibmJmIjoxNzY0MDM4MDg3LCJpc3MiOiJESkkiLCJhY2wiOnsic3ViIjpbInJvYm90L1NOOTAyMzAzMDcvZHJjL2Rvd24iXSwicHViIjpbInJvYm90L1NOOTAyMzAzMDcvZHJjL3VwIl19LCJleHAiOjE3NjQwNDE2ODcsImlhdCI6MTc2NDAzODA4N30.ZVrpEySAJz18mqIktGN3ilZS8-lVAzdNm64uZLa1z6w', 'client_id': 'SN90230307-1764038084155' } # 构建正确的header列表 headers = [ f"username: {config['username']}", f"password: {config['password']}", f"client-id: {config['client_id']}" ] def on_message(ws, message): logger.info(f"收到消息: {message}") def on_error(ws, error): logger.error(f"WebSocket错误: {error}") # 获取更多错误详情 if hasattr(error, 'status_code'): logger.error(f"HTTP状态码: {error.status_code}") if hasattr(error, 'headers'): logger.error(f"响应头: {error.headers}") def on_close(ws, close_status_code, close_msg): logger.info(f"连接关闭: {close_status_code} - {close_msg}") def on_open(ws): logger.info("连接成功建立") # 订阅主题 subscribe_msg = { "method": "subscribe", "data": { "topics": ["robot/SN90230307/drc/down"] } } ws.send(json.dumps(subscribe_msg)) logger.info("已发送订阅请求") logger.info(f"尝试连接WebSocket: {config['address']}") try: ws = websocket.WebSocketApp( config['address'], header=headers, # 这里传递的是列表 on_open=on_open, on_message=on_message, on_error=on_error, on_close=on_close ) # 设置更长的超时时间 ws.run_forever( ping_interval=30, ping_timeout=10, reconnect=5 # 自动重连 ) except Exception as e: logger.exception("发生未捕获的异常") if __name__ == "__main__": test_websocket_connection() ``` ### 其他可能的原因和尝试: 1. **服务器端路由配置**:确保服务器端的WebSocket端点正确配置为`/mqtt`。根据引用[2],服务器返回400可能是由于请求的路径不存在或不允许WebSocket升级。 2. **使用wss协议**:如果服务器要求安全连接,应使用`wss`而不是`ws`。尝试将地址改为`wss://81.71.143.13:8083/mqtt`。 3. **检查服务器防火墙和端口**:确保端口8083在服务器上是开放的,并且防火墙允许通过。 4. **服务器日志**:查看服务器日志,因为400错误的具体原因通常会在服务器日志中记录。 ### 总结步骤: 1. 使用修正后的Python代码(确保请求头格式为列表)。 2. 尝试使用wscat命令行测试,以区分是代码问题还是服务器问题。 3. 如果wscat同样失败,则检查服务器配置和网络环境;如果wscat成功,则检查代码中其他部分。 4. 尝试使用wss协议。 5. 检查服务器日志获取更详细的错误信息。 §§相关问题§§ 1. 如何使用Python的websocket-client库实现WebSocket通信? 2. 如何诊断和修复WebSocket连接中的400 Bad Request错误? 3. 在WebSocket连接中如何正确传递认证信息? 4. 如何配置Nginx代理WebSocket连接? 5. WebSocket连接中常见的错误代码有哪些?如何解决? 希望这些步骤能帮助你解决问题。 </think># 解决WebSocket连接400 Bad Request错误的完整指南 ## 问题分析 您在连接WebSocket服务`ws://81.71.143.13:8083/mqtt`时遇到HTTP 400 Bad Request错误,这通常表示客户端发送的请求在服务器端被认为是无效的。根据您的日志和引用内容,可能的原因包括: 1. **请求头格式不正确** - 使用字典格式而非列表格式(引用[1]) 2. **URL编码问题** - 特殊字符未正确编码(引用[4][5]) 3. **认证信息无效** - JWT token格式或内容问题 4. **服务器配置问题** - 路由或代理设置错误(引用[2]) 5. **参数传递方式错误** - 使用data而非json传递参数(引用[3]) ## 完整解决方案 ### 1. 修正请求头格式(关键修复) ```python # 错误方式:使用字典 headers = { "username": "test", "password": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...", "client-id": "SN90230307-1764038084155" } # 正确方式:使用列表 headers = [ "username: test", # 注意冒号后有空格 "password: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...", "client-id: SN90230307-1764038084155" # 注意连字符改为下划线 ] ``` ### 2. 确保URL正确编码 ```python from urllib.parse import quote # 对特殊字符进行编码 topic = quote("robot/SN90230307/drc/down", safe='') url = f"ws://81.71.143.13:8083/mqtt?topic={topic}" ``` ### 3. 验证和刷新认证token ```python import jwt def validate_token(token): try: # 解析JWT token decoded = jwt.decode(token, options={"verify_signature": False}) print(f"Token有效期至: {decoded['exp']}") return True except jwt.ExpiredSignatureError: print("Token已过期") return False except Exception as e: print(f"Token验证失败: {str(e)}") return False # 使用前验证token if not validate_token(config['password']): print("需要刷新token") ``` ### 4. 使用wscat命令行工具测试 ```bash # 安装wscat npm install -g wscat # 测试连接 wscat -c "ws://81.71.143.13:8083/mqtt" \ -H "username: test" \ -H "password: eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9..." \ -H "client-id: SN90230307-1764038084155" ``` ### 5. 完整修正后的Python代码 ```python #!/usr/bin/env python3 import websocket import json import logging import time import urllib.parse import jwt # 配置日志 logging.basicConfig( level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("WebSocketTest") websocket.enableTrace(True) def validate_token(token): try: decoded = jwt.decode(token, options={"verify_signature": False}) logger.info(f"Token有效期至: {decoded['exp']}") return True except Exception as e: logger.error(f"Token验证失败: {str(e)}") return False def test_websocket_connection(): config = { 'address': 'ws://81.71.143.13:8083/mqtt', 'username': 'test', 'password': 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJoZy11YXYtY2xvdWQtYXBpIiwibmJmIjoxNzY0MDM4MDg3LCJpc3MiOiJESkkiLCJhY2wiOnsic3ViIjpbInJvYm90L1NOOTAyMzAzMDcvZHJjL2Rvd24iXSwicHViIjpbInJvYm90L1NOOTAyMzAzMDcvZHJjL3VwIl19LCJleHAiOjE3NjQwNDE2ODcsImlhdCI6MTc2NDAzODA4N30.ZVrpEySAJz18mqIktGN3ilZS8-lVAzdNm64uZLa1z6w', 'client_id': 'SN90230307-1764038084155' } # 验证token if not validate_token(config['password']): logger.error("认证token无效,请刷新token") return # 构建正确的header列表 headers = [ f"username: {config['username']}", f"password: {config['password']}", f"client-id: {config['client_id']}" ] def on_message(ws, message): logger.info(f"收到消息: {message}") # 处理消息逻辑 def on_error(ws, error): if hasattr(error, 'status_code'): logger.error(f"HTTP状态码: {error.status_code}") if hasattr(error, 'headers'): logger.error(f"响应头: {error.headers}") logger.error(f"WebSocket错误: {error}") # 尝试自动恢复连接 time.sleep(5) logger.info("尝试重新连接...") connect_websocket() def on_close(ws, close_status_code, close_msg): logger.info(f"连接关闭: {close_status_code} - {close_msg}") def on_open(ws): logger.info("连接成功建立") # 订阅主题(URL编码) topic = urllib.parse.quote("robot/SN90230307/drc/down", safe='') subscribe_msg = { "method": "subscribe", "data": { "topics": [topic] } } ws.send(json.dumps(subscribe_msg)) logger.info(f"已发送订阅请求: {topic}") def connect_websocket(): logger.info(f"尝试连接WebSocket: {config['address']}") ws = websocket.WebSocketApp( config['address'], header=headers, on_open=on_open, on_message=on_message, on_error=on_error, on_close=on_close ) # 设置超时和重连参数 ws.run_forever( ping_interval=30, ping_timeout=10, reconnect=5 ) connect_websocket() if __name__ == "__main__": test_websocket_connection() ``` ### 6. 服务器端配置检查 如果上述方法无效,可能是服务器端配置问题: 1. **Nginx配置**: ```nginx location /mqtt { proxy_pass http://backend; proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection "upgrade"; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; } ``` 2. **WebSocket路由处理**(引用[2]): ```python # Flask示例 @app.route('/mqtt') def handle_mqtt(): if request.headers.get('Upgrade', '').lower() == 'websocket': # WebSocket处理逻辑 return handle_websocket() else: abort(400, description="WebSocket connection required") ``` ## 调试技巧 1. **使用HTTP客户端测试**: ```python import requests response = requests.get( "http://81.71.143.13:8083/mqtt", headers={ "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Version": "13", "Sec-WebSocket-Key": "test123" } ) print(response.status_code, response.headers) ``` 2. **网络诊断命令**: ```bash # 测试端口连通性 telnet 81.71.143.13 8083 # 跟踪路由 traceroute 81.71.143.13 # 检查本地防火墙 sudo ufw status ``` ## 总结 解决WebSocket 400 Bad Request错误的关键步骤: 1. 确保使用正确的请求头格式(列表而非字典) 2. 验证URL特殊字符正确编码 3. 检查认证token有效性 4. 使用wscat命令行工具验证连接 5. 检查服务器端配置(Nginx/应用程序) 6. 实现自动重连机制增强鲁棒性 如果问题仍未解决,建议检查服务器日志获取更详细的错误信息,通常400错误会包含服务器拒绝请求的具体原因[^1][^2][^3]。
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值