|
| 1 | +"""SequenceLengthSimilarity module.""" |
| 2 | + |
| 3 | +import pandas as pd |
| 4 | + |
| 5 | +from sdmetrics.goal import Goal |
| 6 | +from sdmetrics.single_column.statistical.kscomplement import KSComplement |
| 7 | + |
| 8 | + |
| 9 | +class SequenceLengthSimilarity: |
| 10 | + """Sequence Length Similarity metric. |
| 11 | +
|
| 12 | + Attributes: |
| 13 | + name (str): |
| 14 | + Name to use when reports about this metric are printed. |
| 15 | + goal (sdmetrics.goal.Goal): |
| 16 | + The goal of this metric. |
| 17 | + min_value (Union[float, tuple[float]]): |
| 18 | + Minimum value or values that this metric can take. |
| 19 | + max_value (Union[float, tuple[float]]): |
| 20 | + Maximum value or values that this metric can take. |
| 21 | + """ |
| 22 | + |
| 23 | + name = 'Sequence Length Similarity' |
| 24 | + goal = Goal.MAXIMIZE |
| 25 | + min_value = 0.0 |
| 26 | + max_value = 1.0 |
| 27 | + |
| 28 | + @staticmethod |
| 29 | + def compute(real_data: pd.Series, synthetic_data: pd.Series) -> float: |
| 30 | + """Compute this metric. |
| 31 | +
|
| 32 | + The length of a sequence is determined by the number of times the same sequence key occurs. |
| 33 | + For example if id_09231 appeared 150 times in the sequence key, then the sequence is of |
| 34 | + length 150. This metric compares the lengths of all sequence keys in the |
| 35 | + real data vs. the synthetic data. |
| 36 | +
|
| 37 | + It works as follows: |
| 38 | + - Calculate the length of each sequence in the real data |
| 39 | + - Calculate the length of each sequence in the synthetic data |
| 40 | + - Apply the KSComplement metric to compare the similarities of the distributions |
| 41 | + - Return this score |
| 42 | +
|
| 43 | + Args: |
| 44 | + real_data (pd.Series): |
| 45 | + The values from the real dataset. |
| 46 | + synthetic_data (pd.Series): |
| 47 | + The values from the synthetic dataset. |
| 48 | +
|
| 49 | + Returns: |
| 50 | + float: |
| 51 | + The score. |
| 52 | + """ |
| 53 | + return KSComplement.compute(real_data.value_counts(), synthetic_data.value_counts()) |
0 commit comments