<think>我们首先需要理解用户的问题:用户希望了解PyTorch中创建张量的方法(特别是`scores.new_tensor(0)`),以及`torch.where`和`gather`函数的使用方法。
根据引用内容,我们可以参考:
引用[2]中提到了`torch.where`和`torch.masked_fill`等操作。
引用[1]和引用[4]中提到了`gather`和`scatter`函数。
下面我们分三部分来回答:
1. `scores.new_tensor(0)` 创建张量
2. `torch.where` 函数的使用方法
3. `gather` 函数的使用方法
注意:用户还提到了`scatter`,但问题中只要求了`gather`,不过由于`gather`和`scatter`是互逆操作,我们可以简单提及,但重点在`gather`。
第一部分:`scores.new_tensor(0)`
在PyTorch中,`new_tensor`是一个方法,它根据已有的张量(这里是`scores`)创建一个新的张量,新张量的数据类型和设备(CPU/GPU)与原有张量相同。
例如,`scores.new_tensor(0)`会创建一个标量张量,其值为0,并且这个新张量与`scores`在同一个设备上,具有相同的数据类型。
第二部分:`torch.where`
引用[2]中提到了`torch.where`,它的作用是根据条件选择两个张量中的元素。函数签名如下:
torch.where(condition, x, y)
其中condition是一个布尔张量,当condition的某个位置为True时,输出张量在该位置取x的对应位置的值,否则取y的对应位置的值。
第三部分:`gather`
引用[1]和引用[4]中提到了`gather`,它的作用是根据索引从输入张量中提取指定位置的元素。函数签名如下:
torch.gather(input, dim, index, out=None, sparse_grad=False)
它沿着指定的维度dim,根据索引index从输入张量中提取元素。索引张量index的形状必须与输出张量相同,且每个位置的值表示在dim维度上的索引。
我们将通过代码示例详细说明。
注意:在回答中,我们按照要求使用LaTeX格式的数学表达式(如张量形状等),并正确引用。</think>### PyTorch张量创建与操作详解
#### 1. `scores.new_tensor(0)` 创建张量
`new_tensor()`方法基于现有张量创建新张量,保留原始张量的设备类型和数据类型属性:
```python
import torch
# 原始张量(假设在GPU上)
scores = torch.tensor([1.0, 2.0], device='cuda', dtype=torch.float32)
# 创建新张量(继承设备类型和数据类型)
new_tensor = scores.new_tensor(0) # 标量0
print(new_tensor) # tensor(0., device='cuda:0')
print(new_tensor.device) # cuda:0
print(new_tensor.dtype) # torch.float32
```
**关键特性**:
- 继承原始张量的设备(CPU/GPU)和数据类型
- 输入可以是标量、列表、NumPy数组等
- 等价操作:`torch.tensor(0, device=scores.device, dtype=scores.dtype)`
#### 2. `torch.where()` 条件选择
`torch.where()`根据条件选择两个张量的元素,语法为:
```python
torch.where(condition, x, y) # 当condition为True时选x,否则选y
```
**示例:阈值处理**(引用[2]):
```python
t = torch.tensor([[3, 7], [2, 5], [9, 1]])
# 大于5的元素保留,否则置0
result = torch.where(t > 5, t, torch.zeros_like(t))
print(result)
# tensor([[0, 7],
# [0, 0],
# [9, 0]])
```
**高级应用**:
```python
# 创建掩码
mask = torch.tensor([[True, False], [False, True]])
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[10, 20], [30, 40]])
# 根据掩码选择元素
combined = torch.where(mask, a, b)
print(combined) # tensor([[ 1, 20], [30, 4]])
```
#### 3. `torch.gather()` 索引收集
`gather()`根据索引在指定维度收集元素,语法为:
```python
output = torch.gather(input, dim, index)
```
**基本用法**(引用[1][4]):
```python
data = torch.tensor([[1, 2], [3, 4], [5, 6]]) # 3x2
index = torch.tensor([[0, 0], [1, 0], [0, 1]]) # 3x2
# 沿dim=1收集
result = torch.gather(data, 1, index)
print(result)
# tensor([[1, 1], # 取第0行: [1,2]中的[0,0]位置
# [4, 3], # 取第1行: [3,4]中的[1,0]位置
# [5, 6]]) # 取第2行: [5,6]中的[0,1]位置
```
**实际应用:one-hot编码**(引用[4]):
```python
labels = torch.tensor([2, 0, 1]) # 3个样本的类别标签
one_hot = torch.zeros(3, 3) # 3x3零矩阵
# 使用gather生成one-hot
one_hot.scatter_(1, labels.unsqueeze(1), 1)
print(one_hot)
# tensor([[0., 0., 1.],
# [1., 0., 0.],
# [0., 1., 0.]])
```
#### 4. 相关函数对比
| 函数 | 用途 | 特点 |
|------|------|------|
| `new_tensor()` | 创建继承属性的新张量 | 保留设备/数据类型[^1] |
| `torch.where()` | 条件选择元素 | 类似三元运算符[^2] |
| `torch.gather()` | 按索引收集数据 | 需指定维度[^1][^4] |
| `torch.scatter()` | 按索引写入数据 | `gather`的逆操作[^1][^4] |
> 引用说明:`new_tensor`继承设备类型和数据类型属性[^1];`where`实现条件选择[^2];`gather`和`scatter`是索引操作的互补函数[^1][^4]。
---