1. 程式人生 > >在mybatis執行SQL語句之前進行攔擊處理

在mybatis執行SQL語句之前進行攔擊處理

比較適用於在分頁時候進行攔截。對分頁的SQL語句通過封裝處理,處理成不同的分頁sql。

實用性比較強。

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;

import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;

import com.yidao.utils.Page;
import com.yidao.utils.ReflectHelper;

/** 
 * 
 * 分頁攔截器,用於攔截需要進行分頁查詢的操作,然後對其進行分頁處理。 
 * 利用攔截器實現Mybatis分頁的原理: 
 * 要利用JDBC對資料庫進行操作就必須要有一個對應的Statement物件,Mybatis在執行Sql語句前就會產生一個包含Sql語句的Statement物件,而且對應的Sql語句 
 * 是在Statement之前產生的,所以我們就可以在它生成Statement之前對用來生成Statement的Sql語句下手。在Mybatis中Statement語句是通過RoutingStatementHandler物件的 
 * prepare方法生成的。所以利用攔截器實現Mybatis分頁的一個思路就是攔截StatementHandler介面的prepare方法,然後在攔截器方法中把Sql語句改成對應的分頁查詢Sql語句,之後再呼叫 
 * StatementHandler物件的prepare方法,即呼叫invocation.proceed()。 
 * 對於分頁而言,在攔截器裡面我們還需要做的一個操作就是統計滿足當前條件的記錄一共有多少,這是通過獲取到了原始的Sql語句後,把它改為對應的統計語句再利用Mybatis封裝好的引數和設 
 * 置引數的功能把Sql語句中的引數進行替換,之後再執行查詢記錄數的Sql語句進行總記錄數的統計。 
 * 
 */  
@Intercepts({@Signature(type=StatementHandler.class,method="prepare",args={Connection.class})})
public class PageInterceptor implements Interceptor {
	private String dialect = ""; //資料庫方言  
    private String pageSqlId = ""; //mapper.xml中需要攔截的ID(正則匹配)  
      
    public Object intercept(Invocation invocation) throws Throwable {
    	//對於StatementHandler其實只有兩個實現類,一個是RoutingStatementHandler,另一個是抽象類BaseStatementHandler,  
        //BaseStatementHandler有三個子類,分別是SimpleStatementHandler,PreparedStatementHandler和CallableStatementHandler,  
        //SimpleStatementHandler是用於處理Statement的,PreparedStatementHandler是處理PreparedStatement的,而CallableStatementHandler是  
        //處理CallableStatement的。Mybatis在進行Sql語句處理的時候都是建立的RoutingStatementHandler,而在RoutingStatementHandler裡面擁有一個  
        //StatementHandler型別的delegate屬性,RoutingStatementHandler會依據Statement的不同建立對應的BaseStatementHandler,即SimpleStatementHandler、  
        //PreparedStatementHandler或CallableStatementHandler,在RoutingStatementHandler裡面所有StatementHandler介面方法的實現都是呼叫的delegate對應的方法。  
        //我們在PageInterceptor類上已經用@Signature標記了該Interceptor只攔截StatementHandler介面的prepare方法,又因為Mybatis只有在建立RoutingStatementHandler的時候  
        //是通過Interceptor的plugin方法進行包裹的,所以我們這裡攔截到的目標物件肯定是RoutingStatementHandler物件。
        if(invocation.getTarget() instanceof RoutingStatementHandler){  
            RoutingStatementHandler statementHandler = (RoutingStatementHandler)invocation.getTarget();  
            StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");  
            BoundSql boundSql = delegate.getBoundSql();
            Object obj = boundSql.getParameterObject();
            if (obj instanceof Page<?>) {  
                Page<?> page = (Page<?>) obj;  
                //通過反射獲取delegate父類BaseStatementHandler的mappedStatement屬性  
                MappedStatement mappedStatement = (MappedStatement)ReflectHelper.getFieldValue(delegate, "mappedStatement");  
                //攔截到的prepare方法引數是一個Connection物件  
                Connection connection = (Connection)invocation.getArgs()[0];  
                //獲取當前要執行的Sql語句,也就是我們直接在Mapper對映語句中寫的Sql語句  
                String sql = boundSql.getSql();  
                //給當前的page引數物件設定總記錄數  
                this.setTotalRecord(page,  
                       mappedStatement, connection);  
                //獲取分頁Sql語句  
                String pageSql = this.getPageSql(page, sql);  
                //利用反射設定當前BoundSql對應的sql屬性為我們建立好的分頁Sql語句  
                ReflectHelper.setFieldValue(boundSql, "sql", pageSql);  
            } 
        }  
        return invocation.proceed();  
    }
    
    /** 
     * 給當前的引數物件page設定總記錄數 
     * 
     * @param page Mapper對映語句對應的引數物件 
     * @param mappedStatement Mapper對映語句 
     * @param connection 當前的資料庫連線 
     */  
    private void setTotalRecord(Page<?> page,  
           MappedStatement mappedStatement, Connection connection) {  
       //獲取對應的BoundSql,這個BoundSql其實跟我們利用StatementHandler獲取到的BoundSql是同一個物件。  
       //delegate裡面的boundSql也是通過mappedStatement.getBoundSql(paramObj)方法獲取到的。  
       BoundSql boundSql = mappedStatement.getBoundSql(page);  
       //獲取到我們自己寫在Mapper對映語句中對應的Sql語句  
       String sql = boundSql.getSql();  
       //通過查詢Sql語句獲取到對應的計算總記錄數的sql語句  
       String countSql = this.getCountSql(sql);  
       //通過BoundSql獲取對應的引數對映  
       List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();  
       //利用Configuration、查詢記錄數的Sql語句countSql、引數對映關係parameterMappings和引數物件page建立查詢記錄數對應的BoundSql物件。  
       BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);  
       //通過mappedStatement、引數物件page和BoundSql物件countBoundSql建立一個用於設定引數的ParameterHandler物件  
       ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql);  
       //通過connection建立一個countSql對應的PreparedStatement物件。  
       PreparedStatement pstmt = null;  
       ResultSet rs = null;  
       try {  
           pstmt = connection.prepareStatement(countSql);  
           //通過parameterHandler給PreparedStatement物件設定引數  
           parameterHandler.setParameters(pstmt);  
           //之後就是執行獲取總記錄數的Sql語句和獲取結果了。  
           rs = pstmt.executeQuery();  
           if (rs.next()) {  
              int totalRecord = rs.getInt(1);  
              //給當前的引數page物件設定總記錄數  
              page.setTotalRecord(totalRecord);  
           }  
       } catch (SQLException e) {  
           e.printStackTrace();  
       } finally {  
           try {  
              if (rs != null)  
                  rs.close();  
               if (pstmt != null)  
                  pstmt.close();  
           } catch (SQLException e) {  
              e.printStackTrace();  
           }  
       }  
    }  
    
    /** 
     * 根據原Sql語句獲取對應的查詢總記錄數的Sql語句 
     * @param sql 
     * @return 
     */  
    private String getCountSql(String sql) {  
       int index = sql.indexOf("from");  
       return "select count(*) " + sql.substring(index);  
    }  
    
    /** 
     * 根據page物件獲取對應的分頁查詢Sql語句,這裡只做了兩種資料庫型別,Mysql和Oracle 
     * 其它的資料庫都 沒有進行分頁 
     * 
     * @param page 分頁物件 
     * @param sql 原sql語句 
     * @return 
     */  
    private String getPageSql(Page<?> page, String sql) {  
       StringBuffer sqlBuffer = new StringBuffer(sql);  
       if ("mysql".equalsIgnoreCase(dialect)) {  
           return getMysqlPageSql(page, sqlBuffer);  
       } else if ("oracle".equalsIgnoreCase(dialect)) {  
           return getOraclePageSql(page, sqlBuffer);  
       }  
       return sqlBuffer.toString();  
    }  
    
    /** 
    * 獲取Mysql資料庫的分頁查詢語句 
    * @param page 分頁物件 
    * @param sqlBuffer 包含原sql語句的StringBuffer物件 
    * @return Mysql資料庫分頁語句 
    */  
   private String getMysqlPageSql(Page<?> page, StringBuffer sqlBuffer) {  
      //計算第一條記錄的位置,Mysql中記錄的位置是從0開始的。  
//	   System.out.println("page:"+page.getPage()+"-------"+page.getRows());
      int offset = (page.getPage() - 1) * page.getRows();  
      sqlBuffer.append(" limit ").append(offset).append(",").append(page.getRows());  
      return sqlBuffer.toString();  
   }  
    
   /** 
    * 獲取Oracle資料庫的分頁查詢語句 
    * @param page 分頁物件 
    * @param sqlBuffer 包含原sql語句的StringBuffer物件 
    * @return Oracle資料庫的分頁查詢語句 
    */  
   private String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) {  
      //計算第一條記錄的位置,Oracle分頁是通過rownum進行的,而rownum是從1開始的  
      int offset = (page.getPage() - 1) * page.getRows() + 1;  
      sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getRows());  
      sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);  
      //上面的Sql語句拼接之後大概是這個樣子:  
      //select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16  
      return sqlBuffer.toString();  
   }  
   
      
    /** 
     * 攔截器對應的封裝原始物件的方法 
     */        
    public Object plugin(Object arg0) {  
        // TODO Auto-generated method stub  
    	if (arg0 instanceof StatementHandler) {  
            return Plugin.wrap(arg0, this);  
        } else {  
            return arg0;  
        } 
    }  
  
    /** 
     * 設定註冊攔截器時設定的屬性 
     */ 
    public void setProperties(Properties p) {
    	
    }

	public String getDialect() {
		return dialect;
	}

	public void setDialect(String dialect) {
		this.dialect = dialect;
	}

	public String getPageSqlId() {
		return pageSqlId;
	}

	public void setPageSqlId(String pageSqlId) {
		this.pageSqlId = pageSqlId;
	}
    
}

xml配置:

<!-- MyBatis 介面程式設計配置  -->
	<bean class="org.mybatis.spring.mapper.MapperScannerConfigurer">
	    <!-- basePackage指定要掃描的包,在此包之下的對映器都會被搜尋到,可指定多個包,包與包之間用逗號或分號分隔-->
	    <property name="basePackage" value="com.yidao.mybatis.dao" />
	    <property name="sqlSessionFactoryBeanName" value="sqlSessionFactory" />
	</bean>
	
	<!-- MyBatis 分頁攔截器-->
	<bean id="paginationInterceptor" class="com.mybatis.interceptor.PageInterceptor">
	    <property name="dialect" value="mysql"/> 
	    <!-- 攔截Mapper.xml檔案中,id包含query字元的語句 --> 
        <property name="pageSqlId" value=".*query$"/>
    </bean> 

Page類

package com.yidao.utils;


/**自己看看,需要什麼欄位加什麼欄位吧*/
public class Page {
	
	private Integer rows;
	
	private Integer page = 1;
	
	private Integer totalRecord;

	public Integer getRows() {
		return rows;
	}

	public void setRows(Integer rows) {
		this.rows = rows;
	}

	public Integer getPage() {
		return page;
	}

	public void setPage(Integer page) {
		this.page = page;
	}

	public Integer getTotalRecord() {
		return totalRecord;
	}

	public void setTotalRecord(Integer totalRecord) {
		this.totalRecord = totalRecord;
	}
	
}

ReflectHelper類

package com.yidao.utils;

import java.lang.reflect.Field;

import org.apache.commons.lang3.reflect.FieldUtils;

public class ReflectHelper {
	
	public static Object getFieldValue(Object obj , String fieldName ){
		
		if(obj == null){
			return null ;
		}
		
		Field targetField = getTargetField(obj.getClass(), fieldName);
		
		try {
			return FieldUtils.readField(targetField, obj, true ) ;
		} catch (IllegalAccessException e) {
			e.printStackTrace();
		} 
		return null ;
	}
	
	public static Field getTargetField(Class<?> targetClass, String fieldName) {
		Field field = null;

		try {
			if (targetClass == null) {
				return field;
			}

			if (Object.class.equals(targetClass)) {
				return field;
			}

			field = FieldUtils.getDeclaredField(targetClass, fieldName, true);
			if (field == null) {
				field = getTargetField(targetClass.getSuperclass(), fieldName);
			}
		} catch (Exception e) {
		}

		return field;
	}
	
	public static void setFieldValue(Object obj , String fieldName , Object value ){
		if(null == obj){return;}
		Field targetField = getTargetField(obj.getClass(), fieldName);	
		try {
			 FieldUtils.writeField(targetField, obj, value) ;
		} catch (IllegalAccessException e) {
			e.printStackTrace();
		} 
	} 
}