tfmri.metrics.ConfusionMetric
tfmri.metrics.ConfusionMetric¶
- class ConfusionMetric(*args, **kwargs)[source]¶
Bases:
keras.metrics.base_metric.Metric
Abstract base class for metrics derived from the confusion matrix.
This class maintains a confusion matrix in its state and updates it with every call to
update_state
. Subclasses must implement the_result
method to compute the desired metric._result
is called duringresult
.Inputs
y_true
andy_pred
are expected to have shape[..., num_classes]
, with channeli
containing labels/predictions for classi
.y_true[..., i]
is 1 if the element represented byy_true[...]
is a member of classi
and 0 otherwise.y_pred[..., i]
is the predicted probability, in the range[0.0, 1.0]
, that the element represented byy_pred[...]
is a member of classi
.This metric works for binary, multiclass and multilabel classification. In multiclass/multilabel problems, this metric can be used to measure performance globally or for a specific class.
With the default configuration, this metric will:
If
num_classes == 1
, assume a binary classification problem with a threshold of 0.5 and return the confusion metric.If
num_classes >= 2
, assume a multiclass classification problem where the class with the highest probability is selected as the prediction, compute the confusion metric for each class and return the unweighted mean.
See the Parameters and Notes for other configurations.
- Parameters
num_classes – Number of unique classes in the dataset. If this value is not specified, it will be inferred during the first call to
update_state
asy_pred.shape[-1]
.class_id –
Integer class ID for which metrics should be reported. This must be in the half-open interval [0, num_classes). If None, a global average metric is returned as defined by
average
. Defaults to None.average –
Type of averaging to be performed on data. Valid values are None,
'micro'
,'macro'
and'weighted'
. Defaults to'macro'
. See Notes for details on the different modes. This parameter is ignored ifclass_id
is not None.threshold –
Elements of
y_pred
above threshold are considered to be 1, and the rest 0. A list of lengthnum_classes
may be provided to specify a threshold for each class. If threshold is None, the argmax is converted to 1, and the rest 0. Defaults to None ifnum_classes >= 2
(multiclass classification) and 0.5 ifnum_classes == 1
(binary classification). This parameter is required for multilabel classification.name – String name of the metric instance.
dtype – Data type of the metric result.
Notes
This metric works for binary, multiclass and multilabel classification.
For binary tasks, set
num_classes
to 1, and optionally,threshold
to the desired value (default is 0.5 if unspecified). The value ofaverage
is irrelevant.For multiclass tasks, set
num_classes
to the number of possible labels and setaverage
to the desired mode.threshold
should be left as None.For multilabel tasks, set
num_classes
to the number of possible labels, setthreshold
to the desired value in the range(0.0, 1.0)
(or provide a list of lengthnum_classes
to specify a different threshold value for each class), and setaverage
to the desired mode.
In multiclass/multilabel problems, this metric can be used to measure performance globally or for a specific class. For a specific class, set
class_id
to the desired value. For a global measure, setclass_id
to None andaverage
to the desired averaging method.average
can take the following values:None: Scores for each class are returned.
'micro'
: Calculate metrics globally by counting the total true positives, true negatives, false positives and false negatives.'macro'
: Calculate metrics for each label, and return their unweighted mean. This does not take label imbalance into account.'weighted'
: Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters ‘macro’ to account for label imbalance.
- reset_state()[source]¶
Resets all of the metric state variables.
This function is called between epochs/steps, when a metric is evaluated during training.
- result()[source]¶
Computes and returns the scalar metric value tensor or a dict of scalars.
Result computation is an idempotent operation that simply calculates the metric value using the state variables.
- Returns
A scalar tensor, or a dictionary of scalar tensors.
- update_state(y_true, y_pred, sample_weight=None)[source]¶
Update confusion matrix entries.
- Parameters
y_true – The ground truth labels. Must have shape
[..., num_classes]
, wherey_true[..., i]
is 1 if the element represented byy_true[...]
is a member of classi
and 0 otherwise.y_pred – The predictions. Must have shape
[..., num_classes]
, wherey_pred[..., i]
is the predicted probability, in the range[0.0, 1.0]
, that the element represented byy_pred[...]
is a member of classi
.sample_weight –
The predictions are weighted by
sample_weight
. Ifsample_weight
is None, weights default to 1. Use asample_weight
of 0 to mask values.