撸一遍代码,加深理解!有些还没有搞懂的部分,希望大家不吝赐教~
源码地址:https://github.com/dav/word2vec/blob/master/src/word2vec.c
// Copyright 2013 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <pthread.h>
#define MAX_STRING 100
#define EXP_TABLE_SIZE 1000
#define MAX_EXP 6
#define MAX_SENTENCE_LENGTH 1000 //句子最大长度,超过的进行切分
#define MAX_CODE_LENGTH 40 //huffman编码最大长度
const int vocab_hash_size = 30000000; // Maximum 30 * 0.7 = 21M words in the vocabulary
typedef float real; // Precision of float numbers
struct vocab_word { //用于记录单词在构造出的Huffman树中的节点信息
long long cn; //词频
int *point; //从根节点到相应单词所在节点的路径
char *word, *code, codelen; //单词,Huffman编码,编码长度
};
char train_file[MAX_STRING], output_file[MAX_STRING];
char save_vocab_file[MAX_STRING], read_vocab_file[MAX_STRING];
struct vocab_word *vocab;
int binary = 0, cbow = 1, debug_mode = 2, window = 5, min_count = 5, num_threads = 12, min_reduce = 1;
int *vocab_hash;
long long vocab_max_size = 1000, vocab_size = 0, layer1_size = 100;
long long train_words = 0, word_count_actual = 0, iter = 5, file_size = 0, classes = 0;
real alpha = 0.025, starting_alpha, sample = 1e-3;
real *syn0, *syn1, *syn1neg, *expTable;
clock_t start;
int hs = 0, negative = 5;
const int table_size = 1e8;
int *table;
void InitUnigramTable() { //用词频初始化单词的能量表,用于负采样
int a, i;
double train_words_pow = 0;
double d1, power = 0.75; //0.75的指数是为了让低词频的单词,有更多机会被抽中
table = (int *)malloc(table_size * sizeof(int));
for (a = 0; a < vocab_size; a++) train_words_pow += pow(vocab[a].cn, power); //计算总能量
i = 0;
d1 = pow(vocab[i].cn, power) / train_words_pow; //能量比例
for (a = 0; a < table_size; a++) {
//table中下一个等分位置
table[a] = i; //vocab_size是去重后的单子个数,table_size是去重前总单词个数(这句不应该出现在下一个if块之后吗?)
if (a / (double)table_size > d1) {
i++; //vocab中下一个单词
d1 += pow(vocab[i].cn, power) / train_words_pow; //加上下一个单词的能量比例
}
if (i >= vocab_size) i = vocab_size - 1; //单词index不能大于或等于vocab_size
}
}
// Reads a single word from a file, assuming space + tab + EOL to be word boundaries
void ReadWord(char *word, FILE *fin) { //从文件流中,读取**一个单词**
int a = 0, ch;
while (!feof(fin)) { //遇到文件结束符返回非0值,否则返回0
ch = fgetc(fin); //从文件中读取一个字符
if (ch == 13) continue; //遇到回车符('\r',ASCII码==13)继续
if ((ch == ' ') || (ch == '\t') || (ch == '\n')) { //单词分隔符
if (a > 0) { //中间遇到换行符,直接跳出循环
if (ch == '\n') ungetc(ch, fin);
break;
}
if (ch == '\n') { //起始位置遇到换行符,返回特定的单词("</s>")
strcpy(word, (char *)"</s>");
return;
} else continue;//所以,"New York"会读为"NewYork"?水平制表符也忽略?
}
word[a] = ch; //读取到的字符,放入word字符数组中
a++; //word字符数组下标+1
if (a >= MAX_STRING - 1) a--; // Truncate too long words
}
word[a] = 0; //word字符数组最后一位,置为空字符
}
// Returns hash value of a word
int GetWordHash(char *word) { //构造一个hash表,快速查找word在vocab中的index
unsigned long long a, hash = 0;
for (a = 0; a < strlen(word); a++) hash = hash * 257 + word[a];
hash = hash % vocab_hash_size;
return hash;
}
// Returns position of a word in the vocabulary; if the word is not found, returns -1
int SearchVocab(char *word) {
unsigned int hash = GetWordHash(word); //先找到hash表的index
while (1) { //vocab_hash[hash]是单词在vocab中的index
if (vocab_hash[hash] == -1) return -1;
if (!strcmp(word, vocab[vocab_hash[hash]].word)) return vocab_hash[hash]; //如果找到,返回单词在vocab中的index
hash = (hash + 1) % vocab_hash_size;//hash冲突的话,开放寻址法查找
}
return -1;
}
// Reads a word and returns its index in the vocabulary
int ReadWordIndex(FILE *fin) {
char word[MAX_STRING];
ReadWord(word, fin); //从文件中读一个单词
if (feof(fin)) return -1;
return SearchVocab(word); //返回单词在词汇表中的index
}
// Adds a word to the vocabulary
int AddWordToVocab(char *word) {
unsigned int hash, length = strlen(word) + 1;
if (length > MAX_STRING) length = MAX_STRING;
vocab[vocab_size].word = (char *)calloc(length, sizeof(char));
strcpy(vocab[vocab_size].word, word); //放在vocab末尾
vocab[vocab_size].cn = 0; //词频初始为0,为什么不是1?
vocab_size++; //词汇表size+1
// Reallocate memory if needed
if (vocab_size + 2 >= vocab_max_size) {
vocab_max_size += 1000;
vocab = (struct vocab_word *)realloc(vocab, vocab_max_size * sizeof(struct vocab_word));
}
hash = GetWordHash(word);
while (vocab_hash[hash] != -1) hash = (hash + 1) % vocab_hash_size;
vocab_hash[hash] = vocab_size - 1; //添加到hash表中
return vocab_size - 1; //返回的是添加的单词,在vocab中的index
}
// Used later for sorting by word counts
int VocabCompare(const void *a, const void *b) { //比较词频
return ((struct vocab_word *)b)->cn - ((struct vocab_word *)a)->cn;
}
// Sorts the vocabulary by frequency using word counts
void SortVocab() {
//按词频排序,并忽略掉词频太低的单词
int a, size;
unsigned int hash;
// Sort the vocabulary and keep </s> at the first position
qsort(&vocab[1], vocab_size - 1, sizeof(str

最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



