「CSP-S 2020」函数调用(拓扑排序+DP)
Address
Solution
因为加是单点加,乘是全体乘,所以考虑计算后面的乘对前面的加的影响。
也就是说,对于某次执行 \(T_j=1\) 的操作 \(a_p+=v\),设在它之后执行的 \(T_j=2\) 的操作的 \(\prod V_j=x\)。那么计算最终答案的时候,只要把 \(a_p+=v\times x\) 即可。
对于 \(T_j=3\) 的操作,题目说保证不会出现递归(即不会直接或间接地调用本身)。因此建一张 DAG,如果函数 \(u\) 直接调用了函数 \(v\),那么连一条 \(u→v\) 的边。方便起见,再建一个点 \(m+1\),向 \(Q\) 个 \(f_i\) 都连一条边。
接下来开始暴力,我们记一个 \(prod\),表示当前访问过的 \(T_j=2\) 节点的 \(\prod V_j\)。(重复访问就重复计算)
从 \(m+1\) 开始 DFS(注意出边的顺序要反过来,因为是后面的乘对前面的加的影响)。DFS 到 \(u\) 的时候,如果 \(T_j=2\),\(prod\times=V_j\),如果 \(T_j=1\),\(a_{P_j}+=V_j\times prod\),如果 \(T_j=3\),就什么都不做。
怎么优化这个暴力?
考虑对于一个点 \(u\),它连向的点分别为 \(v_1,v_2,...,v_k\)。那么 DFS 到 \(u\) 之后,设当前的 \(prod\) 为 \(s\),接下来肯定是 DFS \(v_1\),那么执行完 DFS \(v_1\),准备 DFS \(v_2\) 的时候,\(prod\) 是多少?
预处理出 \(dp_u\) 表示从 \(u\) 开始 DFS,经过的所有 \(T_j=2\) 节点的 \(\prod V_j\),按拓扑序倒序转移即可。
那么上述的 \(prod\) 就是 \(dp_{v_1}\times s\),以此类推,准备 DFS \(v_i\) 的时候,\(prod\) 就是 \(\prod_{j=1}^{i-1}dp_{v_j}\times s\)。
我们可以这样描述这个 DFS:从 \(u\) 开始,带着大小为 \(prod\) 的标记走下去,接下来,对于每个 \(v_i\),带着大小为 \(\prod_{j=1}^{i-1}dp_{v_j}\times s\) 的标记走下去。也就是说,我们不用把 \(v_1\sim v_{i-1}\) 都 DFS 一遍,就可以知道 \(v_i\) 的标记大小,它仅仅取决于所有连向它的 \(u\)。
我们记 \(tag_u\) 表示点 \(u\) 的标记大小。这个 \(tag\) 有什么用呢?我们发现所有的 \(T_j\in\{1,2\}\) 的 \(j\) 都是底层节点,没有出边,所以如果 \(T_j=1\),我们求出 \(tag_j\) 之后,直接让 \(a_{P_j}+=V_j\times tag_j\) 即可。
根据上述分析,对于一个点 \(v\),只要知道所有连向它的 \(u\) 的 \(tag_u\),即可用形如 \(tag_v=\sum_{u→v}tag_u\times \prod_{x\in pre(u,v)}dp_x\) 的式子求出 \(tag_v\)。按照拓扑序转移即可。
注意把所有 \(m+1\) 走不到的点和边删掉。
时间复杂度 \(O(n+m+Q+\sum C_j)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
template <class t>
inline void read(t & res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + (ch ^ 48);
}
template <class t>
inline void print(t x)
{
if (x > 9) print(x / 10);
putchar(x % 10 + 48);
}
const int N = 1e5 + 15, M = 2e6 + 15, mod = 998244353;
int adj[N], nxt[M], go[M], val[N], pos[N], typ[N], n, m, q, tag[N];
int f[N], deg[N], seq[N], cnt, a[N], num;
bool vis[N];
inline void add(int &x, int y)
{
(x += y) >= mod && (x -= mod);
}
inline void link(int x, int y)
{
nxt[++num] = adj[x];
adj[x] = num;
go[num] = y;
deg[y]++;
}
inline void dfs(int u)
{
if (vis[u]) return;
vis[u] = 1;
for (int i = adj[u]; i; i = nxt[i]) dfs(go[i]);
}
inline void topsort()
{
queue<int>q;
int i, j;
q.push(m + 1);
seq[cnt = 1] = m + 1;
while (!q.empty())
{
int u = q.front();
q.pop();
for (i = adj[u]; i; i = nxt[i])
{
int v = go[i];
if (!vis[v]) continue;
deg[v]--;
if (!deg[v]) q.push(v), seq[++cnt] = v;
}
}
for (i = cnt; i >= 1; i--)
{
int u = seq[i];
for (j = adj[u]; j; j = nxt[j])
{
int v = go[j];
f[u] = (ll)f[u] * f[v] % mod;
}
}
}
inline void solve()
{
int i, j;
for (i = 1; i <= cnt; i++)
{
int u = seq[i], pre = 1;
for (j = adj[u]; j; j = nxt[j])
{
int v = go[j];
add(tag[v], (ll)tag[u] * pre % mod);
pre = (ll)pre * f[v] % mod;
}
}
}
int main()
{
freopen("call.in", "r", stdin);
freopen("call.out", "w", stdout);
read(n);
int i, j, k, x;
for (i = 1; i <= n; i++) read(a[i]);
read(m);
for (i = 1; i <= m; i++)
{
read(typ[i]);
f[i] = 1;
if (typ[i] == 1) read(pos[i]), read(val[i]);
else if (typ[i] == 2) read(val[i]), f[i] = val[i];
else
{
read(k);
for (j = 1; j <= k; j++)
{
read(x);
link(i, x);
}
}
}
read(q);
for (i = 1; i <= q; i++)
{
read(x);
link(m + 1, x);
}
dfs(m + 1);
for (i = 1; i <= m + 1; i++)
for (j = adj[i]; j; j = nxt[j])
{
k = go[j];
if (!vis[k] || !vis[i]) deg[k]--;
}
f[m + 1] = tag[m + 1] = 1;
topsort();
solve();
for (i = 1; i <= n; i++) a[i] = (ll)a[i] * f[m + 1] % mod;
for (i = 1; i <= m; i++)
if (typ[i] == 1) add(a[pos[i]], (ll)val[i] * tag[i] % mod);
for (i = 1; i <= n; i++)
printf("%d ", a[i]);
putchar('\n');
fclose(stdin);
fclose(stdout);
return 0;
}