Skip to content

NOIP 2024 T3 笔记

为方便记述,认为 x0,x!=1,x1=1

由于作者是笨比,所以这是不拆脚手架的笔记。

一个自尊的建筑师不会在盖好的房子里留下脚手架。——高斯

1. k=1

首先,可以推得同一节点周围所有相邻边在新树上一定是一条连续的链,因为假设遍历到了边 (u,v),要么就直接遍历完 u 的所有相邻边,此时在新树上显然是链;要么就往后遍历与 v 相邻的边以及后面的边,由于树的结构特性,最后一定会回溯到 (u,v),此时只能继续遍历与 u 相邻的边,最终仍会形成链。

再容易发现,一个节点所形成的新树上的链形态如何,不影响其他链的形态,因为每个点相邻的边中,第一条被遍历的边总不变,即新树上链的起点总不变,故可以独立考虑。对于一个度数为 du 的节点,链的形态数就是一个排列数,但第一条边已经固定,所以是 (du1)!。最后显然是乘法原理,所以答案为

(di1)!

2. k1

容易发现,一棵同样的新树能够被原树上某条链上的任意一条边为起始边构建。

  • 证明:对于一个节点的所有相邻边,只有第一个被遍历的和最后一个被遍历的有可能在不改变新树形态的前提下,成为该点相邻边中第一条被遍历的边。显然这样的边对于每个节点最多只有两条且连续,最终在原树上呈链形态。

下文中原树上的链均指从叶子到叶子的链。

观察钦定一条链后生成多少树?由上面的证明可以反推,原树上的链确定了,链上每个点相邻的边中,第一个和最后一个被遍历的边就确定了,套用 k=1 的结论,设 i 为钦定链上的所有点,j 为其他点,答案为

(di2)!×(dj1)!

等价于

(di1)!×(dj1)!(di1)

分子固定,可以 O(n) 处理出来,假设是 all,每条原树上的链都不会生成重复的树,一条链若要合法,它至少应该包含一条关键边,故总答案为

all×i 是链上的点(di1)1

乘法分配律

all×i 是链上的点(di1)1

求和号及以后这坨东西即可,令这坨东西为 ans。将一条链拆成从“转折点”出发的两条路径,然后显然可以同理用乘法分配律拆,树形 DP 维护。对于一个节点 u,将它的每个子节点 vv 子树内所有叶子的结果和,从左到右相加合并至 u,并统计以 u 为“转折点”的链的答案。再乘上 (du1)1

具体地,设 fu,0 表示从 uu 子树内的所有叶子的路径中,不经过 u 子树内任何一条关键边时,

i 是路径上的点(di1)1

的结果,fu,1 的定义与 fu,0 相反,即经过任意一条关键边。

初始时,对于每个叶子节点 ufu,01

ret 统计已经合并到 u 的子树的路径,经过点 u 走向 v 子树内的路径,且至少包含一条关键边,的答案。

对于 u 的每个子节点 v

  • (u,v) 是关键边,fu,0 不转移。

    • fu,1fu,1+fv,0+fv,1

    • retret+(fu,0+fu,1)×(fv,0+fv,1)

  • (u,v) 不是关键边:

    • fu,0fu,0+fv,0

    • fu,1fu,1+fv,1

    • retret+fu,0×fv,1+fu,1×(fv,0+fv,1)

遍历完子节点后,根据状态定义,fu,0,fu,1,ret 均乘上 (du1)1ansans+ret 即可。

小细节:树形 DP 显然不能把叶子当作 DP 的根节点,否则难以统计,故还要特判 n=2 的时候答案为 1

然后就做完了。

这题真是太厉害了/kel。

cpp
#include <bits/stdc++.h>
#define int long long
#define IOS ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
const int N = 1e5 + 10, mod = 1e9 + 7;
int c, T;
int n, k, dg[N], e[N], f[N][2], ans;
struct point {
    int v, id;
};
vector<point> g[N];
int pow(int x, int y) {
    int ret = 1;
    for (; y; y >>= 1, x = x * x % mod)
        if (y & 1)
            ret = ret * x % mod;
    return ret;
}
int inv(int x) {
    if (x <= 0)
        return 1;
    return pow(x, mod - 2);
}
void dfs(int u, int fa) {
    int ret = 0;
    for (auto [v, id] : g[u]) {
        if (v == fa)
            continue;
        dfs(v, u);
        if (e[id]) {
            ret = (ret + (f[u][0] + f[u][1]) * (f[v][0] + f[v][1]) % mod) % mod;
            f[u][1] = (f[u][1] + f[v][0] + f[v][1]) % mod;
        } else {
            ret = (ret + f[u][0] * f[v][1] % mod + f[u][1] * (f[v][0] + f[v][1]) % mod) % mod;
            f[u][0] = (f[u][0] + f[v][0]) % mod;
            f[u][1] = (f[u][1] + f[v][1]) % mod;
        }
    }
    ans = (ans + ret * inv(dg[u] - 1) % mod) % mod;
    if (dg[u] == 1)
        f[u][0] = 1;
    f[u][0] = f[u][0] * inv(dg[u] - 1) % mod;
    f[u][1] = f[u][1] * inv(dg[u] - 1) % mod;
}
void clear() {
    ans = 0;
    for (int i = 1; i <= n; i++) {
        g[i].clear();
        dg[i] = e[i] = f[i][0] = f[i][1] = 0;
    }
}
void solve() {
    cin >> n >> k;
    for (int i = 1, u, v; i < n; i++) {
        cin >> u >> v;
        g[u].push_back({v, i});
        g[v].push_back({u, i});
        dg[u]++, dg[v]++;
    }
    for (int i = 1, x; i <= k; i++) {
        cin >> x;
        e[x] = 1;
    }
    if (n == 2) {
        cout << "1\n";
        clear();
        return;
    }
    for (int i = 1; i <= n; i++)
        if (dg[i] > 1) {
            dfs(i, 0);
            break;
        }
    for (int i = 1; i <= n; i++)
        for (int j = 1; j < dg[i]; j++)
            ans = ans * j % mod;
    cout << ans << "\n";
    clear();
}
signed main() {
    IOS;
    cin >> c >> T;
    while (T--)
        solve();
    return 0;
}
最近更新