自注意力机制学习有感
- 观看b站博主的讲解视频以及跟着他的pytorch代码实现mindspore的自注意力机制:
- up主讲的很好,推荐入门自注意力机制。
import mindspore as ms
import mindspore.nn as nn
from mindspore import Parameter
from mindspore import context
context.set_context(device_target='Ascend',max_device_memory='1GB')
class SelfAttention(nn.Cell):
def __init__(self, dim):
super(SelfAttention, self).__init__()
wq_data = [[1.0, 0], [1., 1.]]
wk_data = [[0., 1.], [1., 1.]]
wv_data = [[0., 1., 1.], [1., 0., 0.]]
self.q = nn.Dense(in_channels=dim, out_channels=2, has_bias=False)
self.q.weight.set_data(ms.Tensor(wq_data).T)
print("wq value:", self.q.weight.value())
self.k = nn.Dense(in_channels = dim, out_channels=2, has_bias=False)
self.k.weight.set_data(