pytorch动手实现skipgram模型

最近用pytorch实现了一下skipgram模型,代码参考了github哈,这里分享了我的实现:

  • 下载数据集
import nltk
nltk.download('twitter_samples')
from nltk.corpus import twitter_samples
  • 添加label
label = "neg"
neg_dataset = [(label, instance) for instance in twitter_samples.strings('negative_tweets.json')]

label = "pos"
pos_dataset = [(label, instance) for instance in twitter_samples.strings('positive_tweets.json')]

print(neg_dataset[:2])
print(pos_dataset[:2])
  • 划分数据集
training_data = neg_dataset[:4000] + pos_dataset[:4000]
testing_data = neg_dataset[4000:] + pos_dataset[4000:]


####### relatively small dataset, can be used at development phase #######
dev_training_data = neg_dataset[:100] + pos_dataset[:100]
dev_testing_data = neg_dataset[100:200] + pos_dataset[100:200]
  • 存入文件中
import pickle
# write to cPickle
pickle.dump(training_data, open( "training_data.pkl", "wb" ) )
pickle.dump(testing_data, open( "testing_data.pkl", "wb" ) )

####### relatively small dataset, can be used at development phase #######
pickle.dump(dev_training_data, open( "dev_training_data.pkl", "wb" ) )
pickle.dump(dev_testing_data, open( "dev_testing_data.pkl", "wb" ) )
  • 加载数据集
training_data = pickle.load(open("training_data.pkl","rb"))
testing_data = pickle.load(open("testing_data.pkl","rb"))
  • 数据预处理
from nltk.tokenize import TweetTokenizer
tknzr = TweetTokenizer()
train_data1=[tknzr.tokenize(item[1]) for item in training_data]
testing_data1=[tknzr.tokenize(item[1]) for item in testing_data]
from collections import Counter

min_word_freq = 3
def get_words_map_and_words_count(train_sentences):
    words_freq = Counter()
    for sen in train_sentences:  # 遍历每个句子,更新词频表
        words_freq.update(sen)
    words_freq.update(['<unk>'])
    words = [w for w in words_freq.keys() if words_freq[w] > min_word_freq]  # 根据最小词频,筛选出符合条件的词
    words_map = {v: k for k, v in enumerate(words,1)}  # 词汇表
    words_map['<unk>']=0
    words.append('<unk>')
    words_count = [words_freq[k] for k, v in words_map.items()]  # 词频统计
    return words_map, words_count, words
words_map, words_count, words = get_words_map_and_words_count(train_data1)
def get_clean_train(train_sentences, words):
    res, temp_sen = [], []
    for i, sen in enumerate(train_sentences):
        print('the {}th sentence: {}'.format(i, sen))
        for ch in sen:
            if ch in words: temp_sen.append(ch)
            else: temp_sen.append('<unk>')
        res.append(temp_sen)
        temp_sen = []
    return res
train_clean = get_clean_train(train_data1, words)
from collections import defaultdict

def get_center_context_pairs(train_sentences, words_map, window_size):
    pairs, pair = [], defaultdict(list)
    max_dis = window_size // 2
    # 每个句子
    for i, sen in enumerate(train_sentences):
        if i%1000 == 0: print('the {}th sentence'.format(i))
        if len(sen) < 5: continue
        # 每个中心词
        for j in range(0, len(sen)):
            center_id = words_map.get(sen[j], words_map['<unk>'])
            # 根据中心词和窗口大小, 选择每个上下文词
            start = j-max_dis if j>max_dis else 0
            end = j+max_dis+1 if j<len(sen)-max_dis else len(sen)
            for idx in range(start, j): # 左半部分
                pair[center_id].append(words_map.get(sen[idx], words_map['<unk>']))    
            for idx in range(j+1, end): # 右半部分
                pair[center_id].append(words_map.get(sen[idx], words_map['<unk>'])) 
            pairs.append(pair)
            pair = defaultdict(list)
    return pairs
Config = {
    'learning_rate':0.001,
    'embedding_dim':50,
    'window_size':5,
    'batch_size':64,
    'epoch':6,
    'weight_decay':0.5,
    'neg_num':12
}
train_pairs = get_center_context_pairs(train_clean, words_map, Config['window_size'])
reverse_words_map = {v: k for k, v in words_map.items()}
  • 模型实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class SkipGramModel(nn.Module):
    def __init__(self, device, vocab_size, embedd_dim, neg_num=0, word_count=[]):
        super(SkipGramModel, self).__init__()
        self.device = device
        self.neg_num = neg_num
        self.embeddings = nn.Embedding(vocab_size, embedd_dim)
        if self.neg_num > 0:
            self.table = create_sample_table(word_count)
        
    def forward(self, centers, context):
        '''
            params: centers: batch_size, 1
            params: context: batch_size, 1
        '''
        batch_size = len(centers)
        u_embeddings = self.embeddings(centers).view(batch_size, 1, -1)  # batch_size, 1, embedd_dim
        v_embeddings = self.embeddings(context).view(batch_size, 1, -1)  # batch_size. 1, embedd_dim
#         print('size check for u and v embedding: ', u_embeddings.size(), v_embeddings.size())
        score = torch.bmm(u_embeddings, v_embeddings.transpose(1, 2)).squeeze()  # batch_size,
        loss = F.logsigmoid(score).squeeze() # batch_size,
        #print('size check for loss: ', loss.size())
        if self.neg_num > 0:
            neg_contexts = torch.LongTensor(np.random.choice(self.table, size=(batch_size, self.neg_num))).to(self.device)
            
            neg_v_embeddings = self.embeddings(neg_contexts)  # batch_size, neg_num, embedd_dim
            neg_score = torch.bmm(u_embeddings, neg_v_embeddings.transpose(1, 2).neg()).squeeze()  # batch_size, neg_num
            neg_score = F.logsigmoid(neg_score).squeeze()  # batch_size,
            neg_score = torch.sum(neg_score, dim=1) # batch_size, 
            #print('size check for neg_score: ', neg_score.size())
            assert loss.size() == neg_score.size()
            loss += neg_score
        return -1 * loss.sum()
    
    def get_embeddings(self):
        return self.embeddings.weight.data
  • word embedding 预处理
def get_all_pairs(train_pairs):
    all_data = []
    center_idx, context_idx = [], []
    for pair in train_pairs:
        for k, v in pair.items():
            center_idx = k
            for x in v:
                context_idx.append(x)
        for con in context_idx:
            all_data.append((center_idx, con))
        center_idx, context_idx = [], []
    return all_data

all_data = get_all_pairs(train_pairs)
all_data = list(set(all_data))
print(all_data[:10])
print(len(all_data))
def get_batch_pairs(batch_size, all_data):
    batch_center, batch_context = [], []
    batch_index = np.random.choice(len(all_data), batch_size)
    for idx in batch_index:
        t_center, t_context = all_data[idx]
        batch_center.append(t_center)
        batch_context.append(t_context)
    return torch.LongTensor(batch_center), torch.LongTensor(batch_context)
  • 构建word embedding模型
TABLE_SIZE = 1e8
def create_sample_table(words_count):
    table = []
    freq = np.power(np.array(words_count), 0.75)
    print('freq: ', freq)
    sum_freq = sum(freq)
    ratio = freq / sum_freq
    count = np.round(ratio * TABLE_SIZE)
    print('count: ', count)
    for word_id, c in enumerate(count): # 处理后的词频
        table += [word_id] * int(c)
    return np.array(table)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = SkipGramModel(device=device, vocab_size=len(words), embedd_dim=Config['embedding_dim'], neg_num=Config['neg_num'], word_count=words_count)
  • 训练word embedding模型
import torch.optim as optim
optimizor = optim.Adam(model.parameters(), lr=Config['learning_rate'], weight_decay=Config['weight_decay'])
from tqdm import tqdm
iter_time = len(train_pairs) // Config['batch_size']
loss_items = []
batch_size = Config['batch_size']
for e in tqdm(range(Config['epoch'])):
    for i in range(iter_time):
        center, context = get_batch_pairs(batch_size, all_data)
#         print('size check: ', center.size(), context.size())
        
        loss = model(center, context)
        if i % 500 == 0: print(reverse_words_map[int(center[0])], reverse_words_map[int(context[0])])   
        loss.backward()
        optimizor.step()
        optimizor.zero_grad()
        loss_items.append(loss.item()) 
        if i % 1000 == 0: 
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.10f}'.format(e, i * batch_size, iter_time * batch_size,
                                                                             100. * i / iter_time, loss.item()))

参考文献

[1].基于 Skip-Gram 实现一个简单的word2vec. https://github.com/created-Bi/SkipGram_NegSampling_Pytorch/blob/main/Lesson9-word2vec.ipynb

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

农民小飞侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值