P7722 [Ynoi2007] tmpq 题解
前言
这篇题解是 zak 的做法,我实现得很丑所以常数巨大,但我尽量把我的实现讲清楚()。
做法
转化 & 暴力
首先,令
对于每一种
设
共有
由于
根号分治:小于等于根号的部分
对
对于出现次数
使用动态数组
接着考虑维护前缀和,令
对序列进行分块,考虑到只需要进行
根号分治:大于根号的部分
这个 DP 只有四个状态,所以考虑序列动态 DP。
由于出现次数
由于枚举每种
取矩阵中的哪个数呢?初始的
总时间复杂度
我错的一些细节
可能会出现
代码细节有点多,都打注释了。
#include <bits/stdc++.h>
#define ll long long
#define pii pair<int, int>
#define fi first
#define se second
#define IOS ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
const int N = 2e5 + 10, B = 210, M = 5e4 + 10;
int n, m, a[N], aa[N], b[N], c[N], cnt[N];
ll ans[M];
struct QUES {
int opt, x, y;
} q[M];
int block, nb, L[B], R[B], bel[N];
// 预处理分块
// 我根号分治和序列分块共用了 block
inline void init() {
block = 1000;
for (int i = 1; i <= n; i++)
bel[i] = (i + block - 1) / block;
nb = bel[n];
for (int i = 1; i <= nb; i++)
L[i] = R[i - 1] + 1, R[i] = i * block;
R[nb] = n;
}
namespace Small {
ll dp[4], f[N], s[N];
vector<pii> vec[N];
// 转移顺序:不能同一位的 1 到 2
bool cmp(pii x, pii y) {
return x.fi != y.fi ? x.fi < y.fi : x.se > y.se;
}
// 暴力寻找一个数并添加/删除
inline void add(int id, pii x) {
auto it = lower_bound(vec[id].begin(), vec[id].end(), x, cmp);
vec[id].insert(it, x);
}
inline void del(int id, pii x) {
// 提前把贡献消除
if (x.se == 3) {
s[bel[x.fi]] -= f[x.fi];
f[x.fi] = 0;
}
auto it = lower_bound(vec[id].begin(), vec[id].end(), x, cmp);
vec[id].erase(it);
}
// 重做 w=id 的 DP
inline void work(int id) {
dp[0] = 1;
dp[1] = dp[2] = dp[3] = 0;
for (auto [x, y] : vec[id]) {
dp[y] += dp[y - 1];
s[bel[x]] -= f[x];
if (y == 3)
f[x] = dp[2];
s[bel[x]] += f[x];
}
}
// 求 f 的前缀和
inline ll query(int x) {
ll ret = 0;
for (int i = 1; i < bel[x]; i++)
ret += s[i];
for (int i = L[bel[x]]; i <= x; i++)
ret += f[i];
return ret;
}
inline void solve() {
for (int i = 1; i <= n; i++) {
if (cnt[b[a[i]]] <= block)
add(b[a[i]], {i, 1});
if (cnt[a[i]] <= block)
add(a[i], {i, 2});
if (cnt[c[a[i]]] <= block)
add(c[a[i]], {i, 3});
}
for (int i = 1; i <= n; i++)
work(i);
for (int i = 1; i <= m; i++) {
auto [opt, x, y] = q[i];
if (opt == 1) {
if (cnt[b[a[x]]] <= block)
del(b[a[x]], {x, 1}), work(b[a[x]]);
if (cnt[a[x]] <= block)
del(a[x], {x, 2}), work(a[x]);
if (cnt[c[a[x]]] <= block)
del(c[a[x]], {x, 3}), work(c[a[x]]);
a[x] = y;
if (cnt[b[a[x]]] <= block)
add(b[a[x]], {x, 1}), work(b[a[x]]);
if (cnt[a[x]] <= block)
add(a[x], {x, 2}), work(a[x]);
if (cnt[c[a[x]]] <= block)
add(c[a[x]], {x, 3}), work(c[a[x]]);
} else
ans[i] += query(x);
}
}
}
namespace Big {
// cnt[w] > B 的数集
vector<int> big;
struct matrix {
ll d[4][4];
// 因为是前缀乘,所以初始化为单位矩阵
matrix() {
memset(d, 0, sizeof d);
for (int i = 0; i < 4; i++)
d[i][i] = 1;
}
matrix(bool I) {
memset(d, 0, sizeof d);
for (int i = 0; I and i < 4; i++)
d[i][i] = 1;
}
} mat, S[N], SB[B];
// 实现太垃圾不展开过不了 QAQ
inline matrix operator*(const matrix &x, const matrix &y) {
matrix ret;
ret.d[0][0] = x.d[0][0] * y.d[0][0] + x.d[0][1] * y.d[1][0] + x.d[0][2] * y.d[2][0] + x.d[0][3] * y.d[3][0];
ret.d[0][1] = x.d[0][0] * y.d[0][1] + x.d[0][1] * y.d[1][1] + x.d[0][2] * y.d[2][1] + x.d[0][3] * y.d[3][1];
ret.d[0][2] = x.d[0][0] * y.d[0][2] + x.d[0][1] * y.d[1][2] + x.d[0][2] * y.d[2][2] + x.d[0][3] * y.d[3][2];
ret.d[0][3] = x.d[0][0] * y.d[0][3] + x.d[0][1] * y.d[1][3] + x.d[0][2] * y.d[2][3] + x.d[0][3] * y.d[3][3];
ret.d[1][0] = x.d[1][0] * y.d[0][0] + x.d[1][1] * y.d[1][0] + x.d[1][2] * y.d[2][0] + x.d[1][3] * y.d[3][0];
ret.d[1][1] = x.d[1][0] * y.d[0][1] + x.d[1][1] * y.d[1][1] + x.d[1][2] * y.d[2][1] + x.d[1][3] * y.d[3][1];
ret.d[1][2] = x.d[1][0] * y.d[0][2] + x.d[1][1] * y.d[1][2] + x.d[1][2] * y.d[2][2] + x.d[1][3] * y.d[3][2];
ret.d[1][3] = x.d[1][0] * y.d[0][3] + x.d[1][1] * y.d[1][3] + x.d[1][2] * y.d[2][3] + x.d[1][3] * y.d[3][3];
ret.d[2][0] = x.d[2][0] * y.d[0][0] + x.d[2][1] * y.d[1][0] + x.d[2][2] * y.d[2][0] + x.d[2][3] * y.d[3][0];
ret.d[2][1] = x.d[2][0] * y.d[0][1] + x.d[2][1] * y.d[1][1] + x.d[2][2] * y.d[2][1] + x.d[2][3] * y.d[3][1];
ret.d[2][2] = x.d[2][0] * y.d[0][2] + x.d[2][1] * y.d[1][2] + x.d[2][2] * y.d[2][2] + x.d[2][3] * y.d[3][2];
ret.d[2][3] = x.d[2][0] * y.d[0][3] + x.d[2][1] * y.d[1][3] + x.d[2][2] * y.d[2][3] + x.d[2][3] * y.d[3][3];
ret.d[3][0] = x.d[3][0] * y.d[0][0] + x.d[3][1] * y.d[1][0] + x.d[3][2] * y.d[2][0] + x.d[3][3] * y.d[3][0];
ret.d[3][1] = x.d[3][0] * y.d[0][1] + x.d[3][1] * y.d[1][1] + x.d[3][2] * y.d[2][1] + x.d[3][3] * y.d[3][1];
ret.d[3][2] = x.d[3][0] * y.d[0][2] + x.d[3][1] * y.d[1][2] + x.d[3][2] * y.d[2][2] + x.d[3][3] * y.d[3][2];
ret.d[3][3] = x.d[3][0] * y.d[0][3] + x.d[3][1] * y.d[1][3] + x.d[3][2] * y.d[2][3] + x.d[3][3] * y.d[3][3];
return ret;
}
// 先 O(n) 做一次 w=x 的
inline void work(int x) {
for (int i = 1; i <= n; i++) {
int bi = bel[i];
mat.d[0][1] = b[a[i]] == x;
mat.d[1][2] = a[i] == x;
mat.d[2][3] = c[a[i]] == x;
if (i == L[bi])
S[i] = mat;
else
S[i] = S[i - 1] * mat;
}
for (int i = 1; i <= nb; i++)
SB[i] = SB[i - 1] * S[R[i]];
}
// O(sqrt(n)) 更新 w=x 位于 id 处的修改
inline void update(int x, int id) {
int bid = bel[id];
for (int i = id; i <= R[bid]; i++) {
mat.d[0][1] = b[a[i]] == x;
mat.d[1][2] = a[i] == x;
mat.d[2][3] = c[a[i]] == x;
if (i == L[bid])
S[i] = mat;
else
S[i] = S[i - 1] * mat;
}
for (int i = bid; i <= nb; i++)
SB[i] = SB[i - 1] * S[R[i]];
}
// 前缀答案
inline ll query(int x) {
return (SB[bel[x] - 1] * S[x]).d[0][3];
}
inline void solve() {
for (int i = 1; i <= n; i++)
if (cnt[i] > block)
big.push_back(i);
int cc = 0;
for (int x : big) {
// 先清零
for (int i = 1; i <= nb; i++)
S[i] = SB[i] = matrix();
for (int i = 1; i <= n; i++)
a[i] = aa[i];
work(x);
for (int i = 1; i <= m; i++) {
auto [opt, id, v] = q[i];
if (q[i].opt == 1) {
// 有关联的才更新,不然复杂度不对
if (a[id] == x or b[a[id]] == x or c[a[id]] == x or v == x or b[v] == x or c[v] == x) {
a[id] = v;
update(x, id);
cc++;
} else
a[id] = v;
} else
ans[i] += query(id);
}
}
}
}
signed main() {
IOS;
cin >> n >> m;
init();
for (int i = 1; i <= n; i++) {
cin >> a[i];
aa[i] = a[i];
}
for (int i = 1; i <= n; i++)
cin >> b[i];
for (int i = 1; i <= n; i++)
cin >> c[i];
for (int i = 1; i <= n; i++) {
cnt[a[i]]++, cnt[b[a[i]]]++, cnt[c[a[i]]]++;
}
for (int i = 1; i <= m; i++) {
cin >> q[i].opt >> q[i].x;
if (q[i].opt == 1) {
cin >> q[i].y;
cnt[q[i].y]++;
cnt[b[q[i].y]]++;
cnt[c[q[i].y]]++;
}
}
Small::solve();
Big::solve();
for (int i = 1; i <= m; i++)
if (q[i].opt == 2)
cout << ans[i] << "\n";
return 0;
}