Serval and Music Game
题面翻译
题目描述
给定整数 nnn 和长度为 nnn 的递增序列 sss。
定义 f(x)f(x)f(x) 为满足下列要求的整数 i(1≤i≤n)i(1\leq i\leq n)i(1≤i≤n) 的数量:
- 存在非负整数 pi,qip_i,q_ipi,qi 使得 si=pi⌊snx⌋+qi⌈snx⌉s_i=p_i\bigg\lfloor\dfrac{s_n}{x}\bigg\rfloor+q_i\bigg\lceil\dfrac{s_n}{x}\bigg\rceilsi=pi⌊xsn⌋+qi⌈xsn⌉。
你需要求出 ∑x=1snx×f(x)\sum_{x=1}^{s_n}x\times f(x)∑x=1snx×f(x) 对 998244353998244353998244353 取模后的值。
每个测试点包含 ttt 组数据。
输入格式
第一行输入一个整数 t(1≤t≤104)t(1\leq t\leq10^4)t(1≤t≤104) 表示数据组数,接下来对于每组数据:
第一行输入一个整数 n(1≤n≤106)n(1\leq n\leq10^6)n(1≤n≤106)。
接下来输入一行 nnn 个整数表示序列 s(1≤s1<s2<⋯<sn≤107)s(1\leq s_1<s_2<\cdots<s_n\leq10^7)s(1≤s1<s2<⋯<sn≤107)。
单个测试点内所有组数据对应的 nnn 之和不超过 10610^6106,对应的 sns_nsn 之和不超过 10710^7107。
输出格式
对于每组数据:
输出一行一个整数表示 ∑x=1snx×f(x)\sum_{x=1}^{s_n}x\times f(x)∑x=1snx×f(x) 对 998244353998244353998244353 取模后的值。
题目描述
Serval loves playing music games. He meets a problem when playing music games, and he leaves it for you to solve.
You are given nnn positive integers s1<s2<…<sns_1 < s_2 < \ldots < s_ns1<s2<…<sn . f(x)f(x)f(x) is defined as the number of iii ( 1≤i≤n1\leq i\leq n1≤i≤n ) that exist non-negative integers pi,qip_i, q_ipi,qi such that:
si=pi⌊snx⌋+qi⌈snx⌉ s_i=p_i\left\lfloor{s_n\over x}\right\rfloor + q_i\left\lceil{s_n\over x}\right\rceil si=pi⌊xsn⌋+qi⌈xsn⌉
Find out ∑x=1snx⋅f(x)\sum_{x=1}^{s_n} x\cdot f(x)∑x=1snx⋅f(x) modulo $998,244,353 $ .
As a reminder, ⌊x⌋\lfloor x\rfloor⌊x⌋ denotes the maximal integer that is no greater than xxx , and ⌈x⌉\lceil x\rceil⌈x⌉ denotes the minimal integer that is no less than xxx.
输入格式
Each test contains multiple test cases. The first line contains the number of test cases ttt ( 1≤t≤1041\leq t\leq 10^41≤t≤104 ). The description of the test cases follows.
The first line of each test cases contains a single integer nnn ( 1≤n≤1061\leq n\leq 10^61≤n≤106 ).
The second line of each test case contains nnn positive integers s1,s2,…,sns_1,s_2,\ldots,s_ns1,s2,…,sn ( 1≤s1<s2<…<sn≤1071\leq s_1 < s_2 < \ldots < s_n \leq 10^71≤s1<s2<…<sn≤107 ).
It is guaranteed that the sum of nnn over all test cases does not exceed 10610^6106 , and the sum of sns_nsn does not exceed 10710^7107 .
输出格式
For each test case, print a single integer in a single line — the sum of x⋅f(x)x\cdot f(x)x⋅f(x) over all possible xxx modulo 998 244 353998\,244\,353998244353 .
样例 #1
样例输入 #1
4
3
1 2 4
4
1 2 7 9
4
344208 591000 4779956 5403429
5
1633 1661 1741 2134 2221
样例输出 #1
26
158
758737625
12334970
提示
Solution
考虑每个 xxx 对答案的贡献。
分类讨论:
若 x∣snx|s_nx∣sn ,则可以枚举 sn/xs_n/xsn/x 的倍数计算满足要求的 sis_isi 的个数。
若 x∤snx\nmid s_nx∤sn ,首先有 ⌈snx⌉=⌊snx⌋+1\left\lceil\dfrac{s_n}{x}\right\rceil=\left\lfloor\dfrac{s_n}{x}\right\rfloor+1⌈xsn⌉=⌊xsn⌋+1
设 ⌊snx⌋=k\left\lfloor\dfrac{s_n}{x}\right\rfloor = k⌊xsn⌋=k ,则 si=(pi+qi)k+qis_i=(p_i+q_i)k+q_isi=(pi+qi)k+qi ,由于 pi,qi∈Np_i,q_i\in \mathbb{N}pi,qi∈N ,
则若 sis_isi 满足要求,必然有 ⌊sik⌋≥qi\left\lfloor\dfrac{s_i}{k}\right\rfloor\ge q_i⌊ksi⌋≥qi ,不难证明 qi=si mod kq_i=s_i\bmod kqi=simodk ,
此时枚举 j=⌊sik⌋j=\left\lfloor\dfrac{s_i}{k}\right\rfloorj=⌊ksi⌋ ,那么 sis_isi 满足要求的充要条件为 si∈[jk,jk+j]s_i\in [jk,jk+j]si∈[jk,jk+j] 。
可以发现当 0≤qi<k0\le q_i<k0≤qi<k ,故当 j≥kj \ge kj≥k 时,此后的 sis_isi 都会满足要求。
因此我们只需枚举 j<kj<kj<k ,剩下的统一处理(见代码)。
对于一个区间内满足要求的 sis_isi ,可以用一个桶存储出现次数并将其进行前缀和,记为 cntcntcnt 数组,
那么在 [l,r][l,r][l,r] 中的 sis_isi 的个数就是 cntr−cntl−1cnt_r - cnt_{l-1}cntr−cntl−1 。
这样我们可以 O(sn)O(\sqrt {s_n})O(sn) 求出 f(x)f(x)f(x) 的值。
考虑到 k=⌊snx⌋k=\left\lfloor\dfrac{s_n}{x}\right\rfloork=⌊xsn⌋ 可以数论分块,于是可以在 O(sn)O(s_n)O(sn) 求出答案。
Code
#include<iostream>
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N = 1e7 + 3;
const LL MOD = 998244353;
inline void read(int &x)
{
int sgn = 1; x = 0;
char ch = getchar();
while(ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
if(ch == '-') sgn = -1, ch = getchar();
while(ch >= '0' && ch <= '9') x = (x << 1) + (x << 3), x += (ch ^ '0'), ch = getchar();
x *= sgn;
}
int T, cnt[N], n, s[N], m;
int main()
{
read(T);
while(T -- )
{
read(n); m = 0;
for(int i = 1; i <= n; i ++ ) read(s[i]), m = max(m, s[i]);
for(int i = 0; i <= m; i ++ ) cnt[i] = 0;
for(int i = 1; i <= n; i ++ ) cnt[s[i]] ++ ;
for(int i = 1; i <= m; i ++ ) cnt[i] += cnt[i - 1];
LL res = 0, tp = 0;
for(int i = 1; i <= m; i ++ )
{
LL sum = 0; int k = m / i;
if(m % i)
if(m % (i - 1) && m / (i - 1) == k) sum = tp;
else
{
for(int j = 1; j < k && j * k <= m; j ++ ) sum += cnt[min(j * k + j, m)] - cnt[j * k - 1];
if(1ll * k * k <= 1ll * m) sum += cnt[m] - cnt[k * k - 1];
}
else for(int j = k; j <= m; j += k) sum += cnt[j] - cnt[j - 1];
res += i * sum % MOD; res %= MOD;
tp = sum;
}
printf("%lld\n", res);
}
return 0;
}