1. 程式人生 > >【死磕演算法之1刷Leetcode】——找出兩個有序陣列的中位數【Median of Two Sorted Arrays】O(log(m+n))

【死磕演算法之1刷Leetcode】——找出兩個有序陣列的中位數【Median of Two Sorted Arrays】O(log(m+n))

Median of Two Sorted Arrays
題目難度:hard
題目要求
There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
You may assume nums1 and nums2 cannot be both empty.

翻譯:有兩個長度分別為m和n的有序陣列,分別為nums1和nums2,找到兩個陣列的中位數,時間複雜度要求為O(log(m+n))。num1和nums2不為空。

菜雞解法:

我的做法比較複雜,但通過了。

Runtime: 104 ms, faster than 65.11% of Python3 online submissions for
Median of Two Sorted Arrays.

大概思路是,既然兩個陣列大小已知,那麼我們就可以根據兩個有序陣列有序合併之後(記為N)的大小判斷出中位數應當由第幾個數來得出。因此設定i和j逐個比較兩個陣列元素大小,並用size記錄合併陣列N中已經確定的元素個數,當size遞增到中位數的位置時,即可表示出中位數。
處理N的長度奇偶對中位數求解的影響:用taridx1和taridx2分別記錄在合併後的陣列中,計算中位數所需的索引。

class Solution:
   
    def findMedianSortedArrays(self, nums1, nums2):
        """
        :type nums1: List[int]
        :type nums2: List[int]
        :rtype: float
        """
        len1 = len(nums1)
        len2 = len(nums2)
        s = (len1 + len2)
        ifeven = True if (s % 2 == 0) else False
        size = 0 
        i = 0
        j = 0
        sum1 = 0
        sum2 = 0
        taridx1 = (s / 2 - 1) if ifeven else (s - 1) / 2 
        taridx2 = (s) / 2 if ifeven else (s - 1) / 2
        while (size < s):
            if ( i < len1 and ((nums1[i] <= nums2[j])if j < len2 else 1) ):
                if (size == taridx1):
                    sum1 = nums1[i]
                if (size == taridx2):
                    sum2 = nums1[i]
                    return (sum1 + sum2) / 2 if ifeven else (int)(sum1 + sum2) / 2
                i += 1
                size += 1
           if (j < len2 and ((nums2[j] < nums1[i]) if i<len1 else 1)):
                if (size == taridx1):
                    sum1 = nums2[j]
                if (size == taridx2):
                    sum2 = nums2[j]
                    return (sum1 + sum2) / 2 if ifeven else (int)(sum1 + sum2) / 2
                j += 1
                size += 1

其他解法

另一種相似演算法是用空間換時間:建立空列表,直接生成合並後的有序陣列再去找中位數

class Solution:
    def findMedianSortedArrays(self, nums1, nums2):
        """
        :type nums1: List[int]
        :type nums2: List[int]
        :rtype: float
        """
        listr = []
        x = len(nums1)
        y = len(nums2)
        lentotal = x + y
        i = 0
        j = 0
        while(i<x and j < y):
            if (nums1[i] <= nums2[j]):
                listr.append(nums1[i])
                i+=1
           # if (nums1[i] > nums2[j]):
            else:
                listr.append(nums2[j])
                j+=1
        if (i<len(nums1)):
            listr += nums1[i:x]
        if (j < len(nums2)):
            listr += nums2[j:y]
        if lentotal%2 ==0:
            return (listr[lentotal//2-1]+listr[lentotal//2])/2
        else: 
            return listr[(lentotal-1)//2]

討論區還有的python程式碼極短,用到了sort函式,但複雜度為O(nlog(n))不符合要求。

def findMedianSortedArrays(self, nums1, nums2):
    nums1.extend(nums2)#列表擴充套件
    nums1.sort()#sort函式內部演算法實現為timsort演算法
    half = len(nums1) // 2         
    return (nums1[half] + nums1[~half]) / 2

但以上時間複雜度都不符合要求,實現O(log(m+n))的時間複雜度需要用到二分查詢的思想。

二分查詢求兩有序陣列中第k大的元素

二分查詢的基本思想是,找到有序列表中心位置的元素,和待查元素比較,如果相等,就返回當前位置。如果不相等,折半確定新的查詢區間。

將問題泛化,用二分查詢求兩有序陣列中第k大的元素,舉一反三!

思路來自求兩遞增陣列的中位數(O(log(m + n)))你會嗎
下面是我的理解:

在這裡插入圖片描述

程式碼及註釋

class Solution:
    def findMedianSortedArrays(self, nums1, nums2):
        """
        :type nums1: List[int]
        :type nums2: List[int]
        :rtype: float
        """
        s = len(nums1)+len(nums2)
        if(s%2==0):
            return (self.binary_search(0,nums1,0,nums2,s//2)+self.binary_search(0,nums1,0,nums2,s//2+1))/2
        else:
            return self.binary_search(0,nums1,0,nums2,s//2+1)
        #定義跳出條件
    def binary_search(self,i,nums_1,j,nums_2,k):#  i為nums_1可比較元素序列最左端的索引,j為nums_2可比較元素序列最左端的索引,k為在nums_1和nums_2裡,要取第幾大的數
        len1= len(nums_1)
        len2 = len(nums_2)
        if(i >= len1):#將越界範圍判斷放到最前面,注意這裡返回的是一個數
            return nums_2[j+k-1]
        if(j >= len2):
            return nums_1[i+k-1]
        if(k==1):
            return nums_1[i] if nums_1[i]<nums_2[j] else nums_2[j]

        m = i+k//2-1 #i 和 m 分別是從nums_1中取出來k個元素的左右兩端索引值,即取出nums_1[i]...nums_1[m],注意判斷m或者n超出陣列範圍的情況
        n = j + k//2-1#i 和 n 分別是從nums_2中取出來k個元素的左右兩端索引值,即取出nums_2[j]...nums_2[n]
        p1 = nums_1[m] if m<len1 else 100000000#防止越界,這裡設定的值最好比較大
        #最開始設定的1000因為有的陣列大小超過1000,結果報錯了。
        p2 = nums_2[n] if n<len2 else 100000000
        if(p1< p2):#比較兩個陣列右端點的值
            return self.binary_search(m+1,nums_1,j,nums_2,k-k//2)#邏輯為王
        else:
            return self.binary_search(i,nums_1,n+1,nums_2,k-k//2)

java實現可參考求兩遞增陣列的中位數(O(log(m + n)))你會嗎中的程式碼。