CodeChef POLY Polynomials 题解
Polynomials 题解
题意:
给定 n 个形如\(yi(x) = a_0 + a_1x + a_2x^2 + a_3x^3\)的函数以及 q 个询问。每个询问给定整数 t,你 需要求出使得 yi(t) 最小化的函数 yi。
题解:
由于李超线段树只可以维护函数间最多只有一个交点的情况,由于这里的三次函数可能会有多个交点,所以这里就没办法直接使用李超树了。
我们可以分治来解决,先处理好左半边函数在每一个区间内取得最小值的函数,和右半边函数在每一个区间内取得最小值的函数,可以存在一个vector中,记录下多项式和左右端点,然后将两个部分合并起来。
假设将两个多项式\(p_1\)和\(p_2\)合并,要求出在一段区间\([l,r]\)内哪一个取值最小。由于多个交点的函数不好处理,我们可以将两个函数切成几个部分使得每一个部分最多只有一个交点。这也是程序中get_sp_points函数的作用。
比如:
函数: \(y=1+x+x^2+x^3\)和\(y=2*x+1\)有三个交点。
这时候我们只需要按照x=-1,和x=0切两刀,就可以将函数分成三个最多一个交点的部分。可以发现这样的点最多只需要2个。
当函数间最多一个交点的时候就非常好做了。在交点的地方将两个函数分成两个部分,左边和右边一定是某一个函数在那一整段都是最小的,所以我们就可以确定下来\([l,x]\)和\([x,r]\)的最优函数,其中x为交点,可以二分来算(由于两个函数最多一个交点)。最后从前到后扫一遍,看每一个位置最优函数是什么并算出结果。时间复杂度为\(O(n\times \log n)\),可以在一开始将函数随机排列,防止合并之后段数太多。
Code:
ACRush:
#ifdef _MSC_VER
#define _CRT_SECURE_NO_WARNINGS
#endif
#include <vector>
#include <list>
#include <map>
#include <set>
#include <deque>
#include <queue>
#include <stack>
#include <bitset>
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cctype>
#include <string>
#include <cstring>
#include <ctime>
#include <cassert>
#include <string.h>
#include <unordered_set>
#include <unordered_map>
using namespace std;
typedef long long int64;
typedef unsigned long long uint64;
#define two(X) (1<<(X))
#define twoL(X) (((int64)(1))<<(X))
#define contain(S,X) (((S)&two(X))!=0)
#define containL(S,X) (((S)&twoL(X))!=0)
const double pi=acos(-1.0);
const double eps=1e-11;
template<class T> inline void ckmin(T &a,T b){if(b<a) a=b;}
template<class T> inline void ckmax(T &a,T b){if(b>a) a=b;}
template<class T> inline T sqr(T x){return x*x;}
typedef pair<int,int> ipair;
#define SIZE(A) ((int)A.size())
#define LENGTH(A) ((int)A.length())
#define MP(A,B) make_pair(A,B)
#define PB(X) push_back(X)
#define FOR(i,a,b) for(int i=(a);i<(b);++i)
#define REP(i,a) for(int i=0;i<(a);++i)
#define ALL(A) A.begin(),A.end()
const int maxn=(1<<20);
const int range=100000;
struct Poly
{
int a0,a1,a2,a3;
};
struct Segment
{
Segment() {}
Segment(Poly* p,int s,int t) : p(p),s(s),t(t) {}
Poly* p;
int s;
int t;
};
inline int64 eval(const Poly& p,int64 x)
{
return ((x*p.a3+p.a2)*x+p.a1)*x+p.a0;
}
int n;
Poly a[maxn];
int64 ret[range+1];
void append_one_seg(Poly* p1,Poly* p2,int s,int t,vector<Segment>& ret)
{
int64 fs1=eval(*p1,s);
int64 ft1=eval(*p1,t);
int64 fs2=eval(*p2,s);
int64 ft2=eval(*p2,t);
if (fs1<=fs2 && ft1<=ft2) { ret.emplace_back(p1,s,t); return; }
if (fs2<=fs1 && ft2<=ft1) { ret.emplace_back(p2,s,t); return; }
int low=s,high=t;
for (;low+1<high;)
{
int m=(low+high)/2;
int64 e1=eval(*p1,m);
int64 e2=eval(*p2,m);
if (fs1<=fs2 && e1<=e2 || fs1>=fs2 && e1>=e2)
low=m;
else
high=m;
}
ret.emplace_back((fs1<=fs2)?p1:p2,s,low);
ret.emplace_back((ft1<=ft2)?p1:p2,high,t);
}
vector<int> get_sp_points(const Poly& p1,const Poly& p2)
{
int a=3*(p1.a3-p2.a3);
int b=2*(p1.a2-p2.a2);
int c=p1.a1-p2.a1;
if (a==0)
if (b==0) return {};
else return {(int)floor(-(double)(c)/(double)(b))};
else
{
int64 d=(int64)b*(int64)b-4*(int64)a*(int64)c;
if (d<0) return {};
if (d==0) return {(int)floor(-(double)(b)/(2.0*a))};
double sqrt_d=sqrt((double)(d));
double x1=(-(double)(b)-sqrt_d)/(2.0*a);
double x2=(-(double)(b)+sqrt_d)/(2.0*a);
return {(int)floor(x1),(int)floor(x2)};
}
}
void append(Poly* p1,Poly* p2,int s,int t,vector<Segment>& ret)
{
vector<int> sp=get_sp_points(*p1,*p2);
int last=s;
while (last<=t)
{
int next=t;
for (int p:sp) if (p>=last) ckmin(next,p);
append_one_seg(p1,p2,last,next,ret);
last=next+1;
}
}
vector<Segment> solve(int n,Poly a[])
{
if (n==1) return {Segment(a,0,range)};
vector<Segment> r1=solve(n/2,a);
vector<Segment> r2=solve(n-n/2,a+n/2);
int nextp=0;
int p1=0,p2=0;
vector<Segment> ret;
for (;nextp<=range;)
{
for (;r1[p1].t<nextp;++p1);
for (;r2[p2].t<nextp;++p2);
int w=min(r1[p1].t,r2[p2].t);
append(r1[p1].p,r2[p2].p,nextp,w,ret);
nextp=w+1;
}
int new_size=0;
REP(i,SIZE(ret))
if (new_size>0 && ret[new_size-1].p==ret[i].p)
ret[new_size-1].t=ret[i].t;
else
ret[new_size++]=ret[i];
ret.resize(new_size);
return ret;
}
void process()
{
REP(i,n) swap(a[i],a[i+rand()%(n-i)]);
vector<Segment> r=solve(n,a);
//for (const auto& seg:r) printf("(%d %d %d %d) %d %d\n",seg.p->a3,seg.p->a2,seg.p->a1,seg.p->a0,seg.s,seg.t);
for (const Segment& seg:r) FOR(x,seg.s,seg.t+1) ret[x]=eval(*seg.p,x);
}
void validate()
{
REP(x,range+1)
{
int64 e=(1LL<<62);
REP(i,n) ckmin(e,eval(a[i],x));
if (e!=ret[x])
{
printf("ERROR %d %lld %lld\n",x,e,ret[x]);
exit(0);
}
}
printf("PASSED\n");
}
#ifdef _MSC_VER
int random()
{
int v1=rand()2768;
int v2=rand()2768;
return (v1<<15)|v2;
}
#endif
int main()
{
#ifdef _MSC_VER
freopen("input.txt","r",stdin);
#endif
std::ios::sync_with_stdio(false);
int testcase;
for (cin>>testcase;testcase>0;testcase--)
{
cin>>n;
REP(i,n) cin>>a[i].a0>>a[i].a1>>a[i].a2>>a[i].a3;
process();
int nq;
for (cin>>nq;nq>0;--nq)
{
int x;
cin>>x;
printf("%lld\n",ret[x]);
}
}
/*
REP(seed,1000000)
{
//if (seed!=23) continue;
srand(seed);
printf("seed = %d\n",seed);
n=rand()0+1;
int w=1000000;//10;
REP(i,n)
{
a[i].a3=rand()%min(1000,w);
a[i].a2=random()%min(w,100000);
a[i].a1=random()%min(w,100000);
a[i].a0=random()%min(w,100000);
}
process();
validate();
}
*/
return 0;
}
解法2: 李超线段树
上面说没法用李超线段树的原因就是函数可能有\(\geq 2\)个交点。
其实,当x特别大的时候,函数间的交点最多就只有一个。
求交点其实就是解一个方程:
-
一元一次,\(x+c=0\),方程总共就一个根。
-
一元二次方程\(x^2+bx+c=0\) ,假设两个根为\(u,v\)且\(u\geq v> \sqrt {\max (|b|,|c|)}\),则\(0=(x-v)*(x-u)\),\(b=-(u+v)\),\(c=u*v\)。显然u*v>|c|,所以\(u\times v\neq c\),也就是说最多只有一个根\(> \sqrt{\max (|b|,|c|)}\)。
-
一元三次方程\(x^3+ax^2+bx+c=0\),假设有三个根为\(u,v,w\),其中\(u\geq v>\sqrt{\max(|b|,|c|)}+5\geq w\),则\(0=(x-u)\times (x-v)\times (x-w)\),b=\(uv+uw+vw\),c=\(-uvw\),由于\(u\times v>c\),所以\(|w|<1\),\(uv-u-v\leq b\),也就是\((u-1)\times (v-1)\leq b+1\) ,显然不成立。更不可能u,v,w同时\(> \sqrt{\max(|b|,|c|)}+5\),不然\(c=-uvw\)就不可能成立。
所以我们只需要处理\(t\leq \sqrt {100000}+10\)的答案,剩下的李超线段树就可以维护了。
时间复杂度为:\(O(n\times \sqrt{\max(|b_i|,|c_i|)})\)。
解法3:分治
我们同样先处理出前\(\sqrt {\max (|b_i|,|c+i|)}+5\)个,然后并不使用李超树,而是分治。
首先我们要将多项式排序:按照每一个多项式的\((a_3,a_2,a_1,a_0)\)的字典序从大到小排。若当前处理的多项式中i号多项式在mid处为所有函数的最小值,则在\([L,mid]\)中不可能有编号\(>i\)的多项式成为最小值,在\([mid,R]\)中也不可能有编号\(<i\)的成为最小值。
比如这张图:红色代表最小值的函数,由于之前的排序所以在i左边的所有函数在>mid的时候都不可能比i小(他们增长的比i快)
同理在<mid的时候所有编号>i的下降的都比i慢,所以不可能在<mid的时候比i小。
code:
// ayy
// ' lamo
// SUBLIME HAX
#include <bits/stdc++.h>
using namespace std;
template<class T,class U>
ostream &operator<<(ostream &os,const pair<T,U> &x) {
return os<<"("<<x.first<<","<<x.second<<")";
}
namespace dbg_ns {
template<typename C>
struct is_iterable {
template<class T> static long check(...);
template<class T> static char check(int,
typename T::const_iterator = C().end());
enum {
value = sizeof(check<C>(0)) == sizeof(char),
neg_value = sizeof(check<C>(0)) != sizeof(char)
};
};
template<class T> ostream &_out_str(ostream &os,const T &x) {
return os<<'"'<<x<<'"';
}
template<class T> ostream &_dbg2_5(ostream &,const T &);
template<bool B,typename T=void> using eit=typename enable_if<B,T>::type;
template<class T>
inline ostream &_dbg3(ostream &os,eit<is_iterable<T>::neg_value,const T> &x) {
return os<<x;
}
template<class T>
inline ostream &_dbg3(ostream &os,eit<is_iterable<T>::value,const T> &V) {
os<<"{";
bool ff=0;
for(const auto &E:V) _dbg2_5(ff?os<<",":os,E), ff=1;
return os<<"}";
}
template<>
inline ostream &_dbg3<string>(ostream &os,const string &x) {
return _out_str(os,x);
}
template<>
inline ostream &_dbg3<const char *>(ostream &os,const char *const &x) {
return _out_str(os,x);
}
template<class T> inline ostream &_dbg2_5(ostream &os,const T &x) {
return _dbg3<T>(os,x);
}
template<class T,typename... Args> ostream &_dbg2(ostream &os,vector<string>::iterator nm,const T &x,Args&&... args);
inline ostream &_dbg2(ostream &os,vector<string>::iterator) { return os; }
template<typename... Args>
inline ostream &_dbg2(ostream &os,vector<string>::iterator nm,const char *x,Args&&... args) {
return _dbg2(_dbg3<const char *>(os<<" ",x),nm+1,args...);
}
template<class T,typename... Args>
inline ostream &_dbg2(ostream &os,vector<string>::iterator nm,const T &x,Args&&... args) {
return _dbg2(_dbg3<T>(os<<" "<<*nm<<"=",x),nm+1,args...);
}
vector<string> split(string s) {
vector<string> Z;
string z="";
s+=',';
int dep=0;
for(char c:s) {
if(c==',' && !dep) Z.push_back(z),z="";
else z+=c;
if(c=='(') ++dep;
if(c==')') --dep;
}
return Z;
}
template<typename... Args> inline ostream &_dbg1(int ln,const string &nm,Args&&... args) {
auto nms=split(nm);
return _dbg2(cerr<<"L"<<ln<<":",nms.begin(),args...)<<endl;
}
}
#define dbg(...) dbg_ns::_dbg1(__LINE__,#__VA_ARGS__,__VA_ARGS__)
#define sz(x) (int(x.size()))
#define eprintf(...) fprintf(stderr,__VA_ARGS__)
#define fi first
#define se second
#define pb push_back
// END SUBLIME HAX
// #include <bits/extc++.h>
// using namespace __gnu_pbds;
typedef unsigned long long ull;
typedef long long ll;
typedef long double ld; //CARE
typedef complex<ld> pt;
const ld eps=(ld)1e-8;
const ld tau=2*(ld)acosl(-1);
const int inf=1e9+99;
const ll linf=2e18+999;
const int P=1e9+7;
const int N=100<<10;
struct Cubic {
int a,b,c,d;
void read() {
scanf("%d%d%d%d",&d,&c,&b,&a);
}
bool operator<(const Cubic &o) const {
if(a!=o.a) return a<o.a;
if(b!=o.b) return b<o.b;
if(c!=o.c) return c<o.c;
if(d!=o.d) return d<o.d;
return 0;
}
inline ll eval(ll t) const {
ll s=1LL*t*t;
return a*s*t + b*s + c*t + d;
}
} polys[N];
const int C=450;
// const int C=1;
int qs[N],ts[N];
ll zz[N];
void div_conq(int tl,int tr,int cl,int cr) {
if(tl==tr) return;
assert(cl<cr);
int tm=(tl+tr)>>1;
int t=ts[tm];
int cm=cl;
zz[t]=linf;
ll cur;
for(int ci=cl;ci<cr;ci++) if((cur=polys[ci].eval(t)) < zz[t]) {
zz[t]=cur;
cm=ci;
}
div_conq(tl,tm,cl,cm+1);
div_conq(tm+1,tr,cm,cr);
}
void _m() {
int n; scanf("%d",&n);
for(int i=0;i<n;i++) polys[i].read();
sort(polys,polys+n);
reverse(polys,polys+n);
int q; scanf("%d",&q);
for(int i=0;i<q;i++) scanf("%d",qs+i), ts[i]=qs[i];
sort(ts,ts+q);
for(int t=0;t<C;t++) {
zz[t]=linf;
for(int i=0;i<n;i++) zz[t]=min(zz[t],polys[i].eval(t));
}
int lef=0;
for(;lef<q && ts[lef]<C;) ++lef;
div_conq(lef,q,0,n);
for(int i=0;i<q;i++) printf("%lld\n",zz[qs[i]]);
}
int32_t main() {
int T; scanf("%d",&T); for(;T--;) _m();
}
解法4:
和解法3排序类似。但按照\((a_3,a_2,a_1,a_0)\)升序排序(也就是越在后面上升的速度越快,下降速度越快)。
再用线段树维护每一个位置最小的多项式,然后从前往后加入多项式,每次二分到最大的当前多项式最优的位置x,将\([sqrtX,x]\)全部都换成当前多项式(sqrtX为线段树维护区间的左端点)。由于当前多项式下降的比前面所有的多项式都要快,所以从x往前一定是当前的多项式最优。
code:
#include <cstdio>
#include <cassert>
#include <cmath>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
#include <functional>
#include <stack>
#include <queue>
#include <unordered_map>
#include <tuple>
#define getchar getchar_unlocked
#define putchar putchar_unlocked
#define _rep(_1, _2, _3, _4, name, ...) name
#define rep2(i, n) rep3(i, 0, n)
#define rep3(i, a, b) rep4(i, a, b, 1)
#define rep4(i, a, b, c) for (int i = int(a); i < int(b); i += int(c))
#define rep(...) _rep(__VA_ARGS__, rep4, rep3, rep2, _)(__VA_ARGS__)
using namespace std;
using i64 = long long;
using u8 = unsigned char;
using u32 = unsigned;
using u64 = unsigned long long;
using u128 = __uint128_t;
using f80 = long double;
int get_int() {
int c, n;
while ((c = getchar()) < '0');
n = c - '0';
while ((c = getchar()) >= '0') n = n * 10 + (c - '0');
return n;
}
struct Poly {
Poly() {}
Poly(int a0, int a1, int a2, int a3) {
a[0] = a0, a[1] = a1, a[2] = a2, a[3] = a3;
}
i64 evaluated_at(int x) const {
u64 ret = 0;
for (int i = 3; i >= 0; --i) ret = ret * x + a[i];
return ret;
}
bool operator < (const Poly& rhs) const {
for (int i = 3; i >= 0; --i) {
if (a[i] < rhs.a[i]) return true;
else if (a[i] > rhs.a[i]) return false;
}
return false;
}
bool operator == (const Poly& rhs) const {
rep(i, 4) if (a[i] != rhs.a[i]) return false;
return true;
}
void print() const {
printf("%dx^3 + %dx^2 + %dx + %d\n", a[3], a[2], a[1], a[0]);
}
int a[4];
};
struct Node {
Node() {}
Node(int id) : id(id) {}
int result() const { return id; };
void delay(int v) { id = v; }
void propagate(Node& left, Node& right) {
if (id >= 0) left.id = id, right.id = id, id = -1;
}
int id;
};
template <typename Node>
class SegmentTree {
public:
SegmentTree(const int N) : N(N), nodes(2 * N) {
for (int i = 0; i < N; ++i) nodes[N + i] = Node(0);
for (int i = N - 1; i > 0; --i) nodes[i] = Node(0);
}
void update(int l, int r, int v) {
l += N; r += N;
propagate(l); propagate(r - 1);
for (; l < r; l >>= 1, r >>= 1) {
if (l & 1) nodes[l++].delay(v);
if (r & 1) nodes[--r].delay(v);
}
}
Node query(int l) {
propagate(l += N);
return nodes[l];
}
private:
void propagate(int k) {
for (int s = __lg(k); s >= 1; --s) {
int i = k >> s;
nodes[i].propagate(nodes[2 * i + 0], nodes[2 * i + 1]);
}
}
void merge(int k) {
nodes[k] = nodes[2 * k + 0].merge(nodes[2 * k + 1]);
}
int N;
vector<Node> nodes;
};
void solve() {
const int X_MAX = 1e5;
int T = get_int();
rep(_, T) {
int N = get_int();
vector<Poly> polys(N);
rep(i, N) {
int a0 = get_int(), a1 = get_int(), a2 = get_int(), a3 = get_int();
polys[i] = Poly(a0, a1, a2, a3);
}
sort(polys.begin(), polys.end());
N = unique(polys.begin(), polys.end()) - polys.begin();
polys.resize(N);
const int Q = get_int();
vector<int> xs(Q);
rep(i, Q) {
xs[i] = get_int(); assert(xs[i] <= X_MAX);
}
const int sqrtX = 333;
vector<i64> bests(X_MAX + 1, 4e18);
auto seg = SegmentTree<Node>(X_MAX + 1);
rep(i, N) {
auto& f = polys[i];
rep(x, sqrtX) bests[x] = min(bests[x], f.evaluated_at(x));
if (f.evaluated_at(sqrtX) >= bests[sqrtX]) continue;
bests[sqrtX] = f.evaluated_at(sqrtX);
int pid = seg.query(X_MAX).result();
assert(pid >= 0);
if (f.evaluated_at(X_MAX) <= polys[pid].evaluated_at(X_MAX)) {
seg.update(sqrtX + 1, X_MAX + 1, i);
} else {
int x1 = sqrtX + 1, x2 = X_MAX;
while (x2 - x1 > 1) {
int xm = (x1 + x2) >> 1;
int id = seg.query(xm).result();
if (f.evaluated_at(xm) < polys[id].evaluated_at(xm)) {
x1 = xm;
} else {
x2 = xm;
}
}
seg.update(sqrtX + 1, x2, i);
}
}
rep(x, sqrtX + 1, X_MAX + 1) {
int pid = seg.query(x).result();
bests[x] = min(bests[x], polys[pid].evaluated_at(x));
}
rep(i, Q) printf("%lld\n", bests[xs[i]]);
}
}
int main() {
clock_t beg = clock();
solve();
clock_t end = clock();
fprintf(stderr, "%.3f sec\n", double(end - beg) / CLOCKS_PER_SEC);
return 0;
}