1+ import os
2+ import numpy as np
3+ from sentence_transformers import SentenceTransformer , util
4+
5+ MODEL_NAME = 'all-MiniLM-L6-v2'
6+ MODEL_FOLDER = 'model'
7+
8+ def load_file (file_path ):
9+ with open (file_path , 'r' , encoding = 'utf-8' ) as file :
10+ return [line .strip () for line in file if line .strip ()]
11+
12+ def load_or_download_model ():
13+ model_path = os .path .join (MODEL_FOLDER , MODEL_NAME )
14+ if os .path .exists (model_path ):
15+ print (f"Loading model from { model_path } " )
16+ return SentenceTransformer (model_path )
17+ else :
18+ print (f"Downloading model { MODEL_NAME } " )
19+ model = SentenceTransformer (MODEL_NAME )
20+ os .makedirs (MODEL_FOLDER , exist_ok = True )
21+ model .save (model_path )
22+ print (f"Model saved to { model_path } " )
23+ return model
24+
25+ def cosine_similarity (query_embedding , sentence_embeddings ):
26+ return util .pytorch_cos_sim (query_embedding , sentence_embeddings )[0 ]
27+
28+ def euclidean_distance (query_embedding , sentence_embeddings ):
29+ return - np .linalg .norm (query_embedding - sentence_embeddings , axis = 1 )
30+
31+ def manhattan_distance (query_embedding , sentence_embeddings ):
32+ return - np .sum (np .abs (query_embedding - sentence_embeddings ), axis = 1 )
33+
34+ def dot_product (query_embedding , sentence_embeddings ):
35+ return np .dot (sentence_embeddings , query_embedding .T ).flatten ()
36+
37+ similarity_functions = {
38+ '1' : ('Cosine Similarity' , cosine_similarity ),
39+ '2' : ('Euclidean Distance' , euclidean_distance ),
40+ '3' : ('Manhattan Distance' , manhattan_distance ),
41+ '4' : ('Dot Product' , dot_product )
42+ }
43+
44+ def find_similar_sentences (query , file_path , similarity_func , top_n = 5 ):
45+ model = load_or_download_model ()
46+ sentences = load_file (file_path )
47+ sentence_embeddings = model .encode (sentences )
48+ query_embedding = model .encode ([query ])
49+
50+ similarity_scores = similarity_func (query_embedding , sentence_embeddings )
51+ top_results = sorted (zip (sentences , similarity_scores ), key = lambda x : x [1 ], reverse = True )[:top_n ]
52+
53+ return top_results
54+
55+ def validate_file_path (file_path ):
56+ if not file_path .endswith ('.txt' ):
57+ file_path += '.txt'
58+ if not os .path .exists (file_path ):
59+ raise FileNotFoundError (f"The file '{ file_path } ' does not exist." )
60+ return file_path
61+
62+ def main ():
63+ print ("Welcome to the Enhanced Sentence Similarity Search Tool!" )
64+
65+ query = input ("Enter your query: " )
66+
67+ while True :
68+ file_path = input ("Enter the path to your text file without extension: " )
69+ try :
70+ file_path = validate_file_path (file_path )
71+ break
72+ except FileNotFoundError as e :
73+ print (f"Error: { str (e )} Please try again." )
74+
75+ print ("\n Choose a similarity measurement method:" )
76+ for key , (name , _ ) in similarity_functions .items ():
77+ print (f"{ key } . { name } " )
78+
79+ while True :
80+ choice = input ("Enter the number of your choice: " )
81+ if choice in similarity_functions :
82+ similarity_name , similarity_func = similarity_functions [choice ]
83+ break
84+ print ("Invalid choice. Please try again." )
85+
86+ try :
87+ results = find_similar_sentences (query , file_path , similarity_func )
88+ print (f"\n Top 5 similar sentences for query: '{ query } ' using { similarity_name } \n " )
89+ for sentence , score in results :
90+ print (f"Similarity Score: { score :.4f} " )
91+ print (f"Sentence: { sentence } \n " )
92+ except Exception as e :
93+ print (f"An error occurred: { str (e )} " )
94+
95+ if __name__ == "__main__" :
96+ main ()
0 commit comments