题目描述

给定一个整数数组 nums,按要求返回一个新数组 counts。数组 counts 有该性质: counts[i] 的值是  nums[i] 右侧小于 nums[i] 的元素的数量。

示例:

输入: [5,2,6,1]
输出: [2,1,1,0] 
解释:
5 的右侧有 2 个更小的元素 (2 和 1).
2 的右侧仅有 1 个更小的元素 (1).
6 的右侧有 1 个更小的元素 (1).
1 的右侧有 0 个更小的元素.

https://leetcode-cn.com/problems/count-of-smaller-numbers-after-self/

解法1

解法1采用Merge sort。之所以采用Sort是因为第i个数右边小于nums[i]的数字在排序后都会跑到nums[i]的左边,我们只需要在排序的过程中跟踪那些从nums[i]右边跑到nums[i]左边元素的个数,就能够求得答案。

我们举例nums = [5,4,3,1],解释使用Merge sort计算计算过程。

图1 Merge sort调用栈

在执行Merge sort过程中,Sort([5, 4, 3, 1])会被拆分为{Sort([5, 4]), Sort([3, 1])},然后Sort([5, 4])又会被拆分为{Sort([5]), Sort([4])}。当发现对独立的元素排序(如Sort([5])),它自然是有序的,将不会执行任何操作直接返回。为了简化图1,我并没有绘制对于单个元素sort的调用过程。

当Sort([5])与Sort([4])返回后,将两个有序的列表[5]与[4]进行Merge。而在Merge的过程中,我们就可以“跟踪”那些从num[i]后面跑到num[i]前面的元素个数,这也正是我们要求的counts[i]。

需要注意,我们对nums直接排序会破坏nums的顺序,而题目要求我们按照nums原有的顺序输出到counts。为了不破坏nums,我们不能对nums排序,而是使用索引数组originIndices记录nums元素对应的索引,我们将nums的索引排序。

排序前
nums = [5, 4, 3, 1]
originIndices = [0, 1, 2, 3]
排序后
nums = [5, 4, 3, 1]
originIndices = [3, 2, 1, 0]

我们在merge过程中需要使用到排序前的索引,所以不能直接修改originIndices,而是将排序后的索引暂存到sortedIndices,当merge结束后再将sortedIndices写回originIndices。我们在图2绘制了merge(nums, 0, 0, 1, 1)的执行过程。counts[i]用于对应于nums[i]右侧小于nums[i]元素的数量。

图2 merge([5], [4])的过程

在merge过程中,我们使用了一个非常重要的变量rightCount。为了理解这个变量的意义,我们先忽略对索引排序这个细节。我回忆下merge sort的merge过程,merge sort是非原地排序,需要nums存放排序前的结果,sorted_nums存放排序后的结果。当nums[left]<=nums[right], 我们取nums[left++]放入sorted_nums,否则取nums[right++]。

假设left = 1,right = 3,我们发现 nums[3]、nums[4]、nums[5]都小于nums[1] ,那么我们是不是要把right=3,4,5对应的数字依次放入sorted_nums?那么这3个数是不是都小于nums[1]?所以我们记录nums[right] < nums[left]出现的次数rightCount,将它累积到counts[left] 中就是题目要求的答案。

之所以累积,而不是直接赋值counts[left] = rightCount是因为merge过程会发生多次。在下一次merge的过程,可能又有新的元素从nums[i]的后面跑到它的前面,因此我们需要累积到上一次计算的counts[i]。

merge(nums, 1, 1, 2, 3)的过程在这里就不绘制了,和上面的merge(nums, 0, 0, 1, 1)的过程类似。我们接下来描述merge(nums, 0, 1, 2, 3)的过程,即merge([4, 5], [1, 3])。

图3 merge([4, 5], [1, 3])的过程
图3描述了merge(nums, [4, 5], [1, 3])的过程,leftIdx通过计算nums[originIndices[leftIdx]]找到nums中的4,通过nums[originIndices[rightIdx]]找到nums中的1,发现4 > 1,执行rightCount++,sortedIndices[sortedIdx++] = originIndices[rightIdx++]。
接下来还是nums[originIndices[leftIdx]] > nums[originIndices[right]]成立,继续执行rightCount++,sortedIndices[sortedIdx++] = originIndices[rightIdx++]。下一步发现rightIdx已经用尽,将leftIdx指向的剩余元素写入sortedIndices中。向sortedIndices写入排序后的索引时,还需要执行counts[originIndices[leftIdx]]+=rightCount,这个rightCount = 2就表示刚刚右边小于左边的那两个数3和1。

这个过程看起来比较烧脑,其实就是merge sort过程中记录nums[i]的右边有几个数搬到了nums[i]的左边,我们就把从右边搬到左边的个数rightCount累积到counts[i]中。但是我们又不能直接搬动nums数组的元素,所以引入originIndices数组,通过改变索引的顺序来实现对nums排序。

时间复杂度O(nlogn),空间复杂度O(n)。全部代码如下。

import java.util.ArrayList;
import java.util.List;

class Solution {
    public List<Integer> countSmaller(int[] nums) {
        int[] counts = new int[nums.length];
        int[] originIndices = new int[nums.length];

        for (int i = 0; i < nums.length; i++) {
            originIndices[i] = i;
        }

        mergeSort(nums, originIndices, counts, 0, nums.length - 1);

        List<Integer> countsList = new ArrayList<>();
        for (int cnt : counts)
            countsList.add(cnt);

        return countsList;
    }

    private void mergeSort(int[] nums, int[] originIndices, int[] counts, int start, int end) {
        if (end <= start)
            return;
        int mid = (start + end) >>> 1;

        mergeSort(nums, originIndices, counts, start, mid);
        mergeSort(nums, originIndices, counts, mid + 1, end);

        merge(nums, originIndices, counts, start, mid, mid + 1, end);
    }

    private void merge(int[] nums, int[] originIndices, int[] counts, int s1, int e1, int s2, int e2) {
        int leftIdx = s1, rightIdx = s2;
        int sortedIdx = 0;
        int[] sortedIndices = new int[e1 - s1 + 1 + e2 - s2 + 1]; // leftLen + rightLen;
        int rightCount = 0;

        while (leftIdx <= e1 && rightIdx <= e2) {
            int leftOriginIdx = originIndices[leftIdx];
            int rightOriginIdx = originIndices[rightIdx];

            if (nums[leftOriginIdx] > nums[rightOriginIdx]) {
                sortedIndices[sortedIdx++] = rightOriginIdx;
                rightIdx++;
                rightCount++;
            } else {
                sortedIndices[sortedIdx++] = leftOriginIdx;
                leftIdx++;
                counts[leftOriginIdx] += rightCount;
            }
        }

        // left 或者 right用尽,上面的while循环无法处理,应当分别处理
        while (leftIdx <= e1) {
            int leftOriginIdx = originIndices[leftIdx++];
            sortedIndices[sortedIdx++] = leftOriginIdx;
            counts[leftOriginIdx] += rightCount;
        }

        while (rightIdx <= e2) {
            int rightOriginIdx = originIndices[rightIdx++];
            sortedIndices[sortedIdx++] = rightOriginIdx;
        }

        // 将排序后的indices写回
        for (sortedIdx = 0; sortedIdx < sortedIndices.length; sortedIdx++) {
            originIndices[s1 + sortedIdx] = sortedIndices[sortedIdx];
        }
    }
}

解法2

解法2使用二叉搜索树(Binary Search Tree,BST)。我们逆序遍历nums,依次向BST中插入各个元素。和一般的BST不同,我们向TreeNode中添加两个字段,leftChildCnt和dup,如图1所示。leftChildCnt表示节点左子树的数量,dup表示元素的重复次数。

图1 BST的节点结构

那么我们怎么通过BST实现寻找nums[i]右侧小于nums[i]元素的数量呢?

我们定义insert函数:TreeNode insert(TreeNode root, int[] nums, int[] counts, int i, int preSum)。这里解释下preSum,我们用它记录小于nums[i]元素的个数。

我们递归的向BST中插入节点nums[i]时,调用insert(root, nums, counts, i, 0)。首先nums[i]跟root对比:

如果nums[i]比root.val大,那么需要从root的右子树出发递归调用insert(root.right, nums, counts, i, preSum + root.leftChildCnt + root.dup)。preSum累积了之前遇到的小于nums[i]的元素数量。所以preSum = preSum + root.leftChildCnt + root.dup。因为nums[i]大于root.val,那么nums[i]一定大于root的左子树的所有元素,我们需要加上root.leftChildCnt。因为root元素可能重复出现,我们还需要累积root.dup。

如果nums[i]比root.val小,那么需要从root的左子树出发递归调用insert(root.left, nums, counts, i, preSum),还需要更新root.leftChildCnt++,来记录左子树元素的数量。

如果nums[i]与root.val相等,那么我们需要更新root.dup++,来记录重复元素的数量。

我们递归的调用insert函数,直到insert的root参数为空,此时我们走到了BST中合适的位置来插入nums[i]元素。我们将沿途中累积的preSum写入到counts[i]中,就找到了nums[i]右侧小于nums[i]元素的数量。

图2 BST更新过程

我们在图2画出了nums = [3, 2, 2, 6, 1]分别插入1、6、2、2、3的过程。我们首先插入1,因为1首次出现所以1的左孩子的数量为0,1的重复元素数量为1。首次调用insert,root参数为null,preSum参数为0。我们创建新节点1,然后写入counts[4] = preSum = 0。

接下来我们插入6。因为6比1大,走向1的右子树。此时我们发现1传递给6的root参数为null,表示我们已经到达了合适位置,创建节点6并返回。此时,preSum累积了root(val=1).leftChildCnt + root.dup = 0 + 1 = 1。我们写入counts[3] = preSum = 1。

接下来我们插入2,2比1大,走向1的右子树,到达节点6;2比6小,走向6的左子树,更新root[val=6].leftChildCnt++。此时insert的root参数为null,我们到达合适的位置,创建新节点2。在沿途中,累积了preSum = root(val=1).leftChildCnt + root.dup = 0 + 1 = 1。我们写入counts[2] = preSum = 1。

接下来我们插入2,2比1大,走向1的右子树,到达节点6;2比6小,走向6的左子树,更新root(val=6).leftChildCnt++。此时insert的root.val = 2,我们发现已经存在了2,更新root.dup++。在沿途中,累积了preSum = root(val=1).leftChildCnt + root.dup = 1 + 0 = 1。我们写入counts[1] = preSum + root(val=2).leftChildCnt = 1 + 0 = 1。对于出现这种重复元素的case,我们不能够只写入preSum,还需要加上root.leftChildCnt。否则,我们会漏掉重复元素之间小于nums[i]的元素。

图3 BST更新过程

图3的例子说明对于出现重复元素的情况下不能只统计preSum。这里我们只说明插入num[0] = 6的情况。首先6与root(val=1)对比,6比1大,转向1的右子树;到达root(val=6)时,发现nums[0] == root.val。如果我们只统计preSum,那么只会统计第一个小于6的元素数量而会落下两个6之间的2,因为这两个2的信息记录在root[val=6].leftChildCnt中。所以nums[i] == root.val时,counts[i] = preSum + root.leftChildCnt。

和Merge sort方法相比,对于nums = [n, n-1, n-2, … , 0]这种极端case的时间复杂度是O(n^2),因为逆序插入时分别需要遍历0 + 1 + 2 + … + n个BST的节点。例如nums = [6,5,4,3,2,1],如果我们最后插入6,需要遍历5,4,3,2,1才到达合适的位置。对于一般的case,这种方法能达到O(nlogn)的时间复杂度。全部代码如下所。

import java.util.ArrayList;
import java.util.List;

class Solution {
    private class TreeNode {
        int val;
        int leftChildCnt;
        int dup;
        TreeNode left;
        TreeNode right;

        TreeNode(int val) {
            this.val = val;
            this.dup = 1;
        }
    }

    public List<Integer> countSmaller(int[] nums) {
        int[] counts = new int[nums.length];
        TreeNode root = null;

        for (int i = nums.length - 1; i >= 0; i--) {
            root = insert(root, nums, counts, i, 0);
        }

        List<Integer> countList = new ArrayList<>(nums.length);
        for (int c : counts)
            countList.add(c);

        return countList;
    }

    private TreeNode insert(TreeNode root, int[] nums, int[] counts, int i, int preSum) {
        int num = nums[i];

        if (root == null) {
            counts[i] = preSum;
            return new TreeNode(num);
        }

        if (num < root.val) {
            root.leftChildCnt++;
            root.left = insert(root.left, nums, counts, i, preSum);
        } else if (num > root.val)
            root.right = insert(root.right, nums, counts, i, preSum + root.leftChildCnt + root.dup);
        else {
            root.dup++;
            // 这里需要加上自己左孩子的数量,否则第二次插入同一个元素,只会统计这个元素第一次右边小于它的数量
            // 在第1次与第2次之间又插入新的小于nums[i]的元素会漏掉
            // 例如 {6, 2, 2, 6, 1}当insert(nums[0[])的时候,只会将nums[1] = 6后面的1统计进去,而两个6之间的2就落下了
            counts[i] = preSum + root.leftChildCnt;
        }
        return root;
    }
}
pwrliang Algorithms, Array, Sort

Leave a Reply

Your email address will not be published. Required fields are marked *