Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 85 additions & 28 deletions src/vendetect/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import sys
from collections.abc import Iterable
from pathlib import Path
from typing import TextIO
from typing import TYPE_CHECKING, TextIO

if TYPE_CHECKING:
from .metrics import ComparisonMetric

from rich import traceback
from rich.console import Console, ConsoleRenderable
Expand All @@ -19,6 +22,7 @@
from .detector import Detection, Status, VenDetector, get_lexer_for_filename
from .diffing import CollapsedDiffLine, Differ, DiffLineStatus, Document, normalized_edit_distance
from .errors import VendetectError
from .metrics import get_metric
from .repo import File, Repository

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -59,10 +63,16 @@ def update_compare_progress(self, file: File | None = None) -> None:
)


def output_csv(detections: Iterable[Detection], min_similarity: float = 0.5, output_file: TextIO | None = None) -> None:
def output_csv(
detections: Iterable[Detection],
min_similarity: float = 0.5,
output_file: TextIO | None = None,
metric: "ComparisonMetric | None" = None,
) -> None:
output = output_file if output_file else sys.stdout
csv_writer = csv.writer(output)
# Write header
# Write header - use "Token Overlap" for token_overlap metric
header_name = "Token Overlap" if metric and metric.name() == "token_overlap" else "Similarity"
csv_writer.writerow(
[
"Test File",
Expand All @@ -71,22 +81,36 @@ def output_csv(detections: Iterable[Detection], min_similarity: float = 0.5, out
"Test Slice End",
"Source Slice Start",
"Source Slice End",
"Similarity",
header_name,
]
)

for d in detections:
# Calculate overall similarity (average of both similarities)
avg_similarity = (d.comparison.similarity1 + d.comparison.similarity2) / 2

if avg_similarity < min_similarity:
break
# Calculate overall similarity (use metric if provided, otherwise average)
if metric is not None:
similarity_score = metric.score(d.comparison)
# For token_overlap, interpret min_similarity as minimum token count
if metric.name() == "token_overlap":
if similarity_score < min_similarity:
continue # Skip this detection if below threshold
elif similarity_score < min_similarity:
break # For other metrics, stop iteration
else:
similarity_score = (d.comparison.similarity1 + d.comparison.similarity2) / 2
if similarity_score < min_similarity:
break

# Get slices
test_slices = d.comparison.slices1
source_slices = d.comparison.slices2

for (test_start, test_end), (source_start, source_end) in zip(test_slices, source_slices, strict=False):
# Format the score appropriately
if metric and metric.name() == "token_overlap":
score_str = str(int(similarity_score))
else:
score_str = f"{similarity_score:.4f}"

# Write one row per matched slice
csv_writer.writerow(
[
Expand All @@ -96,23 +120,34 @@ def output_csv(detections: Iterable[Detection], min_similarity: float = 0.5, out
test_end,
source_start,
source_end,
f"{avg_similarity:.4f}",
score_str,
]
)


def output_json(
detections: Iterable[Detection], min_similarity: float = 0.5, output_file: TextIO | None = None
detections: Iterable[Detection],
min_similarity: float = 0.5,
output_file: TextIO | None = None,
metric: "ComparisonMetric | None" = None,
) -> None:
results = []
output = output_file if output_file else sys.stdout

for d in detections:
# Calculate overall similarity (average of both similarities)
avg_similarity = (d.comparison.similarity1 + d.comparison.similarity2) / 2

if avg_similarity < min_similarity:
break
# Calculate overall similarity (use metric if provided, otherwise average)
if metric is not None:
similarity_score = metric.score(d.comparison)
# For token_overlap, interpret min_similarity as minimum token count
if metric.name() == "token_overlap":
if similarity_score < min_similarity:
continue # Skip this detection if below threshold
elif similarity_score < min_similarity:
break # For other metrics, stop iteration
else:
similarity_score = (d.comparison.similarity1 + d.comparison.similarity2) / 2
if similarity_score < min_similarity:
break

# Get slices
test_slices = d.comparison.slices1
Expand All @@ -130,15 +165,25 @@ def output_json(
}
)

# Create detection data
detection_data = {
"test_file": f"{d.test.relative_path!s}",
"source_file": f"{d.source.relative_path!s}",
"similarity": round(avg_similarity, 4),
"similarity_test": round(d.comparison.similarity1, 4),
"similarity_source": round(d.comparison.similarity2, 4),
"slices": slices_data,
}
# Create detection data - use appropriate field name and value for token_overlap
if metric and metric.name() == "token_overlap":
detection_data = {
"test_file": f"{d.test.relative_path!s}",
"source_file": f"{d.source.relative_path!s}",
"token_overlap": int(similarity_score),
"similarity_test": round(d.comparison.similarity1, 4),
"similarity_source": round(d.comparison.similarity2, 4),
"slices": slices_data,
}
else:
detection_data = {
"test_file": f"{d.test.relative_path!s}",
"source_file": f"{d.source.relative_path!s}",
"similarity": round(similarity_score, 4),
"similarity_test": round(d.comparison.similarity1, 4),
"similarity_source": round(d.comparison.similarity2, 4),
"slices": slices_data,
}

results.append(detection_data)

Expand Down Expand Up @@ -321,7 +366,15 @@ def main() -> None: # noqa: C901, PLR0912, PLR0915
"--min-similarity",
type=float,
default=0.5,
help="the minimum similarity threshold to output a match (range: 0.0-1.0, default: 0.5)",
help="minimum threshold to output a match (range: 0.0-1.0 for similarity metrics, "
"or minimum token count for token_overlap metric, default: 0.5)",
)
parser.add_argument(
"--metric",
type=str,
default="sum",
choices=["sum", "average", "min", "max", "token_overlap", "weighted"],
help="comparison metric to use for ranking detections (default: sum)",
)

# Performance optimization options
Expand Down Expand Up @@ -402,12 +455,16 @@ def main() -> None: # noqa: C901, PLR0912, PLR0915
Repository.load(args.SOURCE_REPO, args.source_subdir) as source_repo,
RichStatus(console) as status,
):
# Get the comparison metric
metric = get_metric(args.metric)

# Initialize detector with optimization options
vend = VenDetector(
status=status,
incremental=args.incremental,
batch_size=args.batch_size,
max_history_depth=args.max_history_depth,
metric=metric,
)

# Get detections
Expand All @@ -430,9 +487,9 @@ def file_filter(file: File) -> bool:

# Output based on format
if args.format == "csv":
output_csv(detections, args.min_similarity, output_file)
output_csv(detections, args.min_similarity, output_file, metric)
elif args.format == "json":
output_json(detections, args.min_similarity, output_file)
output_json(detections, args.min_similarity, output_file, metric)
else: # rich format
output_rich(detections, console, args.min_similarity, output_file)
except VendetectError as e:
Expand Down
10 changes: 2 additions & 8 deletions src/vendetect/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,8 @@ class Comparison:
slices2: tuple[Slice, ...]

def __lt__(self, other: "Comparison") -> bool:
# TODO([email protected]): Make the comparison metric user-specifiable (#14) # noqa: FIX002
# For now, disable comparison of token overlap, because that was causing too many false-positives
# (I believe due to whitespace overlap)
#
# if self.token_overlap > other.token_overlap:
# return True # noqa: ERA001
# if self.token_overlap < other.token_overlap:
# return False # noqa: ERA001
# Default behavior: sum of similarities
# Note: This is overridden when using custom metrics via Detection class
oursim = self.similarity1 + self.similarity2
theirsim = other.similarity1 + other.similarity2
return oursim > theirsim
Expand Down
14 changes: 11 additions & 3 deletions src/vendetect/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import wraps
from heapq import heappop, heappush
from logging import getLogger
from typing import TypeVar
from typing import TYPE_CHECKING, TypeVar

from pygments import lexer, lexers
from pygments.util import ClassNotFound
Expand All @@ -14,6 +14,9 @@
from .copydetect import CopyDetectComparator
from .repo import File, Repository, Rounding

if TYPE_CHECKING:
from .metrics import ComparisonMetric

log = getLogger(__name__)
F = TypeVar("F")

Expand Down Expand Up @@ -93,8 +96,11 @@ class Detection:
test: File
source: File
comparison: Comparison
metric: "ComparisonMetric | None" = None

def __lt__(self, other: "Detection") -> bool:
if self.metric is not None:
return self.metric.score(self.comparison) > self.metric.score(other.comparison)
return self.comparison < other.comparison

@property
Expand All @@ -111,14 +117,15 @@ def source_repo(self) -> Repository:


class VenDetector:
def __init__(
def __init__( # noqa: PLR0913
self,
comparator: Comparator[F] | None = None,
status: Status | None = None,
batch_size: int = 100,
max_history_depth: int | None = None,
*,
incremental: bool = False,
metric: "ComparisonMetric | None" = None,
):
if comparator is None:
comparator = CopyDetectComparator()
Expand All @@ -133,6 +140,7 @@ def __init__(
max_history_depth if max_history_depth is not None and max_history_depth >= 0 else None
) # Limit history traversal depth
self._fingerprint_cache: dict[File, F] = {} # Cache fingerprints
self.metric = metric # Custom comparison metric

@staticmethod
def callback(func: Callable) -> Callable:
Expand Down Expand Up @@ -224,7 +232,7 @@ def compare( # noqa: C901, PLR0912
continue

cmp = self.comparator.compare(fp1, fp2) # type: ignore
d = Detection(test_file, source_file, cmp)
d = Detection(test_file, source_file, cmp, metric=self.metric)
heappush(detections, d)

# Process accumulated detections for this batch
Expand Down
108 changes: 108 additions & 0 deletions src/vendetect/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Comparison metrics for ranking detection results."""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .comparison import Comparison


class ComparisonMetric(ABC):
"""Abstract base class for comparison metrics."""

@abstractmethod
def score(self, comparison: "Comparison") -> float:
"""Calculate a score for the comparison. Higher scores indicate better matches."""
raise NotImplementedError

@abstractmethod
def name(self) -> str:
"""Return the name of this metric."""
raise NotImplementedError


class SumSimilarityMetric(ComparisonMetric):
"""Default metric: sum of both similarity scores."""

def score(self, comparison: "Comparison") -> float:
return comparison.similarity1 + comparison.similarity2

def name(self) -> str:
return "sum"


class AverageSimilarityMetric(ComparisonMetric):
"""Average of both similarity scores."""

def score(self, comparison: "Comparison") -> float:
return (comparison.similarity1 + comparison.similarity2) / 2

def name(self) -> str:
return "average"


class MinSimilarityMetric(ComparisonMetric):
"""Minimum of both similarity scores (most conservative)."""

def score(self, comparison: "Comparison") -> float:
return min(comparison.similarity1, comparison.similarity2)

def name(self) -> str:
return "min"


class MaxSimilarityMetric(ComparisonMetric):
"""Maximum of both similarity scores (most aggressive)."""

def score(self, comparison: "Comparison") -> float:
return max(comparison.similarity1, comparison.similarity2)

def name(self) -> str:
return "max"


class TokenOverlapMetric(ComparisonMetric):
"""Raw token overlap count."""

def score(self, comparison: "Comparison") -> float:
return float(comparison.token_overlap)

def name(self) -> str:
return "token_overlap"


class WeightedSimilarityMetric(ComparisonMetric):
"""Weighted combination of similarities and token overlap."""

def __init__(self, sim_weight: float = 0.8, token_weight: float = 0.2):
self.sim_weight = sim_weight
self.token_weight = token_weight

def score(self, comparison: "Comparison") -> float:
sim_score = (comparison.similarity1 + comparison.similarity2) / 2
# Normalize token overlap to 0-1 range (assuming max 1000 tokens for normalization)
normalized_tokens = min(comparison.token_overlap / 1000.0, 1.0)
return self.sim_weight * sim_score + self.token_weight * normalized_tokens

def name(self) -> str:
return "weighted"


# Registry of available metrics
METRICS = {
"sum": SumSimilarityMetric(),
"average": AverageSimilarityMetric(),
"min": MinSimilarityMetric(),
"max": MaxSimilarityMetric(),
"token_overlap": TokenOverlapMetric(),
"weighted": WeightedSimilarityMetric(),
}


def get_metric(name: str) -> ComparisonMetric:
"""Get a metric by name."""
if name not in METRICS:
available = ", ".join(METRICS.keys())
msg = f"Unknown metric: {name}. Available metrics: {available}"
raise ValueError(msg)
return METRICS[name]
Loading
Loading