tfmri.metrics.FBetaScore

class FBetaScore(*args, **kwargs)[source]

Bases: tensorflow_mri.python.metrics.confusion_metrics.TverskyIndex

Computes F-beta score.

The F-beta score is the weighted harmonic mean of precision and recall.

\[F_{\beta} = (1 + \beta^2) * \frac{\textrm{precision} * \textrm{precision}}{(\beta^2 \cdot \textrm{precision}) + \textrm{recall}}\]

Inputs y_true and y_pred are expected to have shape [..., num_classes], with channel i containing labels/predictions for class i. y_true[..., i] is 1 if the element represented by y_true[...] is a member of class i and 0 otherwise. y_pred[..., i] is the predicted probability, in the range [0.0, 1.0], that the element represented by y_pred[...] is a member of class i.

This metric works for binary, multiclass and multilabel classification. In multiclass/multilabel problems, this metric can be used to measure performance globally or for a specific class.

With the default configuration, this metric will:

  • If num_classes == 1, assume a binary classification problem with a threshold of 0.5 and return the confusion metric.

  • If num_classes >= 2, assume a multiclass classification problem where the class with the highest probability is selected as the prediction, compute the confusion metric for each class and return the unweighted mean.

See the Parameters and Notes for other configurations.

Parameters
  • num_classes – Number of unique classes in the dataset. If this value is not specified, it will be inferred during the first call to update_state as y_pred.shape[-1].

  • class_id

    Integer class ID for which metrics should be reported. This must be in the half-open interval [0, num_classes). If None, a global average metric is returned as defined by average. Defaults to None.

  • average

    Type of averaging to be performed on data. Valid values are None, 'micro', 'macro' and 'weighted'. Defaults to 'macro'. See Notes for details on the different modes. This parameter is ignored if class_id is not None.

  • threshold

    Elements of y_pred above threshold are considered to be 1, and the rest 0. A list of length num_classes may be provided to specify a threshold for each class. If threshold is None, the argmax is converted to 1, and the rest 0. Defaults to None if num_classes >= 2 (multiclass classification) and 0.5 if num_classes == 1 (binary classification). This parameter is required for multilabel classification.

  • beta – A float. Determines the weight of precision and recall in harmonic mean, such that recall is beta times as important as precision.

  • name – String name of the metric instance.

  • dtype – Data type of the metric result.

Notes

This metric works for binary, multiclass and multilabel classification.

  • For binary tasks, set num_classes to 1, and optionally, threshold to the desired value (default is 0.5 if unspecified). The value of average is irrelevant.

  • For multiclass tasks, set num_classes to the number of possible labels and set average to the desired mode. threshold should be left as None.

  • For multilabel tasks, set num_classes to the number of possible labels, set threshold to the desired value in the range (0.0, 1.0) (or provide a list of length num_classes to specify a different threshold value for each class), and set average to the desired mode.

In multiclass/multilabel problems, this metric can be used to measure performance globally or for a specific class. For a specific class, set class_id to the desired value. For a global measure, set class_id to None and average to the desired averaging method. average can take the following values:

  • None: Scores for each class are returned.

  • 'micro': Calculate metrics globally by counting the total true positives, true negatives, false positives and false negatives.

  • 'macro': Calculate metrics for each label, and return their unweighted mean. This does not take label imbalance into account.

  • 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters ‘macro’ to account for label imbalance.

get_config()[source]

Returns the serializable config of the metric.