前言

rpc 的服务提供者,有时候希望可以统一执行参数校验,或者验签。

基本实现

基本实现

  • PegasusServerInterceptor.java
import cn.hutool.core.util.ArrayUtil;
import com.alibaba.fastjson.JSON;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.EnableAspectJAutoProxy;
import org.springframework.stereotype.Component;
import org.springframework.web.multipart.MultipartFile;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.ArrayList;
import java.util.List;

/**
 *
 * 可以添加下列特性:
 *
 * 1. mdc
 * 2. 入参
 * 3. 出参
 * 4. 统一异常处理
 *
 * 暂时先处理 1/2
 * @author binbin.hou
 * @since 1.0.0
 */
@Aspect
@Component
@EnableAspectJAutoProxy
public class PegasusServerInterceptor {

    /**
     * 日志实例
     * @since 1.0.0
     */
    private static final Logger LOG = LoggerFactory.getLogger(PegasusServerInterceptor.class);

    /**
     * 拦截 services 下所有的 public方法
     */
    @Pointcut("execution(public * com.xxx.pegasus.server.services..*(..))")
    public void pointCut() {
        //
    }

    /**
     * 拦截处理
     *
     * @param point point 信息
     * @return result
     * @throws Throwable if any
     */
    @Around("pointCut()")
    public Object around(ProceedingJoinPoint point) throws Throwable {
        //1. 设置 MDC
        LogUtil.putMdcIfAbsent();
        try {
            // 获取当前拦截的方法签名
            String signatureShortStr = point.getSignature().toShortString();
            //2. 打印入参信息
            Object[] args = point.getArgs();
            List<Object> filterArgs = getFilterArgs(args);
            LOG.info("服务 {}, 入参: {}", signatureShortStr, JSON.toJSON(filterArgs));

            //1.1 参数校验
            Object firstParam = filterArgs.get(0);
            if(firstParam != null) {
                ValidateUtils.validate(firstParam);

                //1.2 签名校验
                BaseRpcRequest baseRpcRequest = (BaseRpcRequest) firstParam;
                // 签名校验
                Md5SignCheckUtils.checkSign(baseRpcRequest.getTraceId(), baseRpcRequest.getSign());
            }

            Object result = point.proceed();
            LOG.info("服务 {}, 出参: {}", signatureShortStr, JSON.toJSON(result));
            return result;
        } finally {
            LogUtil.removeMdc();
        }
    }

    /**
     * 避免 http 复杂参数异常
     * @param args 参数
     * @return 结果
     */
    private List<Object> getFilterArgs(Object[] args) {
        List<Object> list = new ArrayList<>();
        if(ArrayUtil.isEmpty(args)) {
            return list;
        }

        for(Object o : args) {
            if(o instanceof HttpServletRequest) {
                continue;
            }

            if(o instanceof HttpServletResponse) {
                continue;
            }

            if(o instanceof MultipartFile) {
                continue;
            }

            list.add(o);
        }

        return list;
    }

}

其中 ValidateUtils.validate(firstParam); 参数校验失败,会抛出参数异常。

Md5SignCheckUtils.checkSign(baseRpcRequest.getTraceId(), baseRpcRequest.getSign()); 会校验签名,抛出运行时异常。

预定所有的请求,都继承自 BaseRpcRequest 类。

使用者

所有的 rpc facade 实现如下:

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.List;

/**
 * @author binbin.hou
 * @since 1.0.0
 */
@Service
@Slf4j
public class MyServiceImpl implements MyService {

    @Override
    public MyListResponse queryTodoList(MyListRequest request) {
        try {
            MyListResponse response = new MyListResponse();

            response.setRespCode(RespCode.SUCCESS.getCode());
            response.setRespMessage(RespCode.SUCCESS.getDesc());
            return response;
        } catch (BizException bizException) {
            log.error("业务异常", bizException);
            MyListResponse rpcResponse = new MyListResponse();
            rpcResponse.setRespCode(bizException.getCode());
            rpcResponse.setRespMessage(bizException.getMessage());
            return rpcResponse;
        } catch (Exception exception) {
            log.error("系统异常", exception);
            MyListResponse rpcResponse = new MyListResponse();
            rpcResponse.setRespCode(RespCode.FAIL.getCode());
            rpcResponse.setRespMessage(RespCode.FAIL.getDesc());
            return rpcResponse;
        }

    }

}

预期

预期是如果参数校验失败,则被实现子类的 try catch 捕获,然后返回给前端。

问题1:全部异常未被正确捕获

说明

AOP 的拦截,实际上在方法执行之前。

所以在 aop 中如果出现异常,没有捕获,方法中是无法捕获到的。

尝试修复问题1

我们可以直接给 aop 中添加对应的 try catch。

@Around("pointCut()")
public Object around(ProceedingJoinPoint point) throws Throwable {
    //1. 设置 MDC
    LogUtil.putMdcIfAbsent();
    try {
        // 获取当前拦截的方法签名
        String signatureShortStr = point.getSignature().toShortString();
        //2. 打印入参信息
        Object[] args = point.getArgs();
        List<Object> filterArgs = getFilterArgs(args);
        LOG.info("服务 {}, 入参: {}", signatureShortStr, JSON.toJSON(filterArgs));
        //1.1 参数校验
        Object firstParam = filterArgs.get(0);
        if(firstParam != null) {
            ValidateUtils.validate(firstParam);
            //1.2 签名校验
            BaseRpcRequest baseRpcRequest = (BaseRpcRequest) firstParam;
            // 签名校验
            Md5SignCheckUtils.checkSign(baseRpcRequest.getTraceId(), baseRpcRequest.getSign());
        }
        Object result = point.proceed();
        LOG.info("服务 {}, 出参: {}", signatureShortStr, JSON.toJSON(result));
        return result;
    } catch (BizException bizException) {
        log.error("业务异常", bizException);
        BaseRpcResponse rpcResponse = new BaseRpcResponse();
        rpcResponse.setRespCode(bizException.getCode());
        rpcResponse.setRespMessage(bizException.getMessage());
        return rpcResponse;
    } catch (Exception exception) {
        log.error("系统异常", exception);
        BaseRpcResponse rpcResponse = new BaseRpcResponse();
        rpcResponse.setRespCode(RespCode.FAIL.getCode());
        rpcResponse.setRespMessage(RespCode.FAIL.getDesc());
        return rpcResponse;
    } finally {
        LogUtil.removeMdc();
    }
}

BaseRpcResponse 是所有响应统一集继承的父类。

这样看起来问题不大。

但是实际上 rpc 调用的时候,会报错 CAST 类转换异常。

因为 BaseRpcResponse 是父类,会被转换为具体的子类,这个时候直接就崩溃了。

如何解决呢?

解决 CAST 转换异常

思路

如果我们的方法定义的足够规范,那么是可以通过反射获取 method 对应的响应类的。

反射初始化实例,然后反射设置对应的值即可。

实现

@Around("pointCut()")
    public Object around(ProceedingJoinPoint point) throws Throwable {
    try {
        Object result = point.proceed();
        return result;
    }catch (BizException bizException){
        Class<?> returnType = ((MethodSignature) point.getSignature()).getReturnType();
        Object returnObj = returnType.newInstance();
        BeanUtils.setProperty(returnObj, Constants.RESP_CODE, bizException.getCode());
        BeanUtils.setProperty(returnObj, Constants.RESP_MESS, bizException.getMessage());
        return returnObj;
    } catch (Exception e){
        Class<?> returnType = ((MethodSignature) point.getSignature()).getReturnType();
        Object returnObj = returnType.newInstance();
        BeanUtils.setProperty(returnObj, Constants.RESP_CODE, RespCode.FAIL.getCode());
        BeanUtils.setProperty(returnObj, Constants.RESP_MESS, RespCode.FAIL.getDesc());
        return returnObj;
    } finally {
        //...
    }
}

1)point.getSignature()).getReturnType(); 获取 method 对应的返回值类别,通过 newInstance() 创建实例。

2)通过 BeanUtils 设置对应的属性值。或者通过反序列化也行。

说明

总的来说,实现不是很难。

但是需要考虑清楚。

这个和 https 的拦截有一丝不同。

参考资料

实战