致力于写一篇通俗易懂的题解

前置芝士:树的直径,计数原理

solution

题目要我们求出所有的点集的数量,满足点集中任意两点的距离都等于这棵树的直径

由题,可以先求出树的直径 DD,用两次 dfsdfs 即可,随后要分类讨论两种情况:

  1. DD 为偶数的情况,此时我们需要找到树的中心节点(即为直径路径上为中点的点),可以证明,这里的中心节点是唯一的(证明过程详见OI Wiki)

    找到中心节点后,就可以令这个点为整颗树的顶点,将树转化为有根树,同时求出顶点到其余点的距离

    注意到,只有离顶点距离为 D/2D / 2 的点才可能在点集中,可以统计根节点的每颗子树的距离为 D/2D / 2 的点的数量,然后将每颗子树的节点数量11 后相乘即可(根据乘法原理,每颗子树可以选第 11 个离顶点距离为 D/2D / 2 的点,第二个... 和不选节点)。注意:点集中节点数不能为 11 或为空,因此我们要将答案减去一再减去所有距离为 D/2D / 2 的点的数量,即:

    ans=j=1s(ki+1)cnt1ans = \prod_{j=1}^s(k_i + 1) - cnt - 1

    ss 为根节点子树数量,cntcnt 为所有距离为 D/2D / 2 的点的数量

    还不懂?看图:

  2. DD 为奇数的情况,这时大体与上面一样,不过我们要找到中心边(即为直径路径上为中间的边),这样的边也是唯一的,随后我们建一个虚点,虚点连边中心边的两个端点

    我们将这个虚点作为树的根节点,随后按照上文的步骤操作即可,注意此时的树是一颗基环树,需要特殊处理(见代码)

    图:

时间复杂度: O(n+m)O(n + m)

code

#include<bits/stdc++.h>
#define int long long
using namespace std;

const int N = 2e5 + 5;
const int mod = 998244353;

int n , d[N] , dis[N] , c , D , si[N] , vis[N];
vector<int> g[N];

void dfs(int x , int f){
	for(auto u : g[x]){
		if(u == f) continue;
		d[u] = d[x] + 1;
		if(d[u] > d[c]) c = u;
		dfs(u , x);
	}
}

void bfs(int x){
	memset(vis , 0 , sizeof vis);
	queue<int> q;
	q.push(x);
	vis[x] = 1 , dis[x] = 0;
	
	while(!q.empty()){
		int u = q.front();
		q.pop();
		for(auto v : g[u]){
			if(vis[v]) continue;
			vis[v] = 1;
			dis[v] = dis[u] + 1;
			q.push(v);
		}
	}
}

void dfs2(int x , int f){
	if(dis[x] == D / 2) si[x] = 1;
	
	for(auto u : g[x]){
		if(u == f) continue;
		dfs2(u , x);
		si[x] += si[u];
	}
}

void solve(){
	if(D % 2 == 0){
		bfs(c);
		int p = 0 , cnt = 0 , ans = 1;
		
		for(int i = 1 ; i <= n ; i ++){
			if(d[i] == dis[i] && dis[i] == D / 2) p = i;
		}
		bfs(p); dfs2(p , 0);
		
		for(int i = 0 ; i < g[p].size() ; i ++){
			int u = g[p][i];
			ans = (ans * (si[u] + 1)) % mod;
			cnt += si[u];
		}
		cout << (ans - 1 - cnt + mod) % mod;
	}
	else{
		bfs(c);
		int p1 = 0 , p2 = 0 , ans = 1;
		
		for(int i = 1 ; i <= n ; i ++){
			if(d[i] == D / 2 && dis[i] == (D + 1) / 2) p1 = i;
			if(d[i] == (D + 1) / 2 && dis[i] == D / 2) p2 = i;
		}
		g[n + 1].push_back(p1) , g[n + 1].push_back(p2);
		D += 1;
		bfs(n + 1); dfs2(p1 , 0);
		ans = ((si[p1] - si[p2] + 1) * (si[p2] + 1)) % mod;
		
		cout << (ans - 1 - si[p1] + mod) % mod;
	}
}

signed main(){
	freopen("set.in" , "r" , stdin);
	freopen("set.out" , "w" , stdout);
	
	cin >> n;
	
	for(int i = 1 ; i < n ; i ++){
		int x , y;
		cin >> x >> y;
		g[x].push_back(y);
		g[y].push_back(x);
	}
	dfs(1 , 0); d[c] = 0;
	dfs(c , 0);
	
	D = d[c];
	solve();
	return 0;
}

希望这篇题解可以帮助到你 qwqqwq

1 条评论

  • @ 2025-7-22 21:12:57

    可以的!

    • 1