myBatis學習筆記(10)——使用攔截器實現分頁查詢
1. Page
package com.sm.model;
import java.util.List;
public class Page<T> {
public static final int DEFAULT_PAGE_SIZE = 20;
protected int pageNo = 1; // 當前頁, 默覺得第1頁
protected int pageSize = DEFAULT_PAGE_SIZE; // 每頁記錄數
protected long totalRecord = -1; // 總記錄數, 默覺得-1, 表示須要查詢
protected int totalPage = -1; // 總頁數, 默覺得-1, 表示須要計算
protected List<T> results; // 當前頁記錄List形式
public int getPageNo() {
return pageNo;
}
public void setPageNo(int pageNo) {
this.pageNo = pageNo;
}
public int getPageSize() {
return pageSize;
}
public void setPageSize(int pageSize) {
this.pageSize = pageSize;
computeTotalPage();
}
public long getTotalRecord() {
return totalRecord;
}
public int getTotalPage() {
return totalPage;
}
public void setTotalRecord(long totalRecord) {
this .totalRecord = totalRecord;
computeTotalPage();
}
protected void computeTotalPage() {
if (getPageSize() > 0 && getTotalRecord() > -1) {
this.totalPage = (int) (getTotalRecord() % getPageSize() == 0 ? getTotalRecord() / getPageSize() : getTotalRecord() / getPageSize() + 1);
}
}
public List<T> getResults() {
return results;
}
public void setResults(List<T> results) {
this.results = results;
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder().append("Page [pageNo=").append(pageNo).append(", pageSize=").append(pageSize)
.append(", totalRecord=").append(totalRecord < 0 ? "null" : totalRecord).append(", totalPage=")
.append(totalPage < 0 ? "null" : totalPage).append(", results=").append(results == null ?
"null" : results).append("]");
return builder.toString();
}
}
2. 實現攔截器
package com.sm.model;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.DefaultParameterHandler;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }),
@Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) })
public class PageInterceptor implements Interceptor {
private static final Logger log = LoggerFactory.getLogger(PageInterceptor.class);
public static final String MYSQL = "mysql";
public static final String ORACLE = "oracle";
protected String databaseType;// 數據庫類型。不同的數據庫有不同的分頁方法
protected ThreadLocal<Page> pageThreadLocal = new ThreadLocal<Page>();
public String getDatabaseType() {
return databaseType;
}
public void setDatabaseType(String databaseType) {
if (!databaseType.equalsIgnoreCase(MYSQL) && !databaseType.equalsIgnoreCase(ORACLE)) {
throw new PageNotSupportException("Page not support for the type of database, database type [" + databaseType + "]");
}
this.databaseType = databaseType;
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
String databaseType = properties.getProperty("databaseType");
if (databaseType != null) {
setDatabaseType(databaseType);
}
}
@Override
@SuppressWarnings({ "unchecked", "rawtypes" })
public Object intercept(Invocation invocation) throws Throwable {
if (invocation.getTarget() instanceof StatementHandler) {// 控制SQL和查詢總數的地方
Page page = pageThreadLocal.get();
if (page == null) { //不是分頁查詢
return invocation.proceed();
}
RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
BoundSql boundSql = delegate.getBoundSql();
Connection connection = (Connection) invocation.getArgs()[0];
prepareAndCheckDatabaseType(connection); // 準備數據庫類型
if (page.getTotalPage() > -1) {
if (log.isTraceEnabled()) {
log.trace("已經設置了總頁數, 不須要再查詢總數.");
}
} else {
Object parameterObj = boundSql.getParameterObject();
MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
queryTotalRecord(page, parameterObj, mappedStatement, connection);
}
String sql = boundSql.getSql();
String pageSql = buildPageSql(page, sql);
if (log.isDebugEnabled()) {
log.debug("分頁時, 生成分頁pageSql: " + pageSql);
}
ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
return invocation.proceed();
} else { // 查詢結果的地方
// 獲取是否有分頁Page對象
Page<?> page = findPageObject(invocation.getArgs()[1]);
if (page == null) {
if (log.isTraceEnabled()) {
log.trace("沒有Page對象作為參數, 不是分頁查詢.");
}
return invocation.proceed();
} else {
if (log.isTraceEnabled()) {
log.trace("檢測到分頁Page對象, 使用分頁查詢.");
}
}
//設置真正的parameterObj
invocation.getArgs()[1] = extractRealParameterObject(invocation.getArgs()[1]);
pageThreadLocal.set(page);
try {
Object resultObj = invocation.proceed(); // Executor.query(..)
if (resultObj instanceof List) {
/* @SuppressWarnings({ "unchecked", "rawtypes" }) */
page.setResults((List) resultObj);
}
return resultObj;
} finally {
pageThreadLocal.remove();
}
}
}
protected Page<?> findPageObject(Object parameterObj) {
if (parameterObj instanceof Page<?>) {
return (Page<?>) parameterObj;
} else if (parameterObj instanceof Map) {
for (Object val : ((Map<?, ?
>) parameterObj).values()) {
if (val instanceof Page<?>) {
return (Page<?
>) val;
}
}
}
return null;
}
/**
* <pre>
* 把真正的參數對象解析出來
* Spring會自己主動封裝對個參數對象為Map<String, Object>對象
* [email protected],由於XML文件須要該KEY值
* [email protected],Spring會使用0,1作為主鍵
* [email protected],一般XML文件會直接對真正的參數對象解析。
* 此時解析出真正的參數作為根對象
* </pre>
* @author jundong.xu_C
* @param parameterObj
* @return
*/
protected Object extractRealParameterObject(Object parameterObj) {
if (parameterObj instanceof Map<?, ?>) {
Map<?, ?
> parameterMap = (Map<?, ?>) parameterObj;
if (parameterMap.size() == 2) {
boolean springMapWithNoParamName = true;
for (Object key : parameterMap.keySet()) {
if (!(key instanceof String)) {
springMapWithNoParamName = false;
break;
}
String keyStr = (String) key;
if (!"0".equals(keyStr) && !"1".equals(keyStr)) {
springMapWithNoParamName = false;
break;
}
}
if (springMapWithNoParamName) {
for (Object value : parameterMap.values()) {
if (!(value instanceof Page<?
>)) {
return value;
}
}
}
}
}
return parameterObj;
}
protected void prepareAndCheckDatabaseType(Connection connection) throws SQLException {
if (databaseType == null) {
String productName = connection.getMetaData().getDatabaseProductName();
if (log.isTraceEnabled()) {
log.trace("Database productName: " + productName);
}
productName = productName.toLowerCase();
if (productName.indexOf(MYSQL) != -1) {
databaseType = MYSQL;
} else if (productName.indexOf(ORACLE) != -1) {
databaseType = ORACLE;
} else {
throw new PageNotSupportException("Page not support for the type of database, database product name [" + productName + "]");
}
if (log.isInfoEnabled()) {
log.info("自己主動檢測到的數據庫類型為: " + databaseType);
}
}
}
/**
* <pre>
* 生成分頁SQL
* </pre>
*
* @author jundong.xu_C
* @param page
* @param sql
* @return
*/
protected String buildPageSql(Page<?> page, String sql) {
if (MYSQL.equalsIgnoreCase(databaseType)) {
return buildMysqlPageSql(page, sql);
} else if (ORACLE.equalsIgnoreCase(databaseType)) {
return buildOraclePageSql(page, sql);
}
return sql;
}
/**
* <pre>
* 生成Mysql分頁查詢SQL
* </pre>
*
* @author jundong.xu_C
* @param page
* @param sql
* @return
*/
protected String buildMysqlPageSql(Page<?> page, String sql) {
// 計算第一條記錄的位置,Mysql中記錄的位置是從0開始的。
int offset = (page.getPageNo() - 1) * page.getPageSize();
return new StringBuilder(sql).append(" limit ").append(offset).append(",").append(page.getPageSize()).toString();
}
/**
* <pre>
* 生成Oracle分頁查詢SQL
* </pre>
*
* @author jundong.xu_C
* @param page
* @param sql
* @return
*/
protected String buildOraclePageSql(Page<?> page, String sql) {
// 計算第一條記錄的位置。Oracle分頁是通過rownum進行的。而rownum是從1開始的
int offset = (page.getPageNo() - 1) * page.getPageSize() + 1;
StringBuilder sb = new StringBuilder(sql);
sb.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());
sb.insert(0, "select * from (").append(") where r >= ").append(offset);
return sb.toString();
}
/**
* <pre>
* 查詢總數
* </pre>
*
* @author jundong.xu_C
* @param page
* @param parameterObject
* @param mappedStatement
* @param connection
* @throws SQLException
*/
protected void queryTotalRecord(Page<?> page, Object parameterObject, MappedStatement mappedStatement, Connection connection) throws SQLException {
BoundSql boundSql = mappedStatement.getBoundSql(page);
String sql = boundSql.getSql();
String countSql = this.buildCountSql(sql);
if (log.isDebugEnabled()) {
log.debug("分頁時, 生成countSql: " + countSql);
}
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parameterObject);
ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql);
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
pstmt = connection.prepareStatement(countSql);
parameterHandler.setParameters(pstmt);
rs = pstmt.executeQuery();
if (rs.next()) {
long totalRecord = rs.getLong(1);
page.setTotalRecord(totalRecord);
}
} finally {
if (rs != null)
try {
rs.close();
} catch (Exception e) {
if (log.isWarnEnabled()) {
log.warn("關閉ResultSet時異常.", e);
}
}
if (pstmt != null)
try {
pstmt.close();
} catch (Exception e) {
if (log.isWarnEnabled()) {
log.warn("關閉PreparedStatement時異常.", e);
}
}
}
}
/**
* 依據原Sql語句獲取相應的查詢總記錄數的Sql語句
*
* @param sql
* @return
*/
protected String buildCountSql(String sql) {
int index = sql.indexOf("from");
return "select count(*) " + sql.substring(index);
}
/**
* 利用反射進行操作的一個工具類
*
*/
private static class ReflectUtil {
/**
* 利用反射獲取指定對象的指定屬性
*
* @param obj 目標對象
* @param fieldName 目標屬性
* @return 目標屬性的值
*/
public static Object getFieldValue(Object obj, String fieldName) {
Object result = null;
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
field.setAccessible(true);
try {
result = field.get(obj);
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
return result;
}
/**
* 利用反射獲取指定對象裏面的指定屬性
*
* @param obj 目標對象
* @param fieldName 目標屬性
* @return 目標字段
*/
private static Field getField(Object obj, String fieldName) {
Field field = null;
for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) {
try {
field = clazz.getDeclaredField(fieldName);
break;
} catch (NoSuchFieldException e) {
// 榪欓噷涓嶇敤鍋氬鐞嗭紝瀛愮被娌℃湁璇ュ瓧孌靛彲鑳藉搴旂殑鐖剁被鏈夛紝閮芥病鏈夊氨榪斿洖null銆?
}
}
return field;
}
/**
* 利用反射設置指定對象的指定屬性為指定的值
*
* @param obj 目標對象
* @param fieldName 目標屬性
* @param fieldValue 目標值
*/
public static void setFieldValue(Object obj, String fieldName, String fieldValue) {
Field field = ReflectUtil.getField(obj, fieldName);
if (field != null) {
try {
field.setAccessible(true);
field.set(obj, fieldValue);
} catch (IllegalArgumentException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}
public static class PageNotSupportException extends RuntimeException {
public PageNotSupportException() {
super();
}
public PageNotSupportException(String message, Throwable cause) {
super(message, cause);
}
public PageNotSupportException(String message) {
super(message);
}
public PageNotSupportException(Throwable cause) {
super(cause);
}
}
}
3. spring配置文件(mybatis已和spring整合)
<!-- 配置mybatis的sqlSessionFactory -->
<bean id="sqlSessionFactoryBean" class="org.mybatis.spring.SqlSessionFactoryBean">
<property name="dataSource" ref="dataSource"></property>
<!-- 配置了typeAliasesPackage之後,在映射文件裏,這個包下的實體類能夠不寫全名 -->
<property name="typeAliasesPackage" value="com.sm.model"></property>
<!-- 配置映射映射文件的位置 -->
<property name="mapperLocations" value="classpath:resources/mapper/*.xml"></property>
<property name="plugins">
<!-- 分頁攔截器 -->
<bean class="com.sm.model.PageInterceptor"></bean>
</property>
</bean>
4. mapper.xml
<select id="getUsers" resultType="User" parameterType="Map">
select * from user where username=#{user.username}
</select>
6. DAO
List<User> getUsers(Map map);
7. 測試
Page page = new Page();
//配置分頁參數
page.setPageNo(1);
page.setPageSize(3);
//條件查詢,傳參
User user = new User();
user.setUsername("2");
Map map = new HashMap<>();
map.put("user", user);
map.put("page", page);
List<User> list = userDAO.getUsers(map);
System.out.println(list);
System.out.println(page);
8. 總結
上面的分頁攔截器,拷下來直接用就好了。假設想了解實現原理。能夠看慕課網的視頻通過自己主動回復機器人學Mybatis—加強版
myBatis學習筆記(10)——使用攔截器實現分頁查詢