using
System;
using
System.Collections.Generic;
using
System.Data;
using
System.Data.Common;
using
System.Data.Common.CommandTrees;
using
System.Data.Entity;
using
System.Data.EntityClient;
using
System.Data.Metadata.Edm;
using
System.Data.Objects.DataClasses;
using
System.Data.SqlClient;
using
System.Linq;
using
System.Linq.Dynamic;
using
System.Linq.Expressions;
using
System.Reflection;
using
System.Text;
using
System.Data.Objects;
using
EntityFramework.Mapping;
using
EntityFramework.Reflection;
namespace
EntityFramework.Extensions
{
/// <summary>
/// An extensions class for batch queries.
/// </summary>
public
static
class
BatchExtensions
{
/// <summary>
/// Executes a delete statement using the query to filter the rows to be deleted.
/// </summary>
/// <typeparam name="TEntity">The type of the entity.</typeparam>
/// <param name="source">The source used to determine the table to delete from.</param>
/// <param name="query">The IQueryable used to generate the where clause for the delete statement.</param>
/// <returns>The number of row deleted.</returns>
/// <remarks>
/// When executing this method, the statement is immediately executed on the database provider
/// and is not part of the change tracking system. Also, changes will not be reflected on
/// any entities that have already been materialized in the current context.
/// </remarks>
public
static
int
Delete<TEntity>(
this
ObjectSet<TEntity> source,
IQueryable<TEntity> query)
where TEntity :
class
{
if
(source ==
null
)
throw
new
ArgumentNullException(
"source"
);
if
(query ==
null
)
throw
new
ArgumentNullException(
"query"
);
ObjectContext objectContext = source.Context;
if
(objectContext ==
null
)
throw
new
ArgumentException(
"The ObjectContext for the source query can not be null."
,
"source"
);
EntityMap entityMap = source.GetEntityMap<TEntity>();
if
(entityMap ==
null
)
throw
new
ArgumentException(
"Could not load the entity mapping information for the source ObjectSet."
,
"source"
);
ObjectQuery<TEntity> objectQuery = query.ToObjectQuery();
if
(objectQuery ==
null
)
throw
new
ArgumentException(
"The query must be of type ObjectQuery or DbQuery."
,
"query"
);
return
Delete(objectContext, entityMap, objectQuery);
}
/// <summary>
/// Executes a delete statement using an expression to filter the rows to be deleted.
/// </summary>
/// <typeparam name="TEntity">The type of the entity.</typeparam>
/// <param name="source">The source used to determine the table to delete from.</param>
/// <param name="filterExpression">The filter expression used to generate the where clause for the delete statement.</param>
/// <returns>The number of row deleted.</returns>
/// <example>Delete all users with email domain @test.com.
/// <code><![CDATA[
/// var db = new TrackerEntities();
/// string emailDomain = "@test.com";
/// int count = db.Users.Delete(u => u.Email.EndsWith(emailDomain));
/// ]]></code>
/// </example>
/// <remarks>
/// When executing this method, the statement is immediately executed on the database provider
/// and is not part of the change tracking system. Also, changes will not be reflected on
/// any entities that have already been materialized in the current context.
/// </remarks>
public
static
int
Delete<TEntity>(
this
ObjectSet<TEntity> source,
Expression<Func<TEntity,
bool
>> filterExpression)
where TEntity :
class
{
if
(source ==
null
)
throw
new
ArgumentNullException(
"source"
);
if
(filterExpression ==
null
)
throw
new
ArgumentNullException(
"filterExpression"
);
return
source.Delete(source.Where(filterExpression));
}
/// <summary>
/// Executes a delete statement using the query to filter the rows to be deleted.
/// </summary>
/// <typeparam name="TEntity">The type of the entity.</typeparam>
/// <param name="source">The source used to determine the table to delete from.</param>
/// <param name="query">The IQueryable used to generate the where clause for the delete statement.</param>
/// <returns>The number of row deleted.</returns>
/// <remarks>
/// When executing this method, the statement is immediately executed on the database provider
/// and is not part of the change tracking system. Also, changes will not be reflected on
/// any entities that have already been materialized in the current context.
/// </remarks>
public
static
int
Delete<TEntity>(
this
DbSet<TEntity> source,
IQueryable<TEntity> query)
where TEntity :
class
{
if
(source ==
null
)
throw
new
ArgumentNullException(
"source"
);
if
(query ==
null
)
throw
new
ArgumentNullException(
"query"
);
ObjectQuery<TEntity> sourceQuery = source.ToObjectQuery();
if
(sourceQuery ==
null
)
throw
new
ArgumentException(
"The query must be of type ObjectQuery or DbQuery."
,
"source"
);
ObjectContext objectContext = sourceQuery.Context;
if
(objectContext ==
null
)
throw
new
ArgumentException(
"The ObjectContext for the source query can not be null."
,
"source"
);
EntityMap entityMap = sourceQuery.GetEntityMap<TEntity>();
if
(entityMap ==
null
)
throw
new
ArgumentException(
"Could not load the entity mapping information for the source ObjectSet."
,
"source"
);
ObjectQuery<TEntity> objectQuery = query.ToObjectQuery();
if
(objectQuery ==
null
)
throw
new
ArgumentException(
"The query must be of type ObjectQuery or DbQuery."
,
"query"
);
return
Delete(objectContext, entityMap, objectQuery);
}
/// <summary>
/// Executes a delete statement using an expression to filter the rows to be deleted.
/// </summary>
/// <typeparam name="TEntity">The type of the entity.</typeparam>
/// <param name="source">The source used to determine the table to delete from.</param>
/// <param name="filterExpression">The filter expression used to generate the where clause for the delete statement.</param>
/// <returns>The number of row deleted.</returns>
/// <example>Delete all users with email domain @test.com.
/// <code><![CDATA[
/// var db = new TrackerContext();
/// string emailDomain = "@test.com";
/// int count = db.Users.Delete(u => u.Email.EndsWith(emailDomain));
/// ]]></code>
/// </example>
/// <remarks>
/// When executing this method, the statement is immediately executed on the database provider
/// and is not part of the change tracking system. Also, changes will not be reflected on
/// any entities that have already been materialized in the current context.
/// </remarks>
public
static
int
Delete<TEntity>(
this
DbSet<TEntity> source,
Expression<Func<TEntity,
bool
>> filterExpression)
where TEntity :
class
{
if
(source ==
null
)
throw
new
ArgumentNullException(
"source"
);
if
(filterExpression ==
null
)
throw
new
ArgumentNullException(
"filterExpression"
);
return
source.Delete(source.Where(filterExpression));
}
/// <summary>
/// Executes an update statement using the query to filter the rows to be updated.
/// </summary>
/// <typeparam name="TEntity">The type of the entity.</typeparam>
/// <param name="source">The source used to determine the table to update.</param>
/// <param name="query">The query used to generate the where clause.</param>
/// <param name="updateExpression">The MemberInitExpression used to indicate what is updated.</param>
/// <returns>The number of row updated.</returns>
/// <remarks>
/// When executing this method, the statement is immediately executed on the database provider
/// and is not part of the change tracking system. Also, changes will not be reflected on
/// any entities that have already been materialized in the current context.
/// </remarks>
public
static
int
Update<TEntity>(
this
ObjectSet<TEntity> source,
IQueryable<TEntity> query,
Expression<Func<TEntity, TEntity>> updateExpression)
where TEntity :
class
{
if
(source ==
null
)
throw
new
ArgumentNullException(
"source"
);
if
(query ==
null
)
throw
new
ArgumentNullException(
"query"
);
if
(updateExpression ==
null
)
throw
new
ArgumentNullException(
"updateExpression"
);
ObjectContext objectContext = source.Context;
if
(objectContext ==
null
)
throw
new
ArgumentException(
"The ObjectContext for the source query can not be null."
,
"source"
);
EntityMap entityMap = source.GetEntityMap<TEntity>();
if
(entityMap ==
null
)
throw
new
ArgumentException(
"Could not load the entity mapping information for the source ObjectSet."
,
"source"
);
ObjectQuery<TEntity> objectQuery = query.ToObjectQuery();
if
(objectQuery ==
null
)
throw
new
ArgumentException(
"The query must be of type ObjectQuery or DbQuery."
,
"query"
);
return
Update(objectContext, entityMap, objectQuery, updateExpression);
}
/// <summary>
/// Executes an update statement using an expression to filter the rows that are updated.
/// </summary>
/// <typeparam name="TEntity">The type of the entity.</typeparam>
/// <param name="source">The source used to determine the table to update.</param>
/// <param name="filterExpression">The filter expression used to generate the where clause.</param>
/// <param name="updateExpression">The MemberInitExpression used to indicate what is updated.</param>
/// <returns>The number of row updated.</returns>
/// <example>Update all users in the test.com domain to be inactive.
/// <code><![CDATA[
/// var db = new TrackerEntities();
/// string emailDomain = "@test.com";
/// int count = db.Users.Update(
/// u => u.Email.EndsWith(emailDomain),
/// u => new User { IsApproved = false, LastActivityDate = DateTime.Now });
/// ]]></code>
/// </example>
/// <remarks>
/// When executing this method, the statement is immediately executed on the database provider
/// and is not part of the change tracking system. Also, changes will not be reflected on
/// any entities that have already been materialized in the current context.
/// </remarks>
public
static
int
Update<TEntity>(
this
ObjectSet<TEntity> source,
Expression<Func<TEntity,
bool
>> filterExpression,
Expression<Func<TEntity, TEntity>> updateExpression)
where TEntity :
class
{
if
(source ==
null
)
throw
new
ArgumentNullException(
"source"
);
if
(filterExpression ==
null
)
throw
new
ArgumentNullException(
"filterExpression"
);
return
source.Update(source.Where(filterExpression), updateExpression);
}
/// <summary>
/// Executes an update statement using the query to filter the rows to be updated.
/// </summary>
/// <typeparam name="TEntity">The type of the entity.</typeparam>
/// <param name="source">The source used to determine the table to update.</param>
/// <param name="query">The query used to generate the where clause.</param>
/// <param name="updateExpression">The MemberInitExpression used to indicate what is updated.</param>
/// <returns>The number of row updated.</returns>
/// <remarks>
/// When executing this method, the statement is immediately executed on the database provider
/// and is not part of the change tracking system. Also, changes will not be reflected on
/// any entities that have already been materialized in the current context.
/// </remarks>
public
static
int
Update<TEntity>(
this
DbSet<TEntity> source,
IQueryable<TEntity> query,
Expression<Func<TEntity, TEntity>> updateExpression)
where TEntity :
class
{
if
(source ==
null
)
throw
new
ArgumentNullException(
"source"
);
if
(query ==
null
)
throw
new
ArgumentNullException(
"query"
);
if
(updateExpression ==
null
)
throw
new
ArgumentNullException(
"updateExpression"
);
ObjectQuery<TEntity> sourceQuery = source.ToObjectQuery();
if
(sourceQuery ==
null
)
throw
new
ArgumentException(
"The query must be of type ObjectQuery or DbQuery."
,
"source"
);
ObjectContext objectContext = sourceQuery.Context;
if
(objectContext ==
null
)
throw
new
ArgumentException(
"The ObjectContext for the source query can not be null."
,
"source"
);
EntityMap entityMap = sourceQuery.GetEntityMap<TEntity>();
if
(entityMap ==
null
)
throw
new
ArgumentException(
"Could not load the entity mapping information for the source."
,
"source"
);
ObjectQuery<TEntity> objectQuery = query.ToObjectQuery();
if
(objectQuery ==
null
)
throw
new
ArgumentException(
"The query must be of type ObjectQuery or DbQuery."
,
"query"
);
return
Update(objectContext, entityMap, objectQuery, updateExpression);
}
/// <summary>
/// Executes an update statement using an expression to filter the rows that are updated.
/// </summary>
/// <typeparam name="TEntity">The type of the entity.</typeparam>
/// <param name="source">The source used to determine the table to update.</param>
/// <param name="filterExpression">The filter expression used to generate the where clause.</param>
/// <param name="updateExpression">The MemberInitExpression used to indicate what is updated.</param>
/// <returns>The number of row updated.</returns>
/// <example>Update all users in the test.com domain to be inactive.
/// <code><![CDATA[
/// var db = new TrackerContext();
/// string emailDomain = "@test.com";
/// int count = db.Users.Update(
/// u => u.Email.EndsWith(emailDomain),
/// u => new User { IsApproved = false, LastActivityDate = DateTime.Now });
/// ]]></code>
/// </example>
/// <remarks>
/// When executing this method, the statement is immediately executed on the database provider
/// and is not part of the change tracking system. Also, changes will not be reflected on
/// any entities that have already been materialized in the current context.
/// </remarks>
public
static
int
Update<TEntity>(
this
DbSet<TEntity> source,
Expression<Func<TEntity,
bool
>> filterExpression,
Expression<Func<TEntity, TEntity>> updateExpression)
where TEntity :
class
{
if
(source ==
null
)
throw
new
ArgumentNullException(
"source"
);
if
(filterExpression ==
null
)
throw
new
ArgumentNullException(
"filterExpression"
);
return
source.Update(source.Where(filterExpression), updateExpression);
}
private
static
int
Delete<TEntity>(ObjectContext objectContext, EntityMap entityMap, ObjectQuery<TEntity> query)
where TEntity :
class
{
DbConnection deleteConnection =
null
;
DbTransaction deleteTransaction =
null
;
DbCommand deleteCommand =
null
;
bool
ownConnection =
false
;
bool
ownTransaction =
false
;
try
{
var store = GetStore(objectContext);
deleteConnection = store.Item1;
deleteTransaction = store.Item2;
if
(deleteConnection.State != ConnectionState.Open)
{
deleteConnection.Open();
ownConnection =
true
;
}
if
(deleteTransaction ==
null
)
{
deleteTransaction = deleteConnection.BeginTransaction();
ownTransaction =
true
;
}
deleteCommand = deleteConnection.CreateCommand();
deleteCommand.Transaction = deleteTransaction;
if
(objectContext.CommandTimeout.HasValue)
deleteCommand.CommandTimeout = objectContext.CommandTimeout.Value;
var innerSelect = GetSelectSql(query, entityMap, deleteCommand);
var sqlBuilder =
new
StringBuilder(innerSelect.Length * 2);
sqlBuilder.Append(
"DELETE "
);
sqlBuilder.Append(entityMap.TableName);
sqlBuilder.AppendLine();
sqlBuilder.AppendFormat(
"FROM {0} AS j0 INNER JOIN ("
, entityMap.TableName);
sqlBuilder.AppendLine();
sqlBuilder.AppendLine(innerSelect);
sqlBuilder.Append(
") AS j1 ON ("
);
bool
wroteKey =
false
;
foreach
(var keyMap
in
entityMap.KeyMaps)
{
if
(wroteKey)
sqlBuilder.Append(
" AND "
);
sqlBuilder.AppendFormat(
"j0.{0} = j1.{0}"
, keyMap.ColumnName);
wroteKey =
true
;
}
sqlBuilder.Append(
")"
);
deleteCommand.CommandText = sqlBuilder.ToString();
int
result = deleteCommand.ExecuteNonQuery();
if
(ownTransaction)
deleteTransaction.Commit();
return
result;
}
finally
{
if
(deleteCommand !=
null
)
deleteCommand.Dispose();
if
(deleteTransaction !=
null
&& ownTransaction)
deleteTransaction.Dispose();
if
(deleteConnection !=
null
&& ownConnection)
deleteConnection.Close();
}
}
private
static
int
Update<TEntity>(ObjectContext objectContext, EntityMap entityMap, ObjectQuery<TEntity> query, Expression<Func<TEntity, TEntity>> updateExpression)
where TEntity :
class
{
DbConnection updateConnection =
null
;
DbTransaction updateTransaction =
null
;
DbCommand updateCommand =
null
;
bool
ownConnection =
false
;
bool
ownTransaction =
false
;
try
{
var store = GetStore(objectContext);
updateConnection = store.Item1;
updateTransaction = store.Item2;
if
(updateConnection.State != ConnectionState.Open)
{
updateConnection.Open();
ownConnection =
true
;
}
if
(updateTransaction ==
null
)
{
updateTransaction = updateConnection.BeginTransaction();
ownTransaction =
true
;
}
updateCommand = updateConnection.CreateCommand();
updateCommand.Transaction = updateTransaction;
if
(objectContext.CommandTimeout.HasValue)
updateCommand.CommandTimeout = objectContext.CommandTimeout.Value;
var innerSelect = GetSelectSql(query, entityMap, updateCommand);
var sqlBuilder =
new
StringBuilder(innerSelect.Length * 2);
sqlBuilder.Append(
"UPDATE "
);
sqlBuilder.Append(entityMap.TableName);
sqlBuilder.AppendLine(
" SET "
);
var memberInitExpression = updateExpression.Body
as
MemberInitExpression;
if
(memberInitExpression ==
null
)
throw
new
ArgumentException(
"The update expression must be of type MemberInitExpression."
,
"updateExpression"
);
int
nameCount = 0;
bool
wroteSet =
false
;
foreach
(MemberBinding binding
in
memberInitExpression.Bindings)
{
if
(wroteSet)
sqlBuilder.AppendLine(
", "
);
string
propertyName = binding.Member.Name;
string
columnName = entityMap.PropertyMaps
.Where(p => p.PropertyName == propertyName)
.Select(p => p.ColumnName)
.FirstOrDefault();
string
parameterName =
"p__update__"
+ nameCount++;
var memberAssignment = binding
as
MemberAssignment;
if
(memberAssignment ==
null
)
throw
new
ArgumentException(
"The update expression MemberBinding must only by type MemberAssignment."
,
"updateExpression"
);
object
value;
if
(memberAssignment.Expression.NodeType == ExpressionType.Constant)
{
var constantExpression = memberAssignment.Expression
as
ConstantExpression;
if
(constantExpression ==
null
)
throw
new
ArgumentException(
"The MemberAssignment expression is not a ConstantExpression."
,
"updateExpression"
);
value = constantExpression.Value;
}
else
{
LambdaExpression lambda = Expression.Lambda(memberAssignment.Expression,
null
);
value = lambda.Compile().DynamicInvoke();
}
var parameter = updateCommand.CreateParameter();
parameter.ParameterName = parameterName;
parameter.Value = value;
updateCommand.Parameters.Add(parameter);
sqlBuilder.AppendFormat(
"{0} = @{1}"
, columnName, parameterName);
wroteSet =
true
;
}
sqlBuilder.AppendLine(
" "
);
sqlBuilder.AppendFormat(
"FROM {0} AS j0 INNER JOIN ("
, entityMap.TableName);
sqlBuilder.AppendLine();
sqlBuilder.AppendLine(innerSelect);
sqlBuilder.Append(
") AS j1 ON ("
);
bool
wroteKey =
false
;
foreach
(var keyMap
in
entityMap.KeyMaps)
{
if
(wroteKey)
sqlBuilder.Append(
" AND "
);
sqlBuilder.AppendFormat(
"j0.{0} = j1.{0}"
, keyMap.ColumnName);
wroteKey =
true
;
}
sqlBuilder.Append(
")"
);
updateCommand.CommandText = sqlBuilder.ToString();
int
result = updateCommand.ExecuteNonQuery();
if
(ownTransaction)
updateTransaction.Commit();
return
result;
}
finally
{
if
(updateCommand !=
null
)
updateCommand.Dispose();
if
(updateTransaction !=
null
&& ownTransaction)
updateTransaction.Dispose();
if
(updateConnection !=
null
&& ownConnection)
updateConnection.Close();
}
}
private
static
Tuple<DbConnection, DbTransaction> GetStore(ObjectContext objectContext)
{
DbConnection dbConnection = objectContext.Connection;
var entityConnection = dbConnection
as
EntityConnection;
if
(entityConnection ==
null
)
return
new
Tuple<DbConnection, DbTransaction>(dbConnection,
null
);
DbConnection connection = entityConnection.StoreConnection;
dynamic connectionProxy =
new
DynamicProxy(entityConnection);
dynamic entityTransaction = connectionProxy.CurrentTransaction;
if
(entityTransaction ==
null
)
return
new
Tuple<DbConnection, DbTransaction>(connection,
null
);
DbTransaction transaction = entityTransaction.StoreTransaction;
return
new
Tuple<DbConnection, DbTransaction>(connection, transaction);
}
private
static
string
GetSelectSql<TEntity>(ObjectQuery<TEntity> query, EntityMap entityMap, DbCommand command)
where TEntity :
class
{
var selector =
new
StringBuilder(50);
selector.Append(
"new("
);
foreach
(var propertyMap
in
entityMap.KeyMaps)
{
if
(selector.Length > 4)
selector.Append((
", "
));
selector.Append(propertyMap.PropertyName);
}
selector.Append(
")"
);
var selectQuery = DynamicQueryable.Select(query, selector.ToString());
var objectQuery = selectQuery
as
ObjectQuery;
if
(objectQuery ==
null
)
throw
new
ArgumentException(
"The query must be of type ObjectQuery."
,
"query"
);
string
innerJoinSql = objectQuery.ToTraceString();
foreach
(var objectParameter
in
objectQuery.Parameters)
{
var parameter = command.CreateParameter();
parameter.ParameterName = objectParameter.Name;
parameter.Value = objectParameter.Value;
command.Parameters.Add(parameter);
}
return
innerJoinSql;
}
}
}