[Luogu] P7077 函数调用

\(Link\)

Description

某数据库应用程序提供了若干函数用以维护数据。已知这些函数的功能可分为三类:

\(1.\)将数据中的指定元素加上一个值;

\(2.\)将数据中的每一个元素乘以一个相同值;

\(3.\)依次执行若干次函数调用,保证不会出现递归(即不会直接或间接地调用本身)。

在使用该数据库应用时,用户可一次性输入要调用的函数序列(一个函数可能被调用多次),在依次执行完序列中的函数后,系统中的数据被加以更新。为了计算出正确数据,小\(A\)查阅了软件的文档,了解到每个函数的具体功能信息,现在他想请你根据这些信息帮他计算出更新后的数据应该是多少。

Solution

可以考虑依次执行完操作序列后,所有数先一起被乘上一个数,有一些位置再被加上所加上的数乘上这个数的贡献。

注意到依次执行,不会出现递归等关键字眼,可以想到拓扑排序。(一定要注意这里的\(m\)才是原来拓扑排序的\(n\),因为把\(m\)个函数当作点)

可以新建一个点\(0\),作为主函数,同时将它和输入的所有调用函数连边。那么拓扑排序循环就要从\(0\sim{m}\)

前一步是很好做的,可以考虑记忆化搜索,或者先建反图,跑一遍拓扑排序。维护一个乘法标记\(mul\),对于\(1\)类函数和\(3\)类函数,它们的\(mul=1\),对于\(2\)类函数,它的\(mul=v_i\)。然后从下到上,累乘\(mul\)即可。然后所有的\(a_i\)就要乘上\(mul[0]\)

而如何处理加上的数究竟被加上了多少次呢?我们会发现,当操作序列不断执行\(\times,+,\times,+...\)时,对于一个\(+\),只有它后面的\(\times\)会对它产生乘法贡献。我们在原图上再跑一边拓扑排序,从上到下算出每个函数被执行的次数\(add\),也就是加法的次数。

对于\(u\rightarrow{v}\),设\(now=\prod\limits_{u\rightarrow{p},t_p>t_v}mul[p]\)\(t_i\)表示\(i\)被执行的顺序先后),那么\(add[v]=add[v]+add[u]*now\)\(u\)先执行\(add[u]\)次,每次都使\(add[v]\)乘上\(now\))。算完\(add\)后,处理加法即可。

(这个的确不是很难,但是谁有时间想啊,T1出题人1582.10.5~1582.10.14)

Code

#include <bits/stdc++.h>

using namespace std;

#define ll long long

const ll mod = 998244353;

queue < int > q;

int n, m, t, tot, hd[100005], nxt[1100005], to[1100005], rd[1100005], cnt[1100005];

ll a[100005];

struct node
{
	int tp, pos, sz;
	ll v, mul, add;
	vector < int > w;
}f[100005];

int read()
{
	int x = 0, fl = 1; char ch = getchar();
	while (ch < '0' || ch > '9') { if (ch == '-') fl = -1; ch = getchar();}
	while (ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) + ch - '0'; ch = getchar();}
	return x * fl;
}

void add(int x, int y)
{
	tot ++ ;
	to[tot] = y;
	nxt[tot] = hd[x];
	hd[x] = tot;
	return;
}

void topo1()
{
	for (int i = 0; i <= m; i ++ ) cnt[i] = f[i].sz;
	for (int i = 0; i <= m; i ++ ) if (!cnt[i]) q.push(i);
	while (q.size())
	{
		int x = q.front(); q.pop();
		for (int i = hd[x]; i; i = nxt[i])
		{
			int y = to[i];
			f[y].mul = f[y].mul * f[x].mul % mod;
			cnt[y] -- ;
			if (!cnt[y]) q.push(y);
		}
	}
	return;
}

void topo2()
{
	for (int i = 0; i <= m; i ++ ) cnt[i] = rd[i];
	for (int i = 0; i <= m; i ++ ) if (!cnt[i]) q.push(i);
	while (q.size())
	{
		int x = q.front(); q.pop();
		ll now = 1ll;
		for (int i = f[x].sz - 1; i >= 0; i -- )
		{
			int y = f[x].w[i];
			f[y].add = (f[y].add + f[x].add * now)% mod;
			now = now * f[y].mul % mod;
			cnt[y] -- ;
			if (!cnt[y]) q.push(y);
		}
	}
	return;
}

int main()
{
	n = read();
	for (int i = 1; i <= n; i ++ )
		a[i] = (ll)read();
	m = read();
	f[0].mul = 1ll;
	for (int i = 1; i <= m; i ++ )
	{
		f[i].tp = read();
		if (f[i].tp == 1)
		{
			f[i].pos = read();
			f[i].v = (ll)read();
			f[i].mul = 1ll;
		}
		else if (f[i].tp == 2)
		{
			f[i].v = (ll)read();
			f[i].mul = f[i].v;
		}
		else
		{
			f[i].sz = read();
			f[i].mul = 1ll;
			for (int j = 1; j <= f[i].sz; j ++ )
			{
				int x = read();
				f[i].w.push_back(x);
				add(x, i);
				rd[x] ++ ;
			}
		}
	}
	f[0].add = 1ll;
	t = read();
	while (t -- )
	{
		int x = read();
		add(x, 0);
		rd[x] ++ ;
		f[0].sz ++ ;
		f[0].w.push_back(x);
	}
	topo1(); topo2();
	for (int i = 1; i <= n; i ++ )
		a[i] = a[i] * f[0].mul % mod;
	for (int i = 1; i <= m; i ++ )
		if (f[i].tp == 1)
			a[f[i].pos] = (a[f[i].pos] + f[i].add * f[i].v % mod) % mod;
	for (int i = 1; i <= n; i ++ )
		printf("%lld ", a[i]);
	puts("");
	return 0;
}
posted @ 2020-11-14 13:08  andysj  阅读(133)  评论(0编辑  收藏  举报