Source code for daisy.evaluation.anomaly_detection_online_evaluation

# Copyright (C) 2024-2025 DAI-Labor and others
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
"""Extensions to the tensorflow metric set, to measure the quality of online anomaly
detection approaches. For now, only a sliding-window-based solution is realized,
as this works for any kind of metric (not every metric can be computed online).

Author: Fabian Hofmann
Modified: 17.08.23
"""

from abc import ABC, abstractmethod
from collections import deque
from typing import Self

import tensorflow as tf
from tensorflow import keras, Tensor


[docs] class SlidingWindowEvaluation(keras.metrics.Metric, ABC): """Abstract evaluation metric class that extends the existing tensorflow metric base class, with a sliding window to collect the k most recent predicted labels to evaluate the model's recent performance on them in point-wise manner. Note that depending on the metric, non-abstract methods must be extended with a new metric's own functionality. """ true_labels: deque pred_labels: deque _window_size: int def __init__(self, name="ad_online_evaluation", window_size: int = None, **kwargs): """Creates a new sliding window evaluation metric. :param name: Name of metric. :param window_size: Size of sliding window. If not provided, assume infinite window size. :param kwargs: Additional metric/layer keywords arguments. """ super().__init__(name=name, **kwargs) self.true_labels = deque() self.pred_labels = deque() self._window_size = window_size
[docs] def update_state(self, y_true, y_pred, *args, **kwargs): """Adds a mini-batch of inputs to the metric, removing old ones if the window is full, and adjusting statistics accordingly. Converts any tensors into numpy arrays as data points/pairs are processed in element-wise fashion anyway and this makes it easier for generic handling. :param y_true: Vector/Tensor containing true labels of inputs. :param y_pred: Vector/Tensor containing predicted labels of inputs. :param args: Not supported arguments. :param kwargs: Not supported keywords arguments. """ if tf.is_tensor(y_true) and tf.is_tensor(y_pred): y_true = y_true.numpy() y_pred = y_pred.numpy() for t_label, p_label in zip(y_true, y_pred): if len(self.true_labels) == self._window_size: old_t_label = self.true_labels.popleft() old_p_label = self.pred_labels.popleft() self._update(old_t_label, old_p_label, remove=True) self.true_labels.append(t_label) self.pred_labels.append(p_label) self._update(t_label, p_label)
@abstractmethod def _update(self, t_label, p_label, remove: bool = False): """Update function that must be implemented for each metric individually, called during update_state(), able to update the state variables for a singular data point/pair, for both its addition and its removal from the window. :param t_label: True label of single input. :param p_label: Predicted label of single input. :param remove: Whether the input pair is to be removed from the sliding window or added. """ raise NotImplementedError
[docs] def merge_state(self, metrics: Self): """Merges the state from one or more metrics, by merging their sliding windows. Note this is only possible if the sliding window of the current instance is able to encompass all other windows. :param metrics: An iterable of sliding window metrics of the same type. """ for m in metrics: self.update_state(m.true_labels, m.pred_labels)
[docs] def reset_state(self): """Resets the sliding window and all the metric's state variables.""" self.true_labels = deque() self.pred_labels = deque() self._reset()
@abstractmethod def _reset(self): """Reset function that must be implemented for each metric individually, called during reset_state(), resets all underlying statistics to their original state. """ raise NotImplementedError
[docs] @abstractmethod def result(self): """Computes and returns the scalar value(s) of the metric. Idempotent operation based on the underlying state variables and the sliding window. :return: A scalar tensor, or a dictionary of scalar tensors. """ raise NotImplementedError
[docs] class ConfMatrSlidingWindowEvaluation(SlidingWindowEvaluation): """Sliding window evaluation metric that computes the entire confusion matrix along with most(*) its metrics over the k most recent predicted binary labels to evaluate the model's recent performance on them in point-wise manner. """ _fp: int _tp: int _fn: int _tn: int def __init__( self, name="conf_matrix_online_evaluation", window_size: int = None, **kwargs ): """Creates a new confusion matrix sliding window evaluation metric. :param name: Name of metric. :param window_size: Size of sliding window. If not provided, assume infinite window size. :param kwargs: Additional metric/layer keywords arguments. """ super().__init__(name=name, window_size=window_size, **kwargs) self._fp = 0 self._tp = 0 self._fn = 0 self._tn = 0 def _update(self, t_label: bool, p_label: bool, remove: bool = False): """Updates the confusion matrix based on a single data point/pair, for both its addition and its removal from the window. :param t_label: True label of single input. :param p_label: Predicted label of single input. :param remove: Whether the input pair is to be removed from the sliding window confusion matrix or added. """ mod = 1 if not remove else -1 if t_label: if p_label: self._tp += mod else: self._fn += mod else: if not p_label: self._tn += mod else: self._fp += mod def _reset(self): """Zeroes the confusion matrix.""" self._fp = 0 self._tp = 0 self._fn = 0 self._tn = 0 # noinspection DuplicatedCode
[docs] def result(self) -> dict[str, Tensor]: """Based on the accumulated confusion matrix, computes its derived scalar metrics and returns them. :return: Dictionary of all derived scalar (tensor) confusion matrix metrics. """ accuracy = (self._tp + self._tn) / len(self.true_labels) recall = tf.math.divide_no_nan(self._tp, (self._tp + self._fn)) tnr = tf.math.divide_no_nan(self._tn, (self._tn + self._fp)) precision = tf.math.divide_no_nan(self._tp, (self._tp + self._fp)) npv = tf.math.divide_no_nan(self._tn, (self._tn + self._fn)) fnr = tf.math.divide_no_nan(self._fn, (self._fn + self._tp)) fpr = tf.math.divide_no_nan(self._fp, (self._fp + self._tn)) f1 = tf.math.divide_no_nan(2 * self._tp, (2 * self._tp + self._fp + self._fn)) return { "accuracy": accuracy, "recall": recall, "true negative rate": tnr, "precision": precision, "negative predictive value": npv, "false negative rate": fnr, "false positive rate": fpr, "f1": f1, }
[docs] class TFMetricSlidingWindowEvaluation(SlidingWindowEvaluation): """Wrapper class for all kinds of tensorflow evaluation metrics that operate on true-predicted label comparisons. Uses the provided sliding window to accumulate a subset of the overall data points and evaluates them using the tensorflow metric when called upon. Not very computational efficient since cumulative aggregation cannot be supported as not every metric can be computed in sliding window manner. """ _tf_metric: keras.metrics.Metric def __init__( self, tf_metric: keras.metrics.Metric, window_size: int = None, **kwargs ): """Create a new wrapped tf sliding window evaluation metric. :param tf_metric: Tensorflow metric to be wrapped. :param window_size: Size of sliding window. If not provided, assume infinite window size. :param kwargs: Additional metric/layer keywords arguments. """ super().__init__( name=tf_metric.name + "_online_evaluation", window_size=window_size, **kwargs, ) self._tf_metric = tf_metric def _update(self, t_label, p_label, remove: bool = False): """Method is skipped as the entire metric is computed during every result() call.""" pass def _reset(self): """As the metrics are reset every result() call, there is no reason to reset more.""" pass
[docs] def result(self): """Based on the current window, computes and returns the scalar metric value tensor or a dict of scalars, after resetting the metric once more. :return: Tensorflow metric result. """ result = self._tf_metric(self.true_labels, self.pred_labels) self._tf_metric.reset_state() return result