您的位置:首页 > 其它

「BZOJ3509」「CodeChef」 COUNTARI

2017-12-07 20:55 183 查看

Description

给定一个长度为 N 的数组 A,求有多少对 i,j,k (1≤i<j<k≤N) 满足 Ak−Aj=Aj−Ai 。

Input

第一行一个整数 N 。

接下来一行 N 个数Ai 。

Output

一行一个整数。

Sample Input

10
3 5 3 6 3 4 10 4 5 2


Sample Output

9


HINT

N≤105,Ai≤30000

题解

以下记 W=max{A1,…,AN} 。

这道题的 O(NW) 做法非常显然,即维护每个数左边和右边每种数字出现的次数。但这也是这道题的瓶颈所在。因为很显然这个算法是很难(或者不可能)继续优化下去的,所以很可能会卡在这里(如果你之前没做过类似的题目)。

这样,我们就考虑从一个看起来时间复杂度更坏的算法入手。

由于三个数构成等差数列,所以 2Aj=Ai+Ak 。我们可以对于每一个数维护左边和右边每种数字出现的次数,这个可以做到 O(N) 。然后统计方案数可以用卷积来实现,用 FFT 可以做到 O(Wlog2W) ,于是总复杂度为 O(NWlog2W) 。很明显是更差的。但是这个算法就少了很多局限性。

考虑我们卷积的过程,设多项式 f(x) 表示下标在区间 [L1,R1] 的生成函数(xk 的系数表示数字 k 出现的次数); g(x) 表示下标在区间 [L2,R2] (L1≤R1<L2≤R2) 的生成函数(xk 的系数表示数字 k 出现的次数);卷积 (f∗g)(x) 中 x2k 的系数即为首项下标在 [L1,R1] , 末项下标在 [L2,R2] ,中项为 k 的项数为 3 等差数列数目。很明显,如果我们设 L1=1,R2=N ,我们做一次卷积可以求出首项下标在 [1,R1] , 末项下标在 [L2,N] ,中项下标在区间 (R1,L2) 的等差数列个数。

这样就很明显可以分块来做。

我们把整个区间分成 K 块,枚举每一个数为中项,我们讨论下列三种情况:

1.首项和末项都在块内,可以用刚开始的做法,但是如果块内元素比较小,可以枚举首项下标,这样单块复杂度 O((NK)2) 。

2.首项和末项有一个在块内,我们可以枚举在块内的那一项,同样可以做到单块 O((NK)2) 。

3.首项和末项都不在块内,那么我们就需要用卷积了。一次卷积即可求出块内所有元素为中项的方案数。单块复杂度 O(Wlog2W) 。

那么总的复杂度就是 O(N2K+KWlog2W) 。

由均值不等式, K=NWlogW√ 时复杂度最低,为 O(NWlog2W−−−−−−−√) 。

但是事实上,由于常数等原因,块的大小需要调大大约 10 倍,约 2000 左右时最快。由于此题卡常严重,需要手写复数类。

My Code

/**************************************************************
Problem: 3509
User: infinityedge
Language: C++
Result: Accepted
Time:34664 ms
Memory:5784 kb
****************************************************************/

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <complex>

#define MAXN 65536
#define pi acos(-1)
using namespace std;
typedef long long ll;
struct E{
long double real, imag;
E(long double real = 0, long double imag = 0) : real(real), imag(imag) { }
inline friend E operator + (E &a, E &b)
{ return E(a.real + b.real, a.imag + b.imag); }
inline friend E operator - (E &a, E &b)
{ return E(a.real - b.real, a.imag - b.imag); }
inline friend E operator * (E &a, E &b)
{ return E(a.real * b.real - a.imag * b.imag , a.imag * b.real + a.real * b.imag); }
inline friend void swap(E &a, E &b)
{ E c = a; a = b; b = c; }
};

E a[MAXN + 1], b[MAXN + 1];

void bit_reverse(int n, E* r){
for(int i = 0, j = 0; i < n; i ++){
if(i > j) swap(r[i], r[j]);
for(int l = n >> 1; (j ^= l) < l; l >>= 1);
}
}

void fft(int n, E* r, int f){
bit_reverse(n, r);
for(int i = 2; i <= n; i <<= 1){
int m = i >> 1;
for(int j = 0; j < n; j += i){
E w(1, 0), wn(cos(2 * pi / i), f * sin(2 * pi / i));
for(int k = 0; k < m; k ++){
E z = r[j + m + k] * w;
r[j + m + k] = r[j + k] - z;
r[j + k] = r[j + k] + z;
w = w * wn;
}
}
}
if(f == -1){
E ww = E(1.0 / n, 0);
for(int i = 0; i < n; i ++) r[i] = r[i] * ww;
}
}

int n, k, m;
int d[100005], pos[100005], l[1005], r[1005];
ll ans;
int vis[30005];
int tmpl[MAXN], tmpr[MAXN];
void solve(int x){
for(int i = l[x]; i <= r[x]; i ++){
tmpr[d[i]]++;
}
for(int i = l[x]; i <= r[x]; i ++){
for(int j = i + 1; j <= r[x]; j ++){
int dk = d[i] + d[i] - d[j];
ans += tmpl[dk];
}
tmpl[d[i]]++;
}
for(int i = l[x]; i <= r[x]; i ++){
tmpr[d[i]] = tmpl[d[i]] = 0;
}
}
int N = 1;
void solsub(int x){

for(int i = l[x]; i <= r[x]; i ++){
for(int j = i + 1; j <= r[x]; j ++){
int dk = d[i] + d[i] - d[j];
if(dk >= 0) ans += tmpl[dk];
dk = d[j] + d[j] - d[i];
if(dk >= 0) ans += tmpr[dk];
}
}
if(x == 1 || x == m) return;
for(int i = 0; i <= N; i ++){
a[i] = b[i] = E(0, 0);
}
for(int i = 0; i <= N; i ++){
a[i] = E(tmpl[i], 0);
b[i] = E(tmpr[i], 0);
}
fft(N, a, 1); fft(N, b, 1);
for(int i = 0; i <= N; i ++){
a[i] = a[i] * b[i];
}
fft(N, a, -1);
for(int i = l[x]; i <= r[x]; i ++){
ans = ans + (ll)(a[2 * d[i]].real + 0.1);
}
}
void solve2(){
int mx = 0;
for(int i = 1; i <= n; i ++){
mx = max(d[i], mx);
}
mx = mx * 2 + 1;

while(N < mx) N = N << 1;
for(int i = 1; i <= n; i ++){
tmpr[d[i]] ++;
}
for(int i = 1; i <= m; i ++){
for(int j = l[i]; j <= r[i]; j ++){
tmpr[d[j]] --;
}
solsub(i);
for(int j = l[i]; j <= r[i]; j ++){
tmpl[d[j]] ++;
}
}
}

int main(){
scanf("%d", &n); k = 1823;
if(n < 1823) k = 1823;
for(int i = 1; i <= n; i ++){
scanf("%d", &d[i]);
}
for(int i = 1; i <= n; i ++){
pos[i] = (n - 1) / k + 1;
}
m = pos
;
for(int i = 1; i <= m; i ++){
l[i] = (i - 1) * k + 1;
r[i] = i * k;
}
r[m] = n;
for(int i = 1; i <= m; i ++){
solve(i);
}
solve2();
printf("%lld\n", ans);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  FFT 分块 生成函数