直接用kmp算法
class Solution {
public:
int strStr(string haystack, string needle) {
return kmp(haystack,needle);
}
int kmp(std::string &text,std::string &pattern){
int n = text.size();
int m = pattern.size();
if(m == 0)
return 0;
std::vector<int> next;
next.reserve(m);
getNext(pattern,m,next);
int j = -1;
for(int i = 0;i < n;i++){
while(j != -1 && text[i] != pattern[j+1]){
j = next[j];
}
if(text[i] == pattern[j+1]){
j++;
}
if(j == m-1)
return i - j;
}
return -1;
}
void getNext(std::string &pattern,int len,std::vector<int> &next){
next[0] = -1;
int j = -1;
for(int i = 1;i < len;i++){
while(j != -1 && pattern[i] != pattern[j+1]){
j = next[j];
}
if(pattern[i] == pattern[j+1])
j++;
next[i] = j;
}
}
};
或者另外一种写法,用的是部分匹配值(Partial Match)表
class Solution {
public:
int strStr(string haystack, string needle) {
return KMP(haystack,needle);
}
void getPM(std::string &pattern,int len,std::vector<int> &pm){
pm[0] = 0;
int j = 0;// j表示前缀的末尾元素的索引位置,同时它也是最长公共前后缀的长度
//i表示后缀的末尾元素的索引位置
for(int i = 1;i < len;i++){
while(j > 0 && pattern[i] != pattern[j])
j = pm[j-1];
if(pattern[j] == pattern[i]) j++;
pm[i] = j;
}
}
int KMP(std::string &text,std::string& pattern){
int m = pattern.size();
if(m == 0) return 0;
std::vector<int> pm;
pm.resize(m);
getPM(pattern,m,pm);
int n = text.size();
int j = 0;
for(int i = 0;i < n;i++){
while(j > 0 && text[i] != pattern[j])
j = pm[j-1];
if(text[i] == pattern[j]) j++;
if(j == m)
return i-j+1;
}
return -1;
}
};
c语言实现:
int kmp(char* text_str,char* pattern_str);
int strStr(char* haystack, char* needle) {
return kmp(haystack,needle);
}
//计算部分匹配值(Partial Match)表
void get_pm(char* pattern_str,int pm[]){
int len = strlen(pattern_str);
pm[0] = 0;
int j = 0;//j表示前缀的末尾元素的位置索引,同时j也是当前最长公共前后缀的长度
//i表示后缀的末尾元素的位置索引
for(int i = 1;i <len;i++){
while(j > 0 && pattern_str[i] != pattern_str[j])
j = pm[j-1];
if(pattern_str[i] == pattern_str[j])
j++;
pm[i] = j;
}
}
int kmp(char* text_str,char* pattern_str){
if(!text_str || !pattern_str) return -1;
int m = strlen(pattern_str);
if(m == 0) return 0;//空模式串约定返回0
int *pm = (int*) malloc(m*sizeof(int));
get_pm(pattern_str,pm);
int n = strlen(text_str);
int j = 0;//模式串的当前匹配位置
for(int i = 0;i < n;i++){
while(j > 0 && text_str[i] != pattern_str[j])
j = pm[j-1];
if(text_str[i] == pattern_str[j])
j++;
if(j == m){
free(pm);
return i - m + 1;
}
}
free(pm);
return -1;
}