Commit e5f65fac authored by Muniza's avatar Muniza
Browse files

Add option to save predictions to disk

parent e2a7f472
......@@ -47,11 +47,8 @@ def train(X: pd.DataFrame, y: pd.Series) -> XGBClassifier:
return model
def calculate_performance(
model: XGBClassifier, X: pd.DataFrame, y: pd.Series
) -> Performance:
logging.info(f"Calculating model performance metrics: {len(X)} examples")
y_pred = model.predict(X)
def calculate_performance(y: pd.Series, y_pred: pd.Series) -> Performance:
logging.info(f"Calculating model performance metrics: {len(y)} examples")
tn, fp, fn, tp = confusion_matrix(y, y_pred).ravel()
logging.info(f"tn: {tn}, fp: {fp}, fn: {fn}, tp: {tp}")
return Performance(
......@@ -83,10 +80,21 @@ def main() -> None:
required=True,
help="Directory to write the performance metrics for models",
)
parser.add_argument(
"--predictions",
type=Path,
required=False,
help=(
"Directory to optionally write the predictions"
"used for calculating performance"
),
)
args = parser.parse_args()
data_path = args.data
models_dir = args.models
performance_dir = args.performance
predictions_dir = args.predictions
data_df = pd.read_csv(data_path, sep="\t", compression="gzip")
data_df = data_df[~data_df[LABEL].isnull()]
......@@ -97,8 +105,9 @@ def main() -> None:
logging.info(f"Current language: {wiki_db}")
language_df, other_languages_df = split_by_language(data_df, wiki_db)
model = train(other_languages_df[FEATURES], other_languages_df[LABEL])
language_df["prediction"] = model.predict(language_df[FEATURES])
model_performance = calculate_performance(
model, language_df[FEATURES], language_df[LABEL]
language_df[LABEL], language_df["prediction"]
)
model_path = models_dir / f"{wiki_db}.json"
......@@ -110,6 +119,13 @@ def main() -> None:
with open(performance_path, "w") as f:
json.dump(asdict(model_performance), f)
if predictions_dir:
predictions_path = predictions_dir / f"{wiki_db}.tsv"
logging.info(f"Writing predictions to {predictions_path}")
language_df[["revision_id", "wiki_db", LABEL, "prediction"]].to_csv(
predictions_path, index=False, sep="\t"
)
if __name__ == "__main__":
main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment