1. 程式人生 > >利用Mybatis攔截器實現分頁查詢

利用Mybatis攔截器實現分頁查詢

手寫Mybatis攔截器

版本 Spring Boot 2.0.3.RELEASE

Mybatis自定義攔截器

如果有閱讀過我之前一篇部落格 Hibernate 重新整理上下文 的朋友應該還記得 Hibernate 的上下文中可以新增自定義的事件監聽器。當初是為了解決一個類似於二段提交的的問題,後面我利用 Hibernate 自帶的上下文事件監聽器算是比較優雅的處理了。所以當時就想看看 Mybatis 這邊有沒有什麼類似的方式處理,於是就有了這篇文章。

我看可以來先看看Mybatis 官網上對攔截器的介紹。Mybatis 官網對攔截器稱呼為外掛(plugins)官網的介紹也比較簡單,關鍵就是一個小 demo 如下

// ExamplePlugin.java
@Intercepts({@Signature(
  type= Executor.class,
  method = "update",
  args = {MappedStatement.class,Object.class})})
public class ExamplePlugin implements Interceptor {
  public Object intercept(Invocation invocation) throws Throwable {
    return invocation.proceed();
  }
  public
Object plugin(Object target) { return Plugin.wrap(target, this); } // 可以通過 Properties 獲取到你想要的一些配置資訊 public void setProperties(Properties properties) { } }
<!-- mybatis-config.xml -->
<plugins>
  <plugin interceptor="org.mybatis.example.ExamplePlugin">
    <property name
="someProperty" value="100"/>
</plugin> </plugins>

Spring Boot 自定義 Mybatis 攔截器

我們可以根據官網上的介紹來自己寫一個簡單的 Mybatis 攔截器,我寫的簡易程式碼如下。在攔截上直接宣告@Component即可註冊

@Intercepts({
        @Signature(
                type = Executor.class,
                method = "update", args = {MappedStatement.class, Object.class}
        )
})
@Component
public class MyIntertceptor implements Interceptor {
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        System.out.println("進入攔截器");
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
    }
}

可以看出核心就是實現 public Object intercept(Invocation invocation) throws Throwable這麼一個方法。效果就當執行 update 相關操作(insert ,update 語句)時會觸發執行,打印出進入攔截器

MyBatis 允許你在已對映語句執行過程中的某一點進行攔截呼叫。預設情況下,MyBatis 允許使用外掛來攔截的方法呼叫包括:

Executor (update, query, flushStatements, commit, rollback, getTransaction, close, isClosed) ParameterHandler (getParameterObject, setParameters) ResultSetHandler (handleResultSets, handleOutputParameters) StatementHandler (prepare, parameterize, batch, update, query)

我這裡是使用了 Executor 執行時處理的攔截器,有對應著上面幾種情況時的處理。

實現簡易的分頁查詢

設計思路

  • 呼叫形式
  • 資料庫方言
  • 攔截器邏輯

呼叫方法

有使用過Mybatis分頁外掛 PageHelper的應該都知道是先呼叫一個靜態方法,對下條sql語句進行攔截,在new 一個分頁物件時自動處理。

在PageHelper中是利用了ThreadLocal 本地執行緒變數副本來處理的,當執行那個方法時往ThreadLocal設定一個分頁引數值,所以它每次只對下一條SQL語句有效。所以這裡我也準備這麼做。在new 分頁物件時remove掉ThreadLocal中的變數值 程式碼如下

public class PageResult <T>{
    private long total;

    private List<T> data;

    public PageResult(List<T> data) {
        this.data = data;
        PageInterceptor.PageParm pageParm = PageInterceptor.PARM_THREAD_LOCAL.get();
        if(pageParm != null){
            total = pageParm.totalSize;
            PageInterceptor.PARM_THREAD_LOCAL.remove();
        }
    }

    public long getTotal() {
        return total;
    }

    public List<T> getData() {
        return data;
    }
}

@Intercepts({
        @Signature(
                type = Executor.class,method = "query",
                args = {MappedStatement.class,Object.class,RowBounds.class,ResultHandler.class}
        )
})
@Component
public class PageInterceptor implements Interceptor {
	...
	
	static final ThreadLocal<PageParm> PARM_THREAD_LOCAL = new ThreadLocal<>();

    static class PageParm{
        // 分頁開始位置
        int offset;
        // 分頁數量
        int limit;
        // 總數
        long totalSize;
    }

    /**
     * 開始分頁
     * @param pageNum 當前頁碼 從0開始
     * @param pageSize 每頁長度
     */
    public static void startPage(int pageNum,int pageSize){
        int offset = pageNum * pageSize;
        int limit = pageSize;
        PageParm pageParm = new PageParm();
        pageParm.offset = offset;
        pageParm.limit = limit;
        PARM_THREAD_LOCAL.set(pageParm);
    }
}

資料庫方言問題 構建分頁SQL

我這裡用了一個策略模式,定義好一個方言介面,不同的資料使用不同的方言實現,在注入時生宣告,目前我只有一個MySQL所以也不算完全的策略模式。一個分頁是需要兩條語句的,一個是count 一個是 limit。

public interface Dialect {

    /**
     * 獲取countSQL語句
     * @param targetSql
     * @return
     */
    default String getCountSql(String targetSql){
        return String.format("select count(1) from (%s) tmp_count",targetSql);
    }

    String getLimitSql(String targetSql, int offset, int limit);
}
@Component //我這裡直接指定了,當然最好是使用 @bean 這樣把它new出來更好一些
public class MysqlDialect implements Dialect {

    private static final String PATTERN = "%s limit %s, %s";

    private static final String PATTERN_FIRST = "%s limit %s";

    @Override
    public String getLimitSql(String targetSql, int offset, int limit) {
        if (offset == 0) {
            return String.format(PATTERN_FIRST, targetSql, limit);
        }

        return String.format(PATTERN, targetSql, offset, limit);
    }
}

攔截器核心邏輯

在貼出程式碼之前,我想先感謝一下 buzheng同學,因為這裡面的攔截器核心邏輯有很大一部分就是參考他寫的Mybatis分頁中攔截器的實現。

@Override
public Object intercept(Invocation invocation) throws Throwable {
    final Object[] args = invocation.getArgs();
    PageParm pageParm = PARM_THREAD_LOCAL.get();
    //判斷是否需要進分頁
    if(pageParm != null){
        final MappedStatement ms = (MappedStatement)args[MAPPED_STATEMENT_INDEX];
        Object param = args[PARAMETER_INDEX];
        BoundSql boundSql = ms.getBoundSql(param);
        // 獲取總數
        pageParm.totalSize = queryTotal(ms,boundSql);
        // 重新設定SQL語句對映
        args[MAPPED_STATEMENT_INDEX] = copyPageableMappedStatement(ms,boundSql);
    }
    Object proceed = invocation.proceed();
    return proceed;
}

獲取資料的總數量 -> count

/**
 * 查詢總記錄數 基本上屬於直接抄的
 * @param mappedStatement
 * @param boundSql
 * @return
 * @throws SQLException
 */
private long queryTotal(MappedStatement mappedStatement, BoundSql boundSql) throws SQLException {

    Connection connection = null;
    PreparedStatement countStmt = null;
    ResultSet rs = null;
    try {

        connection = mappedStatement.getConfiguration().getEnvironment().getDataSource().getConnection();

        String countSql = this.dialect.getCountSql(boundSql.getSql());

        countStmt = connection.prepareStatement(countSql);
        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql,
                boundSql.getParameterMappings(), boundSql.getParameterObject());

        setParameters(countStmt, mappedStatement, countBoundSql, boundSql.getParameterObject());

        rs = countStmt.executeQuery();
        long totalCount = 0;
        if (rs.next()) {
            totalCount = rs.getLong(1);
        }

        return totalCount;
    } catch (SQLException e) {
        logger.error("查詢總記錄數出錯", e);
        throw e;
    } finally {
        if (rs != null) {
            try {
                rs.close();
            } catch (SQLException e) {
                logger.error("exception happens when doing: ResultSet.close()", e);
            }
        }

        if (countStmt != null) {
            try {
                countStmt.close();
            } catch (SQLException e) {
                logger.error("exception happens when doing: PreparedStatement.close()", e);
            }
        }

        if (connection != null) {
            try {
                connection.close();
            } catch (SQLException e) {
                logger.error("exception happens when doing: Connection.close()", e);
            }
        }
    }
}
/**
 * 對SQL引數(?)設值
 *
 * @param ps
 * @param mappedStatement
 * @param boundSql
 * @param parameterObject
 * @throws SQLException
 */
private void setParameters(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql,
                           Object parameterObject) throws SQLException {
    ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, boundSql);
    parameterHandler.setParameters(ps);
}

利用方言介面替換原始的SQL語句

private MappedStatement copyPageableMappedStatement(MappedStatement ms, BoundSql boundSql) {
    PageParm pageParm = PARM_THREAD_LOCAL.get();
    String pageSql = dialect.getLimitSql(boundSql.getSql(),pageParm.offset,pageParm.limit);
    SqlSource source = new StaticSqlSource(ms.getConfiguration(),pageSql,boundSql.getParameterMappings());
    return copyFromMappedStatement(ms,source);
}

/**
 * 利用新生成的SQL語句去替換原來的MappedStatement
 * @param ms
 * @param newSqlSource
 * @return
 */
private MappedStatement copyFromMappedStatement(MappedStatement ms,SqlSource newSqlSource) {
    MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(),ms.getId(),newSqlSource,ms.getSqlCommandType());

    builder.resource(ms.getResource());
    builder.fetchSize(ms.getFetchSize());
    builder.statementType(ms.getStatementType());
    builder.keyGenerator(ms.getKeyGenerator());
    if(ms.getKeyProperties() != null && ms.getKeyProperties().length !=0){
        StringBuffer keyProperties = new StringBuffer();
        for(String keyProperty : ms.getKeyProperties()){
            keyProperties.append(keyProperty).append(",");
        }
        keyProperties.delete(keyProperties.length()-1, keyProperties.length());
        builder.keyProperty(keyProperties.toString());
    }

    //setStatementTimeout()
    builder.timeout(ms.getTimeout());

    //setStatementResultMap()
    builder.parameterMap(ms.getParameterMap());

    //setStatementResultMap()
    builder.resultMaps(ms.getResultMaps());
    builder.resultSetType(ms.getResultSetType());

    //setStatementCache()
    builder.cache(ms.getCache());
    builder.flushCacheRequired(ms.isFlushCacheRequired());
    builder.useCache(ms.isUseCache());

    return builder.build();
}

這樣在執行了分頁查詢的時候,會額外執行一條count語句,並且把原來的SQL換成帶有limit的語句最終查詢的結果就如下

@GetMapping("/all")
public Object all(){
    PageInterceptor.startPage(1,2);
    List<Model> all = dao.findAll();
    PageResult<Model> modelPageResult = new PageResult<>(all);
    return modelPageResult;
}

{  
   total:3,
   data:-   [  
      -      {  
         id:"2",
         name:null,
         code:"123"
      }
   ]
}

我的程式碼已經放在了github上歡迎大家隨時star

github 地址https://github.com/newShiJ/Mybatis-Pageable