NOIP 2024 T3 笔记
为方便记述,认为
由于作者是笨比,所以这是不拆脚手架的笔记。
一个自尊的建筑师不会在盖好的房子里留下脚手架。——高斯
1.
首先,可以推得同一节点周围所有相邻边在新树上一定是一条连续的链,因为假设遍历到了边
再容易发现,一个节点所形成的新树上的链形态如何,不影响其他链的形态,因为每个点相邻的边中,第一条被遍历的边总不变,即新树上链的起点总不变,故可以独立考虑。对于一个度数为
2.
容易发现,一棵同样的新树能够被原树上某条链上的任意一条边为起始边构建。
- 证明:对于一个节点的所有相邻边,只有第一个被遍历的和最后一个被遍历的有可能在不改变新树形态的前提下,成为该点相邻边中第一条被遍历的边。显然这样的边对于每个节点最多只有两条且连续,最终在原树上呈链形态。
下文中原树上的链均指从叶子到叶子的链。
观察钦定一条链后生成多少树?由上面的证明可以反推,原树上的链确定了,链上每个点相邻的边中,第一个和最后一个被遍历的边就确定了,套用
等价于
分子固定,可以
乘法分配律
求和号及以后这坨东西即可,令这坨东西为
具体地,设
的结果,
初始时,对于每个叶子节点
用
对于
当
是关键边, 不转移。 。 。
当
不是关键边: 。 。 。
遍历完子节点后,根据状态定义,
小细节:树形 DP 显然不能把叶子当作 DP 的根节点,否则难以统计,故还要特判
然后就做完了。
这题真是太厉害了/kel。
#include <bits/stdc++.h>
#define int long long
#define IOS ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
const int N = 1e5 + 10, mod = 1e9 + 7;
int c, T;
int n, k, dg[N], e[N], f[N][2], ans;
struct point {
int v, id;
};
vector<point> g[N];
int pow(int x, int y) {
int ret = 1;
for (; y; y >>= 1, x = x * x % mod)
if (y & 1)
ret = ret * x % mod;
return ret;
}
int inv(int x) {
if (x <= 0)
return 1;
return pow(x, mod - 2);
}
void dfs(int u, int fa) {
int ret = 0;
for (auto [v, id] : g[u]) {
if (v == fa)
continue;
dfs(v, u);
if (e[id]) {
ret = (ret + (f[u][0] + f[u][1]) * (f[v][0] + f[v][1]) % mod) % mod;
f[u][1] = (f[u][1] + f[v][0] + f[v][1]) % mod;
} else {
ret = (ret + f[u][0] * f[v][1] % mod + f[u][1] * (f[v][0] + f[v][1]) % mod) % mod;
f[u][0] = (f[u][0] + f[v][0]) % mod;
f[u][1] = (f[u][1] + f[v][1]) % mod;
}
}
ans = (ans + ret * inv(dg[u] - 1) % mod) % mod;
if (dg[u] == 1)
f[u][0] = 1;
f[u][0] = f[u][0] * inv(dg[u] - 1) % mod;
f[u][1] = f[u][1] * inv(dg[u] - 1) % mod;
}
void clear() {
ans = 0;
for (int i = 1; i <= n; i++) {
g[i].clear();
dg[i] = e[i] = f[i][0] = f[i][1] = 0;
}
}
void solve() {
cin >> n >> k;
for (int i = 1, u, v; i < n; i++) {
cin >> u >> v;
g[u].push_back({v, i});
g[v].push_back({u, i});
dg[u]++, dg[v]++;
}
for (int i = 1, x; i <= k; i++) {
cin >> x;
e[x] = 1;
}
if (n == 2) {
cout << "1\n";
clear();
return;
}
for (int i = 1; i <= n; i++)
if (dg[i] > 1) {
dfs(i, 0);
break;
}
for (int i = 1; i <= n; i++)
for (int j = 1; j < dg[i]; j++)
ans = ans * j % mod;
cout << ans << "\n";
clear();
}
signed main() {
IOS;
cin >> c >> T;
while (T--)
solve();
return 0;
}