Coverage for /home/runner/work/viur-core/viur-core/viur/src/viur/core/ratelimit.py: 0%
69 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-07 19:28 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-07 19:28 +0000
1import datetime
3from viur.core import current, db, errors, utils
4from viur.core.tasks import PeriodicTask, DeleteEntitiesIter
5import typing as t
6from datetime import timedelta
9class RateLimit(object):
10 """
11 This class is used to restrict access to certain functions to *maxRate* calls per minute.
13 Usage: Create an instance of this object in you modules __init__ function. then call
14 isQuotaAvailable before executing the action to check if there is quota available and
15 after executing the action decrementQuota.
17 """
18 rateLimitKind = "viur-ratelimit"
20 def __init__(self, resource: str, maxRate: int, minutes: int, method: t.Literal["ip", "user"]):
21 """
22 Initializes a new RateLimit gate.
24 :param resource: Name of the resource to protect
25 :param maxRate: Amount of tries allowed in the give time-span
26 :param minutes: Length of the time-span in minutes
27 :param method: Lock by IP or by the current user
28 """
29 super(RateLimit, self).__init__()
30 self.resource = resource
31 self.maxRate = maxRate
32 self.minutes = minutes
33 self.steps = min(minutes, 5)
34 self.secondsPerStep = 60 * (float(minutes) / float(self.steps))
35 assert method in ["ip", "user"], "method must be 'ip' or 'user'"
36 self.useUser = method == "user"
38 def _getEndpointKey(self) -> db.Key | str:
39 """
40 :warning:
41 It's invalid to call _getEndpointKey if method is set to user and there's no user logged in!
43 :return: the key associated with the current endpoint (it's IP or the key of the current user)
44 """
45 if self.useUser:
46 user = current.user.get()
47 assert user, "Cannot decrement usage from guest!"
48 return user["key"]
49 else:
50 remoteAddr = current.request.get().request.remote_addr
51 if "::" in remoteAddr: # IPv6 in shorted form
52 remoteAddr = remoteAddr.split(":")
53 blankIndex = remoteAddr.index("")
54 missigParts = ["0000"] * (8 - len(remoteAddr))
55 remoteAddr = remoteAddr[:blankIndex] + missigParts + remoteAddr[blankIndex + 1:]
56 return ":".join(remoteAddr[:4])
57 elif ":" in remoteAddr: # It's IPv6, so we remove the last 64 bits (interface id)
58 # as it is easily controlled by the user
59 return ":".join(remoteAddr.split(":")[:4])
60 else: # It's IPv4, simply return that address
61 return remoteAddr
63 def _getCurrentTimeKey(self) -> str:
64 """
65 :return: the current lockperiod used in second position of the memcache key
66 """
67 dateTime = utils.utcNow()
68 key = dateTime.strftime("%Y-%m-%d-%%s")
69 secsinceMidnight = (dateTime - dateTime.replace(hour=0, minute=0, second=0, microsecond=0)).total_seconds()
70 currentStep = int(secsinceMidnight / self.secondsPerStep)
71 return key % currentStep
73 def decrementQuota(self) -> None:
74 """
75 Removes one attempt from the pool of available Quota for that user/ip
76 """
78 def updateTxn(cacheKey: str) -> None:
79 key = db.Key(self.rateLimitKind, cacheKey)
80 obj = db.Get(key)
81 if obj is None:
82 obj = db.Entity(key)
83 obj["value"] = 0
84 obj["value"] += 1
85 obj["expires"] = utils.utcNow() + timedelta(minutes=2 * self.minutes)
86 db.Put(obj)
88 lockKey = f"{self.resource}-{self._getEndpointKey()}-{self._getCurrentTimeKey()}"
89 db.RunInTransaction(updateTxn, lockKey)
91 def isQuotaAvailable(self) -> bool:
92 """
93 Checks if there's currently quota available for the current user/ip
94 :return: True if there's quota available, False otherwise
95 """
96 endPoint = self._getEndpointKey()
97 currentDateTime = utils.utcNow()
98 secSinceMidnight = (currentDateTime - currentDateTime.replace(hour=0, minute=0, second=0,
99 microsecond=0)).total_seconds()
100 currentStep = int(secSinceMidnight / self.secondsPerStep)
101 keyBase = currentDateTime.strftime("%Y-%m-%d-%%s")
102 cacheKeys = []
103 for x in range(0, self.steps):
104 cacheKeys.append(
105 db.Key(self.rateLimitKind, f"{self.resource}-{endPoint}-{keyBase % (currentStep - x)}"))
106 tmpRes = db.Get(cacheKeys)
107 return sum([x["value"] for x in tmpRes if x and currentDateTime < x["expires"]]) <= self.maxRate
109 def assertQuotaIsAvailable(self, setRetryAfterHeader: bool = True) -> bool:
110 """Assert quota is available.
112 If not quota is available a :class:`viur.core.errors.TooManyRequests`
113 exception will be raised.
115 :param setRetryAfterHeader: Set the Retry-After header on the
116 current request response, if the quota is exceeded.
117 :return: True if quota is available.
118 :raises: :exc:`viur.core.errors.TooManyRequests`, if no quote is available.
119 """
120 if self.isQuotaAvailable():
121 return True
122 if setRetryAfterHeader:
123 current.request.get().response.headers["Retry-After"] = str(self.maxRate * 60)
125 raise errors.TooManyRequests(
126 f"{self.steps} requests allowed per {self.maxRate} minute(s). Try again later."
127 )
130@PeriodicTask(interval=datetime.timedelta(hours=1))
131def cleanOldRateLocks(*args, **kwargs) -> None:
132 DeleteEntitiesIter.startIterOnQuery(db.Query(RateLimit.rateLimitKind).filter("expires <", utils.utcNow()))