1. 程式人生 > >Django框架(二十二)—— Django rest_framework-頻率元件

Django框架(二十二)—— Django rest_framework-頻率元件

目錄

頻率元件

一、作用

為了控制使用者對某個url請求的頻率,比如,一分鐘以內,只能訪問三次

二、自定義頻率類

# 寫一個頻率認證類
class MyThrottle:
    visit_dic = {}
    visit_time = None

    def __init__(self):
        self.ctime = time.time()
    
    # 重寫allow_request()方法
    # request是request物件,view是檢視類,可以對檢視類進行操作
    def allow_request(self, request, view):
        '''
            (1)取出訪問者ip
            (2)判斷當前ip不在訪問字典裡,新增進去,並且直接返回True,表示第一次訪問,在字典裡,繼續往下走
            (3)迴圈判斷當前ip的列表,有值,並且當前時間減去列表的最後一個時間大於60s,把這種資料pop掉,這樣列表中只有60s以內的訪問時間,
            (4)判斷,當列表小於3,說明一分鐘以內訪問不足三次,把當前時間插入到列表第一個位置,返回True,順利通過
            (5)當大於等於3,說明一分鐘內訪問超過三次,返回False驗證失敗
            visit_dic = {ip1:[time2, time1, time0],
                         ip2:[time1, time0],
                        }
        '''
        
        # 取出訪問者ip,ip可以從請求頭中取出來
        ip = request.META.get('REMOTE_ADDR')
        # 判斷該次請求的ip是否在地點中
        if ip in self.visit_dic:
            # 當存在字典中時,取出這個ip訪問時間的列表
            visit_time = self.visit_dic[ip]
            self.visit_time = visit_time
            while visit_time:
                # 當訪問時間列表中有值,時間間隔超過60,就將那個歷史時間刪除
                if self.ctime - visit_time[-1] > 60:
                    visit_time.pop()
                else:
                    # 當pop到一定時,時間間隔不大於60了,退出迴圈,此時得到的是60s內訪問的時間記錄
                    break
                    
            # while迴圈等價於
            # while visit_time and ctime - visit_time[-1] > 60:
            #     visit_time.pop()
            
            # 列表長度可表示訪問次數,根據原始碼,可以得出,返回值是Boolean型別
            if len(visit_time) >= 3:
                return False
            else:
                # 如果60秒內訪問次數小於3次,將當前訪問的時間記錄下來
                visit_time.insert(0, self.ctime)
                return True
        else:
            # 如果字典中沒有當前訪問ip,將ip加到字典中
            self.visit_dic[ip] = [self.ctime, ]
            return True

    # 獲取下次距訪問的時間
    def wait(self):
        return 60 - (self.ctime - self.visit_time[-1])
    
# view層
from app01 import MyAuth
from rest_framework import exceptions

class Book(APIView):
    # 區域性使用頻率控制
    throttle_classes = [MyAuth.MyThrottle, ]
    
    def get(self,request):
        return HttpResponse('ok')
    
    # 重寫丟擲異常的方法 throttled
    def throttled(self, request, wait):
        class MyThrottled(exceptions.Throttled):
            default_detail = '下次訪問'
            extra_detail_singular = '還剩 {wait} 秒.'
            extra_detail_plural = '還剩 {wait} 秒'

        raise MyThrottled(wait)
        

三、內建的訪問頻率控制類

from rest_framework.throttling import SimpleRateThrottle

# 寫一個頻率控制類,繼承SimpleRateThrottle類
class MyThrottle(SimpleRateThrottle):
    # 配置scope,通過scope到setting中找到 3/m
    scope = 'ttt'

    def get_cache_key(self, request, view):
        # 返回ip,效果和 get_ident() 方法相似
        # ip = request.META.get('REMOTE_ADDR')
        # return ip

        # get_ident 返回的就是ip地址
        return self.get_ident(request)
# view層檢視類
class Book(APIView):
    throttle_classes = [MyAuth.MyThrottle, ]

    def get(self, request):
        return HttpResponse('ok')

    def throttled(self, request, wait):
        class MyThrottled(exceptions.Throttled):
            default_detail = '下次訪問'
            extra_detail_singular = '還剩 {wait} 秒.'
            extra_detail_plural = '還剩 {wait} 秒'

        raise MyThrottled(wait)
# setting中配置
REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_RATES': {
        'ttt': '10/m'
    }
}
  • 因此,要實現10分鐘允許訪問六次,可以繼承SimpleRateThrottle類,然後重寫parse_rate()方法,將duration中key對應的值改為自己需要的值

四、全域性、區域性使用

1、全域性使用

在setting中配置

REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_CLASSES': ['app01.MyAuth.MyThrottle', ],
}

2、區域性使用

在檢視類中重定義throttle_classes

throttle_classes = [MyAuth.MyThrottle, ]

3、區域性禁用

在檢視類中重定義throttle_classes為一個空列表

throttle_classes = []

五、原始碼分析

1、as_view -----> view ------> dispatch ------> initial ----> check_throttles 頻率控制

2、self.check_throttles(request)

    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        # (2-----1) get_throttles 由頻率類產生的物件組成的列表
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                # (4)異常資訊的處理
                self.throttled(request, throttle.wait())

(2-----1) self.get_throttles()

    def get_throttles(self):
        """
        Instantiates and returns the list of throttles that this view uses.
        """
        return [throttle() for throttle in self.throttle_classes]

3、allow_request()

自身、所在類找都沒有,去父類中找

class SimpleRateThrottle(BaseThrottle):
    
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)
    
    
    def parse_rate(self, rate):
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)
    
    
    def allow_request(self, request, view):
        if self.rate is None:
            return True
        # (3-----1) get_cache_key就是要重寫的方法,若不重寫,會直接丟擲異常
        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    
    # 返回距下一次能請求的時間
    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

(3-----1) self.get_cache_key(request, view)

    def get_cache_key(self, request, view):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.

        May return `None` if the request should not be throttled.
        """
        raise NotImplementedError('.get_cache_key() must be overridden')

4、self.throttled(request, throttle.wait()) --------> 丟擲異常

    def throttled(self, request, wait):
        """
        If request is throttled, determine what kind of exception to raise.
        """
        raise exceptions.Throttled(wait)

(4------1)raise exceptions.Throttled(wait) -------> 異常資訊

class Throttled(APIException):
    status_code = status.HTTP_429_TOO_MANY_REQUESTS
    # 重寫下面三個變數就可以修改顯示的異常資訊,例如用中文顯示異常資訊
    default_detail = _('Request was throttled.')
    extra_detail_singular = 'Expected available in {wait} second.'
    extra_detail_plural = 'Expected available in {wait} seconds.'
    default_code = 'throttled'

    def __init__(self, wait=None, detail=None, code=None):
        if detail is None:
            detail = force_text(self.default_detail)
        if wait is not None:
            wait = math.ceil(wait)
            detail = ' '.join((
                detail,
                force_text(ungettext(self.extra_detail_singular.format(wait=wait),
                                     self.extra_detail_plural.format(wait=wait),
                                     wait))))
        self.wait = wait
        super(Throttled, self).__init__(detail, code)