1. 程式人生 > >Spring手動掃描包路徑及容器外獲取Bean例項

Spring手動掃描包路徑及容器外獲取Bean例項

最近做的專案有一個需求,希望開放指定包下的Controller給其他應用呼叫,但需要驗證其許可。
解決方案:定義一個Filter,在init初始化方法內掃描指定包下的所有Controller,生成開放URL集合;在doFilter方法內對請求引數校驗(加鹽MD5生成)

方案用到了兩個工具類,第一個是HttpServletRequest的包裝類,主要是為了解決RequestBody兩次讀取的問題。正常情況下HttpServletRequest的流只能被讀取一次

/**
 * 自定義的Http請求封裝類,解決RequestBody只能讀取一次問題
 *
 * @create 2017-12-11 17:18
 */
public class HttpRequestTwiceReadingWrapper extends HttpServletRequestWrapper { private byte[] requestBody = null; public HttpRequestTwiceReadingWrapper(HttpServletRequest request) { super(request); //快取請求body try { requestBody = StreamUtils.copyToByteArray(request.getInputStream()); } catch
(IOException e) { e.printStackTrace(); } } /** * 重寫 getInputStream() */ @Override public ServletInputStream getInputStream() throws IOException { if (requestBody == null) { requestBody = new byte[0]; } InputStream bis = new
ByteArrayInputStream(requestBody); return new ServletInputStream() { @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return true; } @Override public void setReadListener(ReadListener readListener) { } @Override public int read() throws IOException { return bis.read(); } }; } /** * 重寫 getReader() */ @Override public BufferedReader getReader() throws IOException { return new BufferedReader(new InputStreamReader(getInputStream())); } }

第二個工具類是Spring外獲取容器內的Bean的類。

/**
 * Spring容器工具類,用於容器外物件獲取容器內的Bean
 *
 * @create 2017-12-12 14:19
 */
 @Component
public class SpringBeanInstanceAccessor implements BeanFactoryAware {

    //@Autowired不支援static屬性注入,只能用實現指定介面的形式
    private static BeanFactory factory;

    @Override
    public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
        factory = beanFactory;
    }

    /**
     * 獲取指定名稱的Bean
     *
     * @param beanName
     * @param clazz
     * @param <T>
     * @return
     */
    public static <T> Object getBean(String beanName, Class<T> clazz) {
        return factory.getBean(beanName, clazz);
    }

    /**
     * 獲取指定型別的Bean
     *
     * @param clazz
     * @param <T>
     * @return
     */
    public static <T> Object getBean(Class<T> clazz) {
        return factory.getBean(clazz);
    }
}

最後一個就是Filter了,掃描指定包下的Controller生成URL集合,獲取RequestBody內的引數,通過Spring的Mapper查詢引數指定的key,組合key+”,”+Referer+”,”+RequestBody,然後通過加鹽MD5生成token,然後和HttpHeader內的token驗證。

/**
 * 請求許可驗證過濾器
 * [token]和[timestamp]由HttpHeader傳入
 *
 * @create 2017-12-08 11:26
 */
public class CrosRequestPermitCheckingFilter implements Filter {

    private static final Logger LOGGER = LoggerFactory.getLogger(CrosRequestPermitCheckingFilter.class);

    private static final String RESOURCE_PATTERN = "**/*.class";

    private final List<TypeFilter> includeFilters = new LinkedList<TypeFilter>();
    private final List<TypeFilter> excludeFilters = new LinkedList<TypeFilter>();

    private static final String APPLICATION_NAMESPACE = "";
    private static final Integer PERMIT_VALIDITY_IN_MINUTE = 5;
    private static final List<String> EXPOSED_CONTROLLER_PACKAGES = Arrays.asList("com.bob.mvc.controller");

    private Set<String> exposedRequestUriSet = new LinkedHashSet<String>();

    /**
     * 初始化,掃描開放的Controller包,生成開放的URL集合
     *
     * @param filterConfig
     * @throws ServletException
     */
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        includeFilters.add(new AnnotationTypeFilter(RequestMapping.class, false));
        List<String> controllers = new ArrayList<String>();
        try {
            for (String pkg : EXPOSED_CONTROLLER_PACKAGES) {
                controllers.addAll(this.findCandidateControllers(pkg));
            }
            if (controllers.isEmpty()) {
                if (LOGGER.isWarnEnabled()) {
                    LOGGER.warn("掃描指定包{}時未發現符合的開放Controller類", EXPOSED_CONTROLLER_PACKAGES.toString());
                }
                return;
            }
            generateExposedURL(this.transformToClass(controllers), APPLICATION_NAMESPACE);
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("掃描指定Controller包,發現開放URL:{}", exposedRequestUriSet.toString());
            }
        } catch (Exception e) {
            LOGGER.error("掃描開放Controller出現異常", e);
            return;
        }

    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = new HttpRequestTwiceReadingWrapper((HttpServletRequest)servletRequest);
        String path = request.getRequestURI();
        if (path.endsWith("/")) {
            path = path.substring(0, path.length() - 1);
        }
        if (!exposedRequestUriSet.contains(path)) {
            filterChain.doFilter(servletRequest, servletResponse);
            return;
        }
        try {
            processCrosRequestPermitCkecking(request);
        } catch (IllegalArgumentException | IllegalStateException e) {
            writeResult(servletResponse, e.getMessage());
        }
        filterChain.doFilter(request, servletResponse);
    }

    /**
     * 處理跨域請求許可驗證
     *
     * @param request
     * @throws Exception
     */
    private void processCrosRequestPermitCkecking(HttpServletRequest request) throws IOException {
        String timestamp = request.getHeader("timestamp");
        Assert.hasText(timestamp, "跨域請求未指定[timestamp]");
        Assert.state(isNumber(timestamp), "[timestamp]不是一個有效的時間戳");
        Assert.state(Long.valueOf(timestamp) >= System.currentTimeMillis() - 1000 * 60 * PERMIT_VALIDITY_IN_MINUTE, "跨域請求許可已過期");
        String referer = request.getHeader("Referer");
        String requestBody = getRequestBodyInString(request);
        String appcode = getTargetProperty(requestBody, "appcode");
        Assert.hasText(appcode, "跨域請求[appcode]不存在");
        String campusId = getTargetProperty(requestBody, "campusId");
        Assert.isTrue(isNumber(campusId), "跨域請求[campusId]不正確");
        String key = selectKey(appcode, campusId);
        Assert.notNull(key, "跨域請求[appcode]和[campusId]相應的key不存在");
        Assert.state(verify(key + "," + referer + "," + requestBody, request.getHeader("token")), "跨域請求[token]和引數不匹配");
    }

    /**
     * 寫入錯誤結果
     *
     * @param servletResponse
     * @param result
     * @throws IOException
     */
    private void writeResult(ServletResponse servletResponse, String result) throws IOException {
        servletResponse.getOutputStream().write(result.getBytes("UTF-8"));
    }

    /**
     * 如果RequestBody內沒有資料,則返回""
     *
     * @param request
     * @return
     * @throws IOException
     */
    private String getRequestBodyInString(HttpServletRequest request) throws IOException {
        InputStream is = request.getInputStream();
        byte[] bytes = new byte[2048];
        int length = is.read(bytes);
        return length < 0 ? "" : new String(Arrays.copyOf(bytes, length), "UTF-8");
    }

    /**
     * 從json字串中獲取指定屬性的值
     *
     * @param json
     * @param property
     * @return
     */
    private String getTargetProperty(String json, String property) {
        if (StringUtils.isEmpty(json)) {
            return null;
        }
        String[] fragments = json.split(",");
        String value = null;
        for (String fragment : fragments) {
            if (fragment.contains(property)) {
                value = fragment.substring(fragment.indexOf(":") + 1).trim();
                if (value.contains("}")) {
                    value = value.substring(0, value.indexOf("}")).trim();
                }
                if (value.contains("\"")) {
                    value = value.substring(1, value.length() - 1).trim();
                }
                break;
            }
        }
        return value;
    }

    /**
     * 校驗加鹽後是否和原文一致
     *
     * @param password
     * @param md5
     * @return
     */
    private boolean verify(String password, String md5) {
        if (StringUtils.isEmpty(password) || StringUtils.isEmpty(md5)) {
            return false;
        }
        char[] cs1 = new char[32];
        char[] cs2 = new char[16];
        for (int i = 0; i < 48; i += 3) {
            cs1[i / 3 * 2] = md5.charAt(i);
            cs1[i / 3 * 2 + 1] = md5.charAt(i + 2);
            cs2[i / 3] = md5.charAt(i + 1);
        }
        return new String(cs1).equals(md5Hex(password + new String(cs2)));
    }

    /**
     * 獲取十六進位制字串形式的MD5摘要
     */
    private String md5Hex(String src) {
        try {
            MessageDigest md5 = MessageDigest.getInstance("MD5");
            byte[] bs = md5.digest(src.getBytes());
            return new String(new Hex().encode(bs));
        } catch (Exception e) {
            return null;
        }
    }

    /**
     * 字串是否是數字
     *
     * @param value
     * @return
     */
    private boolean isNumber(String value) {
        char[] chars = ((String)value).toCharArray();
        for (char c : chars) {
            if (!Character.isDigit(c)) {
                return false;
            }
        }
        return true;
    }

    /**
     * 獲取符合要求的Controller名稱
     * @ComponentScan就是使用這些程式碼掃描包,然後通過TypeFilter過濾想要的
     * @ComponentScan掃描時添加了一個AnnotationTypeFilter(Component.class, false)的型別過濾
     *
     * @param basePackage
     * @return
     * @throws IOException
     */
    private List<String> findCandidateControllers(String basePackage) throws IOException {
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("開始掃描包[" + basePackage + "]下的所有類");
        }
        List<String> controllers = new ArrayList<String>();
        String packageSearchPath = CLASSPATH_ALL_URL_PREFIX + replaceDotByDelimiter(basePackage) + '/' + RESOURCE_PATTERN;
        ResourceLoader resourceLoader = new DefaultResourceLoader();
        MetadataReaderFactory readerFactory = new SimpleMetadataReaderFactory(resourceLoader);
        Resource[] resources = ResourcePatternUtils.getResourcePatternResolver(resourceLoader).getResources(packageSearchPath);
        for (Resource resource : resources) {
            MetadataReader reader = readerFactory.getMetadataReader(resource);
            if (isCandidateController(reader, readerFactory)) {
                controllers.add(reader.getClassMetadata().getClassName());
                if (LOGGER.isDebugEnabled()) {
                    LOGGER.debug("掃描到符合要求開放Controller類:[" + controllers.get(controllers.size() - 1) + "]");
                }
            }
        }
        return controllers;
    }

    /**
     * 通過TypeFilter得到標識了@RequestMapping的類
     *
     * @param reader
     * @param readerFactory
     * @return
     * @throws IOException
     */
    protected boolean isCandidateController(MetadataReader reader, MetadataReaderFactory readerFactory) throws IOException {
        for (TypeFilter tf : this.excludeFilters) {
            if (tf.match(reader, readerFactory)) {
                return false;
            }
        }
        for (TypeFilter tf : this.includeFilters) {
            if (tf.match(reader, readerFactory)) {
                return true;
            }
        }
        return false;
    }

    /**
     * 將類名轉換為類物件
     *
     * @param classNames
     * @return
     * @throws ClassNotFoundException
     */
    private List<Class<?>> transformToClass(List<String> classNames) throws ClassNotFoundException {
        List<Class<?>> classes = new ArrayList<Class<?>>(classNames.size());
        for (String className : classNames) {
            classes.add(ClassUtils.forName(className, this.getClass().getClassLoader()));
        }
        return classes;
    }

    /**
     * 用"/"替換包路徑中"."
     *
     * @param path
     * @return
     */
    private String replaceDotByDelimiter(String path) {
        return StringUtils.replace(path, ".", "/");
    }

    /**
     * 內省Controllers,生成開放的URL集合
     *
     * @param controllers
     * @param prefix
     */
    private void generateExposedURL(List<Class<?>> controllers, String prefix) {
        for (Class<?> controller : controllers) {
            String[] classMappings = controller.getAnnotation(RequestMapping.class).value();
            ReflectionUtils.doWithMethods(controller,
                (method) -> {
                    String[] methodMappings = method.getAnnotation(RequestMapping.class).value();
                    exposedRequestUriSet.add(prefix + transformMappings(classMappings) + transformMappings(methodMappings));
                },
                (method) -> method.isAnnotationPresent(RequestMapping.class)
            );
        }
    }

    /**
     * 通過工具類獲取Spring容器內的Mapper,查詢資料庫的刀key
     *
     * @param appcode
     * @param campusId
     * @return
     */
    public String selectKey(String appcode, String campusId) {
        CampusMapper campusMapper = (CampusMapper )SpringBeanInstanceAccessor.getBean(CampusMapper .class);
        return bankUserMapper.selectKey(Integer.valueOf(campusId));
    }

    /**
     * 如果方法或者類上的{@linkplain RequestMapping#value()}未指定,則使用""代替
     * value()僅支援單個值
     *
     * @param mappings
     * @return
     */
    private String transformMappings(String[] mappings) {
        return ObjectUtils.isEmpty(mappings) ? "" : mappings[0];
    }

    @Override
    public void destroy() {

    }

}