Commit 523a497c authored by Bmansurov's avatar Bmansurov
Browse files

Rename spark_session to spark for consistency

parent 57a18cab
......@@ -3,7 +3,7 @@ from datetime import datetime, timedelta
import pyspark.sql.functions as F
import wmfdata as wmf
spark_session = wmf.spark.get_session(
spark = wmf.spark.get_session(
type='local',
app_name="knowledge-gaps"
)
......@@ -17,7 +17,7 @@ today = datetime.today()
ninety_days_earlier = today - timedelta(days=90)
pageviews_df = article_features.extract_pageviews(
spark_session,
spark,
ninety_days_earlier,
today,
projects,
......@@ -26,12 +26,12 @@ pageviews_df = article_features.extract_pageviews(
print(pageviews_df.head(3))
def get_pages_df(spark_session, table='aikochou.pages_20220310'):
def get_pages_df(spark, table='aikochou.pages_20220310'):
query = f"SELECT * FROM {table}"
return spark_session.sql(query)
return spark.sql(query)
pages_df = get_pages_df(spark_session)
pages_df = get_pages_df(spark)
print(pages_df.head(3))
# Here we can add content gaps to pages_df, etc.
......
......@@ -6,14 +6,14 @@ import pyspark.sql.functions as F
import wmfdata as wmf # type: ignore
def get_pages_df(spark_session, table='bmansurov.wikipedia_pages_2022_01'):
def get_pages_df(spark, table='bmansurov.wikipedia_pages_2022_01'):
"Temporary function for easy pages retrieval during development."
query = f"SELECT * FROM {table}"
return spark_session.sql(query)
return spark.sql(query)
def main(args):
spark_session = wmf.spark.get_session(
spark = wmf.spark.get_session(
type='yarn-large',
app_name="knowledge-gaps"
)
......@@ -21,19 +21,20 @@ def main(args):
from knowledge_gaps import article_features, func
pageviews_df = article_features.extract_pageviews(
spark_session,
spark,
args.pageviews_start_date,
args.pageviews_end_date,
args.projects,
args.pageviews_table
)
if args.wikipedia_pages_table:
pages_df = get_pages_df(spark_session)
pages_df = get_pages_df(spark)
else:
pages_df = func.wikipedia_pages_df(spark_session,
pages_df = func.wikipedia_pages_df(spark,
args.mediawiki_snapshot,
args.wikidata_snapshot,
args.projects)
# TODO: Move this to a separate function.
pages_df = (pages_df.alias('p')
.join(pageviews_df.alias('pv'),
......
......@@ -106,7 +106,7 @@ def wikitext_df(spark, mediawiki_snapshot, projects=None):
return df
def extract_pageviews(spark_session, start_date, end_date,
def extract_pageviews(spark, start_date, end_date,
projects=None, table="wmf.pageview_hourly"):
"""Extract the number of pageviews between START_DATE and END_DATE
for PROJECTS. Supply a smaller TABLE for faster queries during
......@@ -114,7 +114,7 @@ def extract_pageviews(spark_session, start_date, end_date,
Parameters
----------
spark_session : SparkSession
spark : SparkSession
start_date : datetime.datetime
Start date for counting pageviews.
......@@ -159,7 +159,7 @@ def extract_pageviews(spark_session, start_date, end_date,
GROUP BY
project, page_title, page_id, year, month, day
"""
df = spark_session.sql(query)
df = spark.sql(query)
if projects:
df = df.where(F.col('wiki_db').isin(projects))
return df
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment