知识点 :1.先切割重边,还记得求出top
2.线段树维护dfs序下的val数组,一条链上的数在dfs序中是连续的
3.求值时就在树上跑,优化在于一条链上的信息可以直接求,顾跳top即可
There is a false in the code : False 1
#include#define M 200002using namespace std;long long n,m,rt,mod;long long tot = 0;long long val[M];long long head[M * 4],cnt = 0;struct edge{ long long to; long long nxt;}e[M * 4];void add(long long x,long long y){ e[++cnt].nxt = head[x]; e[cnt].to = y; head[x] = cnt;} long long siz[M];long long son[M],rk[M],tp[M],fa[M],dep[M],top[M];long long tre[M * 4],lazy[M * 4];void build(long long t,long long l,long long r){ if(l == r) { tre[t] = tp[l]; return; } long long mid = (l + r) / 2; build(t * 2,l,mid); build(t * 2 + 1,mid + 1,r); tre[t] = (tre[t * 2] + tre[t * 2 + 1]) % mod; }void up(long long t,long long l,long long r,long long x){ lazy[t] += x; lazy[t] %= mod; tre[t] += (r - l + 1) * x; tre[t] %= mod;}void pushdown(long long l,long long r,long long t){ long long mid = (l + r) / 2; up(t * 2,l,mid,lazy[t]); up(t * 2 + 1,mid + 1,r,lazy[t]); lazy[t] = 0;}long long query(long long t,long long l,long long r,long long b,long long e){ if(b <= l && r <= e) { return tre[t] % mod; } long long mid = (l + r) / 2; pushdown(l,r,t); long long ans = 0; if(b <= mid)ans += query(t * 2,l,mid,b,e); ans = ans % mod; if(mid < e)ans += query(t * 2 + 1,mid + 1,r,b,e); return ans % mod;}void updata(long long t,long long l,long long r,long long b,long long e,long long k){ if(b <= l && r <= e) { tre[t] += (r - l + 1) * k; tre[t] %= mod; lazy[t] += k; lazy[t] %= mod; return; } long long mid = (l + r) / 2; pushdown(l,r,t); if(b <= mid)updata(t * 2,l,mid,b,e,k); if(mid < e)updata(t * 2 + 1,mid + 1,r,b,e,k); tre[t] = (tre[t * 2] + tre[t * 2 + 1]) % mod;}void dfs1(long long x,long long depth,long long fat){ fa[x] = fat; dep[x] = depth; siz[x] = 1; long long maxn = -1; for(long long i = head[x];i;i = e[i].nxt) { if(e[i].to != fat) { dfs1(e[i].to,depth + 1,x); siz[x] += siz[e[i].to]; if(maxn < siz[e[i].to] || maxn == -1) { son[x] = e[i].to; maxn = siz[e[i].to]; } } }}void qupdata(long long l,long long r,long long k){ k %= mod; long long fx = l,fy = r; while(top[fx] != top[fy]) { if(dep[top[fx]] < dep[top[fy]])swap(fx,fy);//False 1 : 这里应该是dep[top[fx]] < dep[top[fy]]而并非dep[fx] < dep[fy] updata(1,1,n,rk[top[fx]],rk[fx],k); fx = fa[top[fx]]; } if(dep[fx] > dep[fy])swap(fx,fy); updata(1,1,n,rk[fx],rk[fy],k);} long long qquery(long long l,long long r){ long long ans = 0; long long fx = l ,fy = r; while(top[fx] != top[fy]) { if(dep[top[fx]] < dep[top[fy]])swap(fx,fy); ans += query(1,1,n,rk[top[fx]],rk[fx]); ans %= mod; fx = fa[top[fx]]; } if(dep[fx] > dep[fy])swap(fx,fy); ans += query(1,1,n,rk[fx],rk[fy]); ans %= mod; return ans;}void rupdata(long long x,long long k){ k %= mod; updata(1,1,n,rk[x],rk[x] + siz[x] - 1,k);}long long rquery(long long x){ return query(1,1,n,rk[x],rk[x] + siz[x] - 1) % mod;}void dfs2(long long x,long long t){ rk[x] = ++tot; tp[tot] = val[x]; top[x] = t; if(!son[x])return; dfs2(son[x],t); for(long long i = head[x];i;i = e[i].nxt) { long long y = e[i].to; if(y != son[x] && y != fa[x]) dfs2(y,y); }}int main(){ scanf("%lld%lld%lld%lld",&n,&m,&rt,&mod); for(long long i = 1;i <= n;i++)scanf("%lld",&val[i]); long long x,y,z,a; for(long long i = 1;i < n;i++) { scanf("%lld%lld",&x,&y); add(x,y); add(y,x); } dfs1(rt,1,0); dfs2(rt,rt); build(1,1,n); while(m--) { scanf("%lld",&a); if(a == 1) { scanf("%lld%lld%lld",&x,&y,&z); qupdata(x,y,z); } else if(a == 2) { scanf("%lld%lld",&x,&y); printf("%lld\n",qquery(x,y) % mod); } else if(a == 3) { scanf("%lld%lld",&x,&y); rupdata(x,y); } else if(a == 4) { scanf("%lld",&x); printf("%lld\n",rquery(x) % mod); } } return 0;}