「CSP-S 2020」函数调用(拓扑排序+DP)

Address

LOJ3381
LuoguP7077

Solution

因为加是单点加,乘是全体乘,所以考虑计算后面的乘对前面的加的影响。

也就是说,对于某次执行 \(T_j=1\) 的操作 \(a_p+=v\),设在它之后执行的 \(T_j=2\) 的操作的 \(\prod V_j=x\)。那么计算最终答案的时候,只要把 \(a_p+=v\times x\) 即可。

对于 \(T_j=3\) 的操作,题目说保证不会出现递归(即不会直接或间接地调用本身)。因此建一张 DAG,如果函数 \(u\) 直接调用了函数 \(v\),那么连一条 \(u→v\) 的边。方便起见,再建一个点 \(m+1\),向 \(Q\)\(f_i\) 都连一条边。

接下来开始暴力,我们记一个 \(prod\),表示当前访问过的 \(T_j=2\) 节点的 \(\prod V_j\)。(重复访问就重复计算)

\(m+1\) 开始 DFS(注意出边的顺序要反过来,因为是后面的乘对前面的加的影响)。DFS 到 \(u\) 的时候,如果 \(T_j=2\)\(prod\times=V_j\),如果 \(T_j=1\)\(a_{P_j}+=V_j\times prod\),如果 \(T_j=3\),就什么都不做。

怎么优化这个暴力?

考虑对于一个点 \(u\),它连向的点分别为 \(v_1,v_2,...,v_k\)。那么 DFS 到 \(u\) 之后,设当前的 \(prod\)\(s\),接下来肯定是 DFS \(v_1\),那么执行完 DFS \(v_1\),准备 DFS \(v_2\) 的时候,\(prod\) 是多少?

预处理出 \(dp_u\) 表示从 \(u\) 开始 DFS,经过的所有 \(T_j=2\) 节点的 \(\prod V_j\),按拓扑序倒序转移即可。

那么上述的 \(prod\) 就是 \(dp_{v_1}\times s\),以此类推,准备 DFS \(v_i\) 的时候,\(prod\) 就是 \(\prod_{j=1}^{i-1}dp_{v_j}\times s\)

我们可以这样描述这个 DFS:从 \(u\) 开始,带着大小为 \(prod\) 的标记走下去,接下来,对于每个 \(v_i\),带着大小为 \(\prod_{j=1}^{i-1}dp_{v_j}\times s\) 的标记走下去。也就是说,我们不用把 \(v_1\sim v_{i-1}\) 都 DFS 一遍,就可以知道 \(v_i\) 的标记大小,它仅仅取决于所有连向它的 \(u\)

我们记 \(tag_u\) 表示点 \(u\) 的标记大小。这个 \(tag\) 有什么用呢?我们发现所有的 \(T_j\in\{1,2\}\)\(j\) 都是底层节点,没有出边,所以如果 \(T_j=1\),我们求出 \(tag_j\) 之后,直接让 \(a_{P_j}+=V_j\times tag_j\) 即可。

根据上述分析,对于一个点 \(v\),只要知道所有连向它的 \(u\)\(tag_u\),即可用形如 \(tag_v=\sum_{u→v}tag_u\times \prod_{x\in pre(u,v)}dp_x\) 的式子求出 \(tag_v\)。按照拓扑序转移即可。

注意把所有 \(m+1\) 走不到的点和边删掉。

时间复杂度 \(O(n+m+Q+\sum C_j)\)

Code

#include <bits/stdc++.h>

using namespace std;

#define ll long long

template <class t>
inline void read(t & res)
{
	char ch;
	while (ch = getchar(), !isdigit(ch));
	res = ch ^ 48;
	while (ch = getchar(), isdigit(ch))
		res = res * 10 + (ch ^ 48);
}

template <class t>
inline void print(t x)
{
	if (x > 9) print(x / 10);
	putchar(x % 10 + 48);
}

const int N = 1e5 + 15, M = 2e6 + 15, mod = 998244353;

int adj[N], nxt[M], go[M], val[N], pos[N], typ[N], n, m, q, tag[N];
int f[N], deg[N], seq[N], cnt, a[N], num;
bool vis[N];

inline void add(int &x, int y)
{
	(x += y) >= mod && (x -= mod);
}

inline void link(int x, int y)
{
	nxt[++num] = adj[x];
	adj[x] = num;
	go[num] = y;
	deg[y]++;
}

inline void dfs(int u)
{
	if (vis[u]) return;
	vis[u] = 1;
	for (int i = adj[u]; i; i = nxt[i]) dfs(go[i]);
}

inline void topsort()
{
	queue<int>q;
	int i, j;
	q.push(m + 1);
	seq[cnt = 1] = m + 1;
	while (!q.empty())
	{
		int u = q.front();
		q.pop();
		for (i = adj[u]; i; i = nxt[i])
		{
			int v = go[i];
			if (!vis[v]) continue;
			deg[v]--;
			if (!deg[v]) q.push(v), seq[++cnt] = v;
		}
	}
	for (i = cnt; i >= 1; i--)
	{
		int u = seq[i];
		for (j = adj[u]; j; j = nxt[j])
		{
			int v = go[j];
			f[u] = (ll)f[u] * f[v] % mod;
		}
	}
}

inline void solve()
{
	int i, j;
	for (i = 1; i <= cnt; i++)
	{
		int u = seq[i], pre = 1;
		for (j = adj[u]; j; j = nxt[j])
		{
			int v = go[j];
			add(tag[v], (ll)tag[u] * pre % mod);
			pre = (ll)pre * f[v] % mod;
		}
	}
}

int main()
{
	freopen("call.in", "r", stdin);
	freopen("call.out", "w", stdout);
	read(n);
	int i, j, k, x;
	for (i = 1; i <= n; i++) read(a[i]);
	read(m);
	for (i = 1; i <= m; i++)
	{
		read(typ[i]);
		f[i] = 1;
		if (typ[i] == 1) read(pos[i]), read(val[i]);
		else if (typ[i] == 2) read(val[i]), f[i] = val[i];
		else
		{
			read(k);
			for (j = 1; j <= k; j++)
			{
				read(x);
				link(i, x);
			}
		}
	}
	read(q);
	for (i = 1; i <= q; i++)
	{
		read(x);
		link(m + 1, x);
	}
	dfs(m + 1);
	for (i = 1; i <= m + 1; i++)
		for (j = adj[i]; j; j = nxt[j])
		{
			k = go[j];
			if (!vis[k] || !vis[i]) deg[k]--;
		}
	f[m + 1] = tag[m + 1] = 1;
	topsort();
	solve();
	for (i = 1; i <= n; i++) a[i] = (ll)a[i] * f[m + 1] % mod;
	for (i = 1; i <= m; i++)
		if (typ[i] == 1) add(a[pos[i]], (ll)val[i] * tag[i] % mod);
	for (i = 1; i <= n; i++)
		printf("%d ", a[i]);
	putchar('\n');
	fclose(stdin);
	fclose(stdout);
	return 0;
}
posted @ 2021-09-07 21:11  花淇淋  阅读(79)  评论(0编辑  收藏  举报