跳转至

2736. 最大和查询

题目描述

给你两个长度为 n 、下标从 0 开始的整数数组 nums1nums2 ,另给你一个下标从 1 开始的二维数组 queries ,其中 queries[i] = [xi, yi]

对于第 i 个查询,在所有满足 nums1[j] >= xinums2[j] >= yi 的下标 j (0 <= j < n) 中,找出 nums1[j] + nums2[j]最大值 ,如果不存在满足条件的 j 则返回 -1

返回数组 answer其中 answer[i] 是第 i 个查询的答案。

 

示例 1:

输入:nums1 = [4,3,1,2], nums2 = [2,4,9,5], queries = [[4,1],[1,3],[2,5]]
输出:[6,10,7]
解释:
对于第 1 个查询:xi = 4 且 yi = 1 ,可以选择下标 j = 0 ,此时 nums1[j] >= 4 且 nums2[j] >= 1 。nums1[j] + nums2[j] 等于 6 ,可以证明 6 是可以获得的最大值。
对于第 2 个查询:xi = 1 且 yi = 3 ,可以选择下标 j = 2 ,此时 nums1[j] >= 1 且 nums2[j] >= 3 。nums1[j] + nums2[j] 等于 10 ,可以证明 10 是可以获得的最大值。
对于第 3 个查询:xi = 2 且 yi = 5 ,可以选择下标 j = 3 ,此时 nums1[j] >= 2 且 nums2[j] >= 5 。nums1[j] + nums2[j] 等于 7 ,可以证明 7 是可以获得的最大值。
因此,我们返回 [6,10,7] 。

示例 2:

输入:nums1 = [3,2,5], nums2 = [2,3,4], queries = [[4,4],[3,2],[1,1]]
输出:[9,9,9]
解释:对于这个示例,我们可以选择下标 j = 2 ,该下标可以满足每个查询的限制。

示例 3:

输入:nums1 = [2,1], nums2 = [2,3], queries = [[3,3]]
输出:[-1]
解释:示例中的查询 xi = 3 且 yi = 3 。对于每个下标 j ,都只满足 nums1[j] < xi 或者 nums2[j] < yi 。因此,不存在答案。 

 

提示:

  • nums1.length == nums2.length 
  • n == nums1.length 
  • 1 <= n <= 105
  • 1 <= nums1[i], nums2[i] <= 109 
  • 1 <= queries.length <= 105
  • queries[i].length == 2
  • xi == queries[i][1]
  • yi == queries[i][2]
  • 1 <= xi, yi <= 109

解法

方法一:树状数组

本题属于二维偏序问题。

二维偏序是这样一类问题:给定若干个点对 \((a_1, b_1)\), \((a_2, b_2)\), \(\cdots\), \((a_n, b_n)\),并定义某种偏序关系,现在给定点 \((a_i, b_i)\),求满足偏序关系的点对 \((a_j, b_j)\) 中的数量/最值。即:

\[ \left(a_{j}, b_{j}\right) \prec\left(a_{i}, b_{i}\right) \stackrel{\text { def }}{=} a_{j} \lesseqgtr a_{i} \text { and } b_{j} \lesseqgtr b_{i} \]

二维偏序的一般解决方法是排序一维,用数据结构处理第二维(这种数据结构一般是树状数组)。

对于本题,我们可以创建一个数组 \(nums\),其中 \(nums[i]=(nums_1[i], nums_2[i])\),然后对 \(nums\) 按照 \(nums_1\) 从大到小的顺序排序,将查询 \(queries\) 也按照 \(x\) 从大到小的顺序排序。

接下来,遍历每个查询 \(queries[i] = (x, y)\),对于当前查询,我们循环将 \(nums\) 中所有大于等于 \(x\) 的元素的 \(nums_2\) 的值插入到树状数组中,树状数组维护的是离散化后的 \(nums_2\) 的区间中 \(nums_1 + nums_2\) 的最大值。那么我们只需要在树状数组中查询大于等于离散化后的 \(y\) 区间对应的最大值即可。注意,由于树状数组维护的是前缀最大值,所以我们在实现上,可以将 \(nums_2\) 反序插入到树状数组中。

时间复杂度 \(O((n + m) \times \log n + m \times \log m)\),空间复杂度 \(O(n + m)\)。其中 \(n\) 是数组 \(nums\) 的长度,而 \(m\) 是数组 \(queries\) 的长度。

相似题目:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class BinaryIndexedTree:
    __slots__ = ["n", "c"]

    def __init__(self, n: int):
        self.n = n
        self.c = [-1] * (n + 1)

    def update(self, x: int, v: int):
        while x <= self.n:
            self.c[x] = max(self.c[x], v)
            x += x & -x

    def query(self, x: int) -> int:
        mx = -1
        while x:
            mx = max(mx, self.c[x])
            x -= x & -x
        return mx


class Solution:
    def maximumSumQueries(
        self, nums1: List[int], nums2: List[int], queries: List[List[int]]
    ) -> List[int]:
        nums = sorted(zip(nums1, nums2), key=lambda x: -x[0])
        nums2.sort()
        n, m = len(nums1), len(queries)
        ans = [-1] * m
        j = 0
        tree = BinaryIndexedTree(n)
        for i in sorted(range(m), key=lambda i: -queries[i][0]):
            x, y = queries[i]
            while j < n and nums[j][0] >= x:
                k = n - bisect_left(nums2, nums[j][1])
                tree.update(k, nums[j][0] + nums[j][1])
                j += 1
            k = n - bisect_left(nums2, y)
            ans[i] = tree.query(k)
        return ans
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class BinaryIndexedTree {
    private int n;
    private int[] c;

    public BinaryIndexedTree(int n) {
        this.n = n;
        c = new int[n + 1];
        Arrays.fill(c, -1);
    }

    public void update(int x, int v) {
        while (x <= n) {
            c[x] = Math.max(c[x], v);
            x += x & -x;
        }
    }

    public int query(int x) {
        int mx = -1;
        while (x > 0) {
            mx = Math.max(mx, c[x]);
            x -= x & -x;
        }
        return mx;
    }
}

class Solution {
    public int[] maximumSumQueries(int[] nums1, int[] nums2, int[][] queries) {
        int n = nums1.length;
        int[][] nums = new int[n][0];
        for (int i = 0; i < n; ++i) {
            nums[i] = new int[] {nums1[i], nums2[i]};
        }
        Arrays.sort(nums, (a, b) -> b[0] - a[0]);
        Arrays.sort(nums2);
        int m = queries.length;
        Integer[] idx = new Integer[m];
        for (int i = 0; i < m; ++i) {
            idx[i] = i;
        }
        Arrays.sort(idx, (i, j) -> queries[j][0] - queries[i][0]);
        int[] ans = new int[m];
        int j = 0;
        BinaryIndexedTree tree = new BinaryIndexedTree(n);
        for (int i : idx) {
            int x = queries[i][0], y = queries[i][1];
            for (; j < n && nums[j][0] >= x; ++j) {
                int k = n - Arrays.binarySearch(nums2, nums[j][1]);
                tree.update(k, nums[j][0] + nums[j][1]);
            }
            int p = Arrays.binarySearch(nums2, y);
            int k = p >= 0 ? n - p : n + p + 1;
            ans[i] = tree.query(k);
        }
        return ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class BinaryIndexedTree {
private:
    int n;
    vector<int> c;

public:
    BinaryIndexedTree(int n) {
        this->n = n;
        c.resize(n + 1, -1);
    }

    void update(int x, int v) {
        while (x <= n) {
            c[x] = max(c[x], v);
            x += x & -x;
        }
    }

    int query(int x) {
        int mx = -1;
        while (x > 0) {
            mx = max(mx, c[x]);
            x -= x & -x;
        }
        return mx;
    }
};

class Solution {
public:
    vector<int> maximumSumQueries(vector<int>& nums1, vector<int>& nums2, vector<vector<int>>& queries) {
        vector<pair<int, int>> nums;
        int n = nums1.size(), m = queries.size();
        for (int i = 0; i < n; ++i) {
            nums.emplace_back(-nums1[i], nums2[i]);
        }
        sort(nums.begin(), nums.end());
        sort(nums2.begin(), nums2.end());
        vector<int> idx(m);
        iota(idx.begin(), idx.end(), 0);
        sort(idx.begin(), idx.end(), [&](int i, int j) { return queries[j][0] < queries[i][0]; });
        vector<int> ans(m);
        int j = 0;
        BinaryIndexedTree tree(n);
        for (int i : idx) {
            int x = queries[i][0], y = queries[i][1];
            for (; j < n && -nums[j].first >= x; ++j) {
                int k = nums2.end() - lower_bound(nums2.begin(), nums2.end(), nums[j].second);
                tree.update(k, -nums[j].first + nums[j].second);
            }
            int k = nums2.end() - lower_bound(nums2.begin(), nums2.end(), y);
            ans[i] = tree.query(k);
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
type BinaryIndexedTree struct {
    n int
    c []int
}

func NewBinaryIndexedTree(n int) BinaryIndexedTree {
    c := make([]int, n+1)
    for i := range c {
        c[i] = -1
    }
    return BinaryIndexedTree{n: n, c: c}
}

func (bit *BinaryIndexedTree) update(x, v int) {
    for x <= bit.n {
        bit.c[x] = max(bit.c[x], v)
        x += x & -x
    }
}

func (bit *BinaryIndexedTree) query(x int) int {
    mx := -1
    for x > 0 {
        mx = max(mx, bit.c[x])
        x -= x & -x
    }
    return mx
}

func maximumSumQueries(nums1 []int, nums2 []int, queries [][]int) []int {
    n, m := len(nums1), len(queries)
    nums := make([][2]int, n)
    for i := range nums {
        nums[i] = [2]int{nums1[i], nums2[i]}
    }
    sort.Slice(nums, func(i, j int) bool { return nums[j][0] < nums[i][0] })
    sort.Ints(nums2)
    idx := make([]int, m)
    for i := range idx {
        idx[i] = i
    }
    sort.Slice(idx, func(i, j int) bool { return queries[idx[j]][0] < queries[idx[i]][0] })
    tree := NewBinaryIndexedTree(n)
    ans := make([]int, m)
    j := 0
    for _, i := range idx {
        x, y := queries[i][0], queries[i][1]
        for ; j < n && nums[j][0] >= x; j++ {
            k := n - sort.SearchInts(nums2, nums[j][1])
            tree.update(k, nums[j][0]+nums[j][1])
        }
        k := n - sort.SearchInts(nums2, y)
        ans[i] = tree.query(k)
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class BinaryIndexedTree {
    private n: number;
    private c: number[];

    constructor(n: number) {
        this.n = n;
        this.c = Array(n + 1).fill(-1);
    }

    update(x: number, v: number): void {
        while (x <= this.n) {
            this.c[x] = Math.max(this.c[x], v);
            x += x & -x;
        }
    }

    query(x: number): number {
        let mx = -1;
        while (x > 0) {
            mx = Math.max(mx, this.c[x]);
            x -= x & -x;
        }
        return mx;
    }
}

function maximumSumQueries(nums1: number[], nums2: number[], queries: number[][]): number[] {
    const n = nums1.length;
    const m = queries.length;
    const nums: [number, number][] = [];
    for (let i = 0; i < n; ++i) {
        nums.push([nums1[i], nums2[i]]);
    }
    nums.sort((a, b) => b[0] - a[0]);
    nums2.sort((a, b) => a - b);
    const idx: number[] = Array(m)
        .fill(0)
        .map((_, i) => i);
    idx.sort((i, j) => queries[j][0] - queries[i][0]);
    const ans: number[] = Array(m).fill(0);
    let j = 0;
    const search = (x: number) => {
        let [l, r] = [0, n];
        while (l < r) {
            const mid = (l + r) >> 1;
            if (nums2[mid] >= x) {
                r = mid;
            } else {
                l = mid + 1;
            }
        }
        return l;
    };
    const tree = new BinaryIndexedTree(n);
    for (const i of idx) {
        const [x, y] = queries[i];
        for (; j < n && nums[j][0] >= x; ++j) {
            const k = n - search(nums[j][1]);
            tree.update(k, nums[j][0] + nums[j][1]);
        }
        const k = n - search(y);
        ans[i] = tree.query(k);
    }
    return ans;
}

方法二:排序 + 单调栈 + 二分查找

首先,将所有查询按 \(x\) 阈值降序排好,把全部数对按 \(\textit{nums1}[i]\) 降序处理。 迭代处理到第 \(j\) 个查询时,将任何满足 \(\textit{nums1}[i] \geq x_j\) 的数对加入单调栈。

单调栈的数对排序规则是:按 \(\textit{nums2}[i]\) 升序、\(\textit{nums1}[i] + \textit{nums2}[i]\) 降序。

如此保证栈中每个数对的 \(\textit{nums2}[i]\) 更大时, \(\textit{nums1}[i] + \textit{nums2}[i]\) 却更小,隔绝无效候选数对。

对于每个查询 \(query_j\),靠二分查找在栈中找到第一个 \(\textit{nums2}[i] \geq y_j\) 的数对,其对应的 \(\textit{nums1}[i] + \textit{nums2}[i]\) 即为答案。

复杂度解析

\(n\) 是数组 \(nums1\) 的长度,\(m\) 是数组 \(queries\) 的长度。

  • 时间复杂度:\(O((n + m) \times \log n + m \times \log m)\)
  • 空间复杂度:\(O(n + m)\)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Solution:
    def maximumSumQueries(
        self, nums1: list[int], nums2: list[int], queries: list[list[int]]
    ) -> list[int]:
        max_values = [-1] * len(queries)

        queries = [(query[0], query[1], idx) for idx, query in enumerate(queries)]
        # Process queries by descending x threshold and y threshold.
        queries.sort(key=lambda x: (-x[0], -x[1]))

        tuples: list[tuple[int, int]] = []  # Format: (num 1, num 2).
        for num_1, num_2 in zip(nums1, nums2):
            tuples.append((num_1, num_2))

        # Process queries by descending num 1 and num 2.
        # Sort by ascending num 1 and num 2 to pop from the back.
        tuples.sort(key=lambda x: (x[0], x[1]))

        stack: list[tuple[int, int]] = []  # Format: (num 2, sum).

        for query_1, query_2, query_idx in queries:
            while tuples and tuples[-1][0] >= query_1:  # Tuple's num 1 >= x threshold.
                num_1, num_2 = tuples.pop(-1)
                nums_sum = num_1 + num_2

                while stack and stack[-1][0] < num_2 and stack[-1][1] <= nums_sum:
                    stack.pop(-1)  # Stack top isn't better than popped tuple.

                insertion_idx = bisect_left(stack, (num_2, nums_sum))

                if insertion_idx == len(stack):
                    stack.insert(insertion_idx, (num_2, nums_sum))

                elif stack[insertion_idx][1] < nums_sum:
                    stack.insert(insertion_idx, (num_2, nums_sum))

            search_idx = bisect_left(stack, (query_2, 0))
            if search_idx < len(stack):
                max_values[query_idx] = stack[search_idx][1]

        return max_values
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class Solution {
public:
    vector<int> maximumSumQueries(vector<int>& nums1, vector<int>& nums2, vector<vector<int>>& queries) {
        vector<int> maxValues(queries.size(), -1);

        vector<vector<int>> queriesIndices;
        for (int idx = 0; idx < queries.size(); idx++)
            queriesIndices.push_back({queries[idx][0], queries[idx][1], idx});

        // Process queries by descending x threshold and y threshold.
        // Sort ascendingly and later pop from the back.
        sort(queriesIndices.begin(), queriesIndices.end());

        vector<pair<int, int>> numsPairs; // Format: {num 1, num 2}.
        for (int idx = 0; idx < nums2.size(); idx++)
            numsPairs.push_back({nums1[idx], nums2[idx]});

        // Process queries by descending num 1 and num 2.
        // Sort by ascending num 1 and num 2 to pop from the back.
        sort(numsPairs.begin(), numsPairs.end());

        deque<pair<int, int>> stack; // Format: {num 2, sum}.

        while (!queriesIndices.empty()) {
            int queryOne = queriesIndices.back()[0];
            int queryTwo = queriesIndices.back()[1];
            int queryIdx = queriesIndices.back()[2];
            queriesIndices.pop_back();

            // Pair's num 1 >= x threshold.
            while (!numsPairs.empty() && numsPairs.back().first >= queryOne) {
                auto [numOne, numTwo] = numsPairs.back();
                numsPairs.pop_back();
                int numsSum = numOne + numTwo;

                while (!stack.empty() and stack.back().first < numTwo and stack.back().second <= numsSum)
                    stack.pop_back(); // Stack top isn't better than popped pair.

                pair<int, int> targetPair = {numTwo, numsSum};
                int insertion_idx = lower_bound(stack.begin(), stack.end(), targetPair) - stack.begin();

                if (insertion_idx == stack.size())
                    stack.insert(stack.begin() + insertion_idx, targetPair);

                else if (stack[insertion_idx].second < numsSum)
                    stack.insert(stack.begin() + insertion_idx, targetPair);
            }

            pair<int, int> queryNumTwoPair = {queryTwo, 0};

            int search_idx = lower_bound(stack.begin(), stack.end(), queryNumTwoPair) - stack.begin();
            if (search_idx < stack.size())
                maxValues[queryIdx] = stack[search_idx].second;
        }

        return maxValues;
    }
};

评论