一道有趣的DP题

题目大意

给你一个树,每个点有个点权,每条边有个概率出现或不出现

求询问点所在的连通块的和的平方的期望值。

n,q100000n,q \leq 100000

题解

好棒的一道题啊!!

考虑每次和的平方拆开

(f[i])2=jS,kSf[j]f[k](\sum f[i])^2=\sum_{j \in S,k \in S }f[j]*f[k]

Pj,kP_{j,k}表示j与k和i都连通的概率,QjQ_j表示j与i连通的概率

那么ans=Qjf[i]2+2Pj,k(f[j]f[k])ans=\sum Q_j*f[i]^2+2*\sum P_{j,k}*(f[j]*f[k])

每次正着DP一遍,反着DP一遍即可

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N=400010;
const int mod=998244353;
int fi[N*2],ne[N*2],la[N*2],a[N*2],c[N*2];
int i,j,k,n,m,x,y,t;
int d[N],f[N][21],s[N];
int v[N],g[N],g1[N],q1[N],q2[N],q3[N],q4[N],p1[N],p2[N];
struct data{int x,y;}w[N];
int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){f=ch=='-'?-f:f;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
return x*f;
}
void add(int x,int y,int t){
k++;a[k]=y;c[k]=t;
if (fi[x]==0)fi[x]=k;else ne[la[x]]=k;
la[x]=k;
}
void dfs(int x,int fa){
for (int i=1;i<21;i++)f[x][i]=f[f[x][i-1]][i-1];
for (int i=fi[x];i;i=ne[i])
if (a[i]!=fa){
f[a[i]][0]=x;d[a[i]]=d[x]+1;
dfs(a[i],x);
}
}
void getv(int x,int fa){
for (int i=fi[x];i;i=ne[i])
if (a[i]!=fa){
getv(a[i],x);
s[x]=(s[x]+s[a[i]])%mod;
}
}
void dp1(int x,int fa){
g[x]=s[x];p1[x]=1ll*s[x]*s[x]%mod;q1[x]=0;
for (int i=fi[x];i;i=ne[i])
if (a[i]!=fa){
dp1(a[i],x);
q4[x]=(q4[x]+1ll*q1[a[i]]*c[i]%mod)%mod;
q1[x]=(q1[x]+1ll*q1[a[i]]*c[i]%mod)%mod;
q1[x]=(q1[x]+1ll*g[a[i]]*c[i]%mod*g[x]%mod)%mod;
p1[x]=(p1[x]+1ll*p1[a[i]]*c[i]%mod)%mod;
g[x]=(g[x]+1ll*g[a[i]]*c[i]%mod)%mod;
}
}
void dp2(int x,int fa,int p){
p2[x]=(1ll*p2[f[x][0]]*p%mod+1ll*(p1[f[x][0]]-1ll*p1[x]*p%mod+mod)*p%mod)%mod;
q2[x]=(1ll*(1ll*g1[f[x][0]]*p%mod+1ll*(g[f[x][0]]-1ll*g[x]*p%mod+mod)%mod*p%mod)*g[x]%mod)%mod;
g1[x]=(1ll*g1[f[x][0]]*p%mod+1ll*(g[f[x][0]]-1ll*g[x]*p%mod+mod)%mod*p%mod)%mod;
// q3[x]=(1ll*q3[f[x][0]]*p%mod+1ll*g1[f[x][0]]*(g[f[x][0]]-1ll*g[x]*p%mod+mod)%mod*p%mod)%mod;
q3[x]=1ll*q3[f[x][0]]%mod;//case 3
q3[x]=(q3[x]+1ll*g1[f[x][0]]*(g[f[x][0]]-1ll*g[x]*p%mod+mod)%mod)%mod;//case2
q3[x]=(q3[x]+q1[f[x][0]]-q1[x]*p%mod+mod)%mod;
q3[x]=(q3[x]-g[x]*(g[f[x][0]]-g[x]*p%mod+mod)%mod*p%mod+mod)%mod;
// printf("%lld %lld\n",q1[f[x][0]],(q1[x]*p%mod+g[x]*(g[f[x][0]]-g[x]*p%mod)%mod*p%mod)%mod);
q3[x]=q3[x]*p%mod;
for (int i=fi[x];i;i=ne[i])if (a[i]!=fa)dp2(a[i],x,c[i]);
}
main(){
n=read();
for (i=1;i<=n;i++)w[i].x=read(),w[i].y=read();
for (i=1;i<n;i++){
x=read();y=read();t=read();
add(x,y,t);add(y,x,t);
}
d[1]=1;
dfs(1,-1);
for (i=1;i<=n;i++){
s[i]=(s[i]+w[i].x)%mod;
y=w[i].y;x=i;
for (j=20;j>=0;j--)
if (y>=(1<<j)){
y-=1<<j;
x=f[x][j];
}
s[f[x][0]]=(s[f[x][0]]+mod-w[i].x)%mod;
}
getv(1,-1);
dp1(1,-1);
dp2(1,0,0);
// printf("s:");for (i=1;i<=n;i++)printf("%lld ",s[i]);printf("\n");
// printf("g:");for (i=1;i<=n;i++)printf("%lld ",g[i]);printf("\n");
// printf("q1:");for (i=1;i<=n;i++)printf("%lld ",q1[i]);printf("\n");
// printf("p1:");for (i=1;i<=n;i++)printf("%lld ",p1[i]);printf("\n");
// printf("g1:");for (i=1;i<=n;i++)printf("%lld ",g1[i]);printf("\n");
// printf("q2:");for (i=1;i<=n;i++)printf("%lld ",q2[i]);printf("\n");
// printf("p2:");for (i=1;i<=n;i++)printf("%lld ",p2[i]);printf("\n");
// printf("q3:");for (i=1;i<=n;i++)printf("%lld ",q3[i]);printf("\n");
scanf("%lld",&m);
while (m--){
scanf("%lld",&x);
// dp1(x,-1);
// printf("%lld\n",(2ll*q1[x]%mod+p1[x])%mod);
printf("%lld\n",(2ll*((q1[x]+q2[x])%mod+q3[x])%mod+(p1[x]+p2[x])%mod)%mod);
}
return 0;
}
,