- 题解
数据传输
- @ 2025-10-24 19:56:47
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 200010, M = N * 2, K = 18;
const LL INF = 1e18;
int n, Q, m;
int w[N];
int h[N], e[M], ne[M], idx;
int fa[N][K], dep[N], g[N];
LL f[N][K][3][3];
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
void dfs_fa(int u, int father, int depth)
{
dep[u] = depth;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == father) continue;
fa[j][0] = u;
for (int k = 1; k < K; k ++ )
fa[j][k] = fa[fa[j][k - 1]][k - 1];
dfs_fa(j, u, depth + 1);
}
}
int lca(int a, int b)
{
if (dep[a] < dep[b]) swap(a, b);
for (int k = K - 1; k >= 0; k -- )
if (dep[fa[a][k]] >= dep[b])
a = fa[a][k];
if (a == b) return a;
for (int k = K - 1; k >= 0; k -- )
if (fa[a][k] != fa[b][k])
a = fa[a][k], b = fa[b][k];
return fa[a][0];
}
void dfs_f(int u, int father)
{
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == father) continue;
auto F = f[j][0];
if (m == 1)
{
F[0][0] = w[u];
}
else if (m == 2)
{
F[0][0] = F[1][0] = w[u];
F[0][1] = 0;
}
else if (m == 3)
{
F[0][0] = F[1][0] = F[2][0] = w[u];
F[0][1] = F[1][2] = 0;
F[2][2] = g[j];
}
for (int k = 1; k < K; k ++ )
for (int x = 0; x < m; x ++ )
for (int y = 0; y < m; y ++ )
for (int z = 0; z < m; z ++ )
f[j][k][x][y] = min(f[j][k][x][y], f[j][k - 1][x][z] + f[fa[j][k - 1]][k - 1][z][y]);
dfs_f(j, u);
}
}
void calc(int a, int p, LL r[])
{
r[0] = w[a], r[1] = r[2] = INF;
LL nr[3];
for (int k = K - 1; k >= 0; k -- )
if (dep[fa[a][k]] >= dep[p])
{
memset(nr, 0x3f, sizeof nr);
for (int x = 0; x < m; x ++ )
for (int y = 0; y < m; y ++ )
nr[y] = min(nr[y], r[x] + f[a][k][x][y]);
memcpy(r, nr, sizeof nr);
a = fa[a][k];
}
}
LL query(int a, int b)
{
int p = lca(a, b);
LL l[3], r[3];
calc(a, p, l), calc(b, p, r);
LL res = l[0] + r[0] - w[p];
for (int x = 0; x < m; x ++ )
for (int y = 0; y < m; y ++ )
{
LL s = l[x] + r[y];
if (x == 2 && y == 2) s += g[p];
res = min(res, s);
}
return res;
}
int main()
{
scanf("%d%d%d", &n, &Q, &m);
for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);
memset(h, -1, sizeof h);
memset(g, 0x3f, sizeof g);
for (int i = 0; i < n - 1; i ++ )
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
g[a] = min(g[a], w[b]);
g[b] = min(g[b], w[a]);
}
dfs_fa(1, -1, 1);
memset(f, 0x3f, sizeof f);
dfs_f(1, -1);
while (Q -- )
{
int a, b;
scanf("%d%d", &a, &b);
printf("%lld\n", query(a, b));
}
return 0;
}
1 条评论
-
邓天润 LV 8 @ 2025-10-24 19:59:41附我的极丑代码
#include <bits/stdc++.h> using namespace std; typedef long long LL; const int N=2e5+5; const LL INF=1e15+9586; int n,T,k,head[N],ver[N*2],nxt[N*2],tot=0; int f[N][22]; bool vis[N]; struct matrix{ int n,m; LL c[3][3]; void _init(int x,int y){n=x,m=y;} }; matrix base_2[N][21],base_3[N][21]; LL S[N]; LL w[N],depth[N],p[N]; void add(int x,int y){ ver[++tot]=y,nxt[tot]=head[x],head[x]=tot; } matrix operator *(const matrix &a,const matrix &b){ matrix d; d._init(a.n,b.m); for(int i=0;i<a.n;i++){ for(int j=0;j<b.m;j++){ d.c[i][j]=INF; for(int k1=0;k1<a.m;k1++) d.c[i][j]=min(d.c[i][j],a.c[i][k1]+b.c[k1][j]); } } return d; } matrix generate_2(int x){ matrix tmp; tmp._init(2,2); tmp.c[0][0]=w[x],tmp.c[0][1]=0,tmp.c[1][0]=w[x],tmp.c[1][1]=INF; return tmp; } matrix generate_3(int x){ matrix tmp; tmp._init(3,3); tmp.c[0][0]=w[x],tmp.c[0][1]=0,tmp.c[0][2]=INF; tmp.c[1][0]=w[x],tmp.c[1][1]=p[x],tmp.c[1][2]=0; tmp.c[2][0]=w[x],tmp.c[2][1]=INF,tmp.c[2][2]=INF; return tmp; } void dfs(int x,int fa){ for(int i=head[x];i;i=nxt[i]){ int y=ver[i]; if(y==fa) continue; p[x]=min(p[x],w[y]); dfs(y,x); } } void bfs(){ queue<int> q; q.push(1); S[1]=w[1]; vis[1]=1; depth[1]=1; base_2[1][0]=generate_2(1); base_3[1][0]=generate_3(1); while(!q.empty()){ int u=q.front(); q.pop(); for(int i=head[u];i;i=nxt[i]){ int v=ver[i]; if(vis[v]) continue; vis[v]=1; f[v][0]=u; S[v]=S[u]+w[v]; depth[v]=depth[u]+1; q.push(v); base_2[v][0]=generate_2(v); base_3[v][0]=generate_3(v); for(int j=1;j<=20;j++){ f[v][j]=f[f[v][j-1]][j-1]; base_2[v][j]=base_2[v][j-1]*base_2[f[v][j-1]][j-1]; base_3[v][j]=base_3[v][j-1]*base_3[f[v][j-1]][j-1]; } } } } int lca(int x,int y){ if(depth[x]>depth[y]) swap(x,y); for(int i=20;i>=0;i--) if(depth[f[y][i]]>=depth[x]) y=f[y][i]; if(x==y) return x; for(int i=20;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i]; return f[x][0]; } matrix get(int x,int y,int h){//查询由x至y的矩阵积 // cout<<x<<","<<y<<":"<<endl; matrix tmp; bool flag=false; for(int i=20;i>=0;i--){ if(depth[f[x][i]]>=depth[y]){ if(!flag){ if(h==2) tmp=base_2[x][i]; else tmp=base_3[x][i]; flag=true; }else{ if(h==2) tmp=tmp*base_2[x][i]; else tmp=tmp*base_3[x][i]; } x=f[x][i]; } } if(!flag){ if(h==2) tmp=base_2[x][0]; else tmp=base_3[x][0]; }else{ if(h==2) tmp=tmp*base_2[x][0]; else tmp=tmp*base_3[x][0]; } return tmp; } int main(){ ios::sync_with_stdio(false); cin.tie(0),cout.tie(0); cin>>n>>T>>k; for(int i=1;i<=n;i++){ cin>>w[i]; p[i]=INF; } int a,b; for(int i=1;i<n;i++){ cin>>a>>b; add(a,b),add(b,a); } dfs(1,0); bfs(); // cout<<"zzzz"<<endl; // cout<<base_2[2][1].c[0][0]<<" "<<base_2[2][1].c[0][1]<<endl; // cout<<base_2[2][1].c[1][0]<<" "<<base_2[2][1].c[1][1]<<endl; // cout<<"zzzz"<<endl; while(T--){ int s,t; cin>>s>>t; int l=lca(s,t); if(k==1){ cout<<S[s]+S[t]-2ll*S[l]+w[l]<<endl; }else if(k==2){ matrix ox,oy; ox._init(1,2); ox.c[0][0]=w[s],ox.c[0][1]=INF; if(s!=l) ox=ox*get(f[s][0],l,2); oy._init(1,2); oy.c[0][0]=w[t],oy.c[0][1]=INF; if(t!=l) oy=oy*get(f[t][0],l,2); LL res=INF; res=min(res,ox.c[0][0]+oy.c[0][0]-w[l]); res=min(res,ox.c[0][0]+oy.c[0][1]); res=min(res,ox.c[0][1]+oy.c[0][0]); res=min(res,ox.c[0][1]+oy.c[0][1]); cout<<res<<endl; }else{ LL pb=p[l]; if(l!=1) p[l]=min(p[l],w[f[l][0]]); base_3[l][0]=generate_3(l); matrix ox,oy; ox._init(1,3),oy._init(1,3); ox.c[0][0]=w[s],ox.c[0][1]=INF,ox.c[0][2]=INF; oy.c[0][0]=w[t],oy.c[0][1]=INF,oy.c[0][2]=INF; if(s!=l) ox=ox*get(f[s][0],l,3); if(t!=l) oy=oy*get(f[t][0],l,3); LL res=INF; res=min(res,ox.c[0][0]+oy.c[0][0]-w[l]); res=min(res,ox.c[0][0]+oy.c[0][1]); res=min(res,ox.c[0][0]+oy.c[0][2]); res=min(res,ox.c[0][1]+oy.c[0][0]); res=min(res,ox.c[0][1]+oy.c[0][1]); res=min(res,ox.c[0][1]+oy.c[0][2]); res=min(res,ox.c[0][2]+oy.c[0][0]); res=min(res,ox.c[0][2]+oy.c[0][1]); cout<<res<<endl; p[l]=pb; base_3[l][0]=generate_3(l); } } return 0; } /* 5 100 2 1 2 3 4 5 1 2 2 3 1 4 4 5 3 4 */
- 1