Coverage for /home/runner/work/viur-core/viur-core/viur/src/viur/core/tasks.py: 21%
437 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-16 22:16 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-16 22:16 +0000
1import abc
2import datetime
3import functools
4import json
5import logging
6import os
7import sys
8import time
9import traceback
10import typing as t
12import grpc
13import requests
14from google import protobuf
15from google.cloud import tasks_v2
17from viur.core import current, db, errors, utils
18from viur.core.config import conf
19from viur.core.decorators import exposed, skey
20from viur.core.module import Module
22CUSTOM_OBJ = t.TypeVar("CUSTOM_OBJ") # A JSON serializable object
25class CustomEnvironmentHandler(abc.ABC):
26 @abc.abstractmethod
27 def serialize(self) -> CUSTOM_OBJ:
28 """Serialize custom environment data
30 This function must not require any parameters and must
31 return a JSON serializable object with the desired information.
32 """
33 ...
35 @abc.abstractmethod
36 def restore(self, obj: CUSTOM_OBJ) -> None:
37 """Restore custom environment data
39 This function will receive the object from :meth:`serialize` and should write
40 the information it contains to the environment of the deferred request.
41 """
42 ...
45_gaeApp = os.environ.get("GAE_APPLICATION")
47queueRegion = None
48if _gaeApp: 48 ↛ 50line 48 didn't jump to line 50 because the condition on line 48 was never true
50 try:
51 headers = {"Metadata-Flavor": "Google"}
52 r = requests.get("http://metadata.google.internal/computeMetadata/v1/instance/region", headers=headers)
53 # r.text should be look like this "projects/(project-number)/region/(region)"
54 # like so "projects/1234567890/region/europe-west3"
55 queueRegion = r.text.split("/")[-1]
56 except Exception as e: # Something went wrong with the Google Metadata Sever we use the old way
57 logging.warning(f"Can't obtain queueRegion from Google MetaData Server due exception {e=}")
58 regionPrefix = _gaeApp.split("~")[0]
59 regionMap = {
60 "h": "europe-west3",
61 "e": "europe-west1"
62 }
63 queueRegion = regionMap.get(regionPrefix)
65if not queueRegion and conf.instance.is_dev_server and os.getenv("TASKS_EMULATOR") is None: 65 ↛ 67line 65 didn't jump to line 67 because the condition on line 65 was never true
66 # Probably local development server
67 logging.warning("Taskqueue disabled, tasks will run inline!")
69if not conf.instance.is_dev_server or os.getenv("TASKS_EMULATOR") is None: 69 ↛ 72line 69 didn't jump to line 72 because the condition on line 69 was always true
70 taskClient = tasks_v2.CloudTasksClient()
71else:
72 taskClient = tasks_v2.CloudTasksClient(
73 transport=tasks_v2.services.cloud_tasks.transports.CloudTasksGrpcTransport(
74 channel=grpc.insecure_channel(os.getenv("TASKS_EMULATOR"))
75 )
76 )
77 queueRegion = "local"
79_periodicTasks: dict[str, dict[t.Callable, int]] = {}
80_callableTasks = {}
81_deferred_tasks = {}
82_startupTasks = []
83_appengineServiceIPs = {"10.0.0.1", "0.1.0.1", "0.1.0.2"}
86class PermanentTaskFailure(Exception):
87 """Indicates that a task failed, and will never succeed."""
88 pass
91def removePeriodicTask(task: t.Callable) -> None:
92 """
93 Removes a periodic task from the queue. Useful to unqueue an task
94 that has been inherited from an overridden module.
95 """
96 global _periodicTasks
97 assert "periodicTaskName" in dir(task), "This is not a periodic task? "
98 for queueDict in _periodicTasks.values():
99 if task in queueDict:
100 del queueDict[task]
103class CallableTaskBase:
104 """
105 Base class for user-callable tasks.
106 Must be subclassed.
107 """
108 key = None # Unique identifier for this task
109 name = None # Human-Readable name
110 descr = None # Human-Readable description
111 kindName = "server-task"
113 def canCall(self) -> bool:
114 """
115 Checks wherever the current user can execute this task
116 :returns: bool
117 """
118 return False
120 def dataSkel(self):
121 """
122 If additional data is needed, return a skeleton-instance here.
123 These values are then passed to *execute*.
124 """
125 return None
127 def execute(self):
128 """
129 The actual code that should be run goes here.
130 """
131 raise NotImplementedError()
134class TaskHandler(Module):
135 """
136 Task Handler.
137 Handles calling of Tasks (queued and periodic), and performs updatechecks
138 Do not Modify. Do not Subclass.
139 """
140 adminInfo = None
141 retryCountWarningThreshold = 25
143 def findBoundTask(self, task: t.Callable, obj: object, depth: int = 0) -> t.Optional[tuple[t.Callable, object]]:
145 """
146 Tries to locate the instance, this function belongs to.
147 If it succeeds in finding it, it returns the function and its instance (-> its "self").
148 Otherwise, None is returned.
149 :param task: A callable decorated with @PeriodicTask
150 :param obj: Object, which will be scanned in the current iteration.
151 :param depth: Current iteration depth.
152 """
153 if depth > 3 or "periodicTaskName" not in dir(task): # Limit the maximum amount of recursions
154 return None
155 for attr in dir(obj):
156 if attr.startswith("_"):
157 continue
158 try:
159 v = getattr(obj, attr)
160 except AttributeError:
161 continue
162 if callable(v) and "periodicTaskName" in dir(v) and str(v.periodicTaskName) == str(task.periodicTaskName):
163 return v, obj
164 if not isinstance(v, str) and not callable(v):
165 res = self.findBoundTask(task, v, depth + 1)
166 if res:
167 return res
168 return None
170 @exposed
171 def queryIter(self, *args, **kwargs):
172 """
173 This processes one chunk of a queryIter (see below).
174 """
175 req = current.request.get().request
176 self._validate_request()
177 data = utils.json.loads(req.body)
178 if data["classID"] not in MetaQueryIter._classCache:
179 logging.error(f"""Could not continue queryIter - {data["classID"]} not known on this instance""")
180 MetaQueryIter._classCache[data["classID"]]._qryStep(data)
182 @exposed
183 def deferred(self, *args, **kwargs):
184 """
185 This catches one deferred call and routes it to its destination
186 """
187 req = current.request.get().request
188 self._validate_request()
189 # Check if the retry count exceeds our warning threshold
190 retryCount = req.headers.get("X-Appengine-Taskretrycount", None)
191 if retryCount and int(retryCount) == self.retryCountWarningThreshold:
192 from viur.core import email
193 email.sendEMailToAdmins(
194 "Deferred task retry counter exceeded warning threshold",
195 f"""Task {req.headers.get("X-Appengine-Taskname", "")} is retried for the {retryCount}th time."""
196 )
198 cmd, data = utils.json.loads(req.body)
199 funcPath, args, kwargs, env = data
200 logging.debug(f"Call task {funcPath} with {cmd=} {args=} {kwargs=} {env=}")
202 if env:
203 if "user" in env and env["user"]:
204 current.session.get()["user"] = env["user"]
206 # Load current user into context variable if user module is there.
207 if user_mod := getattr(conf.main_app.vi, "user", None):
208 current.user.set(user_mod.getCurrentUser())
209 if "lang" in env and env["lang"]:
210 current.language.set(env["lang"])
211 if "transactionMarker" in env:
212 marker = db.Get(db.Key("viur-transactionmarker", env["transactionMarker"]))
213 if not marker:
214 logging.info(f"""Dropping task, transaction {env["transactionMarker"]} did not apply""")
215 return
216 else:
217 logging.info(f"""Executing task, transaction {env["transactionMarker"]} did succeed""")
218 if "custom" in env and conf.tasks_custom_environment_handler:
219 # Check if we need to restore additional environmental data
220 conf.tasks_custom_environment_handler.restore(env["custom"])
221 if cmd == "rel":
222 caller = conf.main_app
223 pathlist = [x for x in funcPath.split("/") if x]
224 for currpath in pathlist:
225 if currpath not in dir(caller):
226 logging.error(f"Could not resolve {funcPath=} (failed part was {currpath!r})")
227 return
228 caller = getattr(caller, currpath)
229 try:
230 caller(*args, **kwargs)
231 except PermanentTaskFailure:
232 logging.error("PermanentTaskFailure")
233 except Exception as e:
234 logging.exception(e)
235 raise errors.RequestTimeout() # Task-API should retry
236 elif cmd == "unb":
237 if funcPath not in _deferred_tasks:
238 logging.error(f"Missed deferred task {funcPath=} ({args=},{kwargs=})")
239 # We call the deferred function *directly* (without walking through the mkDeferred lambda), so ensure
240 # that any hit to another deferred function will defer again
242 current.request.get().DEFERRED_TASK_CALLED = True
243 try:
244 _deferred_tasks[funcPath](*args, **kwargs)
245 except PermanentTaskFailure:
246 logging.error("PermanentTaskFailure")
247 except Exception as e:
248 logging.exception(e)
249 raise errors.RequestTimeout() # Task-API should retry
251 @exposed
252 def cron(self, cronName="default", *args, **kwargs):
253 req = current.request.get()
254 if not conf.instance.is_dev_server:
255 self._validate_request(require_cron=True, require_taskname=False)
256 if cronName not in _periodicTasks:
257 logging.warning(f"Cron request {cronName} doesn't have any tasks")
258 # We must defer from cron, as tasks will interpret it as a call originating from task-queue - causing deferred
259 # functions to be called directly, wich causes calls with _countdown etc set to fail.
260 req.DEFERRED_TASK_CALLED = True
261 for task, interval in _periodicTasks[cronName].items(): # Call all periodic tasks bound to that queue
262 periodicTaskName = task.periodicTaskName.lower()
263 if interval: # Ensure this task doesn't get called to often
264 lastCall = db.Get(db.Key("viur-task-interval", periodicTaskName))
265 if lastCall and utils.utcNow() - lastCall["date"] < datetime.timedelta(minutes=interval):
266 logging.debug(f"Task {periodicTaskName!r} has already run recently - skipping.")
267 continue
268 res = self.findBoundTask(task, conf.main_app)
269 try:
270 if res: # Its bound, call it this way :)
271 res[0]()
272 else:
273 task() # It seems it wasn't bound - call it as a static method
274 except Exception as e:
275 logging.error(f"Error calling periodic task {periodicTaskName}")
276 logging.exception(e)
277 else:
278 logging.debug(f"Successfully called task {periodicTaskName}")
279 if interval:
280 # Update its last-call timestamp
281 entry = db.Entity(db.Key("viur-task-interval", periodicTaskName))
282 entry["date"] = utils.utcNow()
283 db.Put(entry)
284 logging.debug("Periodic tasks complete")
285 for currentTask in db.Query("viur-queued-tasks").iter(): # Look for queued tasks
286 db.Delete(currentTask.key())
287 if currentTask["taskid"] in _callableTasks:
288 task = _callableTasks[currentTask["taskid"]]()
289 tmpDict = {}
290 for k in currentTask.keys():
291 if k == "taskid":
292 continue
293 tmpDict[k] = json.loads(currentTask[k])
294 try:
295 task.execute(**tmpDict)
296 except Exception as e:
297 logging.error("Error executing Task")
298 logging.exception(e)
299 logging.debug("Scheduled tasks complete")
301 def _validate_request(
302 self,
303 *,
304 require_cron: bool = False,
305 require_taskname: bool = True,
306 ) -> None:
307 """
308 Validate the header and metadata of a request
310 If the request is valid, None will be returned.
311 Otherwise, an exception will be raised.
313 :param require_taskname: Require "X-AppEngine-TaskName" header
314 :param require_cron: Require "X-Appengine-Cron" header
315 """
316 req = current.request.get().request
317 if (
318 req.environ.get("HTTP_X_APPENGINE_USER_IP") not in _appengineServiceIPs
319 and (not conf.instance.is_dev_server or os.getenv("TASKS_EMULATOR") is None)
320 ):
321 logging.critical("Detected an attempted XSRF attack. This request did not originate from Task Queue.")
322 raise errors.Forbidden()
323 if require_cron and "X-Appengine-Cron" not in req.headers:
324 logging.critical('Detected an attempted XSRF attack. The header "X-AppEngine-Cron" was not set.')
325 raise errors.Forbidden()
326 if require_taskname and "X-AppEngine-TaskName" not in req.headers:
327 logging.critical('Detected an attempted XSRF attack. The header "X-AppEngine-Taskname" was not set.')
328 raise errors.Forbidden()
330 @exposed
331 def list(self, *args, **kwargs):
332 """Lists all user-callable tasks which are callable by this user"""
333 global _callableTasks
335 tasks = db.SkelListRef()
336 tasks.extend([{
337 "key": x.key,
338 "name": str(x.name),
339 "descr": str(x.descr)
340 } for x in _callableTasks.values() if x().canCall()
341 ])
343 return self.render.list(tasks)
345 @exposed
346 @skey(allow_empty=True)
347 def execute(self, taskID, *args, **kwargs):
348 """Queues a specific task for the next maintenance run"""
349 global _callableTasks
350 if taskID in _callableTasks:
351 task = _callableTasks[taskID]()
352 else:
353 return
354 if not task.canCall():
355 raise errors.Unauthorized()
356 skel = task.dataSkel()
357 if not kwargs or not skel.fromClient(kwargs) or utils.parse.bool(kwargs.get("bounce")):
358 return self.render.add(skel)
359 task.execute(**skel.accessedValues)
360 return self.render.addSuccess(skel)
363TaskHandler.admin = True
364TaskHandler.vi = True
365TaskHandler.html = True
368# Decorators
370def retry_n_times(retries: int, email_recipients: None | str | list[str] = None,
371 tpl: None | str = None) -> t.Callable:
372 """
373 Wrapper for deferred tasks to limit the amount of retries
375 :param retries: Number of maximum allowed retries
376 :param email_recipients: Email addresses to which a info should be sent
377 when the retry limit is reached.
378 :param tpl: Instead of the standard text, a custom template can be used.
379 The name of an email template must be specified.
380 """
381 # language=Jinja2
382 string_template = \
383 """Task {{func_name}} failed {{retries}} times
384 This was the last attempt.<br>
385 <pre>{{func_module|escape}}.{{func_name|escape}}({{signature|escape}})</pre>
386 <pre>{{traceback|escape}}</pre>"""
388 def outer_wrapper(func):
389 @functools.wraps(func)
390 def inner_wrapper(*args, **kwargs):
391 try:
392 retry_count = int(current.request.get().request.headers.get("X-Appengine-Taskretrycount", -1))
393 except AttributeError:
394 # During warmup current.request is None (at least on local devserver)
395 retry_count = -1
396 try:
397 return func(*args, **kwargs)
398 except Exception as exc:
399 logging.exception(f"Task {func.__qualname__} failed: {exc}")
400 logging.info(
401 f"This was the {retry_count}. retry."
402 f"{retries - retry_count} retries remaining. (total = {retries})"
403 )
404 if retry_count < retries:
405 # Raise the exception to mark this task as failed, so the task queue can retry it.
406 raise exc
407 else:
408 if email_recipients:
409 args_repr = [repr(arg) for arg in args]
410 kwargs_repr = [f"{k!s}={v!r}" for k, v in kwargs.items()]
411 signature = ", ".join(args_repr + kwargs_repr)
412 try:
413 from viur.core import email
414 email.sendEMail(
415 dests=email_recipients,
416 tpl=tpl,
417 stringTemplate=string_template if tpl is None else string_template,
418 # The following params provide information for the emails templates
419 func_name=func.__name__,
420 func_qualname=func.__qualname__,
421 func_module=func.__module__,
422 retries=retries,
423 args=args,
424 kwargs=kwargs,
425 signature=signature,
426 traceback=traceback.format_exc(),
427 )
428 except Exception:
429 logging.exception("Failed to send email to %r", email_recipients)
430 # Mark as permanently failed (could return nothing too)
431 raise PermanentTaskFailure()
433 return inner_wrapper
435 return outer_wrapper
438def noRetry(f):
439 """Prevents a deferred Function from being called a second time"""
440 logging.warning("Use of `@noRetry` is deprecated; Use `@retry_n_times(0)` instead!", stacklevel=2)
441 return retry_n_times(0)(f)
444def CallDeferred(func: t.Callable) -> t.Callable:
445 """
446 This is a decorator, which always calls the wrapped method deferred.
448 The call will be packed and queued into a Cloud Tasks queue.
449 The Task Queue calls the TaskHandler which executed the wrapped function
450 with the originally arguments in a different request.
453 In addition to the arguments for the wrapped methods you can set these:
455 _queue: Specify the queue in which the task should be pushed.
456 "default" is the default value. The queue must exist (use the queue.yaml).
458 _countdown: Specify a time in seconds after which the task should be called.
459 This time is relative to the moment where the wrapped method has been called.
461 _eta: Instead of a relative _countdown value you can specify a `datetime`
462 when the task is scheduled to be attempted or retried.
464 _name: Specify a custom name for the cloud task. Must be unique and can
465 contain only letters ([A-Za-z]), numbers ([0-9]), hyphens (-), colons (:), or periods (.).
467 _target_version: Specify a version on which to run this task.
468 By default, a task will be run on the same version where the wrapped method has been called.
470 _call_deferred: Calls the @CallDeferred decorated method directly.
471 This is for example necessary, to call a super method which is decorated with @CallDeferred.
473 .. code-block:: python
475 # Example for use of the _call_deferred-parameter
476 class A(Module):
477 @CallDeferred
478 def task(self):
479 ...
481 class B(A):
482 @CallDeferred
483 def task(self):
484 super().task(_call_deferred=False) # avoid secondary deferred call
485 ...
487 See also:
488 https://cloud.google.com/python/docs/reference/cloudtasks/latest/google.cloud.tasks_v2.types.Task
489 https://cloud.google.com/python/docs/reference/cloudtasks/latest/google.cloud.tasks_v2.types.CreateTaskRequest
490 """
491 if "viur_doc_build" in dir(sys): 491 ↛ 494line 491 didn't jump to line 494 because the condition on line 491 was always true
492 return func
494 __undefinedFlag_ = object()
496 def make_deferred(func, self=__undefinedFlag_, *args, **kwargs):
497 # Extract possibly provided task flags from kwargs
498 queue = kwargs.pop("_queue", "default")
499 call_deferred = kwargs.pop("_call_deferred", True)
500 target_version = kwargs.pop("_target_version", conf.instance.app_version)
501 if "_eta" in kwargs and "_countdown" in kwargs:
502 raise ValueError("You cannot set the _countdown and _eta argument together!")
503 taskargs = {k: kwargs.pop(f"_{k}", None) for k in ("countdown", "eta", "name")}
505 logging.debug(
506 f"make_deferred {func=}, {self=}, {args=}, {kwargs=}, "
507 f"{queue=}, {call_deferred=}, {target_version=}, {taskargs=}"
508 )
510 try:
511 req = current.request.get()
512 except: # This will fail for warmup requests
513 req = None
515 if not queueRegion:
516 # Run tasks inline
517 logging.debug(f"{func=} will be executed inline")
519 @functools.wraps(func)
520 def task():
521 if self is __undefinedFlag_:
522 return func(*args, **kwargs)
523 else:
524 return func(self, *args, **kwargs)
526 if req:
527 req.pendingTasks.append(task) # This property only exists on development server!
528 else:
529 # Warmup request or something - we have to call it now as we can't defer it :/
530 task()
532 return # Ensure no result gets passed back
534 # It's the deferred method which is called from the task queue, this has to be called directly
535 call_deferred &= not (req and req.request.headers.get("X-Appengine-Taskretrycount")
536 and "DEFERRED_TASK_CALLED" not in dir(req))
538 if not call_deferred:
539 if self is __undefinedFlag_:
540 return func(*args, **kwargs)
542 req.DEFERRED_TASK_CALLED = True
543 return func(self, *args, **kwargs)
545 else:
546 try:
547 if self.__class__.__name__ == "index":
548 funcPath = func.__name__
549 else:
550 funcPath = f"{self.modulePath}/{func.__name__}"
552 command = "rel"
554 except:
555 funcPath = f"{func.__name__}.{func.__module__}"
557 if self is not __undefinedFlag_:
558 args = (self,) + args # Re-append self to args, as this function is (hopefully) unbound
560 command = "unb"
562 taskargs["url"] = "/_tasks/deferred"
563 taskargs["headers"] = {"Content-Type": "application/octet-stream"}
565 # Try to preserve the important data from the current environment
566 try: # We might get called inside a warmup request without session
567 usr = current.session.get().get("user")
568 if "password" in usr:
569 del usr["password"]
571 except:
572 usr = None
574 env = {"user": usr}
576 try:
577 env["lang"] = current.language.get()
578 except AttributeError: # This isn't originating from a normal request
579 pass
581 if db.IsInTransaction():
582 # We have to ensure transaction guarantees for that task also
583 env["transactionMarker"] = db.acquireTransactionSuccessMarker()
584 # We move that task at least 90 seconds into the future so the transaction has time to settle
585 taskargs["countdown"] = max(90, taskargs.get("countdown") or 0) # Countdown can be set to None
587 if conf.tasks_custom_environment_handler:
588 # Check if this project relies on additional environmental variables and serialize them too
589 env["custom"] = conf.tasks_custom_environment_handler.serialize()
591 # Create task description
592 task = tasks_v2.Task(
593 app_engine_http_request=tasks_v2.AppEngineHttpRequest(
594 body=utils.json.dumps((command, (funcPath, args, kwargs, env))).encode(),
595 http_method=tasks_v2.HttpMethod.POST,
596 relative_uri=taskargs["url"],
597 app_engine_routing=tasks_v2.AppEngineRouting(
598 version=target_version,
599 ),
600 ),
601 )
602 if taskargs.get("name"):
603 task.name = taskClient.task_path(conf.instance.project_id, queueRegion, queue, taskargs["name"])
605 # Set a schedule time in case eta (absolut) or countdown (relative) was set.
606 eta = taskargs.get("eta")
607 if seconds := taskargs.get("countdown"):
608 eta = utils.utcNow() + datetime.timedelta(seconds=seconds)
609 if eta:
610 # We must send a Timestamp Protobuf instead of a date-string
611 timestamp = protobuf.timestamp_pb2.Timestamp()
612 timestamp.FromDatetime(eta)
613 task.schedule_time = timestamp
615 # Use the client to build and send the task.
616 parent = taskClient.queue_path(conf.instance.project_id, queueRegion, queue)
617 logging.debug(f"{parent=}, {task=}")
618 taskClient.create_task(tasks_v2.CreateTaskRequest(parent=parent, task=task))
620 logging.info(f"Created task {func.__name__}.{func.__module__} with {args=} {kwargs=} {env=}")
622 global _deferred_tasks
623 _deferred_tasks[f"{func.__name__}.{func.__module__}"] = func
625 @functools.wraps(func)
626 def wrapper(*args, **kwargs):
627 return make_deferred(func, *args, **kwargs)
629 return wrapper
632def callDeferred(func):
633 """
634 Deprecated version of CallDeferred
635 """
636 import logging, warnings
638 msg = "Use of @callDeferred is deprecated, use @CallDeferred instead."
639 logging.warning(msg, stacklevel=3)
640 warnings.warn(msg, stacklevel=3)
642 return CallDeferred(func)
645def PeriodicTask(interval: int = 0, cronName: str = "default") -> t.Callable:
646 """
647 Decorator to call a function periodic during maintenance.
648 Interval defines a lower bound for the call-frequency for this task;
649 it will not be called faster than each interval minutes.
650 (Note that the actual delay between two sequent might be much larger)
652 :param interval: Call at most every interval minutes. 0 means call as often as possible.
653 """
655 def mkDecorator(fn):
656 global _periodicTasks
657 if fn.__name__.startswith("_"): 657 ↛ 658line 657 didn't jump to line 658 because the condition on line 657 was never true
658 raise RuntimeError("Periodic called methods cannot start with an underscore! "
659 f"Please rename {fn.__name__!r}")
660 if cronName not in _periodicTasks: 660 ↛ 662line 660 didn't jump to line 662 because the condition on line 660 was always true
661 _periodicTasks[cronName] = {}
662 _periodicTasks[cronName][fn] = interval
663 fn.periodicTaskName = f"{fn.__module__}_{fn.__qualname__}".replace(".", "_").lower()
664 return fn
666 return mkDecorator
669def CallableTask(fn: t.Callable) -> t.Callable:
670 """Marks a Class as representing a user-callable Task.
671 It *should* extend CallableTaskBase and *must* provide
672 its API
673 """
674 global _callableTasks
675 _callableTasks[fn.key] = fn
676 return fn
679def StartupTask(fn: t.Callable) -> t.Callable:
680 """
681 Functions decorated with this are called shortly at instance startup.
682 It's *not* guaranteed that they actually run on the instance that just started up!
683 Wrapped functions must not take any arguments.
684 """
685 global _startupTasks
686 _startupTasks.append(fn)
687 return fn
690@CallDeferred
691def runStartupTasks():
692 """
693 Runs all queued startupTasks.
694 Do not call directly!
695 """
696 global _startupTasks
697 for st in _startupTasks:
698 st()
701class MetaQueryIter(type):
702 """
703 This is the meta class for QueryIters.
704 Used only to keep track of all subclasses of QueryIter so we can emit the callbacks
705 on the correct class.
706 """
707 _classCache = {} # Mapping className -> Class
709 def __init__(cls, name, bases, dct):
710 MetaQueryIter._classCache[str(cls)] = cls
711 cls.__classID__ = str(cls)
712 super(MetaQueryIter, cls).__init__(name, bases, dct)
715class QueryIter(object, metaclass=MetaQueryIter):
716 """
717 BaseClass to run a database Query and process each entry matched.
718 This will run each step deferred, so it is possible to process an arbitrary number of entries
719 without being limited by time or memory.
721 To use this class create a subclass, override the classmethods handleEntry and handleFinish and then
722 call startIterOnQuery with an instance of a database Query (and possible some custom data to pass along)
723 """
724 queueName = "default" # Name of the taskqueue we will run on
726 @classmethod
727 def startIterOnQuery(cls, query: db.Query, customData: t.Any = None) -> None:
728 """
729 Starts iterating the given query on this class. Will return immediately, the first batch will already
730 run deferred.
732 Warning: Any custom data *must* be json-serializable and *must* be passed in customData. You cannot store
733 any data on this class as each chunk may run on a different instance!
734 """
735 assert not (query._customMultiQueryMerge or query._calculateInternalMultiQueryLimit), \
736 "Cannot iter a query with postprocessing"
737 assert isinstance(query.queries, db.QueryDefinition), "Unsatisfiable query or query with an IN filter"
738 qryDict = {
739 "kind": query.kind,
740 "srcSkel": query.srcSkel.kindName if query.srcSkel else None,
741 "filters": query.queries.filters,
742 "orders": [(propName, sortOrder.value) for propName, sortOrder in query.queries.orders],
743 "startCursor": query.queries.startCursor,
744 "endCursor": query.queries.endCursor,
745 "origKind": query.origKind,
746 "distinct": query.queries.distinct,
747 "classID": cls.__classID__,
748 "customData": customData,
749 "totalCount": 0
750 }
751 cls._requeueStep(qryDict)
753 @classmethod
754 def _requeueStep(cls, qryDict: dict[str, t.Any]) -> None:
755 """
756 Internal use only. Pushes a new step defined in qryDict to either the taskqueue or append it to
757 the current request if we are on the local development server.
758 """
759 if not queueRegion: # Run tasks inline - hopefully development server
760 req = current.request.get()
761 task = lambda *args, **kwargs: cls._qryStep(qryDict)
762 if req:
763 req.pendingTasks.append(task) # < This property will be only exist on development server!
764 return
765 taskClient.create_task(tasks_v2.CreateTaskRequest(
766 parent=taskClient.queue_path(conf.instance.project_id, queueRegion, cls.queueName),
767 task=tasks_v2.Task(
768 app_engine_http_request=tasks_v2.AppEngineHttpRequest(
769 body=utils.json.dumps(qryDict).encode(),
770 http_method=tasks_v2.HttpMethod.POST,
771 relative_uri="/_tasks/queryIter",
772 app_engine_routing=tasks_v2.AppEngineRouting(
773 version=conf.instance.app_version,
774 ),
775 )
776 ),
777 ))
779 @classmethod
780 def _qryStep(cls, qryDict: dict[str, t.Any]) -> None:
781 """
782 Internal use only. Processes one block of five entries from the query defined in qryDict and
783 reschedules the next block.
784 """
785 from viur.core.skeleton import skeletonByKind
786 qry = db.Query(qryDict["kind"])
787 qry.srcSkel = skeletonByKind(qryDict["srcSkel"])() if qryDict["srcSkel"] else None
788 qry.queries.filters = qryDict["filters"]
789 qry.queries.orders = [(propName, db.SortOrder(sortOrder)) for propName, sortOrder in qryDict["orders"]]
790 qry.setCursor(qryDict["startCursor"], qryDict["endCursor"])
791 qry.origKind = qryDict["origKind"]
792 qry.queries.distinct = qryDict["distinct"]
793 if qry.srcSkel:
794 qryIter = qry.fetch(5)
795 else:
796 qryIter = qry.run(5)
797 for item in qryIter:
798 try:
799 cls.handleEntry(item, qryDict["customData"])
800 except: # First exception - we'll try another time (probably/hopefully transaction collision)
801 time.sleep(5)
802 try:
803 cls.handleEntry(item, qryDict["customData"])
804 except Exception as e: # Second exception - call error_handler
805 try:
806 doCont = cls.handleError(item, qryDict["customData"], e)
807 except Exception as e:
808 logging.error(f"handleError failed on {item} - bailing out")
809 logging.exception(e)
810 doCont = False
811 if not doCont:
812 logging.error(f"Exiting queryIter on cursor {qry.getCursor()!r}")
813 return
814 qryDict["totalCount"] += 1
815 cursor = qry.getCursor()
816 if cursor:
817 qryDict["startCursor"] = cursor
818 cls._requeueStep(qryDict)
819 else:
820 cls.handleFinish(qryDict["totalCount"], qryDict["customData"])
822 @classmethod
823 def handleEntry(cls, entry, customData):
824 """
825 Overridable hook to process one entry. "entry" will be either an db.Entity or an
826 SkeletonInstance (if that query has been created by skel.all())
828 Warning: If your query has an sortOrder other than __key__ and you modify that property here
829 it is possible to encounter that object later one *again* (as it may jump behind the current cursor).
830 """
831 logging.debug(f"handleEntry called on {cls} with {entry}.")
833 @classmethod
834 def handleFinish(cls, totalCount: int, customData):
835 """
836 Overridable hook that indicates the current run has been finished.
837 """
838 logging.debug(f"handleFinish called on {cls} with {totalCount} total Entries processed")
840 @classmethod
841 def handleError(cls, entry, customData, exception) -> bool:
842 """
843 Handle a error occurred in handleEntry.
844 If this function returns True, the queryIter continues, otherwise it breaks and prints the current cursor.
845 """
846 logging.debug(f"handleError called on {cls} with {entry}.")
847 logging.exception(exception)
848 return True
851class DeleteEntitiesIter(QueryIter):
852 """
853 Simple Query-Iter to delete all entities encountered.
855 ..Warning: When iterating over skeletons, make sure that the
856 query was created using `Skeleton().all()`.
857 This way the `Skeleton.delete()` method can be used and
858 the appropriate post-processing can be done.
859 """
861 @classmethod
862 def handleEntry(cls, entry, customData):
863 from viur.core.skeleton import SkeletonInstance
864 if isinstance(entry, SkeletonInstance):
865 entry.delete()
866 else:
867 db.Delete(entry.key)