Skip to content

Commit e934111

Browse files
committed
save
1 parent 04af6a4 commit e934111

File tree

1 file changed

+154
-19
lines changed

1 file changed

+154
-19
lines changed

EntityFrameworkExtras.Shared/DatabaseExtensions.EFCore3x.cs

Lines changed: 154 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.Threading.Tasks;
99
using Microsoft.EntityFrameworkCore;
1010
using Microsoft.EntityFrameworkCore.Infrastructure;
11+
using Microsoft.EntityFrameworkCore.Metadata;
1112
using Microsoft.EntityFrameworkCore.Storage;
1213

1314
namespace EntityFrameworkExtras.EFCore
@@ -65,15 +66,22 @@ public static async Task ExecuteStoredProcedureAsync(this DatabaseFacade databas
6566
/// <param name="storedProcedure">The stored procedure to execute.</param>
6667
/// <returns></returns>
6768
public static IEnumerable<T> ExecuteStoredProcedure<T>(this DatabaseFacade database, object storedProcedure) where T : class
68-
{
69-
if (storedProcedure == null)
70-
throw new ArgumentNullException("storedProcedure");
69+
{
70+
var contextField = database.GetType().GetField("_context", BindingFlags.Instance | BindingFlags.NonPublic);
7171

72-
var info = StoredProcedureParser.BuildStoredProcedureInfo(storedProcedure);
72+
var context = (DbContext)contextField.GetValue(database);
7373

74-
var contextField = database.GetType().GetField("_context", BindingFlags.Instance | BindingFlags.NonPublic);
74+
var entityType = FindModelEntityType(context, typeof(T));
75+
76+
if (entityType == null)
77+
{
78+
return database.InternalExecuteStoredProcedure<T>(storedProcedure);
79+
}
80+
81+
if (storedProcedure == null)
82+
throw new ArgumentNullException("storedProcedure");
7583

76-
var context = (DbContext)contextField.GetValue(database);
84+
var info = StoredProcedureParser.BuildStoredProcedureInfo(storedProcedure);
7785

7886
List<T> result = context.Set<T>().FromSqlRaw(info.Sql, info.SqlParameters).AsNoTracking().ToList();
7987

@@ -82,7 +90,58 @@ public static IEnumerable<T> ExecuteStoredProcedure<T>(this DatabaseFacade datab
8290
return result;
8391
}
8492

85-
/// <summary>
93+
internal static IEnumerable<T> InternalExecuteStoredProcedure<T>(this DatabaseFacade database, object storedProcedure)
94+
{
95+
if (storedProcedure == null)
96+
throw new ArgumentNullException("storedProcedure");
97+
98+
99+
List<T> result = new List<T>();
100+
var info = StoredProcedureParser.BuildStoredProcedureInfo(storedProcedure);
101+
102+
103+
// from : https://github.com/Fodsuk/EntityFrameworkExtras/pull/23/commits/dce354304aa9a95750f7d2559d1b002444ac46f7
104+
using (var command = database.GetDbConnection().CreateCommand())
105+
{
106+
command.CommandText = info.Sql;
107+
int? commandTimeout = database.GetCommandTimeout();
108+
if (commandTimeout.HasValue)
109+
{
110+
command.CommandTimeout = commandTimeout.Value;
111+
}
112+
command.CommandType = CommandType.Text;
113+
command.Parameters.AddRange(info.SqlParameters);
114+
command.Transaction = database.CurrentTransaction?.GetDbTransaction();
115+
database.OpenConnection();
116+
117+
using (var resultReader = command.ExecuteReader())
118+
{
119+
T obj = default(T);
120+
121+
while (resultReader.Read())
122+
{
123+
obj = Activator.CreateInstance<T>();
124+
foreach (PropertyInfo prop in obj.GetType().GetProperties())
125+
{
126+
var val = GetValue(resultReader, prop.Name);
127+
if (!object.Equals(val, DBNull.Value))
128+
{
129+
prop.SetValue(obj, val, null);
130+
}
131+
}
132+
133+
result.Add(obj);
134+
}
135+
}
136+
137+
}
138+
139+
SetOutputParameterValues(info.SqlParameters, storedProcedure);
140+
141+
return result;
142+
}
143+
144+
/// <summary>
86145
/// Executes the specified stored procedure against a database asynchronously
87146
/// and returns an enumerable of T representing the data returned.
88147
/// </summary>
@@ -92,32 +151,90 @@ public static IEnumerable<T> ExecuteStoredProcedure<T>(this DatabaseFacade datab
92151
/// <param name="cancellationToken">The cancellation token.</param>
93152
/// <returns></returns>
94153
public static async Task<IEnumerable<T>> ExecuteStoredProcedureAsync<T>(this DatabaseFacade database, object storedProcedure, CancellationToken cancellationToken = default) where T : class
154+
{
155+
var contextField = database.GetType().GetField("_context", BindingFlags.Instance | BindingFlags.NonPublic);
156+
157+
var context = (DbContext)contextField.GetValue(database);
158+
159+
var entityType = FindModelEntityType(context, typeof(T));
160+
161+
if (entityType == null)
162+
{
163+
return await database.InternalExecuteStoredProcedureAsync<T>(storedProcedure).ConfigureAwait(false);
164+
}
165+
166+
if (storedProcedure == null)
167+
throw new ArgumentNullException("storedProcedure");
168+
169+
var info = StoredProcedureParser.BuildStoredProcedureInfo(storedProcedure);
170+
171+
List<T> result = await context.Set<T>().FromSqlRaw(info.Sql, info.SqlParameters).AsNoTracking().ToListAsync(cancellationToken).ConfigureAwait(false);
172+
173+
SetOutputParameterValues(info.SqlParameters, storedProcedure);
174+
175+
return result;
176+
}
177+
178+
internal static async Task<IEnumerable<T>> InternalExecuteStoredProcedureAsync<T>(this DatabaseFacade database, object storedProcedure, CancellationToken cancellationToken = default)
95179
{
96180
if (storedProcedure == null)
97181
throw new ArgumentNullException("storedProcedure");
98182

183+
184+
List<T> result = new List<T>();
99185
var info = StoredProcedureParser.BuildStoredProcedureInfo(storedProcedure);
100186

101-
var contextField = database.GetType().GetField("_context", BindingFlags.Instance | BindingFlags.NonPublic);
102187

103-
var context = (DbContext)contextField.GetValue(database);
188+
// from : https://github.com/Fodsuk/EntityFrameworkExtras/pull/23/commits/dce354304aa9a95750f7d2559d1b002444ac46f7
189+
using (var command = database.GetDbConnection().CreateCommand())
190+
{
191+
command.CommandText = info.Sql;
192+
int? commandTimeout = database.GetCommandTimeout();
193+
if (commandTimeout.HasValue)
194+
{
195+
command.CommandTimeout = commandTimeout.Value;
196+
}
197+
command.CommandType = CommandType.Text;
198+
command.Parameters.AddRange(info.SqlParameters);
199+
command.Transaction = database.CurrentTransaction?.GetDbTransaction();
200+
database.OpenConnection();
201+
202+
using (var resultReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false))
203+
{
204+
T obj = default(T);
205+
206+
while (await resultReader.ReadAsync(cancellationToken).ConfigureAwait(false))
207+
{
208+
obj = Activator.CreateInstance<T>();
209+
foreach (PropertyInfo prop in obj.GetType().GetProperties())
210+
{
211+
var val = GetValue(resultReader, prop.Name);
212+
if (!object.Equals(val, DBNull.Value))
213+
{
214+
prop.SetValue(obj, val, null);
215+
}
216+
}
217+
218+
result.Add(obj);
219+
}
220+
}
104221

105-
List<T> result = await context.Set<T>().FromSqlRaw(info.Sql, info.SqlParameters).AsNoTracking().ToListAsync(cancellationToken);
222+
}
106223

107224
SetOutputParameterValues(info.SqlParameters, storedProcedure);
108225

109226
return result;
110227
}
111228

112-
/// <summary>
113-
/// Executes the specified stored procedure against a database
114-
/// and returns the first or default value
115-
/// </summary>
116-
/// <typeparam name="T">Type of the data returned from the stored procedure.</typeparam>
117-
/// <param name="database">The database to execute against.</param>
118-
/// <param name="storedProcedure">The stored procedure to execute.</param>
119-
/// <returns></returns>
120-
public static T ExecuteStoredProcedureFirstOrDefault<T>(this DatabaseFacade database, object storedProcedure) where T : class
229+
/// <summary>
230+
/// Executes the specified stored procedure against a database
231+
/// and returns the first or default value
232+
/// </summary>
233+
/// <typeparam name="T">Type of the data returned from the stored procedure.</typeparam>
234+
/// <param name="database">The database to execute against.</param>
235+
/// <param name="storedProcedure">The stored procedure to execute.</param>
236+
/// <returns></returns>
237+
public static T ExecuteStoredProcedureFirstOrDefault<T>(this DatabaseFacade database, object storedProcedure) where T : class
121238
{
122239
return database.ExecuteStoredProcedure<T>(storedProcedure).FirstOrDefault();
123240
}
@@ -137,6 +254,24 @@ public static async Task<T> ExecuteStoredProcedureFirstOrDefaultAsync<T>(this Da
137254

138255
return executed.FirstOrDefault();
139256
}
257+
258+
internal static IEntityType FindModelEntityType(this DbContext @this, Type type)
259+
{
260+
var entityType = @this.Model.FindEntityType(type);
261+
262+
if (entityType == null)
263+
{
264+
var baseType = type;
265+
266+
while (baseType != null && entityType == null)
267+
{
268+
entityType = @this.Model.FindEntityType(baseType);
269+
baseType = baseType.BaseType;
270+
}
271+
}
272+
273+
return entityType;
274+
}
140275
}
141276
}
142277
#endif

0 commit comments

Comments
 (0)