Coverage for /home/runner/work/viur-core/viur-core/viur/src/viur/core/ratelimit.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-16 22:16 +0000

1from viur.core import current, db, errors, utils 

2from viur.core.tasks import PeriodicTask, DeleteEntitiesIter 

3import typing as t 

4from datetime import timedelta 

5 

6 

7class RateLimit(object): 

8 """ 

9 This class is used to restrict access to certain functions to *maxRate* calls per minute. 

10 

11 Usage: Create an instance of this object in you modules __init__ function. then call 

12 isQuotaAvailable before executing the action to check if there is quota available and 

13 after executing the action decrementQuota. 

14 

15 """ 

16 rateLimitKind = "viur-ratelimit" 

17 

18 def __init__(self, resource: str, maxRate: int, minutes: int, method: t.Literal["ip", "user"]): 

19 """ 

20 Initializes a new RateLimit gate. 

21 

22 :param resource: Name of the resource to protect 

23 :param maxRate: Amount of tries allowed in the give time-span 

24 :param minutes: Length of the time-span in minutes 

25 :param method: Lock by IP or by the current user 

26 """ 

27 super(RateLimit, self).__init__() 

28 self.resource = resource 

29 self.maxRate = maxRate 

30 self.minutes = minutes 

31 self.steps = min(minutes, 5) 

32 self.secondsPerStep = 60 * (float(minutes) / float(self.steps)) 

33 assert method in ["ip", "user"], "method must be 'ip' or 'user'" 

34 self.useUser = method == "user" 

35 

36 def _getEndpointKey(self) -> db.Key | str: 

37 """ 

38 :warning: 

39 It's invalid to call _getEndpointKey if method is set to user and there's no user logged in! 

40 

41 :return: the key associated with the current endpoint (it's IP or the key of the current user) 

42 """ 

43 if self.useUser: 

44 user = current.user.get() 

45 assert user, "Cannot decrement usage from guest!" 

46 return user["key"] 

47 else: 

48 remoteAddr = current.request.get().request.remote_addr 

49 if "::" in remoteAddr: # IPv6 in shorted form 

50 remoteAddr = remoteAddr.split(":") 

51 blankIndex = remoteAddr.index("") 

52 missigParts = ["0000"] * (8 - len(remoteAddr)) 

53 remoteAddr = remoteAddr[:blankIndex] + missigParts + remoteAddr[blankIndex + 1:] 

54 return ":".join(remoteAddr[:4]) 

55 elif ":" in remoteAddr: # It's IPv6, so we remove the last 64 bits (interface id) 

56 # as it is easily controlled by the user 

57 return ":".join(remoteAddr.split(":")[:4]) 

58 else: # It's IPv4, simply return that address 

59 return remoteAddr 

60 

61 def _getCurrentTimeKey(self) -> str: 

62 """ 

63 :return: the current lockperiod used in second position of the memcache key 

64 """ 

65 dateTime = utils.utcNow() 

66 key = dateTime.strftime("%Y-%m-%d-%%s") 

67 secsinceMidnight = (dateTime - dateTime.replace(hour=0, minute=0, second=0, microsecond=0)).total_seconds() 

68 currentStep = int(secsinceMidnight / self.secondsPerStep) 

69 return key % currentStep 

70 

71 def decrementQuota(self) -> None: 

72 """ 

73 Removes one attempt from the pool of available Quota for that user/ip 

74 """ 

75 

76 def updateTxn(cacheKey: str) -> None: 

77 key = db.Key(self.rateLimitKind, cacheKey) 

78 obj = db.Get(key) 

79 if obj is None: 

80 obj = db.Entity(key) 

81 obj["value"] = 0 

82 obj["value"] += 1 

83 obj["expires"] = utils.utcNow() + timedelta(minutes=2 * self.minutes) 

84 db.Put(obj) 

85 

86 lockKey = f"{self.resource}-{self._getEndpointKey()}-{self._getCurrentTimeKey()}" 

87 db.RunInTransaction(updateTxn, lockKey) 

88 

89 def isQuotaAvailable(self) -> bool: 

90 """ 

91 Checks if there's currently quota available for the current user/ip 

92 :return: True if there's quota available, False otherwise 

93 """ 

94 endPoint = self._getEndpointKey() 

95 currentDateTime = utils.utcNow() 

96 secSinceMidnight = (currentDateTime - currentDateTime.replace(hour=0, minute=0, second=0, 

97 microsecond=0)).total_seconds() 

98 currentStep = int(secSinceMidnight / self.secondsPerStep) 

99 keyBase = currentDateTime.strftime("%Y-%m-%d-%%s") 

100 cacheKeys = [] 

101 for x in range(0, self.steps): 

102 cacheKeys.append( 

103 db.Key(self.rateLimitKind, f"{self.resource}-{endPoint}-{keyBase % (currentStep - x)}")) 

104 tmpRes = db.Get(cacheKeys) 

105 return sum([x["value"] for x in tmpRes if x and currentDateTime < x["expires"]]) <= self.maxRate 

106 

107 def assertQuotaIsAvailable(self, setRetryAfterHeader: bool = True) -> bool: 

108 """Assert quota is available. 

109 

110 If not quota is available a :class:`viur.core.errors.TooManyRequests` 

111 exception will be raised. 

112 

113 :param setRetryAfterHeader: Set the Retry-After header on the 

114 current request response, if the quota is exceeded. 

115 :return: True if quota is available. 

116 :raises: :exc:`viur.core.errors.TooManyRequests`, if no quote is available. 

117 """ 

118 if self.isQuotaAvailable(): 

119 return True 

120 if setRetryAfterHeader: 

121 current.request.get().response.headers["Retry-After"] = str(self.maxRate * 60) 

122 

123 raise errors.TooManyRequests( 

124 f"{self.steps} requests allowed per {self.maxRate} minute(s). Try again later." 

125 ) 

126 

127 

128@PeriodicTask(60) 

129def cleanOldRateLocks(*args, **kwargs) -> None: 

130 DeleteEntitiesIter.startIterOnQuery(db.Query(RateLimit.rateLimitKind).filter("expires <", utils.utcNow()))