Coverage for /home/runner/work/viur-core/viur-core/viur/src/viur/core/tasks.py: 21%
433 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-07 19:28 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-07 19:28 +0000
1import abc
2import datetime
3import functools
4import json
5import logging
6import os
7import sys
8import time
9import traceback
10import typing as t
12import grpc
13import requests
14from google import protobuf
15from google.cloud import tasks_v2
17from viur.core import current, db, errors, utils
18from viur.core.config import conf
19from viur.core.decorators import exposed, skey
20from viur.core.module import Module
22CUSTOM_OBJ = t.TypeVar("CUSTOM_OBJ") # A JSON serializable object
25class CustomEnvironmentHandler(abc.ABC):
26 @abc.abstractmethod
27 def serialize(self) -> CUSTOM_OBJ:
28 """Serialize custom environment data
30 This function must not require any parameters and must
31 return a JSON serializable object with the desired information.
32 """
33 ...
35 @abc.abstractmethod
36 def restore(self, obj: CUSTOM_OBJ) -> None:
37 """Restore custom environment data
39 This function will receive the object from :meth:`serialize` and should write
40 the information it contains to the environment of the deferred request.
41 """
42 ...
45_gaeApp = os.environ.get("GAE_APPLICATION")
47queueRegion = None
48if _gaeApp: 48 ↛ 50line 48 didn't jump to line 50 because the condition on line 48 was never true
50 try:
51 headers = {"Metadata-Flavor": "Google"}
52 r = requests.get("http://metadata.google.internal/computeMetadata/v1/instance/region", headers=headers)
53 # r.text should be look like this "projects/(project-number)/region/(region)"
54 # like so "projects/1234567890/region/europe-west3"
55 queueRegion = r.text.split("/")[-1]
56 except Exception as e: # Something went wrong with the Google Metadata Sever we use the old way
57 logging.warning(f"Can't obtain queueRegion from Google MetaData Server due exception {e=}")
58 regionPrefix = _gaeApp.split("~")[0]
59 regionMap = {
60 "h": "europe-west3",
61 "e": "europe-west1"
62 }
63 queueRegion = regionMap.get(regionPrefix)
65if not queueRegion and conf.instance.is_dev_server and os.getenv("TASKS_EMULATOR") is None: 65 ↛ 67line 65 didn't jump to line 67 because the condition on line 65 was never true
66 # Probably local development server
67 logging.warning("Taskqueue disabled, tasks will run inline!")
69if not conf.instance.is_dev_server or os.getenv("TASKS_EMULATOR") is None: 69 ↛ 72line 69 didn't jump to line 72 because the condition on line 69 was always true
70 taskClient = tasks_v2.CloudTasksClient()
71else:
72 taskClient = tasks_v2.CloudTasksClient(
73 transport=tasks_v2.services.cloud_tasks.transports.CloudTasksGrpcTransport(
74 channel=grpc.insecure_channel(os.getenv("TASKS_EMULATOR"))
75 )
76 )
77 queueRegion = "local"
79_periodicTasks: dict[str, dict[t.Callable, datetime.timedelta]] = {}
80_callableTasks = {}
81_deferred_tasks = {}
82_startupTasks = []
83_appengineServiceIPs = {"10.0.0.1", "0.1.0.1", "0.1.0.2"}
86class PermanentTaskFailure(Exception):
87 """Indicates that a task failed, and will never succeed."""
88 pass
91def removePeriodicTask(task: t.Callable) -> None:
92 """
93 Removes a periodic task from the queue. Useful to unqueue an task
94 that has been inherited from an overridden module.
95 """
96 global _periodicTasks
97 assert "periodicTaskName" in dir(task), "This is not a periodic task? "
98 for queueDict in _periodicTasks.values():
99 if task in queueDict:
100 del queueDict[task]
103class CallableTaskBase:
104 """
105 Base class for user-callable tasks.
106 Must be subclassed.
107 """
108 key = None # Unique identifier for this task
109 name = None # Human-Readable name
110 descr = None # Human-Readable description
111 kindName = "server-task"
113 def canCall(self) -> bool:
114 """
115 Checks wherever the current user can execute this task
116 :returns: bool
117 """
118 return False
120 def dataSkel(self):
121 """
122 If additional data is needed, return a skeleton-instance here.
123 These values are then passed to *execute*.
124 """
125 return None
127 def execute(self):
128 """
129 The actual code that should be run goes here.
130 """
131 raise NotImplementedError()
134class TaskHandler(Module):
135 """
136 Task Handler.
137 Handles calling of Tasks (queued and periodic), and performs updatechecks
138 Do not Modify. Do not Subclass.
139 """
140 adminInfo = None
141 retryCountWarningThreshold = 25
143 def findBoundTask(self, task: t.Callable, obj: object, depth: int = 0) -> t.Optional[tuple[t.Callable, object]]:
145 """
146 Tries to locate the instance, this function belongs to.
147 If it succeeds in finding it, it returns the function and its instance (-> its "self").
148 Otherwise, None is returned.
149 :param task: A callable decorated with @PeriodicTask
150 :param obj: Object, which will be scanned in the current iteration.
151 :param depth: Current iteration depth.
152 """
153 if depth > 3 or "periodicTaskName" not in dir(task): # Limit the maximum amount of recursions
154 return None
155 for attr in dir(obj):
156 if attr.startswith("_"):
157 continue
158 try:
159 v = getattr(obj, attr)
160 except AttributeError:
161 continue
162 if callable(v) and "periodicTaskName" in dir(v) and str(v.periodicTaskName) == str(task.periodicTaskName):
163 return v, obj
164 if not isinstance(v, str) and not callable(v):
165 res = self.findBoundTask(task, v, depth + 1)
166 if res:
167 return res
168 return None
170 @exposed
171 def queryIter(self, *args, **kwargs):
172 """
173 This processes one chunk of a queryIter (see below).
174 """
175 req = current.request.get().request
176 self._validate_request()
177 data = utils.json.loads(req.body)
178 if data["classID"] not in MetaQueryIter._classCache:
179 logging.error(f"""Could not continue queryIter - {data["classID"]} not known on this instance""")
180 MetaQueryIter._classCache[data["classID"]]._qryStep(data)
182 @exposed
183 def deferred(self, *args, **kwargs):
184 """
185 This catches one deferred call and routes it to its destination
186 """
187 req = current.request.get().request
188 self._validate_request()
189 # Check if the retry count exceeds our warning threshold
190 retryCount = req.headers.get("X-Appengine-Taskretrycount", None)
191 if retryCount and int(retryCount) == self.retryCountWarningThreshold:
192 from viur.core import email
193 email.send_email_to_admins(
194 "Deferred task retry counter exceeded warning threshold",
195 f"""Task {req.headers.get("X-Appengine-Taskname", "")} is retried for the {retryCount}th time."""
196 )
198 cmd, data = utils.json.loads(req.body)
199 funcPath, args, kwargs, env = data
200 logging.debug(f"Call task {funcPath} with {cmd=} {args=} {kwargs=} {env=}")
202 if env:
203 if "user" in env and env["user"]:
204 current.session.get()["user"] = env["user"]
206 # Load current user into context variable if user module is there.
207 if user_mod := getattr(conf.main_app.vi, "user", None):
208 current.user.set(user_mod.getCurrentUser())
209 if "lang" in env and env["lang"]:
210 current.language.set(env["lang"])
211 if "transactionMarker" in env:
212 marker = db.Get(db.Key("viur-transactionmarker", env["transactionMarker"]))
213 if not marker:
214 logging.info(f"""Dropping task, transaction {env["transactionMarker"]} did not apply""")
215 return
216 else:
217 logging.info(f"""Executing task, transaction {env["transactionMarker"]} did succeed""")
218 if "custom" in env and conf.tasks_custom_environment_handler:
219 # Check if we need to restore additional environmental data
220 conf.tasks_custom_environment_handler.restore(env["custom"])
221 if cmd == "rel":
222 caller = conf.main_app
223 pathlist = [x for x in funcPath.split("/") if x]
224 for currpath in pathlist:
225 if currpath not in dir(caller):
226 logging.error(f"Could not resolve {funcPath=} (failed part was {currpath!r})")
227 return
228 caller = getattr(caller, currpath)
229 try:
230 caller(*args, **kwargs)
231 except PermanentTaskFailure:
232 logging.error("PermanentTaskFailure")
233 except Exception as e:
234 logging.exception(e)
235 raise errors.RequestTimeout() # Task-API should retry
236 elif cmd == "unb":
237 if funcPath not in _deferred_tasks:
238 logging.error(f"Missed deferred task {funcPath=} ({args=},{kwargs=})")
239 # We call the deferred function *directly* (without walking through the mkDeferred lambda), so ensure
240 # that any hit to another deferred function will defer again
242 current.request.get().DEFERRED_TASK_CALLED = True
243 try:
244 _deferred_tasks[funcPath](*args, **kwargs)
245 except PermanentTaskFailure:
246 logging.error("PermanentTaskFailure")
247 except Exception as e:
248 logging.exception(e)
249 raise errors.RequestTimeout() # Task-API should retry
251 @exposed
252 def cron(self, cronName="default", *args, **kwargs):
253 req = current.request.get()
254 if not conf.instance.is_dev_server:
255 self._validate_request(require_cron=True, require_taskname=False)
256 if cronName not in _periodicTasks:
257 logging.warning(f"Cron request {cronName} doesn't have any tasks")
258 # We must defer from cron, as tasks will interpret it as a call originating from task-queue - causing deferred
259 # functions to be called directly, wich causes calls with _countdown etc set to fail.
260 req.DEFERRED_TASK_CALLED = True
261 for task, interval in _periodicTasks[cronName].items(): # Call all periodic tasks bound to that queue
262 periodicTaskName = task.periodicTaskName.lower()
263 if interval: # Ensure this task doesn't get called to often
264 lastCall = db.Get(db.Key("viur-task-interval", periodicTaskName))
265 if lastCall and utils.utcNow() - lastCall["date"] < interval:
266 logging.debug(f"Task {periodicTaskName!r} has already run recently - skipping.")
267 continue
268 res = self.findBoundTask(task, conf.main_app)
269 try:
270 if res: # Its bound, call it this way :)
271 res[0]()
272 else:
273 task() # It seems it wasn't bound - call it as a static method
274 except Exception as e:
275 logging.error(f"Error calling periodic task {periodicTaskName}")
276 logging.exception(e)
277 else:
278 logging.debug(f"Successfully called task {periodicTaskName}")
279 if interval:
280 # Update its last-call timestamp
281 entry = db.Entity(db.Key("viur-task-interval", periodicTaskName))
282 entry["date"] = utils.utcNow()
283 db.Put(entry)
284 logging.debug("Periodic tasks complete")
285 for currentTask in db.Query("viur-queued-tasks").iter(): # Look for queued tasks
286 db.Delete(currentTask.key())
287 if currentTask["taskid"] in _callableTasks:
288 task = _callableTasks[currentTask["taskid"]]()
289 tmpDict = {}
290 for k in currentTask.keys():
291 if k == "taskid":
292 continue
293 tmpDict[k] = json.loads(currentTask[k])
294 try:
295 task.execute(**tmpDict)
296 except Exception as e:
297 logging.error("Error executing Task")
298 logging.exception(e)
299 logging.debug("Scheduled tasks complete")
301 def _validate_request(
302 self,
303 *,
304 require_cron: bool = False,
305 require_taskname: bool = True,
306 ) -> None:
307 """
308 Validate the header and metadata of a request
310 If the request is valid, None will be returned.
311 Otherwise, an exception will be raised.
313 :param require_taskname: Require "X-AppEngine-TaskName" header
314 :param require_cron: Require "X-Appengine-Cron" header
315 """
316 req = current.request.get().request
317 if (
318 req.environ.get("HTTP_X_APPENGINE_USER_IP") not in _appengineServiceIPs
319 and (not conf.instance.is_dev_server or os.getenv("TASKS_EMULATOR") is None)
320 ):
321 logging.critical("Detected an attempted XSRF attack. This request did not originate from Task Queue.")
322 raise errors.Forbidden()
323 if require_cron and "X-Appengine-Cron" not in req.headers:
324 logging.critical('Detected an attempted XSRF attack. The header "X-AppEngine-Cron" was not set.')
325 raise errors.Forbidden()
326 if require_taskname and "X-AppEngine-TaskName" not in req.headers:
327 logging.critical('Detected an attempted XSRF attack. The header "X-AppEngine-Taskname" was not set.')
328 raise errors.Forbidden()
330 @exposed
331 def list(self, *args, **kwargs):
332 """Lists all user-callable tasks which are callable by this user"""
333 global _callableTasks
335 tasks = db.SkelListRef()
336 tasks.extend([{
337 "key": x.key,
338 "name": str(x.name),
339 "descr": str(x.descr)
340 } for x in _callableTasks.values() if x().canCall()
341 ])
343 return self.render.list(tasks)
345 @exposed
346 @skey(allow_empty=True)
347 def execute(self, taskID, *args, **kwargs):
348 """Queues a specific task for the next maintenance run"""
349 global _callableTasks
350 if taskID in _callableTasks:
351 task = _callableTasks[taskID]()
352 else:
353 return
354 if not task.canCall():
355 raise errors.Unauthorized()
356 skel = task.dataSkel()
357 if not kwargs or not skel.fromClient(kwargs) or utils.parse.bool(kwargs.get("bounce")):
358 return self.render.add(skel)
359 task.execute(**skel.accessedValues)
360 return self.render.addSuccess(skel)
363TaskHandler.admin = True
364TaskHandler.vi = True
365TaskHandler.html = True
368# Decorators
370def retry_n_times(retries: int, email_recipients: None | str | list[str] = None,
371 tpl: None | str = None) -> t.Callable:
372 """
373 Wrapper for deferred tasks to limit the amount of retries
375 :param retries: Number of maximum allowed retries
376 :param email_recipients: Email addresses to which a info should be sent
377 when the retry limit is reached.
378 :param tpl: Instead of the standard text, a custom template can be used.
379 The name of an email template must be specified.
380 """
381 # language=Jinja2
382 string_template = \
383 """Task {{func_name}} failed {{retries}} times
384 This was the last attempt.<br>
385 <pre>{{func_module|escape}}.{{func_name|escape}}({{signature|escape}})</pre>
386 <pre>{{traceback|escape}}</pre>"""
388 def outer_wrapper(func):
389 @functools.wraps(func)
390 def inner_wrapper(*args, **kwargs):
391 try:
392 retry_count = int(current.request.get().request.headers.get("X-Appengine-Taskretrycount", -1))
393 except AttributeError:
394 # During warmup current.request is None (at least on local devserver)
395 retry_count = -1
396 try:
397 return func(*args, **kwargs)
398 except Exception as exc:
399 logging.exception(f"Task {func.__qualname__} failed: {exc}")
400 logging.info(
401 f"This was the {retry_count}. retry."
402 f"{retries - retry_count} retries remaining. (total = {retries})"
403 )
404 if retry_count < retries:
405 # Raise the exception to mark this task as failed, so the task queue can retry it.
406 raise exc
407 else:
408 if email_recipients:
409 args_repr = [repr(arg) for arg in args]
410 kwargs_repr = [f"{k!s}={v!r}" for k, v in kwargs.items()]
411 signature = ", ".join(args_repr + kwargs_repr)
412 try:
413 from viur.core import email
414 email.send_email(
415 dests=email_recipients,
416 tpl=tpl,
417 stringTemplate=string_template if tpl is None else string_template,
418 # The following params provide information for the emails templates
419 func_name=func.__name__,
420 func_qualname=func.__qualname__,
421 func_module=func.__module__,
422 retries=retries,
423 args=args,
424 kwargs=kwargs,
425 signature=signature,
426 traceback=traceback.format_exc(),
427 )
428 except Exception:
429 logging.exception("Failed to send email to %r", email_recipients)
430 # Mark as permanently failed (could return nothing too)
431 raise PermanentTaskFailure()
433 return inner_wrapper
435 return outer_wrapper
438def noRetry(f):
439 """Prevents a deferred Function from being called a second time"""
440 logging.warning("Use of `@noRetry` is deprecated; Use `@retry_n_times(0)` instead!", stacklevel=2)
441 return retry_n_times(0)(f)
444def CallDeferred(func: t.Callable) -> t.Callable:
445 """
446 This is a decorator, which always calls the wrapped method deferred.
448 The call will be packed and queued into a Cloud Tasks queue.
449 The Task Queue calls the TaskHandler which executed the wrapped function
450 with the originally arguments in a different request.
453 In addition to the arguments for the wrapped methods you can set these:
455 _queue: Specify the queue in which the task should be pushed.
456 "default" is the default value. The queue must exist (use the queue.yaml).
458 _countdown: Specify a time in seconds after which the task should be called.
459 This time is relative to the moment where the wrapped method has been called.
461 _eta: Instead of a relative _countdown value you can specify a `datetime`
462 when the task is scheduled to be attempted or retried.
464 _name: Specify a custom name for the cloud task. Must be unique and can
465 contain only letters ([A-Za-z]), numbers ([0-9]), hyphens (-), colons (:), or periods (.).
467 _target_version: Specify a version on which to run this task.
468 By default, a task will be run on the same version where the wrapped method has been called.
470 _call_deferred: Calls the @CallDeferred decorated method directly.
471 This is for example necessary, to call a super method which is decorated with @CallDeferred.
473 .. code-block:: python
475 # Example for use of the _call_deferred-parameter
476 class A(Module):
477 @CallDeferred
478 def task(self):
479 ...
481 class B(A):
482 @CallDeferred
483 def task(self):
484 super().task(_call_deferred=False) # avoid secondary deferred call
485 ...
487 See also:
488 https://cloud.google.com/python/docs/reference/cloudtasks/latest/google.cloud.tasks_v2.types.Task
489 https://cloud.google.com/python/docs/reference/cloudtasks/latest/google.cloud.tasks_v2.types.CreateTaskRequest
490 """
491 if "viur_doc_build" in dir(sys): 491 ↛ 494line 491 didn't jump to line 494 because the condition on line 491 was always true
492 return func
494 __undefinedFlag_ = object()
496 def make_deferred(
497 func: t.Callable,
498 self=__undefinedFlag_,
499 *args,
500 _queue: str = "default",
501 _name: str | None = None,
502 _call_deferred: bool = True,
503 _target_version: str = conf.instance.app_version,
504 _eta: datetime.datetime | None = None,
505 _countdown: int = 0,
506 **kwargs
507 ):
508 if _eta is not None and _countdown:
509 raise ValueError("You cannot set the _countdown and _eta argument together!")
511 logging.debug(
512 f"make_deferred {func=}, {self=}, {args=}, {kwargs=}, "
513 f"{_queue=}, {_name=}, {_call_deferred=}, {_target_version=}, {_eta=}, {_countdown=}"
514 )
516 try:
517 req = current.request.get()
518 except Exception: # This will fail for warmup requests
519 req = None
521 if not queueRegion:
522 # Run tasks inline
523 logging.debug(f"{func=} will be executed inline")
525 @functools.wraps(func)
526 def task():
527 if self is __undefinedFlag_:
528 return func(*args, **kwargs)
529 else:
530 return func(self, *args, **kwargs)
532 if req:
533 req.pendingTasks.append(task) # This property only exists on development server!
534 else:
535 # Warmup request or something - we have to call it now as we can't defer it :/
536 task()
538 return # Ensure no result gets passed back
540 # It's the deferred method which is called from the task queue, this has to be called directly
541 _call_deferred &= not (req and req.request.headers.get("X-Appengine-Taskretrycount")
542 and "DEFERRED_TASK_CALLED" not in dir(req))
544 if not _call_deferred:
545 if self is __undefinedFlag_:
546 return func(*args, **kwargs)
548 req.DEFERRED_TASK_CALLED = True
549 return func(self, *args, **kwargs)
551 else:
552 try:
553 if self.__class__.__name__ == "index":
554 funcPath = func.__name__
555 else:
556 funcPath = f"{self.modulePath}/{func.__name__}"
557 command = "rel"
558 except Exception:
559 funcPath = f"{func.__name__}.{func.__module__}"
560 if self is not __undefinedFlag_:
561 args = (self,) + args # Re-append self to args, as this function is (hopefully) unbound
562 command = "unb"
564 # Try to preserve the important data from the current environment
565 try: # We might get called inside a warmup request without session
566 usr = current.session.get().get("user")
567 if "password" in usr:
568 del usr["password"]
569 except Exception:
570 usr = None
572 env = {"user": usr}
574 try:
575 env["lang"] = current.language.get()
576 except AttributeError: # This isn't originating from a normal request
577 pass
579 if db.IsInTransaction():
580 # We have to ensure transaction guarantees for that task also
581 env["transactionMarker"] = db.acquireTransactionSuccessMarker()
582 # We move that task at least 90 seconds into the future so the transaction has time to settle
583 _countdown = max(90, _countdown) # Countdown can be set to None
585 if conf.tasks_custom_environment_handler:
586 # Check if this project relies on additional environmental variables and serialize them too
587 env["custom"] = conf.tasks_custom_environment_handler.serialize()
589 # Create task description
590 task = tasks_v2.Task(
591 app_engine_http_request=tasks_v2.AppEngineHttpRequest(
592 body=utils.json.dumps((command, (funcPath, args, kwargs, env))).encode(),
593 http_method=tasks_v2.HttpMethod.POST,
594 relative_uri="/_tasks/deferred",
595 app_engine_routing=tasks_v2.AppEngineRouting(
596 version=_target_version,
597 ),
598 ),
599 )
600 if _name is not None:
601 task.name = taskClient.task_path(conf.instance.project_id, queueRegion, _queue, _name)
603 # Set a schedule time in case eta (absolut) or countdown (relative) was set.
604 if seconds := _countdown:
605 _eta = utils.utcNow() + datetime.timedelta(seconds=seconds)
606 if _eta:
607 # We must send a Timestamp Protobuf instead of a date-string
608 timestamp = protobuf.timestamp_pb2.Timestamp()
609 timestamp.FromDatetime(_eta)
610 task.schedule_time = timestamp
612 # Use the client to build and send the task.
613 parent = taskClient.queue_path(conf.instance.project_id, queueRegion, _queue)
614 logging.debug(f"{parent=}, {task=}")
615 taskClient.create_task(tasks_v2.CreateTaskRequest(parent=parent, task=task))
617 logging.info(f"Created task {func.__name__}.{func.__module__} with {args=} {kwargs=} {env=}")
619 global _deferred_tasks
620 _deferred_tasks[f"{func.__name__}.{func.__module__}"] = func
622 @functools.wraps(func)
623 def wrapper(*args, **kwargs):
624 return make_deferred(func, *args, **kwargs)
626 return wrapper
629def callDeferred(func):
630 """
631 Deprecated version of CallDeferred
632 """
633 import logging, warnings
635 msg = "Use of @callDeferred is deprecated, use @CallDeferred instead."
636 logging.warning(msg, stacklevel=3)
637 warnings.warn(msg, stacklevel=3)
639 return CallDeferred(func)
642def PeriodicTask(interval: datetime.timedelta | int | float = 0, cronName: str = "default") -> t.Callable:
643 """
644 Decorator to call a function periodically during cron job execution.
646 Interval defines a lower bound for the call-frequency for the given task, specified as a timedelta.
648 The true interval of how often cron jobs are being executed is defined in the project's cron.yaml file.
649 This defaults to 4 hours (see https://github.com/viur-framework/viur-base/blob/main/deploy/cron.yaml).
650 In case the interval defined here is lower than 4 hours, the task will be fired once every 4 hours anyway.
652 :param interval: Call at most the given timedelta.
653 """
654 def make_decorator(fn):
655 nonlocal interval
656 if fn.__name__.startswith("_"): 656 ↛ 657line 656 didn't jump to line 657 because the condition on line 656 was never true
657 raise RuntimeError("Periodic called methods cannot start with an underscore! "
658 f"Please rename {fn.__name__!r}")
660 if cronName not in _periodicTasks: 660 ↛ 663line 660 didn't jump to line 663 because the condition on line 660 was always true
661 _periodicTasks[cronName] = {}
663 if isinstance(interval, (int, float)) and "tasks.periodic.useminutes" in conf.compatibility: 663 ↛ 664line 663 didn't jump to line 664 because the condition on line 663 was never true
664 logging.warning(
665 f"PeriodicTask assuming {interval=} minutes here. This is changed into seconds in future. "
666 f"Please use `datetime.timedelta(minutes={interval})` for clarification.",
667 stacklevel=2,
668 )
669 interval *= 60
671 _periodicTasks[cronName][fn] = utils.parse.timedelta(interval)
672 fn.periodicTaskName = f"{fn.__module__}_{fn.__qualname__}".replace(".", "_").lower()
673 return fn
675 return make_decorator
678def CallableTask(fn: t.Callable) -> t.Callable:
679 """Marks a Class as representing a user-callable Task.
680 It *should* extend CallableTaskBase and *must* provide
681 its API
682 """
683 global _callableTasks
684 _callableTasks[fn.key] = fn
685 return fn
688def StartupTask(fn: t.Callable) -> t.Callable:
689 """
690 Functions decorated with this are called shortly at instance startup.
691 It's *not* guaranteed that they actually run on the instance that just started up!
692 Wrapped functions must not take any arguments.
693 """
694 global _startupTasks
695 _startupTasks.append(fn)
696 return fn
699@CallDeferred
700def runStartupTasks():
701 """
702 Runs all queued startupTasks.
703 Do not call directly!
704 """
705 global _startupTasks
706 for st in _startupTasks:
707 st()
710class MetaQueryIter(type):
711 """
712 This is the meta class for QueryIters.
713 Used only to keep track of all subclasses of QueryIter so we can emit the callbacks
714 on the correct class.
715 """
716 _classCache = {} # Mapping className -> Class
718 def __init__(cls, name, bases, dct):
719 MetaQueryIter._classCache[str(cls)] = cls
720 cls.__classID__ = str(cls)
721 super(MetaQueryIter, cls).__init__(name, bases, dct)
724class QueryIter(object, metaclass=MetaQueryIter):
725 """
726 BaseClass to run a database Query and process each entry matched.
727 This will run each step deferred, so it is possible to process an arbitrary number of entries
728 without being limited by time or memory.
730 To use this class create a subclass, override the classmethods handleEntry and handleFinish and then
731 call startIterOnQuery with an instance of a database Query (and possible some custom data to pass along)
732 """
733 queueName = "default" # Name of the taskqueue we will run on
735 @classmethod
736 def startIterOnQuery(cls, query: db.Query, customData: t.Any = None) -> None:
737 """
738 Starts iterating the given query on this class. Will return immediately, the first batch will already
739 run deferred.
741 Warning: Any custom data *must* be json-serializable and *must* be passed in customData. You cannot store
742 any data on this class as each chunk may run on a different instance!
743 """
744 assert not (query._customMultiQueryMerge or query._calculateInternalMultiQueryLimit), \
745 "Cannot iter a query with postprocessing"
746 assert isinstance(query.queries, db.QueryDefinition), "Unsatisfiable query or query with an IN filter"
747 qryDict = {
748 "kind": query.kind,
749 "srcSkel": query.srcSkel.kindName if query.srcSkel else None,
750 "filters": query.queries.filters,
751 "orders": [(propName, sortOrder.value) for propName, sortOrder in query.queries.orders],
752 "startCursor": query.queries.startCursor,
753 "endCursor": query.queries.endCursor,
754 "origKind": query.origKind,
755 "distinct": query.queries.distinct,
756 "classID": cls.__classID__,
757 "customData": customData,
758 "totalCount": 0
759 }
760 cls._requeueStep(qryDict)
762 @classmethod
763 def _requeueStep(cls, qryDict: dict[str, t.Any]) -> None:
764 """
765 Internal use only. Pushes a new step defined in qryDict to either the taskqueue or append it to
766 the current request if we are on the local development server.
767 """
768 if not queueRegion: # Run tasks inline - hopefully development server
769 req = current.request.get()
770 task = lambda *args, **kwargs: cls._qryStep(qryDict)
771 if req:
772 req.pendingTasks.append(task) # < This property will be only exist on development server!
773 return
774 taskClient.create_task(tasks_v2.CreateTaskRequest(
775 parent=taskClient.queue_path(conf.instance.project_id, queueRegion, cls.queueName),
776 task=tasks_v2.Task(
777 app_engine_http_request=tasks_v2.AppEngineHttpRequest(
778 body=utils.json.dumps(qryDict).encode(),
779 http_method=tasks_v2.HttpMethod.POST,
780 relative_uri="/_tasks/queryIter",
781 app_engine_routing=tasks_v2.AppEngineRouting(
782 version=conf.instance.app_version,
783 ),
784 )
785 ),
786 ))
788 @classmethod
789 def _qryStep(cls, qryDict: dict[str, t.Any]) -> None:
790 """
791 Internal use only. Processes one block of five entries from the query defined in qryDict and
792 reschedules the next block.
793 """
794 from viur.core.skeleton import skeletonByKind
795 qry = db.Query(qryDict["kind"])
796 qry.srcSkel = skeletonByKind(qryDict["srcSkel"])() if qryDict["srcSkel"] else None
797 qry.queries.filters = qryDict["filters"]
798 qry.queries.orders = [(propName, db.SortOrder(sortOrder)) for propName, sortOrder in qryDict["orders"]]
799 qry.setCursor(qryDict["startCursor"], qryDict["endCursor"])
800 qry.origKind = qryDict["origKind"]
801 qry.queries.distinct = qryDict["distinct"]
802 if qry.srcSkel:
803 qryIter = qry.fetch(5)
804 else:
805 qryIter = qry.run(5)
806 for item in qryIter:
807 try:
808 cls.handleEntry(item, qryDict["customData"])
809 except: # First exception - we'll try another time (probably/hopefully transaction collision)
810 time.sleep(5)
811 try:
812 cls.handleEntry(item, qryDict["customData"])
813 except Exception as e: # Second exception - call error_handler
814 try:
815 doCont = cls.handleError(item, qryDict["customData"], e)
816 except Exception as e:
817 logging.error(f"handleError failed on {item} - bailing out")
818 logging.exception(e)
819 doCont = False
820 if not doCont:
821 logging.error(f"Exiting queryIter on cursor {qry.getCursor()!r}")
822 return
823 qryDict["totalCount"] += 1
824 cursor = qry.getCursor()
825 if cursor:
826 qryDict["startCursor"] = cursor
827 cls._requeueStep(qryDict)
828 else:
829 cls.handleFinish(qryDict["totalCount"], qryDict["customData"])
831 @classmethod
832 def handleEntry(cls, entry, customData):
833 """
834 Overridable hook to process one entry. "entry" will be either an db.Entity or an
835 SkeletonInstance (if that query has been created by skel.all())
837 Warning: If your query has an sortOrder other than __key__ and you modify that property here
838 it is possible to encounter that object later one *again* (as it may jump behind the current cursor).
839 """
840 logging.debug(f"handleEntry called on {cls} with {entry}.")
842 @classmethod
843 def handleFinish(cls, totalCount: int, customData):
844 """
845 Overridable hook that indicates the current run has been finished.
846 """
847 logging.debug(f"handleFinish called on {cls} with {totalCount} total Entries processed")
849 @classmethod
850 def handleError(cls, entry, customData, exception) -> bool:
851 """
852 Handle a error occurred in handleEntry.
853 If this function returns True, the queryIter continues, otherwise it breaks and prints the current cursor.
854 """
855 logging.debug(f"handleError called on {cls} with {entry}.")
856 logging.exception(exception)
857 return True
860class DeleteEntitiesIter(QueryIter):
861 """
862 Simple Query-Iter to delete all entities encountered.
864 ..Warning: When iterating over skeletons, make sure that the
865 query was created using `Skeleton().all()`.
866 This way the `Skeleton.delete()` method can be used and
867 the appropriate post-processing can be done.
868 """
870 @classmethod
871 def handleEntry(cls, entry, customData):
872 from viur.core.skeleton import SkeletonInstance
873 if isinstance(entry, SkeletonInstance):
874 entry.delete()
875 else:
876 db.Delete(entry.key)