1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
| #include <bits/stdc++.h> using namespace std; #define int long long const int N=400010; const int mod=998244353; int fi[N*2],ne[N*2],la[N*2],a[N*2],c[N*2]; int i,j,k,n,m,x,y,t; int d[N],f[N][21],s[N]; int v[N],g[N],g1[N],q1[N],q2[N],q3[N],q4[N],p1[N],p2[N]; struct data{int x,y;}w[N]; int read(){ int x=0,f=1; char ch=getchar(); while (ch<'0'||ch>'9'){f=ch=='-'?-f:f;ch=getchar();} while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();} return x*f; } void add(int x,int y,int t){ k++;a[k]=y;c[k]=t; if (fi[x]==0)fi[x]=k;else ne[la[x]]=k; la[x]=k; } void dfs(int x,int fa){ for (int i=1;i<21;i++)f[x][i]=f[f[x][i-1]][i-1]; for (int i=fi[x];i;i=ne[i]) if (a[i]!=fa){ f[a[i]][0]=x;d[a[i]]=d[x]+1; dfs(a[i],x); } } void getv(int x,int fa){ for (int i=fi[x];i;i=ne[i]) if (a[i]!=fa){ getv(a[i],x); s[x]=(s[x]+s[a[i]])%mod; } } void dp1(int x,int fa){ g[x]=s[x];p1[x]=1ll*s[x]*s[x]%mod;q1[x]=0; for (int i=fi[x];i;i=ne[i]) if (a[i]!=fa){ dp1(a[i],x); q4[x]=(q4[x]+1ll*q1[a[i]]*c[i]%mod)%mod; q1[x]=(q1[x]+1ll*q1[a[i]]*c[i]%mod)%mod; q1[x]=(q1[x]+1ll*g[a[i]]*c[i]%mod*g[x]%mod)%mod; p1[x]=(p1[x]+1ll*p1[a[i]]*c[i]%mod)%mod; g[x]=(g[x]+1ll*g[a[i]]*c[i]%mod)%mod; } } void dp2(int x,int fa,int p){ p2[x]=(1ll*p2[f[x][0]]*p%mod+1ll*(p1[f[x][0]]-1ll*p1[x]*p%mod+mod)*p%mod)%mod; q2[x]=(1ll*(1ll*g1[f[x][0]]*p%mod+1ll*(g[f[x][0]]-1ll*g[x]*p%mod+mod)%mod*p%mod)*g[x]%mod)%mod; g1[x]=(1ll*g1[f[x][0]]*p%mod+1ll*(g[f[x][0]]-1ll*g[x]*p%mod+mod)%mod*p%mod)%mod;
q3[x]=1ll*q3[f[x][0]]%mod; q3[x]=(q3[x]+1ll*g1[f[x][0]]*(g[f[x][0]]-1ll*g[x]*p%mod+mod)%mod)%mod; q3[x]=(q3[x]+q1[f[x][0]]-q1[x]*p%mod+mod)%mod; q3[x]=(q3[x]-g[x]*(g[f[x][0]]-g[x]*p%mod+mod)%mod*p%mod+mod)%mod;
q3[x]=q3[x]*p%mod; for (int i=fi[x];i;i=ne[i])if (a[i]!=fa)dp2(a[i],x,c[i]); } main(){ n=read(); for (i=1;i<=n;i++)w[i].x=read(),w[i].y=read(); for (i=1;i<n;i++){ x=read();y=read();t=read(); add(x,y,t);add(y,x,t); } d[1]=1; dfs(1,-1); for (i=1;i<=n;i++){ s[i]=(s[i]+w[i].x)%mod; y=w[i].y;x=i; for (j=20;j>=0;j--) if (y>=(1<<j)){ y-=1<<j; x=f[x][j]; } s[f[x][0]]=(s[f[x][0]]+mod-w[i].x)%mod; } getv(1,-1); dp1(1,-1); dp2(1,0,0);
scanf("%lld",&m); while (m--){ scanf("%lld",&x);
printf("%lld\n",(2ll*((q1[x]+q2[x])%mod+q3[x])%mod+(p1[x]+p2[x])%mod)%mod); } return 0; }
|