88using System . Threading . Tasks ;
99using Microsoft . EntityFrameworkCore ;
1010using Microsoft . EntityFrameworkCore . Infrastructure ;
11+ using Microsoft . EntityFrameworkCore . Metadata ;
1112using Microsoft . EntityFrameworkCore . Storage ;
1213
1314namespace EntityFrameworkExtras . EFCore
@@ -64,58 +65,83 @@ public static async Task ExecuteStoredProcedureAsync(this DatabaseFacade databas
6465 /// <param name="database">The database to execute against.</param>
6566 /// <param name="storedProcedure">The stored procedure to execute.</param>
6667 /// <returns></returns>
67- public static IEnumerable < T > ExecuteStoredProcedure < T > ( this DatabaseFacade database , object storedProcedure )
68- {
69- if ( storedProcedure == null )
70- throw new ArgumentNullException ( "storedProcedure" ) ;
68+ public static IEnumerable < T > ExecuteStoredProcedure < T > ( this DatabaseFacade database , object storedProcedure ) where T : class
69+ {
70+ var contextField = database . GetType ( ) . GetField ( "_context" , BindingFlags . Instance | BindingFlags . NonPublic ) ;
71+
72+ var context = ( DbContext ) contextField . GetValue ( database ) ;
73+
74+ var entityType = FindModelEntityType ( context , typeof ( T ) ) ;
75+
76+ if ( entityType == null )
77+ {
78+ return database . InternalExecuteStoredProcedure < T > ( storedProcedure ) ;
79+ }
7180
81+ if ( storedProcedure == null )
82+ throw new ArgumentNullException ( "storedProcedure" ) ;
7283
73- List < T > result = new List < T > ( ) ;
7484 var info = StoredProcedureParser . BuildStoredProcedureInfo ( storedProcedure ) ;
7585
86+ List < T > result = context . Set < T > ( ) . FromSqlRaw ( info . Sql , info . SqlParameters ) . AsNoTracking ( ) . ToList ( ) ;
7687
77- // from : https://github.com/Fodsuk/EntityFrameworkExtras/pull/23/commits/dce354304aa9a95750f7d2559d1b002444ac46f7
78- using ( var command = database . GetDbConnection ( ) . CreateCommand ( ) )
79- {
80- command . CommandText = info . Sql ;
88+ SetOutputParameterValues ( info . SqlParameters , storedProcedure ) ;
89+
90+ return result ;
91+ }
92+
93+ internal static IEnumerable < T > InternalExecuteStoredProcedure < T > ( this DatabaseFacade database , object storedProcedure )
94+ {
95+ if ( storedProcedure == null )
96+ throw new ArgumentNullException ( "storedProcedure" ) ;
97+
98+
99+ List < T > result = new List < T > ( ) ;
100+ var info = StoredProcedureParser . BuildStoredProcedureInfo ( storedProcedure ) ;
101+
102+
103+ // from : https://github.com/Fodsuk/EntityFrameworkExtras/pull/23/commits/dce354304aa9a95750f7d2559d1b002444ac46f7
104+ using ( var command = database . GetDbConnection ( ) . CreateCommand ( ) )
105+ {
106+ command . CommandText = info . Sql ;
81107 int ? commandTimeout = database . GetCommandTimeout ( ) ;
82- if ( commandTimeout . HasValue )
83- {
84- command . CommandTimeout = commandTimeout . Value ;
85- }
86- command . CommandType = CommandType . Text ;
87- command . Parameters . AddRange ( info . SqlParameters ) ;
108+ if ( commandTimeout . HasValue )
109+ {
110+ command . CommandTimeout = commandTimeout . Value ;
111+ }
112+ command . CommandType = CommandType . Text ;
113+ command . Parameters . AddRange ( info . SqlParameters ) ;
88114 command . Transaction = database . CurrentTransaction ? . GetDbTransaction ( ) ;
89- database . OpenConnection ( ) ;
90-
91- using ( var resultReader = command . ExecuteReader ( ) )
92- {
93- T obj = default ( T ) ;
94-
95- while ( resultReader . Read ( ) )
96- {
97- obj = Activator . CreateInstance < T > ( ) ;
98- foreach ( PropertyInfo prop in obj . GetType ( ) . GetProperties ( ) )
99- {
100- var val = GetValue ( resultReader , prop . Name ) ;
101- if ( ! object . Equals ( val , DBNull . Value ) )
102- {
103- prop . SetValue ( obj , val , null ) ;
104- }
105- }
106-
107- result . Add ( obj ) ;
108- }
109- }
115+ database . OpenConnection ( ) ;
110116
111- }
117+ using ( var resultReader = command . ExecuteReader ( ) )
118+ {
119+ T obj = default ( T ) ;
112120
113- SetOutputParameterValues ( info . SqlParameters , storedProcedure ) ;
121+ while ( resultReader . Read ( ) )
122+ {
123+ obj = Activator . CreateInstance < T > ( ) ;
124+ foreach ( PropertyInfo prop in obj . GetType ( ) . GetProperties ( ) )
125+ {
126+ var val = GetValue ( resultReader , prop . Name ) ;
127+ if ( ! object . Equals ( val , DBNull . Value ) )
128+ {
129+ prop . SetValue ( obj , val , null ) ;
130+ }
131+ }
114132
115- return result ;
116- }
133+ result . Add ( obj ) ;
134+ }
135+ }
117136
118- /// <summary>
137+ }
138+
139+ SetOutputParameterValues ( info . SqlParameters , storedProcedure ) ;
140+
141+ return result ;
142+ }
143+
144+ /// <summary>
119145 /// Executes the specified stored procedure against a database asynchronously
120146 /// and returns an enumerable of T representing the data returned.
121147 /// </summary>
@@ -124,7 +150,32 @@ public static IEnumerable<T> ExecuteStoredProcedure<T>(this DatabaseFacade datab
124150 /// <param name="storedProcedure">The stored procedure to execute.</param>
125151 /// <param name="cancellationToken">The cancellation token.</param>
126152 /// <returns></returns>
127- public static async Task < IEnumerable < T > > ExecuteStoredProcedureAsync < T > ( this DatabaseFacade database , object storedProcedure , CancellationToken cancellationToken = default )
153+ public static async Task < IEnumerable < T > > ExecuteStoredProcedureAsync < T > ( this DatabaseFacade database , object storedProcedure , CancellationToken cancellationToken = default ) where T : class
154+ {
155+ var contextField = database . GetType ( ) . GetField ( "_context" , BindingFlags . Instance | BindingFlags . NonPublic ) ;
156+
157+ var context = ( DbContext ) contextField . GetValue ( database ) ;
158+
159+ var entityType = FindModelEntityType ( context , typeof ( T ) ) ;
160+
161+ if ( entityType == null )
162+ {
163+ return await database . InternalExecuteStoredProcedureAsync < T > ( storedProcedure ) . ConfigureAwait ( false ) ;
164+ }
165+
166+ if ( storedProcedure == null )
167+ throw new ArgumentNullException ( "storedProcedure" ) ;
168+
169+ var info = StoredProcedureParser . BuildStoredProcedureInfo ( storedProcedure ) ;
170+
171+ List < T > result = await context . Set < T > ( ) . FromSqlRaw ( info . Sql , info . SqlParameters ) . AsNoTracking ( ) . ToListAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
172+
173+ SetOutputParameterValues ( info . SqlParameters , storedProcedure ) ;
174+
175+ return result ;
176+ }
177+
178+ internal static async Task < IEnumerable < T > > InternalExecuteStoredProcedureAsync < T > ( this DatabaseFacade database , object storedProcedure , CancellationToken cancellationToken = default )
128179 {
129180 if ( storedProcedure == null )
130181 throw new ArgumentNullException ( "storedProcedure" ) ;
@@ -138,11 +189,11 @@ public static async Task<IEnumerable<T>> ExecuteStoredProcedureAsync<T>(this Dat
138189 using ( var command = database . GetDbConnection ( ) . CreateCommand ( ) )
139190 {
140191 command . CommandText = info . Sql ;
141- int ? commandTimeout = database . GetCommandTimeout ( ) ;
142- if ( commandTimeout . HasValue )
143- {
144- command . CommandTimeout = commandTimeout . Value ;
145- }
192+ int ? commandTimeout = database . GetCommandTimeout ( ) ;
193+ if ( commandTimeout . HasValue )
194+ {
195+ command . CommandTimeout = commandTimeout . Value ;
196+ }
146197 command . CommandType = CommandType . Text ;
147198 command . Parameters . AddRange ( info . SqlParameters ) ;
148199 command . Transaction = database . CurrentTransaction ? . GetDbTransaction ( ) ;
@@ -175,15 +226,15 @@ public static async Task<IEnumerable<T>> ExecuteStoredProcedureAsync<T>(this Dat
175226 return result ;
176227 }
177228
178- /// <summary>
179- /// Executes the specified stored procedure against a database
180- /// and returns the first or default value
181- /// </summary>
182- /// <typeparam name="T">Type of the data returned from the stored procedure.</typeparam>
183- /// <param name="database">The database to execute against.</param>
184- /// <param name="storedProcedure">The stored procedure to execute.</param>
185- /// <returns></returns>
186- public static T ExecuteStoredProcedureFirstOrDefault < T > ( this DatabaseFacade database , object storedProcedure )
229+ /// <summary>
230+ /// Executes the specified stored procedure against a database
231+ /// and returns the first or default value
232+ /// </summary>
233+ /// <typeparam name="T">Type of the data returned from the stored procedure.</typeparam>
234+ /// <param name="database">The database to execute against.</param>
235+ /// <param name="storedProcedure">The stored procedure to execute.</param>
236+ /// <returns></returns>
237+ public static T ExecuteStoredProcedureFirstOrDefault < T > ( this DatabaseFacade database , object storedProcedure ) where T : class
187238 {
188239 return database . ExecuteStoredProcedure < T > ( storedProcedure ) . FirstOrDefault ( ) ;
189240 }
@@ -197,12 +248,30 @@ public static T ExecuteStoredProcedureFirstOrDefault<T>(this DatabaseFacade data
197248 /// <param name="storedProcedure">The stored procedure to execute.</param>
198249 /// <param name="cancellationToken">The cancellation token.</param>
199250 /// <returns></returns>
200- public static async Task < T > ExecuteStoredProcedureFirstOrDefaultAsync < T > ( this DatabaseFacade database , object storedProcedure , CancellationToken cancellationToken = default )
251+ public static async Task < T > ExecuteStoredProcedureFirstOrDefaultAsync < T > ( this DatabaseFacade database , object storedProcedure , CancellationToken cancellationToken = default ) where T : class
201252 {
202253 var executed = await database . ExecuteStoredProcedureAsync < T > ( storedProcedure , cancellationToken ) . ConfigureAwait ( false ) ;
203254
204255 return executed . FirstOrDefault ( ) ;
205256 }
257+
258+ internal static IEntityType FindModelEntityType ( this DbContext @this , Type type )
259+ {
260+ var entityType = @this . Model . FindEntityType ( type ) ;
261+
262+ if ( entityType == null )
263+ {
264+ var baseType = type ;
265+
266+ while ( baseType != null && entityType == null )
267+ {
268+ entityType = @this . Model . FindEntityType ( baseType ) ;
269+ baseType = baseType . BaseType ;
270+ }
271+ }
272+
273+ return entityType ;
274+ }
206275 }
207276}
208277#endif
0 commit comments