洛谷P3914 - 染色计数

 OI / 洛谷
被浏览

CSP-S 前要多做树上问题 (主要是因为去年 NOIP D1T3、D2T1、D2T3 全是树上问题,然后爆炸了)

这题一看就是树形DP,状态也是一眼就能想到的。

fi,jf_{i,j} 表示 ii 这个点染成 jj 的方案数,枚举 sonison_i 的颜色来转移。

最后答案即为 i=1mf1,i\sum_{i=1}^{m}f_{1,i}

核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void dfs(int x, int fa) {
for (int i = head[x]; i; i = edge[i].nxt) {
int y = edge[i].to;
if (y == fa) continue;
dfs(y, x);
}
for (int c : col[x]) {
for (int i = head[x]; i; i = edge[i].nxt) {
int y = edge[i].to;
if (y == fa) continue;
int sum = 0;
for (int j : col[y])
if (j ^ c) Add(sum, f[y][j]);
f[x][c] = 1LL * f[x][c] * sum % P;
}
}
}

统计答案的片段就不放了。

但是冷静分析,这东西是 O(n3)O(n^3) 的,在 n=5000n=5000 的数据范围下显然过不掉。

通过观察转移方程我们不难发现某个儿子对当前节点某个颜色的贡献完全可以用该儿子的总贡献减去该儿子对这个颜色的贡献来得到。

这样就能优化成 O(n2)O(n^2) 了,貌似能过这道题了?

但是,等等,MLE了?!这道毒瘤题还卡你内存!这都 9102 年了居然还有只开 125M 的题?!其实只是自己菜没看空间限制

所以之前用 vector 存可用颜色的方法就行不通了,只能老老实实从 11mm 枚举。

效率大大降低了,时间换空间的典型案例

完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <cstdio>
#include <iostream>
using namespace std;
char buf[1 << 14], *p1 = buf, *p2 = buf;
inline char gc() {
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 14, stdin), p1 == p2) ? EOF : *p1++;
}
inline int read() {
int x = 0, f = 1;
char c = gc();
for (; !isdigit(c); c = gc())
if (c == '-') f = -1;
for (; isdigit(c); c = gc()) x = x * 10 + c - '0';
return x * f;
}
const int N = 5005;
const int P = 1e9 + 7;
int nedge, head[N];
struct Edge {
int to, nxt;
} edge[N << 1];
inline void add(int x, int y) {
edge[++nedge].to = y;
edge[nedge].nxt = head[x];
head[x] = nedge;
}
inline void Add(int &x, int y) {
x = x + y >= P ? x + y - P : x + y;
}
inline int Dec(int x, int y) {
x -= y;
if (x < 0) x += P;
return x;
}
int n, m, f[N][N];
void dfs(int x, int fa) {
for (int i = head[x]; i; i = edge[i].nxt) {
int y = edge[i].to;
if (y == fa) continue;
dfs(y, x);
int sum = 0;
for (int c = 1; c <= m; c++)
if (f[y][c]) Add(sum, f[y][c]);
for (int c = 1; c <= m; c++)
if (f[x][c]) f[x][c] = 1LL * f[x][c] * Dec(sum, f[y][c]) % P;
}
}
int main() {
n = read(), m = read();
for (int i = 1; i <= n; i++)
for (int num = read(); num; num--) f[i][read()] = 1;
for (int i = 1; i < n; i++) {
int x = read(), y = read();
add(x, y), add(y, x);
}
dfs(1, 0);
int ans = 0;
for (int c = 1; c <= m; c++) Add(ans, f[1][c]);
printf("%d", ans);
}