LG P7077 函数调用
Description
函数是各种编程语言中一项重要的概念,借助函数,我们总可以将复杂的任务分解成一个个相对简单的子任务,直到细化为十分简单的基础操作,从而使代码的组织更加严密、更加有条理。然而,过多的函数调用也会导致额外的开销,影响程序的运行效率。
某数据库应用程序提供了若干函数用以维护数据。已知这些函数的功能可分为三类:
- 将数据中的指定元素加上一个值;
- 将数据中的每一个元素乘以一个相同值;
- 依次执行若干次函数调用,保证不会出现递归(即不会直接或间接地调用本身)。
在使用该数据库应用时,用户可一次性输入要调用的函数序列(一个函数可能被调用多次),在依次执行完序列中的函数后,系统中的数据被加以更新。某一天,小 A 在应用该数据库程序处理数据时遇到了困难:由于频繁而低效的函数调用,系统在执行操作时进入了无响应的状态,他只好强制结束了数据库程序。为了计算出正确数据,小 A 查阅了软件的文档,了解到每个函数的具体功能信息,现在他想请你根据这些信息帮他计算出更新后的数据应该是多少。
Solution
显然先进行的加操作会受到之后的乘操作的影响
所以先将每一个操作(包括类型3)的乘法系数算出来
考虑每个加操作的系数就是其之后的所有乘操作乘积,类型3中加操作也会受到内部其之后的乘操作的影响
先对每个类型3建图,由它自己的编号连向子函数的编号,在DAG上DP就可以求出类型3的系数
接着从后向前扫描可以求出所有单个加操作的系数
最后考虑类型3中加操作
在刚刚的DAG上拓扑排序,按次序更新每个操作的所有子函数系数,遍历边时需要逆序
最后对所有加操作更新a数列
#include<iostream> #include<cstdio> #include<queue> using namespace std; int n,m,t[100005],p[100005],tot,head[100005],du[100005],Q,f[100005]; long long a[100005],add[100005],mul[100005],s=1,cnt[100005]; bool vst[100005]; const int mod=998244353; struct Edge { int to,nxt; }edge[1000005]; queue<int>q; inline int read() { int f=1,w=0; char ch=0; while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { w=(w<<1)+(w<<3)+ch-'0'; ch=getchar(); } return f*w; } long long dfs(int k) { if(vst[k]) return mul[k]; vst[k]=true; for(int i=head[k];i;i=edge[i].nxt) { int v=edge[i].to; (mul[k]*=dfs(v))%=mod; } return mul[k]; } void topo() { for(int i=1;i<=m;i++) if(!du[i]) q.push(i); while(q.size()) { int u=q.front(); q.pop(); for(int i=head[u];i;i=edge[i].nxt) { int v=edge[i].to; (cnt[v]+=cnt[u])%=mod; (cnt[u]*=mul[v])%=mod; --du[v]; if(!du[v]) q.push(v); } } } int main() { n=read(); for(int i=1;i<=n;i++) a[i]=read(); m=read(); for(int i=1;i<=m;i++) { t[i]=read(),mul[i]=1; if(t[i]==1) p[i]=read(),add[i]=read(); else if(t[i]==2) mul[i]=read(); else { int c=read(); for(int j=1;j<=c;j++) { int g=read(); edge[++tot]=(Edge){g,head[i]},head[i]=tot; ++du[g]; } } } for(int i=1;i<=m;i++) { if(!vst[i]) dfs(i); } Q=read(); for(int i=1;i<=Q;i++) f[i]=read(); for(int i=Q;i;i--) { (cnt[f[i]]+=s)%=mod; (s*=mul[f[i]])%=mod; } for(int i=1;i<=n;i++) (a[i]*=s)%=mod; topo(); for(int i=1;i<=m;i++) if(t[i]==1) (a[p[i]]+=add[i]*cnt[i]%mod)%=mod; for(int i=1;i<=n;i++) printf("%lld ",a[i]); return 0; }