Skip to content

3915. Maximum Sum of Alternating Subsequence With Distance at Least K

Description

You are given an integer array nums of length n and an integer k.

Pick a subsequence with indices 0 <= i1 < i2 < ... < im < n such that:

  • For every 1 <= t < m, it+1 - it >= k.
  • The selected values form a strictly alternating sequence. In other words, either:
    • nums[i1] < nums[i2] > nums[i3] < ..., or
    • nums[i1] > nums[i2] < nums[i3] > ...

A subsequence of length 1 is also considered strictly alternating. The score of a valid subsequence is the sum of its selected values.

Return an integer denoting the maximum possible score of a valid subsequence.

 

Example 1:

Input: nums = [5,4,2], k = 2

Output: 7

Explanation:

An optimal choice is indices [0, 2], which gives values [5, 2].

  • The distance condition holds because 2 - 0 = 2 >= k.
  • The values are strictly alternating because 5 > 2.

The score is 5 + 2 = 7.

Example 2:

Input: nums = [3,5,4,2,4], k = 1

Output: 14

Explanation:

An optimal choice is indices [0, 1, 3, 4], which gives values [3, 5, 2, 4].

  • The distance condition holds because each pair of consecutive chosen indices differs by at least k = 1.
  • The values are strictly alternating since 3 < 5 > 2 < 4.

The score is 3 + 5 + 2 + 4 = 14.

Example 3:

Input: nums = [5], k = 1

Output: 5

Explanation:

The only valid subsequence is [5]. A subsequence with 1 element is always strictly alternating, so the score is 5.

 

Constraints:

  • 1 <= n == nums.length <= 105
  • 1 <= nums[i] <= 105
  • 1 <= k <= n

Solutions

Solution 1: Dynamic Programming + Binary Indexed Tree

State Definition

Let \(f[i][0]\) denote the maximum sum of a valid subsequence ending at index \(i\) where the last element is a valley (the next element must be larger to maintain alternation), and \(f[i][1]\) denote the maximum sum where the last element is a peak (the next element must be smaller).

Transitions

When transitioning, we enumerate a predecessor index \(j\) satisfying \(j \leq i - k\):

  • State \(f[i][0]\) (valley): transitions from \(f[j][1]\), requiring \(\text{nums}[j] > \text{nums}[i]\), i.e., query the maximum \(f[\cdot][1]\) over the value range \((\text{nums}[i],\ +\infty)\):
\[f[i][0] = \text{nums}[i] + \max\!\left(0,\ \max_{\substack{j \leq i-k \\ \text{nums}[j] > \text{nums}[i]}} f[j][1]\right)\]
  • State \(f[i][1]\) (peak): transitions from \(f[j][0]\), requiring \(\text{nums}[j] < \text{nums}[i]\), i.e., query the maximum \(f[\cdot][0]\) over the value range \([1,\ \text{nums}[i]-1]\):
\[f[i][1] = \text{nums}[i] + \max\!\left(0,\ \max_{\substack{j \leq i-k \\ \text{nums}[j] < \text{nums}[i]}} f[j][0]\right)\]

The final answer is \(\max_{0 \leq i < n}\max(f[i][0],\ f[i][1])\).

Optimization

The transitions involve dynamic prefix/suffix maximum queries over a value domain, which can be maintained efficiently with two Binary Indexed Trees (BITs):

  • BIT \(\text{bit}_0\): indexed by value, maintains the prefix maximum of \(f[\cdot][0]\), used to query cases where \(\text{nums}[j] < \text{nums}[i]\).
  • BIT \(\text{bit}_1\): indexed by \(M + 1 - \text{val}\) (reversed, where \(M = \max(\text{nums})\) ), maintains the prefix maximum of \(f[\cdot][1]\), equivalent to a suffix maximum over the value domain, used to query cases where \(\text{nums}[j] > \text{nums}[i]\).

To ensure only indices \(j \leq i - k\) participate in transitions, when processing index \(i\), we insert the state of index \(i - k\) into the BITs using a sliding pointer.

The time complexity is \(O(n \log M)\) and the space complexity is \(O(M)\), where \(n\) is the length of the array and \(M = \max(\text{nums})\).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class FenwickTree:
    def __init__(self, n):
        self.n = n
        self.tree = [0] * (n + 1)

    def update(self, index: int, val: int) -> None:
        while index <= self.n:
            self.tree[index] = max(self.tree[index], val)
            index += index & (-index)  # 往后更新

    def preSum(self, pos):
        # 按照预期的方式求前缀最大值
        ans = 0
        while pos >= 1:
            ans = max(ans, self.tree[pos])
            pos -= pos & (-pos)
        return ans


class Solution:
    def maxAlternatingSum(self, nums: list[int], k: int) -> int:
        stl = sorted(set(nums))  # 将nums中不同的数字进行排序
        rank = {
            v: i + 1 for i, v in enumerate(stl)
        }  # 将nums中的值快速转换成stl中的索引
        fwt0 = FenwickTree(len(stl))
        fwt1 = FenwickTree(len(stl))

        n = len(nums)
        dp = [[0, 0] for _ in range(n)]
        res = nums[0]
        for i in range(n):
            dp[i][0] = dp[i][1] = nums[i]
            if i >= k:
                indx = rank[nums[i]]  # 找到nums[i]在stl中的索引
                dp[i][1] = max(
                    dp[i][1], fwt0.preSum(indx - 1) + nums[i]
                )  # indx-1即表示小于nums[i]的部分
                dp[i][0] = max(
                    dp[i][0], fwt1.preSum(len(stl) - indx) + nums[i]
                )  # len(stl)-indx即表示在倒序列表中大于nums[i]的部分

            if i - k + 1 >= 0:
                indx = rank[nums[i - k + 1]]
                fwt0.update(indx, dp[i - k + 1][0])  # 在正序列表中更新i-k+1位置的值
                fwt1.update(
                    len(stl) - indx + 1, dp[i - k + 1][1]
                )  # 在倒序列表中更新i-k+1位置的值

            res = max(res, dp[i][0], dp[i][1])  # 更新答案

        return res
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class Solution {
    public long maxAlternatingSum(int[] nums, int k) {
        long maxSum = 0;
        int n = nums.length;
        int m = Arrays.stream(nums).max().getAsInt();
        long[][] dp = new long[n][2];
        SegmentTree[] sts = new SegmentTree[2];
        for (int j = 0; j < 2; j++) {
            sts[j] = new SegmentTree(m + 1);
        }
        for (int i = 0; i < n; i++) {
            if (i >= k) {
                sts[0].update(nums[i - k], dp[i - k][0]);
                sts[1].update(nums[i - k], dp[i - k][1]);
            }
            dp[i][0] = sts[1].getMax(0, nums[i] - 1) + nums[i];
            dp[i][1] = sts[0].getMax(nums[i] + 1, m) + nums[i];
            maxSum = Math.max(maxSum, Math.max(dp[i][0], dp[i][1]));
        }
        return maxSum;
    }
}

class SegmentTree {
    private int n;
    private long[] tree;

    public SegmentTree(int n) {
        this.n = n;
        this.tree = new long[n * 4];
    }

    public long getMax(int start, int end) {
        return getMax(start, end, 0, 0, n - 1);
    }

    public void update(int index, long value) {
        update(index, value, 0, 0, n - 1);
    }

    private long getMax(int rangeStart, int rangeEnd, int treeIndex, int treeStart, int treeEnd) {
        if (rangeStart > rangeEnd) {
            return 0;
        }
        if (rangeStart == treeStart && rangeEnd == treeEnd) {
            return tree[treeIndex];
        }
        int mid = treeStart + (treeEnd - treeStart) / 2;
        if (rangeEnd <= mid) {
            return getMax(rangeStart, rangeEnd, treeIndex * 2 + 1, treeStart, mid);
        } else if (rangeStart > mid) {
            return getMax(rangeStart, rangeEnd, treeIndex * 2 + 2, mid + 1, treeEnd);
        } else {
            return Math.max(getMax(rangeStart, mid, treeIndex * 2 + 1, treeStart, mid), getMax(mid + 1, rangeEnd, treeIndex * 2 + 2, mid + 1, treeEnd));
        }
    }

    private void update(int rangeIndex, long value, int treeIndex, int start, int end) {
        if (start == end) {
            tree[treeIndex] = value;
            return;
        }
        int mid = start + (end - start) / 2;
        if (rangeIndex <= mid) {
            update(rangeIndex, value, treeIndex * 2 + 1, start, mid);
        } else {
            update(rangeIndex, value, treeIndex * 2 + 2, mid + 1, end);
        }
        tree[treeIndex] = Math.max(tree[treeIndex * 2 + 1], tree[treeIndex * 2 + 2]);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Solution {
public:
    long long maxAlternatingSum(vector<int>& nums, int K) {
        int n = nums.size();

        int idx[n];
        map<int, int> mp;
        for (int x : nums) mp[x] = 1;
        int m = 0;
        for (auto &p : mp) p.second = ++m;
        for (int i = 0; i < n; i++) idx[i] = mp[nums[i]];

        const long long INF = 1e18;
        long long tree[2][m + 1];
        for (int i = 0; i < 2; i++) for (int j = 0; j <= m; j++) tree[i][j] = -INF;


        auto lb = [&](int x) { return x & (-x); };

        auto update = [&](long long *tree, int pos, long long val) {
            for (; pos <= m; pos += lb(pos)) tree[pos] = max(tree[pos], val);
        };

        auto query = [&](long long *tree, int pos) {
            long long ret = -INF;
            for (; pos; pos -= lb(pos)) ret = max(ret, tree[pos]);
            return ret;
        };


        long long ans = 0;
        long long f[n + 1][2];
        for (int i = 0; i <= n; i++) for (int j = 0; j < 2; j++) f[i][j] = -INF;
        for (int i = 1, j = 1; i <= n; i++) {
            while (i - j >= K) {
                update(tree[0], idx[j - 1], f[j][0]);
                update(tree[1], m + 1 - idx[j - 1], f[j][1]);
                j++;
            }
            f[i][0] = max(0LL, query(tree[1], m - idx[i - 1])) + nums[i - 1];
            f[i][1] = max(0LL, query(tree[0], idx[i - 1] - 1)) + nums[i - 1];
            ans = max({ans, f[i][0], f[i][1]});
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
type fenwick []int64

func (f fenwick) update(i int, val int64) {
    for ; i < len(f); i += i & -i {
        f[i] = max(f[i], val)
    }
}

func (f fenwick) preMax(i int) (res int64) {
    for ; i > 0; i &= i - 1 {
        res = max(res, f[i])
    }
    return
}

func maxAlternatingSum(nums []int, k int) (ans int64) {
    sorted := slices.Clone(nums)
    slices.Sort(sorted)
    sorted = slices.Compact(sorted)

    n := len(nums)
    fInc := make([]int64, n)
    fDec := make([]int64, n)

    m := len(sorted)
    inc := make(fenwick, m+1)
    dec := make(fenwick, m+1)

    for i, x := range nums {
        if i >= k {
            j := nums[i-k]
            inc.update(m-j, fInc[i-k])
            dec.update(j+1, fDec[i-k])
        }

        j := sort.SearchInts(sorted, x)
        nums[i] = j

        fInc[i] = dec.preMax(j) + int64(x)
        fDec[i] = inc.preMax(m-1-j) + int64(x)
        ans = max(ans, fInc[i], fDec[i])
    }

    return
}

Comments