pytorch: torch.Tensor.view ------ reshape

本文详细介绍了PyTorch中Tensor的view方法及view_as方法的使用,包括如何改变Tensor的形状而不改变其元素数量,以及通过示例展示view方法在实际应用中的效果。

torch.Tensoe.view(python method, in torch.Tensor)

作用: 将输入的torch.Tensor改变形状(size)并返回.返回的Tensor与输入的Tensor必须有相同的元素,相同的元素数目,但形状可以不一样

即,view起到的作用是reshape,view的参数的是改变后的shape.

示例如下:

>>> x = torch.randn(4, 4)
>>> x.size()
torch.Size([4, 4])
>>> y = x.view(16)
>>> y.size()
torch.Size([16])
>>> z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])

view_as:

        tensor_1.view_as(tensor_2):将tensor_1的形状改成与tensor_2一样

 

class AttentionStaticQuantPattern: def __init__( self, layer_name: str, num_heads: int, head_size: int, quant_dtype: torch.dtype, symmetric=True, ): self.layer_name = layer_name self.num_heads = num_heads self.head_size = head_size self.quant_dtype = quant_dtype self.quant_key = QuantKey(dtype=quant_dtype, static=True, group_shape=GroupShape.PER_TENSOR, symmetric=symmetric) assert self.quant_key in QUANT_OPS, \ f"unsupported quantization scheme {self.quant_key}" self.QUANT_OP = QUANT_OPS[self.quant_key] def empty_quant(self, *args, **kwargs): kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} return torch.empty(*args, **kwargs) def register_if_supported(self, pm_pass: PatternMatcherPass, layer: Attention): if layer.impl.fused_output_quant_supported(self.quant_dtype, self.quant_key.static, self.quant_key.group_shape): self._register(pm_pass) def _register(self, pm_pass: PatternMatcherPass): def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor, scale: torch.Tensor): view_7 = RESHAPE_OP(output_attn, [-1, self.num_heads, self.head_size]) at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, output=view_7, layer_name=self.layer_name, output_scale=None) attn_out_view = RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) at2 = auto_functionalized(self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale) return at2[1] def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, output_quant: torch.Tensor, scale: torch.Tensor): view_7 = RESHAPE_OP(output_quant, [-1, self.num_heads, self.head_size]) at1 = auto_functionalized(ATTN_OP, query=q, key=k, value=v, output=view_7, layer_name=self.layer_name, output_scale=scale) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) # Need custom fake mode, otherwise tracing happens with real tensors. # That would not work for the unified_attention custom op. with unset_fake_temporarily(), FakeTensorMode(): inputs = [ empty_bf16(5, self.num_heads, self.head_size), # q empty_bf16(5, self.num_heads, self.head_size), # k empty_bf16(5, self.num_heads, self.head_size), # v empty_bf16(5, self.num_heads * self.head_size), # attn_output self.empty_quant(5, self.num_heads * self.head_size), # quant_output empty_fp32(1, 1) # scale ] def wrap_trace_fn(process_fx, trace_fn): def wrapped(*args, **kwargs): return process_fx(trace_fn(*args, **kwargs)) return wrapped def fx_view_to_reshape(gm: torch.fx.GraphModule): from torch._inductor.fx_passes.post_grad import view_to_reshape view_to_reshape(gm) return gm pm.register_replacement( pattern, replacement, inputs, wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass)这段代码在干什么
07-29
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值