AcWing 3508.最长公共子串
题目描述
做法:字符串哈希+二分
字符串哈希
基于进制转换的哈希算法。它的基本思路是把字符串看作一个进制为 p(比如 p = 131 或 p = 13331)的整数,然后用这个整数作为哈希值。具体地,假设字符串 s 共有 n 个字符,第 i 个字符的 ASCII 码为 a i a_i ai,则 s 的哈希值为:
h a s h ( s ) = ∑ i = 0 n − 1 a i ⋅ p i \displaystyle hash(s) = \sum\limits_{i=0}^{n-1} a_i \cdot p^i hash(s)=i=0∑n−1ai⋅pi
这个哈希函数的过程可以用快速幂算法来计算,复杂度为 O(n)。但是为了提高求哈希值的效率,我们可以使用预处理的方式,在 O(n) 的时间内计算出字符串 s 中所有前缀的哈希值,并保存在一个数组 h 中:
h i = ∑ j = 1 i a j ⋅ p j − 1 h_i = \sum\limits_{j=1}^{i} a_j \cdot p^{j-1} hi=j=1∑iaj⋅pj−1
然后,如果我们想快速计算 [l,r] 区间的哈希值,只需要计算:
h a s h ( l , r ) = h r − h l − 1 ⋅ p r − l + 1 hash(l,r) = h_r - h_{l-1} \cdot p^{r-l+1} hash(l,r)=hr−hl−1⋅pr−l+1
代码实现
/*
1.假设有长度为x的最长公共子串,则必有长度小于x的公共子串,没有大于x的公共子串,所有具有二段性,想到二分
2.那么如何判断子串是否相同呢?
(1)暴力:首先可以想到,暴力枚举在a串中每一个长度为mid的子串,然后insert到哈希表中,
再暴力枚举b串中每一个长度为mid的子串,再count一下是否存在,存在返回true,但是这样会超时
(2)字符串哈希 这里我们想到可以使用字符串哈希的方法降低时间复杂度,将a串中每一个长度为mid的子串,
通过哈希函数得到整数,放到哈希表中,再count一下b串每一个长度为mid的子串
*/
#include<iostream>
#include<algorithm>
#include<cstring>
#include<unordered_set>
using namespace std;
const int N = 20010, P = 131;
typedef unsigned long long ULL; // 定义 ULL 类型
ULL h[N], p[N]; // h 存储哈希值,p 存储进制的幂
char str[N];
int n, m;
ULL get(int l, int r) // 计算区间 [l,r] 的哈希值
{
return h[r] - h[l - 1] * p[r - l + 1];
}
bool check(int mid) // 检查长度为 mid 的子串是否相同
{
unordered_set<ULL> hash; // 定义哈希表
// 把从 s1 中所有长度为 mid 的子串的哈希值插入哈希表
for (int i = 1; i + mid - 1 <= n; i++)
hash.insert(get(i, i + mid - 1));
// 枚举 s2 中所有长度为 mid 的子串,判断是否在哈希表中出现过
for (int i = n + 1; i + mid - 1 <= n + m; i++)
if (hash.count(get(i, i + mid - 1)))
return true;
return false;
}
int main()
{
cin>>(str+1);
n=strlen(str+1);
cin>>(str+n+1);
m=strlen(str+n+1);
// 初始化进制的幂和哈希值
p[0] = 1;
for (int i = 1; i <= n + m; i++)
{
p[i] = p[i - 1] * P;
char c = str[i];
// 数字转换成不同的字符,避免影响哈希函数的求值
if (isdigit(c))
{
if (i <= n) c = '#'; // 对字符串 1 中的数字转换成 #
else c = '$'; // 对字符串 2 中的数字转换成 $
}
h[i] = h[i - 1] * P + c; // 计算哈希值
}
// 二分查找最长的公共子串
int l = 0, r = min(n, m);
while (l < r)
{
int mid = l + r + 1 >> 1;
if (check(mid)) l = mid;
else r = mid - 1;
}
cout << l << endl; // 输出最长的公共子串的长度
return 0;
}