Skip to content

Commit

Permalink
Add documentation for AMRI and other LP eval metrics (#1075)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*

* Add rst docs for AMRI and other LP eval metrics

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
thvasilo authored Oct 30, 2024
1 parent c44923a commit ee16e74
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 12 deletions.
77 changes: 76 additions & 1 deletion docs/source/advanced/link-prediction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ In general, GraphStorm covers following cases:
- **Case 2** ``num_train_hard_negatives`` is smaller than ``num_negative_edges``. GraphStorm will randomly sample ``num_train_hard_negatives`` hard negative nodes from the hard negative set and then randomly sample ``num_negative_edges - num_train_hard_negatives`` negative nodes.
- **Case 3** GraphStorm supports cases when some edges do not have enough hard negatives provided by users. For example, the expected ``num_train_hard_negatives`` is 10, but an edge only have 5 hard negatives. In certain cases, GraphStorm will use all the hard negatives first and then randomly sample negative nodes to fulfill the requirement of ``num_train_hard_negatives``. Then GraphStorm will go back to **Case 1** or **Case 2**.

**Preparing graph data for hard negative sampling**
Preparing graph data for hard negative sampling
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The gconstruct pipeline of GraphStorm provides support to load hard negative data from raw input.
Hard destination negatives can be defined through ``edge_dst_hard_negative`` transformation.
Expand Down Expand Up @@ -328,3 +329,77 @@ For example, the file storing hard negatives should look like the following:
"src_100"| "dst_41"| "dst0;dst_2"
GraphStorm will automatically translate the Raw Node IDs of hard negatives into Partition Node IDs in a DistDGL graph.

.. _link-prediction-evaluation-metrics:

Link Prediction Metrics
-----------------------

GraphStorm supports several metrics for link prediction, to give a well-rounded
view of model performance.
In general, link prediction evaluation happens by constructing a set of negative
edges with one of the sampling methods described above, and including one positive
edge in this set of edges, which we will refer to as the `candidate set`. The model
assigns a score to each edge in the candidate set, and ideally the true edge is ranked
at the top position when edges are ranked by scores.

We define the set of ranking scores as :math:`\mathcal{I}` and the number of candidate
edges as :math:`\mathcal{|I|}`. We refer to the ranking of a positive edge within the list
as :math:`r`.

Mean Reciprocal Rank (MRR)
^^^^^^^^^^^^^^^^^^^^^^^^^^

Mean Reciprocal Rank or MRR is a metric commonly used in link prediction evaluation
that represents the ability of the model to rank the correct edge among a list of
candidate edges. It is defined as:

.. math::
\text{MRR} = \frac{1}{| \mathcal{I} |} \sum_{r \in \mathcal{I}}{\frac{1}{r}}
where :math:`\mathcal{I}` is the set of candidate edges, and :math:`r` corresponds to the
ranking of the positive edge as determined by the score assigned to the model to
each edge in the candidate set.

The ideal MRR is 1.0 meaning that the positive edges are ranked first in every
score list. Because a positive edge is always included in the ranking, it cannot
get the value of 0.0 so its range is in :math:`(0, 1]`. MRR values are influenced by
the size of the candidate lists, so it can only be used to compare the performance
when the number of negative edges per positive edge is the same.

Hits@k
^^^^^^

The ``Hits@k`` metric measures the number of times the positive edge was ranked in the
top k positions by the model in the sorted score list:

.. math::
\text{Hits@k} = \frac{| r \in \mathcal{I} | r \leq k |}{| \mathcal{I} |}
This metric is easy to interpret but has the disadvantage that any position
beyond the top-k is not taken into account, so does not provide a holistic
view needed for cross-model comparison, and is also sensitive to the number
of negatives in the set.


Adjusted Mean Ranking Index (AMRI)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

AMRI was proposed in the paper
`On the Ambiguity of Rank-Based Evaluation of EA or LP Methods <https://arxiv.org/abs/2002.06914>`_
as a metric that allows cross-model comparison, by looking at the entire score list, but is not
sensitive to the chosen number of negative edges per positive edge. It is defined as:


.. math::
\text{AMRI} = 1 - \frac{\text{MR}-1}{\mathbb{E}[\text{MR}-1]}
where :math:`\text{MR}` is the mean rank, and :math:`\mathbb{E}[\text{MR}-1]` is the expected mean rank,
which is used to adjust for chance. Its values will be in the :math:`[-1, 1]` range, where 1 corresponds
to optimal performance where each individual rank of the positive edge is 1. A value of 0 indicates
model performance similar to a model assigning random scores, or equal score
to every candidate. The value is negative if the model performs worse than the
all-equal-score model."
19 changes: 18 additions & 1 deletion docs/source/cli/model-training-inference/configuration-run.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,24 @@ General Configurations
- Default value: This parameter must be provided by user.

.. _eval_metrics:
- **eval_metric**: Evaluation metric used during evaluation. The input can be a string specifying the evaluation metric to report or a list of strings specifying a list of evaluation metrics to report. The first evaluation metric is treated as the major metric and is used to choose the best trained model. The supported evaluation metrics of classification tasks include ``accuracy``, ``precision_recall``, ``roc_auc``, ``f1_score``, ``per_class_f1_score``, ``hit_at_k``. To be noted, ``hit_at_k`` only works with binary classification tasks. The ``k`` of ``hit_at_k`` can be any positive integer, for example ``hit_at_10`` or ``hit_at_100``. The term ``hit_at_k`` refers to the number of true positives among the top ``k`` predictions with the highest confidence scores. The supported evaluation metrics of regression tasks include ``rmse``, ``mse`` and ``mae``. The supported evaluation metrics of link prediction tasks include ``mrr`` and ``hit_at_k``.
- **eval_metric**: Evaluation metrics used during evaluation. The input can be a string specifying
the evaluation metric to report or a list of strings specifying a list of evaluation metrics to
report. The first evaluation metric in the list is treated as the primary metric and is used to
choose the best trained model and for early stopping. Each learning task supports different evaluation metrics:

- The supported evaluation metrics of classification tasks include ``accuracy``,
``precision_recall``, ``roc_auc``, ``f1_score``, ``per_class_f1_score``, ``hit_at_k``. Note that
``hit_at_k`` only works with binary classification tasks.

- The ``k`` of ``hit_at_k`` can be any positive integer, for example ``hit_at_10`` or
``hit_at_100``. The term ``hit_at_k`` refers to the number of true positives among the top ``k``
predictions with the highest confidence scores.
- The supported evaluation metrics of regression tasks include ``rmse``, ``mse`` and ``mae``.
- The supported evaluation metrics of link prediction tasks include ``mrr``, ``amri`` and
``hit_at_k``. MRR refers to the Mean Reciprocal Rank with values between and 0 (worst) and 1
(best), and AMRI refers the Adjusted Mean Rank Index, with values ranging from -1 (worst) to 1
(best). An AMRI value of 0 is equivalent to random guessing or assigning the same score to all
edges in the candidate set. For more details on these metrics see :ref:`link-prediction-evaluation-metrics`.

- Yaml: ``eval_metric:``
| ``- accuracy``
Expand Down
2 changes: 2 additions & 0 deletions python/graphstorm/eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,8 @@ def compute_score(
# compute ranking value for each metric
metrics: Dict[str, th.Tensor] = {}
for metric in self.metric_list:
# NOTE: If other metrics needs candidate list sizes, add them here.
# Avoid adding the size twice to avoid possible errors.
if metric == "amri":
assert candidate_sizes, \
f"candidate_sizes needs to have a value for AMRI, got {candidate_sizes=}."
Expand Down
10 changes: 6 additions & 4 deletions python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
get_lm_ntypes,
use_wholegraph,
)
from graphstorm.eval.eval_func import SUPPORTED_HIT_AT_METRICS
from graphstorm.eval.eval_func import SUPPORTED_HIT_AT_METRICS, SUPPORTED_LINK_PREDICTION_METRICS

def main(config_args):
""" main function
Expand All @@ -56,10 +56,12 @@ def main(config_args):
model_layer_to_load=config.restore_model_layers)
infer = GSgnnLinkPredictionInferrer(model)
infer.setup_device(device=get_device())
assert all((x.startswith(SUPPORTED_HIT_AT_METRICS) or x == 'mrr') for x in
config.eval_metric), (
assert all((x.startswith(SUPPORTED_HIT_AT_METRICS)
or x in SUPPORTED_LINK_PREDICTION_METRICS)
for x in config.eval_metric), (
"Invalid LP evaluation metrics. "
"GraphStorm only supports MRR and Hit@K metrics for link prediction.")
f"GraphStorm only supports {SUPPORTED_LINK_PREDICTION_METRICS} "
"and Hit@K metrics for link prediction.")
if not config.no_validation:
infer_idxs = infer_data.get_edge_test_set(config.eval_etype)
infer.setup_evaluator(GSgnnLPEvaluator(
Expand Down
10 changes: 6 additions & 4 deletions python/graphstorm/run/gsgnn_lp/lp_infer_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER
from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER
from graphstorm.utils import get_device
from graphstorm.eval.eval_func import SUPPORTED_HIT_AT_METRICS
from graphstorm.eval.eval_func import SUPPORTED_HIT_AT_METRICS, SUPPORTED_LINK_PREDICTION_METRICS

def main(config_args):
""" main function
Expand All @@ -50,10 +50,12 @@ def main(config_args):
model_layer_to_load=config.restore_model_layers)
infer = GSgnnLinkPredictionInferrer(model)
infer.setup_device(device=get_device())
assert all((x.startswith(SUPPORTED_HIT_AT_METRICS) or x == 'mrr') for x in
config.eval_metric), (
assert all((x.startswith(SUPPORTED_HIT_AT_METRICS)
or x in SUPPORTED_LINK_PREDICTION_METRICS)
for x in config.eval_metric), (
"Invalid LP evaluation metrics. "
"GraphStorm only supports MRR and Hit@K metrics for link prediction.")
f"GraphStorm only supports {SUPPORTED_LINK_PREDICTION_METRICS} "
"and Hit@K metrics for link prediction.")
if not config.no_validation:
infer_idxs = infer_data.get_edge_test_set(config.eval_etype)
infer.setup_evaluator(GSgnnLPEvaluator(
Expand Down
11 changes: 9 additions & 2 deletions tests/end2end-tests/graphstorm-lp/mgpu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ then
exit -1
fi

echo "**************dataset: Movielens, do inference on saved model, decoder: dot"
python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --save-embed-path /data/gsgnn_lp_ml_dot/infer-emb/ --restore-model-path /data/gsgnn_lp_ml_dot/epoch-$best_epoch_dot/ --logging-file /tmp/log.txt --preserve-input True
echo "**************dataset: Movielens, do inference on saved model, decoder: dot, metrics: mrr amri"
python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --save-embed-path /data/gsgnn_lp_ml_dot/infer-emb/ --restore-model-path /data/gsgnn_lp_ml_dot/epoch-$best_epoch_dot/ --logging-file /tmp/log.txt --preserve-input True --eval-metric mrr amri

error_and_exit $?

Expand All @@ -390,6 +390,13 @@ then
exit -1
fi

cnt=$(grep -c "| Test amri" /tmp/log.txt)
if test $cnt -ne 1
then
echo "We do test, should have amri"
exit 1
fi

bst_cnt=$(grep "Best Test mrr" /tmp/log.txt | wc -l)
if test $bst_cnt -lt 1
then
Expand Down

0 comments on commit ee16e74

Please # to comment.