Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 76 additions & 19 deletions ml_metrics/_src/aggregates/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import collections
from collections.abc import Callable, Iterable
import dataclasses
import enum
import itertools
import math
from typing import Any, Generic, Self, TypeVar
Expand All @@ -35,6 +36,38 @@
_T = TypeVar('_T')


@enum.unique
class FeatureType(enum.Enum):
"""Feature types."""

INT = 'int'
FLOAT = 'float'
STRING = 'string'


def _get_feature_type(value: list[Any]) -> FeatureType | None:
"""Returns the feature type of the given value."""
if not value:
return None
if isinstance(value[0], str):
return FeatureType.STRING
elif isinstance(value[0], bytes):
try:
value[0].decode('utf-8')
return FeatureType.STRING
except UnicodeDecodeError as e:
raise ValueError(
'Unsupported bytes feature type. Feature could not be decoded as'
f' UTF-8 string: {e}'
) from e
elif isinstance(value[0], int):
return FeatureType.INT
elif isinstance(value[0], float):
return FeatureType.FLOAT
else:
raise ValueError(f'Unsupported feature type: {type(value[0])}')


@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS)
@dataclasses.dataclass(kw_only=True)
class UnboundedSampler(chainable.CallableMetric, chainable.HasAsAggFn):
Expand Down Expand Up @@ -344,16 +377,18 @@ def __str__(self):
class FeatureStats:
"""Statistics for a single feature."""

feature_type: FeatureType | None = None
num_missing: int = 0
num_non_missing: int = 0
max_num_values: int = 0
min_num_values: int | None = None
tot_num_values: int = 0
avg_num_values: float = 0.0

def update(self, length: int):
def update(self, length: int, feature_type: FeatureType | None):
self.merge(
FeatureStats(
feature_type=feature_type,
num_non_missing=1,
max_num_values=length,
min_num_values=length,
Expand All @@ -363,6 +398,15 @@ def update(self, length: int):
)

def merge(self, other: Self):
"""Merges with other feature stats."""
if other.feature_type is not None:
if self.feature_type is None:
self.feature_type = other.feature_type
elif self.feature_type != other.feature_type:
raise ValueError(
f'Feature has conflicting types: {self.feature_type} vs'
f' {other.feature_type}'
)
self.num_non_missing += other.num_non_missing
self.max_num_values = max(self.max_num_values, other.max_num_values)
if other.min_num_values is None:
Expand Down Expand Up @@ -415,6 +459,12 @@ def to_proto(self):
feature_name_stats.num_stats.common_stats.tot_num_values = (
feature_stats.tot_num_values
)
if feature_stats.feature_type:
feature_name_stats.type = (
statistics_pb2.FeatureNameStatistics.Type.Value(
feature_stats.feature_type.name
)
)
return statistics_pb2.DatasetFeatureStatisticsList(
datasets=[feature_stats_proto]
)
Expand All @@ -426,10 +476,7 @@ class TfExampleStatsAgg(chainable.CallableMetric):
"""Computes statistics on features."""

batched_inputs: bool = True
_num_examples: int = 0
_feature_stats: dict[str, FeatureStats] = dataclasses.field(
default_factory=dict
)
_stats: TfExampleStats = dataclasses.field(default_factory=TfExampleStats)

def as_agg_fn(self) -> chainable.AggregateFn:
return chainable.as_agg_fn(
Expand All @@ -439,11 +486,11 @@ def as_agg_fn(self) -> chainable.AggregateFn:

@property
def num_examples(self) -> int:
return self._num_examples
return self._stats.num_examples

@property
def feature_stats(self) -> dict[str, FeatureStats]:
return self._feature_stats
return self._stats.feature_stats

def new(
self, inputs: dict[str, list[Any]] | list[dict[str, list[Any]]]
Expand All @@ -461,27 +508,37 @@ def new(
for example in inputs_list:
num_examples += 1
for key, value in example.items():
feature_stats[key].update(len(value))
try:
feature_stats[key].update(len(value), _get_feature_type(value))
except ValueError as e:
if 'conflicting types' in str(e):
raise ValueError(f'Feature {key} has conflicting types: {e}') from e
raise
return self.__class__(
batched_inputs=self.batched_inputs,
_num_examples=num_examples,
_feature_stats=dict(feature_stats),
_stats=TfExampleStats(
num_examples=num_examples,
feature_stats=dict(feature_stats),
),
)

def merge(self, other: Self) -> None:
self._num_examples += other.num_examples
self._stats.num_examples += other.num_examples
for key, value in other.feature_stats.items():
if key in self._feature_stats:
self._feature_stats[key].merge(value)
if key in self._stats.feature_stats:
try:
self._stats.feature_stats[key].merge(value)
except ValueError as e:
if 'conflicting types' in str(e):
raise ValueError(f'Feature {key} has conflicting types: {e}') from e
raise
else:
self._feature_stats[key] = value
self._stats.feature_stats[key] = value

def result(self) -> TfExampleStats:
for feature in self._feature_stats.values():
feature.num_missing = self._num_examples - feature.num_non_missing
return TfExampleStats(
num_examples=self._num_examples, feature_stats=self._feature_stats
)
for feature in self._stats.feature_stats.values():
feature.num_missing = self._stats.num_examples - feature.num_non_missing
return self._stats


@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS)
Expand Down
81 changes: 79 additions & 2 deletions ml_metrics/_src/aggregates/stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np

from absl.testing import absltest
from tensorflow_metadata.proto.v0 import statistics_pb2


class HistogramTest(parameterized.TestCase):
Expand Down Expand Up @@ -694,6 +695,7 @@ def test_single_example(self):
num_examples=1,
feature_stats={
'a': stats.FeatureStats(
feature_type=stats.FeatureType.INT,
num_missing=0,
num_non_missing=1,
max_num_values=1,
Expand All @@ -702,6 +704,7 @@ def test_single_example(self):
avg_num_values=1.0,
),
'b': stats.FeatureStats(
feature_type=stats.FeatureType.INT,
num_missing=0,
num_non_missing=1,
max_num_values=2,
Expand All @@ -723,6 +726,7 @@ def test_multiple_examples(self):
num_examples=2,
feature_stats={
'a': stats.FeatureStats(
feature_type=stats.FeatureType.INT,
num_missing=0,
num_non_missing=2,
max_num_values=3,
Expand All @@ -731,6 +735,7 @@ def test_multiple_examples(self):
avg_num_values=2.0,
),
'b': stats.FeatureStats(
feature_type=stats.FeatureType.INT,
num_missing=1,
num_non_missing=1,
max_num_values=2,
Expand All @@ -756,6 +761,7 @@ def test_merge(self):
num_examples=3,
feature_stats={
'a': stats.FeatureStats(
feature_type=stats.FeatureType.INT,
num_missing=1,
num_non_missing=2,
max_num_values=3,
Expand All @@ -764,6 +770,7 @@ def test_merge(self):
avg_num_values=2.0,
),
'b': stats.FeatureStats(
feature_type=stats.FeatureType.INT,
num_missing=1,
num_non_missing=2,
max_num_values=4,
Expand All @@ -772,6 +779,7 @@ def test_merge(self):
avg_num_values=3.0,
),
'c': stats.FeatureStats(
feature_type=stats.FeatureType.INT,
num_missing=2,
num_non_missing=1,
max_num_values=1,
Expand All @@ -794,6 +802,7 @@ def test_unbatched(self):
num_examples=2,
feature_stats={
'a': stats.FeatureStats(
feature_type=stats.FeatureType.INT,
num_missing=0,
num_non_missing=2,
max_num_values=3,
Expand All @@ -802,6 +811,7 @@ def test_unbatched(self):
avg_num_values=2.0,
),
'b': stats.FeatureStats(
feature_type=stats.FeatureType.INT,
num_missing=1,
num_non_missing=1,
max_num_values=2,
Expand All @@ -814,6 +824,72 @@ def test_unbatched(self):
agg.result(),
)

def test_feature_types(self):
examples = [{'a': [1], 'b': [1.0, 2.0], 'c': ['foo', 'bar']}]
agg = stats.TfExampleStatsAgg()
agg.add(examples)
self.assertEqual(
stats.TfExampleStats(
num_examples=1,
feature_stats={
'a': stats.FeatureStats(
feature_type=stats.FeatureType.INT,
num_missing=0,
num_non_missing=1,
max_num_values=1,
min_num_values=1,
tot_num_values=1,
avg_num_values=1.0,
),
'b': stats.FeatureStats(
feature_type=stats.FeatureType.FLOAT,
num_missing=0,
num_non_missing=1,
max_num_values=2,
min_num_values=2,
tot_num_values=2,
avg_num_values=2.0,
),
'c': stats.FeatureStats(
feature_type=stats.FeatureType.STRING,
num_missing=0,
num_non_missing=1,
max_num_values=2,
min_num_values=2,
tot_num_values=2,
avg_num_values=2.0,
),
},
),
agg.result(),
)

def test_conflicting_types_in_batch(self):
examples = [{'a': [1]}, {'a': ['foo']}]
agg = stats.TfExampleStatsAgg()
with self.assertRaisesRegex(ValueError, 'Feature a has conflicting types'):
agg.add(examples)

def test_conflicting_types_in_merge(self):
examples1 = [{'a': [1]}]
examples2 = [{'a': ['foo']}]
agg1 = stats.TfExampleStatsAgg()
agg1.add(examples1)
agg2 = stats.TfExampleStatsAgg()
agg2.add(examples2)
with self.assertRaisesRegex(ValueError, 'Feature a has conflicting types'):
agg1.merge(agg2)

def test_bytes_feature_type_invalid_utf8_error(self):
examples = [{'a': [b'\xff']}]
agg = stats.TfExampleStatsAgg()
with self.assertRaisesRegex(
ValueError,
'Unsupported bytes feature type. Feature could not be decoded as UTF-8'
' string',
):
agg.add(examples)


class MeanAndVarianceTest(parameterized.TestCase):

Expand Down Expand Up @@ -1900,8 +1976,8 @@ class TfdvTest(absltest.TestCase):

def test_to_proto(self):
feature_stats_instance = stats.FeatureStats()
feature_stats_instance.update(1)
feature_stats_instance.update(2)
feature_stats_instance.update(1, stats.FeatureType.INT)
feature_stats_instance.update(2, stats.FeatureType.INT)
data = stats.TfExampleStats(
num_examples=2,
feature_stats={
Expand All @@ -1919,6 +1995,7 @@ def test_to_proto(self):
self.assertEqual(feature.num_stats.common_stats.max_num_values, 2)
self.assertEqual(feature.num_stats.common_stats.avg_num_values, 1.5)
self.assertEqual(feature.num_stats.common_stats.tot_num_values, 3)
self.assertEqual(feature.type, statistics_pb2.FeatureNameStatistics.INT)


if __name__ == '__main__':
Expand Down
Loading