【python】getattr 和 setattr的用法,以及通过它们来动态的定义pytorch模块

1. 关于getattr 和 setatt 用法

在 Python 中,getattr 和 setattr 是两个非常有用的内置函数,它们允许我们以动态的方式访问和修改对象的属性。这些函数为我们提供了在运行时通过字符串来操作对象属性的强大能力,可以使代码更加灵活和可扩展。本文将详细介绍这两个函数的作用、用法以及实际应用场景。

a.getattr 函数

作用

getattr 用于获取对象的属性值。通过给定对象和属性名,getattr 可以在运行时访问指定的属性。

语法

getattr(object, name[, default])

object:要访问的对象。
name:属性名,必须是字符串。
default(可选):如果指定的属性不存在,返回该默认值。如果未提供,getattr 会抛出 AttributeError。
返回值
返回对象指定属性的值。如果属性不存在且未提供默认值,将抛出 AttributeError。
示例

class MyClass:
    def __init__(self):
        self.x = 10

obj = MyClass()

# 获取属性 x
value = getattr(obj, 'x')
print(value)  # 输出: 10

# 如果属性不存在,使用默认值
value = getattr(obj, 'y', "default_value")
print(value)  # 输出: default_value

在上述代码中,我们首先通过 getattr 获取了 obj 对象的 x 属性。如果 x 属性不存在,我们还提供了一个默认值 default_value,这样可以避免抛出错误。

b.getattr 函数

作用
setattr 用于为对象的属性赋值。如果指定的属性不存在,setattr 会动态地创建该属性并赋值。

语法

setattr(object, name, value)

object:要操作的对象。
name:属性名,必须是字符串。
value:要赋的值。
返回值
setattr 没有返回值,但会直接修改对象的属性。
示例

class MyClass:
    pass

obj = MyClass()

# 动态设置属性 x
setattr(obj, 'x', 42)
print(obj.x)  # 输出: 42

# 修改已存在的属性
setattr(obj, 'x', 100)
print(obj.x)  # 输出: 100

在这个例子中,我们使用 setattr 动态地为 obj 对象添加了一个属性 x,并赋值为 42。随后,我们又修改了该属性的值为 100。

2. 利用getattr 和 setatt 定义动态的pytorch模块,参考如下代码:

class Slot_cross_attention(nn.Module):
    def __init__(self,
        latent_dim,
        context_dim=None,      
        cross_heads = 8, 
        dim_head = 64, 
        dropout = 0, 
        attn_type = 'transformer', 
        more_dropout = 0.1, 
        net_depth=1    #单层cross  attention中执行的次数
        ):
        super().__init__()    
        
        # 左右对称的两个attention
        self.net_depth = net_depth  # 保存 net_depth 为类属性
        for i in range(1, net_depth + 1):
            attention_blocks_l = nn.ModuleList([  
                Attention_sim(query_dim=latent_dim, context_dim=context_dim, heads=cross_heads,
                          dim_head=dim_head, dropout=dropout, attn_type=attn_type,
                          more_dropout=more_dropout),
                nn.LayerNorm(latent_dim),
                ThinFeedForward(dim=latent_dim),
                nn.LayerNorm(latent_dim),
            ])

            attention_blocks_r = nn.ModuleList([ 
                Attention_sim(query_dim=latent_dim, context_dim=context_dim, heads=cross_heads,
                          dim_head=dim_head, dropout=dropout, attn_type=attn_type,
                          more_dropout=more_dropout),
                nn.LayerNorm(latent_dim),
                ThinFeedForward(dim=latent_dim),
                nn.LayerNorm(latent_dim),
            ])

            setattr(self, f"attention_blocks_l_{i}", attention_blocks_l)
            setattr(self, f"attention_blocks_r_{i}", attention_blocks_r)

            self.fusion_linear=nn.Linear(latent_dim*2, latent_dim, bias=False)
   
    def forward(self,x_l,x_r): 

        for i in range(1, self.net_depth + 1):  # 遍历生成的 attention_blocks
            blocks_l = getattr(self, f"attention_blocks_l_{i}")
            blocks_r = getattr(self, f"attention_blocks_r_{i}")

            attention_l, norm_l1, feed_forward_l,norm_l2 = blocks_l  # 解包 blocks_l
            x1 = attention_l(x_l, x_r)                            # 按顺序调用
            x1 = norm_l1(x1)
            x1 = feed_forward_l(x1)
            x1 = norm_l2(x1)

            attention_r, norm_r1, feed_forward_r,norm_r2 = blocks_r  # 解包 blocks_r
            x2 = attention_r(x_r,x_l)                              # 按顺序调用
            x2 = norm_r1(x2)
            x2 = feed_forward_r(x2)
            x2 = norm_r2(x2)

            x = torch.cat((x1,x2),dim=1)
            x = self.fusion_linear(x)
            
        return x
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值