- 学习词向量的概念
- 用Skip-thought模型训练词向量
- 学习使用PyTorch dataset和dataloader
- 学习定义PyTorch模型
- 学习torch.nn中常见的Module
5.1 Embedding - 学习常见的PyTorch operations
6.1 bmm
6.2 logsigmoid - 保存和读取PyTorch模型
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from collections import Counter
import numpy as np
import random
import math,os
from tqdm import tqdm,trange
import pandas as pd
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity
USE_CUDA = torch.cuda.is_available()
random.seed(1000)
np.random.seed(1000)
torch.manual_seed(1000)
if USE_CUDA:
torch.cuda.manual_seed_all(1000)
K = 100
C = 3
NUM_EPOCHS = 2
MAX_VOCAB_SIZE = 30000
BATCH_SIZE = 128
LEARNING_RATE = 0.1
EMBEDDING_SIZE = 100
DATA_PATH = r'./data/demo10_pytorch_skip-Gram'
TRAIN_DATA = DATA_PATH + os.sep + 'text8.train.txt'
TEST_DATA = DATA_PATH + os.sep + 'text8.test.txt'
with open(TRAIN_DATA,'r',encoding='utf-8') as fin:
text = fin.read()
text = text.split()
text = [w for w in text]
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE - 1))
vocab['<unk>'] = len(text) - np.sum(list(vocab.values()))
idx_to_word = [word for word in vocab.keys()]
word_to_idx = { word:key for key, word in vocab.items()}
word_counts = np.array([count for count in vocab.values()], dtype=np.float32)
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs ** (3./4.)
word_freqs = word_freqs / np.sum(word_freqs)
VOCAB_SIZE = len(idx_to_word)
class WordEmbeddingDataset(tud.Dataset):
def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts):
super(WordEmbeddingDataset, self).__init__()
self.word_to_idx = word_to_idx
self.idx_to_word = idx_to_word
self.word_freqs = torch.Tensor(word_freqs)
self.word_counts = torch.Tensor(word_counts)
self.text_encoded = [word_to_idx.get(t, VOCAB_SIZE-1) for t in text]
self.text_encoded = torch.LongTensor(self.text_encoded)
def __len__(self):
return len(self.text_encoded)
def __getitem__(self, idx):
center_word = self.text_encoded[idx]
pos_indices = list(range(idx-C, idx)) + list(range(idx+1, idx+C+1))
pos_indices = [ i%len(self.text_encoded) for i in pos_indices]
pos_words = self.text_encoded[pos_indices]
neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], replacement=True)
return center_word, pos_words, neg_words
dataset = WordEmbeddingDataset(text, word_to_idx, idx_to_word, word_freqs, word_counts)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
class EmbeddingModel(nn.Module):
def __init__(self, vocab_size, embed_size):
super(EmbeddingModel, self).__init__()
self.vocab_size = vocab_size
self.embed_size = embed_size
initrange = 0.5 / self.embed_size
self.out_embed = nn.Embedding(self.vocab_size, self.embed_size,sparse=False)
self.out_embed.weight.data.uniform_(-initrange, initrange)
self.in_embed = nn.Embedding(self.vocab_size, self.embed_size,sparse=False)
self.in_embed.weight.data.uniform_(-initrange, initrange)
def forward(self, input_labels, pos_lables, neg_labes):
batch_size = input_labels.size(0)
input_embedding = self.in_embed(input_labels)
pos_embedding = self.out_embed(pos_lables)
neg_embedding = self.out_embed(neg_labes)
pos_dot = torch.bmm(pos_embedding, input_embedding.unsqueeze(2)).squeeze()
neg_dot = torch.bmm(neg_embedding, -input_embedding.unsqueeze(2)).squeeze()
log_pos = F.logsigmoid(pos_dot).sum(1)
log_neg = F.logsigmoid(neg_dot).sum(1)
loss = log_pos + log_neg
return -loss
def input_embedding(self):
return self.in_embed.weight.data.cpu().numpy()
model = EmbeddingModel(VOCAB_SIZE, EMBEDDING_SIZE)
if USE_CUDA:
print('use cuda......')
model = model.cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
for epoch in trange(NUM_EPOCHS):
for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
input_labels = input_labels.long()
pos_labels = pos_labels.long()
neg_labels = neg_labels.long()
if USE_CUDA:
input_labels = input_labels.cuda()
pos_labels = pos_labels.cuda()
neg_labels = neg_labels.cuda()
optimizer.zero_grad()
loss = model(input_labels, pos_labels, neg_labels).mean()
loss.backward()
optimizer.step()
if i%100 == 0:
print("epoch", epoch, i, loss.item())