时间限制 空间限制
1000 ms 65536 KB

# 题目描述

河边有 nn 个石头,第 ii 个石头位置为 ii,价值为 aia_i。现在从这 nn 个石头中选出 mm 个石头,要求这 mm 个石头中任意两个石头之间的距离不小于 kk,最大化这 mm 个石头的价值总和。

# 输入格式

第一行三个整数 n,m,kn, m, k,表示石头数量,要选的石头数量,距离限制。1n3×105,1m,kn1 \le n \le 3 \times 10^5, 1 \le m, k \le n

第二行 nn 个整数 aia_i,表示每个石头的价值。ai109|a_i| \le 10^9

数据保证至少有一个选石头的方案。

# 输出格式

一行一个整数,表示答案。

# 输入样例

5 2 2
-2 4 6 7 3

# 输出样例

11

# 题解:wqs 二分 + 动态规划

首先考虑暴力 dp\text {dp},不难设计出 dp[i][j]dp[i][j] 表示考虑前 ii 个石头,满足距离限制的条件下选出 jj 个石头的最大价值,但是很显然时空双爆。

如果去掉第二维的数量限制,我们就能得到一个线性时间复杂度的一维 dp\text{dp} 方程:

dp[i]=max(dp[i1],dp[ik]+a[i])dp[i] = \max (dp[i - 1], dp[i - k] + a[i])

于是我们自然而然地想到了用 wqs\text {wqs} 二分来优化这个过程。下面证明答案具有凹凸性,即:

g(x)g(x1)g(x+1)g(x)g(x) - g(x - 1) \ge g(x + 1) - g(x)

x[1,m1]\forall \, x \in [1, m - 1] 恒成立。

显然对于一个合法的选取方式,从选出的石头中任意选择几个构成的子集,也一定合法。考虑 g(x1)g(x - 1)g(x+1)g(x + 1) 的这 2x2x 个石头,总能找到一种划分方案分出 22xx 个石头,其中我们总能挑出一组 xx 个石头,使得 g(x1)+g(x+1)2g(x)g(x - 1) + g(x + 1) \le 2 g(x) 成立,上凸性成立。

参考代码:

#include <stdio.h>
#include <math.h>
#include <stdbool.h>
typedef long long ll;
int n, m, k, a[300010];
ll dp[300010], cnt[300010];
static inline void Init()
{
    for (int i = 1; i <= n; ++ i)
        dp[i] = cnt[i] = 0;
    return;
}
static inline bool Check(ll mid)
{
    Init();
    for (int i = 1; i <= n; ++ i)
    {
        int last = fmax(0, i - k);
        dp[i] = dp[i - 1];
        cnt[i] = cnt[i - 1];
        if (dp[last] + a[i] - mid > dp[i])
        {
            dp[i] = dp[last] + a[i] - mid;
            cnt[i] = cnt[last] + 1;
        }
        else if (dp[last] + a[i] - mid == dp[i])
        {
            if (cnt[i] > cnt[last] + 1)
                cnt[i] = cnt[last] + 1;
        }
    }
    return m >= cnt[n];
}
int main()
{
    // freopen("J.in", "r", stdin);
    // freopen("J.out", "w", stdout);
    scanf("%d%d%d", &n, &m, &k);
    for (int i = 1; i <= n; ++ i)
        scanf("%d", &a[i]);
    
    ll l = -1e9, r = 1e9;
    while (l < r)
    {
        ll mid = (l + r) >> 1;
        if (Check(mid))
            r = mid;
        else
            l = mid + 1;
    }
    Check(l);
    printf("%lld\n", dp[n] + m * l);
    return 0;
}