Commit ad50a6a5 authored by Gmodena's avatar Gmodena
Browse files

Add image matching tasks

parent 3798969a
# image-matching
Training and dataset publishing pipeline for the [Image Suggestion](https://phabricator.wikimedia.org/project/profile/5171/) service.
Airflow DAG for model training and etl.
## Content
- `dags` contains the airflow dags for this workflow.
- `spark` contains Spark based data processing tasks.
- `sbin` contains python scripts for data processing tasks.
- `sql` contains SQL/HQL based data processing tasks.
## Test
```
python3 -m venv venv
source venv/bin/activate
pip install -r requirements-test.txt
PYTHONPATH=spark python3 -m test
```
import pytest
from spark.transform import RawDataset
from pyspark.sql import DataFrame
from pyspark.sql import SparkSession
@pytest.fixture(scope="session")
def raw_data(spark_session):
return spark_session.createDataFrame(
[
(
"0",
"Q1234",
"44444",
"Some page with suggestions",
'[{"image": "image1.jpg", "rating": 2.0, "note": "image was found in the following Wikis: ruwiki"}]',
None,
"arwiki",
"2020-12",
),
(
"1",
"Q56789",
"55555",
"Some page with no suggestion",
None,
None,
"arwiki",
"2020-12",
),
(
"2",
"Q66666",
"523523",
"Some page with 3 suggestions",
"["
'{"image": "image2.jpg", "rating": 2.0, "note": "image was found in the following Wikis: ruwiki,arwiki,enwiki"}, '
'{"image": "image3.jpg", "rating": 1, "note": "image was in the Wikidata item"}, '
'{"image": "image4.jpg", "rating": 3.0, "note": "image was found in the Commons category linked in '
'the Wikidata item"} '
"]",
'{"entity-type":"item","numeric-id":577,"id":"Q577"}',
"enwiki",
"2020-12",
),
],
RawDataset.schema,
)
@pytest.fixture(scope="session")
def wikis(spark_session: SparkSession) -> DataFrame:
return spark_session.createDataFrame(
[
["image was found in the following Wikis: ruwiki, itwiki,enwiki"],
["image was found in the following Wikis: "],
[None],
],
["note"],
)
def assert_shallow_equals(ddf: DataFrame, other_ddf: DataFrame) -> None:
assert len(set(ddf.columns).difference(set(other_ddf.columns))) == 0
assert ddf.subtract(other_ddf).rdd.isEmpty()
assert other_ddf.subtract(ddf).rdd.isEmpty()
pytest==6.2.2
pytest-spark==0.6.0
pytest-cov==2.10.1
flake8==3.8.4
import argparse
import papermill as pm
import os
# Todo: find a more accurate way to get dblist.
all_languages = languages = ['enwiki', 'arwiki', 'kowiki', 'cswiki', 'viwiki', 'frwiki', 'fawiki', 'ptwiki',
'ruwiki', 'trwiki', 'plwiki', 'hewiki', 'svwiki', 'ukwiki', 'huwiki', 'hywiki',
'srwiki', 'euwiki', 'arzwiki', 'cebwiki', 'dewiki', 'bnwiki']
class AlgoRunner(object):
def __init__(self, snapshot, languages, output_dir):
"""
:param str languages: A list of the languages separated by a comma to run against the algorithm.
:param str snapshot: Snapshot date
:param str output_dir: Directory to place output .ipynb and .tsv files
"""
self.snapshot = snapshot
self.languages = languages.split(',')
self.output_dir = output_dir
print(f'Initializing with snapshot={self.snapshot} languages={self.languages} output_dir={self.output_dir}')
def run(self):
if len(self.languages) == 1 and self.languages[0] == 'All':
self.execute_papermill(all_languages)
else:
self.execute_papermill(self.languages)
def execute_papermill(self, languages):
"""
Executes jupyter notebook
:param list languages: List of languages to run against the algorithm
"""
print(f'Starting to execute the algorithm for the following languages: {languages}')
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
for language in languages:
pm.execute_notebook(
'algorithm.ipynb',
self.output_dir + '/' + language + '_' + self.snapshot + '.ipynb',
parameters=dict(language=language, snapshot=self.snapshot, output_dir=self.output_dir)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Executes jupyter notebook with parameters. ' +
'Ex: python3 algorunner.py 2020-12-28 hywiki Output')
parser.add_argument('snapshot', help='Full snapshot date. Ex: 2020-12-28')
parser.add_argument('languages', nargs='?', default='All',
help='Language(s) to execute. If more than one separate with comma. Ex: enwiki,kowiki,arwiki')
parser.add_argument('output_dir', nargs='?', default='Output',
help='Directory to place output .ipynb and .tsv files. Defaults to: Output')
args = parser.parse_args()
runner = AlgoRunner(args.snapshot, args.languages, args.output_dir)
runner.run()
from enum import Enum
class InstancesToFilter(Enum):
YEAR = "Q577"
CALENDARYEAR = "Q3186692"
DISAMBIGUATION = "Q4167410"
LIST = "Q13406463"
@classmethod
def list(cls):
return [p.value for p in InstancesToFilter]
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from schema import CsvDataset
import argparse
spark = SparkSession.builder.getOrCreate()
def parse_args():
parser = argparse.ArgumentParser(
description="Transform raw algo output to production datasets"
)
parser.add_argument("--snapshot", help="Montlhy snapshot date (YYYY-MM)")
parser.add_argument("--wiki", help="Wiki name")
parser.add_argument("--source", help="Source dataset path")
parser.add_argument("--destination", help="Destination path")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
snapshot = args.snapshot
source = args.source
destination = args.destination
wiki = args.wiki
csv_df = (
(
spark.read.options(delimiter="\t", header=False, escape='"')
.schema(CsvDataset.schema)
.csv(source)
)
.withColumn("wiki_db", F.lit(wiki))
.withColumn("snapshot", F.lit(snapshot))
)
csv_df.coalesce(1).write.partitionBy("wiki_db", "snapshot").mode(
"overwrite"
).parquet(
destination
) # Requires dynamic partitioning enabled
spark.stop()
from pyspark.sql.types import StructType, StringType
class CsvDataset:
schema = (
StructType()
.add("pandas_idx", StringType(), True)
.add("item_id", StringType(), True)
.add("page_id", StringType(), True)
.add("page_title", StringType(), True)
.add("top_candidates", StringType(), True)
.add("instance_of", StringType(), True)
)
class RawDataset(CsvDataset):
schema = CsvDataset.schema.add("wiki_db", StringType(), True).add(
"snapshot", StringType(), True
)
top_candidates_schema = "array<struct<image:string,note:string,rating:double>>"
instance_of_schema = "struct<`entity-type`:string,`numeric-id`:bigint,id:string>"
from pyspark.sql import SparkSession
from pyspark.sql import Column, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
from schema import RawDataset
from instances_to_filter import InstancesToFilter
import argparse
import uuid
import datetime
spark = SparkSession.builder.getOrCreate()
class ImageRecommendation:
confidence_rating: Column = (
F.when(F.col("rating").cast(IntegerType()) == 1, F.lit("high"))
.when(F.col("rating").cast(IntegerType()) == 2, F.lit("medium"))
.when(F.col("rating").cast(IntegerType()) == 3, F.lit("low"))
)
source: Column = (
F.when(
F.col("note").like(r"image was in the Wikidata item%"), F.lit("wikidata")
)
.when(
F.col("note").like(r"image was found in the following Wikis:%"),
F.lit("wikipedia"),
)
.when(
F.col("note").like(r"image was found in the Commons category%"),
F.lit("commons"),
)
)
instance_of: Column = F.when(F.col("instance_of").isNull(), F.lit(None)).otherwise(
F.from_json("instance_of", RawDataset.instance_of_schema).getItem("id")
)
found_on: Column = F.when(F.col("note").isNull(), F.lit(None)).otherwise(
F.split(
F.regexp_replace(
F.regexp_extract(F.col("note"), "Wikis:\s+(.*)$", 1), "\s+", ""
),
",",
)
)
is_article_page: Column = (
F.when(
F.col("instance_of").isin(InstancesToFilter.list()),
F.lit(False)
)
.otherwise(True)
)
def __init__(self, dataFrame: DataFrame):
self.dataFrame = dataFrame
if not dataFrame.schema == RawDataset.schema:
raise AttributeError(
f"Invalid schema. Expected '{RawDataset.schema}'. Got '{dataFrame.schema}"
)
def transform(self) -> DataFrame:
with_recommendations = (
self.dataFrame.where(~F.col("top_candidates").isNull())
.withColumn(
"data",
F.explode(
F.from_json("top_candidates", RawDataset.top_candidates_schema)
),
)
.select("*", "data.image", "data.rating", "data.note")
.withColumnRenamed("wiki_db", "wiki")
.withColumnRenamed("image", "image_id")
.withColumn("confidence_rating", self.confidence_rating)
.withColumn("source", self.source)
.withColumn("found_on", self.found_on)
.select(
"wiki",
"page_id",
"page_title",
"image_id",
"confidence_rating",
"source",
"instance_of",
"found_on",
)
)
without_recommendations = (
self.dataFrame.where(F.col("top_candidates").isNull())
.withColumnRenamed("wiki_db", "wiki")
.withColumn("image_id", F.lit(None))
.withColumn("confidence_rating", F.lit(None))
.withColumn("source", F.lit(None))
.withColumn("found_on", F.lit(None))
.select(
"wiki",
"page_id",
"page_title",
"image_id",
"confidence_rating",
"source",
"instance_of",
"found_on",
)
)
return with_recommendations.union(without_recommendations)\
.withColumn("instance_of", self.instance_of)\
.withColumn("is_article_page", self.is_article_page)
def parse_args():
parser = argparse.ArgumentParser(
description="Transform raw algo output to production datasets"
)
parser.add_argument("--snapshot", help="Montlhy snapshot date (YYYY-MM)")
parser.add_argument("--source", help="Source dataset path")
parser.add_argument("--destination", help="Destination path")
parser.add_argument(
"--dataset-id",
help="Production dataset identifier (optional)",
default=str(uuid.uuid4()),
dest="dataset_id",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
snapshot = args.snapshot
source = args.source
destination = args.destination
dataset_id = args.dataset_id
num_partitions = 1
df = spark.read.schema(RawDataset.schema).parquet(source)
insertion_ts = datetime.datetime.now().timestamp()
(
ImageRecommendation(df)
.transform()
.withColumn("dataset_id", F.lit(dataset_id))
.withColumn("insertion_ts", F.lit(insertion_ts))
.withColumn("snapshot", F.lit(snapshot))
.sort(F.desc("page_title"))
.coalesce(num_partitions)
.write.partitionBy("wiki", "snapshot")
.mode("overwrite") # Requires dynamic partitioning enabled
.parquet(destination)
)
spark.stop()
-- This script is used to export production datasets,
-- in a format consumable by the APIs.
--
-- Run with:
-- hive -hiveconf output_path=<output_path> -hiveconf username=${username} -hiveconf wiki=${wiki} -hiveconf snapshot=${monthly_snapshot} -f export_prod_data.hql
--
--
-- Format
-- * Include header: yes
-- * Field delimiter: "\t"
-- * Null value for missing recommendations
-- (image_id, confidence_rating, source fields): ""
-- * found_on: list of wikis delimited by ','
--
-- Changelog:
-- * 2021-03-31: creation.
--
--
use ${hiveconf:username};
set hivevar:null_value="";
set hivevar:found_on_delimiter=",";
set hive.cli.print.header=true;
insert overwrite local directory '${hiveconf:output_path}'
row format delimited fields terminated by '\t'
select page_id,
page_title,
nvl(image_id, ${null_value}) as image_id,
nvl(confidence_rating, ${null_value}) as confidence_rating,
nvl(source, ${null_value}) as source,
dataset_id,
insertion_ts,
wiki,
concat_ws(${found_on_delimiter}, found_on) as found_on
from imagerec_prod
where wiki = '${hiveconf:wiki}' and snapshot='${hiveconf:snapshot}' and is_article_page=true and image_id is not null;
-- This script is used to export production datasets,
-- in a format consumable by the APIs.
--
-- Data is collected locally, in TSV format, under <output_path>.
--
-- Run with:
-- hive -hiveconf output_path=<output_path> -hiveconf username=${username} -hiveconf wiki=${wiki} -hiveconf snapshot=${monthly_snapshot} -f export_prod_data.hql
--
--
-- Format
-- * Include header: yes
-- * Field delimiter: "\t"
-- * Null value for missing recommendations
-- (image_id, confidence_rating, source fields): ""
-- * found_on: list of wikis delimited by ','
--
-- Changelog:
-- * 2021-03-08: schema and format freeze.
-- * 2021-03-25: append found_on column
-- * 2021-03-25: add is_article_page to where clause
--
use ${hiveconf:username};
set hivevar:null_value="";
set hivevar:found_on_delimiter=",";
set hive.cli.print.header=true;
insert overwrite local directory '${hiveconf:output_path}'
row format delimited fields terminated by '\t'
select page_id,
page_title,
nvl(image_id, ${null_value}) as image_id,
nvl(confidence_rating, ${null_value}) as confidence_rating,
nvl(source, ${null_value}) as source,
dataset_id,
insertion_ts,
wiki,
concat_ws(${found_on_delimiter}, found_on) as found_on
from imagerec_prod
where wiki = '${hiveconf:wiki}' and snapshot='${hiveconf:snapshot}' and is_article_page=true
-- DDL to create an external table that exposes samples of the
-- production dataset.
-- The default HDFS location and Hive database are relative to a developer's.
-- username. Example hdfs://analytics-hadoop/user/clarakosi/imagerec/data.
--
-- The dataset will be available at https://superset.wikimedia.org/superset/sqllab via the
-- `presto_analytics` database.
--
-- Execution
-- hive -hiveconf username=<username> -f external_imagerec.hql
USE ${hiveconf:username};
CREATE EXTERNAL TABLE IF NOT EXISTS `imagerec` (
`pandas_idx` string,
`item_id` string,
`page_id` string,
`page_title` string,
`top_candidates` string,
`instance_of` string)
PARTITIONED BY (
`wiki_db` string,
`snapshot` string)
STORED AS PARQUET
LOCATION
'hdfs://analytics-hadoop/user/${hiveconf:username}/imagerec';
-- Update partition metadata
MSCK REPAIR TABLE `imagerec`;
-- DDL to create an external table that exposes samples of the
-- production dataset.
-- The default HDFS location and Hive database are relative to a developer's.
-- username. Example hdfs://analytics-hadoop/user/gmodena/imagerec_prod/data.
--
-- The dataset will be available at https://superset.wikimedia.org/superset/sqllab via the
-- `presto_analytics` database.
--
-- Execution
-- hive -hiveconf username=<username> -f external_imagerec_prod.hql
USE ${hiveconf:username};
CREATE EXTERNAL TABLE IF NOT EXISTS `imagerec_prod`(
`page_id` string,
`page_title` string,
`image_id` string,
`confidence_rating` string,
`source` string,
`instance_of` string,
`is_article_page` boolean,
`dataset_id` string,
`insertion_ts` double,
`found_on` array<string>)
PARTITIONED BY (`wiki` string, `snapshot` string)
STORED AS PARQUET
LOCATION
'hdfs://analytics-hadoop/user/${hiveconf:username}/imagerec_prod';
-- Update partition metadata
MSCK REPAIR TABLE `imagerec_prod`;
from spark.transform import ImageRecommendation
from pyspark.sql import functions as F
from pyspark import Row
from conftest import assert_shallow_equals
def test_etl(raw_data):
assert raw_data.count() == 3
ddf = ImageRecommendation(raw_data).transform()
assert (
len(
set(ddf.columns).difference(
{
"wiki",
"page_id",
"page_title",
"image_id",
"confidence_rating",
"instance_of",
"is_article_page",
"source",
"found_on",
}
)
)
== 0
)
expected_num_records = 5
assert ddf.count() == expected_num_records
expected_confidence = {"wikipedia": "medium", "commons": "low", "wikidata": "high"}
for source in expected_confidence:
ddf.where(F.col("source") == source).select(
"confidence_rating"
).distinct().collect()
rows = (
ddf.where(F.col("source") == source)
.select("confidence_rating")
.distinct()
.collect()
)
assert len(rows) == 1
assert rows[0]["confidence_rating"] == expected_confidence[source]
# Unillustrated articles with no recommendation have no confidence rating
assert (
ddf.where(F.col("source") == "null")
.where(F.col("confidence_rating") != "null")
.count()
== 0
)
# Instance_of json is correctly parsed
expected_instance_of = "Q577"
rows = (
ddf.where(F.col("instance_of") != "null")
.select("instance_of")
.distinct()
.collect()
)
assert len(rows) == 1
assert rows[0]["instance_of"] == expected_instance_of
# Pages are correctly marked for filtering
expected_page_id = "523523"
filter_out_rows = (
ddf.where(~F.col("is_article_page"))