题解:AT_abc355_g [ABC355G] Baseball

最后更新于 2025-08-03 09:26:52
作者
分类 题解

来写一个比二分队列还要短的双老哥做法。

前半部分快速讲一讲:设 $dp_{i,j}$ 表示前 $i$ 个位置放了 $j$ 个点的方案数,则有

$$dp_{i,j}=\min_kdp_{k,j-1}+w(k,i)$$

$w$ 为中间点的距离乘以权值之和。显然 $w$ 满足四边形不等式,从而 $dp_{n,j}$ 关于 $j$ 是凸的,可以使用 wqs 二分套决策单调性解决。问题转化为这个 dp:

$$f_i=\min_j f_j+w(j,i)$$

我们可以对其使用 简易版 LARSCH 算法。具体地,考虑分治,调用 $\operatorname{solve}(l,r)$ 时已经满足:

  • $0\sim l$ 的 $f$ 值和决策点都算对了。
  • $r$ 的 $f$ 值已经从 $0\sim l$ 转移过来了。

设 $mid=(l+r)/2$,那么我们需要做的事情是:

  • 把 $l$ 的决策点到 $r$ 的决策点之间的点转移到 $mid$。
  • 递归 $\operatorname{solve}(l,mid)$。
  • 把 $l+1$ 到 $mid$ 之间的点转移到 $r$。
  • 递归 $\operatorname{solve}(mid+1,r)$。

然后我们就把所有 $f$ 值算对了。正确性是显然的,时间复杂度是多少呢?可以发现,对于分治的每一层,我们都相当于把所有点遍历了一遍,所以复杂度是 $T(n)=2T(n/2)+O(n)=O(n\log n)$ 的。实现出来极短,而且可以处理在线的转移。乘上 wqs 二分,总复杂度为 $O(n\log n\log V)$。

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAXN = 5e4 + 10;

int n, m; ll s[MAXN], si[MAXN];

inline 
ll w(int l, int r) {
	if (l > r) return 0;
	if (l == 1 && r == n) return 1e18;
	if (l == 1) return (r + 1) * s[r] - si[r];
	if (r == n) return (si[n] - si[l - 1]) - (l - 1) * (s[n] - s[l - 1]);
	int mid = l + r >> 1;
	return (si[mid] - si[l - 1]) - (l - 1) * (s[mid] - s[l - 1])
		+ (r + 1) * (s[r] - s[mid]) - (si[r] - si[mid]);
}

ll dp[MAXN], X; int cnt[MAXN], p[MAXN];

inline 
void check(int i, int j) {
	ll x = dp[j] + w(j + 1, i - 1) + X;
	if (x < dp[i]) dp[i] = x, cnt[i] = cnt[j] + 1, p[i] = j;
	else if (x == dp[i] && cnt[j] + 1 < cnt[i]) cnt[i] = cnt[j] + 1, p[i] = j;
}

void solve(int l, int r) {
	if (r - l == 1) return ; int mid = l + r >> 1;
	for (int i = p[l]; i <= p[r]; i++) check(mid, i);
	solve(l, mid);
	for (int i = l + 1; i <= mid; i++) check(r, i);
	solve(mid, r);
}

inline 
bool check() {
	for (int i = 1; i <= n + 1; i++) dp[i] = 1e18, cnt[i] = p[i] = 0;
	check(n + 1, 0), solve(0, n + 1);
	return cnt[n + 1] <= m;
}

ll l, r, ans;

int main() {
	scanf("%d%d", &n, &m), m++;
	for (int i = 1; i <= n; i++) scanf("%lld", &s[i]);
	for (int i = 1; i <= n; i++) si[i] = si[i - 1] + i * s[i];
	for (int i = 1; i <= n; i++) s[i] += s[i - 1];
	for (l = 0, r = 1e10; l <= r; ) {
		X = l + r >> 1;
		if (check()) r = X - 1, ans = X;
		else l = X + 1;
	}
	X = ans, check(), printf("%lld", dp[n + 1] - X * m);
}