pytorch中的深拷贝与浅拷贝

文章介绍了Python中浅拷贝和深拷贝的区别,通过torch库的示例,展示了如何通过不同的方法实现浅拷贝和深拷贝,以及它们在内存地址和数据同步上的影响。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在 Python 中,一切皆为对象,python中有可变对象与不可变对象,Python 中不存在值传递,python中所有变量都是引用,也可以认为是传址。每个变量名都是标签,他们不直接储存数据。可以使用id函数来查看变量id,id相同的变量指代同一内容。

不可变对象:字符串、元组、数字(int、float)
可变对象:数组、字典、集合

可变对象可以被修改,对可变变量的copy有浅拷贝和深拷贝,函数传参如果传的是可变变量,那么传入的参数是浅拷贝,改变该参数,外部参数值也会变化。如果传入参数是不可变对象,函数传参仍是引用,但是在函数内部修改时,会重新创建一个,这会产生深拷贝的效果。

深拷贝:赋值前后两个变量内存地址不同,一个改变不会影响另外一个。

浅拷贝:赋值前后两个变量会共享data的内存地址,数据会同步变化。

python代码验证如下:

import copy
import numpy as np
import torch
import torch.nn as nn
def sum_tem(a, ch=0):
    # 函数传入为引用,浅拷贝
    if ch == 0:
        b = a  # 浅拷贝
        b += 1  # 浅拷贝,就地修改运算会改变原来的值
    elif ch == 1:
        b = a + 0  # 深拷贝,右边进行运算
        b += 1
    elif ch == 2:
        b = a
        b = b + 0  # 深拷贝,右边进行运算
        b += 1
    elif ch == 3:
        b = np.array(a)  # 深拷贝
        # b = a.numpy()  # 浅拷贝,.numpy()是浅拷贝
        # b = torch.from_numpy(b)  # 浅拷贝
        b = torch.Tensor(b)  # 浅拷贝,.numpy()是浅拷贝
        b += 1
    elif ch == 4:
        b = a.clone()  # 深拷贝,torch.clone()
        b += 1
    elif ch == 5:
        b = a.view(a.shape)  # 浅拷贝,.view()
        b += 1
    elif ch == 6:
        b = a.permute(1, 0)  # 浅拷贝,.permute()
        b += 1
    elif ch == 7:
        b = copy.copy(a)  # 浅拷贝
        b += 1
    elif ch == 8:
        b = copy.deepcopy(a)  # 深拷贝
        b += 1
    elif ch == 9:
        b = a[:, :]  # 浅拷贝
        b += 1
    elif ch == 10:
        b = a.contiguous()  # 浅拷贝
        b += 1
    elif ch == 11:
        b = a.detach()  # 浅拷贝
        b += 1
    elif ch == 12:
        b = a.transpose(0, 1)  # 浅拷贝
        b += 1
    elif ch == 13:
        b = a.contiguous()  # 如果地址连续,使用contiguous是浅拷贝
        b += 1
    elif ch == 14:
        c = a.transpose(0, 1)
        b = c.contiguous()  # 如果进行其他操作后,地址不连续的话,使用contiguous是深拷贝
        b += 1
    elif ch == 15:
        c = a.contiguous()  # 如果地址连续,使用contiguous是浅拷贝
        b = c.transpose(0, 1)
        b += 1
    elif ch == 16:
        b = a.detach()  # .detach是浅拷贝,新的tensor会脱离计算图,不会牵扯梯度计算。
        b += 1
    else:
        pass
    return b
for i in range(17):
    a = torch.Tensor([[1, 3], [1, 1]])
    b = sum_tem(a, ch=i)
    if_deep_copy = not sum(sum(a)) == sum(sum(b))
    print(sum(a), sum(b), "ch={}是否是深拷贝:".format(i), if_deep_copy)

首先,函数传入为浅拷贝,在函数中改变参数值,也会影响函数外对应的值。其次,b += 1为原地内存操作,会改变b内存位置的数据。

如果操作是深拷贝,那么对b操作对a不会造成影响。

如果经过运算后sum(sum(a)) == sum(sum(b))为True,那么说明对a造成了影响,是浅拷贝,否则为深拷贝。

结论:

下面情况是深拷贝

  • 等号右方进行运算,b = a + 0 和 b = b + 0 都算
  • 转成 numpy 数组函数 np.array()
  • copy.deepcopy(a)
  • torch.clone() 或 .clone()
  • a.transpose(0, 1).contiguous()

下面情况时浅拷贝

  • 函数传参
  • 直接赋值 b = a
  • 就地运算符 b += 1
  • 转成 numpy 数组函数 .numpy()
  • numpy 数组转成 pytorch 张量 torch.from_numpy() 和 torch.Tensor()
  • 张量维度变换函数 .view() 、.permute()、.transpose()
  • .detach()

contiguous() 返回一个内存连续的有相同数据的tensor,如果原tensor内存连续,则返回原tensor,为浅拷贝,如果不连续,则为深拷贝。 一般与transpose,permute,view等方法搭配使用。

对 pytorch 张量维度的操作函数

transpose、permute、reshape、view

  • transpose、permute函数作用是交换张量的维度,在维度变换操作后,tensor在内存中不再是连续存储的;transpose 只能一次转换两个维度,permute 可以一次转换多个维度;permute 和 transpose 只能在已有的维度之间转换,可以实现转置操作。

  • reshape 和 view函数作用是重新设置维度,包括拆分一个维度成多个连续的,或者组合多个相邻的维度成一个;可以创建新维度。

  • view操作要求tensor的内存连续存储,如果使用了上面transpose、permute函数操作了张量,导致内存不连续,就需要contiguous来返回一个contiguous copy 使内存连续,才能进行view;torch.reshape() 不需要内存连续,其大致相当于 tensor.contiguous().view()。

参考链接:
pytorch中reshape()、view()、permute()、transpose()总结

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值