1.广播机制
- 如果两个张量大小不相同,运算时张量参数可以自动扩展为相同大小
- 广播机制需要满足的条件
- 每个张量至少有一个维度
- 满足右对齐
import torch
a = torch.rand(2,3)
b = torch.rand(3)
c = a + b
print("a:\n",a)
print("b:\n",b)
print("c:\n",c)
print("c.shape:",c.shape)
2.取整/取余运算
import torch
a = torch.rand(2,3)
a = a * 10
print("a:\n",a)
#向下取整
print("向下取整:\n",torch.floor(a))
#向上取整
print("向上取整:\n",torch.ceil(a))
#四舍五入
print("四舍五入:\n",torch.round(a))
#只取整数部分
print("只取整数部分:\n",torch.trunc(a))
#只取小数部分
print("只取小数部分:\n",torch.frac(a))
#取余
print("取余:\n",a % 2)
3.比较运算
- torch.eq():给每个数据进行等式操作,相同的返回True,不相同返回False
- torch.equal():如果两个tensor有相同的大小和元素,则返回True
- torch.ge():大于等于
- torch.gt():大于
- torch.le():小于等于
- torch.lt():小于
- torch.ne():不等于
import torch
a = torch.rand(2,3)
b = torch.rand(2,3)
print("a:\n",a)
print("b:\n",b)
print('########################')
print("eq:\n",torch.eq(a,b))
print("equal:\n",torch.equal(a,b))
print("ge:\n",torch.ge(a,b))
print("gt:\n",torch.gt(a,b))
print("le:\n",torch.le(a,b))
print("lt:\n",torch.lt(a,b))
print("ne:\n",torch.ne(a,b))
4.排序
- torch.sort(input,dim,descending,out)
- 对tensor进行排序
- 参数input:进行操作的tensor
- 参数dim:维度设置
- 参数descending:True为降序,False为升序
- 参数out:用于指定输出结果的存储位置
import torch
a = torch.tensor([[1,4,4,3,5],[2,3,1,7,5]])
print("a:\n",a)
print('########################')
print(torch.sort(a,dim=0,descending=True))
- torch.topk(input,dim,largest,sorted,out)
- 沿着指定维度返回k个最值及其索引值
- 参数k:取几个最值
- 参数largest:True为取最大值,False为取最小值
- 其他参数同上
import torch
a = torch.tensor([[1,4,4,3,5],[2,3,1,7,5]])
print("a:\n",a)
print('########################')
print(torch.topk(a,k=2,dim=1,largest=True,sorted=False))
- torch.kthvalue(input,k,dim,out)
- 沿着指定维度返回第k个最小值及其索引值
- 参数k:取第几个最值
- 其他参数同上
import torch
a = torch.tensor([[1,4,4,3,5],[2,3,1,7,5]])
print("a:\n",a)
print('########################')
print(torch.kthvalue(a,k=2,dim=0))
print(torch.kthvalue(a,k=2,dim=1))
5.判断是否有界、无界或为none
import torch
a = torch.rand(2,3)
print("a:\n",a)
print('########################')
print("判断是否有界:\n",torch.isfinite(a))
print("判断是否无界:\n",torch.isinf(a))
print("判断是否空:\n",torch.isnan(a))
6.三角函数
import torch
a = torch.zeros(2,3)
print("a:\n",a)
print('########################')
print("余弦值:\n",torch.cos(a))
print("正弦值:\n",torch.sin(a))
print("正切值:\n",torch.tan(a))
知识点为听课总结笔记,课程为B站“2025最新整合!公认B站讲解最强【PyTorch】入门到进阶教程,从环境配置到算法原理再到代码实战逐一解读,比自学效果强得多!”:2025最新整合!公认B站讲解最强【PyTorch】入门到进阶教程,从环境配置到算法原理再到代码实战逐一解读,比自学效果强得多!_哔哩哔哩_bilibili