您的位置:首页 > 其它

Educational Codeforces Round 53 (Rated for Div. 2) E. Segment Sum

2018-12-05 22:42 344 查看

应该是我做过最复杂的数位dp了
做之前已经忘了数位dp是啥了,找了一个博客复习下

然后这题不同的是求的是和,而不是个数
所以需要维护更多的元素
dp[i][j]val, count, remain 代表讨论到i位置,已经有mask(j) (0-9分别2进制表示)时他的剩下位数的总值,计数,以及前面全前导0的总值(相当于要分开计算)

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <climits>
#include <cstring>
#include <vector>
#include <list>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <bitset>
#include <algorithm>
#include <functional>
#include <assert.h>
#include <iomanip>
#include <unordered_map>
using namespace std;
const int N = 2005;
const int INF = 0x3f3f3f3f;
const int MOD = 998244353;

typedef long long ll;

long long l, r;
int k;
int digit[20];

struct Node{
ll val;
ll cnt;
ll remain;
Node(ll a=0, ll b=0, ll c=0):val(a), cnt(b), remain(c) {}
}dp[20][1024];
ll ten[20];

Node dfs(int pos, int mask, ll num, bool lead, bool limit) {
if(pos == -1) return Node(0, 1, 0);
if(!lead && !limit && dp[pos][mask].val != - 1) return dp[pos][mask];

ll ans = 0; ll cnt = 0; ll remain = 0;
int up=limit ? digit[pos]:9;

for(int i = 0; i <= up; ++i) {
if(lead && i==0) {
Node callback = dfs(pos-1, mask, num*10, lead, limit && i==digit[pos]);
ll val = callback.val; ll count = callback.cnt; ll remain2 = callback.remain;
// printf("%d\n", count);
remain = (remain + val + remain2) % MOD;
}
else {
int tt = mask | (1 << i);
int ccnt = 0;
for(int j = 0; j < 10; ++j) {
if( (tt >> j) & 1 )
ccnt ++;
}
if(ccnt <= k) {
Node callback = dfs(pos-1, tt, num * 10 + i, lead && i==0, limit && i==digit[pos]);
ll val = callback.val; ll count = callback.cnt; ll remain2 = callback.remain;
// printf("%d\n", count);
ans = (ans + i * ten[pos] % MOD * count % MOD + val ) % MOD; cnt = (cnt + count) % MOD; remain = (remain + remain2) % MOD;
}
}
}

dp[pos][mask].val = ans; dp[pos][mask].cnt = cnt; dp[pos][mask].remain = remain;
return dp[pos][mask];
}

ll solve(ll x) {
int pos=0;
for(int i = 0; i < 20; ++i) {
for(int j = 0; j < 1024; ++j) {
dp[i][j].val = -1;
dp[i][j].remain = 0;
dp[i][j].cnt = 0;
}
}

while(x) {
digit[pos++]=x%10;
x/=10;
}
Node tt = dfs(pos-1, 0, 0, true, true);
return (tt.val + tt.remain) % MOD;
}

int main() {
ten[0] = 1;
for(int i = 1; i < 19; ++i) ten[i] = ten[i - 1] * 10 % MOD;

// for(int i = 1; i < 19; ++i) printf("%lld ", ten[i]); printf("\n");

while(~scanf("%lld %lld %d", &l, &r, &k)) {
printf("%lld\n", (solve(r) - solve(l - 1) + MOD) % MOD);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: