tfmri.metrics.ConfusionMetric

class ConfusionMetric(*args, **kwargs)[source]

Bases: keras.metrics.base_metric.Metric

Abstract base class for metrics derived from the confusion matrix.

This class maintains a confusion matrix in its state and updates it with every call to update_state. Subclasses must implement the _result method to compute the desired metric. _result is called during result.

Inputs y_true and y_pred are expected to have shape [..., num_classes], with channel i containing labels/predictions for class i. y_true[..., i] is 1 if the element represented by y_true[...] is a member of class i and 0 otherwise. y_pred[..., i] is the predicted probability, in the range [0.0, 1.0], that the element represented by y_pred[...] is a member of class i.

This metric works for binary, multiclass and multilabel classification. In multiclass/multilabel problems, this metric can be used to measure performance globally or for a specific class.

With the default configuration, this metric will:

  • If num_classes == 1, assume a binary classification problem with a threshold of 0.5 and return the confusion metric.

  • If num_classes >= 2, assume a multiclass classification problem where the class with the highest probability is selected as the prediction, compute the confusion metric for each class and return the unweighted mean.

See the Parameters and Notes for other configurations.

Parameters
  • num_classes – Number of unique classes in the dataset. If this value is not specified, it will be inferred during the first call to update_state as y_pred.shape[-1].

  • class_id

    Integer class ID for which metrics should be reported. This must be in the half-open interval [0, num_classes). If None, a global average metric is returned as defined by average. Defaults to None.

  • average

    Type of averaging to be performed on data. Valid values are None, 'micro', 'macro' and 'weighted'. Defaults to 'macro'. See Notes for details on the different modes. This parameter is ignored if class_id is not None.

  • threshold

    Elements of y_pred above threshold are considered to be 1, and the rest 0. A list of length num_classes may be provided to specify a threshold for each class. If threshold is None, the argmax is converted to 1, and the rest 0. Defaults to None if num_classes >= 2 (multiclass classification) and 0.5 if num_classes == 1 (binary classification). This parameter is required for multilabel classification.

  • name – String name of the metric instance.

  • dtype – Data type of the metric result.

Notes

This metric works for binary, multiclass and multilabel classification.

  • For binary tasks, set num_classes to 1, and optionally, threshold to the desired value (default is 0.5 if unspecified). The value of average is irrelevant.

  • For multiclass tasks, set num_classes to the number of possible labels and set average to the desired mode. threshold should be left as None.

  • For multilabel tasks, set num_classes to the number of possible labels, set threshold to the desired value in the range (0.0, 1.0) (or provide a list of length num_classes to specify a different threshold value for each class), and set average to the desired mode.

In multiclass/multilabel problems, this metric can be used to measure performance globally or for a specific class. For a specific class, set class_id to the desired value. For a global measure, set class_id to None and average to the desired averaging method. average can take the following values:

  • None: Scores for each class are returned.

  • 'micro': Calculate metrics globally by counting the total true positives, true negatives, false positives and false negatives.

  • 'macro': Calculate metrics for each label, and return their unweighted mean. This does not take label imbalance into account.

  • 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters ‘macro’ to account for label imbalance.

get_config()[source]

Returns the serializable config of the metric.

reset_state()[source]

Resets all of the metric state variables.

This function is called between epochs/steps, when a metric is evaluated during training.

result()[source]

Computes and returns the scalar metric value tensor or a dict of scalars.

Result computation is an idempotent operation that simply calculates the metric value using the state variables.

Returns

A scalar tensor, or a dictionary of scalar tensors.

update_state(y_true, y_pred, sample_weight=None)[source]

Update confusion matrix entries.

Parameters
  • y_true – The ground truth labels. Must have shape [..., num_classes], where y_true[..., i] is 1 if the element represented by y_true[...] is a member of class i and 0 otherwise.

  • y_pred – The predictions. Must have shape [..., num_classes], where y_pred[..., i] is the predicted probability, in the range [0.0, 1.0], that the element represented by y_pred[...] is a member of class i.

  • sample_weight

    The predictions are weighted by sample_weight. If sample_weight is None, weights default to 1. Use a sample_weight of 0 to mask values.