Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Muniza
Language Models
Commits
e5f65fac
Commit
e5f65fac
authored
Aug 23, 2022
by
Muniza
Browse files
Add option to save predictions to disk
parent
e2a7f472
Changes
1
Hide whitespace changes
Inline
Side-by-side
language_models/train.py
View file @
e5f65fac
...
...
@@ -47,11 +47,8 @@ def train(X: pd.DataFrame, y: pd.Series) -> XGBClassifier:
return
model
def
calculate_performance
(
model
:
XGBClassifier
,
X
:
pd
.
DataFrame
,
y
:
pd
.
Series
)
->
Performance
:
logging
.
info
(
f
"Calculating model performance metrics:
{
len
(
X
)
}
examples"
)
y_pred
=
model
.
predict
(
X
)
def
calculate_performance
(
y
:
pd
.
Series
,
y_pred
:
pd
.
Series
)
->
Performance
:
logging
.
info
(
f
"Calculating model performance metrics:
{
len
(
y
)
}
examples"
)
tn
,
fp
,
fn
,
tp
=
confusion_matrix
(
y
,
y_pred
).
ravel
()
logging
.
info
(
f
"tn:
{
tn
}
, fp:
{
fp
}
, fn:
{
fn
}
, tp:
{
tp
}
"
)
return
Performance
(
...
...
@@ -83,10 +80,21 @@ def main() -> None:
required
=
True
,
help
=
"Directory to write the performance metrics for models"
,
)
parser
.
add_argument
(
"--predictions"
,
type
=
Path
,
required
=
False
,
help
=
(
"Directory to optionally write the predictions"
"used for calculating performance"
),
)
args
=
parser
.
parse_args
()
data_path
=
args
.
data
models_dir
=
args
.
models
performance_dir
=
args
.
performance
predictions_dir
=
args
.
predictions
data_df
=
pd
.
read_csv
(
data_path
,
sep
=
"
\t
"
,
compression
=
"gzip"
)
data_df
=
data_df
[
~
data_df
[
LABEL
].
isnull
()]
...
...
@@ -97,8 +105,9 @@ def main() -> None:
logging
.
info
(
f
"Current language:
{
wiki_db
}
"
)
language_df
,
other_languages_df
=
split_by_language
(
data_df
,
wiki_db
)
model
=
train
(
other_languages_df
[
FEATURES
],
other_languages_df
[
LABEL
])
language_df
[
"prediction"
]
=
model
.
predict
(
language_df
[
FEATURES
])
model_performance
=
calculate_performance
(
model
,
language_df
[
FEATURES
],
language_df
[
LABEL
]
language_df
[
LABEL
],
language_df
[
"prediction"
]
)
model_path
=
models_dir
/
f
"
{
wiki_db
}
.json"
...
...
@@ -110,6 +119,13 @@ def main() -> None:
with
open
(
performance_path
,
"w"
)
as
f
:
json
.
dump
(
asdict
(
model_performance
),
f
)
if
predictions_dir
:
predictions_path
=
predictions_dir
/
f
"
{
wiki_db
}
.tsv"
logging
.
info
(
f
"Writing predictions to
{
predictions_path
}
"
)
language_df
[[
"revision_id"
,
"wiki_db"
,
LABEL
,
"prediction"
]].
to_csv
(
predictions_path
,
index
=
False
,
sep
=
"
\t
"
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment