/* (C) 2023 YiRing, Inc. */
package com.yiring.common.aspect;

import cn.hutool.core.util.StrUtil;
import com.yiring.common.annotation.RateLimiter;
import com.yiring.common.core.Redis;
import com.yiring.common.core.Status;
import com.yiring.common.util.IpUtil;
import com.yiring.common.utils.Contexts;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.ZSetOperations;
import org.springframework.stereotype.Component;

/**
 * 流控切面
 *
 * @author Jim
 * @version 0.1
 * 2023/12/19 15:17
 */
@Slf4j
@Aspect
@Component
@RequiredArgsConstructor
public class RateLimiterAspect {

    final Redis redis;

    @Value("${sa-token.token-name:}")
    String tokenName;

    /**
     * 带有注解的方法之前执行
     */
    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) {
        int time = rateLimiter.time();
        int count = rateLimiter.count();

        // 将接口方法和用户IP构建Redis的key
        String key = getRateLimiterKey(rateLimiter.key(), point);

        // 使用 ZSet 的 score 设置成用户访问接口的时间戳
        ZSetOperations<String, Object> zSetOperations = redis.getTemplate().opsForZSet();

        // 当前时间戳
        long currentTime = System.currentTimeMillis();
        zSetOperations.add(key, currentTime, currentTime);

        // 设置过期时间防止 key 不消失
        redis.getTemplate().expire(key, time, TimeUnit.SECONDS);

        // 移除 time 秒之前的访问记录，动态时间段
        zSetOperations.removeRangeByScore(key, 0, currentTime - time * 1000L);

        // 获得当前时间窗口内的访问记录数
        Long currentCount = zSetOperations.zCard(key);
        // 限流判断
        if (Objects.nonNull(currentCount) && currentCount > count) {
            log.warn("[Request RateLimit] Key: {}, count: {}, currentCount: {}", key, count, currentCount);
            throw Status.TOO_MANY_REQUESTS.exception();
        }
    }

    /**
     * 组装 redis 的 key
     */
    private String getRateLimiterKey(String prefixKey, JoinPoint point) {
        StringBuilder sb = new StringBuilder(prefixKey);
        HttpServletRequest request = Contexts.getRequest();
        // 获取请求的 IP 地址
        sb.append(IpUtil.getClientIp(request));

        // 考虑登录用户的 token
        if (StrUtil.isNotBlank(tokenName)) {
            // 1. 优先从请求头中获取 token
            String token = request.getHeader(tokenName);
            // 2. 其次从 cookie 中获取 token
            Cookie[] cookies = request.getCookies();
            if (StrUtil.isBlank(token) && Objects.nonNull(cookies)) {
                token =
                    Arrays
                        .stream(cookies)
                        .filter(Objects::nonNull)
                        .filter(cookie -> tokenName.equals(cookie.getName()))
                        .map(Cookie::getValue)
                        .findFirst()
                        .orElse(null);
            }
            if (StrUtil.isBlank(token)) {
                // 3. 再其次从请求参数中获取 token
                token = request.getParameter(tokenName);
            }

            // 设置 token 作为 key 的一部分
            if (StrUtil.isNotBlank(token)) {
                sb.append("_").append(token);
            }
        }

        // 获取类名和方法名
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        return sb.append("_").append(targetClass.getName()).append("_").append(method.getName()).toString();
    }
}
