学习hash刷的第一道题,简单记录一下hash的思路。
传送门:题目
题意:
给定n个字符串,比较任意两个字符串,如果两个字符串在同一个位置,有且只有一个不同,结果+1,最后输出结果。
题解:
如果我们不用hash,简单思路就是比较任意两个字符串(时间复杂度:
n∗(n−1)2
n
∗
(
n
−
1
)
2
, 从左往右扫一遍 (时间复杂度:len)
如果只有一个不同,结果+1。可想,时间复杂度是很感人的。
hash的核心就是把一个无限长字符串映射为一个大整数,如果大整数不同,我们就可以认为两个字符串是不同的。
当然,这个大整数不是随便造的,要不然我们就无法找到只有一个字母不同的两个字符串了。
题目中说,只有64个字符,所以我们不妨把这64个字符映射为数字[1-64],但是这样是无法保证它的唯一性的,比如说字母a映射1,字母b映射2,那么字符串”ab”映射1+2,字符串”ba”映射2+1,是不是,所以我们需要改进一下。
不难想到,我们在组合不是简单相加,而是乘以对应的权重,简单的设想:我们可以把它所在的列作为它的权重,那么”ab”=
a∗1+b∗2=5
a
∗
1
+
b
∗
2
=
5
而”ba”=
b∗1+a∗2=4
b
∗
1
+
a
∗
2
=
4
,好了,我们现在看似解决了重复性问题。但其实没有完全解决。
不难想到,权重为5的还有”ca”。现在一个字符串有两部分组成,一个是字母对应的映射值,一个是所在位置的权重,究其重复缘由,就是这两部分的值范围太小了,我们需要增大一些。
下面的prime数组对应的字母对应的映射值,hash_value对应的是字符串对应的hash值,读者自己看就好了。
处理的时候有个技巧,就是把他们都开成unsigned long long类型,这样他们在越界的时候会自动取模,而不是取到负数。
我们现在得到了每一个字符串的哈希值,现在只需要对比任意两个字符串的哈希值(时间复杂度:
n∗(n−1)2
n
∗
(
n
−
1
)
2
,然后比较这两个哈希值-相同位置对应字母的哈希值(时间复杂度:1),如果得到了两个相同的值,那么结果加1.
这就是hash的核心,把原来需要逐位比较单个字母,现在只需要对比hash值就行了。
AC代码:
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#define INF 0x3f3f3f3f
using namespace std;
const int maxn = 30010;
unsigned long long hash_str[maxn], hash_value[210][65], prime[210], temp[maxn];
char mmap[maxn][210];
int solve(char c) {//把字母映射到[1-64]数字
if (c >= 'a' && c <= 'z')
return c - 96;//[1-26]
if (c >= 'A' && c <= 'Z')
return c - 38;//[27-52]
if (c >= '0' && c <= '9')
return c + 5;//[53-62]
if (c == '_')
return 63;//63
return 64;//c=='@' 64
}
int main(void) {
ios::sync_with_stdio(false);
int n, len, laji;
long long ans = 0;
cin >> n >> len >> laji;
prime[0] = 1;
for (int i = 1; i <= len; i++)//这里利用usigned long long自然溢出就好了,不用取模了,取模的话会超时
prime[i] = prime[i - 1] * 65 * 131;//把每一个行映射一个值,我这里第一开始开了个64,WA了.
for (int i = 1; i <= len; i++)
for (int j = 1; j <= 64; j++)
hash_value[i][j] = prime[i - 1] * j * 131;//找到任意一个字母在任意一个位置所对应的随机数
for (int i = 1; i <= n; i++)
cin >> (mmap[i] + 1); //读入字符矩阵
for (int i = 1; i <= n; i++)
for (int j = 1; j <= len; j++)
hash_str[i] += hash_value[j][solve(mmap[i][j])];//找到每一个字符串对应的hash值
for (int j = 1; j <= len; j++) { //遍历每一列
for (int i = 1; i <= n; i++) //遍历每一行
temp[i] = hash_str[i] - hash_value[j][solve(mmap[i][j])];
sort(temp + 1, temp + 1 + n);
int sum = 1;
for (int i = 1; i <= n; i++)
if (temp[i] != temp[i + 1])
ans += sum * (sum - 1) / 2, sum = 1;
else
sum++;
}
cout << ans << endl;
return 0;
}