1. 程式人生 > >718. Maximum Length of Repeated Subarray 字尾陣列解最長公共子串 O(n log^2 n)時間複雜度

718. Maximum Length of Repeated Subarray 字尾陣列解最長公共子串 O(n log^2 n)時間複雜度

題意

  • 找最長公共子串

思路

  • 用dp的方法很容易在O(n^2)解決問題,這裡主要討論用字尾陣列的思路解決這個問題
  • 字尾數組裡有兩個經典的概念或者稱為資料結構,就是字尾陣列SA,以及高度陣列LCP
  • SA陣列的定義是:將原串S所有的字尾按字典序排序後,定義rank(i)為字尾S[i…]的排名,SA陣列是rank陣列的逆對映,即SA(rank(i)) = i
  • LCP陣列的定義是:LCP(i)是字尾S[SA(i)…]和字尾S[SA(i+1)…]的最長公共字首長度
  • 這裡就不討論這兩個陣列的求解演算法了,我們使用比較簡單的倍增法求解SA陣列,複雜度是O(n log^2 n)的,有了SA陣列,求解LCP陣列是O(n)的
  • 有了LCP陣列後,我們先來思考另一個問題,一個數組裡兩個不同的子串的最長公共子串是多長呢?答案是max(LCP),也就是LCP數組裡的最大值。原因的話反證一下很簡單,這裡簡單說明一下,主要是考慮陣列其它子串和以i開頭的子串的最長公共子串是多長,容易證明能達到最長的只能是以SA(rank(i)-1)或SA(rank(i)+1)開頭的子串,那麼這個結果都儲存在LCP裡了,所以遍歷一遍LCP就能找到最大值
  • 利用上述結論,我們很容易解決新的問題了。可以把兩個陣列拼在一起,並在拼接處加一個特殊的int,是在兩個數組裡都沒有出現的
  • 求出LCP陣列後,我們只要找i和SA(i+1)不在同一個字串的LCP的最大值即可

實現

class Solution {
public:
    //size of rank and sa are n+1
    //size of lcp is n, definition of lcp[i] is max common prefix of s[sa[i]...] and s[sa[i+1]...]
    //input s of getSa and getLcp can be string as well
    vector<int> rank, sa, lcp;
    void getSa(const vector<int>& s,
vector<int>& rank, vector<int>& sa){ int n = s.size(); vector<int> tmp(n+1); for (int i = 0; i < n; i++){ sa.push_back(i); rank.push_back(s[i]); } sa.push_back(n); rank.push_back(-1); for (int k = 1; k <= n; k <<= 1){ auto cmp = [&](int x, int y){ if (rank[x] != rank[y]) return rank[x] < rank[y]; auto tx = x + k > n ? -1 : rank[x + k]; auto ty = y + k > n ? -1 : rank[y + k]; return tx < ty; }; sort(sa.begin(), sa.end(), cmp); tmp[sa[0]] = 0; for (int i = 1; i <= n; i++){ tmp[sa[i]] = tmp[sa[i-1]]; if (cmp(sa[i-1], sa[i])){ tmp[sa[i]]++; } } for (int i = 0; i <= n; i++) rank[i] = tmp[i]; } } void getLcp(const vector<int>& s, const vector<int>& rank, const vector<int>& sa, vector<int>& lcp){ int n = s.size(); lcp.insert(lcp.begin(), n, 0); for (int i = 0, h = 0; i < n; i++){ if (h > 0) h--; int k = rank[i]; int j = sa[k-1]; while (max(j, i) + h < n && s[j+h] == s[i+h]){ h++; } lcp[k-1] = h; } } int findLength(vector<int>& A, vector<int>& B) { int n = A.size(), m = B.size(); A.push_back(101); A.insert(A.end(), B.begin(), B.end()); getSa(A, rank, sa); getLcp(A, rank, sa, lcp); int ans = 0; for (int i = 0; i <= n + m; i++){ if (sa[i] < n && sa[i+1] > n || sa[i] > n && sa[i+1] < n){ ans = max(ans, lcp[i]); } } return ans; } };