Celery 源码分析(七): Task 执行

754 阅读5分钟

Task 执行

在上一篇文章中,我们最终发现apply_async任务的执行是通过Request里面的execute_using_pool 里面的apply_async方法。但是在开始之前,我们有必要引入一个小插曲,那就是,前面不止一次的我们提到过celery内部是通过多进程来提高性能的。在execute_using_pool方法中也的确传入了一个pool池子一样的东西,因为我们没有办法肯定这个池子就是进程池而不是线程池。

回到我们Worker的子组件的Pool初始化逻辑:

class Pool(bootsteps.StartStopStep):

    def create(self, w):
        semaphore = None
        max_restarts = None
        if w.app.conf.worker_pool in GREEN_POOLS:  # pragma: no cover
            warnings.warn(UserWarning(W_POOL_SETTING))
        threaded = not w.use_eventloop or IS_WINDOWS
        procs = w.min_concurrency
        w.process_task = w._process_task
        if not threaded:
            semaphore = w.semaphore = LaxBoundedSemaphore(procs)
            w._quick_acquire = w.semaphore.acquire
            w._quick_release = w.semaphore.release
            max_restarts = 100
            if w.pool_putlocks and w.pool_cls.uses_semaphore:
                w.process_task = w._process_task_sem
        allow_restart = w.pool_restarts
        pool = w.pool = self.instantiate(
            w.pool_cls, w.min_concurrency,
            initargs=(w.app, w.hostname),
            maxtasksperchild=w.max_tasks_per_child,
            max_memory_per_child=w.max_memory_per_child,
            timeout=w.time_limit,
            soft_timeout=w.soft_time_limit,
            putlocks=w.pool_putlocks and threaded,
            lost_worker_timeout=w.worker_lost_wait,
            threads=threaded,
            max_restarts=max_restarts,
            allow_restart=allow_restart,
            forking_enable=True,
            semaphore=semaphore,
            sched_strategy=self.optimization,
            app=w.app,
        )
        _set_task_join_will_block(pool.task_join_will_block)
        return pool

注意这个pool_cls,最终发现它默认指向的是:

ALIASES = {
    'prefork': 'celery.concurrency.prefork:TaskPool',
    'eventlet': 'celery.concurrency.eventlet:TaskPool',
    'gevent': 'celery.concurrency.gevent:TaskPool',
    'solo': 'celery.concurrency.solo:TaskPool',
    'processes': 'celery.concurrency.prefork:TaskPool',  # XXX compat alias
}

默认的话使用的是celery.concurrency.solo.TaskPool

然后我们发现,实际上apply_async = pool.apply_async, 也就是说实际上我们的任务是交给TaskPoolapply_async方法去执行的。找到TaskPool的apply_async方法,发现实际上是执行了on_apply

def apply_async(self, target, args=None, kwargs=None, **options):
    """Equivalent of the :func:`apply` built-in function.

    Callbacks should optimally return as soon as possible since
    otherwise the thread which handles the result will get blocked.
    """
    kwargs = {} if not kwargs else kwargs
    args = [] if not args else args
    if self._does_debug:
        logger.debug('TaskPool: Apply %s (args:%s kwargs:%s)',
                     target, truncate(safe_repr(args), 1024),
                     truncate(safe_repr(kwargs), 1024))

    return self.on_apply(target, args, kwargs,
                         waitforslot=self.putlocks,
                         callbacks_propagate=self.callbacks_propagate,
                         **options)

然后呢,我们发现,TaskPool这个类,实际上把on_apply覆盖成为了apply_target :

class TaskPool(BasePool):
    """Solo task pool (blocking, inline, fast)."""

    body_can_be_buffer = True

    def __init__(self, *args, **kwargs):
        super(TaskPool, self).__init__(*args, **kwargs)
        self.on_apply = apply_target
        self.limit = 1
        signals.worker_process_init.send(sender=None)

    def _get_info(self):
        return {
            'max-concurrency': 1,
            'processes': [os.getpid()],
            'max-tasks-per-child': None,
            'put-guarded-by-semaphore': True,
            'timeouts': (),
        }

apply_target里面的内容是:

def apply_target(target, args=(), kwargs=None, callback=None,
                 accept_callback=None, pid=None, getpid=os.getpid,
                 propagate=(), monotonic=monotonic, **_):
    """Apply function within pool context."""
    kwargs = {} if not kwargs else kwargs
    if accept_callback:
        accept_callback(pid or getpid(), monotonic())
    try:
        ret = target(*args, **kwargs)
    except propagate:
        raise
    except Exception:
        raise
    except (WorkerShutdown, WorkerTerminate):
        raise
    except BaseException as exc:
        try:
            reraise(WorkerLostError, WorkerLostError(repr(exc)),
                    sys.exc_info()[2])
        except WorkerLostError:
            callback(ExceptionInfo())
    else:
        callback(ret)

这里的target指向了:_fast_trace_task方法。

这里的args里面则是我们task的源信息:

 ('itsm.ticket.tasks.add', '9784662a-2afe-4fce-a320-9f537f1486eb', {'lang': 'py', 'task': 'itsm.ticket.tasks.add', 'id': '9784662a-2afe-4fce-a320-9f537f1486eb', 'shadow': None, 'eta': None, 'expires': None, 'group': None, 'retries': 0, 'timelimit': [None, None], 'root_id': '9784662a-2afe-4fce-a320-9f537f1486eb', 'parent_id': None, 'argsrepr': '(1, 2)', 'kwargsrepr': '{}', 'origin': 'gen62869@MARKHAN-MB0', 'reply_to': 'd6aa95a4-f8cd-3c4e-a04f-a6b3c80431e8', 'correlation_id': '9784662a-2afe-4fce-a320-9f537f1486eb', 'hostname': 'celery@MARKHAN-MB0', 'delivery_info': {'exchange': '', 'routing_key': 'default', 'priority': 0, 'redelivered': None}, 'args': (1, 2), 'kwargs': {}}, b'\x80\x02K\x01K\x02\x86q\x00}q\x01}q\x02(X\t\x00\x00\x00callbacksq\x03NX\x08\x00\x00\x00errbacksq\x04NX\x05\x00\x00\x00chainq\x05NX\x05\x00\x00\x00chordq\x06Nu\x87q\x07.', 'application/x-python-serialize', 'binary')

所以最终执行我们任务的方法是_fast_trace_task,看样子是执行哪些不怎么延时的普通任务用的,找到它,点进去看看:

def _fast_trace_task(task, uuid, request, body, content_type,
                     content_encoding, loads=loads_message, _loc=None,
                     hostname=None, **_):
    _loc = _localized if not _loc else _loc
    embed = None
    tasks, accept, hostname = _loc
    if content_type:
        args, kwargs, embed = loads(
            body, content_type, content_encoding, accept=accept,
        )
    else:
        args, kwargs, embed = body
    request.update({
        'args': args, 'kwargs': kwargs,
        'hostname': hostname, 'is_eager': False,
    }, **embed or {})
    R, I, T, Rstr = tasks[task].__trace__(
        uuid, args, kwargs, request,
    )
    return (1, R, T) if I else (0, Rstr, T)

最后那几行代码属实给我看懵了,这都啥啊这是。

这里的tasks实际上是我们之前维护的tasks信息, 不知道什么时候被扔到当前线程的上下文里面去的。

 requests = {'lang': 'py', 'task': 'itsm.ticket.tasks.add', 'id': '5c88a734-21e5-4a61-b137-d9e6d2c3827d', 'shadow': None, 'eta': None, 'expires': None, 'group': None, 'retries': 0, 'timelimit': [None, None], 'root_id': '5c88a734-21e5-4a61-b137-d9e6d2c3827d', 'parent_id': None, 'argsrepr': '(1, 2)', 'kwargsrepr': '{}', 'origin': 'gen62869@MARKHAN-MB0', 'reply_to': 'd6aa95a4-f8cd-3c4e-a04f-a6b3c80431e8', 'correlation_id': '5c88a734-21e5-4a61-b137-d9e6d2c3827d', 'hostname': 'celery@MARKHAN-MB0', 'delivery_info': {'exchange': '', 'routing_key': 'default', 'priority': 0, 'redelivered': None}, 'args': (1, 2), 'kwargs': {}, 'is_eager': False, 'callbacks': None, 'errbacks': None, 'chain': None, 'chord': None}

所以他实际上执行的是我们的代理task对象的__trace__,找到它,最终发现它指向的是build_tracer:

def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                 Info=TraceInfo, eager=False, propagate=False, app=None,
                 monotonic=monotonic, trace_ok_t=trace_ok_t,
                 IGNORE_STATES=IGNORE_STATES):
		# 拿到函数
    fun = task if task_has_custom(task, '__call__') else task.run

    def trace_task(uuid, args, kwargs, request=None):
        # R      - is the possibly prepared return value.
        # I      - is the Info object.
        # T      - runtime
        # Rstr   - textual representation of return value
        # retval - is the always unmodified return value.
        # state  - is the resulting task state.

        # This function is very long because we've unrolled all the calls
        # for performance reasons, and because the function is so long
        # we want the main variables (I, and R) to stand out visually from the
        # the rest of the variables, so breaking PEP8 is worth it ;)
        R = I = T = Rstr = retval = state = None
        task_request = None
        time_start = monotonic()
        try:
            try:
                kwargs.items
            except AttributeError:
                raise InvalidTaskError(
                    'Task keyword arguments is not a mapping')
            push_task(task)
            task_request = Context(request or {}, args=args,
                                   called_directly=False, kwargs=kwargs)
            root_id = task_request.root_id or uuid
            task_priority = task_request.delivery_info.get('priority') if \
                inherit_parent_priority else None
            push_request(task_request)
            try:
                # -*- PRE -*-
                if prerun_receivers:
                    send_prerun(sender=task, task_id=uuid, task=task,
                                args=args, kwargs=kwargs)
                loader_task_init(uuid, task)
                if track_started:
                    store_result(
                        uuid, {'pid': pid, 'hostname': hostname}, STARTED,
                        request=task_request,
                    )

                # -*- TRACE -*-
                try:
                    # z在这执行了
                    R = retval = fun(*args, **kwargs)
                    state = SUCCESS
                    
                    """
                    省略部分代码
                    """

    return trace_task

至此,task已经正常被执行完成了,但是执行中的结果如何处理以及执行过程报错了的一些容错措施,留到后面再去讲解。