题意:给你一棵树, 根结点为1, q组操作, 每组操作有两种, 一种是对一个结点的所有子树结点的值全部+1, 另一种是查询一个结点的子树结点上值%m的余数为素数的个数。
思路:对于第一个操作, 我们可以想到用dfs序给树重新标号, 使得一个结点的子树结点为相邻的一条线段, 这样,就可以很容易的用线段树进行处理了。 对于第二个操作, 为了维护一个区间内的值, 我们可以用bitset作为结点信息。 我们可以开一个m位的bitset, 对于每个位, 1表示这个数在此区间中, 最后用素数表和答案交一下就行了。
对于加一个数这个操作, 因为我们只需要m的余数, 我们可以考虑用位运算循环移位。 S << x 表示把集合S左移x位,相当于把每个数加了x, S >> m - x, 表示右移m - x位, 这样, 两者并, 就等价于第一个操作, 注意, bitset里的二进制位和整数是一样的, 高位在左, 低位在右。
细节参见代码:
- #include<cstdio>
- #include<cstring>
- #include<algorithm>
- #include<iostream>
- #include<string>
- #include<vector>
- #include<stack>
- #include<bitset>
- #include<cstdlib>
- #include<cmath>
- #include<set>
- #include<list>
- #include<deque>
- #include<map>
- #include<queue>
- #define Max(a,b) ((a)>(b)?(a):(b))
- #define Min(a,b) ((a)<(b)?(a):(b))
- using namespace std;
- typedef long long ll;
- typedef long double ld;
- #define M 1010
- typedef bitset<M> bt;
- const ld eps = 1e-9, PI = 3.1415926535897932384626433832795;
- const int mod = 1000000000 + 7;
- const int INF = int(1e9);
- // & 0x7FFFFFFF
- const int seed = 131;
- const ll INF64 = ll(1e18);
- const int maxn = 1e5 + 10;
- int T,n,m,cnt=0,addv[maxn<<2],id[maxn],last[maxn],a[maxn],b[maxn],vis[M];
- bt sum[maxn<<2], res;
- void PushUp(int o) {
- sum[o] = sum[o<<1] | sum[o<<1|1];
- }
- void add(int& a, int b) {
- a += b;
- if(a >= m) a -= m;
- }
- void change(bt &x, int y) {
- x = (x << y) | (x >> m - y);
- }
- void pushdown(int o) {
- if(addv[o]) {
- add(addv[o<<1], addv[o]);
- add(addv[o<<1|1], addv[o]);
- change(sum[o<<1], addv[o]);
- change(sum[o<<1|1], addv[o]);
- addv[o] = 0;
- }
- }
- void build(int l, int r, int o) {
- int m = (l + r) >> 1;
- addv[o] = 0;
- sum[o].reset();
- if(l == r) {
- sum[o][b[l]] = 1;
- return ;
- }
- build(l, m, o<<1);
- build(m+1, r, o<<1|1);
- PushUp(o);
- }
- void update(int L, int R, int v, int l, int r, int o) {
- int m = (l + r) >> 1;
- if(L <= l && r <= R) {
- add(addv[o], v);
- change(sum[o], v);
- return ;
- }
- pushdown(o);
- if(L <= m) update(L, R, v, l, m, o<<1);
- if(m < R) update(L, R, v, m+1, r, o<<1|1);
- PushUp(o);
- }
- bt query(int L, int R, int l, int r, int o) {
- int m = (l + r) >> 1;
- if(L <= l && r <= R) {
- return sum[o];
- }
- pushdown(o);
- bt ans;
- ans.reset();
- if(L <= m) ans |= query(L, R, l, m, o<<1);
- if(m < R) ans |= query(L, R, m+1, r, o<<1|1);
- return ans;
- }
- vector<int> g[maxn];
- void dfs(int u, int fa) {
- int len = g[u].size();
- id[u] = ++cnt;
- b[cnt] = a[u];
- for(int i=0;i<len;i++) {
- int v = g[u][i];
- if(v != fa) {
- dfs(v, u);
- }
- }
- last[u] = cnt;
- }
- void init() {
- for(int i=2;i<m;i++) if(!vis[i]) {
- res[i] = 1;
- for(int j=i*i;j<m;j+=i) vis[j] = 1;
- }
- }
- int u, v, ii, x, q;
- int main() {
- scanf("%d%d",&n,&m);
- init();
- for(int i=1;i<=n;i++) scanf("%d",&a[i]), a[i] %= m;
- for(int i=1;i<n;i++) {
- scanf("%d%d",&u,&v);
- g[u].push_back(v);
- g[v].push_back(u);
- }
- dfs(1, 1);
- build(1, n, 1);
- scanf("%d",&q);
- while(q--) {
- scanf("%d%d",&ii,&v);
- if(ii == 1) {
- scanf("%d",&x);
- update(id[v], last[v], x%m, 1, n, 1);
- }
- else {
- bt cur = query(id[v], last[v], 1, n, 1);
- printf("%d\n",(res & cur).count());
- }
- }
- return 0;
- }