Airflow 分支:一个有时依赖上游任务的任务
我有两个任务:task_a
和 task_b
。有两个参数 run_task_a
和 run_task_b
,用来决定这两个任务是否要执行。还有一个参数是 task_a
的输入。这里有个重点:
如果 task_a
被执行,那么 task_b
只能在 task_a
完成后才能开始。如果 task_a
没有被执行,那么 task_b
可以随时开始。
(动机:task_a
是主要任务。每次运行 task_a
可能会产生一些无用的东西,而 task_b
是用来清理这些东西的。不过,有时候我们也希望独立触发 task_b
。)
这是我目前写的代码:
from airflow.decorators import dag, task
from airflow.models.param import Param
from datetime import datetime
default_args = {
'owner': 'xyz',
'email_on_retry': False,
'email_on_failure': False,
'retries': 0,
'provide_context': True,
'depends_on_past': False
}
@dag(
default_args=default_args,
start_date=datetime(2024, 3, 7),
schedule_interval=None,
params={
'run_task_a': Param(
True,
type='boolean'),
'run_task_b': Param(
True,
type='boolean'),
'param_for_task_a': Param(
'foo',
enum=['foo','bar'],
type='string')
}
)
def my_dag():
@task
def get_context_values(**context):
context_values = dict()
context_values['params'] = context['params']
return context_values
@task.branch
def branching(context_values):
tasks_to_run = []
if context_values['params']['run_task_a']:
tasks_to_run.append('task_a')
if context_values['params']['run_task_b']:
tasks_to_run.append('task_b')
return tasks_to_run
@task
def task_a(context_values):
param_for_task_a = context_values['params']['param_for_task_a']
if param_for_task_a == 'foo':
# Do some stuff
pass
if param_for_task_a == 'bar':
# Do some different stuff
pass
return None
@task
def task_b():
# Do some more stuff
return None
# Taskflow
context_values = get_context_values()
branching(context_values) >> [task_a(context_values),task_b()]
my_dag()
问题是,当 run_task_a == True
和 run_task_b == True
时:两个任务都会运行,但 task_b
并不会等 task_a
完成就开始,因为它们之间没有依赖关系。我尝试通过将 task_b
设置为 task_a
的下游任务来添加这种依赖关系,但这样一来,如果 run_task_a == False
而 run_task_b == True
,task_b
就不会运行了。触发规则似乎也不是解决办法,因为如果 run_task_b == False
,task_b
就不应该运行。
3 个回答
一个简单的解决办法是让任务 task_b 依赖于任务 task_a,使用 >>
来连接它们,同时把 task_b 的 触发规则 改成 none_failed
。这样,如果因为分支的原因任务 task_a 被跳过,任务 task_b 仍然会运行,即使它的上游任务没有成功。
你可以调整这个分支函数,让它根据不同的参数来控制程序的执行流程。你需要确保只有在task_a
需要运行的时候,task_b
才会依赖于task_a
。
from airflow.models.param import Param
from datetime import datetime
default_args = {
'owner': 'xyz',
'email_on_retry': False,
'email_on_failure': False,
'retries': 0,
'provide_context': True,
'depends_on_past': False
}
@dag(
default_args=default_args,
start_date=datetime(2024, 3, 7),
schedule_interval=None,
params={
'run_task_a': Param(True, type='boolean'),
'run_task_b': Param(True, type='boolean'),
'param_for_task_a': Param('foo', enum=['foo', 'bar'], type='string')
}
)
def my_dag():
@task
def get_context_values(**context):
return context['params']
@task
def task_a(param_for_task_a):
if param_for_task_a == 'foo':
# Do some stuff
pass
elif param_for_task_a == 'bar':
# Do some different stuff
pass
return None
@task
def task_b():
# Do some more stuff
return None
# Taskflow
context_values = get_context_values()
# Branching logic
run_task_a = context_values['run_task_a']
run_task_b = context_values['run_task_b']
if run_task_a and run_task_b:
task_a_output = task_a(context_values['param_for_task_a'])
task_a_output >> task_b()
elif run_task_b:
task_b()
# Add an else clause if you want to handle cases where both are False
my_dag_instance = my_dag()
经过很多次的尝试和错误,我们终于成功地使用了短路操作来实现这个功能:
from airflow.decorators import dag, task
from airflow.models.param import Param
from airflow.utils.trigger_rule import TriggerRule
from datetime import datetime
default_args = {
'owner': 'xyz',
'email_on_retry': False,
'email_on_failure': False,
'retries': 0,
'provide_context': True,
'depends_on_past': False
}
@dag(
default_args=default_args,
start_date=datetime(2024, 3, 7),
schedule_interval=None,
params={
'run_task_a': Param(
True,
type='boolean'),
'run_task_b': Param(
True,
type='boolean'),
'param_for_task_a': Param(
'foo',
enum=['foo','bar'],
type='string')
}
)
def my_dag():
@task
def get_context_values(**context):
context_values = dict()
context_values['params'] = context['params']
return context_values
@task.short_circuit
def short_circuit(context_values,key):
return context_values['params'][key]
@task
def task_a(context_values):
param_for_task_a = context_values['params']['param_for_task_a']
if param_for_task_a == 'foo':
# Do some stuff
pass
if param_for_task_a == 'bar':
# Do some different stuff
pass
return None
@task(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
def task_b():
# Do some more stuff
return None
# Taskflow
context_values = get_context_values()
short_circuit_a = short_circuit.override(
task_id='short_circuit_a',ignore_downstream_trigger_rules=False)(context_values,'run_task_a')
a = task_a(context_values)
short_circuit_b = short_circuit.override(
task_id='short_circuit_b')(context_values,'run_task_b')
b = task_b()
short_circuit_a >> a
short_circuit_b >> b
a >> b
my_dag()