小白也能看懂的Redis教學基礎篇——做一個時間窗限流就是這麼簡單

語言: CN / TW / HK

不知道ZSet(有序集合)的看官們,可以翻閱我的上一篇文章:

小白也能看懂的REDIS教學基礎篇——朋友面試被SKIPLIST跳躍表攔住了

書接上回,話說我朋友小A童鞋,終於面世通過加入了一家公司。這個公司待遇比較豐厚,而且離小A住的地方也比較近,最讓小A中意還是有個肯帶他的大佬。小A對這份工作非常滿意。時間一天一天過去,某個週末,小A來找我家吃蹭飯.在飯桌上小A給我分享了他上星期的一次事故經歷。

上個星期,他們公司出了比較嚴重的一個事故,一個匯出報表的後臺服務拖垮了報表資料服務,導致很多查詢該服務的業務都受到了牽連。主要原因是因為匯出的報表資料比較多,導致匯出時間比較漫長。前端也未針對匯出按鈕做防重試限制。多個運營人員多次點選了匯出按鈕。加上後臺服務配置的重試機制把這個流量放大了好幾倍,最後拖垮了整個報表資料服務。老大讓小A出個叢集限流的方案防止下次在出現這類問題。

這下小A慌了,單機限流好搞,使用 Hystrix 框架註解一加就完事。或者使用 Sentinel 在 Sentinel Dashboard 後臺配置一下就完事。叢集限流要怎麼弄?小A苦思冥想了一個上午也沒整出來,最後只能求助大佬幫助。

小A:大佬,老大讓我出一個叢集限流的方案,我以前對這個不熟,網站找的一堆,都是重複相同的感覺不靠譜,能教教我怎麼弄嗎?

大佬:莫慌莫慌,這件事情其實不難。我先考考你,限流的演算法有哪些?

小A:... 我想想

小A:有了,常見的限流演算法有以下三種:滑動視窗演算法,令牌桶演算法,漏桶演算法。

大佬:對頭,那你覺得單機限流和叢集限流有什麼區別呢?

小A:emmm...

大佬:你可以從叢集和單機程式本身的區別去想想。

小A:我知道了,單機限流限流資料存在單機上,只能一個機器用。而叢集是分散式多機器的,要讓多個機器共享同一份限流資料,才能保證多機器的限流。

大佬:很好,那你還記得面試時候我問你的Zset(Sorted Sets)嗎?用它就能很簡單的實現一個滑動時間窗哦。

小A:大佬快教我!

大佬:那話不多說,現在進入正題。

如果不知道Zset資料結構的,可以先去看看我的這篇文章

小白也能看懂的REDIS教學基礎篇——朋友面試被SKIPLIST跳躍表攔住了

首先來看一段實現滑動時間窗的lua程式碼(這是實現Redis滑動時間窗限流的核心程式碼)

-- 引數:
-- nowTime 當前時間
-- windowTime視窗時間
-- maxCount最大次數
-- expiredWindowTime 已經過期的視窗時間
-- value 請求標記
local nowTime = tonumber(ARGV[1]);
local windowTime = tonumber(ARGV[2]);
local maxCount = tonumber(ARGV[3]);
local expiredWindowTime = tonumber(ARGV[4])
local value = ARGV[5];
local key = KEYS[1];
-- 獲取當前視窗的請求標誌個數
local count = redis.call('ZCARD', key)
-- 比較當前已經請求的數量是否大於視窗最大請求數
if count >= maxCount then
    -- 如果大於最大請求數
    -- 刪除過期的請求標誌 釋放視窗空間 等同於滑動時間視窗向前滑動
    redis.call('ZREMRANGEBYSCORE', key, 0, expiredWindowTime)
    -- 再次獲取當前視窗的請求標誌個數
    local count = redis.call('ZCARD', key)
    -- 延長過期時間
    redis.call('PEXPIRE', key, windowTime + 1000)
    -- 比較釋放後的大小 是否小於視窗最大請求數
    if count < maxCount then
        -- 返回200代表成功
        return 200
    else
        -- 返回500代表失敗
        return 500
    end
else
     -- 插入當前訪問的訪問標記
    redis.call('ZADD', key, nowTime, value)
    -- 延長過期時間
    redis.call('PEXPIRE', key, windowTime + 1000)
    -- 返回200代表成功
    return 200
end

Redis接收到請求,開始執行 lua 指令碼。根據 Key 找到對應的Zset也就是時間窗。獲取當前時間窗內請求標記的數量,如果小於最大視窗允許訪問最大次數,直接插入最新的請求標記 設定標記的score = 1642403014820 ms 。延長該視窗的過期時間,並返回成功。如下圖所示:

如果獲取到當前時間窗內請求標記的數量,大於或者等於視窗最大允許請求數。如下圖所示,獲取到當前時間窗內請求標記的數量為6,大於視窗最大允許請求數5。

則先根據時間窗大小刪除視窗中已經過期的請求。當前請求的 score = 1642403014820 ms 時間視窗大小的 10000ms。那過期時間就是1642403014820 - 10000 = 1642403004820;那就刪除 score < 1642403004820 的節點。

刪除完成後,再次獲取當前視窗中請求標記數量,可以看到當前數量為1小於視窗最大請求數。插入最新的請求標記 score = 1642403014820 ms 。延長該視窗的過期時間,並返回成功。

大佬:現在是否明白了一些?

小A:明白了,簡單總結一下就是,利用Redis中現成的資料結構ZSet(有序集合)來做時間視窗。集合中的排序值為請求發生時的時間戳。在請求發生時,

統計時間視窗中總的請求數,如果總請求數小於視窗允許最大請求數,就插入一個請求標記,也就相當於視窗中請求數加一。如果總請求數大於或者等於視窗允許最大請求數,則需要刪除過期的統計,以便釋放足夠的空間。刪除的方式就是先計算出視窗的前邊界,也就是已經失效的最大時間。根據這個時間戳,然後利用Zset原生的刪除方法 ZREMRANGEBYSCORE key min max 刪除小於最大失效時間的請求標記,其實這裡的刪除過期資料也就等同於滑動時間視窗向前滑動。刪除完成後再次統計下視窗中剩餘的請求數是否大於或者等於視窗最大請求數,如果大於就直接返回失敗,告訴客戶端拒絕該請求。如果小於就插入當前請求的請求標記 score 為當前請求的請求時間戳。至此完成了一次限流請求。可是我不明白為什麼要使用 lua 指令碼呢?這不是增加了維護成本嗎?

大佬:不錯。基本原理就是這樣的,至於為什麼使用 lua 指令碼。那為了原子的執行多個命令和限流的判斷邏輯。防止在你執行刪除或者獲取總數的命令時,其他人也在執行導致資料不準確,從而使限流失敗。

小A:嗯嗯。明白了。

大佬:不過這個限流也存在不足。比如需要設定一個10秒內允許訪問100萬次的請求,它就不合適,因為這樣視窗中會有100萬個請求機率,會消耗大量的記憶體空間。切記!不要盲目使用,要根據自己業務量來綜合考量。

小A:好的,大佬,記住了!

大佬:瞭解完這個,接下來我們來看一下JAVA程式碼怎麼寫。

首先我們是通過 Spring AOP 和 標記註解 @CurrentLimiting 來實現限流方案的。

    @GetMapping("getId")
    @CurrentLimiting(value = "getId",
            // ErrorCallback 是錯誤回撥 callback 是 bean Name callbackClass 是實現類 Class
            errorCallback = @ErrorCallback(callback = "redisCurrentLimitingDegradeCallbackImpl", callbackClass = RedisCurrentLimitingErrorCallbackImpl.class)
,
            // DegradeCallback 是降級回撥 callback 是 bean Name callbackClass 是實現類 Class
            degradeCallback = @DegradeCallback(callback = "redisCurrentLimitingDegradeCallbackImpl", callbackClass = RedisCurrentLimitingDegradeCallbackImpl.class))
    public Integer getId(){
        return 1;
    }

下面介紹AOP切面類:

package com.raiden.redis.current.limiter.aop;
​
import com.raiden.redis.current.limiter.RedisCurrentLimiter;
import com.raiden.redis.current.limiter.annotation.CurrentLimiting;
import com.raiden.redis.current.limiter.annotation.DegradeCallback;
import com.raiden.redis.current.limiter.annotation.ErrorCallback;
import com.raiden.redis.current.limiter.callbock.RedisCurrentLimitingDegradeCallback;
import com.raiden.redis.current.limiter.chain.ErrorCallbackChain;
import com.raiden.redis.current.limiter.info.RedisCurrentLimiterInfo;
import com.raiden.redis.current.limiter.properties.RedisCurrentLimiterProperties;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotationUtils;
​
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
​
​
/**
 * @建立人:Raiden
 * @Descriotion:
 * @Date:Created in 23:51 2020/8/27
 * @Modified By:
 * 限流AOP切面類
 */
@Aspect
public class RedisCurrentLimitingAspect {
​
    private Map<String, RedisCurrentLimiterInfo> config;
    private ApplicationContext context;
​
    private ConcurrentHashMap<Method, ErrorCallbackChain> errorCallbackChainCache;
​
    private ConcurrentHashMap<Method, RedisCurrentLimitingDegradeCallback> degradeCallbackCache;
​
    public RedisCurrentLimitingAspect(ApplicationContext context, RedisCurrentLimiterProperties properties){
        this.context = context;
        this.config = properties.getConfig();
        this.errorCallbackChainCache = new ConcurrentHashMap<>();
        this.degradeCallbackCache = new ConcurrentHashMap<>();
    }
​
    @Pointcut("@annotation(com.raiden.redis.current.limiter.annotation.CurrentLimiting) || @within(com.raiden.redis.current.limiter.annotation.CurrentLimiting)")
    public void intercept(){}
​
    @Around("intercept()")
    public Object currentLimitingHandle(ProceedingJoinPoint joinPoint) throws Throwable{
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        CurrentLimiting annotation = AnnotationUtils.findAnnotation(method, CurrentLimiting.class);
        if (annotation == null){
            annotation = method.getDeclaringClass().getAnnotation(CurrentLimiting.class);
        }
        String path = annotation.value();
        //如果沒有配置 資源 直接 放過
        //如果沒有找到限流配置 也放過
        RedisCurrentLimiterInfo info;
        if (path != null && !path.isEmpty() && (info = config.get(path)) != null){
            try {
                //檢視是否需要限流
                boolean allowAccess = RedisCurrentLimiter.isAllowAccess(path, info.getWindowTime(), info.getWindowTimeUnit(), info.getMaxCount());
                if (allowAccess){
                    return joinPoint.proceed();
                }else {
                    //獲取降級處理器
                    RedisCurrentLimitingDegradeCallback currentLimitingDegradeCallback = degradeCallbackCache.get(method);
                    if (currentLimitingDegradeCallback == null){
                        degradeCallbackCache.putIfAbsent(method, getRedisCurrentLimitingDegradeCallback(annotation));
                    }
                    currentLimitingDegradeCallback = degradeCallbackCache.get(method);
                    //呼叫降級回撥
                    return currentLimitingDegradeCallback.callback();
                }
            }catch (Throwable e){
                //如果報錯 交給 錯誤回撥
                ErrorCallbackChain errorCallbackChain = errorCallbackChainCache.get(method);
                if (errorCallbackChain == null){
                    ErrorCallback[] errorCallbacks = annotation.errorCallback();
                    if (errorCallbacks.length == 0){
                        throw e;
                    }
                    //放入錯誤回撥快取
                    errorCallbackChainCache.putIfAbsent(method, new ErrorCallbackChain(errorCallbacks, context));
                }
                errorCallbackChain = errorCallbackChainCache.get(method);
                return errorCallbackChain.execute(e);
            }
        }
        return joinPoint.proceed();
    }
​
    private RedisCurrentLimitingDegradeCallback getRedisCurrentLimitingDegradeCallback(CurrentLimiting annotation) throws IllegalAccessException, InstantiationException {
        DegradeCallback degradeCallback = annotation.degradeCallback();
        String callback = degradeCallback.callback();
        if (callback == null || callback.isEmpty()){
            return degradeCallback.callbackClass().newInstance();
        }else {
            return context.getBean(degradeCallback.callback(), degradeCallback.callbackClass());
        }
    }
}
RedisCurrentLimiter Redis時間窗限流執行類:
package com.raiden.redis.current.limiter;
​
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
​
import java.net.Inet4Address;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
​
​
/**
 * @建立人:Raiden
 * @Descriotion:
 * @Date:Created in 23:51 2020/8/27
 * @Modified By:
 */
public final class RedisCurrentLimiter {
​
    private static final String CURRENT_LIMITER = "CurrentLimiter:";;
​
    private static String ip;;
​
    private static RedisTemplate redis;
​
    private static ResourceScriptSource resourceScriptSource;
​
    protected static void init(RedisTemplate redisTemplate){
        if (redisTemplate == null){
            throw new NullPointerException("The parameter cannot be null");
        }
        try {
            ip = Inet4Address.getLocalHost().getHostAddress().replaceAll("\\.", "");
        } catch (UnknownHostException e) {
            throw new RuntimeException(e);
        }
        redis = redisTemplate;
        //lua檔案存放在resources目錄下的redis資料夾內
        resourceScriptSource = new ResourceScriptSource(new ClassPathResource("redis/redis-current-limiter.lua"));
    }
​
    public static boolean isAllowAccess(String path, int windowTime, TimeUnit windowTimeUnit, int maxCount){
        if (redis == null){
            throw new NullPointerException("Redis is not initialized !");
        }
        if (path == null || path.isEmpty()){
            throw new IllegalArgumentException("The path parameter cannot be empty !");
        }
        //獲取 key
        final String key = new StringBuffer(CURRENT_LIMITER).append(path).toString();
        //獲取當前時間戳
        long now = System.currentTimeMillis();
        //獲取視窗時間 並轉換為 毫秒
        long windowTimeMillis = windowTimeUnit.toMillis(windowTime);
        //呼叫lua指令碼並執行
        DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
        //設定返回型別是Long
        redisScript.setResultType(Long.class);
        //設定 lua 指令碼原始碼
        redisScript.setScriptSource(resourceScriptSource);
        //執行 lua 指令碼
        Long result = (Long) redis.execute(redisScript, Arrays.asList(key), now, windowTimeMillis, maxCount, now - windowTimeMillis, createValue(now));
        //獲取到返回值
        return result.intValue() == 200;
    }
​
    private static String createValue(long now){
        return new StringBuilder(ip).append(now).append(Math.random() * 100).toString();
    }
}
RedisCurrentLimiterConfiguration Redis滑動時間窗限流配置類:
package com.raiden.redis.current.limiter.config;
​
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.fasterxml.jackson.module.paramnames.ParameterNamesModule;
import com.raiden.redis.current.limiter.RedisCurrentLimiterInit;
import com.raiden.redis.current.limiter.aop.RedisCurrentLimitingAspect;
import com.raiden.redis.current.limiter.common.PublicString;
import com.raiden.redis.current.limiter.properties.RedisCurrentLimiterProperties;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Primary;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
​
/**
 * @建立人:Raiden
 * @Descriotion:
 * @Date:Created in 23:51 2020/8/27
 * @Modified By:
 */
@EnableConfigurationProperties(RedisCurrentLimiterProperties.class)
// 判斷 配置中是否 有 redis.current-limiter.enabled = true 才載入
@ConditionalOnProperty(
        name = {"redis.current-limiter.enabled"}
)
public class RedisCurrentLimiterConfiguration {
​
    /**
     * AOP 切面配置
     * @param properties
     * @param context
     * @return
     */
    @Bean
    public RedisCurrentLimitingAspect redisCurrentLimitingAspect(RedisCurrentLimiterProperties properties, ApplicationContext context){
        return new RedisCurrentLimitingAspect(context, properties);
    }
​
    /**
     * RedisTemplate配置
     * @param redisConnectionFactory
     * @return
     */
    @Bean(PublicString.REDIS_CURRENT_LIMITER_REDIS_TEMPLATE)
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        // 設定序列化
        Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<Object>(
                Object.class);
        ObjectMapper om = new ObjectMapper()
                .registerModule(new ParameterNamesModule())
                .registerModule(new Jdk8Module())
                .registerModule(new JavaTimeModule());
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(om);
        // 配置redisTemplate
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<String, Object>();
        redisTemplate.setConnectionFactory(redisConnectionFactory);
        RedisSerializer stringSerializer = new StringRedisSerializer();
        // key序列化
        redisTemplate.setKeySerializer(stringSerializer);
        // value序列化
        redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
        // Hash key序列化
        redisTemplate.setHashKeySerializer(stringSerializer);
        // Hash value序列化
        redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
        redisTemplate.afterPropertiesSet();
        return redisTemplate;
    }
​
    /**
     * 生成 Redis 滑動時間窗限流器 初始化類
     * @param redisTemplate
     * @return
     */
    @Bean
    @ConditionalOnBean(name = PublicString.REDIS_CURRENT_LIMITER_REDIS_TEMPLATE)
    public RedisCurrentLimiterInit redisCurrentLimiterInit(RedisTemplate<String, Object> redisTemplate){
        return new RedisCurrentLimiterInit(redisTemplate);
    }
}
大佬:好了,Java程式碼實戰也帶你看過了,現在你學廢了嗎?如果覺得好的話請點個贊喲。

程式碼github地址:

https://github.com/RaidenXin/redis-current-limiter.git

有需要的同學可以拉下來看看。