Commit 26473feb authored by Ottomata's avatar Ottomata
Browse files

CI Pipeline fixes

parent a68eaf2e
......@@ -7,10 +7,13 @@ include:
stages:
- test
variables:
# Add CI_PROJECT_DIR to PYTHONPATH so we can import from wmf_airflow_common
PYTHONPATH: ${CI_PROJECT_DIR}
before_script:
- !reference [.setup_conda, before_script]
# Need these for transitive dependencies of pyarrow, skein, spark, etc.
# Need these for building transitive dependencies like pyarrow, skein, spark, etc.
- apt install -y gcc g++ libkrb5-dev libsasl2-dev
- pip install tox
......
......@@ -18,7 +18,7 @@ install_requires =
apache-airflow-providers-apache-spark
apache-airflow-providers-apache-hdfs
# TODO: Use gitlab pypi for this dependency
workflow_utils @ git+https://gitlab.wikimedia.org/repos/data-engineering/workflow_utils.git@v0.2.2
workflow_utils @ git+https://gitlab.wikimedia.org/repos/data-engineering/workflow_utils.git@main
isodate
mergedeep
skein ==0.8.1
......@@ -95,7 +95,7 @@ select = E9,F63,F7,F82
files = wmf_airflow_common
ignore_missing_imports = True
follow_imports = silent
; follow_imports = silent
# methods signature should be typed
disallow_untyped_defs = True
......
import os
import pytest
from unittest import mock
from airflow.models import DagBag
# This is needed because mediarequest_hourly_dag uses Variables.
# TODO: make this into a common fixture useable by all dags test.
@pytest.fixture(name='airflow', autouse=True, scope="session")
def fixture_airflow(tmp_path_factory):
"""
Sets up an airflow SQLlite database and airflow home
fixture used by the entire test session.
"""
from airflow.utils import db
airflow_environ = {
'AIRFLOW__CORE__LOAD_DEFAULT_CONNECTIONS': "False",
'AIRFLOW__CORE__LOAD_EXAMPLES': 'False',
'AIRFLOW__CORE__UNIT_TEST_MODE': 'True',
'AIRFLOW_HOME': os.path.join(tmp_path_factory.mktemp('airflow_home'))
}
with mock.patch.dict(os.environ, airflow_environ, clear=True):
db.resetdb()
yield
# TODO: Make a common fixture that automatically loads and test all dag validity.
# https://www.astronomer.io/events/recaps/testing-airflow-to-bulletproof-your-code/s
@pytest.fixture(name='dagbag')
def fixture_dagbag():
dag_bag = DagBag(None, include_examples=False, read_dags_from_db=False)
dag_file = os.path.join(os.path.abspath(os.getcwd()),
'analytics_test', 'dags', 'aqs', 'hourly_dag.py')
'analytics', 'dags', 'aqs', 'hourly_dag.py')
dag_bag.process_file(dag_file)
return dag_bag
def test_aqs_hourly_loaded(dagbag):
def test_aqs_hourly_loaded(airflow, dagbag):
assert dagbag.import_errors == {}
dag = dagbag.get_dag(dag_id="aqs_hourly")
assert dag is not None
......@@ -28,4 +53,4 @@ def test_aqs_hourly_loaded(dagbag):
assert dag.default_args['do_xcom_push'] is False
# Tests that the defaults from dag_config are here.
assert dag.default_args['metastore_conn_id'] == 'analytics-test-hive'
assert dag.default_args['metastore_conn_id'] == 'analytics-hive'
import os
import pytest
from unittest import mock
from airflow.models import DagBag
# This is needed because mediarequest_hourly_dag uses Variables.
# TODO: make this into a common fixture useable by all dags test.
@pytest.fixture(name='airflow', autouse=True, scope="session")
def fixture_airflow(tmp_path_factory):
"""
Sets up an airflow SQLlite database and airflow home
fixture used by the entire test session.
"""
from airflow.utils import db
airflow_environ = {
'AIRFLOW__CORE__LOAD_DEFAULT_CONNECTIONS': "False",
'AIRFLOW__CORE__LOAD_EXAMPLES': 'False',
'AIRFLOW__CORE__UNIT_TEST_MODE': 'True',
'AIRFLOW_HOME': os.path.join(tmp_path_factory.mktemp('airflow_home'))
}
with mock.patch.dict(os.environ, airflow_environ, clear=True):
db.resetdb()
yield
# TODO: Make a common fixture that automatically loads and test all dag validity.
# https://www.astronomer.io/events/recaps/testing-airflow-to-bulletproof-your-code/
@pytest.fixture(name='dagbag')
def fixture_dagbag():
dag_bag = DagBag(None, include_examples=False, read_dags_from_db=False)
......@@ -12,7 +37,7 @@ def fixture_dagbag():
return dag_bag
def test_mediarequest_hourly_dag_loaded(dagbag):
def test_mediarequest_hourly_dag_loaded(airflow, dagbag):
assert dagbag.import_errors == {}
dag = dagbag.get_dag(dag_id="mediarequest_hourly")
assert dag is not None
......
......@@ -146,7 +146,8 @@ Container: container_e33_1637058075222_419446_01_000001 on an-worker1109.eqiad.w
'env': {
'env1': 'val1',
},
'log_level': 'DEBUG',
'client_log_level': 'DEBUG',
'master_log_level': 'DEBUG',
'principal': 'user@domain',
'keytab': '/pat/to/keytab',
'resources': {
......
......@@ -8,12 +8,13 @@ envlist = py37,py39
isolated_build = True
requires = tox-conda
# Main test job and defaults for other test jobs.
# this will be run for each of the default tox envlist defined above.
[testenv]
deps =
.[test]
# Passed in from gitlab-ci to import from wmf_airflow_common
passenv = PYTHONPATH
commands=
pytest
......
......@@ -14,14 +14,14 @@ class ArtifactRegistry:
URLs where their artifacts have been synced.
"""
def __init__(self, artifact_config_files):
def __init__(self, artifact_config_files: List[str]):
"""
:param artifact_config_files:
List of artifact .yaml config files to load.
"""
self._artifacts = Artifact.load_artifacts_from_config(artifact_config_files)
def artifact(self, artifact_id):
def artifact(self, artifact_id: str) -> Artifact:
"""
Gets the registered Artifact by artifact_id.
Throws KeyError is artifact_id is not registered in config files.
......@@ -30,19 +30,19 @@ class ArtifactRegistry:
raise KeyError(f'Artifact id {artifact_id} not declared in configuration.')
return self._artifacts[artifact_id]
def artifact_url(self, artifact_id):
def artifact_url(self, artifact_id: str) -> str:
"""
Gets the first expected cached url for the artifact id.
The artifact is not checked for existence at this url.
"""
return self.artifact(artifact_id).cached_url()
return str(self.artifact(artifact_id).cached_url())
@classmethod
def for_wmf_airflow_instance(
cls,
airflow_instance_name: str,
other_files: Optional[List[str]] = None
):
) -> 'ArtifactRegistry':
"""
Uses WMF airflow instance conventions in the
data-enginering/airflow-dags repository to
......
......@@ -28,7 +28,7 @@ from wmf_airflow_common.util import \
airflow_config_get
def get_base_spark_defaults_args():
def get_base_spark_defaults_args() -> dict:
"""
Gets values for some Spark operator defaults we might always want to set.
"""
......@@ -164,8 +164,8 @@ def get(
if extra_default_args is None:
extra_default_args = {}
return merge(
return dict(merge(
{},
common_default_args,
extra_default_args,
)
))
......@@ -61,6 +61,8 @@ use the following getters instead of var_props.get():
The default value should be a python dict.
"""
from typing import cast, Any, Callable
from datetime import datetime, timedelta
from isodate import parse_duration
from json import JSONDecodeError
......@@ -68,9 +70,9 @@ from airflow.models import Variable
class VariableProperties:
def __init__(self, variable_name):
def __init__(self, variable_name: str):
try:
self.variable = Variable.get(
self.variable: Variable = Variable.get(
variable_name,
deserialize_json=True,
default_var={},
......@@ -80,20 +82,31 @@ class VariableProperties:
f'Variable {variable_name} can not be parsed as JSON.'
)
def get(self, property_name, default_value):
def get(self, property_name: str, default_value: Any) -> Any:
return self.variable.get(property_name, default_value)
def get_datetime(self, property_name, default_value):
def get_datetime(self, property_name: str, default_value: datetime) -> datetime:
if type(default_value) is not datetime:
raise ValueError('Default value is not a datetime.')
return self.get_parsed(property_name, datetime.fromisoformat, default_value)
return cast(
datetime,
self.get_parsed(property_name, datetime.fromisoformat, default_value)
)
def get_timedelta(self, property_name, default_value):
def get_timedelta(self, property_name: str, default_value: timedelta) -> timedelta:
if type(default_value) is not timedelta:
raise ValueError('Default value is not a timedelta.')
return self.get_parsed(property_name, parse_duration, default_value)
def get_parsed(self, property_name, parser, default_value):
return cast(
timedelta,
self.get_parsed(property_name, parse_duration, default_value)
)
def get_parsed(
self,
property_name: str,
parser: Callable[[Any], Any],
default_value: Any
) -> Any:
if property_name not in self.variable:
return default_value
try:
......@@ -101,7 +114,7 @@ class VariableProperties:
except ValueError:
raise ValueError(f'Property {property_name} can not be parsed.')
def get_merged(self, property_name, default_value):
def get_merged(self, property_name: str, default_value: dict) -> dict:
if type(default_value) is not dict:
raise ValueError('Default value is not a dict.')
if property_name not in self.variable:
......
from typing import List, Any, Optional
from datetime import datetime
import os
from airflow import DAG
from airflow.operators.python import BranchPythonOperator
......@@ -8,7 +10,7 @@ from wmf_airflow_common.operators.email import HdfsEmailOperator
from wmf_airflow_common.operators.spark import SparkSqlOperator, SparkSubmitOperator
from wmf_airflow_common.partitions_builder import daily_partitions
def should_alert(**kwargs):
def should_alert(**kwargs: Any) -> List[str]:
"""
Determines whether the last execution of the detect_anomalies task
has detected anomalies and whether an alert should be sent for them.
......@@ -36,18 +38,18 @@ class AnomalyDetectionDAG(DAG):
def __init__(
self,
dag_id,
start_date,
source_table,
source_granularity,
metrics_query,
destination_table,
hdfs_temp_directory,
anomaly_threshold,
anomaly_email,
hadoop_name_node,
refinery_job_shaded_jar,
**kwargs,
dag_id: str,
start_date: Optional[datetime],
source_table: str,
source_granularity: str,
metrics_query: str,
destination_table: str,
hdfs_temp_directory: str,
anomaly_threshold: str,
anomaly_email: str,
hadoop_name_node: str,
refinery_job_shaded_jar: str,
**kwargs: Any,
):
"""
Initializes the DAG (at interpretation time).
......@@ -143,6 +145,7 @@ class AnomalyDetectionDAG(DAG):
'Affected metrics and corresponding deviations:<br/>'
),
embedded_file=anomaly_output_path,
hadoop_name_node=hadoop_name_node,
dag=self,
)
......
from typing import Any, Tuple, List, Union, Optional
from typing import cast, Any, Tuple, List, Union, Optional
import os
import re
......@@ -56,24 +56,25 @@ class SkeinHook(BaseHook):
"""
if isinstance(application_spec, str):
self._application_spec = \
app_spec = \
skein.model.ApplicationSpec.from_file(application_spec)
else:
self._application_spec = application_spec
app_spec = cast(skein.model.ApplicationSpec, application_spec)
self._client_kwargs = client_kwargs or {}
self._app_log_collection_enabled = app_log_collection_enabled
self._application_spec: skein.model.ApplicationSpec = app_spec
self._client_kwargs: dict = client_kwargs or {}
self._app_log_collection_enabled: bool = app_log_collection_enabled
self._application_id = None
self._application_client = None
self._yarn_logs = None
self._application_id: Optional[str] = None
self._application_client: Optional[skein.ApplicationClient] = None
self._yarn_logs: Optional[str] = None
# finished will be set after submit() has been called and done.
# This is used to prevent submit() from being called more than once.
self._finished = False
self._finished: bool = False
super().__init__(**kwargs)
def get_conn(self):
def get_conn(self) -> None:
"""
Overridden for compatibility with Airflow BaseHook.
We may want to implement this if we want to store Hadoop
......@@ -94,14 +95,14 @@ class SkeinHook(BaseHook):
"""
The skein Master part of the ApplicationSpec as a dict.
"""
return self._application_spec.to_dict()['master']
return dict(self._application_spec.to_dict()['master'])
@cached_property
def _script(self) -> str:
"""
The script that will be run in the YARN AppMaster.
"""
return self._master['script']
return str(self._master['script'])
@cached_property
def _app_owner(self) -> str:
......@@ -178,7 +179,7 @@ class SkeinHook(BaseHook):
return self._yarn_logs
def find_yarn_app_ids(self) -> List[str]:
def find_yarn_app_ids(self) -> Optional[List[str]]:
"""
Searches Skein YARN application master logs for other YARN app ids and returns
a list of mentioned YARN app ids.
......@@ -244,17 +245,24 @@ class SkeinHook(BaseHook):
)
yarn_app_ids = self.find_yarn_app_ids()
self.log.info(
'%s - YARN logs mentioned the following YARN application ids: %s',
self,
','.join(yarn_app_ids)
)
self.log.info(
'%s - To view all mentioned YARN application logs, '
'run the following commands:\n\t%s',
self,
'\n\t'.join(yarn_logs_commands(yarn_app_ids, self._app_owner, sudo=True))
)
if yarn_app_ids is not None:
self.log.info(
'%s - YARN logs mentioned the following YARN application ids: %s',
self,
','.join(yarn_app_ids)
)
self.log.info(
'%s - To view all mentioned YARN application logs, '
'run the following commands:\n\t%s',
self,
'\n\t'.join(yarn_logs_commands(yarn_app_ids, self._app_owner, sudo=True))
)
else:
self.log.info(
'%s - YARN logs did not mention any of YARN application ids.',
self
)
else:
self.log.info(
'%s - YARN application log collection is disabled. To view logs for the '
......@@ -263,7 +271,7 @@ class SkeinHook(BaseHook):
'you will need to look at these logs and run a simliar command but with '
'the appropriate YARN application_id.',
self,
'\n\t'.join(yarn_logs_commands([self._application_id], self._app_owner, sudo=True))
'\n\t'.join(yarn_logs_commands([str(self._application_id)], self._app_owner, sudo=True))
)
if self.final_status() != 'SUCCEEDED':
......@@ -275,7 +283,7 @@ class SkeinHook(BaseHook):
return str(self.final_status())
def stop(self, final_status: str, diagnostics: str = None):
def stop(self, final_status: str, diagnostics: Optional[str] = None) -> None:
"""
Stops the YARN application (if running), closes the Skein ApplicationClientk
and closes the Skein Client (and Driver, if this Client started it).
......@@ -306,11 +314,11 @@ class SkeinHook(BaseHook):
# submit will no longer do anything.
self._finished = True
def on_kill(self):
def on_kill(self) -> None:
return self.stop('KILLED', 'Killed by Airflow executor')
def __str__(self) -> str:
parts = [__class__.__name__, self._application_spec.name]
parts = [self.__class__.__name__, self._application_spec.name]
if self._application_id:
parts += [self._application_id]
return ' '.join(parts).strip()
......@@ -355,28 +363,28 @@ class SkeinHookBuilder():
self._app_log_collection_enabled = True
def name(self, name):
def name(self, name: str) -> 'SkeinHookBuilder':
self.app_spec_dict['name'] = name
return self
def queue(self, queue):
def queue(self, queue: str) -> 'SkeinHookBuilder':
self.app_spec_dict['queue'] = queue
return self
def script(self, script):
def script(self, script: str) -> 'SkeinHookBuilder':
self.app_spec_dict['master']['script'] = script
return self
def resources(self, resources):
def resources(self, resources: dict) -> 'SkeinHookBuilder':
self.app_spec_dict['master']['resources'] = resources
return self
def files(
self,
files: Union[str, list, dict],
append=False,
**file_properties
):
append: bool = False,
**file_properties: dict
) -> 'SkeinHookBuilder':
"""
Sets skein.Master files kwargs.
......@@ -424,7 +432,7 @@ class SkeinHookBuilder():
return self
def env(self, env_vars: dict, append: bool = False):
def env(self, env_vars: dict, append: bool = False) -> 'SkeinHookBuilder':
"""
Sets env vars for the YARN AppMaster.
......@@ -442,23 +450,23 @@ class SkeinHookBuilder():
self.app_spec_dict['master']['env'].update(env_vars)
return self
def master_log_level(self, log_level: str):
def master_log_level(self, log_level: str) -> 'SkeinHookBuilder':
self.app_spec_dict['master']['log_level'] = log_level
return self
def client_log_level(self, log_level: str):
def client_log_level(self, log_level: str) -> 'SkeinHookBuilder':
self.client_kwargs['log_level'] = log_level
return self
def app_log_collection_enabled(self, enabled: bool):
def app_log_collection_enabled(self, enabled: bool) -> 'SkeinHookBuilder':
self._app_log_collection_enabled = enabled
return self
def principal(self, principal: str):
def principal(self, principal: str) -> 'SkeinHookBuilder':
self.client_kwargs['principal'] = principal
return self
def keytab(self, keytab: str):
def keytab(self, keytab: str) -> 'SkeinHookBuilder':
self.client_kwargs['keytab'] = keytab
return self
......@@ -470,7 +478,7 @@ class SkeinHookBuilder():
)
YARN_APP_ID_REGEX = r'application_\d{13}_\d{4,}'
YARN_APP_ID_REGEX: re.Pattern = re.compile('application_\d{13}_\d{4,}')
"""
Regex used to match YARN applicaition IDs out of logs.
"""
......@@ -521,7 +529,11 @@ def parse_file_source(
return (alias, file)
def yarn_logs_commands(yarn_app_ids: List[str], app_owner=None, sudo=True) -> List[str]:
def yarn_logs_commands(
yarn_app_ids: List[str],
app_owner: Optional[str] = None,
sudo: bool = True
) -> List[str]:
"""
Given YARN app ids, returns a list of yarn logs commands to run
to view aggregated YARN logs for those YARN apps.
......@@ -537,9 +549,9 @@ def yarn_logs_commands(yarn_app_ids: List[str], app_owner=None, sudo=True) -> Li
will be prefixed to the yarn logs command.
"""
sudo = f'sudo -u {app_owner} ' if app_owner and sudo else ''
sudo_cmd = f'sudo -u {app_owner} ' if app_owner and sudo else ''
app_owner = f'-appOwner {app_owner} ' if app_owner else ''
return [f'{sudo}yarn logs {app_owner}-applicationId {app_id}' for app_id in yarn_app_ids]
return [f'{sudo_cmd}yarn logs {app_owner}-applicationId {app_id}' for app_id in yarn_app_ids]
def find_yarn_app_ids(logs: str) -> List[str]:
......
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
import os
from airflow.providers.apache.spark.hooks.spark_submit \
......@@ -31,7 +31,7 @@ class SparkSubmitHook(AirflowSparkSubmitHook):
queue: Optional[str] = None,
spark_home: Optional[str] = None,
conn_id: Optional[str] = None, # overridden here to change default value
**kwargs,
**kwargs: Any,
):
"""
......@@ -80,16 +80,16 @@ class SparkSubmitHook(AirflowSparkSubmitHook):
kwargs to pass to parent Airflow SparkSubmitHook.
"""
self._application = application
self._master = master
self._deploy_mode = deploy_mode
self._queue = queue
self._application: str = application
self._master: Optional[str] = master
self._deploy_mode: Optional['str'] = deploy_mode
self._queue: Optional[str] = queue
self._driver_cores: Optional[str] = None
if driver_cores:
self._driver_cores = str(driver_cores)