Coverage for /home/runner/work/viur-core/viur-core/viur/src/viur/core/ratelimit.py: 0%
68 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-16 22:16 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-16 22:16 +0000
1from viur.core import current, db, errors, utils
2from viur.core.tasks import PeriodicTask, DeleteEntitiesIter
3import typing as t
4from datetime import timedelta
7class RateLimit(object):
8 """
9 This class is used to restrict access to certain functions to *maxRate* calls per minute.
11 Usage: Create an instance of this object in you modules __init__ function. then call
12 isQuotaAvailable before executing the action to check if there is quota available and
13 after executing the action decrementQuota.
15 """
16 rateLimitKind = "viur-ratelimit"
18 def __init__(self, resource: str, maxRate: int, minutes: int, method: t.Literal["ip", "user"]):
19 """
20 Initializes a new RateLimit gate.
22 :param resource: Name of the resource to protect
23 :param maxRate: Amount of tries allowed in the give time-span
24 :param minutes: Length of the time-span in minutes
25 :param method: Lock by IP or by the current user
26 """
27 super(RateLimit, self).__init__()
28 self.resource = resource
29 self.maxRate = maxRate
30 self.minutes = minutes
31 self.steps = min(minutes, 5)
32 self.secondsPerStep = 60 * (float(minutes) / float(self.steps))
33 assert method in ["ip", "user"], "method must be 'ip' or 'user'"
34 self.useUser = method == "user"
36 def _getEndpointKey(self) -> db.Key | str:
37 """
38 :warning:
39 It's invalid to call _getEndpointKey if method is set to user and there's no user logged in!
41 :return: the key associated with the current endpoint (it's IP or the key of the current user)
42 """
43 if self.useUser:
44 user = current.user.get()
45 assert user, "Cannot decrement usage from guest!"
46 return user["key"]
47 else:
48 remoteAddr = current.request.get().request.remote_addr
49 if "::" in remoteAddr: # IPv6 in shorted form
50 remoteAddr = remoteAddr.split(":")
51 blankIndex = remoteAddr.index("")
52 missigParts = ["0000"] * (8 - len(remoteAddr))
53 remoteAddr = remoteAddr[:blankIndex] + missigParts + remoteAddr[blankIndex + 1:]
54 return ":".join(remoteAddr[:4])
55 elif ":" in remoteAddr: # It's IPv6, so we remove the last 64 bits (interface id)
56 # as it is easily controlled by the user
57 return ":".join(remoteAddr.split(":")[:4])
58 else: # It's IPv4, simply return that address
59 return remoteAddr
61 def _getCurrentTimeKey(self) -> str:
62 """
63 :return: the current lockperiod used in second position of the memcache key
64 """
65 dateTime = utils.utcNow()
66 key = dateTime.strftime("%Y-%m-%d-%%s")
67 secsinceMidnight = (dateTime - dateTime.replace(hour=0, minute=0, second=0, microsecond=0)).total_seconds()
68 currentStep = int(secsinceMidnight / self.secondsPerStep)
69 return key % currentStep
71 def decrementQuota(self) -> None:
72 """
73 Removes one attempt from the pool of available Quota for that user/ip
74 """
76 def updateTxn(cacheKey: str) -> None:
77 key = db.Key(self.rateLimitKind, cacheKey)
78 obj = db.Get(key)
79 if obj is None:
80 obj = db.Entity(key)
81 obj["value"] = 0
82 obj["value"] += 1
83 obj["expires"] = utils.utcNow() + timedelta(minutes=2 * self.minutes)
84 db.Put(obj)
86 lockKey = f"{self.resource}-{self._getEndpointKey()}-{self._getCurrentTimeKey()}"
87 db.RunInTransaction(updateTxn, lockKey)
89 def isQuotaAvailable(self) -> bool:
90 """
91 Checks if there's currently quota available for the current user/ip
92 :return: True if there's quota available, False otherwise
93 """
94 endPoint = self._getEndpointKey()
95 currentDateTime = utils.utcNow()
96 secSinceMidnight = (currentDateTime - currentDateTime.replace(hour=0, minute=0, second=0,
97 microsecond=0)).total_seconds()
98 currentStep = int(secSinceMidnight / self.secondsPerStep)
99 keyBase = currentDateTime.strftime("%Y-%m-%d-%%s")
100 cacheKeys = []
101 for x in range(0, self.steps):
102 cacheKeys.append(
103 db.Key(self.rateLimitKind, f"{self.resource}-{endPoint}-{keyBase % (currentStep - x)}"))
104 tmpRes = db.Get(cacheKeys)
105 return sum([x["value"] for x in tmpRes if x and currentDateTime < x["expires"]]) <= self.maxRate
107 def assertQuotaIsAvailable(self, setRetryAfterHeader: bool = True) -> bool:
108 """Assert quota is available.
110 If not quota is available a :class:`viur.core.errors.TooManyRequests`
111 exception will be raised.
113 :param setRetryAfterHeader: Set the Retry-After header on the
114 current request response, if the quota is exceeded.
115 :return: True if quota is available.
116 :raises: :exc:`viur.core.errors.TooManyRequests`, if no quote is available.
117 """
118 if self.isQuotaAvailable():
119 return True
120 if setRetryAfterHeader:
121 current.request.get().response.headers["Retry-After"] = str(self.maxRate * 60)
123 raise errors.TooManyRequests(
124 f"{self.steps} requests allowed per {self.maxRate} minute(s). Try again later."
125 )
128@PeriodicTask(60)
129def cleanOldRateLocks(*args, **kwargs) -> None:
130 DeleteEntitiesIter.startIterOnQuery(db.Query(RateLimit.rateLimitKind).filter("expires <", utils.utcNow()))