1
1
import os
2
+ import numpy as np
2
3
from sentence_transformers import SentenceTransformer , util
3
4
4
5
MODEL_NAME = 'all-MiniLM-L6-v2'
@@ -21,23 +22,34 @@ def load_or_download_model():
21
22
print (f"Model saved to { model_path } " )
22
23
return model
23
24
24
- def find_similar_sentences (query , file_path , top_n = 5 ):
25
- # Load the pre-trained model
26
- model = load_or_download_model ()
25
+ def cosine_similarity (query_embedding , sentence_embeddings ):
26
+ return util .pytorch_cos_sim (query_embedding , sentence_embeddings )[0 ]
27
27
28
- # Load and encode the sentences from the file
29
- sentences = load_file (file_path )
30
- sentence_embeddings = model .encode (sentences )
28
+ def euclidean_distance (query_embedding , sentence_embeddings ):
29
+ return - np .linalg .norm (query_embedding - sentence_embeddings , axis = 1 )
31
30
32
- # Encode the query
33
- query_embedding = model . encode ([ query ] )
31
+ def manhattan_distance ( query_embedding , sentence_embeddings ):
32
+ return - np . sum ( np . abs ( query_embedding - sentence_embeddings ), axis = 1 )
34
33
35
- # Calculate cosine similarities
36
- cosine_scores = util . pytorch_cos_sim ( query_embedding , sentence_embeddings )[ 0 ]
34
+ def dot_product ( query_embedding , sentence_embeddings ):
35
+ return np . dot ( sentence_embeddings , query_embedding . T ). flatten ()
37
36
38
- # Get top N results
39
- top_results = sorted (zip (sentences , cosine_scores ), key = lambda x : x [1 ], reverse = True )[:top_n ]
37
+ similarity_functions = {
38
+ '1' : ('Cosine Similarity' , cosine_similarity ),
39
+ '2' : ('Euclidean Distance' , euclidean_distance ),
40
+ '3' : ('Manhattan Distance' , manhattan_distance ),
41
+ '4' : ('Dot Product' , dot_product )
42
+ }
40
43
44
+ def find_similar_sentences (query , file_path , similarity_func , top_n = 5 ):
45
+ model = load_or_download_model ()
46
+ sentences = load_file (file_path )
47
+ sentence_embeddings = model .encode (sentences )
48
+ query_embedding = model .encode ([query ])
49
+
50
+ similarity_scores = similarity_func (query_embedding , sentence_embeddings )
51
+ top_results = sorted (zip (sentences , similarity_scores ), key = lambda x : x [1 ], reverse = True )[:top_n ]
52
+
41
53
return top_results
42
54
43
55
def validate_file_path (file_path ):
@@ -48,26 +60,34 @@ def validate_file_path(file_path):
48
60
return file_path
49
61
50
62
def main ():
51
- print ("Welcome to the Sentence Similarity Search Tool!" )
63
+ print ("Welcome to the Enhanced Sentence Similarity Search Tool!" )
52
64
53
- # Get user input for query
54
65
query = input ("Enter your query: " )
55
66
56
- # Get user input for file path and validate it
57
67
while True :
58
68
file_path = input ("Enter the path to your text file without extension: " )
59
69
try :
60
70
file_path = validate_file_path (file_path )
61
71
break
62
72
except FileNotFoundError as e :
63
73
print (f"Error: { str (e )} Please try again." )
74
+
75
+ print ("\n Choose a similarity measurement method:" )
76
+ for key , (name , _ ) in similarity_functions .items ():
77
+ print (f"{ key } . { name } " )
78
+
79
+ while True :
80
+ choice = input ("Enter the number of your choice: " )
81
+ if choice in similarity_functions :
82
+ similarity_name , similarity_func = similarity_functions [choice ]
83
+ break
84
+ print ("Invalid choice. Please try again." )
64
85
65
86
try :
66
- results = find_similar_sentences (query , file_path )
67
-
68
- print (f"\n Top 5 similar sentences for query: '{ query } '\n " )
87
+ results = find_similar_sentences (query , file_path , similarity_func )
88
+ print (f"\n Top 5 similar sentences for query: '{ query } ' using { similarity_name } \n " )
69
89
for sentence , score in results :
70
- print (f"Similarity: { score :.4f} " )
90
+ print (f"Similarity Score : { score :.4f} " )
71
91
print (f"Sentence: { sentence } \n " )
72
92
except Exception as e :
73
93
print (f"An error occurred: { str (e )} " )
0 commit comments