T1 奇怪的魔法

solution

我们要对 nn 个集合维护两种操作:合并与查询

带权并查集来维护集合内点的个数 ss 和某个点到根节点的距离 dd

  1. find 函数:在找父节点的同时压缩路径,更新当前节点 dx=dx+dfaxd_x = d_x + d_{fa_x}
  2. 合并:将根节点合并,更新节点。具体地,设集合的根节点分别为 uvu,vfau=vfa_u = vdu=svd_u = s_vsv=sv+sus_v = s_v + s_u
  3. 查询:节点是否在同一集合中,若在则输出距离

code

#include<bits/stdc++.h>
#define qwq ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;

const int N = 3e4 + 5;

int fa[N] , s[N] , d[N];

int t , x , y;
char op;

int find(int x){
	if(x != fa[x]){
		int t = fa[x];
		fa[x] = find(fa[x]);
		d[x] += d[t] , s[x] = s[fa[x]];
	}
	return fa[x];
}

void uni(int x , int y){
	int u = find(x) , v = find(y);
	if(u != v){
		fa[u] = v;
		d[u] = d[v] + s[v];
		s[u] += s[v] , s[v] = s[u];
	}
}

int main(){                qwq
	cin >> t;
	
	for(int i = 1 ; i <= N - 5 ; i ++){
		fa[i] = i , s[i] = 1;
	}
	
	while(t --){
		cin >> op >> x >> y;
		if(op == 'M'){
			uni(x , y);
		}
		else{
			int u = find(x) , v = find(y);
			if(u == v) cout << abs(d[x] - d[y]) - 1 << '\n';
			else cout << -1 << '\n';
		}
	}
	return 0;
}

T2 不安分的魔法学院

solution

我们用扩展域并查集来维护每个法师的信息

设某一个法师编号为 pp,我们构造两个虚点

  1. p+np + n 表示 pp 克制的法师编号
  2. p+2np + 2n 表示克制 pp 的法师编号

依据题意模拟,具体地:

  1. d=1d = 1,如果 xx 是克制 yy 的,或者 xx 是被 yy 克制的(x=y+2nx = y + 2nx=y+nx = y + n)则为假话,否则合并
  2. d=2d = 2,如果 xxyy ,或者 xx 是被 yy 克制的(x=yx = yx=y+nx = y + n)则为假话。否则合并,注意这里合并时要合并 xxy+2ny + 2nx+nx + nyyx+2nx + 2ny+ny + n (可以想想实际含义是什么)
  3. 若为假话统计 ansans

code

#include<bits/stdc++.h>
using namespace std;

const int N = 2e5 + 5;

int n , k , ans;
int fa[N];

int find(int x){
	if(x == fa[x]) return x;
	return fa[x] = find(fa[x]);
}

void uni(int x , int y){
	int u = find(x) , v = find(y);
	if(u != v) fa[u] = v;
}

int main(){
	cin >> n >> k;
	
	for(int i = 1 ; i <= n * 3 ; i ++){
		fa[i] = i;
	}
	for(int i = 1 ; i <= k ; i ++){
		int d , x , y;
		cin >> d >> x >> y;
		
		if(x > n || y > n){
			ans ++;
			continue;
		}
		if(d == 1){
			if(find(x) == find(y + n) || find(x) == find(y + 2 * n)) ans ++;
			else{
				uni(x , y) , uni(x + n , y + n) , uni(x + 2 * n , y + 2 * n);
			}
		}
		if(d == 2){
			if(find(x) == find(y) || find(x) == find(y + n)) ans ++;
			else{
				uni(x , y + 2 * n) , uni(x + n , y) , uni(x + 2 * n , y + n);
			}
		}
	}
	cout << ans;
	return 0;
}

T3 最短路径的数量

solution

考虑在 bfs 中 dp,用 bfs 求边权为 11 的图的最短路

dpxdp_x11 ~ xx 的最短路数量disxdis_x11 ~ xx 的最短路长度

显然,只有在第一次遍历到节点 xx 或当前是最短路时(disx=disu+1dis_x = dis_u + 1)才能更新 dpxdp_x,状态转移方程为:

dpx=dpx+dpudp_x = dp_x + dp_u

注意每个点只能入队一次,但可以被遍历多次,所以当第一次遍历到节点 xx 时入队

code

#include<bits/stdc++.h>
using namespace std;

const int N = 2e5 + 5;
const int mod = 1e9 + 7;

int n , m , dp[N] , dis[N];
vector<int> g[N];

void bfs(int s){
	queue<int> q;
	q.push(s);
	
	while(!q.empty()){
		int u = q.front();
		q.pop();
		for(int i = 0 ; i < g[u].size() ; i ++){
			int v = g[u][i];
			
			if(dis[v] == 0 || dis[v] == dis[u] + 1){
				if(!dis[v]){
					dis[v] = dis[u] + 1;
					q.push(v);
				}
				dp[v] = (dp[v] + dp[u]) % mod;
			}
		}
	}
}

int main(){
	cin >> n >> m;
	
	for(int i = 1 ; i <= m ; i ++){
		int x , y;
		cin >> x >> y;
		g[x].push_back(y);
		g[y].push_back(x);
	}
	dis[1] = 1;
	dp[1] = 1;
	bfs(1);
	
	cout << dp[n] % mod;
	return 0;
}

T5 非递减彩色路径

solution

看数据范围 N,M2×105N,M≤2×10^5 只能用 O(nlogn)O(n\log n) 的算法实现,首先考虑 dp ,但是 dp 要求只能在 DAG 上实现,所以要先把图转化为 DAG

DAG转化宝典:

  1. 缩点,在此题中可以把权值相同且相邻的点给缩为一个点,用并查集实现
  2. 去边,观察性质可知,在图中若一条有向边连接了一个权值大的点和一个权值小的点,这样的边肯定不会在路径中,对答案无影响,因此可以去掉

随后将并查集合并后的点按照权值排序,设 dpxdp_x11 ~ xx 的最长非递减彩色路径的长度,状态转移方程为:

dpx=max(dpx,dpu+1)dp_x = max(dp_x , dp_u + 1)

按照权值排序从小到大进行转移即可

时间复杂度:O(m+nlogn)O(m + n\log n)

code

#include<bits/stdc++.h>
#define qwq ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;

const int N = 2e5 + 5;

struct edge{
	int u , v;
}e[N];

int n , m;
int a[N] , dp[N] , fa[N];
vector<int> g[N] , d;

int find(int x){
	if(x == fa[x]) return x;
	return fa[x] = find(fa[x]);
}

void uni(int x , int y){
	int u = find(x) , v = find(y);
	if(u != v) fa[u] = v;
}

bool cmp(const int &lhs , const int &rhs){
	return a[lhs] < a[rhs];
}

int main(){              qwq
	cin >> n >> m;
	
	for(int i = 1 ; i <= n ; i ++){
		cin >> a[i];
		fa[i] = i;
	}
	for(int i = 1 ; i <= m ; i ++){
		int x , y;
		cin >> x >> y;
		e[i] = {x , y};
		if(a[x] == a[y]) uni(x , y);
	}
	
	for(int i = 1 ; i <= m ; i ++){
		int u = find(e[i].u) , v = find(e[i].v);
		if(u == v) continue;
		
		if(a[u] > a[v]) g[v].push_back(u);
		if(a[u] < a[v]) g[u].push_back(v);
	}
	for(int i = 1 ; i <= n ; i ++){
		if(i == find(i)) d.push_back(i);
	}
	sort(d.begin() , d.end() , cmp);
	memset(dp , -1 , sizeof dp);
	dp[find(1)] = 1;
	
	for(auto u : d){
		if(dp[u] == -1) continue;
		
		for(auto v : g[u]){
			if(a[u] == a[v]) dp[v] = max(dp[u] , dp[v]);
			if(a[u] < a[v]) dp[v] = max(dp[v] , dp[u] + 1);
		}
	}
	cout << max(dp[find(n)] , 0);
	return 0;
}