11import os
2+ import numpy as np
23from sentence_transformers import SentenceTransformer , util
34
45MODEL_NAME = 'all-MiniLM-L6-v2'
@@ -21,23 +22,34 @@ def load_or_download_model():
2122 print (f"Model saved to { model_path } " )
2223 return model
2324
24- def find_similar_sentences (query , file_path , top_n = 5 ):
25- # Load the pre-trained model
26- model = load_or_download_model ()
25+ def cosine_similarity (query_embedding , sentence_embeddings ):
26+ return util .pytorch_cos_sim (query_embedding , sentence_embeddings )[0 ]
2727
28- # Load and encode the sentences from the file
29- sentences = load_file (file_path )
30- sentence_embeddings = model .encode (sentences )
28+ def euclidean_distance (query_embedding , sentence_embeddings ):
29+ return - np .linalg .norm (query_embedding - sentence_embeddings , axis = 1 )
3130
32- # Encode the query
33- query_embedding = model . encode ([ query ] )
31+ def manhattan_distance ( query_embedding , sentence_embeddings ):
32+ return - np . sum ( np . abs ( query_embedding - sentence_embeddings ), axis = 1 )
3433
35- # Calculate cosine similarities
36- cosine_scores = util . pytorch_cos_sim ( query_embedding , sentence_embeddings )[ 0 ]
34+ def dot_product ( query_embedding , sentence_embeddings ):
35+ return np . dot ( sentence_embeddings , query_embedding . T ). flatten ()
3736
38- # Get top N results
39- top_results = sorted (zip (sentences , cosine_scores ), key = lambda x : x [1 ], reverse = True )[:top_n ]
37+ similarity_functions = {
38+ '1' : ('Cosine Similarity' , cosine_similarity ),
39+ '2' : ('Euclidean Distance' , euclidean_distance ),
40+ '3' : ('Manhattan Distance' , manhattan_distance ),
41+ '4' : ('Dot Product' , dot_product )
42+ }
4043
44+ def find_similar_sentences (query , file_path , similarity_func , top_n = 5 ):
45+ model = load_or_download_model ()
46+ sentences = load_file (file_path )
47+ sentence_embeddings = model .encode (sentences )
48+ query_embedding = model .encode ([query ])
49+
50+ similarity_scores = similarity_func (query_embedding , sentence_embeddings )
51+ top_results = sorted (zip (sentences , similarity_scores ), key = lambda x : x [1 ], reverse = True )[:top_n ]
52+
4153 return top_results
4254
4355def validate_file_path (file_path ):
@@ -48,26 +60,34 @@ def validate_file_path(file_path):
4860 return file_path
4961
5062def main ():
51- print ("Welcome to the Sentence Similarity Search Tool!" )
63+ print ("Welcome to the Enhanced Sentence Similarity Search Tool!" )
5264
53- # Get user input for query
5465 query = input ("Enter your query: " )
5566
56- # Get user input for file path and validate it
5767 while True :
5868 file_path = input ("Enter the path to your text file without extension: " )
5969 try :
6070 file_path = validate_file_path (file_path )
6171 break
6272 except FileNotFoundError as e :
6373 print (f"Error: { str (e )} Please try again." )
74+
75+ print ("\n Choose a similarity measurement method:" )
76+ for key , (name , _ ) in similarity_functions .items ():
77+ print (f"{ key } . { name } " )
78+
79+ while True :
80+ choice = input ("Enter the number of your choice: " )
81+ if choice in similarity_functions :
82+ similarity_name , similarity_func = similarity_functions [choice ]
83+ break
84+ print ("Invalid choice. Please try again." )
6485
6586 try :
66- results = find_similar_sentences (query , file_path )
67-
68- print (f"\n Top 5 similar sentences for query: '{ query } '\n " )
87+ results = find_similar_sentences (query , file_path , similarity_func )
88+ print (f"\n Top 5 similar sentences for query: '{ query } ' using { similarity_name } \n " )
6989 for sentence , score in results :
70- print (f"Similarity: { score :.4f} " )
90+ print (f"Similarity Score : { score :.4f} " )
7191 print (f"Sentence: { sentence } \n " )
7292 except Exception as e :
7393 print (f"An error occurred: { str (e )} " )
0 commit comments