#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

const int N = 200010, M = N * 2, K = 18;
const LL INF = 1e18;

int n, Q, m;
int w[N];
int h[N], e[M], ne[M], idx;
int fa[N][K], dep[N], g[N];
LL f[N][K][3][3];

void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}

void dfs_fa(int u, int father, int depth)
{
    dep[u] = depth;
    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if (j == father) continue;

        fa[j][0] = u;
        for (int k = 1; k < K; k ++ )
            fa[j][k] = fa[fa[j][k - 1]][k - 1];

        dfs_fa(j, u, depth + 1);
    }
}

int lca(int a, int b)
{
    if (dep[a] < dep[b]) swap(a, b);
    for (int k = K - 1; k >= 0; k -- )
        if (dep[fa[a][k]] >= dep[b])
            a = fa[a][k];
    if (a == b) return a;

    for (int k = K - 1; k >= 0; k -- )
        if (fa[a][k] != fa[b][k])
            a = fa[a][k], b = fa[b][k];
    return fa[a][0];
}

void dfs_f(int u, int father)
{
    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if (j == father) continue;

        auto F = f[j][0];
        if (m == 1)
        {
            F[0][0] = w[u];
        }
        else if (m == 2)
        {
            F[0][0] = F[1][0] = w[u];
            F[0][1] = 0;

        }
        else if (m == 3)
        {
            F[0][0] = F[1][0] = F[2][0] = w[u];
            F[0][1] = F[1][2] = 0;
            F[2][2] = g[j];
        }

        for (int k = 1; k < K; k ++ )
            for (int x = 0; x < m; x ++ )
                for (int y = 0; y < m; y ++ )
                    for (int z = 0; z < m; z ++ )
                        f[j][k][x][y] = min(f[j][k][x][y], f[j][k - 1][x][z] + f[fa[j][k - 1]][k - 1][z][y]);

        dfs_f(j, u);
    }
}

void calc(int a, int p, LL r[])
{
    r[0] = w[a], r[1] = r[2] = INF;
    LL nr[3];

    for (int k = K - 1; k >= 0; k -- )
        if (dep[fa[a][k]] >= dep[p])
        {
            memset(nr, 0x3f, sizeof nr);
            for (int x = 0; x < m; x ++ )
                for (int y = 0; y < m; y ++ )
                    nr[y] = min(nr[y], r[x] + f[a][k][x][y]);
            memcpy(r, nr, sizeof nr);

            a = fa[a][k];
        }
}

LL query(int a, int b)
{
    int p = lca(a, b);
    LL l[3], r[3];
    calc(a, p, l), calc(b, p, r);

    LL res = l[0] + r[0] - w[p];
    for (int x = 0; x < m; x ++ )
        for (int y = 0; y < m; y ++ )
        {
            LL s = l[x] + r[y];
            if (x == 2 && y == 2) s += g[p];
            res = min(res, s);
        }

    return res;
}

int main()
{
    scanf("%d%d%d", &n, &Q, &m);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);

    memset(h, -1, sizeof h);
    memset(g, 0x3f, sizeof g);
    for (int i = 0; i < n - 1; i ++ )
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b), add(b, a);
        g[a] = min(g[a], w[b]);
        g[b] = min(g[b], w[a]);
    }

    dfs_fa(1, -1, 1);

    memset(f, 0x3f, sizeof f);
    dfs_f(1, -1);

    while (Q -- )
    {
        int a, b;
        scanf("%d%d", &a, &b);
        printf("%lld\n", query(a, b));
    }

    return 0;
}

1 条评论

  • @ 2025-10-24 19:59:41

    附我的极丑代码

    #include <bits/stdc++.h>
    using namespace std;
    
    typedef long long LL;
    const int N=2e5+5;
    const LL INF=1e15+9586;
    
    int n,T,k,head[N],ver[N*2],nxt[N*2],tot=0;
    int f[N][22];
    bool vis[N];
    
    struct matrix{
    	int n,m;
    	LL c[3][3];
    	void _init(int x,int y){n=x,m=y;}
    };
    
    
    matrix base_2[N][21],base_3[N][21];
    
    LL S[N];
    LL w[N],depth[N],p[N];
    
    void add(int x,int y){
    	ver[++tot]=y,nxt[tot]=head[x],head[x]=tot;
    }
    matrix operator *(const matrix &a,const matrix &b){
    	matrix d;
    	d._init(a.n,b.m);
    	for(int i=0;i<a.n;i++){
    		for(int j=0;j<b.m;j++){
    			d.c[i][j]=INF;
    			for(int k1=0;k1<a.m;k1++) d.c[i][j]=min(d.c[i][j],a.c[i][k1]+b.c[k1][j]);
    		}
    	}
    	return d;
    }
    
    matrix generate_2(int x){
    	matrix tmp;
    	tmp._init(2,2);
    	tmp.c[0][0]=w[x],tmp.c[0][1]=0,tmp.c[1][0]=w[x],tmp.c[1][1]=INF;
    	return tmp;
    }
    matrix generate_3(int x){
    	matrix tmp;
    	tmp._init(3,3);
    	tmp.c[0][0]=w[x],tmp.c[0][1]=0,tmp.c[0][2]=INF;
    	tmp.c[1][0]=w[x],tmp.c[1][1]=p[x],tmp.c[1][2]=0;
    	tmp.c[2][0]=w[x],tmp.c[2][1]=INF,tmp.c[2][2]=INF;
    	return tmp;
    }
    void dfs(int x,int fa){
    	for(int i=head[x];i;i=nxt[i]){
    		int y=ver[i];
    		if(y==fa) continue;
    		p[x]=min(p[x],w[y]);
    		dfs(y,x);
    	}
    }
    void bfs(){
    	queue<int> q;
    	q.push(1);
    	S[1]=w[1];
    	vis[1]=1;
    	depth[1]=1;
    	base_2[1][0]=generate_2(1);
    	base_3[1][0]=generate_3(1);
    	while(!q.empty()){
    		int u=q.front();
    		q.pop();
    		for(int i=head[u];i;i=nxt[i]){
    			int v=ver[i];
    			if(vis[v]) continue;
    			vis[v]=1;
    			f[v][0]=u;
    			S[v]=S[u]+w[v];
    			depth[v]=depth[u]+1;
    			q.push(v);
    			base_2[v][0]=generate_2(v);
    			base_3[v][0]=generate_3(v);
    			for(int j=1;j<=20;j++){
    				f[v][j]=f[f[v][j-1]][j-1];
    				base_2[v][j]=base_2[v][j-1]*base_2[f[v][j-1]][j-1];
    				base_3[v][j]=base_3[v][j-1]*base_3[f[v][j-1]][j-1];
    				
    			}
    		}
    	}
    }
    int lca(int x,int y){
    	if(depth[x]>depth[y]) swap(x,y);
    	for(int i=20;i>=0;i--) if(depth[f[y][i]]>=depth[x]) y=f[y][i];
    	if(x==y) return x;
    	for(int i=20;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    	return f[x][0];
    }
    matrix get(int x,int y,int h){//查询由x至y的矩阵积
    //	cout<<x<<","<<y<<":"<<endl;
    	matrix tmp;
    	bool flag=false;
    	for(int i=20;i>=0;i--){
    		if(depth[f[x][i]]>=depth[y]){
    			if(!flag){
    				if(h==2) tmp=base_2[x][i];
    				else tmp=base_3[x][i];
    				flag=true;
    			}else{
    				if(h==2) tmp=tmp*base_2[x][i];
    				else tmp=tmp*base_3[x][i];
    			}
    			x=f[x][i];
    		}
    	}
    	if(!flag){
    		if(h==2) tmp=base_2[x][0];
    		else tmp=base_3[x][0];
    	}else{
    		if(h==2) tmp=tmp*base_2[x][0];
    		else tmp=tmp*base_3[x][0];
    	}
    	return tmp;
    }
    int main(){
    	ios::sync_with_stdio(false);
    	cin.tie(0),cout.tie(0);
    	cin>>n>>T>>k;
    	for(int i=1;i<=n;i++){
    		cin>>w[i];
    		p[i]=INF;
    	}
    	int a,b;
    	for(int i=1;i<n;i++){
    		cin>>a>>b;
    		add(a,b),add(b,a);
    	}
    	dfs(1,0);
    	bfs();
    //	cout<<"zzzz"<<endl;
    //	cout<<base_2[2][1].c[0][0]<<" "<<base_2[2][1].c[0][1]<<endl;
    //	cout<<base_2[2][1].c[1][0]<<" "<<base_2[2][1].c[1][1]<<endl;
    //	cout<<"zzzz"<<endl;
    	while(T--){
    		int s,t;
    		cin>>s>>t;
    		int l=lca(s,t);
    		if(k==1){
    			cout<<S[s]+S[t]-2ll*S[l]+w[l]<<endl;
    		}else if(k==2){
    			matrix ox,oy;
    			ox._init(1,2);
    			ox.c[0][0]=w[s],ox.c[0][1]=INF;
    			if(s!=l) ox=ox*get(f[s][0],l,2);
    			oy._init(1,2);
    			oy.c[0][0]=w[t],oy.c[0][1]=INF;
    			if(t!=l) oy=oy*get(f[t][0],l,2);
    			LL res=INF;
    			res=min(res,ox.c[0][0]+oy.c[0][0]-w[l]);
    			res=min(res,ox.c[0][0]+oy.c[0][1]);
    			res=min(res,ox.c[0][1]+oy.c[0][0]);
    			res=min(res,ox.c[0][1]+oy.c[0][1]);
    			cout<<res<<endl;
    		}else{
    			LL pb=p[l];
    			if(l!=1) p[l]=min(p[l],w[f[l][0]]);
    			base_3[l][0]=generate_3(l);
    			matrix ox,oy;
    			ox._init(1,3),oy._init(1,3);
    			ox.c[0][0]=w[s],ox.c[0][1]=INF,ox.c[0][2]=INF;
    			oy.c[0][0]=w[t],oy.c[0][1]=INF,oy.c[0][2]=INF;
    			
    			
    			if(s!=l) ox=ox*get(f[s][0],l,3);
    			if(t!=l) oy=oy*get(f[t][0],l,3);
    			LL res=INF;
    			res=min(res,ox.c[0][0]+oy.c[0][0]-w[l]);
    			res=min(res,ox.c[0][0]+oy.c[0][1]);
    			res=min(res,ox.c[0][0]+oy.c[0][2]);
    			res=min(res,ox.c[0][1]+oy.c[0][0]);
    			res=min(res,ox.c[0][1]+oy.c[0][1]);
    			res=min(res,ox.c[0][1]+oy.c[0][2]);
    			res=min(res,ox.c[0][2]+oy.c[0][0]);
    			res=min(res,ox.c[0][2]+oy.c[0][1]);
    			cout<<res<<endl;
    			p[l]=pb;
    			base_3[l][0]=generate_3(l);
    		}
    	}
    	return 0;
    }
    /*
    5 100 2
    1 2 3 4 5
    1 2
    2 3
    1 4
    4 5
    3 4
    */
    
    • 1