[docs]classLossTrackerWrapper:"""A wrapper for Keras Mean metrics to track multiple loss values."""def__init__(self,prefix):""" Initialize the loss tracker wrapper. Args: prefix (str): Prefix to use for the loss name. For example "n_loss" or "i_loss". """self.prefix=prefixself.trackers={}
[docs]defupdate_state(self,loss_value):""" Update the tracker(s) with a loss value. If loss_value is a dict, then for each key a separate tracker is created (if not already created) and updated. The tracker's name will be <prefix>_<key>. If loss_value is not a dict, then a default tracker with name <prefix> is updated. Args: loss_value: A tensor or a dictionary mapping field names to tensors. """ifisinstance(loss_value,dict):forkey,valueinloss_value.items():tracker_name=f"{self.prefix}_{key}"iftracker_namenotinself.trackers:self.trackers[tracker_name]=keras.metrics.Mean(name=tracker_name)self.trackers[tracker_name].update_state(value)else:ifself.prefixnotinself.trackers:self.trackers[self.prefix]=keras.metrics.Mean(name=self.prefix)self.trackers[self.prefix].update_state(loss_value)
[docs]defresult(self):""" Return a dictionary with the current average results. """results={}for_,trackerinself.trackers.items():# Use the tracker's name (e.g. "n_loss_a") if availableresults[tracker.name]=tracker.result()returnresults
[docs]defreset_state(self):""" Reset all the internal trackers. """fortrackerinself.trackers.values():tracker.reset_state()