T6 保卫王国

倍增 + 树形dp

solution

题意简化:

给出 mm 次询问,每次强制选 / 不选 两个点,求满足此条件下树上最小权独立集

关于最小权独立集以及其求法这里不赘述,右转这里

part1 暴力部分

首先可以想到强制选点和不选点可以将点权设置为负无穷和正无穷实现,最后统计时再减去影响加上原来权值即可

每次询问重新 dp 一次,求出答案

时间复杂度:O(nm)O(nm)

此做法可用动态 dp 优化,但是作者太菜了不会qwq

part2 正解

考虑如下状态 dpu,v,q1,q2dp_{u , v , q1 , q2} 表示点 u,vu , v 分别选 / 不选的树上最小权独立集大小,但是这样显然会超。考虑倍增优化

首先求出 fu,q,q{0,1}f_{u , q} , q \in \{0,1\} 表示以 uu 为根子树内的最小代价,并且 uu 选 / 不选

显然有如下转移

$$\begin{split} f_{u , 0} &= \sum_{v \in subtree(u)} {f_{v,1}} \\ f_{u , 1} &= \sum_{v \in subtree(u)} \min({f_{v,0}} , f_{v,1}) + a_u \end{split}$$

随后设状态 dpu,k,q1,q2dp_{u , k , q1 , q2} , q1,q2{0,1}q1 , q2 \in \{0,1\} 表示uu2k2^k 级祖先 (设其为 ancanc) 为根子树内的最小代价,并且 uu 选 / 不选 , ancanc 选 / 不选

借鉴求 lca 时的思路,我们可以枚举中间点的状态来进行转移 (设 fafauu2k12^{k-1} 级祖先)

$$dp_{u , k , q1 , q2} = \min \begin{cases} dp_{i,k-1,q1,0} + dp_{fa,k-1,0,q2} - f_{fa,0} \\ dp_{i,k-1,q1,1} + dp_{fa,k-1,1,q2} - f_{fa,1} \end{cases}$$

解释一下式子,因为 fafauu 之间无影响,所以去重可以直接减去

特别注意 dp 的初始化,考虑 ff 数组是如何统计贡献 / 转移的

随后思路就清楚了,对于每个询问我们可以根据 dp 来处理答案,具体地设询问为 (u,a,v,b)(u , a , v , b)

首先考虑跳 (u,v)(u , v),有两种情况

  • uulcalca , 此时直接将 vv 跳到 uu,边条边合并贡献vvlcalca 同理
  • 如果不是,将 (u,v)(u , v) 跳到 lcalca 儿子处,再在 lcalca 处合并贡献

然后直接将 lcalca 跳到根节点处即可(因为需要统计整棵树)

考虑如何合并贡献,合并有两种

  • 向上跳的合并,设当前节点为 uu,此时直接考虑枚举 u 的状态与当前要跳到的节点的状态 (设其为 ancanc),并将状态对应的子树答案记录,设 reskres_{k} 表示 uu 节点选 / 不选,以 uu 节点为根子树答案。tmpktmp_{k} 表示 ancanc 节点选 / 不选,以 ancanc 节点为根子树答案。k{0,1}k \in \{0,1\},跳之后更新 resres 值为 tmptmp,合并时更新 tmptmp 值,有
$$\begin{split} tmp_1 &= \min(dp_{u , k , 0 , 1} + res_2 - f_{u , 0} , dp_{u , k , 1 , 1} + res_1 - f_{u , 1}) \\ tmp_2 &= \min(dp_{u , k , 0 , 0} + res_2 - f_{u , 0} , dp_{u , k , 1 , 0} + res_1 - f_{u , 1}) \end{split}$$
  • lcalca 处的合并,此时用 ff 数组直接更新 res1,res2res_1 , res_2 的值即可,考虑 ff 数组是如何统计贡献,减去原来的贡献,替换上新的贡献(具体看代码)

时间复杂度:O((n+q)logn)O((n + q)\log n)

代码常数较大,但是可以通过(用矩阵转移似乎可以更优?)

code

#include<bits/stdc++.h>
#define ll long long
#define qwq ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;

const int N = 1e5 + 5;
const ll inf = 1e18;

string type;
int n , m , w[N];
ll f[N][2] , dp[N][20][2][2] , fa[N][20] , dis[N] , lg[N];

vector<int> g[N];

void dfs1(int u , int fh){
	f[u][1] = w[u];
	
	for(auto v : g[u]){
		if(v == fh) continue;
		dfs1(v , u);
		
		f[u][0] += f[v][1];
		f[u][1] += min(f[v][0] , f[v][1]);
	}
}

void dfs2(int u , int fh){
	fa[u][0] = fh , dis[u] = dis[fh] + 1;
	
	dp[u][0][0][0] = inf;
	dp[u][0][0][1] = f[fh][1] - min(f[u][0] , f[u][1]) + f[u][0];
	dp[u][0][1][0] = f[fh][0];
	dp[u][0][1][1] = f[fh][1] - min(f[u][0] , f[u][1]) + f[u][1];
	
	for(int k = 1 ; k <= lg[dis[u]] ; k ++){
		fa[u][k] = fa[fa[u][k - 1]][k - 1];
		
		for(int a = 0 ; a < 2 ; a ++){
			for(int b = 0 ; b < 2 ; b ++){
				int r = fa[u][k - 1];
				dp[u][k][a][b] = min(dp[u][k - 1][a][0] + dp[r][k - 1][0][b] - f[r][0] , 
								     dp[u][k - 1][a][1] + dp[r][k - 1][1][b] - f[r][1]);
			}
		}
	}
	for(auto v : g[u]){
		if(v != fh) dfs2(v , u);
	}
}

ll query(int x , int a , int y , int b){
	if(dis[x] < dis[y]) swap(x , y) , swap(a , b);
	
	ll res[5] = {0 , 0 , 0 , 0 , 0} , tmp[5] = {0 , 0 , 0 , 0 , 0};
	res[1] = (a ? f[x][1] : inf);
	res[2] = (a ? inf : f[x][0]);
	res[3] = (b ? f[y][1] : inf);
	res[4] = (b ? inf : f[y][0]);
	
	for(int k = lg[dis[x]] ; k >= 0 ; k --){
		if(dis[fa[x][k]] >= dis[y]){
			tmp[1] = min(dp[x][k][0][1] + res[2] - f[x][0] , dp[x][k][1][1] + res[1] - f[x][1]);
			tmp[2] = min(dp[x][k][0][0] + res[2] - f[x][0] , dp[x][k][1][0] + res[1] - f[x][1]);
			res[1] = tmp[1] , res[2] = tmp[2] , x = fa[x][k];
		}
	}
	if(x == y){
		if(b) res[2] = inf;
		else res[1] = inf;
	}
	else{
		for(int k = lg[dis[x]] ; k >= 0 ; k --){
			if(fa[x][k] != fa[y][k]){
				tmp[1] = min(dp[x][k][0][1] + res[2] - f[x][0] , dp[x][k][1][1] + res[1] - f[x][1]);
				tmp[2] = min(dp[x][k][0][0] + res[2] - f[x][0] , dp[x][k][1][0] + res[1] - f[x][1]);
				
				tmp[3] = min(dp[y][k][0][1] + res[4] - f[y][0] , dp[y][k][1][1] + res[3] - f[y][1]);
				tmp[4] = min(dp[y][k][0][0] + res[4] - f[y][0] , dp[y][k][1][0] + res[3] - f[y][1]);
				
				res[1] = tmp[1] , res[2] = tmp[2] , res[3] = tmp[3] , res[4] = tmp[4];
				x = fa[x][k] , y = fa[y][k];
			}
		}
		int r = fa[x][0];
		tmp[1] = f[r][1] - min(f[x][0] , f[x][1]) - min(f[y][0] , f[y][1]) + min(res[1] , res[2]) + min(res[3] , res[4]);
		tmp[2] = f[r][0] - f[x][1] - f[y][1] + res[1] + res[3];
		res[1] = tmp[1] , res[2] = tmp[2] , x = r;
	}
	
	for(int k = lg[dis[x]] ; k >= 0 ; k --){
		if(dis[fa[x][k]] >= dis[1]){
			tmp[1] = min(dp[x][k][0][1] + res[2] - f[x][0] , dp[x][k][1][1] + res[1] - f[x][1]);
			tmp[2] = min(dp[x][k][0][0] + res[2] - f[x][0] , dp[x][k][1][0] + res[1] - f[x][1]);
			res[1] = tmp[1] , res[2] = tmp[2] , x = fa[x][k];
		}
	}
	return min(res[1] , res[2]);
}

int main(){               qwq
	cin >> n >> m >> type;
	
	for(int i = 1 ; i <= n ; i ++){
		cin >> w[i];
	}
	for(int i = 1 ; i < n ; i ++){
		int u , v;
		cin >> u >> v;
		g[u].push_back(v);
		g[v].push_back(u);
	}
	
	for(int i = 1 ; i <= n ; i ++){
		lg[i] = lg[i >> 1] + 1;
	}
	
	dfs1(1 , 0);
	dfs2(1 , 0);
	
	for(int i = 1 ; i <= m ; i ++){
		int x , a , y , b;
		cin >> x >> a >> y >> b;
		
		ll tmp = query(x , a , y , b);
		cout << (tmp >= inf ? -1 : tmp) << '\n';
	}
	return 0;
}