数位dp

“聚小而大。”

定义

一种以位数为转移条件的 $dp$ ,一般是统计一个区间内满足一类条件的计数 $dp$ 。一般的初始转移有二维 $dp_{i,j}$ 表示有 $i$ 位且末位为 $j$ 的统计个数。一般满足差分性质(与树状数组类似),即 $[l,r]=[1,r]-[1,l-1]$ 的性质。

写法

迭代式

即使用 for 循环来遍历整个 dp 数组。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
inline int dp(int x)
{
if(!x) return 0;
vector<int>Num; //存储limit的每一位
while(x) Num.push_back(x%10),x/=10;
ll res=0; //存储答案
int last=/*init*/; //限制条件s
for(re int i=Num.size()-1;i>=0;--i)
{
int x=Num[i];
for(re int j=0;j<x;++j) if(/*some condition*/) res+=dp[i+1][j];
if(/*some condition*/) last=x;
else break;
if(!i) ++res;
} //以limit=119547为例,这部分计算的是[100000,119547]的值
for(re int i=Num.size()-2;i>=0;--i)
for(re int j=0;j<=9;++j)
res+=dp[i+1][j];
//同上,这部分计算[1,99999]的总值
return res;
}

记忆化

使用dfs遍历所有的情况

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
ll dfs(int x/*,bool con1,bool con2...bool conn*/)	//一些限制条件,x表示位数
{
if(!x) return /*something*/;
ll &res=dp[x]/*[some][conditions]...[]*/;
if(~res) return res; //该位置已达到
res=0;
for(re int i=0;i<=cntN;++i) //limit范围内
{
if(/*some conditions*/) /*something*/;
else if(/*some conditions*/) /*something*/;
else /*something*/;
res+=dfs(x-1/*,...,...,...,*/); //查找下一位
}
return res;
}

初始化

将所有应算的都算出来, $dp$ 的过程只是在查找 $limit$ 范围之内的。

1
2
3
4
5
6
7
8
inline void init()
{
for(re int i=0;i<=9;++i) dp[1][i]=1; //处理[0,9]
for(re int i=1;i<MAXN;++i)
for(re int j=0;j<=9;++j)
/*for(re int l=?;?;?)*/
/*some conditions*/dp[i][j]=dp[i-1][/*some conditions*/];
}

例题

Luogu P2657 windy数

解释

数位 $dp$ 模板题,用 $dp_{i,j}$ 表示位数为 $i$ ,最高位是 $j$ 的计数,递推出在 $limit$ 内满足的个数,对于相邻的两位,只要其相差不到 $2$ 都可以相加。最终答案便是 dp(r)-dp(l-1)

AC Code

递推写法。

查看代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<ctime>
#include<iomanip>
#include<queue>
#include<stack>
#include<map>
#include<vector>
#define gh() getchar()
#define re register
#define underMax(x,y) ((x)>(y)?(x):(y))
#define underMin(x,y) ((x)<(y)?(x):(y))
typedef long long ll;
using namespace std;
template<class T>
inline void underRead(T &x)
{
x=0;
char ch=gh(),t=0;
while(ch<'0'||ch>'9') t|=ch=='-',ch=gh();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=gh();
if(t) x=-x;
}
int dp[11][11],l,r;
inline int underAbs(int x)
{
return (x>0?x:-x);
}
inline void underInit()
{
for(re int i=0;i<=9;++i) dp[1][i]=1;
for(re int i=2;i<=10;++i)
for(re int j=0;j<=9;++j)
for(re int l=0;l<=9;++l)
if(underAbs(l-j)>=2)
dp[i][l]+=dp[i-1][j];
}
inline int underDp(int N)
{
if(!N) return 0;
vector<int>Num;
Num.clear();
int n=N,res=0;
while(n) Num.push_back(n%10),n/=10;
int last=-2;
for(re int i=Num.size()-1;i>=0;--i)
{
int x=Num[i];
for(re int j=i==Num.size()-1;j<x;++j)
if(underAbs(j-last)>=2) res+=dp[i+1][j];
if(underAbs(x-last)>=2) last=x;
else break;
if(!i) ++res;
}
for(re int i=Num.size()-2;i>=0;--i)
for(re int j=1;j<=9;++j)
res+=dp[i+1][j];
return res;
}
int main()
{
// freopen("digit-dp.in","r",stdin);
// freopen("digit-dp.out","w",stdout);
underRead(l),underRead(r);
underInit();
printf("%d",underDp(r)-underDp(l-1));
return 0;
}
/*
25 50
*/

Loj #10166 数字游戏

解释

与上一题差不了多少。预处理 $dp_{i,j,k}$ 表示 $i$ 位最高位为 $j$ 且模 $k$ 的计数即可。

AC Code

同样是迭代式写法。

查看代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<ctime>
#include<iomanip>
#include<queue>
#include<stack>
#include<map>
#include<vector>
#define gh() getchar()
#define re register
#define underMax(x,y) ((x)>(y)?(x):(y))
#define underMin(x,y) ((x)<(y)?(x):(y))
typedef long long ll;
using namespace std;
const int MAXN=21;
const int MAXMOD=101;
template<class T>
inline void underRead(T &x)
{
x=0;
char ch=gh(),t=0;
while(ch<'0'||ch>'9') t|=ch=='-',ch=gh();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=gh();
if(t) x=-x;
}
int l,r,mod;
ll dp[MAXN][10][MAXMOD];
inline int underMod(int x)
{
return (x%mod+mod)%mod;
}
inline void underInit()
{
memset(dp,0,sizeof(dp));
for(re int i=0;i<=9;++i) dp[1][i][i%mod]=1;
for(re int i=2;i<MAXN;++i)
for(re int j=0;j<=9;++j)
for(re int k=0;k<=9;++k)
for(re int l=0;l<mod;++l)
dp[i][j][l]+=dp[i-1][k][underMod(l-j)];
}
inline ll underDp(int N)
{
if(!N) return 1;
vector<int>Num;
while(N) Num.push_back(N%10),N/=10;
ll res=0;
int last=0;
for(re int i=Num.size()-1;i>=0;--i)
{
int x=Num[i];
for(re int j=0;j<x;++j) res+=dp[i+1][j][underMod(-last)];
last=underMod(last+x);
if(!i&&!last) ++res;
}
return res;
}
int main()
{
// freopen("digit-dp.in","r",stdin);
// freopen("digit-dp.out","w",stdout);
while(cin>>l>>r>>mod)
{
underInit();
printf("%lld\n",underDp(r)-underDp(l-1));
}
return 0;
}
/*
1 19 9
*/

Luogu P6669 组合数问题

解释

其实这道题用记忆化要好些一些,一个五维数组记录位数,$n$ 数是否达到上限, $m$ 数是否达到上限,$n$和 $m$ 是否相同过以及 $n$ 是否小于过 $m$ 的情况。这里需要很多组合数的前置知识(比如卢卡斯定理)以优化。这里不多解释,有兴趣可以在洛谷上看题解。

AC Code

查看代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<ctime>
#include<iomanip>
#include<queue>
#include<stack>
#include<map>
#include<vector>
#define gh() getchar()
#define re register
#define underMax(x,y) ((x)>(y)?(x):(y))
#define underMin(x,y) ((x)<(y)?(x):(y))
typedef long long ll;
using namespace std;
const int MAXN=61;
const int MOD=1e9+7;
template<class T>
inline void underRead(T &x)
{
x=0;
char ch=gh(),t=0;
while(ch<'0'||ch>'9') t|=ch=='-',ch=gh();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=gh();
if(t) x=-x;
}
ll n,m,dp[MAXN][2][2][2][2],N[MAXN],M[MAXN],k;
int T;
ll underDfs(int x,bool qn,bool qm,bool dif,bool sml)
{
if(!x) return sml;
ll &res=dp[x][qn][qm][dif][sml];
if(~res) return res;
res=0;
int cntN=qn?k-1:N[x],cntM=qm?k-1:M[x];
for(re int i=0;i<=cntN;++i)
for(re int j=0;(j<=i||dif)&&j<=cntM;++j)
res=(res+underDfs(x-1,qn|(i<cntN),qm|(j<cntM),dif|(i!=j),sml|(i<j)))%MOD;
return res;
}
int main()
{
// freopen("digit-dp.in","r",stdin);
// freopen("digit-dp.out","w",stdout);
underRead(T),underRead(k);
while(T--)
{
underRead(n),underRead(m);
memset(dp,-1,sizeof(dp));
ll Max=underMax(n,m),size=0;
while(Max) Max/=k,++size;
for(re int i=1;i<=size;++i) N[i]=n%k,n/=k;
for(re int i=1;i<=size;++i) M[i]=m%k,m/=k;
printf("%lld\n",underDfs(size,0,0,0,0));
}
return 0;
}
/*
1 2
3 3
*/