您的位置:首页 > 其它

HDU-4471 Homework 矩阵运算上的优化

2013-04-28 19:53 495 查看
题意:给定一个函数定义如下:





对于q个点满足:



给定f[1]-f
的数值,然后存在q个特殊的点,其于前面的关联的项数特殊,系数特殊,当然位置也特殊。现在要求f
的值。

解法:如果题目中没有强调q个特殊点的话,那么可以使用矩阵快速幂搞出来。鉴于只有最多100个特殊点,我们可以选择分段进行处理,对每一个空隙进行一次矩阵快速运算,然后对于特殊点单独做一次。这里又有一个地方要特别注意:那就是q个点中有位置大于n的点。

当然仅仅是一般的矩阵快速幂这题的复杂度将达到O(q*log(n)*L^3),结合多组数据这样会超时,一个优化就是使用一个列向量去依次乘以若干个矩阵,那么每一次相乘的复杂度就变成了L^2,那么最后的复杂度就变成了O(q*log(n)*L^2)。

代码如下:

#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;

const int MOD = int(1e9)+7;
const int MAXN = 105;
int N, M, Q;

struct Matrix {
int r, c;
int a[MAXN][MAXN];
void init(int rr, int cc) {
r = rr, c = cc;
memset(a, 0, sizeof (a));
}
void show() {
for (int i = 1; i <= r; ++i) {
for (int j = 1; j <= c; ++j) {
printf("%d ", a[i][j]);
}
}
puts("");
}
};

Matrix operator * (const Matrix & x, const Matrix & y) {
Matrix ret;
ret.init(x.r, y.c);
//    printf("__%d %d %d\n", x.r, x.c, y.r);
for (int k = 1; k <= x.c; ++k) {
for (int i = 1; i <= ret.r; ++i) {
if (!x.a[i][k]) continue;
for (int j = 1; j <= ret.c; ++j) {
if (!y.a[k][j]) continue;
ret.a[i][j] = (1LL*x.a[i][k]*y.a[k][j]+ret.a[i][j])%MOD;
}
}
}
return ret;
}

Matrix s, pw[35], c, ci[105];
int t, xi[105], ti[105], pos[105];

bool cmp(int a, int b) {
return xi[a] < xi[b];
}

void getpw() {
pw[0] = c;
for (int i = 1; (1 << i) <= N; ++i) {
pw[i] = pw[i-1] * pw[i-1];
}
}

void cal(int b) {
for (int i = 0; i < 31; ++i) {
if (b & (1 << i)) {
s = pw[i] * s;
}
}
}

void AC() {
int L = t;
for (int i = 1; i <= Q; ++i) {
L = max(L, ti[i]);
} // 得到最长的线性关系
s.r = L, s.c = 1;
c.r = c.c = L;

for (int i = 2; i <= L; ++i) {
c.a[i][i-1] = 1;
}
for (int i = 1; i <= Q; ++i) {
ci[i].r = ci[i].c = L;
for (int j = 2; j <= L; ++j) {
ci[i].a[j][j-1] = 1;
}
}
getpw();
sort(pos+1, pos+1+Q, cmp);
int last = M;
for (int i = 1; i <= Q; ++i) {
int p = pos[i];
if (xi[p] > N || xi[p] <= last) continue;
cal(xi[p]-last-1);
s = ci[p] * s;
last = xi[p];
}
cal(N-last);
printf("%d\n", s.a[1][1]);
}

int main() {
int ca = 0;
while (scanf("%d %d %d", &N, &M, &Q) != EOF) {
memset(s.a, 0, sizeof (s.a));
for (int i = M; i >= 1; --i) {
scanf("%d", &s.a[i][1]);
}
scanf("%d", &t);
memset(c.a, 0, sizeof (c.a));
for (int i = 1; i <= t; ++i) {
scanf("%d", &c.a[1][i]);
}
for (int i = 1; i <= Q; ++i) {
pos[i] = i;
scanf("%d %d", &xi[i], &ti[i]);
memset(ci[i].a, 0, sizeof (ci[i].a));
for (int j = 1; j <= ti[i]; ++j) {
scanf("%d", &ci[i].a[1][j]);
}
}
printf("Case %d: ", ++ca);
AC();
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: