88using System . Threading . Tasks ;
99using Microsoft . EntityFrameworkCore ;
1010using Microsoft . EntityFrameworkCore . Infrastructure ;
11+ using Microsoft . EntityFrameworkCore . Metadata ;
1112using Microsoft . EntityFrameworkCore . Storage ;
1213
1314namespace EntityFrameworkExtras . EFCore
@@ -65,15 +66,22 @@ public static async Task ExecuteStoredProcedureAsync(this DatabaseFacade databas
6566 /// <param name="storedProcedure">The stored procedure to execute.</param>
6667 /// <returns></returns>
6768 public static IEnumerable < T > ExecuteStoredProcedure < T > ( this DatabaseFacade database , object storedProcedure ) where T : class
68- {
69- if ( storedProcedure == null )
70- throw new ArgumentNullException ( "storedProcedure" ) ;
69+ {
70+ var contextField = database . GetType ( ) . GetField ( "_context" , BindingFlags . Instance | BindingFlags . NonPublic ) ;
7171
72- var info = StoredProcedureParser . BuildStoredProcedureInfo ( storedProcedure ) ;
72+ var context = ( DbContext ) contextField . GetValue ( database ) ;
7373
74- var contextField = database . GetType ( ) . GetField ( "_context" , BindingFlags . Instance | BindingFlags . NonPublic ) ;
74+ var entityType = FindModelEntityType ( context , typeof ( T ) ) ;
75+
76+ if ( entityType == null )
77+ {
78+ return database . InternalExecuteStoredProcedure < T > ( storedProcedure ) ;
79+ }
80+
81+ if ( storedProcedure == null )
82+ throw new ArgumentNullException ( "storedProcedure" ) ;
7583
76- var context = ( DbContext ) contextField . GetValue ( database ) ;
84+ var info = StoredProcedureParser . BuildStoredProcedureInfo ( storedProcedure ) ;
7785
7886 List < T > result = context . Set < T > ( ) . FromSqlRaw ( info . Sql , info . SqlParameters ) . AsNoTracking ( ) . ToList ( ) ;
7987
@@ -82,7 +90,58 @@ public static IEnumerable<T> ExecuteStoredProcedure<T>(this DatabaseFacade datab
8290 return result ;
8391 }
8492
85- /// <summary>
93+ internal static IEnumerable < T > InternalExecuteStoredProcedure < T > ( this DatabaseFacade database , object storedProcedure )
94+ {
95+ if ( storedProcedure == null )
96+ throw new ArgumentNullException ( "storedProcedure" ) ;
97+
98+
99+ List < T > result = new List < T > ( ) ;
100+ var info = StoredProcedureParser . BuildStoredProcedureInfo ( storedProcedure ) ;
101+
102+
103+ // from : https://github.com/Fodsuk/EntityFrameworkExtras/pull/23/commits/dce354304aa9a95750f7d2559d1b002444ac46f7
104+ using ( var command = database . GetDbConnection ( ) . CreateCommand ( ) )
105+ {
106+ command . CommandText = info . Sql ;
107+ int ? commandTimeout = database . GetCommandTimeout ( ) ;
108+ if ( commandTimeout . HasValue )
109+ {
110+ command . CommandTimeout = commandTimeout . Value ;
111+ }
112+ command . CommandType = CommandType . Text ;
113+ command . Parameters . AddRange ( info . SqlParameters ) ;
114+ command . Transaction = database . CurrentTransaction ? . GetDbTransaction ( ) ;
115+ database . OpenConnection ( ) ;
116+
117+ using ( var resultReader = command . ExecuteReader ( ) )
118+ {
119+ T obj = default ( T ) ;
120+
121+ while ( resultReader . Read ( ) )
122+ {
123+ obj = Activator . CreateInstance < T > ( ) ;
124+ foreach ( PropertyInfo prop in obj . GetType ( ) . GetProperties ( ) )
125+ {
126+ var val = GetValue ( resultReader , prop . Name ) ;
127+ if ( ! object . Equals ( val , DBNull . Value ) )
128+ {
129+ prop . SetValue ( obj , val , null ) ;
130+ }
131+ }
132+
133+ result . Add ( obj ) ;
134+ }
135+ }
136+
137+ }
138+
139+ SetOutputParameterValues ( info . SqlParameters , storedProcedure ) ;
140+
141+ return result ;
142+ }
143+
144+ /// <summary>
86145 /// Executes the specified stored procedure against a database asynchronously
87146 /// and returns an enumerable of T representing the data returned.
88147 /// </summary>
@@ -92,32 +151,90 @@ public static IEnumerable<T> ExecuteStoredProcedure<T>(this DatabaseFacade datab
92151 /// <param name="cancellationToken">The cancellation token.</param>
93152 /// <returns></returns>
94153 public static async Task < IEnumerable < T > > ExecuteStoredProcedureAsync < T > ( this DatabaseFacade database , object storedProcedure , CancellationToken cancellationToken = default ) where T : class
154+ {
155+ var contextField = database . GetType ( ) . GetField ( "_context" , BindingFlags . Instance | BindingFlags . NonPublic ) ;
156+
157+ var context = ( DbContext ) contextField . GetValue ( database ) ;
158+
159+ var entityType = FindModelEntityType ( context , typeof ( T ) ) ;
160+
161+ if ( entityType == null )
162+ {
163+ return await database . InternalExecuteStoredProcedureAsync < T > ( storedProcedure ) . ConfigureAwait ( false ) ;
164+ }
165+
166+ if ( storedProcedure == null )
167+ throw new ArgumentNullException ( "storedProcedure" ) ;
168+
169+ var info = StoredProcedureParser . BuildStoredProcedureInfo ( storedProcedure ) ;
170+
171+ List < T > result = await context . Set < T > ( ) . FromSqlRaw ( info . Sql , info . SqlParameters ) . AsNoTracking ( ) . ToListAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
172+
173+ SetOutputParameterValues ( info . SqlParameters , storedProcedure ) ;
174+
175+ return result ;
176+ }
177+
178+ internal static async Task < IEnumerable < T > > InternalExecuteStoredProcedureAsync < T > ( this DatabaseFacade database , object storedProcedure , CancellationToken cancellationToken = default )
95179 {
96180 if ( storedProcedure == null )
97181 throw new ArgumentNullException ( "storedProcedure" ) ;
98182
183+
184+ List < T > result = new List < T > ( ) ;
99185 var info = StoredProcedureParser . BuildStoredProcedureInfo ( storedProcedure ) ;
100186
101- var contextField = database . GetType ( ) . GetField ( "_context" , BindingFlags . Instance | BindingFlags . NonPublic ) ;
102187
103- var context = ( DbContext ) contextField . GetValue ( database ) ;
188+ // from : https://github.com/Fodsuk/EntityFrameworkExtras/pull/23/commits/dce354304aa9a95750f7d2559d1b002444ac46f7
189+ using ( var command = database . GetDbConnection ( ) . CreateCommand ( ) )
190+ {
191+ command . CommandText = info . Sql ;
192+ int ? commandTimeout = database . GetCommandTimeout ( ) ;
193+ if ( commandTimeout . HasValue )
194+ {
195+ command . CommandTimeout = commandTimeout . Value ;
196+ }
197+ command . CommandType = CommandType . Text ;
198+ command . Parameters . AddRange ( info . SqlParameters ) ;
199+ command . Transaction = database . CurrentTransaction ? . GetDbTransaction ( ) ;
200+ database . OpenConnection ( ) ;
201+
202+ using ( var resultReader = await command . ExecuteReaderAsync ( cancellationToken ) . ConfigureAwait ( false ) )
203+ {
204+ T obj = default ( T ) ;
205+
206+ while ( await resultReader . ReadAsync ( cancellationToken ) . ConfigureAwait ( false ) )
207+ {
208+ obj = Activator . CreateInstance < T > ( ) ;
209+ foreach ( PropertyInfo prop in obj . GetType ( ) . GetProperties ( ) )
210+ {
211+ var val = GetValue ( resultReader , prop . Name ) ;
212+ if ( ! object . Equals ( val , DBNull . Value ) )
213+ {
214+ prop . SetValue ( obj , val , null ) ;
215+ }
216+ }
217+
218+ result . Add ( obj ) ;
219+ }
220+ }
104221
105- List < T > result = await context . Set < T > ( ) . FromSqlRaw ( info . Sql , info . SqlParameters ) . AsNoTracking ( ) . ToListAsync ( cancellationToken ) ;
222+ }
106223
107224 SetOutputParameterValues ( info . SqlParameters , storedProcedure ) ;
108225
109226 return result ;
110227 }
111228
112- /// <summary>
113- /// Executes the specified stored procedure against a database
114- /// and returns the first or default value
115- /// </summary>
116- /// <typeparam name="T">Type of the data returned from the stored procedure.</typeparam>
117- /// <param name="database">The database to execute against.</param>
118- /// <param name="storedProcedure">The stored procedure to execute.</param>
119- /// <returns></returns>
120- public static T ExecuteStoredProcedureFirstOrDefault < T > ( this DatabaseFacade database , object storedProcedure ) where T : class
229+ /// <summary>
230+ /// Executes the specified stored procedure against a database
231+ /// and returns the first or default value
232+ /// </summary>
233+ /// <typeparam name="T">Type of the data returned from the stored procedure.</typeparam>
234+ /// <param name="database">The database to execute against.</param>
235+ /// <param name="storedProcedure">The stored procedure to execute.</param>
236+ /// <returns></returns>
237+ public static T ExecuteStoredProcedureFirstOrDefault < T > ( this DatabaseFacade database , object storedProcedure ) where T : class
121238 {
122239 return database . ExecuteStoredProcedure < T > ( storedProcedure ) . FirstOrDefault ( ) ;
123240 }
@@ -137,6 +254,24 @@ public static async Task<T> ExecuteStoredProcedureFirstOrDefaultAsync<T>(this Da
137254
138255 return executed . FirstOrDefault ( ) ;
139256 }
257+
258+ internal static IEntityType FindModelEntityType ( this DbContext @this , Type type )
259+ {
260+ var entityType = @this . Model . FindEntityType ( type ) ;
261+
262+ if ( entityType == null )
263+ {
264+ var baseType = type ;
265+
266+ while ( baseType != null && entityType == null )
267+ {
268+ entityType = @this . Model . FindEntityType ( baseType ) ;
269+ baseType = baseType . BaseType ;
270+ }
271+ }
272+
273+ return entityType ;
274+ }
140275 }
141276}
142277#endif
0 commit comments