PyTorch:torch.clamp()用法详解

torch.clamp函数用于限制输入Tensor的值在指定范围内,例如[min, max],确保输出始终在这个区间内。该函数常用于数值计算的边界处理。在提供的示例中,创建了一个3x3的Tensor a,通过torch.clamp限制其元素值在3到6之间,输出的b保持了相同的形状并应用了限幅操作。
该文章已生成可运行项目,

函数定义

torch.clamp(input, min, max, out=None)

作用:限幅。将input的值限制在[min, max]之间,并返回结果。out (Tensor, optional) – 输出张量,一般用不到该参数。

out参数的理解:很多torch函数有out参数,这主要是因为torch没有tf.cast()这类的类型转换函数,也少有dtype参数指定输出类型,所以需要事先建立一个输出Tensor为LongTensor、IntTensor等等,再由out导入。

参考:https://www.cnblogs.com/hellcat/p/8445372.html

举例说明

import torch

a = torch.arange(9).reshape(3, 3)   # 创建3*3的tensor
b = torch.clamp(a, 3, 6)     # 对a的值进行限幅,限制在[3, 6]
print('a:', a)
print('shape of a:', a.shape)
print('b:', b)
print('shape of b:', b.shape)


'''   输出结果   '''
a: tensor([[0, 1, 2],
           [3, 4, 5],
           [6, 7, 8]])
shape of a: torch.Size([3, 3])

b: tensor([[3, 3, 3],
           [3, 4, 5],
           [6, 6, 6]])
shape of b: torch.Size([3, 3])

 

本文章已经生成可运行项目
怎么在下面循环里面把I_0换成不同的图片,代码:for e in range(config.retrain_epoch): with tqdm(trainLoader, dynamic_ncols=True) as tqdmtrainLoader: for i, (images, labels) in enumerate(tqdmtrainLoader): # train snr = config.SNRs - CHDDIM_config.large_snr x_0 = images.to(config.device) I_0 = images.to(config.device) feature, _ = encoder(x_0) feature_I, _ = encoder(I_0) y_0 = feature y_I = feature_I y_main, pwr_main, h_main = pass_channel.forward(y_0, snr) # normalize y_interf, pwr_interf, h_interf = pass_channel.forward(y_I, snr) sigma_square = 1.0 / (2 * 10 ** (snr / 10)) # y = y / math.sqrt(1 + sigma_square) # 这里可能改一下 y_main_in = y_main / math.sqrt(1 + sigma_square) y_interf_in = y_interf / math.sqrt(1 + sigma_square) y_cat = y_main_in + y_interf_in h_cat = torch.cat([h_main, h_interf], dim =1) feature_hat = sampler(y_cat, snr, snr + CHDDIM_config.large_snr, h_cat, config.channel_type) feature_hat = feature_hat * torch.sqrt(pwr_main) x_0_hat = decoder(feature_hat) # mse1=torch.nn.MSEloss()() if config.loss_function == "MSE": loss = torch.nn.MSELoss()(x_0, x_0_hat) elif config.loss_function == "MSSSIM": loss = CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean() else: raise ValueError optimizer_decoder.zero_grad() loss.backward() optimizer_decoder.step() # optimizer_encoder.step() if config.loss_function == "MSE": mse = torch.nn.MSELoss()(x_0 * 255., x_0_hat.clamp(0., 1.) * 255) psnr = 10 * math.log10(255. * 255. / mse.item()) matric = psnr elif config.loss_function == "MSSSIM": msssim = 1 - CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean().item() matric = msssim tqdmtrainLoader.set_postfix(ordered_dict={ "dataset": config.dataset, "state": "train_decoder" + config.loss_function, "noise_schedule":CHDDIM_config.noise_schedule, "channel": config.channel_type, "CBR:": feature.numel() / 2 / x_0.numel(), "SNR": snr, "matric": matric, "T_max":CHDDIM_config.t_max }) if (e + 1) % config.retrain_save_model_freq == 0: torch.save(decoder.state_dict(), config.re_decoder_path)
08-27
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

地球被支点撬走啦

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值