Skip to content

Commit e71f978

Browse files
author
Leander Reascos
committed
Update: Add get_U and get_S, fixed bugs
1 parent a8cc216 commit e71f978

File tree

1 file changed

+184
-52
lines changed

1 file changed

+184
-52
lines changed

sympt/solver.py

Lines changed: 184 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from tabulate import tabulate
5959
from sympy import (Rational as sp_Rational, factorial as sp_factorial,
6060
nsimplify as sp_nsimplify, simplify as sp_simplify,
61-
Add as sp_Add, Matrix as sp_Matrix)
61+
Add as sp_Add, Matrix as sp_Matrix, eye as sp_eye)
6262
from numpy import any as np_any
6363
# import deep copy
6464
from copy import copy
@@ -496,6 +496,9 @@ def __checks_and_prepare_solver(self, method, mask, max_order=2):
496496
self.__Hs[h_order] = new_hk
497497
self.__Vs[h_order] = new_vk
498498

499+
if self.__Hs.get(0) is None:
500+
raise ValueError("The provided Hamiltonian contains no diagonal zeroth order term.")
501+
499502
# Apply the commutation relations to the zeroth-order Hamiltonian
500503
H0_expr = apply_commutation_relations(
501504
self.__Hs.get(0) + Expression(), self.commutation_relations).simplify()
@@ -706,6 +709,8 @@ def __prepare_result(self, O_final, return_form='operator', disable=True):
706709
return O_final_projected
707710

708711
elif return_form == 'matrix':
712+
if len(O_final.expr) == 0:
713+
return S.Zero
709714
O_matrix_form = sp_zeros(
710715
O_final.expr[0].fn.shape[0], O_final.expr[0].fn.shape[1])
711716

@@ -743,6 +748,92 @@ def __prepare_result(self, O_final, return_form='operator', disable=True):
743748
raise ValueError(f'Invalid return form {return_form}. Please choose either: ' + ', '.join(
744749
['operator', 'matrix', 'dict', 'dict_operator', 'dict_matrix']))
745750

751+
def _convert_form(self, collection, ret_form, cache_attrs, result_attr, corrections_attr):
752+
"""
753+
Helper method that converts a dictionary of corrections into the desired form
754+
and stores the final result and corrections on the provided attributes.
755+
756+
Parameters
757+
----------
758+
collection : dict
759+
The dictionary (e.g. self.__Hs_final or U) whose values will be converted.
760+
ret_form : str
761+
The desired output form. Valid values are 'operator', 'matrix', 'dict',
762+
'dict_operator', or 'dict_matrix'.
763+
cache_attrs : dict
764+
A dictionary mapping conversion types to a tuple of attribute names for caching.
765+
For example, for H one might have:
766+
{
767+
'operator': ('_EffectiveFrame__H_operator_form', '_EffectiveFrame__H_operator_form_corrections'),
768+
'matrix': ('_EffectiveFrame__H_matrix_form', '_EffectiveFrame__H_matrix_form_corrections'),
769+
'dict': ('_EffectiveFrame__H_dict_form', '_EffectiveFrame__H_dict_form_corrections')
770+
}
771+
result_attr : str
772+
The attribute name on self to store the final result (e.g. "H" or "U").
773+
corrections_attr : str
774+
The attribute name on self to store the corrections (e.g. "H_corrections" or "U_corrections").
775+
776+
Returns
777+
-------
778+
The converted result in the requested form.
779+
"""
780+
if ret_form in ['operator', 'matrix']:
781+
# Choose the summing function depending on the form.
782+
sum_func = np_sum if ret_form == 'operator' else lambda x: sp_Add(*x)
783+
cache_attr, corr_cache_attr = cache_attrs[ret_form]
784+
# Return cached result if available.
785+
if hasattr(self, cache_attr):
786+
setattr(self, corrections_attr, getattr(self, corr_cache_attr))
787+
return getattr(self, cache_attr)
788+
# Otherwise, convert each element.
789+
corrections = {
790+
k: self.__prepare_result(v, ret_form)
791+
for k, v in tqdm(collection.items(),
792+
desc=f"Converting to {ret_form} form",
793+
disable=not self.verbose)
794+
}
795+
result = sum_func(list(corrections.values()))
796+
setattr(self, corr_cache_attr, corrections)
797+
setattr(self, cache_attr, result)
798+
setattr(self, corrections_attr, corrections)
799+
setattr(self, result_attr, result)
800+
return result
801+
802+
elif 'dict' in ret_form:
803+
# Determine the sub-type: e.g., "operator" or "matrix"
804+
extra = ret_form.split('_')[1] if '_' in ret_form else self.__return_form
805+
new_ret_form = 'dict_' + extra
806+
cache_attr, corr_cache_attr = cache_attrs['dict']
807+
# Return cached result if available.
808+
if hasattr(self, cache_attr) and getattr(self, cache_attr).get(extra) is not None:
809+
setattr(self, corrections_attr, getattr(self, corr_cache_attr)[extra])
810+
return getattr(self, cache_attr)[extra]
811+
if not hasattr(self, cache_attr):
812+
setattr(self, cache_attr, {})
813+
setattr(self, corr_cache_attr, {})
814+
815+
corrections = {
816+
k: self.__prepare_result(v, new_ret_form)
817+
for k, v in tqdm(collection.items(),
818+
desc=f"Converting to dictionary of {extra} form",
819+
disable=not self.verbose)
820+
}
821+
# Merge the sub-dictionaries.
822+
result_dict = {}
823+
for subdict in corrections.values():
824+
for key, value in subdict.items():
825+
result_dict[key] = result_dict.get(key, 0) + value
826+
getattr(self, corr_cache_attr)[extra] = corrections
827+
getattr(self, cache_attr)[extra] = result_dict
828+
setattr(self, corrections_attr, corrections)
829+
setattr(self, result_attr, result_dict)
830+
return result_dict
831+
832+
else:
833+
raise ValueError(
834+
"Invalid return form. Please choose either: operator, matrix, dict, dict_operator, or dict_matrix."
835+
)
836+
746837
def get_H(self, return_form=None):
747838
"""
748839
Returns the effective Hamiltonian.
@@ -760,68 +851,109 @@ def get_H(self, return_form=None):
760851
Expression or Matrix
761852
The effective Hamiltonian in the specified form.
762853
"""
763-
764854
return_form = self.__return_form if return_form is None else return_form
765-
self.__Hs_final = {k: v for k, v in self.__Hs_final.items() if v.expr.shape[0] != 0}
855+
# Filter out entries with empty expressions.
856+
self.__Hs_final = {
857+
k: v for k, v in self.__Hs_final.items() if v.expr.shape[0] != 0
858+
}
766859

767860
if not hasattr(self, '_EffectiveFrame__Hs_final'):
768861
raise AttributeError(
769-
'The Hamiltonian has not been solved yet. Please run the solver method first.')
862+
"The Hamiltonian has not been solved yet. Please run the solver method first."
863+
)
864+
865+
cache_attrs = {
866+
'operator': ('_EffectiveFrame__H_operator_form', '_EffectiveFrame__H_operator_form_corrections'),
867+
'matrix': ('_EffectiveFrame__H_matrix_form', '_EffectiveFrame__H_matrix_form_corrections'),
868+
'dict': ('_EffectiveFrame__H_dict_form', '_EffectiveFrame__H_dict_form_corrections'),
869+
}
870+
# The helper stores the final result in self.H and the corrections in self.H_corrections.
871+
return self._convert_form(self.__Hs_final, return_form, cache_attrs, 'H', 'H_corrections')
872+
873+
def get_U(self, return_form=None):
874+
"""
875+
Returns the effective frame transformation U.
770876
771-
if return_form == 'operator':
772-
if hasattr(self, '_EffectiveFrame__H_operator_form'):
773-
self.corrections = self.__H_operator_form_corrections
774-
self.H = self.__H_operator_form
775-
return self.__H_operator_form
877+
Parameters
878+
----------
879+
return_form : str, optional
880+
If 'operator', returns the result in operator form (default is 'operator').
881+
If 'matrix', returns the matrix form.
882+
If 'dict' or 'dict_operator', returns the dictionary form with projected finite subspaces.
883+
If 'dict_matrix', returns the dictionary form with the full matrix.
776884
777-
self.__H_operator_form_corrections = {k: self.__prepare_result(v, return_form) for k, v in tqdm(self.__Hs_final.items(), desc='Converting to operator form', disable= not self.verbose)}
778-
self.__H_operator_form = np_sum(list(self.__H_operator_form_corrections.values()))
779-
self.H = self.__H_operator_form
780-
self.corrections = self.__H_operator_form_corrections
885+
Returns
886+
-------
887+
Expression or Matrix
888+
The effective frame transformation U in the specified form.
889+
"""
890+
if not hasattr(self, '_EffectiveFrame__Up'):
891+
raise AttributeError(
892+
"The Hamiltonian has not been solved yet. Please run the solver method first."
893+
)
894+
895+
# Make a copy of U and replace U[0] with an identity operator.
896+
U = self.__Up.copy()
897+
idMulGroup = MulGroup(
898+
fn=sp_eye(U[1].expr[0].fn.shape[0]),
899+
inf=[1] * len(U[1].expr[0].inf),
900+
delta=[0] * len(U[1].expr[0].delta)
901+
)
902+
U[0] = Expression(np_array([idMulGroup]))
781903

782-
elif return_form == 'matrix':
783-
if hasattr(self, '_EffectiveFrame__H_matrix_form'):
784-
self.corrections = self.__H_matrix_form_corrections
785-
self.H = self.__H_matrix_form
786-
return self.__H_matrix_form
787-
788-
self.__H_matrix_form_corrections = {k: self.__prepare_result(v, return_form) for k, v in tqdm(self.__Hs_final.items(), desc='Converting to matrix form', disable= not self.verbose)}
789-
self.__H_matrix_form = sp_Add(*list(self.__H_matrix_form_corrections.values()))
790-
self.H = self.__H_matrix_form
791-
self.corrections = self.__H_matrix_form_corrections
904+
return_form = self.__return_form if return_form is None else return_form
792905

906+
cache_attrs = {
907+
'operator': ('_EffectiveFrame__U_operator_form', '_EffectiveFrame__U_operator_form_corrections'),
908+
'matrix': ('_EffectiveFrame__U_matrix_form', '_EffectiveFrame__U_matrix_form_corrections'),
909+
'dict': ('_EffectiveFrame__U_dict_form', '_EffectiveFrame__U_dict_form_corrections'),
910+
}
911+
# The helper stores the final result in self.U and the corrections in self.U_corrections.
912+
return self._convert_form(U, return_form, cache_attrs, 'U', 'U_corrections')
793913

794-
elif 'dict' in return_form:
795-
extra = return_form.split(
796-
'_')[1] if '_' in return_form else self.__return_form
797-
if hasattr(self, '_EffectiveFrame__H_dict_form') and self.__H_dict_form.get(extra) is not None:
798-
self.corrections = self.__H_dict_form_corrections[extra]
799-
self.H = self.__H_dict_form[extra]
800-
return self.__H_dict_form[extra]
801-
802-
if not hasattr(self, '_EffectiveFrame__H_dict_form'):
803-
self.__H_dict_form = {}
804-
self.__H_dict_form_corrections = {}
914+
def get_S(self, return_form=None):
915+
"""
916+
Returns the effective frame transformation S.
805917
806-
self.__H_dict_form_corrections[extra] = {k: self.__prepare_result(v, 'dict' + f'_{extra}') for k, v in tqdm(self.__Hs_final.items(), desc=f'Converting to dictionary of {extra} form', disable= not self.verbose)}
918+
Parameters
919+
----------
920+
return_form : str, optional
921+
If 'operator', returns the result in operator form (default is 'operator').
922+
If 'matrix', returns the matrix form.
923+
If 'dict' or 'dict_operator', returns the dictionary form with projected finite subspaces.
924+
If 'dict_matrix', returns the dictionary form with the full matrix.
807925
808-
self.__H_dict_form[extra] = {}
809-
810-
for _, v in self.__H_dict_form_corrections[extra].items():
811-
for k, v1 in v.items():
812-
if self.__H_dict_form.get(k):
813-
self.__H_dict_form[extra][k] += v1
814-
else:
815-
self.__H_dict_form[extra][k] = v1
816-
817-
self.H = self.__H_dict_form[extra]
818-
self.corrections = self.__H_dict_form_corrections[extra]
819-
820-
else:
821-
raise ValueError('Invalid return form. Please choose either: ' + ', '.join(
822-
['operator', 'matrix', 'dict', 'dict_operator', 'dict_matrix']))
823-
824-
return self.H
926+
Returns
927+
-------
928+
Expression or Matrix
929+
The effective frame transformation S in the specified form.
930+
"""
931+
if not hasattr(self, '_EffectiveFrame__Up'):
932+
raise AttributeError(
933+
"The Hamiltonian has not been solved yet. Please run the solver method first."
934+
)
935+
936+
self.__S = {0: Expression()}
937+
for order in range(1, self.__max_order + 1):
938+
self.__S[order] = - self.__Etas[order]
939+
for theta in P(order):
940+
if len(theta) % 2 == 0:
941+
continue
942+
if len(theta) == 1:
943+
continue
944+
self.__S[order] -= sp_Rational(1, sp_factorial(len(theta))) * np_prod([self.__S[k] for k in theta])
945+
946+
self.__S[order] = self.__S[order].simplify()
947+
948+
return_form = self.__return_form if return_form is None else return_form
949+
950+
cache_attrs = {
951+
'operator': ('_EffectiveFrame__S_operator_form', '_EffectiveFrame__S_operator_form_corrections'),
952+
'matrix': ('_EffectiveFrame__S_matrix_form', '_EffectiveFrame__S_matrix_form_corrections'),
953+
'dict': ('_EffectiveFrame__S_dict_form', '_EffectiveFrame__S_dict_form_corrections'),
954+
}
955+
# The helper stores the final result in self.S and the corrections in self.S_corrections.
956+
return self._convert_form(self.__S, return_form, cache_attrs, 'S', 'S_corrections')
825957

826958
def rotate(self, expr, max_order=None, return_form=None):
827959
"""

0 commit comments

Comments
 (0)