1- from termcolor import colored
2- from copy import deepcopy
31import collections
2+ from copy import deepcopy
3+ from functools import partial
44
5+ from termcolor import colored
56
6- class FlatDictDiffer (object ):
7- def __init__ (self , ref , target ):
8- self .ref , self .target = ref , target
9- self .ref_set , self .target_set = set (ref .keys ()), set (target .keys ())
10- self .isect = self .ref_set .intersection (self .target_set )
7+
8+ class DiffResolver (object ):
9+ """Determines diffs between two dicts, where the remote copy is considered the baseline"""
10+ def __init__ (self , remote , local , force = False ):
11+ self .remote_flat , self .local_flat = self ._flatten (remote ), self ._flatten (local )
12+ self .remote_set , self .local_set = set (self .remote_flat .keys ()), set (self .local_flat .keys ())
13+ self .intersection = self .remote_set .intersection (self .local_set )
14+ self .force = force
1115
1216 if self .added () or self .removed () or self .changed ():
1317 self .differ = True
1418 else :
1519 self .differ = False
1620
21+ @classmethod
22+ def configure (cls , * args , ** kwargs ):
23+ return partial (cls , * args , ** kwargs )
24+
1725 def added (self ):
18- return self .target_set - self .isect
26+ """Returns a (flattened) dict of added leaves i.e. {"full/path": value, ...}"""
27+ return self .local_set - self .intersection
1928
2029 def removed (self ):
21- return self .ref_set - self .isect
30+ """Returns a (flattened) dict of removed leaves i.e. {"full/path": value, ...}"""
31+ return self .remote_set - self .intersection
2232
2333 def changed (self ):
24- return set (k for k in self .isect if self .ref [k ] != self .target [k ])
34+ """Returns a (flattened) dict of changed leaves i.e. {"full/path": value, ...}"""
35+ return set (k for k in self .intersection if self .remote_flat [k ] != self .local_flat [k ])
2536
2637 def unchanged (self ):
27- return set (k for k in self .isect if self .ref [k ] == self .target [k ])
38+ """Returns a (flattened) dict of unchanged leaves i.e. {"full/path": value, ...}"""
39+ return set (k for k in self .intersection if self .remote_flat [k ] == self .local_flat [k ])
2840
29- def print_state (self ):
41+ def describe_diff (self ):
42+ """Return a (multi-line) string describing all differences"""
43+ description = ""
3044 for k in self .added ():
31- print ( colored ("+" , 'green' ), "{} = {}" .format (k , self .target [k ]))
45+ description += colored ("+" , 'green' ), "{} = {}" .format (k , self .local_flat [k ]) + ' \n '
3246
3347 for k in self .removed ():
34- print ( colored ("-" , 'red' ), k )
48+ description += colored ("-" , 'red' ), k + ' \n '
3549
3650 for k in self .changed ():
37- print (colored ("~" , 'yellow' ), "{}:\n \t < {}\n \t > {}" .format (k , self .ref [k ], self .target [k ]))
38-
39-
40- def flatten (d , pkey = '' , sep = '/' ):
41- items = []
42- for k in d :
43- new = pkey + sep + k if pkey else k
44- if isinstance (d [k ], collections .MutableMapping ):
45- items .extend (flatten (d [k ], new , sep = sep ).items ())
51+ description += colored ("~" , 'yellow' ), "{}:\n \t < {}\n \t > {}" .format (k , self .remote_flat [k ], self .local_flat [k ]) + '\n '
52+
53+ return description
54+
55+ def _flatten (self , d , current_path = '' , sep = '/' ):
56+ """Convert a nested dict structure into a "flattened" dict i.e. {"full/path": "value", ...}"""
57+ items = []
58+ for k in d :
59+ new = current_path + sep + k if current_path else k
60+ if isinstance (d [k ], collections .MutableMapping ):
61+ items .extend (self ._flatten (d [k ], new , sep = sep ).items ())
62+ else :
63+ items .append ((sep + new , d [k ]))
64+ return dict (items )
65+
66+ def _unflatten (self , d , sep = '/' ):
67+ """Converts a "flattened" dict i.e. {"full/path": "value", ...} into a nested dict structure"""
68+ output = {}
69+ for k in d :
70+ add (
71+ obj = output ,
72+ path = k ,
73+ value = d [k ],
74+ sep = sep ,
75+ )
76+ return output
77+
78+ def merge (self ):
79+ """Generate a merge of the local and remote dicts, following configurations set during __init__"""
80+ dictfilter = lambda original , keep_keys : dict ([(i , original [i ]) for i in original if i in set (keep_keys )])
81+ if self .force :
82+ # Overwrite local changes (i.e. only preserve added keys)
83+ # NOTE: Currently the system cannot tell the difference between a remote delete and a local add
84+ prior_set = self .changed ().union (self .removed ()).union (self .unchanged ())
85+ current_set = self .added ()
4686 else :
47- items .append ((sep + new , d [k ]))
48- return dict (items )
49-
50-
51- def add (obj , path , value ):
52- parts = path .strip ("/" ).split ("/" )
87+ # Preserve added keys and changed keys
88+ # NOTE: Currently the system cannot tell the difference between a remote delete and a local add
89+ prior_set = self .unchanged ().union (self .removed ())
90+ current_set = self .added ().union (self .changed ())
91+ state = dictfilter (original = self .remote_flat , keep_keys = prior_set )
92+ state .update (dictfilter (original = self .local_flat , keep_keys = current_set ))
93+ return self ._unflatten (state )
94+
95+
96+ def add (obj , path , value , sep = '/' ):
97+ """Add value to the `obj` dict at the specified path"""
98+ parts = path .strip (sep ).split (sep )
5399 last = len (parts ) - 1
54100 for index , part in enumerate (parts ):
55101 if index == last :
@@ -61,7 +107,7 @@ def add(obj, path, value):
61107def search (state , path ):
62108 result = state
63109 for p in path .strip ("/" ).split ("/" ):
64- if result .get (p ):
110+ if result .clone (p ):
65111 result = result [p ]
66112 else :
67113 result = {}
@@ -71,16 +117,6 @@ def search(state, path):
71117 return output
72118
73119
74- def unflatten (d ):
75- output = {}
76- for k in d :
77- add (
78- obj = output ,
79- path = k ,
80- value = d [k ])
81- return output
82-
83-
84120def merge (a , b ):
85121 if not isinstance (b , dict ):
86122 return b
0 commit comments