package cn.gtmap.hlw.core.aop;

import cn.gtmap.hlw.core.annotation.RequestLimit;
import cn.gtmap.hlw.core.enums.error.ErrorEnum;
import cn.gtmap.hlw.core.exception.BizException;
import cn.gtmap.hlw.core.util.session.SessionUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ZSetOperations;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.util.concurrent.TimeUnit;

/**
 * @author <a href="mailto:dingweiwei@gtmap.cn">dingweiwei</a>
 * @version 1.0, 2024/11/21
 * @description
 */
@Aspect
@Component
@Slf4j
public class RequestLimitAspect {
    @Autowired
    private RedisTemplate redisTemplate;

    // 切点
    @Pointcut("@annotation(requestLimit)")
    public void controllerAspect(RequestLimit requestLimit) {
    }

    @Around("controllerAspect(requestLimit)")
    public Object doAround(ProceedingJoinPoint joinPoint, RequestLimit requestLimit) throws Throwable {
        // get parameter from annotation
        long period = requestLimit.period();
        long limitCount = requestLimit.count();
        String methodName = requestLimit.methodName();

        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = attributes.getRequest();
        // request info
        String uri = request.getRequestURI();
        String userId = SessionUtil.getUserId();
        //默认按接口地址过滤  配置方法名  按方法名过滤
        String key = "req_limit_".concat(uri).concat(userId);
        if (StringUtils.isNotBlank(methodName)) {
            key = "req_limit_".concat(methodName).concat(userId);
        }

        ZSetOperations zSetOperations = redisTemplate.opsForZSet();

        // add current timestamp
        long currentMs = System.currentTimeMillis();
        zSetOperations.add(key, currentMs, currentMs);

        // set the expiration time for the code user
        redisTemplate.expire(key, period, TimeUnit.SECONDS);

        // remove the value that out of current window
        zSetOperations.removeRangeByScore(key, 0, currentMs - period * 1000);

        // check all available count
        Long count = zSetOperations.zCard(key);

        if (count > limitCount) {
            log.error("接口拦截：{} 请求超过限制频率【{}次/{}s】,用户为{}", uri, limitCount, period, userId);
            throw new BizException(ErrorEnum.TOO_FREQUENT_VISIT);
        }
        // execute the user request
        return joinPoint.proceed();
    }
}
