题目背景
- 对于一给定的素数集合S=p1,p2,...,pK,考虑一个正整数集合,该集合中任一元素的质因数全部属于S。这个正整数集合包括,p1,p1×p2、p1×p1、p1×p2×p3...(还有其它)。该集合被称为S集合的“丑数集合”。注意:我们认为1不是一个丑数。
题目描述
- 你的工作是对于输入的集合S去寻找“丑数集合”中的第N个“丑数”。所有的答案可以用longint(32位整数)存储。
- 补充:丑数集合中每个数从小到大排列,每个丑数都是素数集合中的数的乘积,第N个“丑数”就是在能由素数集合中的数相乘得来的(包括它本身)第N小的数。
输入输出格式
输入格式
- 第1行:二个被空格分开的整数:K和N,1≤K≤100,1≤N≤100,000.
- 第2行:K个被空格分开的整数:集合S的元素
输出格式
输入输出样例
输入样例#1
输出样例#1
说明
- 题目翻译来自NOCOW
USACO Training Section 3.1
原题地址
分析 平衡树
- 主要做法:用数据结构维护一个丑数集合, 每次取出当前的最小丑数,乘以素数集合S中的每一个元素,然后再把这些新的丑数加入数据结构,并保证取出的丑数个数和数据结构维护的丑数个数总共不超过N个,若有超过则删去多余的最大丑数
- 然而直接用Splay做会TLE/RE,我们考虑如下优化:
- 在删去多余丑数时,一个个删除显然是不够高效的,实际上我们可以直接将当前第N小的丑数旋转到根,然后删去该节点的右子树即可
- 最后答案保证在int范围内,因此如果发现一个丑数爆int,我们就没必要加入平衡树了,进一步的我们会发现,如果某一个丑数已经大于当前第N小的丑数,那么它肯定不可能成为最后的答案,我们同样不将其加入平衡树,这便是个重要的优化
- 加入平衡树的丑数数量相当多,而实际满足条件的只有N个,因此我们显然不希望浪费太多的空间,对应的我们可以用一个栈存储那些被删除的点,并把相关的信息清空,插入丑数时再取出栈中节点重新使用即可
代码
#include <iostream>
#include <cstdio>
using namespace std;
const int S = 1 << 20;
char frd[S], *hed = frd + S;
const char *tal = hed;
inline char nxtChar()
{
if (hed == tal)
fread(frd, 1, S, stdin), hed = frd;
return *hed++;
}
inline int get()
{
char ch; int res = 0;
while (!isdigit(ch = nxtChar()));
res = ch ^ 48;
while (isdigit(ch = nxtChar()))
res = res * 10 + ch - 48;
return res;
}
typedef long long ll;
const int N = 11e4 + 5;
ll Int = 2147483647ll;
int fa[N], lc[N], rc[N], sze[N], val[N];
int T, n, k, rt, top, stk[N], a[N];
inline void Push(int x) {sze[x] = sze[lc[x]] + sze[rc[x]] + 1;}
inline void Rotate(int x)
{
int y = fa[x], z = fa[y];
int b = (lc[y] == x ? rc[x] : lc[x]);
fa[y] = x; fa[x] = z;
if (b) fa[b] = y;
if (z) (lc[z] == y ? lc[z] : rc[z]) = x;
if (lc[y] == x) rc[x] = y, lc[y] = b;
else lc[x] = y, rc[y] = b;
Push(y);
}
inline bool Which(int x) {return lc[fa[x]] == x;}
inline void Splay(int x, int tar)
{
while (fa[x] != tar)
{
if (fa[fa[x]] != tar)
Which(fa[x]) == Which(x) ? Rotate(fa[x]) : Rotate(x);
Rotate(x);
}
Push(x);
if (!tar) rt = x;
}
inline void Insert(int vi)
{
int x = rt, y = 0, dir;
while (x)
{
y = x;
if (val[x] == vi) return ;
if (val[x] > vi) x = lc[x], dir = 0;
else x = rc[x], dir = 1;
}
int w = y; if (y) ++sze[y];
while (fa[w]) ++sze[w = fa[w]];
x = top ? stk[top--] : ++T;
fa[x] = y; val[x] = vi; sze[x] = 1;
if (y) (dir ? rc[y] : lc[y]) = x;
Splay(x, 0);
}
inline void Join(int x, int y)
{
int w = y;
while (lc[w]) w = lc[w];
lc[fa[x] = w] = x; fa[rt = y] = 0;
Splay(w, 0);
}
inline void Delete(int x)
{
Splay(x, 0);stk[++top] = x;
int L = lc[x], R = rc[x];
lc[x] = rc[x] = 0;
if (!L || !R) fa[rt = L + R] = 0;
else Join(L, R);
}
inline int GetKth(int k)
{
int x = rt;
while (x)
{
if (k <= sze[lc[x]])
x = lc[x];
else
{
k -= sze[lc[x]] + 1;
if (!k) return x;
x = rc[x];
}
}
return 0;
}
inline void Print(int x)
{
if (!x) return ;
Print(lc[x]); Print(rc[x]);
stk[++top] = x;
fa[x] = lc[x] = rc[x] = 0;
}
inline int FindMin() {int x = rt; while (lc[x]) x = lc[x]; return x;}
int main()
{
k = get(); n = get(); int x, cnt = 0; ll y;
for (int i = 1; i <= k; ++i) Insert(a[i] = get());
while (++cnt <= n)
{
y = val[x = FindMin()]; Delete(x);
for (int i = 1; i <= k; ++i)
if (y * a[i] < Int) Insert(y * a[i]);
else break;
if (sze[rt] + cnt >= n)
{
Int = val[x = GetKth(n - cnt)]; Splay(x, 0);
Print(rc[x]); rc[x] = 0; Push(x);
}
}
cout << y << endl;
}